More review additions

This commit is contained in:
Trent Clarke 2021-05-04 18:01:49 +10:00 committed by Russell Jones
parent 03ae893846
commit b4c3b16d03
5 changed files with 105 additions and 49 deletions

View file

@ -811,11 +811,6 @@ type ParsedProxyHost struct {
// itself.
UsingDefaultWebProxyPort bool
WebProxyAddr string
// UsingDefaultSSHProxyPort means that the port in SSHProxyAddr was
// supplied by ParseProxyHost function rather than ProxyHost string
// itself.
UsingDefaultSSHProxyPort bool
SSHProxyAddr string
}
@ -838,7 +833,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) {
// set the default values of the port strings. One, both, or neither may
// be overridden by the port string parsing below.
usingDefaultWebProxyPort := true
usingDefaultSSHProxyPort := true
webPort := strconv.Itoa(defaults.HTTPListenPort)
sshPort := strconv.Itoa(defaults.SSHProxyListenPort)
@ -866,7 +860,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) {
}
if text := strings.TrimSpace(parts[1]); len(text) > 0 {
sshPort = text
usingDefaultSSHProxyPort = false
}
default:
@ -877,8 +870,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) {
Host: host,
UsingDefaultWebProxyPort: usingDefaultWebProxyPort,
WebProxyAddr: net.JoinHostPort(host, webPort),
UsingDefaultSSHProxyPort: usingDefaultSSHProxyPort,
SSHProxyAddr: net.JoinHostPort(host, sshPort),
}
return result, nil

View file

@ -58,7 +58,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: true,
UsingDefaultSSHProxyPort: true,
WebProxyAddr: "example.org:3080",
SSHProxyAddr: "example.org:3023",
},
@ -69,7 +68,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: false,
UsingDefaultSSHProxyPort: true,
WebProxyAddr: "example.org:1234",
SSHProxyAddr: "example.org:3023",
},
@ -80,7 +78,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: false,
UsingDefaultSSHProxyPort: true,
WebProxyAddr: "example.org:1234",
SSHProxyAddr: "example.org:3023",
},
@ -91,7 +88,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: true,
UsingDefaultSSHProxyPort: false,
WebProxyAddr: "example.org:3080",
SSHProxyAddr: "example.org:200",
},
@ -102,7 +98,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: true,
UsingDefaultSSHProxyPort: false,
WebProxyAddr: "example.org:3080",
SSHProxyAddr: "example.org:200",
},
@ -113,7 +108,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: false,
UsingDefaultSSHProxyPort: true,
WebProxyAddr: "example.org:100",
SSHProxyAddr: "example.org:3023",
},
@ -124,7 +118,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: false,
UsingDefaultSSHProxyPort: false,
WebProxyAddr: "example.org:100",
SSHProxyAddr: "example.org:200",
},
@ -135,7 +128,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: false,
UsingDefaultSSHProxyPort: true,
WebProxyAddr: "example.org:100",
SSHProxyAddr: "example.org:3023",
},
@ -146,7 +138,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: false,
UsingDefaultSSHProxyPort: false,
WebProxyAddr: "example.org:100",
SSHProxyAddr: "example.org:200",
},
@ -157,7 +148,6 @@ func TestParseProxyHostString(t *testing.T) {
expect: ParsedProxyHost{
Host: "example.org",
UsingDefaultWebProxyPort: true,
UsingDefaultSSHProxyPort: true,
WebProxyAddr: "example.org:3080",
SSHProxyAddr: "example.org:3023",
},
@ -183,7 +173,6 @@ func TestParseProxyHostString(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expected.Host, actual.Host)
require.Equal(t, expected.UsingDefaultWebProxyPort, actual.UsingDefaultWebProxyPort)
require.Equal(t, expected.UsingDefaultSSHProxyPort, actual.UsingDefaultSSHProxyPort)
require.Equal(t, expected.WebProxyAddr, actual.WebProxyAddr)
require.Equal(t, expected.SSHProxyAddr, actual.SSHProxyAddr)
})

View file

@ -20,12 +20,16 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
)
type raceResult struct {
@ -33,36 +37,53 @@ type raceResult struct {
err error
}
// nonOKResponseError indicates that the racer made contact with a server &
// issued a request but received a non-OK response. This is still
// considered a failure by the port resolution algorithm.
type nonOKResponseError struct {
Status int
}
// logResponseBody reads and dumps a response body to the log at the supplied
// level. Note that it is still the caller's responsibility to close the body
// stream.
func logResponseBody(level logrus.Level, bodyStream io.Reader) {
if log.Logger.Level < level {
return
}
func (err nonOKResponseError) Error() string {
return fmt.Sprintf("Non-OK response status: %03d", err.Status)
// NB: `ReadAll()` will time out (or be cancelled) according to the
// context originally supplied to the request that initiated this
// response, so no need to have an independent reading timeout
// here.
body, err := ioutil.ReadAll(bodyStream)
if err != nil {
// This is only for debugging purposes, so it's safe to just give up here.
log.WithError(err).Debug("Failed to read body stream")
return
}
log.Logf(level, "Response body: %q", body)
}
// raceRequest drives an HTTP request to completion and posts the results back
// to the supplied channel.
func raceRequest(ctx context.Context, cli *http.Client, addr string, results chan<- raceResult) {
target := fmt.Sprintf("https://%s/", addr)
func raceRequest(ctx context.Context, cli *http.Client, addr string, waitgroup *sync.WaitGroup, results chan<- raceResult) {
defer waitgroup.Done()
target := fmt.Sprintf("https://%s/webapi/ping", addr)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil)
if err == nil {
var rsp *http.Response
rsp, err = cli.Do(request)
if err == nil {
rsp.Body.Close()
defer rsp.Body.Close()
// If the request returned a non-OK response then we're still going
// to treat this as a failure and return an error to the race
// aggregator.
if rsp.StatusCode != http.StatusOK {
err = nonOKResponseError{Status: rsp.StatusCode}
rsp = nil
log.Debugf("Racer received non-OK response: %03d", rsp.StatusCode)
logResponseBody(logrus.DebugLevel, rsp.Body)
err = trace.BadParameter("Received non-ok response: %03d", rsp.StatusCode)
}
} else {
log.WithError(err).Debug("Race request failed")
}
}
@ -72,11 +93,12 @@ func raceRequest(ctx context.Context, cli *http.Client, addr string, results cha
// startRacer starts the asynchronous execution of a single request, and keeps
// all the associated bookeeping up to date.
func startRacer(ctx context.Context, cli *http.Client, host string, candidates []int, results chan<- raceResult) []int {
func startRacer(ctx context.Context, cli *http.Client, host string, candidates []int, waitGroup *sync.WaitGroup, results chan<- raceResult) []int {
port, tail := candidates[0], candidates[1:]
addr := net.JoinHostPort(host, strconv.Itoa(port))
log.Debugf("Trying %s...", addr)
go raceRequest(ctx, cli, addr, results)
waitGroup.Add(1)
go raceRequest(ctx, cli, addr, waitGroup, results)
return tail
}
@ -100,10 +122,23 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
},
}
// NOTE: We rely on a specific order of deferred function execution in
// order not to deadlock as we exit this function.
// Please be careful when moving chunks around.
// Make sure all of our live goroutines have quit before we return. This is
// mainly for testing, so we can assert that the racers are all exiting
// properly in error conditions.
var racersInFlight sync.WaitGroup
defer func() {
log.Debug("Waiting for all in-flight racers to finish")
racersInFlight.Wait()
}()
// Define an inner context that we'll give to the requests to cancel
// them regardless of if we exit successfully or we are killed from above.
raceCtx, cancel := context.WithCancel(ctx)
defer cancel()
raceCtx, cancelRace := context.WithCancel(ctx)
defer cancelRace()
// Make the channel for the race results big enough so we're guaranteed that a
// channel write will never block. Once we have a hit we will stop reading the
@ -113,12 +148,12 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
results := make(chan raceResult, len(candidates))
// Start the first attempt racing
outstandingRacers := len(candidates)
candidates = startRacer(raceCtx, httpClient, host, candidates, results)
unfinishedRacers := len(candidates)
candidates = startRacer(raceCtx, httpClient, host, candidates, &racersInFlight, results)
// Start a ticker that will kick off the subsequent racers after a small
// interval. We don't want to start them all at once, as we may swamp the
// network and give away advantage we have from doing this concurrently in
// network and give away any advantage we have from doing this concurrently in
// the first place. RFC8305 recommends an interval of between 100ms and 2s,
// with 250ms being a "sensible default"
ticker := time.NewTicker(250 * time.Millisecond)
@ -134,11 +169,11 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
case <-ticker.C:
// It's time to kick off a new racer
if len(candidates) > 0 {
candidates = startRacer(raceCtx, httpClient, host, candidates, results)
candidates = startRacer(raceCtx, httpClient, host, candidates, &racersInFlight, results)
}
case r := <-results:
outstandingRacers--
unfinishedRacers--
// if the request succeeded, it wins the race
if r.err == nil {
@ -154,7 +189,7 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
// the ping failed. This could be for any number of reasons. All we
// really care about is whether _all_ of the ping attempts have
// failed and it's time to return with error
if outstandingRacers == 0 {
if unfinishedRacers == 0 {
// Context errors like cancellation or timeout take precedence over any
// underlying HTTP errors, as the caller is expected to interrogate them
// to decide what it should do next. This is not so much the case for other

View file

@ -33,8 +33,8 @@ import (
var testLog = log.WithField(trace.Component, "test")
func newWaitForeverHandler() (http.Handler, chan interface{}) {
doneChannel := make(chan interface{})
func newWaitForeverHandler() (http.Handler, chan struct{}) {
doneChannel := make(chan struct{})
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testLog.Debug("Waiting forever...")
<-doneChannel
@ -176,7 +176,7 @@ func TestResolveDefaultAddrTimeout(t *testing.T) {
func TestResolveNonOKResponseIsAnError(t *testing.T) {
t.Parallel()
// Given a single candidate servers configured to respond with a non-OK status
// Given a single candidate server configured to respond with a non-OK status
// code
servers := []*httptest.Server{
makeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)),
@ -188,7 +188,48 @@ func TestResolveNonOKResponseIsAnError(t *testing.T) {
// Expect that the resolution fails because the server responded with a non-OK
// response
require.ErrorIs(t, err, nonOKResponseError{Status: http.StatusTeapot})
require.Error(t, err)
}
func TestResolveUndeliveredBodyDoesNotBlockForever(t *testing.T) {
t.Parallel()
// Given a single candidate server configured to respond with a non-OK status
// code and a looooong, streaming body that never arrives
doneChannel := make(chan struct{})
defer close(doneChannel)
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
f, ok := w.(http.Flusher)
if !ok {
testLog.Error("ResponseWriter must also be a Flusher, or the test is invalid")
t.Fatal()
}
testLog.Debugf("Writing response header to %T", w)
w.Header().Set("Content-Length", "1048576")
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusTeapot)
w.Write([]byte("I'm a little teapot, short and stout."))
f.Flush()
testLog.Debug("Waiting forever instead of sending response body")
<-doneChannel
testLog.Debug("Exiting handler")
})
servers := []*httptest.Server{makeTestServer(t, handler)}
ports := mustGetCandidatePorts(servers)
// When I attempt to resolve a default address
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports)
// Expect that the resolution fails with a context timeout
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) {

View file

@ -1828,6 +1828,7 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error {
proxyAddress := parsedAddrs.WebProxyAddr
if parsedAddrs.UsingDefaultWebProxyPort {
log.Debug("Web proxy port was not set. Attempting to detect port number to use.")
timeout, cancel := context.WithTimeout(context.Background(), proxyDefaultResolutionTimeout)
defer cancel()
@ -1836,8 +1837,7 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error {
// On error, fall back to the legacy behaviour
if err != nil {
log.Debugf("Proxy port resolution failed: %v", err)
log.Debug("Falling back to legacy default")
log.WithError(err).Debug("Proxy port resolution failed, falling back to legacy default.")
return c.ParseProxyHost(cf.Proxy)
}
}