diff --git a/src/net/http/client.go b/src/net/http/client.go index 10f5684a797..ee0fd2cb62a 100644 --- a/src/net/http/client.go +++ b/src/net/http/client.go @@ -475,6 +475,7 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo URL: u, Header: make(Header), Cancel: ireq.Cancel, + ctx: ireq.ctx, } if ireq.Method == "POST" || ireq.Method == "PUT" { req.Method = "GET" diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index e4fed268032..b9e17c52709 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -8,6 +8,7 @@ package http_test import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -290,6 +291,33 @@ func TestClientRedirects(t *testing.T) { } } +func TestClientRedirectContext(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Redirect(w, r, "/", StatusFound) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + c := &Client{CheckRedirect: func(req *Request, via []*Request) error { + cancel() + if len(via) > 2 { + return errors.New("too many redirects") + } + return nil + }} + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(ctx) + _, err := c.Do(req) + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got error %T; want *url.Error") + } + if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn { + t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err) + } +} + func TestPostRedirects(t *testing.T) { defer afterTest(t) var log struct {