mirror of
https://github.com/golang/go
synced 2024-10-14 11:53:56 +00:00
net/http: add Transport.GetProxyConnectHeader
Fixes golang/go#41048 Change-Id: I38e01605bffb6f85100c098051b0c416dd77f261 Reviewed-on: https://go-review.googlesource.com/c/go/+/259917 Trust: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Go Bot <gobot@golang.org> Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
db428ad7b6
commit
930fa890c9
|
@ -240,8 +240,18 @@ type Transport struct {
|
|||
|
||||
// ProxyConnectHeader optionally specifies headers to send to
|
||||
// proxies during CONNECT requests.
|
||||
// To set the header dynamically, see GetProxyConnectHeader.
|
||||
ProxyConnectHeader Header
|
||||
|
||||
// GetProxyConnectHeader optionally specifies a func to return
|
||||
// headers to send to proxyURL during a CONNECT request to the
|
||||
// ip:port target.
|
||||
// If it returns an error, the Transport's RoundTrip fails with
|
||||
// that error. It can return (nil, nil) to not add headers.
|
||||
// If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
|
||||
// ignored.
|
||||
GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many
|
||||
// response bytes are allowed in the server's response
|
||||
// header.
|
||||
|
@ -313,6 +323,7 @@ func (t *Transport) Clone() *Transport {
|
|||
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
|
||||
ExpectContinueTimeout: t.ExpectContinueTimeout,
|
||||
ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
|
||||
GetProxyConnectHeader: t.GetProxyConnectHeader,
|
||||
MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
|
||||
ForceAttemptHTTP2: t.ForceAttemptHTTP2,
|
||||
WriteBufferSize: t.WriteBufferSize,
|
||||
|
@ -1623,7 +1634,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
|
|||
}
|
||||
case cm.targetScheme == "https":
|
||||
conn := pconn.conn
|
||||
hdr := t.ProxyConnectHeader
|
||||
var hdr Header
|
||||
if t.GetProxyConnectHeader != nil {
|
||||
var err error
|
||||
hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
hdr = t.ProxyConnectHeader
|
||||
}
|
||||
if hdr == nil {
|
||||
hdr = make(Header)
|
||||
}
|
||||
|
|
|
@ -5174,6 +5174,57 @@ func TestTransportProxyConnectHeader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTransportProxyGetConnectHeader(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
reqc := make(chan *Request, 1)
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
if r.Method != "CONNECT" {
|
||||
t.Errorf("method = %q; want CONNECT", r.Method)
|
||||
}
|
||||
reqc <- r
|
||||
c, _, err := w.(Hijacker).Hijack()
|
||||
if err != nil {
|
||||
t.Errorf("Hijack: %v", err)
|
||||
return
|
||||
}
|
||||
c.Close()
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := ts.Client()
|
||||
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
|
||||
return url.Parse(ts.URL)
|
||||
}
|
||||
// These should be ignored:
|
||||
c.Transport.(*Transport).ProxyConnectHeader = Header{
|
||||
"User-Agent": {"foo"},
|
||||
"Other": {"bar"},
|
||||
}
|
||||
c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
|
||||
return Header{
|
||||
"User-Agent": {"foo2"},
|
||||
"Other": {"bar2"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
|
||||
if err == nil {
|
||||
res.Body.Close()
|
||||
t.Errorf("unexpected success")
|
||||
}
|
||||
select {
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout")
|
||||
case r := <-reqc:
|
||||
if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
|
||||
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
|
||||
}
|
||||
if got, want := r.Header.Get("Other"), "bar2"; got != want {
|
||||
t.Errorf("CONNECT request Other = %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var errFakeRoundTrip = errors.New("fake roundtrip")
|
||||
|
||||
type funcRoundTripper func()
|
||||
|
@ -5842,6 +5893,7 @@ func TestTransportClone(t *testing.T) {
|
|||
ResponseHeaderTimeout: time.Second,
|
||||
ExpectContinueTimeout: time.Second,
|
||||
ProxyConnectHeader: Header{},
|
||||
GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
|
||||
MaxResponseHeaderBytes: 1,
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
|
||||
|
|
Loading…
Reference in a new issue