diff --git a/AK/NonnullRefPtr.h b/AK/NonnullRefPtr.h index 316a61138c..46686927e9 100644 --- a/AK/NonnullRefPtr.h +++ b/AK/NonnullRefPtr.h @@ -28,6 +28,7 @@ #include #include +#include #include namespace AK { @@ -126,69 +127,56 @@ public: NonnullRefPtr& operator=(const NonnullRefPtr& other) { - if (m_ptr != other.m_ptr) { - deref_if_not_null(m_ptr); - m_ptr = const_cast(other.ptr()); - m_ptr->ref(); - } + NonnullRefPtr ptr(other); + swap(ptr); return *this; } template NonnullRefPtr& operator=(const NonnullRefPtr& other) { - if (m_ptr != other.ptr()) { - deref_if_not_null(m_ptr); - m_ptr = const_cast(static_cast(other.ptr())); - m_ptr->ref(); - } + NonnullRefPtr ptr(other); + swap(ptr); return *this; } NonnullRefPtr& operator=(NonnullRefPtr&& other) { - if (this != &other) { - deref_if_not_null(m_ptr); - m_ptr = &other.leak_ref(); - } + NonnullRefPtr ptr(move(other)); + swap(ptr); return *this; } template NonnullRefPtr& operator=(NonnullRefPtr&& other) { - if (this != static_cast(&other)) { - deref_if_not_null(m_ptr); - m_ptr = static_cast(&other.leak_ref()); - } + NonnullRefPtr ptr(move(other)); + swap(ptr); return *this; } NonnullRefPtr& operator=(const T& object) { - if (m_ptr != &object) { - deref_if_not_null(m_ptr); - m_ptr = const_cast(&object); - m_ptr->ref(); - } + NonnullRefPtr ptr(object); + swap(ptr); return *this; } [[nodiscard]] CALLABLE_WHEN(unconsumed) - SET_TYPESTATE(consumed) - T& leak_ref() + SET_TYPESTATE(consumed) + T& leak_ref() { ASSERT(m_ptr); return *exchange(m_ptr, nullptr); } - CALLABLE_WHEN("unconsumed","unknown") + CALLABLE_WHEN("unconsumed", "unknown") T* ptr() { ASSERT(m_ptr); return m_ptr; } - CALLABLE_WHEN("unconsumed","unknown") + CALLABLE_WHEN("unconsumed", "unknown") const T* ptr() const { ASSERT(m_ptr); @@ -250,6 +238,17 @@ public: operator bool() const = delete; bool operator!() const = delete; + void swap(NonnullRefPtr& other) + { + ::swap(m_ptr, other.m_ptr); + } + + template + void swap(NonnullRefPtr& other) + { + ::swap(m_ptr, other.m_ptr); + } + private: NonnullRefPtr() = delete; diff --git a/AK/Tests/TestNonnullRefPtr.cpp b/AK/Tests/TestNonnullRefPtr.cpp index a37e82ce2b..a8c77ae1ea 100644 --- a/AK/Tests/TestNonnullRefPtr.cpp +++ b/AK/Tests/TestNonnullRefPtr.cpp @@ -59,4 +59,18 @@ TEST_CASE(assign_reference) EXPECT_EQ(object->ref_count(), 1); } +TEST_CASE(assign_owner_of_self) +{ + struct Object : public RefCounted { + RefPtr parent; + }; + + auto parent = adopt(*new Object); + auto child = adopt(*new Object); + child->parent = move(parent); + + child = *child->parent; + EXPECT_EQ(child->ref_count(), 1); +} + TEST_MAIN(String)