diff --git a/lib/client/api.go b/lib/client/api.go index dbb4ff7a8b9..862f182dff4 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -991,6 +991,7 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally tc.ExitStatus = 1 return trace.Wrap(err) } + defer nodeClient.Close() // If forwarding ports were specified, start port forwarding. tc.startPortForwarding(ctx, nodeClient) @@ -1021,8 +1022,10 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally if len(command) > 0 { if len(nodeAddrs) > 1 { fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes matched label selector, running command on all.") + return tc.runCommandOnNodes(ctx, siteInfo.Name, nodeAddrs, proxyClient, command) } - return tc.runCommand(ctx, siteInfo.Name, nodeAddrs, proxyClient, command) + // Reuse the existing nodeClient we connected above. + return tc.runCommand(ctx, nodeClient, command) } // Issue "shell" request to run single node. @@ -1457,54 +1460,32 @@ func (tc *TeleportClient) ListAllNodes(ctx context.Context) ([]services.Server, return proxyClient.FindServersByLabels(ctx, tc.Namespace, nil) } -// runCommand executes a given bash command on a bunch of remote nodes -func (tc *TeleportClient) runCommand( +// runCommandOnNodes executes a given bash command on a bunch of remote nodes. +func (tc *TeleportClient) runCommandOnNodes( ctx context.Context, siteName string, nodeAddresses []string, proxyClient *ProxyClient, command []string) error { resultsC := make(chan error, len(nodeAddresses)) for _, address := range nodeAddresses { go func(address string) { - var ( - err error - nodeSession *NodeSession - ) + var err error defer func() { resultsC <- err }() + var nodeClient *NodeClient nodeClient, err = proxyClient.ConnectToNode(ctx, NodeAddr{Addr: address, Namespace: tc.Namespace, Cluster: siteName}, tc.Config.HostLogin, false) if err != nil { + // err is passed to resultsC in the defer above. fmt.Fprintln(tc.Stderr, err) return } defer nodeClient.Close() - // run the command on one node: - if len(nodeAddresses) > 1 { - fmt.Printf("Running command on %v:\n", address) - } - nodeSession, err = newSession(nodeClient, nil, tc.Config.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient), tc.EnableEscapeSequences) - if err != nil { - log.Error(err) - return - } - defer nodeSession.Close() - if err = nodeSession.runCommand(ctx, command, tc.OnShellCreated, tc.Config.Interactive); err != nil { - originErr := trace.Unwrap(err) - exitErr, ok := originErr.(*ssh.ExitError) - if ok { - tc.ExitStatus = exitErr.ExitStatus() - } else { - // if an error occurs, but no exit status is passed back, GoSSH returns - // a generic error like this. in this case the error message is printed - // to stderr by the remote process so we have to quietly return 1: - if strings.Contains(originErr.Error(), "exited without exit status") { - tc.ExitStatus = 1 - } - } - } + fmt.Printf("Running command on %v:\n", address) + err = tc.runCommand(ctx, nodeClient, command) + // err is passed to resultsC in the defer above. }(address) } var lastError error @@ -1516,6 +1497,33 @@ func (tc *TeleportClient) runCommand( return trace.Wrap(lastError) } +// runCommand executes a given bash command on an established NodeClient. +func (tc *TeleportClient) runCommand(ctx context.Context, nodeClient *NodeClient, command []string) error { + nodeSession, err := newSession(nodeClient, nil, tc.Config.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient), tc.EnableEscapeSequences) + if err != nil { + return trace.Wrap(err) + } + defer nodeSession.Close() + if err := nodeSession.runCommand(ctx, command, tc.OnShellCreated, tc.Config.Interactive); err != nil { + originErr := trace.Unwrap(err) + exitErr, ok := originErr.(*ssh.ExitError) + if ok { + tc.ExitStatus = exitErr.ExitStatus() + } else { + // if an error occurs, but no exit status is passed back, GoSSH returns + // a generic error like this. in this case the error message is printed + // to stderr by the remote process so we have to quietly return 1: + if strings.Contains(originErr.Error(), "exited without exit status") { + tc.ExitStatus = 1 + } + } + + return trace.Wrap(err) + } + + return nil +} + // runShell starts an interactive SSH session/shell. // sessionID : when empty, creates a new shell. otherwise it tries to join the existing session. func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.Session) error {