From 8263e0a6199213e9ba5a30ea07a1ce461d78c85a Mon Sep 17 00:00:00 2001 From: Dan Klishch Date: Sun, 4 Feb 2024 15:09:00 -0500 Subject: [PATCH] AK: Introduce AK::Coroutine --- AK/Coroutine.h | 243 +++++++++++++++++++ AK/Forward.h | 3 + Meta/CMake/common_compile_options.cmake | 1 + Tests/AK/CMakeLists.txt | 1 + Tests/AK/TestCoroutine.cpp | 296 ++++++++++++++++++++++++ Userland/Libraries/LibCore/EventLoop.h | 13 ++ 6 files changed, 557 insertions(+) create mode 100644 AK/Coroutine.h create mode 100644 Tests/AK/TestCoroutine.cpp diff --git a/AK/Coroutine.h b/AK/Coroutine.h new file mode 100644 index 0000000000..3fac791a56 --- /dev/null +++ b/AK/Coroutine.h @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include + +namespace AK { + +namespace Detail { +struct SuspendNever { + // Even though we set -fno-exceptions, Clang really wants these to be noexcept. + bool await_ready() const noexcept { return true; } + void await_suspend(std::coroutine_handle<>) const noexcept { } + void await_resume() const noexcept { } +}; + +struct SymmetricControlTransfer { + SymmetricControlTransfer(std::coroutine_handle<> handle) + : m_handle(handle ? handle : std::noop_coroutine()) + { + } + + bool await_ready() const noexcept { return false; } + auto await_suspend(std::coroutine_handle<>) const noexcept { return m_handle; } + void await_resume() const noexcept { } + + std::coroutine_handle<> m_handle; +}; + +template +struct TryAwaiter; + +template +struct ValueHolder { + alignas(T) u8 m_return_value[sizeof(T)]; +}; + +template<> +struct ValueHolder { }; +} + +template +class [[nodiscard]] Coroutine : private Detail::ValueHolder { + struct CoroutinePromiseVoid; + struct CoroutinePromiseValue; + + AK_MAKE_NONCOPYABLE(Coroutine); + +public: + using ReturnType = T; + using promise_type = Conditional, CoroutinePromiseVoid, CoroutinePromiseValue>; + + ~Coroutine() + { + VERIFY(await_ready()); + if constexpr (!SameAs) + return_value()->~T(); + if (m_handle) + m_handle.destroy(); + } + + Coroutine(Coroutine&& other) + { + m_handle = AK::exchange(other.m_handle, {}); + if (!await_ready()) + m_handle.promise().m_coroutine = this; + else if constexpr (!IsVoid) + new (return_value()) T(move(*other.return_value())); + } + + Coroutine& operator=(Coroutine&& other) + { + if (this != &other) { + this->~Coroutine(); + new (this) Coroutine(move(other)); + } + return *this; + } + + bool await_ready() const + { + return !m_handle || m_handle.done(); + } + + void await_suspend(std::coroutine_handle<> awaiter) + { + m_handle.promise().m_awaiter = awaiter; + } + + // Do NOT bind the result of await_resume() on a temporary coroutine (or the result of CO_TRY) to auto&&! + [[nodiscard]] decltype(auto) await_resume() + { + if constexpr (SameAs) + return; + else + return static_cast(*return_value()); + } + +private: + template + friend struct Detail::TryAwaiter; + + // You cannot just have return_value and return_void defined in the same promise type because C++. + struct CoroutinePromiseBase { + CoroutinePromiseBase() = default; + + Coroutine get_return_object() + { + return { std::coroutine_handle::from_promise(*static_cast(this)) }; + } + + Detail::SuspendNever initial_suspend() { return {}; } + + Detail::SymmetricControlTransfer final_suspend() noexcept + { + return { m_awaiter }; + } + + std::coroutine_handle<> m_awaiter; + Coroutine* m_coroutine { nullptr }; + }; + + struct CoroutinePromiseValue : CoroutinePromiseBase { + template + requires requires { { T(forward(declval())) }; } + void return_value(U&& returned_object) + { + new (this->m_coroutine->return_value()) T(forward(returned_object)); + } + + void return_value(T&& returned_object) + { + new (this->m_coroutine->return_value()) T(move(returned_object)); + } + }; + + struct CoroutinePromiseVoid : CoroutinePromiseBase { + void return_void() { } + }; + + Coroutine(std::coroutine_handle&& handle) + : m_handle(move(handle)) + { + m_handle.promise().m_coroutine = this; + } + + T* return_value() + { + return reinterpret_cast(this->m_return_value); + } + + std::coroutine_handle m_handle; +}; + +template +T must_sync(Coroutine>&& coroutine) +{ + VERIFY(coroutine.await_ready()); + auto&& object = coroutine.await_resume(); + VERIFY(!object.is_error()); + return object.release_value(); +} + +namespace Detail { +template +struct TryAwaiter { + TryAwaiter(T& expression) + requires(!IsSpecializationOf) + : m_expression(&expression) + { + } + + TryAwaiter(T&& expression) + requires(!IsSpecializationOf) + : m_expression(&expression) + { + } + + bool await_ready() { return false; } + + template + requires IsSpecializationOf + std::coroutine_handle<> await_suspend(std::coroutine_handle handle) + { + if (!m_expression->is_error()) { + return handle; + } else { + auto awaiter = handle.promise().m_awaiter; + auto* coroutine = handle.promise().m_coroutine; + using ReturnType = RemoveReference::ReturnType; + static_assert(IsSpecializationOf, + "CO_TRY can only be used inside functions returning a specialization of ErrorOr"); + + // Move error to the user-visible AK::Coroutine + new (coroutine->return_value()) ReturnType(m_expression->release_error()); + // ... and tell it that there's a result available. + coroutine->m_handle = {}; + + // Run destructors for locals in the coroutine that failed. + handle.destroy(); + + // Lastly, transfer control to the parent (or nothing, if parent is not yet suspended). + if (awaiter) + return awaiter; + return std::noop_coroutine(); + } + } + + decltype(auto) await_resume() + { + return m_expression->release_value(); + } + + T* m_expression { nullptr }; +}; +} + +#ifdef AK_COMPILER_CLANG +# define CO_TRY(expression) (co_await ::AK::Detail::TryAwaiter { (expression) }) +#else +// GCC cannot handle CO_TRY(...CO_TRY(...)...), this hack ensures that it always has the right type information available. +// FIXME: Remove this once GCC can correctly infer the result type of `co_await TryAwaiter { ... }`. +# define CO_TRY(expression) static_cast(co_await ::AK::Detail::TryAwaiter { (expression) }) + +namespace Detail { +template +auto declval_coro_result(Coroutine&&) -> T; +template +auto declval_coro_result(T&&) -> T; +} + +#endif +} + +#ifdef USING_AK_GLOBALLY +using AK::Coroutine; +#endif diff --git a/AK/Forward.h b/AK/Forward.h index afc4e18dbd..c6fb4b86cd 100644 --- a/AK/Forward.h +++ b/AK/Forward.h @@ -26,6 +26,8 @@ class Bitmap; using ByteBuffer = Detail::ByteBuffer<32>; class CircularBuffer; class ConstrainedStream; +template +class Coroutine; class CountingStream; class DeprecatedFlyString; class ByteString; @@ -163,6 +165,7 @@ using AK::ByteString; using AK::CircularBuffer; using AK::CircularQueue; using AK::ConstrainedStream; +using AK::Coroutine; using AK::CountingStream; using AK::DeprecatedFlyString; using AK::DeprecatedStringCodePointIterator; diff --git a/Meta/CMake/common_compile_options.cmake b/Meta/CMake/common_compile_options.cmake index f60189cffd..84a00d309f 100644 --- a/Meta/CMake/common_compile_options.cmake +++ b/Meta/CMake/common_compile_options.cmake @@ -34,6 +34,7 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang$") add_compile_options(-Wno-implicit-const-int-float-conversion) add_compile_options(-Wno-user-defined-literals) add_compile_options(-Wno-vla-cxx-extension) + add_compile_options(-Wno-coroutine-missing-unhandled-exception) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # Only ignore expansion-to-defined for g++, clang's implementation doesn't complain about function-like macros add_compile_options(-Wno-expansion-to-defined) diff --git a/Tests/AK/CMakeLists.txt b/Tests/AK/CMakeLists.txt index ec85150dae..276924d777 100644 --- a/Tests/AK/CMakeLists.txt +++ b/Tests/AK/CMakeLists.txt @@ -20,6 +20,7 @@ set(AK_TEST_SOURCES TestCircularDeque.cpp TestCircularQueue.cpp TestComplex.cpp + TestCoroutine.cpp TestDisjointChunks.cpp TestDistinctNumeric.cpp TestDoublyLinkedList.cpp diff --git a/Tests/AK/TestCoroutine.cpp b/Tests/AK/TestCoroutine.cpp new file mode 100644 index 0000000000..44a043d726 --- /dev/null +++ b/Tests/AK/TestCoroutine.cpp @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include +#include +#include + +namespace { +Coroutine id(int a) +{ + co_return a; +} + +Coroutine sum(int a, int b) +{ + int c = co_await id(a); + int d = co_await id(b); + co_return c + d; +} +} + +TEST_CASE(no_spin) +{ + auto coro = sum(2, 3); + EXPECT(coro.await_ready()); + EXPECT_EQ(coro.await_resume(), 5); +} + +namespace { +struct LoopSpinner { + bool await_ready() const { return false; } + void await_suspend(std::coroutine_handle<> awaiter) + { + Core::deferred_invoke([awaiter] { + awaiter.resume(); + }); + } + void await_resume() { } +}; + +Coroutine loop_spinner() +{ + co_await LoopSpinner {}; + co_return 42; +} + +Coroutine> failing_loop_spinner() +{ + co_await LoopSpinner {}; + co_return Error::from_errno(ENOMEM); +} + +Coroutine two_level_loop_spinner() +{ + EXPECT_EQ(co_await loop_spinner(), 42); + co_return 43; +} +} + +TEST_CASE(loop_spinners) +{ + EXPECT_EQ(Core::run_async_in_new_event_loop(loop_spinner), 42); + EXPECT_EQ(Core::run_async_in_new_event_loop(failing_loop_spinner).error().code(), ENOMEM); + EXPECT_EQ(Core::run_async_in_new_event_loop(two_level_loop_spinner), 43); +} + +namespace { +Coroutine spinner1(Vector& result) +{ + result.append(1); + co_await LoopSpinner {}; + result.append(2); + co_return 3; +} + +Coroutine spinner2(Vector& result) +{ + result.append(4); + co_await LoopSpinner {}; + result.append(5); + co_return 6; +} + +Coroutine> interleaved() +{ + Vector result; + + result.append(7); + auto coro1 = spinner1(result); + result.append(8); + auto coro2 = spinner2(result); + result.append(9); + + result.append(co_await coro2); + result.append(co_await coro1); + + co_return result; +} +} + +TEST_CASE(interleaved_coroutines) +{ + EXPECT_EQ(Core::run_async_in_new_event_loop(interleaved), (Vector { 7, 1, 8, 4, 9, 2, 5, 6, 3 })); +} + +namespace { +Coroutine void_coro(int& result) +{ + result = 45; + co_return; +} +} + +TEST_CASE(void_coro) +{ + int result = 0; + auto coro = void_coro(result); + EXPECT(coro.await_ready()); + EXPECT_EQ(result, 45); +} + +namespace { +Coroutine destructors_inner(Vector& order) +{ + ScopeGuard guard = [&] { + order.append(1); + }; + co_await LoopSpinner {}; + order.append(2); + co_return; +} + +Coroutine> destructors_outer() +{ + Vector order; + order.append(3); + co_await destructors_inner(order); + order.append(4); + co_return order; +} +} + +TEST_CASE(destructors_order) +{ + EXPECT_EQ(Core::run_async_in_new_event_loop(destructors_outer), (Vector { 3, 2, 1, 4 })); +} + +namespace { +class Class { + AK_MAKE_NONCOPYABLE(Class); + +public: + Class() + : m_cookie(1) + { + } + + ~Class() + { + VERIFY(m_cookie >= 0); + m_cookie = -1; + AK::taint_for_optimizer(m_cookie); + } + + Class(Class&& other) + { + VERIFY(other.m_cookie >= 0); + m_cookie = exchange(other.m_cookie, 0) + 1; + AK::taint_for_optimizer(m_cookie); + } + + Class& operator=(Class&& other) = delete; + + int cookie() { return m_cookie; } + +private: + int m_cookie; +}; + +Coroutine return_class_1() +{ + co_await LoopSpinner {}; + co_return {}; +} + +Coroutine return_class_2() +{ + co_await LoopSpinner {}; + Class c; + co_return c; +} + +Coroutine> return_class_3() +{ + co_await LoopSpinner {}; + co_return Class {}; +} + +Coroutine move_count() +{ + { + auto c = co_await return_class_1(); + // 1. Construct temporary as an argument for return_value. + // 2. Move this temporary into Coroutine. + // 3. Move class from Coroutine to local variable. + EXPECT_EQ(c.cookie(), 3); + } + + { + auto c = co_await return_class_2(); + // 1. Construct new class and store it as a local variable. + // 2. Move this temporary into Coroutine. + // 3. Move class from Coroutine to local variable. + EXPECT_EQ(c.cookie(), 3); + } + + { + auto c_or_error = co_await return_class_3(); + auto c = c_or_error.release_value(); + // 1. Construct temporary as an argument for the constructor of a temporary ErrorOr. + // 2. Move temporary ErrorOr into Coroutine. + // 3. Move ErrorOr from Coroutine to c_or_error. + // 4. Move Class from c_or_error to c. + EXPECT_EQ(c.cookie(), 4); + } +} +} + +TEST_CASE(move_count) +{ + Core::run_async_in_new_event_loop(move_count); +} + +namespace { +Coroutine> co_try_success() +{ + auto c = CO_TRY(co_await return_class_3()); + // 1. Construct temporary as an argument for the constructor of a temporary ErrorOr. + // 2. Move temporary ErrorOr into Coroutine. + // -. Some magic is done in TryAwaiter. + // 3. Move Class from ErrorOr inside Coroutine to c. + EXPECT_EQ(c.cookie(), 3); + co_return {}; +} + +Coroutine> co_try_fail() +{ + ErrorOr error = Error::from_string_literal("ERROR!"); + CO_TRY(error); + co_return {}; +} + +Coroutine> co_try_fail_inner() +{ + co_await LoopSpinner {}; + co_return Error::from_string_literal("ERROR!"); +} + +Coroutine> co_try_fail_async() +{ + CO_TRY(co_await co_try_fail_inner()); + co_return {}; +} +} + +TEST_CASE(co_try) +{ + { + auto result = Core::run_async_in_new_event_loop(co_try_success); + EXPECT(!result.is_error()); + } + + { + auto result = Core::run_async_in_new_event_loop(co_try_fail); + EXPECT(result.is_error()); + } + + { + auto result = Core::run_async_in_new_event_loop(co_try_fail_async); + EXPECT(result.is_error()); + } +} + +namespace { +Coroutine nothing() { co_return; } +} + +TEST_CASE(move_void_coroutine) +{ + auto void_coro = nothing(); + auto moved = move(void_coro); + EXPECT(moved.await_ready()); +} diff --git a/Userland/Libraries/LibCore/EventLoop.h b/Userland/Libraries/LibCore/EventLoop.h index 61cbfafcf9..b69f9a756e 100644 --- a/Userland/Libraries/LibCore/EventLoop.h +++ b/Userland/Libraries/LibCore/EventLoop.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -103,4 +104,16 @@ private: void deferred_invoke(ESCAPING Function); +template +requires(IsSpecializationOf, Coroutine>) +auto run_async_in_new_event_loop(T&& function) +{ + Core::EventLoop loop; + auto coro = function(); + loop.spin_until([&] { + return coro.await_ready(); + }); + return coro.await_resume(); +} + }