diff --git a/Userland/Libraries/LibCore/EventLoopImplementationUnix.cpp b/Userland/Libraries/LibCore/EventLoopImplementationUnix.cpp index 59154c2e76..34c7a4e240 100644 --- a/Userland/Libraries/LibCore/EventLoopImplementationUnix.cpp +++ b/Userland/Libraries/LibCore/EventLoopImplementationUnix.cpp @@ -239,10 +239,10 @@ struct ThreadData { return *data; } - static ThreadData& for_thread(pthread_t thread_id) + static ThreadData* for_thread(pthread_t thread_id) { pthread_rwlock_rdlock(&*s_thread_data_lock); - auto& result = *s_thread_data.get(thread_id).value(); + auto result = s_thread_data.get(thread_id).value_or(nullptr); pthread_rwlock_unlock(&*s_thread_data_lock); return result; } @@ -644,7 +644,10 @@ intptr_t EventLoopManagerUnix::register_timer(EventReceiver& object, int millise void EventLoopManagerUnix::unregister_timer(intptr_t timer_id) { auto* timer = bit_cast(timer_id); - auto& thread_data = ThreadData::for_thread(timer->owner_thread); + auto thread_data_ptr = ThreadData::for_thread(timer->owner_thread); + if (!thread_data_ptr) + return; + auto& thread_data = *thread_data_ptr; auto expected = false; if (timer->is_being_deleted.compare_exchange_strong(expected, true, AK::MemoryOrder::memory_order_acq_rel)) { if (timer->is_scheduled()) @@ -670,8 +673,11 @@ void EventLoopManagerUnix::register_notifier(Notifier& notifier) void EventLoopManagerUnix::unregister_notifier(Notifier& notifier) { - auto& thread_data = ThreadData::for_thread(notifier.owner_thread()); + auto thread_data_ptr = ThreadData::for_thread(notifier.owner_thread()); + if (!thread_data_ptr) + return; + auto& thread_data = *thread_data_ptr; auto it = thread_data.notifier_by_ptr.find(¬ifier); VERIFY(it != thread_data.notifier_by_ptr.end()); diff --git a/Userland/Libraries/LibCore/SOCKSProxyClient.cpp b/Userland/Libraries/LibCore/SOCKSProxyClient.cpp index db738c47ba..151a2c1bec 100644 --- a/Userland/Libraries/LibCore/SOCKSProxyClient.cpp +++ b/Userland/Libraries/LibCore/SOCKSProxyClient.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include @@ -125,27 +126,27 @@ StringView reply_response_name(Reply reply) VERIFY_NOT_REACHED(); } -ErrorOr send_version_identifier_and_method_selection_message(Core::Socket& socket, Core::SOCKSProxyClient::Version version, Method method) +Coroutine> send_version_identifier_and_method_selection_message(Core::Socket& socket, Core::SOCKSProxyClient::Version version, Method method) { Socks5VersionIdentifierAndMethodSelectionMessage message { .version_identifier = to_underlying(version), .method_count = 1, .methods = { to_underlying(method) }, }; - TRY(socket.write_value(message)); + CO_TRY(socket.write_value(message)); - auto response = TRY(socket.read_value()); + auto response = CO_TRY(socket.read_value()); if (response.version_identifier != to_underlying(version)) - return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); + co_return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); if (response.method != to_underlying(method)) - return Error::from_string_literal("SOCKS negotiation failed: Failed to negotiate a method"); + co_return Error::from_string_literal("SOCKS negotiation failed: Failed to negotiate a method"); - return {}; + co_return {}; } -ErrorOr send_connect_request_message(Core::Socket& socket, Core::SOCKSProxyClient::Version version, Core::SOCKSProxyClient::HostOrIPV4 target, int port, Core::SOCKSProxyClient::Command command) +Coroutine> send_connect_request_message(Core::Socket& socket, Core::SOCKSProxyClient::Version version, Core::SOCKSProxyClient::HostOrIPV4 target, int port, Core::SOCKSProxyClient::Command command) { AllocatingMemoryStream stream; @@ -158,89 +159,89 @@ ErrorOr send_connect_request_message(Core::Socket& socket, Core::SOCKSPro .port = htons(port), }; - TRY(stream.write_value(header)); + CO_TRY(stream.write_value(header)); - TRY(target.visit( - [&](ByteString const& hostname) -> ErrorOr { + CO_TRY(co_await target.visit( + [&](ByteString const& hostname) -> Coroutine> { u8 address_data[2]; address_data[0] = to_underlying(AddressType::DomainName); address_data[1] = hostname.length(); - TRY(stream.write_until_depleted({ address_data, sizeof(address_data) })); - TRY(stream.write_until_depleted({ hostname.characters(), hostname.length() })); - return {}; + CO_TRY(stream.write_until_depleted({ address_data, sizeof(address_data) })); + CO_TRY(stream.write_until_depleted({ hostname.characters(), hostname.length() })); + co_return {}; }, - [&](u32 ipv4) -> ErrorOr { + [&](u32 ipv4) -> Coroutine> { u8 address_data[5]; address_data[0] = to_underlying(AddressType::IPV4); u32 network_ordered_ipv4 = NetworkOrdered(ipv4); memcpy(address_data + 1, &network_ordered_ipv4, sizeof(network_ordered_ipv4)); - TRY(stream.write_until_depleted({ address_data, sizeof(address_data) })); - return {}; + CO_TRY(stream.write_until_depleted({ address_data, sizeof(address_data) })); + co_return {}; })); - TRY(stream.write_value(trailer)); + CO_TRY(stream.write_value(trailer)); - auto buffer = TRY(ByteBuffer::create_uninitialized(stream.used_buffer_size())); - TRY(stream.read_until_filled(buffer.bytes())); - TRY(socket.write_until_depleted(buffer)); + auto buffer = CO_TRY(ByteBuffer::create_uninitialized(stream.used_buffer_size())); + CO_TRY(stream.read_until_filled(buffer.bytes())); + CO_TRY(socket.write_until_depleted(buffer)); - auto response_header = TRY(socket.read_value()); + auto response_header = CO_TRY(socket.read_value()); if (response_header.version_identifier != to_underlying(version)) - return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); + co_return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); - auto response_address_type = TRY(socket.read_value()); + auto response_address_type = CO_TRY(socket.read_value()); switch (AddressType(response_address_type)) { case AddressType::IPV4: { u8 response_address_data[4]; - TRY(socket.read_until_filled({ response_address_data, sizeof(response_address_data) })); + CO_TRY(socket.read_until_filled({ response_address_data, sizeof(response_address_data) })); break; } case AddressType::DomainName: { - auto response_address_length = TRY(socket.read_value()); - auto buffer = TRY(ByteBuffer::create_uninitialized(response_address_length)); - TRY(socket.read_until_filled(buffer)); + auto response_address_length = CO_TRY(socket.read_value()); + auto buffer = CO_TRY(ByteBuffer::create_uninitialized(response_address_length)); + CO_TRY(socket.read_until_filled(buffer)); break; } case AddressType::IPV6: default: - return Error::from_string_literal("SOCKS negotiation failed: Invalid connect response address type"); + co_return Error::from_string_literal("SOCKS negotiation failed: Invalid connect response address type"); } - [[maybe_unused]] auto bound_port = TRY(socket.read_value()); + [[maybe_unused]] auto bound_port = CO_TRY(socket.read_value()); - return Reply(response_header.status); + co_return Reply(response_header.status); } -ErrorOr send_username_password_authentication_message(Core::Socket& socket, Core::SOCKSProxyClient::UsernamePasswordAuthenticationData const& auth_data) +Coroutine> send_username_password_authentication_message(Core::Socket& socket, Core::SOCKSProxyClient::UsernamePasswordAuthenticationData const& auth_data) { AllocatingMemoryStream stream; u8 version = 0x01; - TRY(stream.write_value(version)); + CO_TRY(stream.write_value(version)); u8 username_length = auth_data.username.length(); - TRY(stream.write_value(username_length)); + CO_TRY(stream.write_value(username_length)); - TRY(stream.write_until_depleted({ auth_data.username.characters(), auth_data.username.length() })); + CO_TRY(stream.write_until_depleted({ auth_data.username.characters(), auth_data.username.length() })); u8 password_length = auth_data.password.length(); - TRY(stream.write_value(password_length)); + CO_TRY(stream.write_value(password_length)); - TRY(stream.write_until_depleted({ auth_data.password.characters(), auth_data.password.length() })); + CO_TRY(stream.write_until_depleted({ auth_data.password.characters(), auth_data.password.length() })); - auto buffer = TRY(ByteBuffer::create_uninitialized(stream.used_buffer_size())); - TRY(stream.read_until_filled(buffer.bytes())); + auto buffer = CO_TRY(ByteBuffer::create_uninitialized(stream.used_buffer_size())); + CO_TRY(stream.read_until_filled(buffer.bytes())); - TRY(socket.write_until_depleted(buffer)); + CO_TRY(socket.write_until_depleted(buffer)); - auto response = TRY(socket.read_value()); + auto response = CO_TRY(socket.read_value()); if (response.version_identifier != version) - return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); + co_return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); - return response.status; + co_return response.status; } } @@ -252,49 +253,49 @@ SOCKSProxyClient::~SOCKSProxyClient() m_socket.on_ready_to_read = nullptr; } -ErrorOr> SOCKSProxyClient::connect(Socket& underlying, Version version, HostOrIPV4 const& target, int target_port, Variant const& auth_data, Command command) +Coroutine>> SOCKSProxyClient::async_connect(Socket& underlying, Version version, HostOrIPV4 const& target, int target_port, Variant const& auth_data, Command command) { if (version != Version::V5) - return Error::from_string_literal("SOCKS version not supported"); + co_return Error::from_string_literal("SOCKS version not supported"); - return auth_data.visit( - [&](Empty) -> ErrorOr> { - TRY(send_version_identifier_and_method_selection_message(underlying, version, Method::NoAuth)); - auto reply = TRY(send_connect_request_message(underlying, version, target, target_port, command)); + co_return co_await auth_data.visit( + [&](Empty) -> Coroutine>> { + CO_TRY(co_await send_version_identifier_and_method_selection_message(underlying, version, Method::NoAuth)); + auto reply = CO_TRY(co_await send_connect_request_message(underlying, version, target, target_port, command)); if (reply != Reply::Succeeded) { underlying.close(); - return Error::from_string_view(reply_response_name(reply)); + co_return Error::from_string_view(reply_response_name(reply)); } - return adopt_nonnull_own_or_enomem(new SOCKSProxyClient { + co_return adopt_nonnull_own_or_enomem(new SOCKSProxyClient { underlying, nullptr, }); }, - [&](UsernamePasswordAuthenticationData const& auth_data) -> ErrorOr> { - TRY(send_version_identifier_and_method_selection_message(underlying, version, Method::UsernamePassword)); - auto auth_response = TRY(send_username_password_authentication_message(underlying, auth_data)); + [&](UsernamePasswordAuthenticationData const& auth_data) -> Coroutine>> { + CO_TRY(co_await send_version_identifier_and_method_selection_message(underlying, version, Method::UsernamePassword)); + auto auth_response = CO_TRY(co_await send_username_password_authentication_message(underlying, auth_data)); if (auth_response != 0) { underlying.close(); - return Error::from_string_literal("SOCKS authentication failed"); + co_return Error::from_string_literal("SOCKS authentication failed"); } - auto reply = TRY(send_connect_request_message(underlying, version, target, target_port, command)); + auto reply = CO_TRY(co_await send_connect_request_message(underlying, version, target, target_port, command)); if (reply != Reply::Succeeded) { underlying.close(); - return Error::from_string_view(reply_response_name(reply)); + co_return Error::from_string_view(reply_response_name(reply)); } - return adopt_nonnull_own_or_enomem(new SOCKSProxyClient { + co_return adopt_nonnull_own_or_enomem(new SOCKSProxyClient { underlying, nullptr, }); }); } -ErrorOr> SOCKSProxyClient::connect(HostOrIPV4 const& server, int server_port, Version version, HostOrIPV4 const& target, int target_port, Variant const& auth_data, Command command) +Coroutine>> SOCKSProxyClient::async_connect(HostOrIPV4 const& server, int server_port, Version version, HostOrIPV4 const& target, int target_port, Variant const& auth_data, Command command) { - auto underlying = TRY(server.visit( + auto underlying = CO_TRY(server.visit( [&](u32 ipv4) { return Core::TCPSocket::connect({ IPv4Address(ipv4), static_cast(server_port) }); }, @@ -302,10 +303,10 @@ ErrorOr> SOCKSProxyClient::connect(HostOrIPV4 co return Core::TCPSocket::connect(hostname, static_cast(server_port)); })); - auto socket = TRY(connect(*underlying, version, target, target_port, auth_data, command)); + auto socket = CO_TRY(co_await async_connect(*underlying, version, target, target_port, auth_data, command)); socket->m_own_underlying_socket = move(underlying); - dbgln("SOCKS proxy connected, have {} available bytes", TRY(socket->m_socket.pending_bytes())); - return socket; + dbgln("SOCKS proxy connected, have {} available bytes", CO_TRY(socket->m_socket.pending_bytes())); + co_return socket; } } diff --git a/Userland/Libraries/LibCore/SOCKSProxyClient.h b/Userland/Libraries/LibCore/SOCKSProxyClient.h index cbb816870d..387055ab89 100644 --- a/Userland/Libraries/LibCore/SOCKSProxyClient.h +++ b/Userland/Libraries/LibCore/SOCKSProxyClient.h @@ -34,6 +34,9 @@ public: static ErrorOr> connect(Socket& underlying, Version, HostOrIPV4 const& target, int target_port, Variant const& auth_data = {}, Command = Command::Connect); static ErrorOr> connect(HostOrIPV4 const& server, int server_port, Version, HostOrIPV4 const& target, int target_port, Variant const& auth_data = {}, Command = Command::Connect); + static Coroutine>> async_connect(Socket& underlying, Version, HostOrIPV4 const& target, int target_port, Variant const& auth_data = {}, Command = Command::Connect); + static Coroutine>> async_connect(HostOrIPV4 const& server, int server_port, Version, HostOrIPV4 const& target, int target_port, Variant const& auth_data = {}, Command = Command::Connect); + virtual ~SOCKSProxyClient() override; // ^Stream::Stream diff --git a/Userland/Libraries/LibCore/Socket.cpp b/Userland/Libraries/LibCore/Socket.cpp index e2dc3b0676..b12c4c4610 100644 --- a/Userland/Libraries/LibCore/Socket.cpp +++ b/Userland/Libraries/LibCore/Socket.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include @@ -216,6 +217,16 @@ ErrorOr> TCPSocket::connect(SocketAddress const& addres return socket; } +Coroutine>> TCPSocket::async_connect(Core::SocketAddress const& address) +{ + co_return CO_TRY(connect(address)); +} + +Coroutine>> TCPSocket::async_connect(const AK::ByteString& host, u16 port) +{ + co_return CO_TRY(connect(host, port)); +} + ErrorOr> TCPSocket::adopt_fd(int fd) { if (fd < 0) { diff --git a/Userland/Libraries/LibCore/Socket.h b/Userland/Libraries/LibCore/Socket.h index 133610b841..c782b8d194 100644 --- a/Userland/Libraries/LibCore/Socket.h +++ b/Userland/Libraries/LibCore/Socket.h @@ -160,6 +160,9 @@ public: static ErrorOr> connect(SocketAddress const& address); static ErrorOr> adopt_fd(int fd); + static Coroutine>> async_connect(ByteString const& host, u16 port); + static Coroutine>> async_connect(SocketAddress const& address); + TCPSocket(TCPSocket&& other) : Socket(static_cast(other)) , m_helper(move(other.m_helper)) diff --git a/Userland/Libraries/LibTLS/Socket.cpp b/Userland/Libraries/LibTLS/Socket.cpp index 495d0e0d55..46b62372ea 100644 --- a/Userland/Libraries/LibTLS/Socket.cpp +++ b/Userland/Libraries/LibTLS/Socket.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -50,48 +51,88 @@ ErrorOr TLSv12::write_some(ReadonlyBytes bytes) return bytes.size(); } -ErrorOr> TLSv12::connect(ByteString const& host, u16 port, Options options) +template +struct PromiseAwaiter { + bool await_ready() const { return promise->is_resolved(); } + void await_suspend(std::coroutine_handle<> awaiter) + { + promise->when_resolved([awaiter](auto&) { + Core::deferred_invoke([awaiter] { awaiter.resume(); }); + }); + promise->when_rejected([awaiter](auto&) { + Core::deferred_invoke([awaiter] { awaiter.resume(); }); + }); + } + ErrorOr await_resume() + { + if constexpr (IsVoid) + return {}; + else + return promise->await(); // Already resolved, so this should never yield to the event loop. + } + + NonnullRefPtr> promise; +}; + +Coroutine>> TLSv12::async_connect(ByteString const& host, u16 port, Options options) { auto promise = Core::Promise::construct(); - OwnPtr tcp_socket = TRY(Core::TCPSocket::connect(host, port)); - TRY(tcp_socket->set_blocking(false)); + OwnPtr tcp_socket = CO_TRY(co_await Core::TCPSocket::async_connect(host, port)); + CO_TRY(tcp_socket->set_blocking(false)); auto tls_socket = make(move(tcp_socket), move(options)); tls_socket->set_sni(host); - tls_socket->on_connected = [&] { + tls_socket->on_connected = [=] { promise->resolve({}); }; - tls_socket->on_tls_error = [&](auto alert) { - tls_socket->try_disambiguate_error(); + tls_socket->on_tls_error = [&tls_socket = *tls_socket, promise](auto alert) { + tls_socket.try_disambiguate_error(); promise->reject(AK::Error::from_string_view(enum_to_string(alert))); }; - TRY(promise->await()); + ScopeGuard clear_callbacks = [&tls_socket = *tls_socket] { + tls_socket.on_tls_error = nullptr; + tls_socket.on_connected = nullptr; + }; + + CO_TRY(co_await PromiseAwaiter { promise }); - tls_socket->on_tls_error = nullptr; - tls_socket->on_connected = nullptr; tls_socket->m_context.should_expect_successful_read = true; - return tls_socket; + co_return tls_socket; } -ErrorOr> TLSv12::connect(ByteString const& host, Core::Socket& underlying_stream, Options options) +Coroutine>> TLSv12::async_connect(ByteString const& host, Core::Socket& underlying_stream, Options options) { auto promise = Core::Promise::construct(); - TRY(underlying_stream.set_blocking(false)); + CO_TRY(underlying_stream.set_blocking(false)); auto tls_socket = make(&underlying_stream, move(options)); tls_socket->set_sni(host); - tls_socket->on_connected = [&] { + tls_socket->on_connected = [=] { promise->resolve({}); }; - tls_socket->on_tls_error = [&](auto alert) { + tls_socket->on_tls_error = [&, promise](auto alert) { tls_socket->try_disambiguate_error(); promise->reject(AK::Error::from_string_view(enum_to_string(alert))); }; - TRY(promise->await()); - tls_socket->on_tls_error = nullptr; - tls_socket->on_connected = nullptr; + ScopeGuard clear_callbacks = [&tls_socket = *tls_socket] { + tls_socket.on_tls_error = nullptr; + tls_socket.on_connected = nullptr; + }; + + CO_TRY(co_await PromiseAwaiter { promise }); + tls_socket->m_context.should_expect_successful_read = true; - return tls_socket; + co_return tls_socket; +} + +ErrorOr> TLSv12::connect(const AK::ByteString& host, u16 port, TLS::Options options) +{ + return Core::run_async_in_current_event_loop([&] { return async_connect(host, port, move(options)); }); +} + +ErrorOr> TLSv12::connect(const AK::ByteString& host, Core::Socket& underlying_stream, TLS::Options options) +{ + return Core::run_async_in_current_event_loop([&] { return async_connect(host, underlying_stream, move(options)); }); } void TLSv12::setup_connection() diff --git a/Userland/Libraries/LibTLS/TLSv12.h b/Userland/Libraries/LibTLS/TLSv12.h index 806fd9ff9b..5591d6dd0a 100644 --- a/Userland/Libraries/LibTLS/TLSv12.h +++ b/Userland/Libraries/LibTLS/TLSv12.h @@ -360,6 +360,9 @@ public: static ErrorOr> connect(ByteString const& host, u16 port, Options = {}); static ErrorOr> connect(ByteString const& host, Core::Socket& underlying_stream, Options = {}); + static Coroutine>> async_connect(ByteString const& host, u16 port, Options = {}); + static Coroutine>> async_connect(ByteString const& host, Core::Socket& underlying_stream, Options = {}); + using StreamVariantType = Variant, Core::Socket*>; explicit TLSv12(StreamVariantType, Options); diff --git a/Userland/Libraries/LibThreading/ThreadPool.h b/Userland/Libraries/LibThreading/ThreadPool.h index 7992d11371..4da0d50f66 100644 --- a/Userland/Libraries/LibThreading/ThreadPool.h +++ b/Userland/Libraries/LibThreading/ThreadPool.h @@ -55,32 +55,45 @@ public: using Work = TWork; friend struct ThreadPoolLooper; - ThreadPool(Optional concurrency = {}) + template + ThreadPool(Optional concurrency = {}, Args&&... looper_args) requires(IsFunction) : m_handler([](Work work) { return work(); }) , m_work_available(m_mutex) , m_work_done(m_mutex) { - initialize_workers(concurrency.value_or(Core::System::hardware_concurrency())); + initialize_workers(concurrency.value_or(Core::System::hardware_concurrency()), forward(looper_args)...); } - explicit ThreadPool(Function handler, Optional concurrency = {}) + template + explicit ThreadPool(Function handler, Optional concurrency = {}, Args&&... looper_args) : m_handler(move(handler)) , m_work_available(m_mutex) , m_work_done(m_mutex) { - initialize_workers(concurrency.value_or(Core::System::hardware_concurrency())); + initialize_workers(concurrency.value_or(Core::System::hardware_concurrency()), forward(looper_args)...); } ~ThreadPool() { - m_should_exit.store(true, AK::MemoryOrder::memory_order_release); + request_exit(); for (auto& worker : m_workers) { m_work_available.broadcast(); (void)worker->join(); } } + void request_exit() + { + m_should_exit.store(true, AK::MemoryOrder::memory_order_release); + m_work_available.broadcast(); + } + + bool was_exit_requested() const + { + return m_should_exit.load(AK::MemoryOrder::memory_order_acquire); + } + void submit(Work work) { m_work_queue.with_locked([&](auto& queue) { @@ -107,11 +120,12 @@ public: } private: - void initialize_workers(size_t concurrency) + template + void initialize_workers(size_t concurrency, Args&&... looper_args) { for (size_t i = 0; i < concurrency; ++i) { - m_workers.append(Thread::construct([this]() -> intptr_t { - Looper thread_looper; + m_workers.append(Thread::construct([this, looper_args...]() -> intptr_t { + Looper thread_looper { move(looper_args)... }; for (; !m_should_exit;) { auto result = thread_looper.next(*this, true); m_busy_count--; diff --git a/Userland/Services/RequestServer/ConnectionCache.cpp b/Userland/Services/RequestServer/ConnectionCache.cpp index c7dfdd1353..52ca512db0 100644 --- a/Userland/Services/RequestServer/ConnectionCache.cpp +++ b/Userland/Services/RequestServer/ConnectionCache.cpp @@ -25,9 +25,8 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) dbgln_if(REQUESTSERVER_DEBUG, "Request for {} finished", url); ConnectionKey partial_key { url.serialized_host().release_value_but_fixme_should_propagate_errors().to_byte_string(), url.port_or_default() }; - auto fire_off_next_job = [&](auto& cache) { - using CacheType = typename RemoveCVReference::ProtectedType; - auto [it, end] = cache.with_read_locked([&](auto const& cache) { + auto fire_off_next_job = [socket, url, partial_key = move(partial_key)](auto& cache) -> Coroutine { + auto [it, end] = cache.with_read_locked([&](auto& cache) { struct Result { decltype(cache.begin()) it; decltype(cache.end()) end; @@ -41,12 +40,12 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) }); if (it == end) { dbgln("Request for URL {} finished, but we don't own that!", url); - return; + co_return; } auto connection_it = it->value->find_if([&](auto& connection) { return connection->socket == socket; }); if (connection_it.is_end()) { dbgln("Request for URL {} finished, but we don't have a socket for that!", url); - return; + co_return; } auto& connection = *connection_it; @@ -54,12 +53,11 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) connection->job_data->timing_info.performing_request = Duration::from_milliseconds(connection->job_data->timing_info.timer.elapsed_milliseconds()); connection->job_data->timing_info.timer.start(); } - auto& properties = g_inferred_server_properties.with_write_locked([&](auto& map) -> InferredServerProperties& { return map.ensure(partial_key.hostname); }); if (!connection->socket->is_open()) properties.requests_served_per_connection = min(properties.requests_served_per_connection, connection->max_queue_length + 1); - if (connection->request_queue.with_read_locked([](auto const& queue) { return queue.is_empty(); })) { + if (connection->request_queue.with_read_locked([&](auto& queue) { return queue.is_empty(); })) { // Immediately mark the connection as finished, as new jobs will never be run if they are queued // before the deferred_invoke() below runs otherwise. connection->has_started = false; @@ -76,8 +74,8 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) if (ptr->has_started) return; - dbgln_if(REQUESTSERVER_DEBUG, "Removing no-longer-used connection {} (socket {})", ptr, ptr->socket.ptr()); - cache.with_write_locked([&](CacheType& cache) { + dbgln_if(REQUESTSERVER_DEBUG, "Removing no-longer-used connection {} (socket {})", ptr, ptr->socket); + cache.with_write_locked([&](auto& cache) { auto did_remove = cache_entry.remove_first_matching([&](auto& entry) { return entry == ptr; }); VERIFY(did_remove); if (cache_entry.is_empty()) @@ -89,7 +87,7 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) }); } else { auto timer = Core::ElapsedTimer::start_new(); - if (auto result = recreate_socket_if_needed(*connection, url); result.is_error()) { + if (auto result = co_await recreate_socket_if_needed(*connection, url); result.is_error()) { if constexpr (REQUESTSERVER_DEBUG) { connection->job_data->timing_info.starting_connection += Duration::from_milliseconds(timer.elapsed_milliseconds()); } @@ -97,19 +95,19 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) dbgln("ConnectionCache request finish handler, reconnection failed with {}", result.error()); connection->job_data->fail(Core::NetworkJob::Error::ConnectionFailed); }); - return; + co_return; } if constexpr (REQUESTSERVER_DEBUG) { connection->job_data->timing_info.starting_connection += Duration::from_milliseconds(timer.elapsed_milliseconds()); } connection->has_started = true; - Core::deferred_invoke([&connection = *connection, url, &cache] { + Core::deferred_invoke([&connection = *connection, &cache, url] { cache.with_read_locked([&](auto&) { dbgln_if(REQUESTSERVER_DEBUG, "Running next job in queue for connection {}", &connection); connection.timer.start(); connection.current_url = url; - connection.job_data = connection.request_queue.with_write_locked([](auto& queue) { return queue.take_first(); }); + connection.job_data = connection.request_queue.with_write_locked([&](auto& queue) { return queue.take_first(); }); if constexpr (REQUESTSERVER_DEBUG) { connection.job_data->timing_info.waiting_in_queue = Duration::from_milliseconds(connection.job_data->timing_info.timer.elapsed_milliseconds() - connection.job_data->timing_info.performing_request.to_milliseconds()); connection.job_data->timing_info.timer.start(); @@ -122,46 +120,42 @@ void request_did_finish(URL::URL const& url, Core::Socket const* socket) }; if (is>(socket)) - fire_off_next_job(g_tls_connection_cache); + Core::deferred_invoke([fire_off_next_job = move(fire_off_next_job)] { Core::run_async_in_current_event_loop([&] { return fire_off_next_job(g_tls_connection_cache); }); }); else if (is>(socket)) - fire_off_next_job(g_tcp_connection_cache); + Core::deferred_invoke([fire_off_next_job = move(fire_off_next_job)] { Core::run_async_in_current_event_loop([&] { return fire_off_next_job(g_tcp_connection_cache); }); }); else dbgln("Unknown socket {} finished for URL {}", socket, url); } void dump_jobs() { + dbgln("=========== TLS Connection Cache =========="); g_tls_connection_cache.with_read_locked([](auto& cache) { - dbgln("=========== TLS Connection Cache =========="); for (auto& connection : cache) { dbgln(" - {}:{}", connection.key.hostname, connection.key.port); for (auto& entry : *connection.value) { - dbgln(" - Connection {} (started={}) (socket={})", &entry, entry->has_started, entry->socket.ptr()); + dbgln(" - Connection {} (started={}) (socket={})", &entry, entry->has_started, entry->socket); dbgln(" Currently loading {} ({} elapsed)", entry->current_url, entry->timer.is_valid() ? entry->timer.elapsed() : 0); dbgln(" Request Queue:"); - entry->request_queue.for_each_locked([](auto const& job) { + entry->request_queue.for_each_locked([](auto& job) { dbgln(" - {}", &job); }); } } }); - + dbgln("=========== TCP Connection Cache =========="); g_tcp_connection_cache.with_read_locked([](auto& cache) { - dbgln("=========== TCP Connection Cache =========="); for (auto& connection : cache) { dbgln(" - {}:{}", connection.key.hostname, connection.key.port); for (auto& entry : *connection.value) { - dbgln(" - Connection {} (started={}) (socket={})", &entry, entry->has_started, entry->socket.ptr()); + dbgln(" - Connection {} (started={}) (socket={})", &entry, entry->has_started, entry->socket); dbgln(" Currently loading {} ({} elapsed)", entry->current_url, entry->timer.is_valid() ? entry->timer.elapsed() : 0); dbgln(" Request Queue:"); - entry->request_queue.for_each_locked([](auto const& job) { + entry->request_queue.for_each_locked([](auto& job) { dbgln(" - {}", &job); }); } } }); } - -size_t hits; -size_t misses; } diff --git a/Userland/Services/RequestServer/ConnectionCache.h b/Userland/Services/RequestServer/ConnectionCache.h index 8a5482c820..1a0aaa5645 100644 --- a/Userland/Services/RequestServer/ConnectionCache.h +++ b/Userland/Services/RequestServer/ConnectionCache.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -35,19 +36,20 @@ struct Proxy { OwnPtr proxy_client_storage {}; template - ErrorOr> tunnel(URL::URL const& url, Args&&... args) + Coroutine>> tunnel(URL::URL const& url, Args&&... args) { if (data.type == Core::ProxyData::Direct) { - return TRY(SocketType::connect(TRY(url.serialized_host()).to_byte_string(), url.port_or_default(), forward(args)...)); + co_return CO_TRY(co_await SocketType::async_connect(CO_TRY(url.serialized_host()).to_byte_string(), url.port_or_default(), forward(args)...)); } + if (data.type == Core::ProxyData::SOCKS5) { if constexpr (requires { SocketType::connect(declval(), *proxy_client_storage, forward(args)...); }) { - proxy_client_storage = TRY(Core::SOCKSProxyClient::connect(data.host_ipv4, data.port, Core::SOCKSProxyClient::Version::V5, TRY(url.serialized_host()).to_byte_string(), url.port_or_default())); - return TRY(SocketType::connect(TRY(url.serialized_host()).to_byte_string(), *proxy_client_storage, forward(args)...)); + proxy_client_storage = CO_TRY(co_await Core::SOCKSProxyClient::async_connect(data.host_ipv4, data.port, Core::SOCKSProxyClient::Version::V5, CO_TRY(url.serialized_host()).to_byte_string(), url.port_or_default())); + co_return CO_TRY(co_await SocketType::async_connect(CO_TRY(url.serialized_host()).to_byte_string(), *proxy_client_storage, forward(args)...)); } else if constexpr (IsSame) { - return TRY(Core::SOCKSProxyClient::connect(data.host_ipv4, data.port, Core::SOCKSProxyClient::Version::V5, TRY(url.serialized_host()).to_byte_string(), url.port_or_default())); + co_return CO_TRY(co_await Core::SOCKSProxyClient::async_connect(data.host_ipv4, data.port, Core::SOCKSProxyClient::Version::V5, CO_TRY(url.serialized_host()).to_byte_string(), url.port_or_default())); } else { - return Error::from_string_literal("SOCKS5 not supported for this socket type"); + co_return Error::from_string_literal("SOCKS5 not supported for this socket type"); } } VERIFY_NOT_REACHED(); @@ -58,23 +60,24 @@ struct JobData { Function start {}; Function fail {}; Function()> provide_client_certificates {}; - struct TimingInfo { #if REQUESTSERVER_DEBUG + struct { bool valid { true }; Core::ElapsedTimer timer {}; URL::URL url {}; Duration waiting_in_queue {}; Duration starting_connection {}; Duration performing_request {}; -#endif } timing_info {}; - JobData(Function start, Function fail, Function()> provide_client_certificates, TimingInfo timing_info) - : start(move(start)) - , fail(move(fail)) - , provide_client_certificates(move(provide_client_certificates)) - , timing_info(move(timing_info)) + ~JobData() { + if (!timing_info.valid) + return; + dbgln("[RSTIMING] JobData for {} timings:", timing_info.url); + dbgln("[RSTIMING] - Waiting in queue: {}ms", timing_info.waiting_in_queue.to_milliseconds()); + dbgln("[RSTIMING] - Starting connection: {}ms", timing_info.starting_connection.to_milliseconds()); + dbgln("[RSTIMING] - Performing request: {}ms", timing_info.performing_request.to_milliseconds()); } JobData(JobData&& other) @@ -83,30 +86,29 @@ struct JobData { , provide_client_certificates(move(other.provide_client_certificates)) , timing_info(move(other.timing_info)) { -#if REQUESTSERVER_DEBUG other.timing_info.valid = false; -#endif } -#if REQUESTSERVER_DEBUG - ~JobData() + JobData( + Function start, + Function fail, + Function()> provide_client_certificates, + decltype(timing_info) timing_info) + : start(move(start)) + , fail(move(fail)) + , provide_client_certificates(move(provide_client_certificates)) + , timing_info(move(timing_info)) { - if (timing_info.valid) { - dbgln("JobData for {} timings:", timing_info.url); - dbgln(" - Waiting in queue: {}ms", timing_info.waiting_in_queue.to_milliseconds()); - dbgln(" - Starting connection: {}ms", timing_info.starting_connection.to_milliseconds()); - dbgln(" - Performing request: {}ms", timing_info.performing_request.to_milliseconds()); - } } #endif template - static JobData create(NonnullRefPtr job, [[maybe_unused]] URL::URL url) + static JobData create(NonnullRefPtr job) { return JobData { - [job](auto& socket) { job->start(socket); }, - [job](auto error) { job->fail(error); }, - [job] { + /* .start = */ [job](auto& socket) { job->start(socket); }, + /* .fail = */ [job](auto error) { job->fail(error); }, + /* .provide_client_certificates = */ [job] { if constexpr (requires { job->on_certificate_requested; }) { if (job->on_certificate_requested) return job->on_certificate_requested(); @@ -114,17 +116,13 @@ struct JobData { // "use" `job`, otherwise clang gets sad. (void)job; } - return Vector {}; - }, - { + return Vector {}; }, #if REQUESTSERVER_DEBUG + /* .timing_info = */ { .timer = Core::ElapsedTimer::start_new(Core::TimerType::Precise), - .url = move(url), - .waiting_in_queue = {}, - .starting_connection = {}, - .performing_request = {}, -#endif + .url = job->url(), }, +#endif }; } }; @@ -154,7 +152,6 @@ struct ConnectionKey { bool operator==(ConnectionKey const&) const = default; }; - }; template<> @@ -179,11 +176,11 @@ void request_did_finish(URL::URL const&, Core::Socket const*); void dump_jobs(); constexpr static size_t MaxConcurrentConnectionsPerURL = 4; -constexpr static size_t ConnectionKeepAliveTimeMilliseconds = 20'000; +constexpr static size_t ConnectionKeepAliveTimeMilliseconds = 10'000; constexpr static size_t ConnectionCacheQueueHighWatermark = 4; template -ErrorOr recreate_socket_if_needed(T& connection, URL::URL const& url) +Coroutine> recreate_socket_if_needed(T& connection, URL::URL const& url) { using SocketType = typename T::SocketType; using SocketStorageType = typename T::StorageType; @@ -215,30 +212,24 @@ ErrorOr recreate_socket_if_needed(T& connection, URL::URL const& url) return connection.job_data->provide_client_certificates(); return {}; }); - TRY(set_socket(TRY((connection.proxy.template tunnel(url, move(options)))))); + CO_TRY(set_socket(CO_TRY(co_await (connection.proxy.template tunnel(url, move(options)))))); } else { - TRY(set_socket(TRY((connection.proxy.template tunnel(url))))); + CO_TRY(set_socket(CO_TRY(co_await (connection.proxy.template tunnel(url))))); } - dbgln_if(REQUESTSERVER_DEBUG, "Creating a new socket for {} -> {}", url, connection.socket.ptr()); + dbgln_if(REQUESTSERVER_DEBUG, "Creating a new socket for {} -> {}", url, connection.socket); } - return {}; + co_return {}; } -extern size_t hits; -extern size_t misses; - -template -void start_connection(const URL::URL& url, auto job, auto& sockets_for_url, size_t index, Duration, Cache&); - -void ensure_connection(auto& cache, const URL::URL& url, auto job, Core::ProxyData proxy_data = {}) +Coroutine async_get_or_create_connection(auto& cache, URL::URL url, auto job, Core::ProxyData proxy_data = {}) { using CacheEntryType = RemoveCVReference::ProtectedType>().begin()->value)>; auto hostname = url.serialized_host().release_value_but_fixme_should_propagate_errors().to_byte_string(); auto& properties = g_inferred_server_properties.with_write_locked([&](auto& map) -> InferredServerProperties& { return map.ensure(hostname); }); - auto& sockets_for_url = *cache.with_write_locked([&](auto& map) -> NonnullOwnPtr& { - return map.ensure({ move(hostname), url.port_or_default(), proxy_data }, [] { return make(); }); + auto& sockets_for_url = *cache.with_write_locked([&](auto& map) -> CacheEntryType* { + return map.ensure({ move(hostname), url.port_or_default(), proxy_data }, [] { return make(); }).ptr(); }); // Find the connection with an empty queue; if none exist, we'll find the least backed-up connection later. @@ -246,43 +237,40 @@ void ensure_connection(auto& cache, const URL::URL& url, auto job, Core::ProxyDa // issues with concurrent connections, so we'll only allow one connection per URL in that case to avoid issues. // This is a bit too aggressive, but there's no way to know if the server can handle concurrent connections // without trying it out first, and that's not worth the effort as HTTP/1.0 is a legacy protocol anyway. - auto it = sockets_for_url.find_if([&](auto const& connection) { - return properties.requests_served_per_connection < 2 - || connection->request_queue.with_read_locked([](auto const& queue) { return queue.size(); }) <= ConnectionCacheQueueHighWatermark; + auto it = cache.with_read_locked([&](auto&) { + return sockets_for_url.find_if([&](auto& connection) { + return properties.requests_served_per_connection < 2 + || connection->request_queue.with_read_locked([&](auto const& queue) { return queue.size(); }) < ConnectionCacheQueueHighWatermark; + }); }); auto did_add_new_connection = false; auto failed_to_find_a_socket = it.is_end(); - - Proxy proxy { proxy_data }; size_t index; + Proxy proxy { proxy_data }; - auto timer = Core::ElapsedTimer::start_new(); - if (failed_to_find_a_socket && sockets_for_url.size() < MaxConcurrentConnectionsPerURL) { - using ConnectionType = RemoveCVReference().at(0))>; - auto& connection = cache.with_write_locked([&](auto&) -> ConnectionType& { - index = sockets_for_url.size(); - sockets_for_url.append(AK::make( + auto start_timer = Core::ElapsedTimer::start_new(); + if (failed_to_find_a_socket && sockets_for_url.size() < ConnectionCache::MaxConcurrentConnectionsPerURL) { + using ConnectionType = RemoveCVReference().at(0))>; + cache.with_write_locked([&](auto&) { + sockets_for_url.append(make( nullptr, typename ConnectionType::QueueType {}, Core::Timer::create_single_shot(ConnectionKeepAliveTimeMilliseconds, nullptr), true)); - auto& connection = sockets_for_url.last(); - connection->proxy = move(proxy); - return *connection; + index = sockets_for_url.size() - 1; }); - ScopeGuard start_guard = [&] { - connection.is_being_started = false; + auto& socket_for_url = sockets_for_url[index]; + ScopeGuard created = [&] { + socket_for_url->is_being_started = false; }; - dbgln_if(REQUESTSERVER_DEBUG, "I will start a connection ({}) for URL {}", &connection, url); - auto connection_result = proxy.tunnel(url); - misses++; + auto connection_result = co_await proxy.tunnel(url); if (connection_result.is_error()) { dbgln("ConnectionCache: Connection to {} failed: {}", url, connection_result.error()); Core::deferred_invoke([job] { job->fail(Core::NetworkJob::Error::ConnectionFailed); }); - return; + co_return; } auto socket_result = Core::BufferedSocket::create(connection_result.release_value()); if (socket_result.is_error()) { @@ -290,21 +278,20 @@ void ensure_connection(auto& cache, const URL::URL& url, auto job, Core::ProxyDa Core::deferred_invoke([job] { job->fail(Core::NetworkJob::Error::ConnectionFailed); }); - return; + co_return; } + + socket_for_url->socket = socket_result.release_value(); + socket_for_url->proxy = move(proxy); did_add_new_connection = true; - connection.socket = socket_result.release_value(); } - - auto elapsed = Duration::from_milliseconds(timer.elapsed_milliseconds()); - if (failed_to_find_a_socket) { if (!did_add_new_connection) { // Find the least backed-up connection (based on how many entries are in their request queue). index = 0; auto min_queue_size = (size_t)-1; for (auto it = sockets_for_url.begin(); it != sockets_for_url.end(); ++it) { - if (auto queue_size = (*it)->request_queue.with_read_locked([](auto const& queue) { return queue.size(); }); min_queue_size > queue_size) { + if (auto queue_size = (*it)->request_queue.with_read_locked([](auto& queue) { return queue.size(); }); min_queue_size > queue_size) { index = it.index(); min_queue_size = queue_size; } @@ -312,76 +299,67 @@ void ensure_connection(auto& cache, const URL::URL& url, auto job, Core::ProxyDa } } else { index = it.index(); - hits++; } - - dbgln_if(REQUESTSERVER_DEBUG, "ConnectionCache: Hits: {}, Misses: {}", RequestServer::ConnectionCache::hits, RequestServer::ConnectionCache::misses); - start_connection(url, job, sockets_for_url, index, elapsed, cache); -} - -template -void start_connection(URL::URL const& url, auto job, auto& sockets_for_url, size_t index, Duration setup_time, Cache& cache) -{ if (sockets_for_url.is_empty()) { Core::deferred_invoke([job] { job->fail(Core::NetworkJob::Error::ConnectionFailed); }); - return; + co_return; } auto& connection = *sockets_for_url[index]; if (connection.is_being_started) { - // Someone else is creating the connection, queue the job and let them handle it. - dbgln_if(REQUESTSERVER_DEBUG, "Enqueue request for URL {} in {} - {}", url, &connection, connection.socket.ptr()); - auto size = connection.request_queue.with_write_locked([&](auto& queue) { - queue.append(JobData::create(job, url)); - return queue.size(); + dbgln_if(REQUESTSERVER_DEBUG, "Enqueue request for URL {} in {} - {}", url, &connection, connection.socket); + connection.request_queue.with_write_locked([&](auto& queue) { + queue.append(JobData::create(job)); + connection.max_queue_length = max(connection.max_queue_length, queue.size()); }); - connection.max_queue_length = max(connection.max_queue_length, size); - return; + co_return; } + auto connection_time = start_timer.elapsed_milliseconds(); + if (!connection.has_started) { connection.has_started = true; - Core::deferred_invoke([&connection, &cache, url, job, setup_time] { - (void)setup_time; - auto job_data = JobData::create(job, url); - if constexpr (REQUESTSERVER_DEBUG) { - job_data.timing_info.waiting_in_queue = Duration::from_milliseconds(job_data.timing_info.timer.elapsed_milliseconds()); - job_data.timing_info.timer.start(); - } - if (auto result = recreate_socket_if_needed(connection, url); result.is_error()) { - dbgln_if(REQUESTSERVER_DEBUG, "ConnectionCache: request failed to start, failed to make a socket: {}", result.error()); - if constexpr (REQUESTSERVER_DEBUG) { - job_data.timing_info.starting_connection += Duration::from_milliseconds(job_data.timing_info.timer.elapsed_milliseconds()) + setup_time; - job_data.timing_info.timer.start(); - } - Core::deferred_invoke([job] { - job->fail(Core::NetworkJob::Error::ConnectionFailed); - }); - } else { - cache.with_write_locked([&](auto&) { - dbgln_if(REQUESTSERVER_DEBUG, "Immediately start request for url {} in {} - {}", url, &connection, connection.socket.ptr()); - connection.job_data = move(job_data); + Core::deferred_invoke([&connection, url, job = move(job), connection_time] { + Core::run_async_in_current_event_loop([&connection, url = move(url), job = move(job), connection_time] -> Coroutine { + auto timer = Core::ElapsedTimer::start_new(); + // if !REQUESTSERVER_DEBUG, this is unused. + (void)connection_time; + (void)timer; + + if (auto result = co_await recreate_socket_if_needed(connection, url); result.is_error()) { if constexpr (REQUESTSERVER_DEBUG) { - connection.job_data->timing_info.starting_connection += Duration::from_milliseconds(connection.job_data->timing_info.timer.elapsed_milliseconds()) + setup_time; - connection.job_data->timing_info.timer.start(); + connection.job_data->timing_info.starting_connection += Duration::from_milliseconds(timer.elapsed_milliseconds()); } + dbgln("ConnectionCache: request failed to start, failed to make a socket: {}", result.error()); + Core::deferred_invoke([job] { + job->fail(Core::NetworkJob::Error::ConnectionFailed); + }); + } else { + dbgln_if(REQUESTSERVER_DEBUG, "Immediately start request for url {} in {} - {}", url, &connection, connection.socket); connection.removal_timer->stop(); connection.timer.start(); connection.current_url = url; + connection.job_data = JobData::create(job); + if constexpr (REQUESTSERVER_DEBUG) + connection.job_data->timing_info.starting_connection += Duration::from_milliseconds(timer.elapsed_milliseconds() + connection_time); connection.socket->set_notifications_enabled(true); connection.job_data->start(*connection.socket); - }); - } + } + }); }); } else { - dbgln_if(REQUESTSERVER_DEBUG, "Enqueue request for URL {} in {} - {}", url, &connection, connection.socket.ptr()); - auto size = connection.request_queue.with_write_locked([&](auto& queue) { - queue.append(JobData::create(job, url)); - return queue.size(); + dbgln_if(REQUESTSERVER_DEBUG, "Enqueue request for URL {} in {} - {}", url, &connection, connection.socket); + connection.request_queue.with_write_locked([&](auto& queue) { + queue.append(JobData::create(job)); + connection.max_queue_length = max(connection.max_queue_length, queue.size()); }); - connection.max_queue_length = max(connection.max_queue_length, size); } } + +void ensure_connection(auto& cache, URL::URL const& url, auto job, Core::ProxyData proxy_data = {}) +{ + Core::EventLoop::current().adopt_coroutine(async_get_or_create_connection(cache, url, move(job), proxy_data)); +} } diff --git a/Userland/Services/RequestServer/ConnectionFromClient.cpp b/Userland/Services/RequestServer/ConnectionFromClient.cpp index 8f83a5a54f..0a95ef9b63 100644 --- a/Userland/Services/RequestServer/ConnectionFromClient.cpp +++ b/Userland/Services/RequestServer/ConnectionFromClient.cpp @@ -26,11 +26,18 @@ static IDAllocator s_client_ids; ConnectionFromClient::ConnectionFromClient(NonnullOwnPtr socket) : IPC::ConnectionFromClient(*this, move(socket), s_client_ids.allocate()) - , m_thread_pool([this](Work work) { worker_do_work(move(work)); }) + , m_thread_pipe_fds(MUST(Core::System::pipe2(O_CLOEXEC | O_NONBLOCK))) + , m_thread_pool([this](Work work) { worker_do_work(move(work)); }, {}, m_thread_pipe_fds[0]) { s_connections.set(client_id(), *this); } +ConnectionFromClient::~ConnectionFromClient() +{ + close(m_thread_pipe_fds[0]); + close(m_thread_pipe_fds[1]); +} + class Job : public RefCounted , public Weakable { public: @@ -63,6 +70,8 @@ public: s_jobs.remove(m_url); } + URL::URL const& url() const { return m_url; } + private: explicit Job(URL::URL url) : m_url(move(url)) @@ -76,15 +85,37 @@ private: template IterationDecision ConnectionFromClient::Looper::next(Pool& pool, bool wait) { - bool should_exit = false; - auto timer = Core::Timer::create_repeating(100, [&] { - if (Threading::ThreadPoolLooper::next(pool, false) == IterationDecision::Break) { + if (done) + return IterationDecision::Break; + + auto should_quit = false; + + auto exit_timer = Core::Timer::create_repeating(100, [&] { + if (pool.was_exit_requested()) { + done = true; + should_quit = true; event_loop.quit(0); - should_exit = true; } }); - timer->start(); + exit_timer->start(); + + notifier->on_activation = [&] { + char buffer[1]; + auto nread = read(notifier->fd(), buffer, 1); + if (nread == 1) { + if (pool.was_exit_requested()) { + done = true; + should_quit = true; + } else { + should_quit = Threading::ThreadPoolLooper::next(pool, true) == IterationDecision::Break; + } + + if (should_quit) + event_loop.quit(0); + } + }; + if (!wait) { event_loop.deferred_invoke([&] { event_loop.quit(0); @@ -93,7 +124,10 @@ IterationDecision ConnectionFromClient::Looper::next(Pool& pool, bool wait event_loop.exec(); - if (should_exit) + exit_timer->stop(); + notifier->on_activation = nullptr; + + if (should_quit) return IterationDecision::Break; return IterationDecision::Continue; } @@ -188,6 +222,10 @@ Messages::RequestServer::ConnectNewClientResponse ConnectionFromClient::connect_ void ConnectionFromClient::enqueue(Work work) { m_thread_pool.submit(move(work)); + auto nwritten = write(m_thread_pipe_fds[1], "x", 1); // notify the worker threads + if (nwritten < 0) { + VERIFY_NOT_REACHED(); + } } Messages::RequestServer::IsSupportedProtocolResponse ConnectionFromClient::is_supported_protocol(ByteString const& protocol) diff --git a/Userland/Services/RequestServer/ConnectionFromClient.h b/Userland/Services/RequestServer/ConnectionFromClient.h index 2d2c820502..c2dfa29eff 100644 --- a/Userland/Services/RequestServer/ConnectionFromClient.h +++ b/Userland/Services/RequestServer/ConnectionFromClient.h @@ -23,7 +23,7 @@ class ConnectionFromClient final C_OBJECT(ConnectionFromClient); public: - ~ConnectionFromClient() override = default; + ~ConnectionFromClient() override; virtual void die() override; @@ -74,10 +74,18 @@ private: template struct Looper : public Threading::ThreadPoolLooper { + Looper(int pipe_fd) + : notifier(Core::Notifier::construct(pipe_fd, Core::NotificationType::Read)) + { + } + IterationDecision next(Pool& pool, bool wait); Core::EventLoop event_loop; + NonnullRefPtr notifier; + bool done { false }; }; + Array m_thread_pipe_fds { -1, -1 }; Threading::ThreadPool m_thread_pool; Threading::Mutex m_ipc_mutex; };