af_unix: Add OOB support

This patch adds OOB support for AF_UNIX sockets.
The semantics is same as TCP.

The last byte of a message with the OOB flag is
treated as the OOB byte. The byte is separated into
a skb and a pointer to the skb is stored in unix_sock.
The pointer is used to enforce OOB semantics.

Signed-off-by: Rao Shoaib <rao.shoaib@oracle.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
Rao Shoaib 2021-08-01 00:57:07 -07:00 committed by David S. Miller
parent 7e89350c90
commit 314001f0bf
6 changed files with 602 additions and 2 deletions

View File

@ -70,6 +70,9 @@ struct unix_sock {
struct socket_wq peer_wq;
wait_queue_entry_t peer_wake;
struct scm_stat scm_stat;
#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
struct sk_buff *oob_skb;
#endif
};
static inline struct unix_sock *unix_sk(const struct sock *sk)

View File

@ -25,6 +25,11 @@ config UNIX_SCM
depends on UNIX
default y
config AF_UNIX_OOB
bool
depends on UNIX
default y
config UNIX_DIAG
tristate "UNIX: socket monitoring interface"
depends on UNIX

View File

@ -503,6 +503,12 @@ static void unix_sock_destructor(struct sock *sk)
skb_queue_purge(&sk->sk_receive_queue);
#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
if (u->oob_skb) {
kfree_skb(u->oob_skb);
u->oob_skb = NULL;
}
#endif
WARN_ON(refcount_read(&sk->sk_wmem_alloc));
WARN_ON(!sk_unhashed(sk));
WARN_ON(sk->sk_socket);
@ -1889,6 +1895,46 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
*/
#define UNIX_SKB_FRAGS_SZ (PAGE_SIZE << get_order(32768))
#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
static int queue_oob(struct socket *sock, struct msghdr *msg, struct sock *other)
{
struct unix_sock *ousk = unix_sk(other);
struct sk_buff *skb;
int err = 0;
skb = sock_alloc_send_skb(sock->sk, 1, msg->msg_flags & MSG_DONTWAIT, &err);
if (!skb)
return err;
skb_put(skb, 1);
skb->len = 1;
err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1);
if (err) {
kfree_skb(skb);
return err;
}
unix_state_lock(other);
maybe_add_creds(skb, sock, other);
skb_get(skb);
if (ousk->oob_skb)
kfree_skb(ousk->oob_skb);
ousk->oob_skb = skb;
scm_stat_add(other, skb);
skb_queue_tail(&other->sk_receive_queue, skb);
sk_send_sigurg(other);
unix_state_unlock(other);
other->sk_data_ready(other);
return err;
}
#endif
static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
size_t len)
{
@ -1907,8 +1953,14 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
return err;
err = -EOPNOTSUPP;
if (msg->msg_flags&MSG_OOB)
goto out_err;
if (msg->msg_flags & MSG_OOB) {
#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
if (len)
len--;
else
#endif
goto out_err;
}
if (msg->msg_namelen) {
err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
@ -1973,6 +2025,15 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
sent += size;
}
#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
if (msg->msg_flags & MSG_OOB) {
err = queue_oob(sock, msg, other);
if (err)
goto out_err;
sent++;
}
#endif
scm_destroy(&scm);
return sent;
@ -2358,6 +2419,59 @@ struct unix_stream_read_state {
unsigned int splice_flags;
};
#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
static int unix_stream_recv_urg(struct unix_stream_read_state *state)
{
struct socket *sock = state->socket;
struct sock *sk = sock->sk;
struct unix_sock *u = unix_sk(sk);
int chunk = 1;
if (sock_flag(sk, SOCK_URGINLINE) || !u->oob_skb)
return -EINVAL;
chunk = state->recv_actor(u->oob_skb, 0, chunk, state);
if (chunk < 0)
return -EFAULT;
if (!(state->flags & MSG_PEEK)) {
UNIXCB(u->oob_skb).consumed += 1;
kfree_skb(u->oob_skb);
u->oob_skb = NULL;
}
state->msg->msg_flags |= MSG_OOB;
return 1;
}
static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
int flags, int copied)
{
struct unix_sock *u = unix_sk(sk);
if (!unix_skb_len(skb) && !(flags & MSG_PEEK)) {
skb_unlink(skb, &sk->sk_receive_queue);
consume_skb(skb);
skb = NULL;
} else {
if (skb == u->oob_skb) {
if (copied) {
skb = NULL;
} else if (sock_flag(sk, SOCK_URGINLINE)) {
if (!(flags & MSG_PEEK)) {
u->oob_skb = NULL;
consume_skb(skb);
}
} else if (!(flags & MSG_PEEK)) {
skb_unlink(skb, &sk->sk_receive_queue);
consume_skb(skb);
skb = skb_peek(&sk->sk_receive_queue);
}
}
}
return skb;
}
#endif
static int unix_stream_read_generic(struct unix_stream_read_state *state,
bool freezable)
{
@ -2383,6 +2497,15 @@ static int unix_stream_read_generic(struct unix_stream_read_state *state,
if (unlikely(flags & MSG_OOB)) {
err = -EOPNOTSUPP;
#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
mutex_lock(&u->iolock);
unix_state_lock(sk);
err = unix_stream_recv_urg(state);
unix_state_unlock(sk);
mutex_unlock(&u->iolock);
#endif
goto out;
}
@ -2411,6 +2534,18 @@ static int unix_stream_read_generic(struct unix_stream_read_state *state,
}
last = skb = skb_peek(&sk->sk_receive_queue);
last_len = last ? last->len : 0;
#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
if (skb) {
skb = manage_oob(skb, sk, flags, copied);
if (!skb) {
unix_state_unlock(sk);
if (copied)
break;
goto redo;
}
}
#endif
again:
if (skb == NULL) {
if (copied >= target)
@ -2746,6 +2881,20 @@ static int unix_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
case SIOCUNIXFILE:
err = unix_open_file(sk);
break;
#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
case SIOCATMARK:
{
struct sk_buff *skb;
struct unix_sock *u = unix_sk(sk);
int answ = 0;
skb = skb_peek(&sk->sk_receive_queue);
if (skb && skb == u->oob_skb)
answ = 1;
err = put_user(answ, (int __user *)arg);
}
break;
#endif
default:
err = -ENOIOCTLCMD;
break;

View File

@ -38,6 +38,7 @@ TARGETS += mount_setattr
TARGETS += mqueue
TARGETS += nci
TARGETS += net
TARGETS += net/af_unix
TARGETS += net/forwarding
TARGETS += net/mptcp
TARGETS += netfilter

View File

@ -0,0 +1,5 @@
##TEST_GEN_FILES := test_unix_oob
TEST_PROGS := test_unix_oob
include ../../lib.mk
all: $(TEST_PROGS)

View File

@ -0,0 +1,437 @@
// SPDX-License-Identifier: GPL-2.0-or-later
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <errno.h>
#include <netinet/tcp.h>
#include <sys/un.h>
#include <sys/signal.h>
#include <sys/poll.h>
static int pipefd[2];
static int signal_recvd;
static pid_t producer_id;
static char sock_name[32];
static void sig_hand(int sn, siginfo_t *si, void *p)
{
signal_recvd = sn;
}
static int set_sig_handler(int signal)
{
struct sigaction sa;
sa.sa_sigaction = sig_hand;
sigemptyset(&sa.sa_mask);
sa.sa_flags = SA_SIGINFO | SA_RESTART;
return sigaction(signal, &sa, NULL);
}
static void set_filemode(int fd, int set)
{
int flags = fcntl(fd, F_GETFL, 0);
if (set)
flags &= ~O_NONBLOCK;
else
flags |= O_NONBLOCK;
fcntl(fd, F_SETFL, flags);
}
static void signal_producer(int fd)
{
char cmd;
cmd = 'S';
write(fd, &cmd, sizeof(cmd));
}
static void wait_for_signal(int fd)
{
char buf[5];
read(fd, buf, 5);
}
static void die(int status)
{
fflush(NULL);
unlink(sock_name);
kill(producer_id, SIGTERM);
exit(status);
}
int is_sioctatmark(int fd)
{
int ans = -1;
if (ioctl(fd, SIOCATMARK, &ans, sizeof(ans)) < 0) {
#ifdef DEBUG
perror("SIOCATMARK Failed");
#endif
}
return ans;
}
void read_oob(int fd, char *c)
{
*c = ' ';
if (recv(fd, c, sizeof(*c), MSG_OOB) < 0) {
#ifdef DEBUG
perror("Reading MSG_OOB Failed");
#endif
}
}
int read_data(int pfd, char *buf, int size)
{
int len = 0;
memset(buf, size, '0');
len = read(pfd, buf, size);
#ifdef DEBUG
if (len < 0)
perror("read failed");
#endif
return len;
}
static void wait_for_data(int pfd, int event)
{
struct pollfd pfds[1];
pfds[0].fd = pfd;
pfds[0].events = event;
poll(pfds, 1, -1);
}
void producer(struct sockaddr_un *consumer_addr)
{
int cfd;
char buf[64];
int i;
memset(buf, 'x', sizeof(buf));
cfd = socket(AF_UNIX, SOCK_STREAM, 0);
wait_for_signal(pipefd[0]);
if (connect(cfd, (struct sockaddr *)consumer_addr,
sizeof(struct sockaddr)) != 0) {
perror("Connect failed");
kill(0, SIGTERM);
exit(1);
}
for (i = 0; i < 2; i++) {
/* Test 1: Test for SIGURG and OOB */
wait_for_signal(pipefd[0]);
memset(buf, 'x', sizeof(buf));
buf[63] = '@';
send(cfd, buf, sizeof(buf), MSG_OOB);
wait_for_signal(pipefd[0]);
/* Test 2: Test for OOB being overwitten */
memset(buf, 'x', sizeof(buf));
buf[63] = '%';
send(cfd, buf, sizeof(buf), MSG_OOB);
memset(buf, 'x', sizeof(buf));
buf[63] = '#';
send(cfd, buf, sizeof(buf), MSG_OOB);
wait_for_signal(pipefd[0]);
/* Test 3: Test for SIOCATMARK */
memset(buf, 'x', sizeof(buf));
buf[63] = '@';
send(cfd, buf, sizeof(buf), MSG_OOB);
memset(buf, 'x', sizeof(buf));
buf[63] = '%';
send(cfd, buf, sizeof(buf), MSG_OOB);
memset(buf, 'x', sizeof(buf));
send(cfd, buf, sizeof(buf), 0);
wait_for_signal(pipefd[0]);
/* Test 4: Test for 1byte OOB msg */
memset(buf, 'x', sizeof(buf));
buf[0] = '@';
send(cfd, buf, 1, MSG_OOB);
}
}
int
main(int argc, char **argv)
{
int lfd, pfd;
struct sockaddr_un consumer_addr, paddr;
socklen_t len = sizeof(consumer_addr);
char buf[1024];
int on = 0;
char oob;
int flags;
int atmark;
char *tmp_file;
lfd = socket(AF_UNIX, SOCK_STREAM, 0);
memset(&consumer_addr, 0, sizeof(consumer_addr));
consumer_addr.sun_family = AF_UNIX;
sprintf(sock_name, "unix_oob_%d", getpid());
unlink(sock_name);
strcpy(consumer_addr.sun_path, sock_name);
if ((bind(lfd, (struct sockaddr *)&consumer_addr,
sizeof(consumer_addr))) != 0) {
perror("socket bind failed");
exit(1);
}
pipe(pipefd);
listen(lfd, 1);
producer_id = fork();
if (producer_id == 0) {
producer(&consumer_addr);
exit(0);
}
set_sig_handler(SIGURG);
signal_producer(pipefd[1]);
pfd = accept(lfd, (struct sockaddr *) &paddr, &len);
fcntl(pfd, F_SETOWN, getpid());
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 1:
* veriyf that SIGURG is
* delivered and 63 bytes are
* read and oob is '@'
*/
wait_for_data(pfd, POLLIN | POLLPRI);
read_oob(pfd, &oob);
len = read_data(pfd, buf, 1024);
if (!signal_recvd || len != 63 || oob != '@') {
fprintf(stderr, "Test 1 failed sigurg %d len %d %c\n",
signal_recvd, len, oob);
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 2:
* Verify that the first OOB is over written by
* the 2nd one and the first OOB is returned as
* part of the read, and sigurg is received.
*/
wait_for_data(pfd, POLLIN | POLLPRI);
len = 0;
while (len < 70)
len = recv(pfd, buf, 1024, MSG_PEEK);
len = read_data(pfd, buf, 1024);
read_oob(pfd, &oob);
if (!signal_recvd || len != 127 || oob != '#') {
fprintf(stderr, "Test 2 failed, sigurg %d len %d OOB %c\n",
signal_recvd, len, oob);
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 3:
* verify that 2nd oob over writes
* the first one and read breaks at
* oob boundary returning 127 bytes
* and sigurg is received and atmark
* is set.
* oob is '%' and second read returns
* 64 bytes.
*/
len = 0;
wait_for_data(pfd, POLLIN | POLLPRI);
while (len < 150)
len = recv(pfd, buf, 1024, MSG_PEEK);
len = read_data(pfd, buf, 1024);
atmark = is_sioctatmark(pfd);
read_oob(pfd, &oob);
if (!signal_recvd || len != 127 || oob != '%' || atmark != 1) {
fprintf(stderr, "Test 3 failed, sigurg %d len %d OOB %c ",
"atmark %d\n", signal_recvd, len, oob, atmark);
die(1);
}
signal_recvd = 0;
len = read_data(pfd, buf, 1024);
if (len != 64) {
fprintf(stderr, "Test 3.1 failed, sigurg %d len %d OOB %c\n",
signal_recvd, len, oob);
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 4:
* verify that a single byte
* oob message is delivered.
* set non blocking mode and
* check proper error is
* returned and sigurg is
* received and correct
* oob is read.
*/
set_filemode(pfd, 0);
wait_for_data(pfd, POLLIN | POLLPRI);
len = read_data(pfd, buf, 1024);
if ((len == -1) && (errno == 11))
len = 0;
read_oob(pfd, &oob);
if (!signal_recvd || len != 0 || oob != '@') {
fprintf(stderr, "Test 4 failed, sigurg %d len %d OOB %c\n",
signal_recvd, len, oob);
die(1);
}
set_filemode(pfd, 1);
/* Inline Testing */
on = 1;
if (setsockopt(pfd, SOL_SOCKET, SO_OOBINLINE, &on, sizeof(on))) {
perror("SO_OOBINLINE");
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 1 -- Inline:
* Check that SIGURG is
* delivered and 63 bytes are
* read and oob is '@'
*/
wait_for_data(pfd, POLLIN | POLLPRI);
len = read_data(pfd, buf, 1024);
if (!signal_recvd || len != 63) {
fprintf(stderr, "Test 1 Inline failed, sigurg %d len %d\n",
signal_recvd, len);
die(1);
}
len = read_data(pfd, buf, 1024);
if (len != 1) {
fprintf(stderr,
"Test 1.1 Inline failed, sigurg %d len %d oob %c\n",
signal_recvd, len, oob);
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 2 -- Inline:
* Verify that the first OOB is over written by
* the 2nd one and read breaks correctly on
* 2nd OOB boundary with the first OOB returned as
* part of the read, and sigurg is delivered and
* siocatmark returns true.
* next read returns one byte, the oob byte
* and siocatmark returns false.
*/
len = 0;
wait_for_data(pfd, POLLIN | POLLPRI);
while (len < 70)
len = recv(pfd, buf, 1024, MSG_PEEK);
len = read_data(pfd, buf, 1024);
atmark = is_sioctatmark(pfd);
if (len != 127 || atmark != 1 || !signal_recvd) {
fprintf(stderr, "Test 2 Inline failed, len %d atmark %d\n",
len, atmark);
die(1);
}
len = read_data(pfd, buf, 1024);
atmark = is_sioctatmark(pfd);
if (len != 1 || buf[0] != '#' || atmark == 1) {
fprintf(stderr, "Test 2.1 Inline failed, len %d data %c atmark %d\n",
len, buf[0], atmark);
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 3 -- Inline:
* verify that 2nd oob over writes
* the first one and read breaks at
* oob boundary returning 127 bytes
* and sigurg is received and siocatmark
* is true after the read.
* subsequent read returns 65 bytes
* because of oob which should be '%'.
*/
len = 0;
wait_for_data(pfd, POLLIN | POLLPRI);
while (len < 126)
len = recv(pfd, buf, 1024, MSG_PEEK);
len = read_data(pfd, buf, 1024);
atmark = is_sioctatmark(pfd);
if (!signal_recvd || len != 127 || !atmark) {
fprintf(stderr,
"Test 3 Inline failed, sigurg %d len %d data %c\n",
signal_recvd, len, buf[0]);
die(1);
}
len = read_data(pfd, buf, 1024);
atmark = is_sioctatmark(pfd);
if (len != 65 || buf[0] != '%' || atmark != 0) {
fprintf(stderr,
"Test 3.1 Inline failed, len %d oob %c atmark %d\n",
len, buf[0], atmark);
die(1);
}
signal_recvd = 0;
signal_producer(pipefd[1]);
/* Test 4 -- Inline:
* verify that a single
* byte oob message is delivered
* and read returns one byte, the oob
* byte and sigurg is received
*/
wait_for_data(pfd, POLLIN | POLLPRI);
len = read_data(pfd, buf, 1024);
if (!signal_recvd || len != 1 || buf[0] != '@') {
fprintf(stderr,
"Test 4 Inline failed, signal %d len %d data %c\n",
signal_recvd, len, buf[0]);
die(1);
}
die(0);
}