/* * Copyright (c) 2024, Dan Klishch * * SPDX-License-Identifier: BSD-2-Clause */ #pragma once #include #include #include namespace AK { namespace Detail { // FIXME: GCC ICEs when a simpler implementation of CO_TRY_OR_FAIL is used. See also LibTest/AsyncTestCase.h. #ifdef AK_COMPILER_GCC namespace Test { template struct TryOrFailAwaiter; } #endif struct SuspendNever { // Even though we set -fno-exceptions, Clang really wants these to be noexcept. bool await_ready() const noexcept { return true; } void await_suspend(std::coroutine_handle<>) const noexcept { } void await_resume() const noexcept { } }; struct SuspendAlways { bool await_ready() const noexcept { return false; } void await_suspend(std::coroutine_handle<>) const noexcept { } void await_resume() const noexcept { } }; struct SymmetricControlTransfer { SymmetricControlTransfer(std::coroutine_handle<> handle) : m_handle(handle ? handle : std::noop_coroutine()) { } bool await_ready() const noexcept { return false; } auto await_suspend(std::coroutine_handle<>) const noexcept { return m_handle; } void await_resume() const noexcept { } std::coroutine_handle<> m_handle; }; template struct TryAwaiter; template struct ValueHolder { alignas(T) u8 m_return_value[sizeof(T)]; }; template<> struct ValueHolder { }; } template class [[nodiscard]] Coroutine : private Detail::ValueHolder { struct CoroutinePromiseVoid; struct CoroutinePromiseValue; AK_MAKE_NONCOPYABLE(Coroutine); public: using ReturnType = T; using promise_type = Conditional, CoroutinePromiseVoid, CoroutinePromiseValue>; ~Coroutine() { VERIFY(await_ready()); if constexpr (!SameAs) return_value()->~T(); if (m_handle) m_handle.destroy(); } Coroutine(Coroutine&& other) { m_handle = AK::exchange(other.m_handle, {}); if (!await_ready()) m_handle.promise().m_coroutine = this; else if constexpr (!IsVoid) new (return_value()) T(move(*other.return_value())); } Coroutine& operator=(Coroutine&& other) { if (this != &other) { this->~Coroutine(); new (this) Coroutine(move(other)); } return *this; } bool await_ready() const { return !m_handle || m_handle.done(); } void await_suspend(std::coroutine_handle<> awaiter) { m_handle.promise().m_awaiter = awaiter; } // Do NOT bind the result of await_resume() on a temporary coroutine (or the result of CO_TRY) to auto&&! [[nodiscard]] decltype(auto) await_resume() { if constexpr (SameAs) return; else return static_cast(*return_value()); } private: template friend struct Detail::TryAwaiter; #ifdef AK_COMPILER_GCC template friend struct AK::Detail::Test::TryOrFailAwaiter; #endif // You cannot just have return_value and return_void defined in the same promise type because C++. struct CoroutinePromiseBase { CoroutinePromiseBase() = default; Coroutine get_return_object() { return { std::coroutine_handle::from_promise(*static_cast(this)) }; } Detail::SuspendNever initial_suspend() { return {}; } Detail::SymmetricControlTransfer final_suspend() noexcept { return { m_awaiter }; } std::coroutine_handle<> m_awaiter; Coroutine* m_coroutine { nullptr }; }; struct CoroutinePromiseValue : CoroutinePromiseBase { template requires requires { { T(forward(declval())) }; } void return_value(U&& returned_object) { new (this->m_coroutine->return_value()) T(forward(returned_object)); } void return_value(T&& returned_object) { new (this->m_coroutine->return_value()) T(move(returned_object)); } }; struct CoroutinePromiseVoid : CoroutinePromiseBase { void return_void() { } }; Coroutine(std::coroutine_handle&& handle) : m_handle(move(handle)) { m_handle.promise().m_coroutine = this; } T* return_value() { return reinterpret_cast(this->m_return_value); } std::coroutine_handle m_handle; }; template T must_sync(Coroutine>&& coroutine) { VERIFY(coroutine.await_ready()); auto&& object = coroutine.await_resume(); VERIFY(!object.is_error()); return object.release_value(); } namespace Detail { template struct TryAwaiter { TryAwaiter(T& expression) requires(!IsSpecializationOf) : m_expression(&expression) { } TryAwaiter(T&& expression) requires(!IsSpecializationOf) : m_expression(&expression) { } bool await_ready() { return false; } template requires IsSpecializationOf std::coroutine_handle<> await_suspend(std::coroutine_handle handle) { if (!m_expression->is_error()) { return handle; } else { auto awaiter = handle.promise().m_awaiter; auto* coroutine = handle.promise().m_coroutine; using ReturnType = RemoveReference::ReturnType; static_assert(IsSpecializationOf, "CO_TRY can only be used inside functions returning a specialization of ErrorOr"); // Move error to the user-visible AK::Coroutine new (coroutine->return_value()) ReturnType(m_expression->release_error()); // ... and tell it that there's a result available. coroutine->m_handle = {}; // Run destructors for locals in the coroutine that failed. handle.destroy(); // Lastly, transfer control to the parent (or nothing, if parent is not yet suspended). if (awaiter) return awaiter; return std::noop_coroutine(); } } decltype(auto) await_resume() { return m_expression->release_value(); } T* m_expression { nullptr }; }; } #ifdef AK_COMPILER_CLANG # define CO_TRY(expression) (co_await ::AK::Detail::TryAwaiter { (expression) }) #else // GCC cannot handle CO_TRY(...CO_TRY(...)...), this hack ensures that it always has the right type information available. // FIXME: Remove this once GCC can correctly infer the result type of `co_await TryAwaiter { ... }`. # define CO_TRY(expression) static_cast(co_await ::AK::Detail::TryAwaiter { (expression) }) namespace Detail { template auto declval_coro_result(Coroutine&&) -> T; template auto declval_coro_result(T&&) -> T; } #endif } #ifdef USING_AK_GLOBALLY using AK::Coroutine; #endif