mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 16:53:57 +00:00
Set extra proxy headers in all tsh
HTTP requests (#19766)
Before this commit, the `tsh` HTTP requests that had the extra headers were those that did not use `roundtrip`. This commit leverages `http.RoundTripper.RoundTrip` to ensure that all requests have the the extra headers.
This commit is contained in:
parent
a2c2f3a092
commit
d72ac18247
|
@ -71,32 +71,39 @@ func parse(addr string) (*url.URL, error) {
|
|||
return addrURL, nil
|
||||
}
|
||||
|
||||
// HTTPFallbackRoundTripper is a wrapper for http.Transport that downgrades requests
|
||||
// to plain HTTP when using a plain HTTP proxy at localhost.
|
||||
type HTTPFallbackRoundTripper struct {
|
||||
// HTTPRoundTripper is a wrapper for http.Transport that
|
||||
// - adds extra HTTP headers to all requests, and
|
||||
// - downgrades requests to plain HTTP when proxy is at localhost and the wrapped http.Transport has TLSClientConfig.InsecureSkipVerify set to true.
|
||||
type HTTPRoundTripper struct {
|
||||
*http.Transport
|
||||
// extraHeaders is a map of extra HTTP headers to be included in requests.
|
||||
extraHeaders map[string]string
|
||||
// isProxyHTTPLocalhost indicates that the HTTP_PROXY is at "http://localhost"
|
||||
isProxyHTTPLocalhost bool
|
||||
}
|
||||
|
||||
// NewHTTPFallbackRoundTripper creates a new initialized HTTP fallback roundtripper.
|
||||
func NewHTTPFallbackRoundTripper(transport *http.Transport, insecure bool) *HTTPFallbackRoundTripper {
|
||||
// NewHTTPRoundTripper creates a new initialized HTTP roundtripper.
|
||||
func NewHTTPRoundTripper(transport *http.Transport, extraHeaders map[string]string) *HTTPRoundTripper {
|
||||
proxyConfig := httpproxy.FromEnvironment()
|
||||
rt := HTTPFallbackRoundTripper{
|
||||
return &HTTPRoundTripper{
|
||||
Transport: transport,
|
||||
extraHeaders: extraHeaders,
|
||||
isProxyHTTPLocalhost: strings.HasPrefix(proxyConfig.HTTPProxy, "http://localhost"),
|
||||
}
|
||||
if rt.TLSClientConfig != nil {
|
||||
rt.TLSClientConfig.InsecureSkipVerify = insecure
|
||||
}
|
||||
return &rt
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction. Part of the RoundTripper interface.
|
||||
func (rt *HTTPFallbackRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
tlsConfig := rt.Transport.TLSClientConfig
|
||||
func (rt *HTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Add extra HTTP headers.
|
||||
for header, v := range rt.extraHeaders {
|
||||
req.Header.Add(header, v)
|
||||
}
|
||||
|
||||
// Use plain HTTP if proxying via http://localhost in insecure mode.
|
||||
tlsConfig := rt.Transport.TLSClientConfig
|
||||
if rt.isProxyHTTPLocalhost && tlsConfig != nil && tlsConfig.InsecureSkipVerify {
|
||||
req.URL.Scheme = "http"
|
||||
}
|
||||
|
||||
return rt.Transport.RoundTrip(req)
|
||||
}
|
||||
|
|
|
@ -19,7 +19,9 @@ package proxy
|
|||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -182,12 +184,14 @@ func buildProxyAddr(addr, user, pass string) (string, error) {
|
|||
func TestProxyAwareRoundTripper(t *testing.T) {
|
||||
t.Setenv("HTTP_PROXY", "http://localhost:8888")
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{},
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
|
||||
},
|
||||
}
|
||||
rt := NewHTTPFallbackRoundTripper(transport, true)
|
||||
rt := NewHTTPRoundTripper(transport, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://localhost:9999", nil)
|
||||
require.NoError(t, err)
|
||||
// Don't care about response, only if the scheme changed.
|
||||
|
@ -197,6 +201,190 @@ func TestProxyAwareRoundTripper(t *testing.T) {
|
|||
require.Equal(t, "http", req.URL.Scheme)
|
||||
}
|
||||
|
||||
// TestHttpRoundTripperDowngrade tests that the round tripper downgrades https requests to http
|
||||
// when HTTP_PROXY is set to "http://localhost:*" (i.e. there's an http proxy running on localhost).
|
||||
func TestHttpRoundTripperDowngrade(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
setHTTPProxy bool
|
||||
shouldHitProxy bool
|
||||
}{
|
||||
{
|
||||
desc: "hits http proxy if insecure and localhost http proxy is set",
|
||||
setHTTPProxy: true,
|
||||
shouldHitProxy: true,
|
||||
},
|
||||
{
|
||||
desc: "does not hit http proxy if insecure and localhost http proxy is not set",
|
||||
setHTTPProxy: false,
|
||||
shouldHitProxy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
newHandler := func(runningAtProxy bool, wasHit *bool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
*wasHit = true
|
||||
if tc.shouldHitProxy {
|
||||
// If the request should hit the proxy, then:
|
||||
// - this handler is running at the proxy, and
|
||||
// - the scheme should be http.
|
||||
require.True(t, runningAtProxy)
|
||||
require.Equal(t, "http", r.URL.Scheme)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// Start localhost http proxy.
|
||||
runningAtProxy := true
|
||||
loopback := true
|
||||
https := false
|
||||
httpProxyWasHit := false
|
||||
httpProxy, err := newServer(newHandler(runningAtProxy, &httpProxyWasHit), loopback, https)
|
||||
require.NoError(t, err)
|
||||
defer httpProxy.Close()
|
||||
|
||||
// Start non-localhost https server.
|
||||
runningAtProxy = false
|
||||
loopback = false
|
||||
https = true
|
||||
httpsSrvWasHit := false
|
||||
httpsSrv, err := newServer(newHandler(runningAtProxy, &httpsSrvWasHit), loopback, https)
|
||||
require.NoError(t, err)
|
||||
defer httpsSrv.Close()
|
||||
|
||||
if tc.setHTTPProxy {
|
||||
// url.Parse won't correctly parse an absolute URL without a scheme.
|
||||
u, err := url.Parse("http://" + httpProxy.Listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
_, port, err := net.SplitHostPort(u.Host)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set HTTP_PROXY to "http://localhost:*".
|
||||
t.Setenv("HTTP_PROXY", fmt.Sprintf("http://localhost:%s", port))
|
||||
}
|
||||
|
||||
clt := newClient(t, nil)
|
||||
|
||||
// Perform any request.
|
||||
// Set addr to the https server. If HTTP_PROXY was set above,
|
||||
// the http proxy should be hit regardless.
|
||||
addr := httpsSrv.Listener.Addr().String()
|
||||
request(t, clt, addr)
|
||||
|
||||
// Validate that the correct server was hit.
|
||||
require.Equal(t, tc.shouldHitProxy, httpProxyWasHit)
|
||||
require.Equal(t, !tc.shouldHitProxy, httpsSrvWasHit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHttpRoundTripperExtraHeaders tests that the round tripper adds the extra headers set.
|
||||
func TestHttpRoundTripperExtraHeaders(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
extraHeaders map[string]string
|
||||
expectHeaders func(*testing.T, http.Header)
|
||||
}{
|
||||
{
|
||||
desc: "extra headers are added",
|
||||
extraHeaders: map[string]string{
|
||||
"header1": "value1",
|
||||
"header2": "value2",
|
||||
},
|
||||
expectHeaders: func(t *testing.T, headers http.Header) {
|
||||
require.Equal(t, []string{"value1"}, headers.Values("header1"))
|
||||
require.Equal(t, []string{"value2"}, headers.Values("header2"))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "extra headers do not overwrite existing headers",
|
||||
extraHeaders: map[string]string{
|
||||
"header1": "value1",
|
||||
"Content-Type": "value2",
|
||||
},
|
||||
expectHeaders: func(t *testing.T, headers http.Header) {
|
||||
require.Equal(t, []string{"value1"}, headers.Values("header1"))
|
||||
require.Equal(t, []string{"application/json", "value2"}, headers.Values("Content-Type"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
tc.expectHeaders(t, r.Header)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Start localhost https server.
|
||||
loopback := true
|
||||
tls := true
|
||||
httpsSrv, err := newServer(handler, loopback, tls)
|
||||
require.NoError(t, err)
|
||||
defer httpsSrv.Close()
|
||||
|
||||
clt := newClient(t, tc.extraHeaders)
|
||||
|
||||
// Perform any request.
|
||||
// Set the address to the localhost https server.
|
||||
addr := httpsSrv.Listener.Addr().String()
|
||||
request(t, clt, addr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newServer starts a new server that:
|
||||
// - runs TLS if `https`
|
||||
// - uses a loopback listener if `loopback`
|
||||
func newServer(handler http.HandlerFunc, loopback bool, https bool) (*httptest.Server, error) {
|
||||
srv := httptest.NewUnstartedServer(handler)
|
||||
|
||||
if !loopback {
|
||||
// Replace the test-supplied loopback listener with the first available
|
||||
// non-loopback address.
|
||||
srv.Listener.Close()
|
||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
srv.Listener = l
|
||||
}
|
||||
|
||||
if https {
|
||||
srv.StartTLS()
|
||||
} else {
|
||||
srv.Start()
|
||||
}
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// newClient creates a new https roundtrip client.
|
||||
func newClient(t *testing.T, extraHeaders map[string]string) *http.Client {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
// Setting insecure ensures that https requests succeed.
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
|
||||
},
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: NewHTTPRoundTripper(transport, extraHeaders),
|
||||
}
|
||||
}
|
||||
|
||||
// request perform a POST request.
|
||||
func request(t *testing.T, clt *http.Client, addr string) {
|
||||
url := "https://" + addr + "/v1/content"
|
||||
resp, err := clt.Post(url, "application/json", nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
successTests := []struct {
|
||||
name, addr, scheme, host, path string
|
||||
|
|
|
@ -103,7 +103,7 @@ func newWebClient(cfg *Config) (*http.Client, error) {
|
|||
}
|
||||
return &http.Client{
|
||||
Transport: otelhttp.NewTransport(
|
||||
proxy.NewHTTPFallbackRoundTripper(&transport, cfg.Insecure),
|
||||
proxy.NewHTTPRoundTripper(&transport, nil),
|
||||
otelhttp.WithSpanNameFormatter(tracing.HTTPTransportFormatter),
|
||||
),
|
||||
Timeout: cfg.Timeout,
|
||||
|
|
|
@ -2787,7 +2787,12 @@ func makeProxySSHClient(ctx context.Context, tc *TeleportClient, sshConfig *ssh.
|
|||
if len(tc.JumpHosts) > 0 {
|
||||
sshProxyAddr = tc.JumpHosts[0].Addr.Addr
|
||||
// Check if JumpHost address is a proxy web address.
|
||||
resp, err := webclient.Find(&webclient.Config{Context: ctx, ProxyAddr: sshProxyAddr, Insecure: tc.InsecureSkipVerify})
|
||||
resp, err := webclient.Find(&webclient.Config{
|
||||
Context: ctx,
|
||||
ProxyAddr: sshProxyAddr,
|
||||
Insecure: tc.InsecureSkipVerify,
|
||||
ExtraHeaders: tc.ExtraProxyHeaders,
|
||||
})
|
||||
// If JumpHost address is a proxy web port and proxy supports TLSRouting dial proxy with TLSWrapper.
|
||||
if err == nil && resp.Proxy.TLSRoutingEnabled {
|
||||
log.Infof("Connecting to proxy=%v login=%q using TLS Routing JumpHost", sshProxyAddr, sshConfig.User)
|
||||
|
@ -3284,6 +3289,7 @@ func (tc *TeleportClient) newSSHLogin(priv *keys.PrivateKey) (SSHLogin, error) {
|
|||
RouteToCluster: tc.SiteName,
|
||||
KubernetesCluster: tc.KubernetesCluster,
|
||||
AttestationStatement: attestationStatement,
|
||||
ExtraHeaders: tc.ExtraProxyHeaders,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -37,39 +37,30 @@ import (
|
|||
)
|
||||
|
||||
func NewInsecureWebClient() *http.Client {
|
||||
// Because Teleport clients can't be configured (yet), they take the default
|
||||
// list of cipher suites from Go.
|
||||
tlsConfig := utils.TLSConfig(nil)
|
||||
transport := http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
|
||||
},
|
||||
}
|
||||
return newClient(true, nil, nil)
|
||||
}
|
||||
|
||||
func newClient(insecure bool, pool *x509.CertPool, extraHeaders map[string]string) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: otelhttp.NewTransport(
|
||||
apiproxy.NewHTTPFallbackRoundTripper(&transport, true /* insecure */),
|
||||
apiproxy.NewHTTPRoundTripper(httpTransport(insecure, pool), extraHeaders),
|
||||
otelhttp.WithSpanNameFormatter(tracing.HTTPTransportFormatter),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func newClientWithPool(pool *x509.CertPool) *http.Client {
|
||||
func httpTransport(insecure bool, pool *x509.CertPool) *http.Transport {
|
||||
// Because Teleport clients can't be configured (yet), they take the default
|
||||
// list of cipher suites from Go.
|
||||
tlsConfig := utils.TLSConfig(nil)
|
||||
tlsConfig.InsecureSkipVerify = insecure
|
||||
tlsConfig.RootCAs = pool
|
||||
|
||||
return &http.Client{
|
||||
Transport: otelhttp.NewTransport(
|
||||
&http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
|
||||
},
|
||||
},
|
||||
otelhttp.WithSpanNameFormatter(tracing.HTTPTransportFormatter),
|
||||
),
|
||||
return &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -46,9 +46,9 @@ func TestNewInsecureWebClientNoProxy(t *testing.T) {
|
|||
require.Contains(t, err.Error(), "no such host")
|
||||
}
|
||||
|
||||
func TestNewClientWithPoolHTTPProxy(t *testing.T) {
|
||||
func TestNewSecureWebClientHTTPProxy(t *testing.T) {
|
||||
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
|
||||
client := newClientWithPool(nil)
|
||||
client := newClient(false, nil, nil)
|
||||
//nolint:bodyclose // resp should be nil, so there will be no body to close.
|
||||
resp, err := client.Get("https://fakedomain.example.com")
|
||||
// Client should try to proxy through nonexistent server at localhost.
|
||||
|
@ -58,10 +58,10 @@ func TestNewClientWithPoolHTTPProxy(t *testing.T) {
|
|||
require.Contains(t, err.Error(), "no such host")
|
||||
}
|
||||
|
||||
func TestNewClientWithPoolNoProxy(t *testing.T) {
|
||||
func TestNewSecureWebClientNoProxy(t *testing.T) {
|
||||
t.Setenv("HTTPS_PROXY", "fakeproxy.example.com:9999")
|
||||
t.Setenv("NO_PROXY", "fakedomain.example.com")
|
||||
client := newClientWithPool(nil)
|
||||
client := newClient(false, nil, nil)
|
||||
//nolint:bodyclose // resp should be nil, so there will be no body to close.
|
||||
resp, err := client.Get("https://fakedomain.example.com")
|
||||
require.Error(t, err, "GET unexpectedly succeeded: %+v", resp)
|
||||
|
|
|
@ -91,7 +91,7 @@ type RedirectorConfig struct {
|
|||
|
||||
// NewRedirector returns new local web server redirector
|
||||
func NewRedirector(ctx context.Context, login SSHLoginSSO, config *RedirectorConfig) (*Redirector, error) {
|
||||
clt, proxyURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool)
|
||||
clt, proxyURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ type SSHLogin struct {
|
|||
TTL time.Duration
|
||||
// Insecure turns off verification for x509 target proxy
|
||||
Insecure bool
|
||||
// Pool is x509 cert pool to use for server certifcate verification
|
||||
// Pool is x509 cert pool to use for server certificate verification
|
||||
Pool *x509.CertPool
|
||||
// Compatibility sets compatibility mode for SSH certificates
|
||||
Compatibility string
|
||||
|
@ -184,6 +184,8 @@ type SSHLogin struct {
|
|||
KubernetesCluster string
|
||||
// AttestationStatement is an attestation statement.
|
||||
AttestationStatement *keys.AttestationStatement
|
||||
// ExtraHeaders is a map of extra HTTP headers to be included in requests.
|
||||
ExtraHeaders map[string]string
|
||||
}
|
||||
|
||||
// SSHLoginSSO contains SSH login parameters for SSO login.
|
||||
|
@ -250,13 +252,13 @@ type SSHLoginPasswordless struct {
|
|||
}
|
||||
|
||||
// initClient creates a new client to the HTTPS web proxy.
|
||||
func initClient(proxyAddr string, insecure bool, pool *x509.CertPool) (*WebClient, *url.URL, error) {
|
||||
func initClient(proxyAddr string, insecure bool, pool *x509.CertPool, extraHeaders map[string]string) (*WebClient, *url.URL, error) {
|
||||
log := logrus.WithFields(logrus.Fields{
|
||||
trace.Component: teleport.ComponentClient,
|
||||
})
|
||||
log.Debugf("HTTPS client init(proxyAddr=%v, insecure=%v)", proxyAddr, insecure)
|
||||
log.Debugf("HTTPS client init(proxyAddr=%v, insecure=%v, extraHeaders=%v)", proxyAddr, insecure, extraHeaders)
|
||||
|
||||
// validate proxyAddr:
|
||||
// validate proxy address
|
||||
host, port, err := net.SplitHostPort(proxyAddr)
|
||||
if err != nil || host == "" || port == "" {
|
||||
if err != nil {
|
||||
|
@ -270,18 +272,13 @@ func initClient(proxyAddr string, insecure bool, pool *x509.CertPool) (*WebClien
|
|||
return nil, nil, trace.BadParameter("'%v' is not a valid proxy address", proxyAddr)
|
||||
}
|
||||
|
||||
var opts []roundtrip.ClientParam
|
||||
|
||||
if insecure {
|
||||
// Skip https cert verification, print a warning that this is insecure.
|
||||
// Skipping https cert verification, print a warning that this is insecure.
|
||||
fmt.Fprintf(os.Stderr, "WARNING: You are using insecure connection to Teleport proxy %v\n", proxyAddr)
|
||||
opts = append(opts, roundtrip.HTTPClient(NewInsecureWebClient()))
|
||||
} else if pool != nil {
|
||||
// use custom set of trusted CAs
|
||||
opts = append(opts, roundtrip.HTTPClient(newClientWithPool(pool)))
|
||||
}
|
||||
|
||||
clt, err := NewWebClient(proxyAddr, opts...)
|
||||
opt := roundtrip.HTTPClient(newClient(insecure, pool, extraHeaders))
|
||||
clt, err := NewWebClient(proxyAddr, opt)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -360,7 +357,7 @@ func SSHAgentSSOLogin(ctx context.Context, login SSHLoginSSO, config *Redirector
|
|||
|
||||
// SSHAgentLogin is used by tsh to fetch local user credentials.
|
||||
func SSHAgentLogin(ctx context.Context, login SSHLoginDirect) (*auth.SSHLoginResponse, error) {
|
||||
clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool)
|
||||
clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -395,7 +392,7 @@ func SSHAgentLogin(ctx context.Context, login SSHLoginDirect) (*auth.SSHLoginRes
|
|||
//
|
||||
// Returns the SSH certificate if authn is successful or an error.
|
||||
func SSHAgentPasswordlessLogin(ctx context.Context, login SSHLoginPasswordless) (*auth.SSHLoginResponse, error) {
|
||||
webClient, webURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool)
|
||||
webClient, webURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -466,7 +463,7 @@ func SSHAgentPasswordlessLogin(ctx context.Context, login SSHLoginPasswordless)
|
|||
// prompt the user to provide 2nd factor and pass the response to the proxy.
|
||||
// If the authentication succeeds, we will get a temporary certificate back.
|
||||
func SSHAgentMFALogin(ctx context.Context, login SSHLoginMFA) (*auth.SSHLoginResponse, error) {
|
||||
clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool)
|
||||
clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -534,7 +531,7 @@ func SSHAgentMFALogin(ctx context.Context, login SSHLoginMFA) (*auth.SSHLoginRes
|
|||
|
||||
// HostCredentials is used to fetch host credentials for a node.
|
||||
func HostCredentials(ctx context.Context, proxyAddr string, insecure bool, req types.RegisterUsingTokenRequest) (*proto.Certs, error) {
|
||||
clt, _, err := initClient(proxyAddr, insecure, nil)
|
||||
clt, _, err := initClient(proxyAddr, insecure, nil, nil)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -554,7 +551,7 @@ func HostCredentials(ctx context.Context, proxyAddr string, insecure bool, req t
|
|||
|
||||
// GetWebConfig is used by teleterm to fetch webconfig.js from proxies
|
||||
func GetWebConfig(ctx context.Context, proxyAddr string, insecure bool) (*webclient.WebConfig, error) {
|
||||
clt, _, err := initClient(proxyAddr, insecure, nil)
|
||||
clt, _, err := initClient(proxyAddr, insecure, nil, nil)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -36,71 +36,84 @@ import (
|
|||
"github.com/gravitational/teleport/lib/client"
|
||||
)
|
||||
|
||||
func TestPlainHttpFallback(t *testing.T) {
|
||||
// TestHostCredentialsHttpFallback tests that HostCredentials requests (/v1/webapi/host/credentials/)
|
||||
// fall back to HTTP only if the address is a loopback and the insecure mode was set.
|
||||
func TestHostCredentialsHttpFallback(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
handler http.HandlerFunc
|
||||
actionUnderTest func(ctx context.Context, addr string, insecure bool) error
|
||||
desc string
|
||||
loopback bool
|
||||
insecure bool
|
||||
fallback bool
|
||||
}{
|
||||
{
|
||||
desc: "HostCredentials",
|
||||
path: "/v1/webapi/host/credentials",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI != "/v1/webapi/host/credentials" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(proto.Certs{})
|
||||
},
|
||||
actionUnderTest: func(ctx context.Context, addr string, insecure bool) error {
|
||||
_, err := client.HostCredentials(ctx, addr, insecure, types.RegisterUsingTokenRequest{})
|
||||
return err
|
||||
},
|
||||
desc: "falls back to http if loopback and insecure",
|
||||
loopback: true,
|
||||
insecure: true,
|
||||
fallback: true,
|
||||
},
|
||||
{
|
||||
desc: "does not fall back to http if loopback and secure",
|
||||
loopback: true,
|
||||
insecure: false,
|
||||
fallback: false,
|
||||
},
|
||||
{
|
||||
desc: "does not fall back to http if non-loopback and insecure",
|
||||
loopback: false,
|
||||
insecure: true,
|
||||
fallback: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.desc, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
for _, tc := range testCases {
|
||||
// Start an http server (not https) so that the request only succeeds
|
||||
// if the fallback occurs.
|
||||
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI != "/v1/webapi/host/credentials" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(proto.Certs{})
|
||||
}
|
||||
httpSvr, err := newServer(handler, tc.loopback)
|
||||
require.NoError(t, err)
|
||||
defer httpSvr.Close()
|
||||
|
||||
t.Run("Allowed on insecure & loopback", func(t *testing.T) {
|
||||
httpSvr := httptest.NewServer(testCase.handler)
|
||||
defer httpSvr.Close()
|
||||
// Send the HostCredentials request.
|
||||
ctx := context.Background()
|
||||
_, err = client.HostCredentials(ctx, httpSvr.Listener.Addr().String(), tc.insecure, types.RegisterUsingTokenRequest{})
|
||||
|
||||
err := testCase.actionUnderTest(ctx, httpSvr.Listener.Addr().String(), true /* insecure */)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Denied on secure", func(t *testing.T) {
|
||||
httpSvr := httptest.NewServer(testCase.handler)
|
||||
defer httpSvr.Close()
|
||||
|
||||
err := testCase.actionUnderTest(ctx, httpSvr.Listener.Addr().String(), false /* secure */)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Denied on non-loopback", func(t *testing.T) {
|
||||
nonLoopbackSvr := httptest.NewUnstartedServer(testCase.handler)
|
||||
|
||||
// replace the test-supplied loopback listener with the first available
|
||||
// non-loopback address
|
||||
nonLoopbackSvr.Listener.Close()
|
||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
require.NoError(t, err)
|
||||
nonLoopbackSvr.Listener = l
|
||||
nonLoopbackSvr.Start()
|
||||
defer nonLoopbackSvr.Close()
|
||||
|
||||
err = testCase.actionUnderTest(ctx, nonLoopbackSvr.Listener.Addr().String(), true /* insecure */)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
// If it should fallback, then no error should occur
|
||||
// as the request will hit the running http server.
|
||||
if tc.fallback {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newServer starts a new server that uses a loopback listener if `loopback`.
|
||||
func newServer(handler http.HandlerFunc, loopback bool) (*httptest.Server, error) {
|
||||
srv := httptest.NewUnstartedServer(handler)
|
||||
|
||||
if !loopback {
|
||||
// Replace the test-supplied loopback listener with the first available
|
||||
// non-loopback address.
|
||||
srv.Listener.Close()
|
||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
srv.Listener = l
|
||||
}
|
||||
|
||||
srv.Start()
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
func TestSSHAgentPasswordlessLogin(t *testing.T) {
|
||||
silenceLogger(t)
|
||||
|
||||
|
|
Loading…
Reference in a new issue