AK: Add OOM safe interface to HashTable/Map

This adds a new HashSetResult only returned by try_set, to signal
allocation failure during setting.
This commit is contained in:
Hediadyoin1 2021-08-14 02:07:39 +02:00 committed by Brian Gianforcaro
parent f36781a8bc
commit 1aa527f5b6
2 changed files with 68 additions and 31 deletions

View file

@ -57,6 +57,9 @@ public:
HashSetResult set(const K& key, const V& value) { return m_table.set({ key, value }); }
HashSetResult set(const K& key, V&& value) { return m_table.set({ key, move(value) }); }
HashSetResult try_set(const K& key, const V& value) { return m_table.try_set({ key, value }); }
HashSetResult try_set(const K& key, V&& value) { return m_table.try_set({ key, move(value) }); }
bool remove(const K& key)
{
auto it = find(key);
@ -96,6 +99,7 @@ public:
}
void ensure_capacity(size_t capacity) { m_table.ensure_capacity(capacity); }
bool try_ensure_capacity(size_t capacity) { return m_table.try_ensure_capacity(capacity); }
Optional<typename Traits<V>::PeekType> get(const K& key) const requires(!IsPointer<typename Traits<V>::PeekType>)
{

View file

@ -15,6 +15,7 @@
namespace AK {
enum class HashSetResult {
Failed = 0,
InsertedNewEntry,
ReplacedExistingEntry,
KeptExistingEntry
@ -184,11 +185,19 @@ public:
[[nodiscard]] size_t capacity() const { return m_capacity; }
template<typename U, size_t N>
void set_from(U (&from_array)[N])
bool try_set_from(U (&from_array)[N])
{
for (size_t i = 0; i < N; ++i) {
set(from_array[i]);
if (try_set(from_array[i]) == HashSetResult::Failed)
return false;
}
return true;
}
template<typename U, size_t N>
void set_from(U (&from_array)[N])
{
bool result = try_set_from(from_array);
VERIFY(result);
}
void ensure_capacity(size_t capacity)
@ -250,36 +259,45 @@ public:
}
template<typename U = T>
HashSetResult set(U&& value, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Replace)
HashSetResult try_set(U&& value, HashSetExistingEntryBehavior existing_entry_behavior = HashSetExistingEntryBehavior::Replace)
{
auto& bucket = lookup_for_writing(value);
if (bucket.used) {
auto* bucket = try_lookup_for_writing(value);
if (!bucket)
return HashSetResult::Failed;
if (bucket->used) {
if (existing_entry_behavior == HashSetExistingEntryBehavior::Keep)
return HashSetResult::KeptExistingEntry;
(*bucket.slot()) = forward<U>(value);
(*bucket->slot()) = forward<U>(value);
return HashSetResult::ReplacedExistingEntry;
}
new (bucket.slot()) T(forward<U>(value));
bucket.used = true;
if (bucket.deleted) {
bucket.deleted = false;
new (bucket->slot()) T(forward<U>(value));
bucket->used = true;
if (bucket->deleted) {
bucket->deleted = false;
--m_deleted_count;
}
if constexpr (IsOrdered) {
if (!m_collection_data.head) [[unlikely]] {
m_collection_data.head = &bucket;
m_collection_data.head = bucket;
} else {
bucket.previous = m_collection_data.tail;
m_collection_data.tail->next = &bucket;
bucket->previous = m_collection_data.tail;
m_collection_data.tail->next = bucket;
}
m_collection_data.tail = &bucket;
m_collection_data.tail = bucket;
}
++m_size;
return HashSetResult::InsertedNewEntry;
}
template<typename U = T>
HashSetResult set(U&& value, HashSetExistingEntryBehavior existing_entry_behaviour = HashSetExistingEntryBehavior::Replace)
{
auto result = try_set(forward<U>(value), existing_entry_behaviour);
VERIFY(result != HashSetResult::Failed);
return result;
}
template<typename TUnaryPredicate>
[[nodiscard]] Iterator find(unsigned hash, TUnaryPredicate predicate)
@ -369,7 +387,7 @@ private:
}
}
void rehash(size_t new_capacity)
bool try_rehash(size_t new_capacity)
{
new_capacity = max(new_capacity, static_cast<size_t>(4));
new_capacity = kmalloc_good_size(new_capacity * sizeof(BucketType)) / sizeof(BucketType);
@ -378,24 +396,23 @@ private:
auto old_capacity = m_capacity;
Iterator old_iter = begin();
if constexpr (IsOrdered) {
m_buckets = (BucketType*)kmalloc(size_in_bytes(new_capacity));
__builtin_memset(m_buckets, 0, size_in_bytes(new_capacity));
auto new_buckets = kmalloc(size_in_bytes(new_capacity));
if (!new_buckets)
return false;
m_collection_data = { nullptr, nullptr };
} else {
m_buckets = (BucketType*)kmalloc(size_in_bytes(new_capacity));
__builtin_memset(m_buckets, 0, size_in_bytes(new_capacity));
}
m_buckets = (BucketType*)new_buckets;
__builtin_memset(m_buckets, 0, size_in_bytes(new_capacity));
m_capacity = new_capacity;
m_deleted_count = 0;
if constexpr (!IsOrdered)
if constexpr (IsOrdered)
m_collection_data = { nullptr, nullptr };
else
m_buckets[m_capacity].end = true;
if (!old_buckets)
return;
return true;
for (auto it = move(old_iter); it != end(); ++it) {
insert_during_rehash(move(*it));
@ -403,6 +420,12 @@ private:
}
kfree_sized(old_buckets, size_in_bytes(old_capacity));
return true;
}
void rehash(size_t new_capacity)
{
bool result = try_rehash(new_capacity);
VERIFY(result);
}
template<typename TUnaryPredicate>
@ -424,30 +447,40 @@ private:
}
}
[[nodiscard]] BucketType& lookup_for_writing(T const& value)
[[nodiscard]] BucketType* try_lookup_for_writing(T const& value)
{
if (should_grow())
rehash(capacity() * 2);
// FIXME: Maybe overrun the "allowed" load factor to avoid OOM
// If we are allowed to do that, separate that logic from
// the normal lookup_for_writing
if (should_grow()) {
if (!try_rehash(capacity() * 2))
return nullptr;
}
auto hash = TraitsForT::hash(value);
BucketType* first_empty_bucket = nullptr;
for (;;) {
auto& bucket = m_buckets[hash % m_capacity];
if (bucket.used && TraitsForT::equals(*bucket.slot(), value))
return bucket;
return &bucket;
if (!bucket.used) {
if (!first_empty_bucket)
first_empty_bucket = &bucket;
if (!bucket.deleted)
return *const_cast<BucketType*>(first_empty_bucket);
return const_cast<BucketType*>(first_empty_bucket);
}
hash = double_hash(hash);
}
}
[[nodiscard]] BucketType& lookup_for_writing(T const& value)
{
auto* item = try_lookup_for_writing(value);
VERIFY(item);
return *item;
}
[[nodiscard]] size_t used_bucket_count() const { return m_size + m_deleted_count; }
[[nodiscard]] bool should_grow() const { return ((used_bucket_count() + 1) * 100) >= (m_capacity * load_factor_in_percent); }