diff --git a/src/internal/poll/fd_unix.go b/src/internal/poll/fd_unix.go index fe8a5c8ec0..3b17cd22b0 100644 --- a/src/internal/poll/fd_unix.go +++ b/src/internal/poll/fd_unix.go @@ -231,7 +231,7 @@ func (fd *FD) ReadFrom(p []byte) (int, syscall.Sockaddr, error) { } // ReadMsg wraps the recvmsg network call. -func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) { +func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) { if err := fd.readLock(); err != nil { return 0, 0, 0, nil, err } @@ -240,7 +240,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er return 0, 0, 0, nil, err } for { - n, oobn, flags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, 0) + n, oobn, sysflags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, flags) if err != nil { if err == syscall.EINTR { continue @@ -253,7 +253,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er } } err = fd.eofError(n, err) - return n, oobn, flags, sa, err + return n, oobn, sysflags, sa, err } } diff --git a/src/internal/poll/fd_windows.go b/src/internal/poll/fd_windows.go index d8c834f929..4a5169527c 100644 --- a/src/internal/poll/fd_windows.go +++ b/src/internal/poll/fd_windows.go @@ -1013,7 +1013,7 @@ func sockaddrToRaw(sa syscall.Sockaddr) (unsafe.Pointer, int32, error) { } // ReadMsg wraps the WSARecvMsg network call. -func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) { +func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) { if err := fd.readLock(); err != nil { return 0, 0, 0, nil, err } @@ -1028,6 +1028,7 @@ func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, er o.rsa = new(syscall.RawSockaddrAny) o.msg.Name = (syscall.Pointer)(unsafe.Pointer(o.rsa)) o.msg.Namelen = int32(unsafe.Sizeof(*o.rsa)) + o.msg.Flags = uint32(flags) n, err := execIO(o, func(o *operation) error { return windows.WSARecvMsg(o.fd.Sysfd, &o.msg, &o.qty, &o.o, nil) }) diff --git a/src/net/fd_posix.go b/src/net/fd_posix.go index 2945e46a48..4703ff33a1 100644 --- a/src/net/fd_posix.go +++ b/src/net/fd_posix.go @@ -64,10 +64,10 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return n, sa, wrapSyscallError(readFromSyscallName, err) } -func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { - n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob) +func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { + n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags) runtime.KeepAlive(fd) - return n, oobn, flags, sa, wrapSyscallError(readMsgSyscallName, err) + return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err) } func (fd *netFD) Write(p []byte) (nn int, err error) { diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index c1514f1698..b94eec0e18 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -75,7 +75,7 @@ func stripIPv4Header(n int, b []byte) int { func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { var sa syscall.Sockaddr - n, oobn, flags, sa, err = c.fd.readMsg(b, oob) + n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0) switch sa := sa.(type) { case *syscall.SockaddrInet4: addr = &IPAddr{IP: sa.Addr[0:]} diff --git a/src/net/net_fake.go b/src/net/net_fake.go index 49dc57c6ff..74fc1da6fd 100644 --- a/src/net/net_fake.go +++ b/src/net/net_fake.go @@ -268,7 +268,7 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return 0, nil, syscall.ENOSYS } -func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { +func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { return 0, 0, 0, nil, syscall.ENOSYS } diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go index 3b5346e573..fcfb9c004c 100644 --- a/src/net/udpsock_posix.go +++ b/src/net/udpsock_posix.go @@ -56,7 +56,7 @@ func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) { func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { var sa syscall.Sockaddr - n, oobn, flags, sa, err = c.fd.readMsg(b, oob) + n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0) switch sa := sa.(type) { case *syscall.SockaddrInet4: addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go index 1d1f27449f..0306b5989b 100644 --- a/src/net/unixsock_posix.go +++ b/src/net/unixsock_posix.go @@ -113,7 +113,11 @@ func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) { func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) { var sa syscall.Sockaddr - n, oobn, flags, sa, err = c.fd.readMsg(b, oob) + n, oobn, flags, sa, err = c.fd.readMsg(b, oob, readMsgFlags) + if oobn > 0 { + setReadMsgCloseOnExec(oob[:oobn]) + } + switch sa := sa.(type) { case *syscall.SockaddrUnix: if sa.Name != "" { diff --git a/src/net/unixsock_readmsg_linux.go b/src/net/unixsock_readmsg_linux.go new file mode 100644 index 0000000000..3296681017 --- /dev/null +++ b/src/net/unixsock_readmsg_linux.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux +// +build linux + +package net + +import ( + "syscall" +) + +const readMsgFlags = syscall.MSG_CMSG_CLOEXEC + +func setReadMsgCloseOnExec(oob []byte) { +} diff --git a/src/net/unixsock_readmsg_other.go b/src/net/unixsock_readmsg_other.go new file mode 100644 index 0000000000..c8db657cd6 --- /dev/null +++ b/src/net/unixsock_readmsg_other.go @@ -0,0 +1,13 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build (js && wasm) || windows +// +build js,wasm windows + +package net + +const readMsgFlags = 0 + +func setReadMsgCloseOnExec(oob []byte) { +} diff --git a/src/net/unixsock_readmsg_posix.go b/src/net/unixsock_readmsg_posix.go new file mode 100644 index 0000000000..07d7df5e66 --- /dev/null +++ b/src/net/unixsock_readmsg_posix.go @@ -0,0 +1,33 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd netbsd openbsd solaris + +package net + +import ( + "syscall" +) + +const readMsgFlags = 0 + +func setReadMsgCloseOnExec(oob []byte) { + scms, err := syscall.ParseSocketControlMessage(oob) + if err != nil { + return + } + + for _, scm := range scms { + if scm.Header.Level == syscall.SOL_SOCKET && scm.Header.Type == syscall.SCM_RIGHTS { + fds, err := syscall.ParseUnixRights(&scm) + if err != nil { + continue + } + for _, fd := range fds { + syscall.CloseOnExec(fd) + } + } + } +} diff --git a/src/net/unixsock_readmsg_test.go b/src/net/unixsock_readmsg_test.go new file mode 100644 index 0000000000..4961ecbe10 --- /dev/null +++ b/src/net/unixsock_readmsg_test.go @@ -0,0 +1,105 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package net + +import ( + "os" + "syscall" + "testing" + "time" +) + +func TestUnixConnReadMsgUnixSCMRightsCloseOnExec(t *testing.T) { + if !testableNetwork("unix") { + t.Skip("not unix system") + } + + scmFile, err := os.Open(os.DevNull) + if err != nil { + t.Fatalf("file open: %v", err) + } + defer scmFile.Close() + + rights := syscall.UnixRights(int(scmFile.Fd())) + fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) + if err != nil { + t.Fatalf("Socketpair: %v", err) + } + + writeFile := os.NewFile(uintptr(fds[0]), "write-socket") + defer writeFile.Close() + readFile := os.NewFile(uintptr(fds[1]), "read-socket") + defer readFile.Close() + + cw, err := FileConn(writeFile) + if err != nil { + t.Fatalf("FileConn: %v", err) + } + defer cw.Close() + cr, err := FileConn(readFile) + if err != nil { + t.Fatalf("FileConn: %v", err) + } + defer cr.Close() + + ucw, ok := cw.(*UnixConn) + if !ok { + t.Fatalf("got %T; want UnixConn", cw) + } + ucr, ok := cr.(*UnixConn) + if !ok { + t.Fatalf("got %T; want UnixConn", cr) + } + + oob := make([]byte, syscall.CmsgSpace(4)) + err = ucw.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + t.Fatalf("Can't set unix connection timeout: %v", err) + } + _, _, err = ucw.WriteMsgUnix(nil, rights, nil) + if err != nil { + t.Fatalf("UnixConn readMsg: %v", err) + } + err = ucr.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + t.Fatalf("Can't set unix connection timeout: %v", err) + } + _, oobn, _, _, err := ucr.ReadMsgUnix(nil, oob) + if err != nil { + t.Fatalf("UnixConn readMsg: %v", err) + } + + scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + t.Fatalf("ParseSocketControlMessage: %v", err) + } + if len(scms) != 1 { + t.Fatalf("got scms = %#v; expected 1 SocketControlMessage", scms) + } + scm := scms[0] + gotFds, err := syscall.ParseUnixRights(&scm) + if err != nil { + t.Fatalf("syscall.ParseUnixRights: %v", err) + } + if len(gotFds) != 1 { + t.Fatalf("got FDs %#v: wanted only 1 fd", gotFds) + } + defer func() { + if err := syscall.Close(int(gotFds[0])); err != nil { + t.Fatalf("fail to close gotFds: %v", err) + } + }() + + flags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(gotFds[0]), uintptr(syscall.F_GETFD), 0) + if errno != 0 { + t.Fatalf("Can't get flags of fd:%#v, with err:%v", gotFds[0], errno) + } + if flags&syscall.FD_CLOEXEC == 0 { + t.Fatalf("got flags %#x, want %#x (FD_CLOEXEC) set", flags, syscall.FD_CLOEXEC) + } +} diff --git a/src/syscall/creds_test.go b/src/syscall/creds_test.go index 736b497bc4..c1a8b516e8 100644 --- a/src/syscall/creds_test.go +++ b/src/syscall/creds_test.go @@ -105,8 +105,8 @@ func TestSCMCredentials(t *testing.T) { if err != nil { t.Fatalf("ReadMsgUnix: %v", err) } - if flags != 0 { - t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags) + if flags != syscall.MSG_CMSG_CLOEXEC { + t.Fatalf("ReadMsgUnix flags = %#x, want %#x (MSG_CMSG_CLOEXEC)", flags, syscall.MSG_CMSG_CLOEXEC) } if n != tt.dataLen { t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)