From b4c3b16d0363582c4f942fee26a396d68cad3cb9 Mon Sep 17 00:00:00 2001 From: Trent Clarke Date: Tue, 4 May 2021 18:01:49 +1000 Subject: [PATCH] More review additions --- lib/client/api.go | 9 --- lib/client/api_test.go | 11 ---- tool/tsh/resolve_default_addr.go | 81 +++++++++++++++++++-------- tool/tsh/resolve_default_addr_test.go | 49 ++++++++++++++-- tool/tsh/tsh.go | 4 +- 5 files changed, 105 insertions(+), 49 deletions(-) diff --git a/lib/client/api.go b/lib/client/api.go index 0fbcf2f9127..ff65c020ec4 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -811,11 +811,6 @@ type ParsedProxyHost struct { // itself. UsingDefaultWebProxyPort bool WebProxyAddr string - - // UsingDefaultSSHProxyPort means that the port in SSHProxyAddr was - // supplied by ParseProxyHost function rather than ProxyHost string - // itself. - UsingDefaultSSHProxyPort bool SSHProxyAddr string } @@ -838,7 +833,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) { // set the default values of the port strings. One, both, or neither may // be overridden by the port string parsing below. usingDefaultWebProxyPort := true - usingDefaultSSHProxyPort := true webPort := strconv.Itoa(defaults.HTTPListenPort) sshPort := strconv.Itoa(defaults.SSHProxyListenPort) @@ -866,7 +860,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) { } if text := strings.TrimSpace(parts[1]); len(text) > 0 { sshPort = text - usingDefaultSSHProxyPort = false } default: @@ -877,8 +870,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) { Host: host, UsingDefaultWebProxyPort: usingDefaultWebProxyPort, WebProxyAddr: net.JoinHostPort(host, webPort), - - UsingDefaultSSHProxyPort: usingDefaultSSHProxyPort, SSHProxyAddr: net.JoinHostPort(host, sshPort), } return result, nil diff --git a/lib/client/api_test.go b/lib/client/api_test.go index 8aca93a19f4..dd2fde21226 100644 --- a/lib/client/api_test.go +++ b/lib/client/api_test.go @@ -58,7 +58,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: true, - UsingDefaultSSHProxyPort: true, WebProxyAddr: "example.org:3080", SSHProxyAddr: "example.org:3023", }, @@ -69,7 +68,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: false, - UsingDefaultSSHProxyPort: true, WebProxyAddr: "example.org:1234", SSHProxyAddr: "example.org:3023", }, @@ -80,7 +78,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: false, - UsingDefaultSSHProxyPort: true, WebProxyAddr: "example.org:1234", SSHProxyAddr: "example.org:3023", }, @@ -91,7 +88,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: true, - UsingDefaultSSHProxyPort: false, WebProxyAddr: "example.org:3080", SSHProxyAddr: "example.org:200", }, @@ -102,7 +98,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: true, - UsingDefaultSSHProxyPort: false, WebProxyAddr: "example.org:3080", SSHProxyAddr: "example.org:200", }, @@ -113,7 +108,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: false, - UsingDefaultSSHProxyPort: true, WebProxyAddr: "example.org:100", SSHProxyAddr: "example.org:3023", }, @@ -124,7 +118,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: false, - UsingDefaultSSHProxyPort: false, WebProxyAddr: "example.org:100", SSHProxyAddr: "example.org:200", }, @@ -135,7 +128,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: false, - UsingDefaultSSHProxyPort: true, WebProxyAddr: "example.org:100", SSHProxyAddr: "example.org:3023", }, @@ -146,7 +138,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: false, - UsingDefaultSSHProxyPort: false, WebProxyAddr: "example.org:100", SSHProxyAddr: "example.org:200", }, @@ -157,7 +148,6 @@ func TestParseProxyHostString(t *testing.T) { expect: ParsedProxyHost{ Host: "example.org", UsingDefaultWebProxyPort: true, - UsingDefaultSSHProxyPort: true, WebProxyAddr: "example.org:3080", SSHProxyAddr: "example.org:3023", }, @@ -183,7 +173,6 @@ func TestParseProxyHostString(t *testing.T) { require.NoError(t, err) require.Equal(t, expected.Host, actual.Host) require.Equal(t, expected.UsingDefaultWebProxyPort, actual.UsingDefaultWebProxyPort) - require.Equal(t, expected.UsingDefaultSSHProxyPort, actual.UsingDefaultSSHProxyPort) require.Equal(t, expected.WebProxyAddr, actual.WebProxyAddr) require.Equal(t, expected.SSHProxyAddr, actual.SSHProxyAddr) }) diff --git a/tool/tsh/resolve_default_addr.go b/tool/tsh/resolve_default_addr.go index 449b228edf5..373283be4e2 100644 --- a/tool/tsh/resolve_default_addr.go +++ b/tool/tsh/resolve_default_addr.go @@ -20,12 +20,16 @@ import ( "context" "crypto/tls" "fmt" + "io" + "io/ioutil" "net" "net/http" "strconv" + "sync" "time" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" ) type raceResult struct { @@ -33,36 +37,53 @@ type raceResult struct { err error } -// nonOKResponseError indicates that the racer made contact with a server & -// issued a request but received a non-OK response. This is still -// considered a failure by the port resolution algorithm. -type nonOKResponseError struct { - Status int -} +// logResponseBody reads and dumps a response body to the log at the supplied +// level. Note that it is still the caller's responsibility to close the body +// stream. +func logResponseBody(level logrus.Level, bodyStream io.Reader) { + if log.Logger.Level < level { + return + } -func (err nonOKResponseError) Error() string { - return fmt.Sprintf("Non-OK response status: %03d", err.Status) + // NB: `ReadAll()` will time out (or be cancelled) according to the + // context originally supplied to the request that initiated this + // response, so no need to have an independent reading timeout + // here. + body, err := ioutil.ReadAll(bodyStream) + if err != nil { + // This is only for debugging purposes, so it's safe to just give up here. + log.WithError(err).Debug("Failed to read body stream") + return + } + + log.Logf(level, "Response body: %q", body) } // raceRequest drives an HTTP request to completion and posts the results back // to the supplied channel. -func raceRequest(ctx context.Context, cli *http.Client, addr string, results chan<- raceResult) { - target := fmt.Sprintf("https://%s/", addr) +func raceRequest(ctx context.Context, cli *http.Client, addr string, waitgroup *sync.WaitGroup, results chan<- raceResult) { + defer waitgroup.Done() + + target := fmt.Sprintf("https://%s/webapi/ping", addr) request, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil) if err == nil { var rsp *http.Response rsp, err = cli.Do(request) if err == nil { - rsp.Body.Close() + defer rsp.Body.Close() // If the request returned a non-OK response then we're still going // to treat this as a failure and return an error to the race // aggregator. if rsp.StatusCode != http.StatusOK { - err = nonOKResponseError{Status: rsp.StatusCode} - rsp = nil + log.Debugf("Racer received non-OK response: %03d", rsp.StatusCode) + logResponseBody(logrus.DebugLevel, rsp.Body) + + err = trace.BadParameter("Received non-ok response: %03d", rsp.StatusCode) } + } else { + log.WithError(err).Debug("Race request failed") } } @@ -72,11 +93,12 @@ func raceRequest(ctx context.Context, cli *http.Client, addr string, results cha // startRacer starts the asynchronous execution of a single request, and keeps // all the associated bookeeping up to date. -func startRacer(ctx context.Context, cli *http.Client, host string, candidates []int, results chan<- raceResult) []int { +func startRacer(ctx context.Context, cli *http.Client, host string, candidates []int, waitGroup *sync.WaitGroup, results chan<- raceResult) []int { port, tail := candidates[0], candidates[1:] addr := net.JoinHostPort(host, strconv.Itoa(port)) log.Debugf("Trying %s...", addr) - go raceRequest(ctx, cli, addr, results) + waitGroup.Add(1) + go raceRequest(ctx, cli, addr, waitGroup, results) return tail } @@ -100,10 +122,23 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in }, } + // NOTE: We rely on a specific order of deferred function execution in + // order not to deadlock as we exit this function. + // Please be careful when moving chunks around. + + // Make sure all of our live goroutines have quit before we return. This is + // mainly for testing, so we can assert that the racers are all exiting + // properly in error conditions. + var racersInFlight sync.WaitGroup + defer func() { + log.Debug("Waiting for all in-flight racers to finish") + racersInFlight.Wait() + }() + // Define an inner context that we'll give to the requests to cancel // them regardless of if we exit successfully or we are killed from above. - raceCtx, cancel := context.WithCancel(ctx) - defer cancel() + raceCtx, cancelRace := context.WithCancel(ctx) + defer cancelRace() // Make the channel for the race results big enough so we're guaranteed that a // channel write will never block. Once we have a hit we will stop reading the @@ -113,12 +148,12 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in results := make(chan raceResult, len(candidates)) // Start the first attempt racing - outstandingRacers := len(candidates) - candidates = startRacer(raceCtx, httpClient, host, candidates, results) + unfinishedRacers := len(candidates) + candidates = startRacer(raceCtx, httpClient, host, candidates, &racersInFlight, results) // Start a ticker that will kick off the subsequent racers after a small // interval. We don't want to start them all at once, as we may swamp the - // network and give away advantage we have from doing this concurrently in + // network and give away any advantage we have from doing this concurrently in // the first place. RFC8305 recommends an interval of between 100ms and 2s, // with 250ms being a "sensible default" ticker := time.NewTicker(250 * time.Millisecond) @@ -134,11 +169,11 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in case <-ticker.C: // It's time to kick off a new racer if len(candidates) > 0 { - candidates = startRacer(raceCtx, httpClient, host, candidates, results) + candidates = startRacer(raceCtx, httpClient, host, candidates, &racersInFlight, results) } case r := <-results: - outstandingRacers-- + unfinishedRacers-- // if the request succeeded, it wins the race if r.err == nil { @@ -154,7 +189,7 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in // the ping failed. This could be for any number of reasons. All we // really care about is whether _all_ of the ping attempts have // failed and it's time to return with error - if outstandingRacers == 0 { + if unfinishedRacers == 0 { // Context errors like cancellation or timeout take precedence over any // underlying HTTP errors, as the caller is expected to interrogate them // to decide what it should do next. This is not so much the case for other diff --git a/tool/tsh/resolve_default_addr_test.go b/tool/tsh/resolve_default_addr_test.go index c1bbe74aa57..c8786aa901c 100644 --- a/tool/tsh/resolve_default_addr_test.go +++ b/tool/tsh/resolve_default_addr_test.go @@ -33,8 +33,8 @@ import ( var testLog = log.WithField(trace.Component, "test") -func newWaitForeverHandler() (http.Handler, chan interface{}) { - doneChannel := make(chan interface{}) +func newWaitForeverHandler() (http.Handler, chan struct{}) { + doneChannel := make(chan struct{}) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testLog.Debug("Waiting forever...") <-doneChannel @@ -176,7 +176,7 @@ func TestResolveDefaultAddrTimeout(t *testing.T) { func TestResolveNonOKResponseIsAnError(t *testing.T) { t.Parallel() - // Given a single candidate servers configured to respond with a non-OK status + // Given a single candidate server configured to respond with a non-OK status // code servers := []*httptest.Server{ makeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)), @@ -188,7 +188,48 @@ func TestResolveNonOKResponseIsAnError(t *testing.T) { // Expect that the resolution fails because the server responded with a non-OK // response - require.ErrorIs(t, err, nonOKResponseError{Status: http.StatusTeapot}) + require.Error(t, err) +} + +func TestResolveUndeliveredBodyDoesNotBlockForever(t *testing.T) { + t.Parallel() + + // Given a single candidate server configured to respond with a non-OK status + // code and a looooong, streaming body that never arrives + doneChannel := make(chan struct{}) + defer close(doneChannel) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + f, ok := w.(http.Flusher) + if !ok { + testLog.Error("ResponseWriter must also be a Flusher, or the test is invalid") + t.Fatal() + } + + testLog.Debugf("Writing response header to %T", w) + w.Header().Set("Content-Length", "1048576") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusTeapot) + + w.Write([]byte("I'm a little teapot, short and stout.")) + f.Flush() + + testLog.Debug("Waiting forever instead of sending response body") + <-doneChannel + + testLog.Debug("Exiting handler") + }) + + servers := []*httptest.Server{makeTestServer(t, handler)} + ports := mustGetCandidatePorts(servers) + + // When I attempt to resolve a default address + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports) + + // Expect that the resolution fails with a context timeout + require.ErrorIs(t, err, context.DeadlineExceeded) } func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) { diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 4ca0f65435f..da1e0cf8ede 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -1828,6 +1828,7 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error { proxyAddress := parsedAddrs.WebProxyAddr if parsedAddrs.UsingDefaultWebProxyPort { + log.Debug("Web proxy port was not set. Attempting to detect port number to use.") timeout, cancel := context.WithTimeout(context.Background(), proxyDefaultResolutionTimeout) defer cancel() @@ -1836,8 +1837,7 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error { // On error, fall back to the legacy behaviour if err != nil { - log.Debugf("Proxy port resolution failed: %v", err) - log.Debug("Falling back to legacy default") + log.WithError(err).Debug("Proxy port resolution failed, falling back to legacy default.") return c.ParseProxyHost(cf.Proxy) } }