diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 115b6c6578c..03775685fb6 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -16,6 +16,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) @@ -56,6 +57,11 @@ type Conn struct { input *block // application data waiting to be read hand bytes.Buffer // handshake data waiting to be read + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + tmp [16]byte } @@ -855,8 +861,22 @@ func (c *Conn) readHandshake() (interface{}, error) { return m, nil } +var errClosed = errors.New("crypto/tls: use of closed connection") + // Write writes data to the connection. func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, errClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + defer atomic.AddInt32(&c.activeCall, -2) + break + } + } + if err := c.Handshake(); err != nil { return 0, err } @@ -960,6 +980,27 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Close closes the connection. func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return errClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + var alertErr error c.handshakeMutex.Lock() diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 6b5d455be41..5cc14278a06 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "errors" "fmt" "internal/testenv" "io" @@ -364,3 +365,104 @@ func TestVerifyHostnameResumed(t *testing.T) { c.Close() } } + +func TestConnCloseBreakingWrite(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + var serr error + var sconn net.Conn + go func() { + var err error + sconn, err = ln.Accept() + if err != nil { + serr = err + srvCh <- nil + return + } + serverConfig := *testConfig + srv := Server(sconn, &serverConfig) + if err := srv.Handshake(); err != nil { + serr = fmt.Errorf("handshake: %v", err) + srvCh <- nil + return + } + srvCh <- srv + }() + + cconn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cconn.Close() + + conn := &changeImplConn{ + Conn: cconn, + } + + clientConfig := *testConfig + tconn := Client(conn, &clientConfig) + if err := tconn.Handshake(); err != nil { + t.Fatal(err) + } + + srv := <-srvCh + if srv == nil { + t.Fatal(serr) + } + defer sconn.Close() + + connClosed := make(chan struct{}) + conn.closeFunc = func() error { + close(connClosed) + return nil + } + + inWrite := make(chan bool, 1) + var errConnClosed = errors.New("conn closed for test") + conn.writeFunc = func(p []byte) (n int, err error) { + inWrite <- true + <-connClosed + return 0, errConnClosed + } + + closeReturned := make(chan bool, 1) + go func() { + <-inWrite + tconn.Close() // test that this doesn't block forever. + closeReturned <- true + }() + + _, err = tconn.Write([]byte("foo")) + if err != errConnClosed { + t.Errorf("Write error = %v; want errConnClosed", err) + } + + <-closeReturned + if err := tconn.Close(); err != errClosed { + t.Errorf("Close error = %v; want errClosed", err) + } +} + +// changeImplConn is a net.Conn which can change its Write and Close +// methods. +type changeImplConn struct { + net.Conn + writeFunc func([]byte) (int, error) + closeFunc func() error +} + +func (w *changeImplConn) Write(p []byte) (n int, err error) { + if w.writeFunc != nil { + return w.writeFunc(p) + } + return w.Conn.Write(p) +} + +func (w *changeImplConn) Close() error { + if w.closeFunc != nil { + return w.closeFunc() + } + return w.Conn.Close() +}