diff --git a/src/net/http/transport.go b/src/net/http/transport.go index bb9657f4ee..f0ae6ef0b9 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1375,6 +1375,17 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) { return } +// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if +// the Conn implements io.ReaderFrom, it can take advantage of optimizations +// such as sendfile. +func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) { + n, err = io.Copy(w.pc.conn, r) + w.pc.nwrite += n + return +} + +var _ io.ReaderFrom = (*persistConnWriter)(nil) + // connectMethod is the map key (in its String form) for keeping persistent // TCP connections alive for subsequent HTTP requests. // diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 6e075847dd..74767f8499 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -5059,3 +5059,146 @@ func TestTransportRequestReplayable(t *testing.T) { }) } } + +// testMockTCPConn is a mock TCP connection used to test that +// ReadFrom is called when sending the request body. +type testMockTCPConn struct { + *net.TCPConn + + ReadFromCalled bool +} + +func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { + c.ReadFromCalled = true + return c.TCPConn.ReadFrom(r) +} + +func TestTransportRequestWriteRoundTrip(t *testing.T) { + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := ioutil.TempFile("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + readerFunc func() (io.Reader, func(), error) + contentLength int64 + expectedReadFrom bool + }{ + { + name: "file, length", + readerFunc: newFileFunc, + contentLength: nBytes, + expectedReadFrom: true, + }, + { + name: "file, no length", + readerFunc: newFileFunc, + }, + { + name: "file, negative length", + readerFunc: newFileFunc, + contentLength: -1, + }, + { + name: "buffer", + contentLength: nBytes, + readerFunc: newBufferFunc, + }, + { + name: "buffer, no length", + readerFunc: newBufferFunc, + }, + { + name: "buffer, length -1", + contentLength: -1, + readerFunc: newBufferFunc, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, cleanup, err := tc.readerFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + tConn := &testMockTCPConn{} + trFunc := func(tr *Transport) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) + } + + tConn.TCPConn = tcpConn + return tConn, nil + } + } + + cst := newClientServerTest( + t, + h1Mode, + HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + w.WriteHeader(200) + }), + trFunc, + ) + defer cst.close() + + req, err := NewRequest("PUT", cst.ts.URL, r) + if err != nil { + t.Fatal(err) + } + req.ContentLength = tc.contentLength + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("status code = %d; want 200", resp.StatusCode) + } + + if !tConn.ReadFromCalled && tc.expectedReadFrom { + t.Fatalf("did not call ReadFrom") + } + + if tConn.ReadFromCalled && !tc.expectedReadFrom { + t.Fatalf("ReadFrom was unexpectedly invoked") + } + }) + } +}