AK: Allow AK::Variant::visit to return a value

This changes Variant::visit() to forward the value returned by the
selected visitor invocation. By perfectly forwarding the returned value,
this allows for the visitor to return by value or reference.

Note that all provided visitors must return the same type - the compiler
will otherwise fail with the message: "inconsistent deduction for auto
return type".
This commit is contained in:
Timothy Flynn 2021-05-19 11:28:27 -04:00 committed by Andreas Kling
parent 585e7890cd
commit 145e246a5e
2 changed files with 79 additions and 27 deletions

View file

@ -9,6 +9,7 @@
#include <AK/Array.h>
#include <AK/BitCast.h>
#include <AK/StdLibExtras.h>
#include <AK/TypeList.h>
namespace AK::Detail {
@ -68,24 +69,6 @@ struct Variant<IndexType, InitialIndex, F, Ts...> {
else
Variant<IndexType, InitialIndex + 1, Ts...>::copy_(old_id, old_data, new_data);
}
template<typename Visitor>
static void visit_(IndexType id, void* data, Visitor&& visitor)
{
if (id == current_index)
visitor(*bit_cast<F*>(data));
else
Variant<IndexType, InitialIndex + 1, Ts...>::visit_(id, data, forward<Visitor>(visitor));
}
template<typename Visitor>
static void visit_(IndexType id, const void* data, Visitor&& visitor)
{
if (id == current_index)
visitor(*bit_cast<const F*>(data));
else
Variant<IndexType, InitialIndex + 1, Ts...>::visit_(id, data, forward<Visitor>(visitor));
}
};
template<typename IndexType, IndexType InitialIndex>
@ -93,10 +76,23 @@ struct Variant<IndexType, InitialIndex> {
static void delete_(IndexType, void*) { }
static void move_(IndexType, void*, void*) { }
static void copy_(IndexType, const void*, void*) { }
template<typename Visitor>
static void visit_(IndexType, void*, Visitor&&) { }
template<typename Visitor>
static void visit_(IndexType, const void*, Visitor&&) { }
};
template<typename IndexType, typename... Ts>
struct VisitImpl {
template<typename Visitor, IndexType CurrentIndex = 0>
static constexpr inline decltype(auto) visit(IndexType id, const void* data, Visitor&& visitor) requires(CurrentIndex < sizeof...(Ts))
{
using T = typename TypeList<Ts...>::template Type<CurrentIndex>;
if (id == CurrentIndex)
return visitor(*bit_cast<T*>(data));
if constexpr ((CurrentIndex + 1) < sizeof...(Ts))
return visit<Visitor, CurrentIndex + 1>(id, data, forward<Visitor>(visitor));
else
VERIFY_NOT_REACHED();
}
};
struct VariantNoClearTag {
@ -310,17 +306,17 @@ public:
}
template<typename... Fs>
void visit(Fs&&... functions)
decltype(auto) visit(Fs&&... functions)
{
Visitor<Fs...> visitor { forward<Fs>(functions)... };
Helper::visit_(m_index, m_data, visitor);
return VisitHelper::visit(m_index, m_data, move(visitor));
}
template<typename... Fs>
void visit(Fs&&... functions) const
decltype(auto) visit(Fs&&... functions) const
{
Visitor<Fs...> visitor { forward<Fs>(functions)... };
Helper::visit_(m_index, m_data, visitor);
return VisitHelper::visit(m_index, m_data, move(visitor));
}
template<typename... NewTs>
@ -357,6 +353,7 @@ private:
static constexpr auto data_size = integer_sequence_generate_array<size_t>(0, IntegerSequence<size_t, sizeof(Ts)...>()).max();
static constexpr auto data_alignment = integer_sequence_generate_array<size_t>(0, IntegerSequence<size_t, alignof(Ts)...>()).max();
using Helper = Detail::Variant<IndexType, 0, Ts...>;
using VisitHelper = Detail::VisitImpl<IndexType, Ts...>;
explicit Variant(IndexType index, Detail::VariantConstructTag)
: Detail::MergeAndDeduplicatePacks<Detail::VariantConstructors<Ts, Variant<Ts...>>...>()
@ -367,7 +364,7 @@ private:
template<typename... Fs>
struct Visitor : Fs... {
Visitor(Fs&&... args)
: Fs(args)...
: Fs(forward<Fs>(args))...
{
}

View file

@ -6,8 +6,16 @@
#include <LibTest/TestSuite.h>
#include <AK/RefPtr.h>
#include <AK/Variant.h>
namespace {
struct Object : public RefCounted<Object> {
};
}
TEST_CASE(basic)
{
Variant<int, String> the_value { 42 };
@ -117,3 +125,50 @@ TEST_CASE(duplicated_types)
EXPECT(its_just_an_int.has<int>());
EXPECT_EQ(its_just_an_int.get<int>(), 42);
}
TEST_CASE(return_values)
{
using MyVariant = Variant<int, String, float>;
{
MyVariant the_value { 42.0f };
float value = the_value.visit(
[&](const int&) { return 1.0f; },
[&](const String&) { return 2.0f; },
[&](const float& f) { return f; });
EXPECT_EQ(value, 42.0f);
}
{
MyVariant the_value { 42 };
int value = the_value.visit(
[&](int& i) { return i; },
[&](String&) { return 2; },
[&](float&) { return 3; });
EXPECT_EQ(value, 42);
}
{
const MyVariant the_value { "str" };
String value = the_value.visit(
[&](const int&) { return String { "wrong" }; },
[&](const String& s) { return s; },
[&](const float&) { return String { "wrong" }; });
EXPECT_EQ(value, "str");
}
}
TEST_CASE(return_values_by_reference)
{
auto ref = adopt_ref_if_nonnull(new Object());
Variant<int, String, float> the_value { 42.0f };
auto& value = the_value.visit(
[&](const int&) -> RefPtr<Object>& { return ref; },
[&](const String&) -> RefPtr<Object>& { return ref; },
[&](const float&) -> RefPtr<Object>& { return ref; });
EXPECT_EQ(ref, value);
EXPECT_EQ(ref->ref_count(), 1u);
EXPECT_EQ(value->ref_count(), 1u);
}