mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 17:53:28 +00:00
More review additions
This commit is contained in:
parent
03ae893846
commit
b4c3b16d03
|
@ -811,11 +811,6 @@ type ParsedProxyHost struct {
|
|||
// itself.
|
||||
UsingDefaultWebProxyPort bool
|
||||
WebProxyAddr string
|
||||
|
||||
// UsingDefaultSSHProxyPort means that the port in SSHProxyAddr was
|
||||
// supplied by ParseProxyHost function rather than ProxyHost string
|
||||
// itself.
|
||||
UsingDefaultSSHProxyPort bool
|
||||
SSHProxyAddr string
|
||||
}
|
||||
|
||||
|
@ -838,7 +833,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) {
|
|||
// set the default values of the port strings. One, both, or neither may
|
||||
// be overridden by the port string parsing below.
|
||||
usingDefaultWebProxyPort := true
|
||||
usingDefaultSSHProxyPort := true
|
||||
webPort := strconv.Itoa(defaults.HTTPListenPort)
|
||||
sshPort := strconv.Itoa(defaults.SSHProxyListenPort)
|
||||
|
||||
|
@ -866,7 +860,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) {
|
|||
}
|
||||
if text := strings.TrimSpace(parts[1]); len(text) > 0 {
|
||||
sshPort = text
|
||||
usingDefaultSSHProxyPort = false
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -877,8 +870,6 @@ func ParseProxyHost(proxyHost string) (*ParsedProxyHost, error) {
|
|||
Host: host,
|
||||
UsingDefaultWebProxyPort: usingDefaultWebProxyPort,
|
||||
WebProxyAddr: net.JoinHostPort(host, webPort),
|
||||
|
||||
UsingDefaultSSHProxyPort: usingDefaultSSHProxyPort,
|
||||
SSHProxyAddr: net.JoinHostPort(host, sshPort),
|
||||
}
|
||||
return result, nil
|
||||
|
|
|
@ -58,7 +58,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: true,
|
||||
UsingDefaultSSHProxyPort: true,
|
||||
WebProxyAddr: "example.org:3080",
|
||||
SSHProxyAddr: "example.org:3023",
|
||||
},
|
||||
|
@ -69,7 +68,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: false,
|
||||
UsingDefaultSSHProxyPort: true,
|
||||
WebProxyAddr: "example.org:1234",
|
||||
SSHProxyAddr: "example.org:3023",
|
||||
},
|
||||
|
@ -80,7 +78,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: false,
|
||||
UsingDefaultSSHProxyPort: true,
|
||||
WebProxyAddr: "example.org:1234",
|
||||
SSHProxyAddr: "example.org:3023",
|
||||
},
|
||||
|
@ -91,7 +88,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: true,
|
||||
UsingDefaultSSHProxyPort: false,
|
||||
WebProxyAddr: "example.org:3080",
|
||||
SSHProxyAddr: "example.org:200",
|
||||
},
|
||||
|
@ -102,7 +98,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: true,
|
||||
UsingDefaultSSHProxyPort: false,
|
||||
WebProxyAddr: "example.org:3080",
|
||||
SSHProxyAddr: "example.org:200",
|
||||
},
|
||||
|
@ -113,7 +108,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: false,
|
||||
UsingDefaultSSHProxyPort: true,
|
||||
WebProxyAddr: "example.org:100",
|
||||
SSHProxyAddr: "example.org:3023",
|
||||
},
|
||||
|
@ -124,7 +118,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: false,
|
||||
UsingDefaultSSHProxyPort: false,
|
||||
WebProxyAddr: "example.org:100",
|
||||
SSHProxyAddr: "example.org:200",
|
||||
},
|
||||
|
@ -135,7 +128,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: false,
|
||||
UsingDefaultSSHProxyPort: true,
|
||||
WebProxyAddr: "example.org:100",
|
||||
SSHProxyAddr: "example.org:3023",
|
||||
},
|
||||
|
@ -146,7 +138,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: false,
|
||||
UsingDefaultSSHProxyPort: false,
|
||||
WebProxyAddr: "example.org:100",
|
||||
SSHProxyAddr: "example.org:200",
|
||||
},
|
||||
|
@ -157,7 +148,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
expect: ParsedProxyHost{
|
||||
Host: "example.org",
|
||||
UsingDefaultWebProxyPort: true,
|
||||
UsingDefaultSSHProxyPort: true,
|
||||
WebProxyAddr: "example.org:3080",
|
||||
SSHProxyAddr: "example.org:3023",
|
||||
},
|
||||
|
@ -183,7 +173,6 @@ func TestParseProxyHostString(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, expected.Host, actual.Host)
|
||||
require.Equal(t, expected.UsingDefaultWebProxyPort, actual.UsingDefaultWebProxyPort)
|
||||
require.Equal(t, expected.UsingDefaultSSHProxyPort, actual.UsingDefaultSSHProxyPort)
|
||||
require.Equal(t, expected.WebProxyAddr, actual.WebProxyAddr)
|
||||
require.Equal(t, expected.SSHProxyAddr, actual.SSHProxyAddr)
|
||||
})
|
||||
|
|
|
@ -20,12 +20,16 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type raceResult struct {
|
||||
|
@ -33,36 +37,53 @@ type raceResult struct {
|
|||
err error
|
||||
}
|
||||
|
||||
// nonOKResponseError indicates that the racer made contact with a server &
|
||||
// issued a request but received a non-OK response. This is still
|
||||
// considered a failure by the port resolution algorithm.
|
||||
type nonOKResponseError struct {
|
||||
Status int
|
||||
}
|
||||
// logResponseBody reads and dumps a response body to the log at the supplied
|
||||
// level. Note that it is still the caller's responsibility to close the body
|
||||
// stream.
|
||||
func logResponseBody(level logrus.Level, bodyStream io.Reader) {
|
||||
if log.Logger.Level < level {
|
||||
return
|
||||
}
|
||||
|
||||
func (err nonOKResponseError) Error() string {
|
||||
return fmt.Sprintf("Non-OK response status: %03d", err.Status)
|
||||
// NB: `ReadAll()` will time out (or be cancelled) according to the
|
||||
// context originally supplied to the request that initiated this
|
||||
// response, so no need to have an independent reading timeout
|
||||
// here.
|
||||
body, err := ioutil.ReadAll(bodyStream)
|
||||
if err != nil {
|
||||
// This is only for debugging purposes, so it's safe to just give up here.
|
||||
log.WithError(err).Debug("Failed to read body stream")
|
||||
return
|
||||
}
|
||||
|
||||
log.Logf(level, "Response body: %q", body)
|
||||
}
|
||||
|
||||
// raceRequest drives an HTTP request to completion and posts the results back
|
||||
// to the supplied channel.
|
||||
func raceRequest(ctx context.Context, cli *http.Client, addr string, results chan<- raceResult) {
|
||||
target := fmt.Sprintf("https://%s/", addr)
|
||||
func raceRequest(ctx context.Context, cli *http.Client, addr string, waitgroup *sync.WaitGroup, results chan<- raceResult) {
|
||||
defer waitgroup.Done()
|
||||
|
||||
target := fmt.Sprintf("https://%s/webapi/ping", addr)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil)
|
||||
|
||||
if err == nil {
|
||||
var rsp *http.Response
|
||||
rsp, err = cli.Do(request)
|
||||
if err == nil {
|
||||
rsp.Body.Close()
|
||||
defer rsp.Body.Close()
|
||||
|
||||
// If the request returned a non-OK response then we're still going
|
||||
// to treat this as a failure and return an error to the race
|
||||
// aggregator.
|
||||
if rsp.StatusCode != http.StatusOK {
|
||||
err = nonOKResponseError{Status: rsp.StatusCode}
|
||||
rsp = nil
|
||||
log.Debugf("Racer received non-OK response: %03d", rsp.StatusCode)
|
||||
logResponseBody(logrus.DebugLevel, rsp.Body)
|
||||
|
||||
err = trace.BadParameter("Received non-ok response: %03d", rsp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Debug("Race request failed")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,11 +93,12 @@ func raceRequest(ctx context.Context, cli *http.Client, addr string, results cha
|
|||
|
||||
// startRacer starts the asynchronous execution of a single request, and keeps
|
||||
// all the associated bookeeping up to date.
|
||||
func startRacer(ctx context.Context, cli *http.Client, host string, candidates []int, results chan<- raceResult) []int {
|
||||
func startRacer(ctx context.Context, cli *http.Client, host string, candidates []int, waitGroup *sync.WaitGroup, results chan<- raceResult) []int {
|
||||
port, tail := candidates[0], candidates[1:]
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
log.Debugf("Trying %s...", addr)
|
||||
go raceRequest(ctx, cli, addr, results)
|
||||
waitGroup.Add(1)
|
||||
go raceRequest(ctx, cli, addr, waitGroup, results)
|
||||
return tail
|
||||
}
|
||||
|
||||
|
@ -100,10 +122,23 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
|
|||
},
|
||||
}
|
||||
|
||||
// NOTE: We rely on a specific order of deferred function execution in
|
||||
// order not to deadlock as we exit this function.
|
||||
// Please be careful when moving chunks around.
|
||||
|
||||
// Make sure all of our live goroutines have quit before we return. This is
|
||||
// mainly for testing, so we can assert that the racers are all exiting
|
||||
// properly in error conditions.
|
||||
var racersInFlight sync.WaitGroup
|
||||
defer func() {
|
||||
log.Debug("Waiting for all in-flight racers to finish")
|
||||
racersInFlight.Wait()
|
||||
}()
|
||||
|
||||
// Define an inner context that we'll give to the requests to cancel
|
||||
// them regardless of if we exit successfully or we are killed from above.
|
||||
raceCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
raceCtx, cancelRace := context.WithCancel(ctx)
|
||||
defer cancelRace()
|
||||
|
||||
// Make the channel for the race results big enough so we're guaranteed that a
|
||||
// channel write will never block. Once we have a hit we will stop reading the
|
||||
|
@ -113,12 +148,12 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
|
|||
results := make(chan raceResult, len(candidates))
|
||||
|
||||
// Start the first attempt racing
|
||||
outstandingRacers := len(candidates)
|
||||
candidates = startRacer(raceCtx, httpClient, host, candidates, results)
|
||||
unfinishedRacers := len(candidates)
|
||||
candidates = startRacer(raceCtx, httpClient, host, candidates, &racersInFlight, results)
|
||||
|
||||
// Start a ticker that will kick off the subsequent racers after a small
|
||||
// interval. We don't want to start them all at once, as we may swamp the
|
||||
// network and give away advantage we have from doing this concurrently in
|
||||
// network and give away any advantage we have from doing this concurrently in
|
||||
// the first place. RFC8305 recommends an interval of between 100ms and 2s,
|
||||
// with 250ms being a "sensible default"
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
|
@ -134,11 +169,11 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
|
|||
case <-ticker.C:
|
||||
// It's time to kick off a new racer
|
||||
if len(candidates) > 0 {
|
||||
candidates = startRacer(raceCtx, httpClient, host, candidates, results)
|
||||
candidates = startRacer(raceCtx, httpClient, host, candidates, &racersInFlight, results)
|
||||
}
|
||||
|
||||
case r := <-results:
|
||||
outstandingRacers--
|
||||
unfinishedRacers--
|
||||
|
||||
// if the request succeeded, it wins the race
|
||||
if r.err == nil {
|
||||
|
@ -154,7 +189,7 @@ func pickDefaultAddr(ctx context.Context, insecure bool, host string, ports []in
|
|||
// the ping failed. This could be for any number of reasons. All we
|
||||
// really care about is whether _all_ of the ping attempts have
|
||||
// failed and it's time to return with error
|
||||
if outstandingRacers == 0 {
|
||||
if unfinishedRacers == 0 {
|
||||
// Context errors like cancellation or timeout take precedence over any
|
||||
// underlying HTTP errors, as the caller is expected to interrogate them
|
||||
// to decide what it should do next. This is not so much the case for other
|
||||
|
|
|
@ -33,8 +33,8 @@ import (
|
|||
|
||||
var testLog = log.WithField(trace.Component, "test")
|
||||
|
||||
func newWaitForeverHandler() (http.Handler, chan interface{}) {
|
||||
doneChannel := make(chan interface{})
|
||||
func newWaitForeverHandler() (http.Handler, chan struct{}) {
|
||||
doneChannel := make(chan struct{})
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
testLog.Debug("Waiting forever...")
|
||||
<-doneChannel
|
||||
|
@ -176,7 +176,7 @@ func TestResolveDefaultAddrTimeout(t *testing.T) {
|
|||
func TestResolveNonOKResponseIsAnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given a single candidate servers configured to respond with a non-OK status
|
||||
// Given a single candidate server configured to respond with a non-OK status
|
||||
// code
|
||||
servers := []*httptest.Server{
|
||||
makeTestServer(t, newRespondingHandlerWithStatus(http.StatusTeapot)),
|
||||
|
@ -188,7 +188,48 @@ func TestResolveNonOKResponseIsAnError(t *testing.T) {
|
|||
|
||||
// Expect that the resolution fails because the server responded with a non-OK
|
||||
// response
|
||||
require.ErrorIs(t, err, nonOKResponseError{Status: http.StatusTeapot})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestResolveUndeliveredBodyDoesNotBlockForever(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given a single candidate server configured to respond with a non-OK status
|
||||
// code and a looooong, streaming body that never arrives
|
||||
doneChannel := make(chan struct{})
|
||||
defer close(doneChannel)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
f, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
testLog.Error("ResponseWriter must also be a Flusher, or the test is invalid")
|
||||
t.Fatal()
|
||||
}
|
||||
|
||||
testLog.Debugf("Writing response header to %T", w)
|
||||
w.Header().Set("Content-Length", "1048576")
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
|
||||
w.Write([]byte("I'm a little teapot, short and stout."))
|
||||
f.Flush()
|
||||
|
||||
testLog.Debug("Waiting forever instead of sending response body")
|
||||
<-doneChannel
|
||||
|
||||
testLog.Debug("Exiting handler")
|
||||
})
|
||||
|
||||
servers := []*httptest.Server{makeTestServer(t, handler)}
|
||||
ports := mustGetCandidatePorts(servers)
|
||||
|
||||
// When I attempt to resolve a default address
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
_, err := pickDefaultAddr(ctx, true, "127.0.0.1", ports)
|
||||
|
||||
// Expect that the resolution fails with a context timeout
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestResolveDefaultAddrTimeoutBeforeAllRacersLaunched(t *testing.T) {
|
||||
|
|
|
@ -1828,6 +1828,7 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error {
|
|||
|
||||
proxyAddress := parsedAddrs.WebProxyAddr
|
||||
if parsedAddrs.UsingDefaultWebProxyPort {
|
||||
log.Debug("Web proxy port was not set. Attempting to detect port number to use.")
|
||||
timeout, cancel := context.WithTimeout(context.Background(), proxyDefaultResolutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
@ -1836,8 +1837,7 @@ func setClientWebProxyAddr(cf *CLIConf, c *client.Config) error {
|
|||
|
||||
// On error, fall back to the legacy behaviour
|
||||
if err != nil {
|
||||
log.Debugf("Proxy port resolution failed: %v", err)
|
||||
log.Debug("Falling back to legacy default")
|
||||
log.WithError(err).Debug("Proxy port resolution failed, falling back to legacy default.")
|
||||
return c.ParseProxyHost(cf.Proxy)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue