diff --git a/Kernel/API/POSIX/sys/socket.h b/Kernel/API/POSIX/sys/socket.h index 750a353a41..d8fc765b88 100644 --- a/Kernel/API/POSIX/sys/socket.h +++ b/Kernel/API/POSIX/sys/socket.h @@ -81,6 +81,32 @@ struct msghdr { int msg_flags; }; +// These three are non-POSIX, but common: +#define CMSG_ALIGN(x) (((x) + sizeof(void*) - 1) & ~(sizeof(void*) - 1)) +#define CMSG_SPACE(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + CMSG_ALIGN(x)) +#define CMSG_LEN(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + (x)) + +static inline struct cmsghdr* CMSG_FIRSTHDR(struct msghdr* msg) +{ + if (msg->msg_controllen < sizeof(struct cmsghdr)) + return (struct cmsghdr*)0; + return (struct cmsghdr*)msg->msg_control; +} + +static inline struct cmsghdr* CMSG_NXTHDR(struct msghdr* msg, struct cmsghdr* cmsg) +{ + struct cmsghdr* next = (struct cmsghdr*)((char*)cmsg + CMSG_ALIGN(cmsg->cmsg_len)); + unsigned offset = (char*)next - (char*)msg->msg_control; + if (msg->msg_controllen < offset + sizeof(struct cmsghdr)) + return (struct cmsghdr*)0; + return next; +} + +static inline void* CMSG_DATA(struct cmsghdr* cmsg) +{ + return (void*)(cmsg + 1); +} + struct sockaddr { sa_family_t sa_family; char sa_data[14]; diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index 4b55e8bea2..33db0b60b0 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -520,6 +520,26 @@ ErrorOr> LocalSocket::recvfd(OpenFileDesc return queue.take_first(); } +ErrorOr> LocalSocket::recvfds(OpenFileDescription const& socket_description, int n) +{ + MutexLocker locker(mutex()); + NonnullLockRefPtrVector fds; + + auto role = this->role(socket_description); + if (role != Role::Connected && role != Role::Accepted) + return set_so_error(EINVAL); + auto& queue = recvfd_queue_for(socket_description); + + for (int i = 0; i < n; ++i) { + if (queue.is_empty()) + break; + + fds.append(queue.take_first()); + } + + return fds; +} + ErrorOr LocalSocket::try_set_path(StringView path) { m_path = TRY(KString::try_create(path)); diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index 5416bd5373..9698a49384 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -28,6 +28,7 @@ public: ErrorOr sendfd(OpenFileDescription const& socket_description, NonnullLockRefPtr passing_description); ErrorOr> recvfd(OpenFileDescription const& socket_description); + ErrorOr> recvfds(OpenFileDescription const& socket_description, int n); static void for_each(Function); static ErrorOr try_for_each(Function(LocalSocket const&)>); diff --git a/Kernel/Syscalls/socket.cpp b/Kernel/Syscalls/socket.cpp index e02e78bb10..f29598a35d 100644 --- a/Kernel/Syscalls/socket.cpp +++ b/Kernel/Syscalls/socket.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include @@ -199,6 +200,24 @@ ErrorOr Process::sys$sendmsg(int sockfd, Userspacesend_signal(SIGPIPE, &Process::current()); return EPIPE; } + + if (msg.msg_controllen > 0) { + // Handle command messages. + auto cmsg_buffer = TRY(ByteBuffer::create_uninitialized(msg.msg_controllen)); + TRY(copy_from_user(cmsg_buffer.data(), msg.msg_control, msg.msg_controllen)); + msg.msg_control = cmsg_buffer.data(); + for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (socket.is_local() && cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + auto& local_socket = static_cast(socket); + int* fds = (int*)CMSG_DATA(cmsg); + size_t nfds = (cmsg->cmsg_len - CMSG_ALIGN(sizeof(struct cmsghdr))) / sizeof(int); + for (size_t i = 0; i < nfds; ++i) { + TRY(local_socket.sendfd(*description, TRY(open_file_description(fds[i])))); + } + } + } + } + auto data_buffer = TRY(UserOrKernelBuffer::for_user_buffer((u8*)iovs[0].iov_base, iovs[0].iov_len)); while (true) { @@ -267,21 +286,41 @@ ErrorOr Process::sys$recvmsg(int sockfd, Userspace user msg_flags |= MSG_TRUNC; } - if (socket.wants_timestamp()) { - struct { - cmsghdr cmsg; - timeval timestamp; - } cmsg_timestamp; - socklen_t control_length = sizeof(cmsg_timestamp); - if (msg.msg_controllen < control_length) { + socklen_t current_cmsg_len = 0; + auto try_add_cmsg = [&](int level, int type, void const* data, socklen_t len) -> ErrorOr { + if (current_cmsg_len + len > msg.msg_controllen) { msg_flags |= MSG_CTRUNC; - } else { - cmsg_timestamp = { { control_length, SOL_SOCKET, SCM_TIMESTAMP }, timestamp.to_timeval() }; - TRY(copy_to_user(msg.msg_control, &cmsg_timestamp, control_length)); + return false; } - TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_controllen, &control_length)); + + cmsghdr cmsg = { (socklen_t)CMSG_LEN(len), level, type }; + cmsghdr* target = (cmsghdr*)(((char*)msg.msg_control) + current_cmsg_len); + TRY(copy_to_user(target, &cmsg)); + TRY(copy_to_user(CMSG_DATA(target), data, len)); + current_cmsg_len += CMSG_ALIGN(cmsg.cmsg_len); + return true; + }; + + if (socket.wants_timestamp()) { + timeval time = timestamp.to_timeval(); + TRY(try_add_cmsg(SOL_SOCKET, SCM_TIMESTAMP, &time, sizeof(time))); } + int space_for_fds = (msg.msg_controllen - current_cmsg_len - sizeof(struct cmsghdr)) / sizeof(int); + if (space_for_fds > 0 && socket.is_local()) { + auto& local_socket = static_cast(socket); + auto descriptions = TRY(local_socket.recvfds(description, space_for_fds)); + Vector fdnums; + for (auto& description : descriptions) { + auto fd_allocation = TRY(m_fds.with_exclusive([](auto& fds) { return fds.allocate(); })); + m_fds.with_exclusive([&](auto& fds) { fds[fd_allocation.fd].set(description, 0); }); + fdnums.append(fd_allocation.fd); + } + TRY(try_add_cmsg(SOL_SOCKET, SCM_RIGHTS, fdnums.data(), fdnums.size() * sizeof(int))); + } + + TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_controllen, ¤t_cmsg_len)); + TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_flags, &msg_flags)); return result.value(); } diff --git a/Userland/Libraries/LibC/sys/socket.h b/Userland/Libraries/LibC/sys/socket.h index 158ec12303..f505992c02 100644 --- a/Userland/Libraries/LibC/sys/socket.h +++ b/Userland/Libraries/LibC/sys/socket.h @@ -33,30 +33,4 @@ int socketpair(int domain, int type, int protocol, int sv[2]); int sendfd(int sockfd, int fd); int recvfd(int sockfd, int options); -// These three are non-POSIX, but common: -#define CMSG_ALIGN(x) (((x) + sizeof(void*) - 1) & ~(sizeof(void*) - 1)) -#define CMSG_SPACE(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + CMSG_ALIGN(x)) -#define CMSG_LEN(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + (x)) - -static inline struct cmsghdr* CMSG_FIRSTHDR(struct msghdr* msg) -{ - if (msg->msg_controllen < sizeof(struct cmsghdr)) - return 0; - return (struct cmsghdr*)msg->msg_control; -} - -static inline struct cmsghdr* CMSG_NXTHDR(struct msghdr* msg, struct cmsghdr* cmsg) -{ - struct cmsghdr* next = (struct cmsghdr*)((char*)cmsg + CMSG_ALIGN(cmsg->cmsg_len)); - unsigned offset = (char*)next - (char*)msg->msg_control; - if (msg->msg_controllen < offset + sizeof(struct cmsghdr)) - return NULL; - return next; -} - -static inline void* CMSG_DATA(struct cmsghdr* cmsg) -{ - return (void*)(cmsg + 1); -} - __END_DECLS