net/http: add Transport.GetProxyConnectHeader

Fixes golang/go#41048

Change-Id: I38e01605bffb6f85100c098051b0c416dd77f261
Reviewed-on: https://go-review.googlesource.com/c/go/+/259917
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Brad Fitzpatrick 2020-10-06 10:53:11 -07:00
parent db428ad7b6
commit 930fa890c9
2 changed files with 74 additions and 1 deletions

View file

@ -240,8 +240,18 @@ type Transport struct {
// ProxyConnectHeader optionally specifies headers to send to
// proxies during CONNECT requests.
// To set the header dynamically, see GetProxyConnectHeader.
ProxyConnectHeader Header
// GetProxyConnectHeader optionally specifies a func to return
// headers to send to proxyURL during a CONNECT request to the
// ip:port target.
// If it returns an error, the Transport's RoundTrip fails with
// that error. It can return (nil, nil) to not add headers.
// If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
// ignored.
GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)
// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
@ -313,6 +323,7 @@ func (t *Transport) Clone() *Transport {
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
ExpectContinueTimeout: t.ExpectContinueTimeout,
ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
GetProxyConnectHeader: t.GetProxyConnectHeader,
MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
ForceAttemptHTTP2: t.ForceAttemptHTTP2,
WriteBufferSize: t.WriteBufferSize,
@ -1623,7 +1634,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
}
case cm.targetScheme == "https":
conn := pconn.conn
hdr := t.ProxyConnectHeader
var hdr Header
if t.GetProxyConnectHeader != nil {
var err error
hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
if err != nil {
conn.Close()
return nil, err
}
} else {
hdr = t.ProxyConnectHeader
}
if hdr == nil {
hdr = make(Header)
}

View file

@ -5174,6 +5174,57 @@ func TestTransportProxyConnectHeader(t *testing.T) {
}
}
func TestTransportProxyGetConnectHeader(t *testing.T) {
defer afterTest(t)
reqc := make(chan *Request, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "CONNECT" {
t.Errorf("method = %q; want CONNECT", r.Method)
}
reqc <- r
c, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Errorf("Hijack: %v", err)
return
}
c.Close()
}))
defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
return url.Parse(ts.URL)
}
// These should be ignored:
c.Transport.(*Transport).ProxyConnectHeader = Header{
"User-Agent": {"foo"},
"Other": {"bar"},
}
c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
return Header{
"User-Agent": {"foo2"},
"Other": {"bar2"},
}, nil
}
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
if err == nil {
res.Body.Close()
t.Errorf("unexpected success")
}
select {
case <-time.After(3 * time.Second):
t.Fatal("timeout")
case r := <-reqc:
if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
}
if got, want := r.Header.Get("Other"), "bar2"; got != want {
t.Errorf("CONNECT request Other = %q; want %q", got, want)
}
}
}
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()
@ -5842,6 +5893,7 @@ func TestTransportClone(t *testing.T) {
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
ProxyConnectHeader: Header{},
GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
MaxResponseHeaderBytes: 1,
ForceAttemptHTTP2: true,
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{