From 930fa890c9b6a75700bda3dc4043de81350749ea Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 6 Oct 2020 10:53:11 -0700 Subject: [PATCH] net/http: add Transport.GetProxyConnectHeader Fixes golang/go#41048 Change-Id: I38e01605bffb6f85100c098051b0c416dd77f261 Reviewed-on: https://go-review.googlesource.com/c/go/+/259917 Trust: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Go Bot Reviewed-by: Damien Neil --- src/net/http/transport.go | 23 ++++++++++++++- src/net/http/transport_test.go | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index b97c4268b5..4546166430 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -240,8 +240,18 @@ type Transport struct { // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. + // To set the header dynamically, see GetProxyConnectHeader. ProxyConnectHeader Header + // GetProxyConnectHeader optionally specifies a func to return + // headers to send to proxyURL during a CONNECT request to the + // ip:port target. + // If it returns an error, the Transport's RoundTrip fails with + // that error. It can return (nil, nil) to not add headers. + // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is + // ignored. + GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) + // MaxResponseHeaderBytes specifies a limit on how many // response bytes are allowed in the server's response // header. @@ -313,6 +323,7 @@ func (t *Transport) Clone() *Transport { ResponseHeaderTimeout: t.ResponseHeaderTimeout, ExpectContinueTimeout: t.ExpectContinueTimeout, ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + GetProxyConnectHeader: t.GetProxyConnectHeader, MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, ForceAttemptHTTP2: t.ForceAttemptHTTP2, WriteBufferSize: t.WriteBufferSize, @@ -1623,7 +1634,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } case cm.targetScheme == "https": conn := pconn.conn - hdr := t.ProxyConnectHeader + var hdr Header + if t.GetProxyConnectHeader != nil { + var err error + hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr) + if err != nil { + conn.Close() + return nil, err + } + } else { + hdr = t.ProxyConnectHeader + } if hdr == nil { hdr = make(Header) } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index f4b7623630..a1c9e822b4 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -5174,6 +5174,57 @@ func TestTransportProxyConnectHeader(t *testing.T) { } } +func TestTransportProxyGetConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *Request, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + // These should be ignored: + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { + return Header{ + "User-Agent": {"foo2"}, + "Other": {"bar2"}, + }, nil + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar2"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() @@ -5842,6 +5893,7 @@ func TestTransportClone(t *testing.T) { ResponseHeaderTimeout: time.Second, ExpectContinueTimeout: time.Second, ProxyConnectHeader: Header{}, + GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{