/* * Copyright (c) 2021, Idan Horowitz * * SPDX-License-Identifier: BSD-2-Clause */ #pragma once #include namespace AK { template class IntrusiveRedBlackTreeNode; template V::*member> class IntrusiveRedBlackTree : public BaseRedBlackTree { public: IntrusiveRedBlackTree() = default; virtual ~IntrusiveRedBlackTree() override { clear(); } using BaseTree = BaseRedBlackTree; using TreeNode = IntrusiveRedBlackTreeNode; V* find(K key) { auto* node = static_cast(BaseTree::find(this->m_root, key)); if (!node) return nullptr; return node_to_value(*node); } V* find_largest_not_above(K key) { auto* node = static_cast(BaseTree::find_largest_not_above(this->m_root, key)); if (!node) return nullptr; return node_to_value(*node); } void insert(V& value) { auto& node = value.*member; BaseTree::insert(&node); node.m_in_tree = true; } template class BaseIterator { public: BaseIterator() = default; bool operator!=(const BaseIterator& other) const { return m_node != other.m_node; } BaseIterator& operator++() { if (!m_node) return *this; m_prev = m_node; // the complexity is O(logn) for each successor call, but the total complexity for all elements comes out to O(n), meaning the amortized cost for a single call is O(1) m_node = static_cast(BaseTree::successor(m_node)); return *this; } BaseIterator& operator--() { if (!m_prev) return *this; m_node = m_prev; m_prev = static_cast(BaseTree::predecessor(m_prev)); return *this; } ElementType& operator*() { VERIFY(m_node); return *node_to_value(*m_node); } ElementType* operator->() { VERIFY(m_node); return node_to_value(*m_node); } [[nodiscard]] bool is_end() const { return !m_node; } [[nodiscard]] bool is_begin() const { return !m_prev; } private: friend class IntrusiveRedBlackTree; explicit BaseIterator(TreeNode* node, TreeNode* prev = nullptr) : m_node(node) , m_prev(prev) { } TreeNode* m_node { nullptr }; TreeNode* m_prev { nullptr }; }; using Iterator = BaseIterator; Iterator begin() { return Iterator(static_cast(this->m_minimum)); } Iterator end() { return {}; } Iterator begin_from(K key) { return Iterator(static_cast(BaseTree::find(this->m_root, key))); } using ConstIterator = BaseIterator; ConstIterator begin() const { return ConstIterator(static_cast(this->m_minimum)); } ConstIterator end() const { return {}; } ConstIterator begin_from(K key) const { return ConstIterator(static_cast(BaseTree::find(this->m_rootF, key))); } bool remove(K key) { auto* node = static_cast(BaseTree::find(this->m_root, key)); if (!node) return false; BaseTree::remove(node); node->right_child = nullptr; node->left_child = nullptr; node->m_in_tree = false; return true; } void clear() { clear_nodes(static_cast(this->m_root)); this->m_root = nullptr; this->m_minimum = nullptr; this->m_size = 0; } private: static void clear_nodes(TreeNode* node) { if (!node) return; clear_nodes(static_cast(node->right_child)); node->right_child = nullptr; clear_nodes(static_cast(node->left_child)); node->left_child = nullptr; node->m_in_tree = false; } static V* node_to_value(TreeNode& node) { return (V*)((u8*)&node - ((u8*)&(((V*)nullptr)->*member) - (u8*)nullptr)); } }; template class IntrusiveRedBlackTreeNode : public BaseRedBlackTree::Node { public: IntrusiveRedBlackTreeNode(K key) : BaseRedBlackTree::Node(key) { } ~IntrusiveRedBlackTreeNode() { VERIFY(!is_in_tree()); } bool is_in_tree() { return m_in_tree; } private: template V::*member> friend class IntrusiveRedBlackTree; bool m_in_tree { false }; }; } using AK::IntrusiveRedBlackTree; using AK::IntrusiveRedBlackTreeNode;