AK: Introduce AK::Generator

This commit is contained in:
Dan Klishch 2024-05-14 01:00:33 -04:00 committed by Ali Mohammad Pur
parent aeb17a864e
commit 0f1d67f2f0
4 changed files with 274 additions and 0 deletions

View file

@ -31,6 +31,12 @@ struct SuspendNever {
void await_resume() const noexcept { }
};
struct SuspendAlways {
bool await_ready() const noexcept { return false; }
void await_suspend(std::coroutine_handle<>) const noexcept { }
void await_resume() const noexcept { }
};
struct SymmetricControlTransfer {
SymmetricControlTransfer(std::coroutine_handle<> handle)
: m_handle(handle ? handle : std::noop_coroutine())

213
AK/Generator.h Normal file
View file

@ -0,0 +1,213 @@
/*
* Copyright (c) 2024, Dan Klishch <danilklishch@gmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#pragma once
#include <AK/Coroutine.h>
#include <AK/Variant.h>
namespace AK {
namespace Detail {
class YieldAwaiter {
public:
YieldAwaiter(std::coroutine_handle<> control_transfer, std::coroutine_handle<>& awaiter)
: m_control_transfer(control_transfer)
, m_awaiter(awaiter)
{
}
bool await_ready() const { return false; }
auto await_suspend(std::coroutine_handle<> handle)
{
m_awaiter = handle;
return m_control_transfer;
}
void await_resume() { }
private:
std::coroutine_handle<> m_control_transfer;
std::coroutine_handle<>& m_awaiter;
};
}
template<typename Y, typename R>
class [[nodiscard]] Generator {
struct GeneratorPromiseType;
AK_MAKE_NONCOPYABLE(Generator);
public:
using YieldType = Y;
using ReturnType = R;
using promise_type = GeneratorPromiseType;
~Generator()
{
destroy_stored_object();
if (m_handle)
m_handle.destroy();
}
Generator(Generator&& other)
{
m_handle = AK::exchange(other.m_handle, {});
m_read_returned_object = exchange(other.m_read_returned_object, false);
m_currently_stored_type = other.m_currently_stored_type;
if (m_currently_stored_type == CurrentlyStoredType::Yield) {
new (m_data) YieldType(move(*reinterpret_cast<YieldType*>(other.m_data)));
} else if (m_currently_stored_type == CurrentlyStoredType::Return) {
new (m_data) ReturnType(move(*reinterpret_cast<ReturnType*>(other.m_data)));
}
other.destroy_stored_object();
if (m_handle)
m_handle.promise().m_coroutine = this;
}
Generator& operator=(Generator&& other)
{
if (this != &other) {
this->~Generator();
new (this) Generator(move(other));
}
return *this;
}
bool is_done() const { return !m_handle || m_handle.done(); }
void destroy()
{
VERIFY(m_handle && !m_handle.promise().m_awaiter);
destroy_stored_object();
m_handle.destroy();
m_handle = {};
}
Coroutine<Variant<Y, R>> next()
{
if (!is_done()) {
VERIFY(m_currently_stored_type != CurrentlyStoredType::Return);
co_await Detail::YieldAwaiter { m_handle, m_handle.promise().m_awaiter };
if (m_handle)
m_handle.promise().m_awaiter = {};
}
if (is_done()) {
VERIFY(m_currently_stored_type == CurrentlyStoredType::Return && !m_read_returned_object);
m_read_returned_object = true;
co_return move(*reinterpret_cast<ReturnType*>(m_data));
} else {
VERIFY(m_currently_stored_type == CurrentlyStoredType::Yield);
co_return move(*reinterpret_cast<YieldType*>(m_data));
}
}
private:
template<typename U>
friend struct Detail::TryAwaiter;
struct GeneratorPromiseType {
Generator get_return_object()
{
return { std::coroutine_handle<promise_type>::from_promise(*this) };
}
Detail::SuspendAlways initial_suspend() { return {}; }
Detail::SymmetricControlTransfer final_suspend() noexcept
{
VERIFY(m_awaiter);
return { m_awaiter };
}
template<typename U>
requires requires { { T(forward<U>(declval<U>())) }; }
void return_value(U&& returned_object)
{
m_coroutine->place_returned_object(forward<U>(returned_object));
}
void return_value(ReturnType&& returned_object)
{
m_coroutine->place_returned_object(move(returned_object));
}
Detail::SymmetricControlTransfer yield_value(YieldType&& yield_value)
{
m_coroutine->place_yield_object(move(yield_value));
VERIFY(m_awaiter);
return { m_awaiter };
}
std::coroutine_handle<> m_awaiter;
Generator* m_coroutine { nullptr }; // Must be named `m_coroutine` for CO_TRY to work
};
Generator(std::coroutine_handle<promise_type>&& handle)
: m_handle(move(handle))
{
m_handle.promise().m_coroutine = this;
}
void destroy_stored_object()
{
switch (m_currently_stored_type) {
case CurrentlyStoredType::Empty:
break;
case CurrentlyStoredType::Yield:
reinterpret_cast<YieldType*>(m_data)->~YieldType();
break;
case CurrentlyStoredType::Return:
reinterpret_cast<ReturnType*>(m_data)->~ReturnType();
break;
}
m_currently_stored_type = CurrentlyStoredType::Empty;
}
template<typename... Args>
YieldType* place_yield_object(Args&&... args)
{
destroy_stored_object();
m_currently_stored_type = CurrentlyStoredType::Yield;
return new (m_data) YieldType(forward<Args>(args)...);
}
template<typename... Args>
ReturnType* place_returned_object(Args&&... args)
{
destroy_stored_object();
m_currently_stored_type = CurrentlyStoredType::Return;
return new (m_data) ReturnType(forward<Args>(args)...);
}
ReturnType* return_value() // Must be defined for CO_TRY.
{
destroy_stored_object();
m_currently_stored_type = CurrentlyStoredType::Return;
return reinterpret_cast<ReturnType*>(m_data);
}
std::coroutine_handle<promise_type> m_handle;
enum class CurrentlyStoredType {
Empty,
Yield,
Return,
} m_currently_stored_type
= CurrentlyStoredType::Empty;
bool m_read_returned_object { false };
alignas(max(alignof(YieldType), alignof(ReturnType))) u8 m_data[max(sizeof(YieldType), sizeof(ReturnType))];
};
}
#ifdef USING_AK_GLOBALLY
using AK::Generator;
#endif

View file

@ -35,6 +35,7 @@ set(AK_TEST_SOURCES
TestFlyString.cpp
TestFormat.cpp
TestFuzzyMatch.cpp
TestGeneratorAK.cpp
TestGenericLexer.cpp
TestHashFunctions.cpp
TestHashMap.cpp

View file

@ -0,0 +1,54 @@
/*
* Copyright (c) 2024, Dan Klishch <danilklishch@gmail.com>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Generator.h>
#include <LibTest/AsyncTestCase.h>
namespace {
Generator<int, Empty> generate_sync(Vector<int>& order)
{
ScopeGuard guard = [&] {
order.append(7);
};
order.append(2);
co_yield 1;
order.append(4);
co_yield 2;
order.append(6);
co_return {};
}
}
ASYNC_TEST_CASE(sync_order)
{
Vector<int> order;
auto gen = generate_sync(order);
EXPECT(!gen.is_done());
order.append(1);
auto result1 = gen.next();
order.append(3);
EXPECT(result1.await_ready());
EXPECT_EQ(result1.await_resume(), 1);
auto result2 = gen.next();
order.append(5);
EXPECT(result2.await_ready());
EXPECT_EQ(result2.await_resume(), 2);
auto end = gen.next();
order.append(8);
EXPECT(end.await_ready());
EXPECT_EQ(end.await_resume(), Empty {});
EXPECT_EQ(order, (Vector<int> { 1, 2, 3, 4, 5, 6, 7, 8 }));
co_return;
}