net: annotate data-races around sock->ops

IPV6_ADDRFORM socket option is evil, because it can change sock->ops
while other threads might read it. Same issue for sk->sk_family
being set to AF_INET.

Adding READ_ONCE() over sock->ops reads is needed for sockets
that might be impacted by IPV6_ADDRFORM.

Note that mptcp_is_tcpsk() can also overwrite sock->ops.

Adding annotations for all sk->sk_family reads will require
more patches :/

BUG: KCSAN: data-race in ____sys_sendmsg / do_ipv6_setsockopt

write to 0xffff888109f24ca0 of 8 bytes by task 4470 on cpu 0:
do_ipv6_setsockopt+0x2c5e/0x2ce0 net/ipv6/ipv6_sockglue.c:491
ipv6_setsockopt+0x57/0x130 net/ipv6/ipv6_sockglue.c:1012
udpv6_setsockopt+0x95/0xa0 net/ipv6/udp.c:1690
sock_common_setsockopt+0x61/0x70 net/core/sock.c:3663
__sys_setsockopt+0x1c3/0x230 net/socket.c:2273
__do_sys_setsockopt net/socket.c:2284 [inline]
__se_sys_setsockopt net/socket.c:2281 [inline]
__x64_sys_setsockopt+0x66/0x80 net/socket.c:2281
do_syscall_x64 arch/x86/entry/common.c:50 [inline]
do_syscall_64+0x41/0xc0 arch/x86/entry/common.c:80
entry_SYSCALL_64_after_hwframe+0x63/0xcd

read to 0xffff888109f24ca0 of 8 bytes by task 4469 on cpu 1:
sock_sendmsg_nosec net/socket.c:724 [inline]
sock_sendmsg net/socket.c:747 [inline]
____sys_sendmsg+0x349/0x4c0 net/socket.c:2503
___sys_sendmsg net/socket.c:2557 [inline]
__sys_sendmmsg+0x263/0x500 net/socket.c:2643
__do_sys_sendmmsg net/socket.c:2672 [inline]
__se_sys_sendmmsg net/socket.c:2669 [inline]
__x64_sys_sendmmsg+0x57/0x60 net/socket.c:2669
do_syscall_x64 arch/x86/entry/common.c:50 [inline]
do_syscall_64+0x41/0xc0 arch/x86/entry/common.c:80
entry_SYSCALL_64_after_hwframe+0x63/0xcd

value changed: 0xffffffff850e32b8 -> 0xffffffff850da890

Reported by Kernel Concurrency Sanitizer on:
CPU: 1 PID: 4469 Comm: syz-executor.1 Not tainted 6.4.0-rc5-syzkaller-00313-g4c605260bc60 #0
Hardware name: Google Google Compute Engine/Google Compute Engine, BIOS Google 05/25/2023

Reported-by: syzbot <syzkaller@googlegroups.com>
Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Link: https://lore.kernel.org/r/20230808135809.2300241-1-edumazet@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Eric Dumazet 2023-08-08 13:58:09 +00:00 committed by Jakub Kicinski
parent e05a53ab86
commit 1ded5e5a59
9 changed files with 118 additions and 78 deletions

View file

@ -123,7 +123,7 @@ struct socket {
struct file *file; struct file *file;
struct sock *sk; struct sock *sk;
const struct proto_ops *ops; const struct proto_ops *ops; /* Might change with IPV6_ADDRFORM or MPTCP. */
struct socket_wq wq; struct socket_wq wq;
}; };

View file

@ -1019,7 +1019,7 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
} }
} }
err = csocket->ops->connect(csocket, err = READ_ONCE(csocket->ops)->connect(csocket,
(struct sockaddr *)&sin_server, (struct sockaddr *)&sin_server,
sizeof(struct sockaddr_in), 0); sizeof(struct sockaddr_in), 0);
if (err < 0) { if (err < 0) {
@ -1060,7 +1060,7 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args)
return err; return err;
} }
err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server, err = READ_ONCE(csocket->ops)->connect(csocket, (struct sockaddr *)&sun_server,
sizeof(struct sockaddr_un) - 1, 0); sizeof(struct sockaddr_un) - 1, 0);
if (err < 0) { if (err < 0) {
pr_err("%s (%d): problem connecting socket: %s: %d\n", pr_err("%s (%d): problem connecting socket: %s: %d\n",

View file

@ -130,6 +130,7 @@ EXPORT_SYMBOL(__scm_destroy);
int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
{ {
const struct proto_ops *ops = READ_ONCE(sock->ops);
struct cmsghdr *cmsg; struct cmsghdr *cmsg;
int err; int err;
@ -153,7 +154,7 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
switch (cmsg->cmsg_type) switch (cmsg->cmsg_type)
{ {
case SCM_RIGHTS: case SCM_RIGHTS:
if (!sock->ops || sock->ops->family != PF_UNIX) if (!ops || ops->family != PF_UNIX)
goto error; goto error;
err=scm_fp_copy(cmsg, &p->fp); err=scm_fp_copy(cmsg, &p->fp);
if (err<0) if (err<0)

View file

@ -1198,13 +1198,17 @@ static int sk_psock_verdict_recv(struct sock *sk, struct sk_buff *skb)
static void sk_psock_verdict_data_ready(struct sock *sk) static void sk_psock_verdict_data_ready(struct sock *sk)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
const struct proto_ops *ops;
int copied; int copied;
trace_sk_data_ready(sk); trace_sk_data_ready(sk);
if (unlikely(!sock || !sock->ops || !sock->ops->read_skb)) if (unlikely(!sock))
return; return;
copied = sock->ops->read_skb(sk, sk_psock_verdict_recv); ops = READ_ONCE(sock->ops);
if (!ops || !ops->read_skb)
return;
copied = ops->read_skb(sk, sk_psock_verdict_recv);
if (copied >= 0) { if (copied >= 0) {
struct sk_psock *psock; struct sk_psock *psock;

View file

@ -1277,14 +1277,19 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
break; break;
case SO_RCVLOWAT: case SO_RCVLOWAT:
{
int (*set_rcvlowat)(struct sock *sk, int val) = NULL;
if (val < 0) if (val < 0)
val = INT_MAX; val = INT_MAX;
if (sock && sock->ops->set_rcvlowat) if (sock)
ret = sock->ops->set_rcvlowat(sk, val); set_rcvlowat = READ_ONCE(sock->ops)->set_rcvlowat;
if (set_rcvlowat)
ret = set_rcvlowat(sk, val);
else else
WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
break; break;
}
case SO_RCVTIMEO_OLD: case SO_RCVTIMEO_OLD:
case SO_RCVTIMEO_NEW: case SO_RCVTIMEO_NEW:
ret = sock_set_timeout(&sk->sk_rcvtimeo, optval, ret = sock_set_timeout(&sk->sk_rcvtimeo, optval,
@ -1379,11 +1384,16 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
break; break;
case SO_PEEK_OFF: case SO_PEEK_OFF:
if (sock->ops->set_peek_off) {
ret = sock->ops->set_peek_off(sk, val); int (*set_peek_off)(struct sock *sk, int val);
set_peek_off = READ_ONCE(sock->ops)->set_peek_off;
if (set_peek_off)
ret = set_peek_off(sk, val);
else else
ret = -EOPNOTSUPP; ret = -EOPNOTSUPP;
break; break;
}
case SO_NOFCS: case SO_NOFCS:
sock_valbool_flag(sk, SOCK_NOFCS, valbool); sock_valbool_flag(sk, SOCK_NOFCS, valbool);
@ -1816,7 +1826,7 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
{ {
struct sockaddr_storage address; struct sockaddr_storage address;
lv = sock->ops->getname(sock, (struct sockaddr *)&address, 2); lv = READ_ONCE(sock->ops)->getname(sock, (struct sockaddr *)&address, 2);
if (lv < 0) if (lv < 0)
return -ENOTCONN; return -ENOTCONN;
if (lv < len) if (lv < len)
@ -1858,7 +1868,7 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
break; break;
case SO_PEEK_OFF: case SO_PEEK_OFF:
if (!sock->ops->set_peek_off) if (!READ_ONCE(sock->ops)->set_peek_off)
return -EOPNOTSUPP; return -EOPNOTSUPP;
v.val = READ_ONCE(sk->sk_peek_off); v.val = READ_ONCE(sk->sk_peek_off);

View file

@ -474,8 +474,8 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
WRITE_ONCE(sk->sk_prot, &tcp_prot); WRITE_ONCE(sk->sk_prot, &tcp_prot);
/* Paired with READ_ONCE() in tcp_(get|set)sockopt() */ /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */
WRITE_ONCE(icsk->icsk_af_ops, &ipv4_specific); WRITE_ONCE(icsk->icsk_af_ops, &ipv4_specific);
sk->sk_socket->ops = &inet_stream_ops; WRITE_ONCE(sk->sk_socket->ops, &inet_stream_ops);
sk->sk_family = PF_INET; WRITE_ONCE(sk->sk_family, PF_INET);
tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); tcp_sync_mss(sk, icsk->icsk_pmtu_cookie);
} else { } else {
struct proto *prot = &udp_prot; struct proto *prot = &udp_prot;
@ -488,8 +488,8 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
/* Paired with READ_ONCE(sk->sk_prot) in inet6_dgram_ops */ /* Paired with READ_ONCE(sk->sk_prot) in inet6_dgram_ops */
WRITE_ONCE(sk->sk_prot, prot); WRITE_ONCE(sk->sk_prot, prot);
sk->sk_socket->ops = &inet_dgram_ops; WRITE_ONCE(sk->sk_socket->ops, &inet_dgram_ops);
sk->sk_family = PF_INET; WRITE_ONCE(sk->sk_family, PF_INET);
} }
/* Disable all options not to allocate memory anymore, /* Disable all options not to allocate memory anymore,

View file

@ -67,11 +67,11 @@ static bool mptcp_is_tcpsk(struct sock *sk)
* Hand the socket over to tcp so all further socket ops * Hand the socket over to tcp so all further socket ops
* bypass mptcp. * bypass mptcp.
*/ */
sock->ops = &inet_stream_ops; WRITE_ONCE(sock->ops, &inet_stream_ops);
return true; return true;
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
} else if (unlikely(sk->sk_prot == &tcpv6_prot)) { } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
sock->ops = &inet6_stream_ops; WRITE_ONCE(sock->ops, &inet6_stream_ops);
return true; return true;
#endif #endif
} }
@ -3683,7 +3683,7 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
goto unlock; goto unlock;
} }
err = ssock->ops->bind(ssock, uaddr, addr_len); err = READ_ONCE(ssock->ops)->bind(ssock, uaddr, addr_len);
if (!err) if (!err)
mptcp_copy_inaddrs(sock->sk, ssock->sk); mptcp_copy_inaddrs(sock->sk, ssock->sk);
@ -3717,7 +3717,7 @@ static int mptcp_listen(struct socket *sock, int backlog)
inet_sk_state_store(sk, TCP_LISTEN); inet_sk_state_store(sk, TCP_LISTEN);
sock_set_flag(sk, SOCK_RCU_FREE); sock_set_flag(sk, SOCK_RCU_FREE);
err = ssock->ops->listen(ssock, backlog); err = READ_ONCE(ssock->ops)->listen(ssock, backlog);
inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); inet_sk_state_store(sk, inet_sk_state_load(ssock->sk));
if (!err) { if (!err) {
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);

View file

@ -136,9 +136,10 @@ static void sock_splice_eof(struct file *file);
static void sock_show_fdinfo(struct seq_file *m, struct file *f) static void sock_show_fdinfo(struct seq_file *m, struct file *f)
{ {
struct socket *sock = f->private_data; struct socket *sock = f->private_data;
const struct proto_ops *ops = READ_ONCE(sock->ops);
if (sock->ops->show_fdinfo) if (ops->show_fdinfo)
sock->ops->show_fdinfo(m, sock); ops->show_fdinfo(m, sock);
} }
#else #else
#define sock_show_fdinfo NULL #define sock_show_fdinfo NULL
@ -646,12 +647,14 @@ EXPORT_SYMBOL(sock_alloc);
static void __sock_release(struct socket *sock, struct inode *inode) static void __sock_release(struct socket *sock, struct inode *inode)
{ {
if (sock->ops) { const struct proto_ops *ops = READ_ONCE(sock->ops);
struct module *owner = sock->ops->owner;
if (ops) {
struct module *owner = ops->owner;
if (inode) if (inode)
inode_lock(inode); inode_lock(inode);
sock->ops->release(sock); ops->release(sock);
sock->sk = NULL; sock->sk = NULL;
if (inode) if (inode)
inode_unlock(inode); inode_unlock(inode);
@ -722,7 +725,7 @@ static noinline void call_trace_sock_send_length(struct sock *sk, int ret,
static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg) static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg)
{ {
int ret = INDIRECT_CALL_INET(sock->ops->sendmsg, inet6_sendmsg, int ret = INDIRECT_CALL_INET(READ_ONCE(sock->ops)->sendmsg, inet6_sendmsg,
inet_sendmsg, sock, msg, inet_sendmsg, sock, msg,
msg_data_left(msg)); msg_data_left(msg));
BUG_ON(ret == -EIOCBQUEUED); BUG_ON(ret == -EIOCBQUEUED);
@ -786,13 +789,14 @@ int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg,
struct kvec *vec, size_t num, size_t size) struct kvec *vec, size_t num, size_t size)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
const struct proto_ops *ops = READ_ONCE(sock->ops);
if (!sock->ops->sendmsg_locked) if (!ops->sendmsg_locked)
return sock_no_sendmsg_locked(sk, msg, size); return sock_no_sendmsg_locked(sk, msg, size);
iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size); iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size);
return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg)); return ops->sendmsg_locked(sk, msg, msg_data_left(msg));
} }
EXPORT_SYMBOL(kernel_sendmsg_locked); EXPORT_SYMBOL(kernel_sendmsg_locked);
@ -1017,7 +1021,8 @@ static noinline void call_trace_sock_recv_length(struct sock *sk, int ret, int f
static inline int sock_recvmsg_nosec(struct socket *sock, struct msghdr *msg, static inline int sock_recvmsg_nosec(struct socket *sock, struct msghdr *msg,
int flags) int flags)
{ {
int ret = INDIRECT_CALL_INET(sock->ops->recvmsg, inet6_recvmsg, int ret = INDIRECT_CALL_INET(READ_ONCE(sock->ops)->recvmsg,
inet6_recvmsg,
inet_recvmsg, sock, msg, inet_recvmsg, sock, msg,
msg_data_left(msg), flags); msg_data_left(msg), flags);
if (trace_sock_recv_length_enabled()) if (trace_sock_recv_length_enabled())
@ -1072,19 +1077,23 @@ static ssize_t sock_splice_read(struct file *file, loff_t *ppos,
unsigned int flags) unsigned int flags)
{ {
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
const struct proto_ops *ops;
if (unlikely(!sock->ops->splice_read)) ops = READ_ONCE(sock->ops);
if (unlikely(!ops->splice_read))
return copy_splice_read(file, ppos, pipe, len, flags); return copy_splice_read(file, ppos, pipe, len, flags);
return sock->ops->splice_read(sock, ppos, pipe, len, flags); return ops->splice_read(sock, ppos, pipe, len, flags);
} }
static void sock_splice_eof(struct file *file) static void sock_splice_eof(struct file *file)
{ {
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
const struct proto_ops *ops;
if (sock->ops->splice_eof) ops = READ_ONCE(sock->ops);
sock->ops->splice_eof(sock); if (ops->splice_eof)
ops->splice_eof(sock);
} }
static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to) static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to)
@ -1181,13 +1190,14 @@ EXPORT_SYMBOL(vlan_ioctl_set);
static long sock_do_ioctl(struct net *net, struct socket *sock, static long sock_do_ioctl(struct net *net, struct socket *sock,
unsigned int cmd, unsigned long arg) unsigned int cmd, unsigned long arg)
{ {
const struct proto_ops *ops = READ_ONCE(sock->ops);
struct ifreq ifr; struct ifreq ifr;
bool need_copyout; bool need_copyout;
int err; int err;
void __user *argp = (void __user *)arg; void __user *argp = (void __user *)arg;
void __user *data; void __user *data;
err = sock->ops->ioctl(sock, cmd, arg); err = ops->ioctl(sock, cmd, arg);
/* /*
* If this ioctl is unknown try to hand it down * If this ioctl is unknown try to hand it down
@ -1216,6 +1226,7 @@ static long sock_do_ioctl(struct net *net, struct socket *sock,
static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg)
{ {
const struct proto_ops *ops;
struct socket *sock; struct socket *sock;
struct sock *sk; struct sock *sk;
void __user *argp = (void __user *)arg; void __user *argp = (void __user *)arg;
@ -1223,6 +1234,7 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg)
struct net *net; struct net *net;
sock = file->private_data; sock = file->private_data;
ops = READ_ONCE(sock->ops);
sk = sock->sk; sk = sock->sk;
net = sock_net(sk); net = sock_net(sk);
if (unlikely(cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))) { if (unlikely(cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))) {
@ -1280,23 +1292,23 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg)
break; break;
case SIOCGSTAMP_OLD: case SIOCGSTAMP_OLD:
case SIOCGSTAMPNS_OLD: case SIOCGSTAMPNS_OLD:
if (!sock->ops->gettstamp) { if (!ops->gettstamp) {
err = -ENOIOCTLCMD; err = -ENOIOCTLCMD;
break; break;
} }
err = sock->ops->gettstamp(sock, argp, err = ops->gettstamp(sock, argp,
cmd == SIOCGSTAMP_OLD, cmd == SIOCGSTAMP_OLD,
!IS_ENABLED(CONFIG_64BIT)); !IS_ENABLED(CONFIG_64BIT));
break; break;
case SIOCGSTAMP_NEW: case SIOCGSTAMP_NEW:
case SIOCGSTAMPNS_NEW: case SIOCGSTAMPNS_NEW:
if (!sock->ops->gettstamp) { if (!ops->gettstamp) {
err = -ENOIOCTLCMD; err = -ENOIOCTLCMD;
break; break;
} }
err = sock->ops->gettstamp(sock, argp, err = ops->gettstamp(sock, argp,
cmd == SIOCGSTAMP_NEW, cmd == SIOCGSTAMP_NEW,
false); false);
break; break;
case SIOCGIFCONF: case SIOCGIFCONF:
@ -1357,9 +1369,10 @@ EXPORT_SYMBOL(sock_create_lite);
static __poll_t sock_poll(struct file *file, poll_table *wait) static __poll_t sock_poll(struct file *file, poll_table *wait)
{ {
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
const struct proto_ops *ops = READ_ONCE(sock->ops);
__poll_t events = poll_requested_events(wait), flag = 0; __poll_t events = poll_requested_events(wait), flag = 0;
if (!sock->ops->poll) if (!ops->poll)
return 0; return 0;
if (sk_can_busy_loop(sock->sk)) { if (sk_can_busy_loop(sock->sk)) {
@ -1371,14 +1384,14 @@ static __poll_t sock_poll(struct file *file, poll_table *wait)
flag = POLL_BUSY_LOOP; flag = POLL_BUSY_LOOP;
} }
return sock->ops->poll(file, sock, wait) | flag; return ops->poll(file, sock, wait) | flag;
} }
static int sock_mmap(struct file *file, struct vm_area_struct *vma) static int sock_mmap(struct file *file, struct vm_area_struct *vma)
{ {
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
return sock->ops->mmap(file, sock, vma); return READ_ONCE(sock->ops)->mmap(file, sock, vma);
} }
static int sock_close(struct inode *inode, struct file *filp) static int sock_close(struct inode *inode, struct file *filp)
@ -1728,7 +1741,7 @@ int __sys_socketpair(int family, int type, int protocol, int __user *usockvec)
goto out; goto out;
} }
err = sock1->ops->socketpair(sock1, sock2); err = READ_ONCE(sock1->ops)->socketpair(sock1, sock2);
if (unlikely(err < 0)) { if (unlikely(err < 0)) {
sock_release(sock2); sock_release(sock2);
sock_release(sock1); sock_release(sock1);
@ -1789,7 +1802,7 @@ int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen)
(struct sockaddr *)&address, (struct sockaddr *)&address,
addrlen); addrlen);
if (!err) if (!err)
err = sock->ops->bind(sock, err = READ_ONCE(sock->ops)->bind(sock,
(struct sockaddr *) (struct sockaddr *)
&address, addrlen); &address, addrlen);
} }
@ -1823,7 +1836,7 @@ int __sys_listen(int fd, int backlog)
err = security_socket_listen(sock, backlog); err = security_socket_listen(sock, backlog);
if (!err) if (!err)
err = sock->ops->listen(sock, backlog); err = READ_ONCE(sock->ops)->listen(sock, backlog);
fput_light(sock->file, fput_needed); fput_light(sock->file, fput_needed);
} }
@ -1843,6 +1856,7 @@ struct file *do_accept(struct file *file, unsigned file_flags,
struct file *newfile; struct file *newfile;
int err, len; int err, len;
struct sockaddr_storage address; struct sockaddr_storage address;
const struct proto_ops *ops;
sock = sock_from_file(file); sock = sock_from_file(file);
if (!sock) if (!sock)
@ -1851,15 +1865,16 @@ struct file *do_accept(struct file *file, unsigned file_flags,
newsock = sock_alloc(); newsock = sock_alloc();
if (!newsock) if (!newsock)
return ERR_PTR(-ENFILE); return ERR_PTR(-ENFILE);
ops = READ_ONCE(sock->ops);
newsock->type = sock->type; newsock->type = sock->type;
newsock->ops = sock->ops; newsock->ops = ops;
/* /*
* We don't need try_module_get here, as the listening socket (sock) * We don't need try_module_get here, as the listening socket (sock)
* has the protocol module (sock->ops->owner) held. * has the protocol module (sock->ops->owner) held.
*/ */
__module_get(newsock->ops->owner); __module_get(ops->owner);
newfile = sock_alloc_file(newsock, flags, sock->sk->sk_prot_creator->name); newfile = sock_alloc_file(newsock, flags, sock->sk->sk_prot_creator->name);
if (IS_ERR(newfile)) if (IS_ERR(newfile))
@ -1869,14 +1884,13 @@ struct file *do_accept(struct file *file, unsigned file_flags,
if (err) if (err)
goto out_fd; goto out_fd;
err = sock->ops->accept(sock, newsock, sock->file->f_flags | file_flags, err = ops->accept(sock, newsock, sock->file->f_flags | file_flags,
false); false);
if (err < 0) if (err < 0)
goto out_fd; goto out_fd;
if (upeer_sockaddr) { if (upeer_sockaddr) {
len = newsock->ops->getname(newsock, len = ops->getname(newsock, (struct sockaddr *)&address, 2);
(struct sockaddr *)&address, 2);
if (len < 0) { if (len < 0) {
err = -ECONNABORTED; err = -ECONNABORTED;
goto out_fd; goto out_fd;
@ -1989,8 +2003,8 @@ int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
if (err) if (err)
goto out; goto out;
err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen, err = READ_ONCE(sock->ops)->connect(sock, (struct sockaddr *)address,
sock->file->f_flags | file_flags); addrlen, sock->file->f_flags | file_flags);
out: out:
return err; return err;
} }
@ -2039,7 +2053,7 @@ int __sys_getsockname(int fd, struct sockaddr __user *usockaddr,
if (err) if (err)
goto out_put; goto out_put;
err = sock->ops->getname(sock, (struct sockaddr *)&address, 0); err = READ_ONCE(sock->ops)->getname(sock, (struct sockaddr *)&address, 0);
if (err < 0) if (err < 0)
goto out_put; goto out_put;
/* "err" is actually length in this case */ /* "err" is actually length in this case */
@ -2071,13 +2085,15 @@ int __sys_getpeername(int fd, struct sockaddr __user *usockaddr,
sock = sockfd_lookup_light(fd, &err, &fput_needed); sock = sockfd_lookup_light(fd, &err, &fput_needed);
if (sock != NULL) { if (sock != NULL) {
const struct proto_ops *ops = READ_ONCE(sock->ops);
err = security_socket_getpeername(sock); err = security_socket_getpeername(sock);
if (err) { if (err) {
fput_light(sock->file, fput_needed); fput_light(sock->file, fput_needed);
return err; return err;
} }
err = sock->ops->getname(sock, (struct sockaddr *)&address, 1); err = ops->getname(sock, (struct sockaddr *)&address, 1);
if (err >= 0) if (err >= 0)
/* "err" is actually length in this case */ /* "err" is actually length in this case */
err = move_addr_to_user(&address, err, usockaddr, err = move_addr_to_user(&address, err, usockaddr,
@ -2227,6 +2243,7 @@ int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval,
int optlen) int optlen)
{ {
sockptr_t optval = USER_SOCKPTR(user_optval); sockptr_t optval = USER_SOCKPTR(user_optval);
const struct proto_ops *ops;
char *kernel_optval = NULL; char *kernel_optval = NULL;
int err, fput_needed; int err, fput_needed;
struct socket *sock; struct socket *sock;
@ -2255,12 +2272,13 @@ int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval,
if (kernel_optval) if (kernel_optval)
optval = KERNEL_SOCKPTR(kernel_optval); optval = KERNEL_SOCKPTR(kernel_optval);
ops = READ_ONCE(sock->ops);
if (level == SOL_SOCKET && !sock_use_custom_sol_socket(sock)) if (level == SOL_SOCKET && !sock_use_custom_sol_socket(sock))
err = sock_setsockopt(sock, level, optname, optval, optlen); err = sock_setsockopt(sock, level, optname, optval, optlen);
else if (unlikely(!sock->ops->setsockopt)) else if (unlikely(!ops->setsockopt))
err = -EOPNOTSUPP; err = -EOPNOTSUPP;
else else
err = sock->ops->setsockopt(sock, level, optname, optval, err = ops->setsockopt(sock, level, optname, optval,
optlen); optlen);
kfree(kernel_optval); kfree(kernel_optval);
out_put: out_put:
@ -2285,6 +2303,7 @@ int __sys_getsockopt(int fd, int level, int optname, char __user *optval,
int __user *optlen) int __user *optlen)
{ {
int max_optlen __maybe_unused; int max_optlen __maybe_unused;
const struct proto_ops *ops;
int err, fput_needed; int err, fput_needed;
struct socket *sock; struct socket *sock;
@ -2299,12 +2318,13 @@ int __sys_getsockopt(int fd, int level, int optname, char __user *optval,
if (!in_compat_syscall()) if (!in_compat_syscall())
max_optlen = BPF_CGROUP_GETSOCKOPT_MAX_OPTLEN(optlen); max_optlen = BPF_CGROUP_GETSOCKOPT_MAX_OPTLEN(optlen);
ops = READ_ONCE(sock->ops);
if (level == SOL_SOCKET) if (level == SOL_SOCKET)
err = sock_getsockopt(sock, level, optname, optval, optlen); err = sock_getsockopt(sock, level, optname, optval, optlen);
else if (unlikely(!sock->ops->getsockopt)) else if (unlikely(!ops->getsockopt))
err = -EOPNOTSUPP; err = -EOPNOTSUPP;
else else
err = sock->ops->getsockopt(sock, level, optname, optval, err = ops->getsockopt(sock, level, optname, optval,
optlen); optlen);
if (!in_compat_syscall()) if (!in_compat_syscall())
@ -2332,7 +2352,7 @@ int __sys_shutdown_sock(struct socket *sock, int how)
err = security_socket_shutdown(sock, how); err = security_socket_shutdown(sock, how);
if (!err) if (!err)
err = sock->ops->shutdown(sock, how); err = READ_ONCE(sock->ops)->shutdown(sock, how);
return err; return err;
} }
@ -3324,6 +3344,7 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock,
void __user *argp = compat_ptr(arg); void __user *argp = compat_ptr(arg);
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
const struct proto_ops *ops;
if (cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15)) if (cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))
return sock_ioctl(file, cmd, (unsigned long)argp); return sock_ioctl(file, cmd, (unsigned long)argp);
@ -3333,10 +3354,11 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock,
return compat_siocwandev(net, argp); return compat_siocwandev(net, argp);
case SIOCGSTAMP_OLD: case SIOCGSTAMP_OLD:
case SIOCGSTAMPNS_OLD: case SIOCGSTAMPNS_OLD:
if (!sock->ops->gettstamp) ops = READ_ONCE(sock->ops);
if (!ops->gettstamp)
return -ENOIOCTLCMD; return -ENOIOCTLCMD;
return sock->ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD, return ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD,
!COMPAT_USE_64BIT_TIME); !COMPAT_USE_64BIT_TIME);
case SIOCETHTOOL: case SIOCETHTOOL:
case SIOCBONDSLAVEINFOQUERY: case SIOCBONDSLAVEINFOQUERY:
@ -3417,6 +3439,7 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
unsigned long arg) unsigned long arg)
{ {
struct socket *sock = file->private_data; struct socket *sock = file->private_data;
const struct proto_ops *ops = READ_ONCE(sock->ops);
int ret = -ENOIOCTLCMD; int ret = -ENOIOCTLCMD;
struct sock *sk; struct sock *sk;
struct net *net; struct net *net;
@ -3424,8 +3447,8 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
sk = sock->sk; sk = sock->sk;
net = sock_net(sk); net = sock_net(sk);
if (sock->ops->compat_ioctl) if (ops->compat_ioctl)
ret = sock->ops->compat_ioctl(sock, cmd, arg); ret = ops->compat_ioctl(sock, cmd, arg);
if (ret == -ENOIOCTLCMD && if (ret == -ENOIOCTLCMD &&
(cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST)) (cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST))
@ -3449,7 +3472,7 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
int kernel_bind(struct socket *sock, struct sockaddr *addr, int addrlen) int kernel_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
{ {
return sock->ops->bind(sock, addr, addrlen); return READ_ONCE(sock->ops)->bind(sock, addr, addrlen);
} }
EXPORT_SYMBOL(kernel_bind); EXPORT_SYMBOL(kernel_bind);
@ -3463,7 +3486,7 @@ EXPORT_SYMBOL(kernel_bind);
int kernel_listen(struct socket *sock, int backlog) int kernel_listen(struct socket *sock, int backlog)
{ {
return sock->ops->listen(sock, backlog); return READ_ONCE(sock->ops)->listen(sock, backlog);
} }
EXPORT_SYMBOL(kernel_listen); EXPORT_SYMBOL(kernel_listen);
@ -3481,6 +3504,7 @@ EXPORT_SYMBOL(kernel_listen);
int kernel_accept(struct socket *sock, struct socket **newsock, int flags) int kernel_accept(struct socket *sock, struct socket **newsock, int flags)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
const struct proto_ops *ops = READ_ONCE(sock->ops);
int err; int err;
err = sock_create_lite(sk->sk_family, sk->sk_type, sk->sk_protocol, err = sock_create_lite(sk->sk_family, sk->sk_type, sk->sk_protocol,
@ -3488,15 +3512,15 @@ int kernel_accept(struct socket *sock, struct socket **newsock, int flags)
if (err < 0) if (err < 0)
goto done; goto done;
err = sock->ops->accept(sock, *newsock, flags, true); err = ops->accept(sock, *newsock, flags, true);
if (err < 0) { if (err < 0) {
sock_release(*newsock); sock_release(*newsock);
*newsock = NULL; *newsock = NULL;
goto done; goto done;
} }
(*newsock)->ops = sock->ops; (*newsock)->ops = ops;
__module_get((*newsock)->ops->owner); __module_get(ops->owner);
done: done:
return err; return err;
@ -3519,7 +3543,7 @@ EXPORT_SYMBOL(kernel_accept);
int kernel_connect(struct socket *sock, struct sockaddr *addr, int addrlen, int kernel_connect(struct socket *sock, struct sockaddr *addr, int addrlen,
int flags) int flags)
{ {
return sock->ops->connect(sock, addr, addrlen, flags); return READ_ONCE(sock->ops)->connect(sock, addr, addrlen, flags);
} }
EXPORT_SYMBOL(kernel_connect); EXPORT_SYMBOL(kernel_connect);
@ -3534,7 +3558,7 @@ EXPORT_SYMBOL(kernel_connect);
int kernel_getsockname(struct socket *sock, struct sockaddr *addr) int kernel_getsockname(struct socket *sock, struct sockaddr *addr)
{ {
return sock->ops->getname(sock, addr, 0); return READ_ONCE(sock->ops)->getname(sock, addr, 0);
} }
EXPORT_SYMBOL(kernel_getsockname); EXPORT_SYMBOL(kernel_getsockname);
@ -3549,7 +3573,7 @@ EXPORT_SYMBOL(kernel_getsockname);
int kernel_getpeername(struct socket *sock, struct sockaddr *addr) int kernel_getpeername(struct socket *sock, struct sockaddr *addr)
{ {
return sock->ops->getname(sock, addr, 1); return READ_ONCE(sock->ops)->getname(sock, addr, 1);
} }
EXPORT_SYMBOL(kernel_getpeername); EXPORT_SYMBOL(kernel_getpeername);
@ -3563,7 +3587,7 @@ EXPORT_SYMBOL(kernel_getpeername);
int kernel_sock_shutdown(struct socket *sock, enum sock_shutdown_cmd how) int kernel_sock_shutdown(struct socket *sock, enum sock_shutdown_cmd how)
{ {
return sock->ops->shutdown(sock, how); return READ_ONCE(sock->ops)->shutdown(sock, how);
} }
EXPORT_SYMBOL(kernel_sock_shutdown); EXPORT_SYMBOL(kernel_sock_shutdown);

View file

@ -29,10 +29,11 @@ struct sock *unix_get_socket(struct file *filp)
/* Socket ? */ /* Socket ? */
if (S_ISSOCK(inode->i_mode) && !(filp->f_mode & FMODE_PATH)) { if (S_ISSOCK(inode->i_mode) && !(filp->f_mode & FMODE_PATH)) {
struct socket *sock = SOCKET_I(inode); struct socket *sock = SOCKET_I(inode);
const struct proto_ops *ops = READ_ONCE(sock->ops);
struct sock *s = sock->sk; struct sock *s = sock->sk;
/* PF_UNIX ? */ /* PF_UNIX ? */
if (s && sock->ops && sock->ops->family == PF_UNIX) if (s && ops && ops->family == PF_UNIX)
u_sock = s; u_sock = s;
} else { } else {
/* Could be an io_uring instance */ /* Could be an io_uring instance */