From 3f11a2f3747a7e2a92e1e65558631095508a5599 Mon Sep 17 00:00:00 2001 From: Robert Watson Date: Thu, 24 Jun 2004 01:37:04 +0000 Subject: [PATCH] Introduce sbreserve_locked(), which asserts the socket buffer lock on the socket buffer having its limits adjusted. sbreserve() now acquires the lock before calling sbreserve_locked(). In soreserve(), acquire socket buffer locks across read-modify-writes of socket buffer fields, and calls into sbreserve/sbrelease; make sure to acquire in keeping with the socket buffer lock order. In tcp_mss(), acquire the socket buffer lock in the calling context so that we have atomic read-modify -write on buffer sizes. --- sys/kern/uipc_sockbuf.c | 35 +++++++++++++++++++++++++++-------- sys/kern/uipc_socket2.c | 35 +++++++++++++++++++++++++++-------- sys/netinet/tcp_input.c | 8 ++++++-- sys/netinet/tcp_reass.c | 8 ++++++-- sys/sys/socketvar.h | 2 ++ 5 files changed, 68 insertions(+), 20 deletions(-) diff --git a/sys/kern/uipc_sockbuf.c b/sys/kern/uipc_sockbuf.c index 7f18ba87fe2f..e366c3c89769 100644 --- a/sys/kern/uipc_sockbuf.c +++ b/sys/kern/uipc_sockbuf.c @@ -456,24 +456,26 @@ soreserve(so, sndcc, rcvcc) { struct thread *td = curthread; - if (sbreserve(&so->so_snd, sndcc, so, td) == 0) - goto bad; - if (sbreserve(&so->so_rcv, rcvcc, so, td) == 0) - goto bad2; + SOCKBUF_LOCK(&so->so_snd); SOCKBUF_LOCK(&so->so_rcv); + if (sbreserve_locked(&so->so_snd, sndcc, so, td) == 0) + goto bad; + if (sbreserve_locked(&so->so_rcv, rcvcc, so, td) == 0) + goto bad2; if (so->so_rcv.sb_lowat == 0) so->so_rcv.sb_lowat = 1; - SOCKBUF_UNLOCK(&so->so_rcv); - SOCKBUF_LOCK(&so->so_snd); if (so->so_snd.sb_lowat == 0) so->so_snd.sb_lowat = MCLBYTES; if (so->so_snd.sb_lowat > so->so_snd.sb_hiwat) so->so_snd.sb_lowat = so->so_snd.sb_hiwat; + SOCKBUF_UNLOCK(&so->so_rcv); SOCKBUF_UNLOCK(&so->so_snd); return (0); bad2: - sbrelease(&so->so_snd, so); + sbrelease_locked(&so->so_snd, so); bad: + SOCKBUF_UNLOCK(&so->so_rcv); + SOCKBUF_UNLOCK(&so->so_snd); return (ENOBUFS); } @@ -503,7 +505,7 @@ sysctl_handle_sb_max(SYSCTL_HANDLER_ARGS) * if buffering efficiency is near the normal case. */ int -sbreserve(sb, cc, so, td) +sbreserve_locked(sb, cc, so, td) struct sockbuf *sb; u_long cc; struct socket *so; @@ -511,6 +513,8 @@ sbreserve(sb, cc, so, td) { rlim_t sbsize_limit; + SOCKBUF_LOCK_ASSERT(sb); + /* * td will only be NULL when we're in an interrupt * (e.g. in tcp_input()) @@ -532,6 +536,21 @@ sbreserve(sb, cc, so, td) return (1); } +int +sbreserve(sb, cc, so, td) + struct sockbuf *sb; + u_long cc; + struct socket *so; + struct thread *td; +{ + int error; + + SOCKBUF_LOCK(sb); + error = sbreserve_locked(sb, cc, so, td); + SOCKBUF_UNLOCK(sb); + return (error); +} + /* * Free mbufs held by a socket, and reserved mbuf space. */ diff --git a/sys/kern/uipc_socket2.c b/sys/kern/uipc_socket2.c index 7f18ba87fe2f..e366c3c89769 100644 --- a/sys/kern/uipc_socket2.c +++ b/sys/kern/uipc_socket2.c @@ -456,24 +456,26 @@ soreserve(so, sndcc, rcvcc) { struct thread *td = curthread; - if (sbreserve(&so->so_snd, sndcc, so, td) == 0) - goto bad; - if (sbreserve(&so->so_rcv, rcvcc, so, td) == 0) - goto bad2; + SOCKBUF_LOCK(&so->so_snd); SOCKBUF_LOCK(&so->so_rcv); + if (sbreserve_locked(&so->so_snd, sndcc, so, td) == 0) + goto bad; + if (sbreserve_locked(&so->so_rcv, rcvcc, so, td) == 0) + goto bad2; if (so->so_rcv.sb_lowat == 0) so->so_rcv.sb_lowat = 1; - SOCKBUF_UNLOCK(&so->so_rcv); - SOCKBUF_LOCK(&so->so_snd); if (so->so_snd.sb_lowat == 0) so->so_snd.sb_lowat = MCLBYTES; if (so->so_snd.sb_lowat > so->so_snd.sb_hiwat) so->so_snd.sb_lowat = so->so_snd.sb_hiwat; + SOCKBUF_UNLOCK(&so->so_rcv); SOCKBUF_UNLOCK(&so->so_snd); return (0); bad2: - sbrelease(&so->so_snd, so); + sbrelease_locked(&so->so_snd, so); bad: + SOCKBUF_UNLOCK(&so->so_rcv); + SOCKBUF_UNLOCK(&so->so_snd); return (ENOBUFS); } @@ -503,7 +505,7 @@ sysctl_handle_sb_max(SYSCTL_HANDLER_ARGS) * if buffering efficiency is near the normal case. */ int -sbreserve(sb, cc, so, td) +sbreserve_locked(sb, cc, so, td) struct sockbuf *sb; u_long cc; struct socket *so; @@ -511,6 +513,8 @@ sbreserve(sb, cc, so, td) { rlim_t sbsize_limit; + SOCKBUF_LOCK_ASSERT(sb); + /* * td will only be NULL when we're in an interrupt * (e.g. in tcp_input()) @@ -532,6 +536,21 @@ sbreserve(sb, cc, so, td) return (1); } +int +sbreserve(sb, cc, so, td) + struct sockbuf *sb; + u_long cc; + struct socket *so; + struct thread *td; +{ + int error; + + SOCKBUF_LOCK(sb); + error = sbreserve_locked(sb, cc, so, td); + SOCKBUF_UNLOCK(sb); + return (error); +} + /* * Free mbufs held by a socket, and reserved mbuf space. */ diff --git a/sys/netinet/tcp_input.c b/sys/netinet/tcp_input.c index 581fe9a2ee1e..cab335470748 100644 --- a/sys/netinet/tcp_input.c +++ b/sys/netinet/tcp_input.c @@ -2990,6 +2990,7 @@ tcp_mss(tp, offer) * Make the socket buffers an integral number of mss units; * if the mss is larger than the socket buffer, decrease the mss. */ + SOCKBUF_LOCK(&so->so_snd); if ((so->so_snd.sb_hiwat == tcp_sendspace) && metrics.rmx_sendpipe) bufsize = metrics.rmx_sendpipe; else @@ -3001,10 +3002,12 @@ tcp_mss(tp, offer) if (bufsize > sb_max) bufsize = sb_max; if (bufsize > so->so_snd.sb_hiwat) - (void)sbreserve(&so->so_snd, bufsize, so, NULL); + (void)sbreserve_locked(&so->so_snd, bufsize, so, NULL); } + SOCKBUF_UNLOCK(&so->so_snd); tp->t_maxseg = mss; + SOCKBUF_LOCK(&so->so_rcv); if ((so->so_rcv.sb_hiwat == tcp_recvspace) && metrics.rmx_recvpipe) bufsize = metrics.rmx_recvpipe; else @@ -3014,8 +3017,9 @@ tcp_mss(tp, offer) if (bufsize > sb_max) bufsize = sb_max; if (bufsize > so->so_rcv.sb_hiwat) - (void)sbreserve(&so->so_rcv, bufsize, so, NULL); + (void)sbreserve_locked(&so->so_rcv, bufsize, so, NULL); } + SOCKBUF_UNLOCK(&so->so_rcv); /* * While we're here, check the others too */ diff --git a/sys/netinet/tcp_reass.c b/sys/netinet/tcp_reass.c index 581fe9a2ee1e..cab335470748 100644 --- a/sys/netinet/tcp_reass.c +++ b/sys/netinet/tcp_reass.c @@ -2990,6 +2990,7 @@ tcp_mss(tp, offer) * Make the socket buffers an integral number of mss units; * if the mss is larger than the socket buffer, decrease the mss. */ + SOCKBUF_LOCK(&so->so_snd); if ((so->so_snd.sb_hiwat == tcp_sendspace) && metrics.rmx_sendpipe) bufsize = metrics.rmx_sendpipe; else @@ -3001,10 +3002,12 @@ tcp_mss(tp, offer) if (bufsize > sb_max) bufsize = sb_max; if (bufsize > so->so_snd.sb_hiwat) - (void)sbreserve(&so->so_snd, bufsize, so, NULL); + (void)sbreserve_locked(&so->so_snd, bufsize, so, NULL); } + SOCKBUF_UNLOCK(&so->so_snd); tp->t_maxseg = mss; + SOCKBUF_LOCK(&so->so_rcv); if ((so->so_rcv.sb_hiwat == tcp_recvspace) && metrics.rmx_recvpipe) bufsize = metrics.rmx_recvpipe; else @@ -3014,8 +3017,9 @@ tcp_mss(tp, offer) if (bufsize > sb_max) bufsize = sb_max; if (bufsize > so->so_rcv.sb_hiwat) - (void)sbreserve(&so->so_rcv, bufsize, so, NULL); + (void)sbreserve_locked(&so->so_rcv, bufsize, so, NULL); } + SOCKBUF_UNLOCK(&so->so_rcv); /* * While we're here, check the others too */ diff --git a/sys/sys/socketvar.h b/sys/sys/socketvar.h index 6ea4bb808897..68f6a2fc4491 100644 --- a/sys/sys/socketvar.h +++ b/sys/sys/socketvar.h @@ -465,6 +465,8 @@ void sbrelease(struct sockbuf *sb, struct socket *so); void sbrelease_locked(struct sockbuf *sb, struct socket *so); int sbreserve(struct sockbuf *sb, u_long cc, struct socket *so, struct thread *td); +int sbreserve_locked(struct sockbuf *sb, u_long cc, struct socket *so, + struct thread *td); void sbtoxsockbuf(struct sockbuf *sb, struct xsockbuf *xsb); int sbwait(struct sockbuf *sb); int sb_lock(struct sockbuf *sb);