LibIPC+LibWeb: Transfer IPC::Files using sendmsg/recvmsg directly

This refactor eliminates the need for a second "fd passing socket" on
Lagom, as it uses SCM_RIGHTS in the expected fashion, to send fds along
with the data of our Unix socket message.
This commit is contained in:
Andrew Kaster 2024-04-17 16:46:24 -06:00 committed by Tim Flynn
parent a18c7c4405
commit cb87725ec8
9 changed files with 109 additions and 53 deletions

View file

@ -372,9 +372,9 @@ public:)~~~");
static i32 static_message_id() { return (int)MessageID::@message.pascal_name@; }
virtual const char* message_name() const override { return "@endpoint.name@::@message.pascal_name@"; }
static ErrorOr<NonnullOwnPtr<@message.pascal_name@>> decode(Stream& stream, Core::LocalSocket& socket)
static ErrorOr<NonnullOwnPtr<@message.pascal_name@>> decode(Stream& stream, Queue<IPC::File>& files)
{
IPC::Decoder decoder { stream, socket };)~~~");
IPC::Decoder decoder { stream, files };)~~~");
for (auto const& parameter : parameters) {
auto parameter_generator = message_generator.fork();
@ -620,7 +620,7 @@ public:
static u32 static_magic() { return @endpoint.magic@; }
static ErrorOr<NonnullOwnPtr<IPC::Message>> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::LocalSocket& socket)
static ErrorOr<NonnullOwnPtr<IPC::Message>> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Queue<IPC::File>& files)
{
FixedMemoryStream stream { buffer };
auto message_endpoint_magic = TRY(stream.read_value<u32>());)~~~");
@ -649,7 +649,7 @@ public:
message_generator.append(R"~~~(
case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@:
return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket));)~~~");
return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, files));)~~~");
};
do_decode_message(message.name);

View file

@ -7,6 +7,7 @@
#include <LibCore/System.h>
#include <LibIPC/Connection.h>
#include <LibIPC/File.h>
#include <LibIPC/Stub.h>
#include <sys/select.h>
@ -35,18 +36,6 @@ void ConnectionBase::set_deferred_invoker(NonnullOwnPtr<DeferredInvoker> deferre
m_deferred_invoker = move(deferred_invoker);
}
void ConnectionBase::set_fd_passing_socket(NonnullOwnPtr<Core::LocalSocket> socket)
{
m_fd_passing_socket = move(socket);
}
Core::LocalSocket& ConnectionBase::fd_passing_socket()
{
if (m_fd_passing_socket)
return *m_fd_passing_socket;
return *m_socket;
}
ErrorOr<void> ConnectionBase::post_message(Message const& message)
{
return post_message(TRY(message.encode()));
@ -59,7 +48,7 @@ ErrorOr<void> ConnectionBase::post_message(MessageBuffer buffer)
if (!m_socket->is_open())
return Error::from_string_literal("Trying to post_message during IPC shutdown");
if (auto result = buffer.transfer_message(fd_passing_socket(), *m_socket); result.is_error()) {
if (auto result = buffer.transfer_message(*m_socket); result.is_error()) {
shutdown_with_error(result.error());
return result.release_error();
}
@ -122,6 +111,7 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_socket_without
}
u8 buffer[4096];
Vector<int> received_fds;
bool should_shut_down = false;
auto schedule_shutdown = [this, &should_shut_down]() {
@ -132,7 +122,7 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_socket_without
};
while (m_socket->is_open()) {
auto maybe_bytes_read = m_socket->read_without_waiting({ buffer, 4096 });
auto maybe_bytes_read = m_socket->receive_message({ buffer, 4096 }, MSG_DONTWAIT, received_fds);
if (maybe_bytes_read.is_error()) {
auto error = maybe_bytes_read.release_error();
if (error.is_syscall() && error.code() == EAGAIN) {
@ -156,6 +146,8 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_socket_without
}
bytes.append(bytes_read.data(), bytes_read.size());
for (auto const& fd : received_fds)
m_unprocessed_fds.enqueue(IPC::File::adopt_fd(fd));
}
if (!bytes.is_empty()) {

View file

@ -8,12 +8,14 @@
#pragma once
#include <AK/ByteBuffer.h>
#include <AK/Queue.h>
#include <AK/Try.h>
#include <LibCore/Event.h>
#include <LibCore/EventLoop.h>
#include <LibCore/Notifier.h>
#include <LibCore/Socket.h>
#include <LibCore/Timer.h>
#include <LibIPC/File.h>
#include <LibIPC/Forward.h>
#include <LibIPC/Message.h>
#include <errno.h>
@ -38,7 +40,7 @@ class ConnectionBase : public Core::EventReceiver {
public:
virtual ~ConnectionBase() override = default;
void set_fd_passing_socket(NonnullOwnPtr<Core::LocalSocket>);
void set_fd_passing_socket(NonnullOwnPtr<Core::LocalSocket>) { }
void set_deferred_invoker(NonnullOwnPtr<DeferredInvoker>);
DeferredInvoker& deferred_invoker() { return *m_deferred_invoker; }
@ -49,7 +51,7 @@ public:
virtual void die() { }
Core::LocalSocket& socket() { return *m_socket; }
Core::LocalSocket& fd_passing_socket();
Core::LocalSocket const& fd_passing_socket() const { return *m_socket; }
protected:
explicit ConnectionBase(IPC::Stub&, NonnullOwnPtr<Core::LocalSocket>, u32 local_endpoint_magic);
@ -70,11 +72,11 @@ protected:
IPC::Stub& m_local_stub;
NonnullOwnPtr<Core::LocalSocket> m_socket;
OwnPtr<Core::LocalSocket> m_fd_passing_socket;
RefPtr<Core::Timer> m_responsiveness_timer;
Vector<NonnullOwnPtr<Message>> m_unprocessed_messages;
Queue<IPC::File> m_unprocessed_fds;
ByteBuffer m_unprocessed_bytes;
u32 m_local_endpoint_magic { 0 };
@ -138,13 +140,13 @@ protected:
index += sizeof(message_size);
auto remaining_bytes = ReadonlyBytes { bytes.data() + index, message_size };
auto local_message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket());
auto local_message = LocalEndpoint::decode_message(remaining_bytes, m_unprocessed_fds);
if (!local_message.is_error()) {
m_unprocessed_messages.append(local_message.release_value());
continue;
}
auto peer_message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket());
auto peer_message = PeerEndpoint::decode_message(remaining_bytes, m_unprocessed_fds);
if (!peer_message.is_error()) {
m_unprocessed_messages.append(peer_message.release_value());
continue;

View file

@ -90,8 +90,12 @@ ErrorOr<URL::URL> decode(Decoder& decoder)
template<>
ErrorOr<File> decode(Decoder& decoder)
{
int fd = TRY(decoder.socket().receive_fd(O_CLOEXEC));
return File::adopt_fd(fd);
auto file = TRY(decoder.files().try_dequeue());
auto fd = file.fd();
auto fd_flags = TRY(Core::System::fcntl(fd, F_GETFD));
TRY(Core::System::fcntl(fd, F_SETFD, fd_flags | FD_CLOEXEC));
return file;
}
template<>

View file

@ -11,6 +11,7 @@
#include <AK/Concepts.h>
#include <AK/Forward.h>
#include <AK/NumericLimits.h>
#include <AK/Queue.h>
#include <AK/StdLibExtras.h>
#include <AK/String.h>
#include <AK/Try.h>
@ -35,9 +36,9 @@ inline ErrorOr<T> decode(Decoder&)
class Decoder {
public:
Decoder(Stream& stream, Core::LocalSocket& socket)
Decoder(Stream& stream, Queue<IPC::File>& files)
: m_stream(stream)
, m_socket(socket)
, m_files(files)
{
}
@ -60,11 +61,11 @@ public:
ErrorOr<size_t> decode_size();
Stream& stream() { return m_stream; }
Core::LocalSocket& socket() { return m_socket; }
Queue<IPC::File>& files() { return m_files; }
private:
Stream& m_stream;
Core::LocalSocket& m_socket;
Queue<IPC::File>& m_files;
};
template<Arithmetic T>

View file

@ -37,7 +37,7 @@ ErrorOr<void> MessageBuffer::append_file_descriptor(int fd)
return {};
}
ErrorOr<void> MessageBuffer::transfer_message(Core::LocalSocket& fd_passing_socket, Core::LocalSocket& data_socket)
ErrorOr<void> MessageBuffer::transfer_message(Core::LocalSocket& socket)
{
Checked<MessageSizeType> checked_message_size { m_data.size() };
checked_message_size -= sizeof(MessageSizeType);
@ -45,17 +45,30 @@ ErrorOr<void> MessageBuffer::transfer_message(Core::LocalSocket& fd_passing_sock
if (checked_message_size.has_overflow())
return Error::from_string_literal("Message is too large for IPC encoding");
auto message_size = checked_message_size.value();
MessageSizeType const message_size = checked_message_size.value();
m_data.span().overwrite(0, reinterpret_cast<u8 const*>(&message_size), sizeof(message_size));
for (auto const& fd : m_fds)
TRY(fd_passing_socket.send_fd(fd->value()));
auto raw_fds = Vector<int, 1> {};
auto num_fds_to_transfer = m_fds.size();
if (num_fds_to_transfer > 0) {
raw_fds.ensure_capacity(num_fds_to_transfer);
for (auto& owned_fd : m_fds) {
raw_fds.unchecked_append(owned_fd->value());
}
}
ReadonlyBytes bytes_to_write { m_data.span() };
size_t writes_done = 0;
while (!bytes_to_write.is_empty()) {
auto maybe_nwritten = data_socket.write_some(bytes_to_write);
ErrorOr<ssize_t> maybe_nwritten = 0;
if (num_fds_to_transfer > 0) {
maybe_nwritten = socket.send_message(bytes_to_write, 0, raw_fds);
if (!maybe_nwritten.is_error())
num_fds_to_transfer = 0;
} else {
maybe_nwritten = socket.write_some(bytes_to_write);
}
++writes_done;
if (maybe_nwritten.is_error()) {

View file

@ -44,7 +44,7 @@ public:
ErrorOr<void> append_file_descriptor(int fd);
ErrorOr<void> transfer_message(Core::LocalSocket& fd_passing_socket, Core::LocalSocket& data_socket);
ErrorOr<void> transfer_message(Core::LocalSocket& socket);
private:
Vector<u8, 1024> m_data;

View file

@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/ByteReader.h>
#include <AK/MemoryStream.h>
#include <LibCore/Socket.h>
#include <LibCore/System.h>
@ -259,7 +260,7 @@ ErrorOr<void> MessagePort::send_message_on_socket(SerializedTransferRecord const
IPC::Encoder encoder(buffer);
MUST(encoder.encode(serialize_with_transfer_result));
TRY(buffer.transfer_message(*m_fd_passing_socket, *m_socket));
TRY(buffer.transfer_message(*m_socket));
return {};
}
@ -276,41 +277,76 @@ void MessagePort::post_port_message(SerializedTransferRecord serialize_with_tran
});
}
void MessagePort::read_from_socket()
ErrorOr<MessagePort::ParseDecision> MessagePort::parse_message()
{
auto num_bytes_ready = MUST(m_socket->pending_bytes());
static constexpr size_t HEADER_SIZE = sizeof(u32);
auto num_bytes_ready = m_buffered_data.size();
switch (m_socket_state) {
case SocketState::Header: {
if (num_bytes_ready < sizeof(u32))
break;
m_socket_incoming_message_size = MUST(m_socket->read_value<u32>());
num_bytes_ready -= sizeof(u32);
if (num_bytes_ready < HEADER_SIZE)
return ParseDecision::NotEnoughData;
m_socket_incoming_message_size = ByteReader::load32(m_buffered_data.data());
// NOTE: We don't decrement the number of ready bytes because we want to remove the entire
// message + header from the buffer in one go on success
m_socket_state = SocketState::Data;
}
[[fallthrough]];
}
case SocketState::Data: {
if (num_bytes_ready < m_socket_incoming_message_size)
break;
return ParseDecision::NotEnoughData;
Vector<u8, 1024> data;
data.resize(m_socket_incoming_message_size, true);
MUST(m_socket->read_until_filled(data));
auto payload = m_buffered_data.span().slice(HEADER_SIZE, m_socket_incoming_message_size);
FixedMemoryStream stream { data, FixedMemoryStream::Mode::ReadOnly };
IPC::Decoder decoder(stream, *m_fd_passing_socket);
FixedMemoryStream stream { payload, FixedMemoryStream::Mode::ReadOnly };
IPC::Decoder decoder { stream, m_unprocessed_fds };
auto serialize_with_transfer_result = MUST(decoder.decode<SerializedTransferRecord>());
auto serialized_transfer_record = TRY(decoder.decode<SerializedTransferRecord>());
// Make sure to advance our state machine before dispatching the MessageEvent,
// as dispatching events can run arbitrary JS (and cause us to receive another message!)
m_socket_state = SocketState::Header;
post_message_task_steps(serialize_with_transfer_result);
m_buffered_data.remove(0, HEADER_SIZE + m_socket_incoming_message_size);
post_message_task_steps(serialized_transfer_record);
break;
}
case SocketState::Error:
VERIFY_NOT_REACHED();
break;
return Error::from_errno(ENOMSG);
}
return ParseDecision::ParseNextMessage;
}
void MessagePort::read_from_socket()
{
u8 buffer[4096] {};
Vector<int> fds;
// FIXME: What if pending bytes is > 4096? Should we loop here?
auto maybe_bytes = m_socket->receive_message(buffer, MSG_NOSIGNAL, fds);
if (maybe_bytes.is_error()) {
dbgln("MessagePort::read_from_socket(): Failed to receive message: {}", maybe_bytes.error());
return;
}
auto bytes = maybe_bytes.release_value();
m_buffered_data.append(bytes.data(), bytes.size());
for (auto fd : fds)
m_unprocessed_fds.enqueue(IPC::File::adopt_fd(fd));
while (true) {
auto parse_decision_or_error = parse_message();
if (parse_decision_or_error.is_error()) {
dbgln("MessagePort::read_from_socket(): Failed to parse message: {}", parse_decision_or_error.error());
return;
}
if (parse_decision_or_error.value() == ParseDecision::NotEnoughData)
break;
}
}

View file

@ -78,6 +78,12 @@ private:
ErrorOr<void> send_message_on_socket(SerializedTransferRecord const&);
void read_from_socket();
enum class ParseDecision {
NotEnoughData,
ParseNextMessage,
};
ErrorOr<ParseDecision> parse_message();
// The HTML spec implies(!) that this is MessagePort.[[RemotePort]]
JS::GCPtr<MessagePort> m_remote_port;
@ -93,6 +99,8 @@ private:
Error,
} m_socket_state { SocketState::Header };
size_t m_socket_incoming_message_size { 0 };
Queue<IPC::File> m_unprocessed_fds;
Vector<u8> m_buffered_data;
JS::GCPtr<DOM::EventTarget> m_worker_event_target;
};