Reuse a NodeClient when running a single non-shell command

When running `tsh ssh foo@bar cmd` we end up dialing `bar` twice - once to
(maybe) start port forwarding and a second time to execute `cmd`.
Instead, reuse the first connection to run `cmd` and only fall back to
re-dialing if we're matching multiple nodes by label.

This gives ~20-30% speedup for non-interactive commands (useful for
tools like ansible):

```
> hyperfine 'tsh ssh localhost true' '~/src/teleport/build/tsh ssh localhost true'
Benchmark #1: tsh ssh localhost true
  Time (mean ± σ):      65.5 ms ±   5.0 ms    [User: 12.9 ms, System: 6.1 ms]
  Range (min … max):    57.0 ms …  74.2 ms    41 runs

Benchmark #2: ~/src/teleport/build/tsh ssh localhost true
  Time (mean ± σ):      51.7 ms ±   3.2 ms    [User: 9.0 ms, System: 5.0 ms]
  Range (min … max):    48.5 ms …  68.5 ms    57 runs

Summary
  '~/src/teleport/build/tsh ssh localhost true' ran
    1.27 ± 0.12 times faster than 'tsh ssh localhost true'
```
This commit is contained in:
Andrew Lytvynov 2020-07-17 15:45:46 -07:00 committed by Andrew Lytvynov
parent 80dd6cf065
commit 53b1eb4727

View file

@ -991,6 +991,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally
tc.ExitStatus = 1
return trace.Wrap(err)
}
defer nodeClient.Close()
// If forwarding ports were specified, start port forwarding.
tc.startPortForwarding(ctx, nodeClient)
@ -1021,8 +1022,10 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally
if len(command) > 0 {
if len(nodeAddrs) > 1 {
fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes matched label selector, running command on all.")
return tc.runCommandOnNodes(ctx, siteInfo.Name, nodeAddrs, proxyClient, command)
}
return tc.runCommand(ctx, siteInfo.Name, nodeAddrs, proxyClient, command)
// Reuse the existing nodeClient we connected above.
return tc.runCommand(ctx, nodeClient, command)
}
// Issue "shell" request to run single node.
@ -1457,54 +1460,32 @@ func (tc *TeleportClient) ListAllNodes(ctx context.Context) ([]services.Server,
return proxyClient.FindServersByLabels(ctx, tc.Namespace, nil)
}
// runCommand executes a given bash command on a bunch of remote nodes
func (tc *TeleportClient) runCommand(
// runCommandOnNodes executes a given bash command on a bunch of remote nodes.
func (tc *TeleportClient) runCommandOnNodes(
ctx context.Context, siteName string, nodeAddresses []string, proxyClient *ProxyClient, command []string) error {
resultsC := make(chan error, len(nodeAddresses))
for _, address := range nodeAddresses {
go func(address string) {
var (
err error
nodeSession *NodeSession
)
var err error
defer func() {
resultsC <- err
}()
var nodeClient *NodeClient
nodeClient, err = proxyClient.ConnectToNode(ctx,
NodeAddr{Addr: address, Namespace: tc.Namespace, Cluster: siteName},
tc.Config.HostLogin, false)
if err != nil {
// err is passed to resultsC in the defer above.
fmt.Fprintln(tc.Stderr, err)
return
}
defer nodeClient.Close()
// run the command on one node:
if len(nodeAddresses) > 1 {
fmt.Printf("Running command on %v:\n", address)
}
nodeSession, err = newSession(nodeClient, nil, tc.Config.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient), tc.EnableEscapeSequences)
if err != nil {
log.Error(err)
return
}
defer nodeSession.Close()
if err = nodeSession.runCommand(ctx, command, tc.OnShellCreated, tc.Config.Interactive); err != nil {
originErr := trace.Unwrap(err)
exitErr, ok := originErr.(*ssh.ExitError)
if ok {
tc.ExitStatus = exitErr.ExitStatus()
} else {
// if an error occurs, but no exit status is passed back, GoSSH returns
// a generic error like this. in this case the error message is printed
// to stderr by the remote process so we have to quietly return 1:
if strings.Contains(originErr.Error(), "exited without exit status") {
tc.ExitStatus = 1
}
}
}
fmt.Printf("Running command on %v:\n", address)
err = tc.runCommand(ctx, nodeClient, command)
// err is passed to resultsC in the defer above.
}(address)
}
var lastError error
@ -1516,6 +1497,33 @@ func (tc *TeleportClient) runCommand(
return trace.Wrap(lastError)
}
// runCommand executes a given bash command on an established NodeClient.
func (tc *TeleportClient) runCommand(ctx context.Context, nodeClient *NodeClient, command []string) error {
nodeSession, err := newSession(nodeClient, nil, tc.Config.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient), tc.EnableEscapeSequences)
if err != nil {
return trace.Wrap(err)
}
defer nodeSession.Close()
if err := nodeSession.runCommand(ctx, command, tc.OnShellCreated, tc.Config.Interactive); err != nil {
originErr := trace.Unwrap(err)
exitErr, ok := originErr.(*ssh.ExitError)
if ok {
tc.ExitStatus = exitErr.ExitStatus()
} else {
// if an error occurs, but no exit status is passed back, GoSSH returns
// a generic error like this. in this case the error message is printed
// to stderr by the remote process so we have to quietly return 1:
if strings.Contains(originErr.Error(), "exited without exit status") {
tc.ExitStatus = 1
}
}
return trace.Wrap(err)
}
return nil
}
// runShell starts an interactive SSH session/shell.
// sessionID : when empty, creates a new shell. otherwise it tries to join the existing session.
func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.Session) error {