mirror of
https://github.com/SerenityOS/serenity
synced 2024-10-01 21:53:54 +00:00
LibIPC+LibWeb: Transfer IPC::Files using sendmsg/recvmsg directly
This refactor eliminates the need for a second "fd passing socket" on Lagom, as it uses SCM_RIGHTS in the expected fashion, to send fds along with the data of our Unix socket message.
This commit is contained in:
parent
a18c7c4405
commit
cb87725ec8
|
@ -372,9 +372,9 @@ public:)~~~");
|
|||
static i32 static_message_id() { return (int)MessageID::@message.pascal_name@; }
|
||||
virtual const char* message_name() const override { return "@endpoint.name@::@message.pascal_name@"; }
|
||||
|
||||
static ErrorOr<NonnullOwnPtr<@message.pascal_name@>> decode(Stream& stream, Core::LocalSocket& socket)
|
||||
static ErrorOr<NonnullOwnPtr<@message.pascal_name@>> decode(Stream& stream, Queue<IPC::File>& files)
|
||||
{
|
||||
IPC::Decoder decoder { stream, socket };)~~~");
|
||||
IPC::Decoder decoder { stream, files };)~~~");
|
||||
|
||||
for (auto const& parameter : parameters) {
|
||||
auto parameter_generator = message_generator.fork();
|
||||
|
@ -620,7 +620,7 @@ public:
|
|||
|
||||
static u32 static_magic() { return @endpoint.magic@; }
|
||||
|
||||
static ErrorOr<NonnullOwnPtr<IPC::Message>> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Core::LocalSocket& socket)
|
||||
static ErrorOr<NonnullOwnPtr<IPC::Message>> decode_message(ReadonlyBytes buffer, [[maybe_unused]] Queue<IPC::File>& files)
|
||||
{
|
||||
FixedMemoryStream stream { buffer };
|
||||
auto message_endpoint_magic = TRY(stream.read_value<u32>());)~~~");
|
||||
|
@ -649,7 +649,7 @@ public:
|
|||
|
||||
message_generator.append(R"~~~(
|
||||
case (int)Messages::@endpoint.name@::MessageID::@message.pascal_name@:
|
||||
return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, socket));)~~~");
|
||||
return TRY(Messages::@endpoint.name@::@message.pascal_name@::decode(stream, files));)~~~");
|
||||
};
|
||||
|
||||
do_decode_message(message.name);
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
#include <LibCore/System.h>
|
||||
#include <LibIPC/Connection.h>
|
||||
#include <LibIPC/File.h>
|
||||
#include <LibIPC/Stub.h>
|
||||
#include <sys/select.h>
|
||||
|
||||
|
@ -35,18 +36,6 @@ void ConnectionBase::set_deferred_invoker(NonnullOwnPtr<DeferredInvoker> deferre
|
|||
m_deferred_invoker = move(deferred_invoker);
|
||||
}
|
||||
|
||||
void ConnectionBase::set_fd_passing_socket(NonnullOwnPtr<Core::LocalSocket> socket)
|
||||
{
|
||||
m_fd_passing_socket = move(socket);
|
||||
}
|
||||
|
||||
Core::LocalSocket& ConnectionBase::fd_passing_socket()
|
||||
{
|
||||
if (m_fd_passing_socket)
|
||||
return *m_fd_passing_socket;
|
||||
return *m_socket;
|
||||
}
|
||||
|
||||
ErrorOr<void> ConnectionBase::post_message(Message const& message)
|
||||
{
|
||||
return post_message(TRY(message.encode()));
|
||||
|
@ -59,7 +48,7 @@ ErrorOr<void> ConnectionBase::post_message(MessageBuffer buffer)
|
|||
if (!m_socket->is_open())
|
||||
return Error::from_string_literal("Trying to post_message during IPC shutdown");
|
||||
|
||||
if (auto result = buffer.transfer_message(fd_passing_socket(), *m_socket); result.is_error()) {
|
||||
if (auto result = buffer.transfer_message(*m_socket); result.is_error()) {
|
||||
shutdown_with_error(result.error());
|
||||
return result.release_error();
|
||||
}
|
||||
|
@ -122,6 +111,7 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_socket_without
|
|||
}
|
||||
|
||||
u8 buffer[4096];
|
||||
Vector<int> received_fds;
|
||||
|
||||
bool should_shut_down = false;
|
||||
auto schedule_shutdown = [this, &should_shut_down]() {
|
||||
|
@ -132,7 +122,7 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_socket_without
|
|||
};
|
||||
|
||||
while (m_socket->is_open()) {
|
||||
auto maybe_bytes_read = m_socket->read_without_waiting({ buffer, 4096 });
|
||||
auto maybe_bytes_read = m_socket->receive_message({ buffer, 4096 }, MSG_DONTWAIT, received_fds);
|
||||
if (maybe_bytes_read.is_error()) {
|
||||
auto error = maybe_bytes_read.release_error();
|
||||
if (error.is_syscall() && error.code() == EAGAIN) {
|
||||
|
@ -156,6 +146,8 @@ ErrorOr<Vector<u8>> ConnectionBase::read_as_much_as_possible_from_socket_without
|
|||
}
|
||||
|
||||
bytes.append(bytes_read.data(), bytes_read.size());
|
||||
for (auto const& fd : received_fds)
|
||||
m_unprocessed_fds.enqueue(IPC::File::adopt_fd(fd));
|
||||
}
|
||||
|
||||
if (!bytes.is_empty()) {
|
||||
|
|
|
@ -8,12 +8,14 @@
|
|||
#pragma once
|
||||
|
||||
#include <AK/ByteBuffer.h>
|
||||
#include <AK/Queue.h>
|
||||
#include <AK/Try.h>
|
||||
#include <LibCore/Event.h>
|
||||
#include <LibCore/EventLoop.h>
|
||||
#include <LibCore/Notifier.h>
|
||||
#include <LibCore/Socket.h>
|
||||
#include <LibCore/Timer.h>
|
||||
#include <LibIPC/File.h>
|
||||
#include <LibIPC/Forward.h>
|
||||
#include <LibIPC/Message.h>
|
||||
#include <errno.h>
|
||||
|
@ -38,7 +40,7 @@ class ConnectionBase : public Core::EventReceiver {
|
|||
public:
|
||||
virtual ~ConnectionBase() override = default;
|
||||
|
||||
void set_fd_passing_socket(NonnullOwnPtr<Core::LocalSocket>);
|
||||
void set_fd_passing_socket(NonnullOwnPtr<Core::LocalSocket>) { }
|
||||
void set_deferred_invoker(NonnullOwnPtr<DeferredInvoker>);
|
||||
DeferredInvoker& deferred_invoker() { return *m_deferred_invoker; }
|
||||
|
||||
|
@ -49,7 +51,7 @@ public:
|
|||
virtual void die() { }
|
||||
|
||||
Core::LocalSocket& socket() { return *m_socket; }
|
||||
Core::LocalSocket& fd_passing_socket();
|
||||
Core::LocalSocket const& fd_passing_socket() const { return *m_socket; }
|
||||
|
||||
protected:
|
||||
explicit ConnectionBase(IPC::Stub&, NonnullOwnPtr<Core::LocalSocket>, u32 local_endpoint_magic);
|
||||
|
@ -70,11 +72,11 @@ protected:
|
|||
IPC::Stub& m_local_stub;
|
||||
|
||||
NonnullOwnPtr<Core::LocalSocket> m_socket;
|
||||
OwnPtr<Core::LocalSocket> m_fd_passing_socket;
|
||||
|
||||
RefPtr<Core::Timer> m_responsiveness_timer;
|
||||
|
||||
Vector<NonnullOwnPtr<Message>> m_unprocessed_messages;
|
||||
Queue<IPC::File> m_unprocessed_fds;
|
||||
ByteBuffer m_unprocessed_bytes;
|
||||
|
||||
u32 m_local_endpoint_magic { 0 };
|
||||
|
@ -138,13 +140,13 @@ protected:
|
|||
index += sizeof(message_size);
|
||||
auto remaining_bytes = ReadonlyBytes { bytes.data() + index, message_size };
|
||||
|
||||
auto local_message = LocalEndpoint::decode_message(remaining_bytes, fd_passing_socket());
|
||||
auto local_message = LocalEndpoint::decode_message(remaining_bytes, m_unprocessed_fds);
|
||||
if (!local_message.is_error()) {
|
||||
m_unprocessed_messages.append(local_message.release_value());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto peer_message = PeerEndpoint::decode_message(remaining_bytes, fd_passing_socket());
|
||||
auto peer_message = PeerEndpoint::decode_message(remaining_bytes, m_unprocessed_fds);
|
||||
if (!peer_message.is_error()) {
|
||||
m_unprocessed_messages.append(peer_message.release_value());
|
||||
continue;
|
||||
|
|
|
@ -90,8 +90,12 @@ ErrorOr<URL::URL> decode(Decoder& decoder)
|
|||
template<>
|
||||
ErrorOr<File> decode(Decoder& decoder)
|
||||
{
|
||||
int fd = TRY(decoder.socket().receive_fd(O_CLOEXEC));
|
||||
return File::adopt_fd(fd);
|
||||
auto file = TRY(decoder.files().try_dequeue());
|
||||
auto fd = file.fd();
|
||||
|
||||
auto fd_flags = TRY(Core::System::fcntl(fd, F_GETFD));
|
||||
TRY(Core::System::fcntl(fd, F_SETFD, fd_flags | FD_CLOEXEC));
|
||||
return file;
|
||||
}
|
||||
|
||||
template<>
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include <AK/Concepts.h>
|
||||
#include <AK/Forward.h>
|
||||
#include <AK/NumericLimits.h>
|
||||
#include <AK/Queue.h>
|
||||
#include <AK/StdLibExtras.h>
|
||||
#include <AK/String.h>
|
||||
#include <AK/Try.h>
|
||||
|
@ -35,9 +36,9 @@ inline ErrorOr<T> decode(Decoder&)
|
|||
|
||||
class Decoder {
|
||||
public:
|
||||
Decoder(Stream& stream, Core::LocalSocket& socket)
|
||||
Decoder(Stream& stream, Queue<IPC::File>& files)
|
||||
: m_stream(stream)
|
||||
, m_socket(socket)
|
||||
, m_files(files)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -60,11 +61,11 @@ public:
|
|||
ErrorOr<size_t> decode_size();
|
||||
|
||||
Stream& stream() { return m_stream; }
|
||||
Core::LocalSocket& socket() { return m_socket; }
|
||||
Queue<IPC::File>& files() { return m_files; }
|
||||
|
||||
private:
|
||||
Stream& m_stream;
|
||||
Core::LocalSocket& m_socket;
|
||||
Queue<IPC::File>& m_files;
|
||||
};
|
||||
|
||||
template<Arithmetic T>
|
||||
|
|
|
@ -37,7 +37,7 @@ ErrorOr<void> MessageBuffer::append_file_descriptor(int fd)
|
|||
return {};
|
||||
}
|
||||
|
||||
ErrorOr<void> MessageBuffer::transfer_message(Core::LocalSocket& fd_passing_socket, Core::LocalSocket& data_socket)
|
||||
ErrorOr<void> MessageBuffer::transfer_message(Core::LocalSocket& socket)
|
||||
{
|
||||
Checked<MessageSizeType> checked_message_size { m_data.size() };
|
||||
checked_message_size -= sizeof(MessageSizeType);
|
||||
|
@ -45,17 +45,30 @@ ErrorOr<void> MessageBuffer::transfer_message(Core::LocalSocket& fd_passing_sock
|
|||
if (checked_message_size.has_overflow())
|
||||
return Error::from_string_literal("Message is too large for IPC encoding");
|
||||
|
||||
auto message_size = checked_message_size.value();
|
||||
MessageSizeType const message_size = checked_message_size.value();
|
||||
m_data.span().overwrite(0, reinterpret_cast<u8 const*>(&message_size), sizeof(message_size));
|
||||
|
||||
for (auto const& fd : m_fds)
|
||||
TRY(fd_passing_socket.send_fd(fd->value()));
|
||||
auto raw_fds = Vector<int, 1> {};
|
||||
auto num_fds_to_transfer = m_fds.size();
|
||||
if (num_fds_to_transfer > 0) {
|
||||
raw_fds.ensure_capacity(num_fds_to_transfer);
|
||||
for (auto& owned_fd : m_fds) {
|
||||
raw_fds.unchecked_append(owned_fd->value());
|
||||
}
|
||||
}
|
||||
|
||||
ReadonlyBytes bytes_to_write { m_data.span() };
|
||||
size_t writes_done = 0;
|
||||
|
||||
while (!bytes_to_write.is_empty()) {
|
||||
auto maybe_nwritten = data_socket.write_some(bytes_to_write);
|
||||
ErrorOr<ssize_t> maybe_nwritten = 0;
|
||||
if (num_fds_to_transfer > 0) {
|
||||
maybe_nwritten = socket.send_message(bytes_to_write, 0, raw_fds);
|
||||
if (!maybe_nwritten.is_error())
|
||||
num_fds_to_transfer = 0;
|
||||
} else {
|
||||
maybe_nwritten = socket.write_some(bytes_to_write);
|
||||
}
|
||||
++writes_done;
|
||||
|
||||
if (maybe_nwritten.is_error()) {
|
||||
|
|
|
@ -44,7 +44,7 @@ public:
|
|||
|
||||
ErrorOr<void> append_file_descriptor(int fd);
|
||||
|
||||
ErrorOr<void> transfer_message(Core::LocalSocket& fd_passing_socket, Core::LocalSocket& data_socket);
|
||||
ErrorOr<void> transfer_message(Core::LocalSocket& socket);
|
||||
|
||||
private:
|
||||
Vector<u8, 1024> m_data;
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
* SPDX-License-Identifier: BSD-2-Clause
|
||||
*/
|
||||
|
||||
#include <AK/ByteReader.h>
|
||||
#include <AK/MemoryStream.h>
|
||||
#include <LibCore/Socket.h>
|
||||
#include <LibCore/System.h>
|
||||
|
@ -259,7 +260,7 @@ ErrorOr<void> MessagePort::send_message_on_socket(SerializedTransferRecord const
|
|||
IPC::Encoder encoder(buffer);
|
||||
MUST(encoder.encode(serialize_with_transfer_result));
|
||||
|
||||
TRY(buffer.transfer_message(*m_fd_passing_socket, *m_socket));
|
||||
TRY(buffer.transfer_message(*m_socket));
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -276,41 +277,76 @@ void MessagePort::post_port_message(SerializedTransferRecord serialize_with_tran
|
|||
});
|
||||
}
|
||||
|
||||
void MessagePort::read_from_socket()
|
||||
ErrorOr<MessagePort::ParseDecision> MessagePort::parse_message()
|
||||
{
|
||||
auto num_bytes_ready = MUST(m_socket->pending_bytes());
|
||||
static constexpr size_t HEADER_SIZE = sizeof(u32);
|
||||
|
||||
auto num_bytes_ready = m_buffered_data.size();
|
||||
switch (m_socket_state) {
|
||||
case SocketState::Header: {
|
||||
if (num_bytes_ready < sizeof(u32))
|
||||
break;
|
||||
m_socket_incoming_message_size = MUST(m_socket->read_value<u32>());
|
||||
num_bytes_ready -= sizeof(u32);
|
||||
if (num_bytes_ready < HEADER_SIZE)
|
||||
return ParseDecision::NotEnoughData;
|
||||
|
||||
m_socket_incoming_message_size = ByteReader::load32(m_buffered_data.data());
|
||||
// NOTE: We don't decrement the number of ready bytes because we want to remove the entire
|
||||
// message + header from the buffer in one go on success
|
||||
m_socket_state = SocketState::Data;
|
||||
}
|
||||
[[fallthrough]];
|
||||
}
|
||||
case SocketState::Data: {
|
||||
if (num_bytes_ready < m_socket_incoming_message_size)
|
||||
break;
|
||||
return ParseDecision::NotEnoughData;
|
||||
|
||||
Vector<u8, 1024> data;
|
||||
data.resize(m_socket_incoming_message_size, true);
|
||||
MUST(m_socket->read_until_filled(data));
|
||||
auto payload = m_buffered_data.span().slice(HEADER_SIZE, m_socket_incoming_message_size);
|
||||
|
||||
FixedMemoryStream stream { data, FixedMemoryStream::Mode::ReadOnly };
|
||||
IPC::Decoder decoder(stream, *m_fd_passing_socket);
|
||||
FixedMemoryStream stream { payload, FixedMemoryStream::Mode::ReadOnly };
|
||||
IPC::Decoder decoder { stream, m_unprocessed_fds };
|
||||
|
||||
auto serialize_with_transfer_result = MUST(decoder.decode<SerializedTransferRecord>());
|
||||
auto serialized_transfer_record = TRY(decoder.decode<SerializedTransferRecord>());
|
||||
|
||||
// Make sure to advance our state machine before dispatching the MessageEvent,
|
||||
// as dispatching events can run arbitrary JS (and cause us to receive another message!)
|
||||
m_socket_state = SocketState::Header;
|
||||
|
||||
post_message_task_steps(serialize_with_transfer_result);
|
||||
m_buffered_data.remove(0, HEADER_SIZE + m_socket_incoming_message_size);
|
||||
|
||||
post_message_task_steps(serialized_transfer_record);
|
||||
|
||||
break;
|
||||
}
|
||||
case SocketState::Error:
|
||||
VERIFY_NOT_REACHED();
|
||||
break;
|
||||
return Error::from_errno(ENOMSG);
|
||||
}
|
||||
|
||||
return ParseDecision::ParseNextMessage;
|
||||
}
|
||||
|
||||
void MessagePort::read_from_socket()
|
||||
{
|
||||
u8 buffer[4096] {};
|
||||
|
||||
Vector<int> fds;
|
||||
// FIXME: What if pending bytes is > 4096? Should we loop here?
|
||||
auto maybe_bytes = m_socket->receive_message(buffer, MSG_NOSIGNAL, fds);
|
||||
if (maybe_bytes.is_error()) {
|
||||
dbgln("MessagePort::read_from_socket(): Failed to receive message: {}", maybe_bytes.error());
|
||||
return;
|
||||
}
|
||||
auto bytes = maybe_bytes.release_value();
|
||||
|
||||
m_buffered_data.append(bytes.data(), bytes.size());
|
||||
|
||||
for (auto fd : fds)
|
||||
m_unprocessed_fds.enqueue(IPC::File::adopt_fd(fd));
|
||||
|
||||
while (true) {
|
||||
auto parse_decision_or_error = parse_message();
|
||||
if (parse_decision_or_error.is_error()) {
|
||||
dbgln("MessagePort::read_from_socket(): Failed to parse message: {}", parse_decision_or_error.error());
|
||||
return;
|
||||
}
|
||||
if (parse_decision_or_error.value() == ParseDecision::NotEnoughData)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -78,6 +78,12 @@ private:
|
|||
ErrorOr<void> send_message_on_socket(SerializedTransferRecord const&);
|
||||
void read_from_socket();
|
||||
|
||||
enum class ParseDecision {
|
||||
NotEnoughData,
|
||||
ParseNextMessage,
|
||||
};
|
||||
ErrorOr<ParseDecision> parse_message();
|
||||
|
||||
// The HTML spec implies(!) that this is MessagePort.[[RemotePort]]
|
||||
JS::GCPtr<MessagePort> m_remote_port;
|
||||
|
||||
|
@ -93,6 +99,8 @@ private:
|
|||
Error,
|
||||
} m_socket_state { SocketState::Header };
|
||||
size_t m_socket_incoming_message_size { 0 };
|
||||
Queue<IPC::File> m_unprocessed_fds;
|
||||
Vector<u8> m_buffered_data;
|
||||
|
||||
JS::GCPtr<DOM::EventTarget> m_worker_event_target;
|
||||
};
|
||||
|
|
Loading…
Reference in a new issue