diff --git a/integration/integration_test.go b/integration/integration_test.go index b284e56c85a..8531a467688 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -191,6 +191,8 @@ func TestIntegrations(t *testing.T) { t.Run("RotateSuccess", suite.bind(testRotateSuccess)) t.Run("RotateTrustedClusters", suite.bind(testRotateTrustedClusters)) t.Run("SessionStartContainsAccessRequest", suite.bind(testSessionStartContainsAccessRequest)) + t.Run("SessionStreaming", suite.bind(testSessionStreaming)) + t.Run("SSHExitCode", suite.bind(testSSHExitCode)) t.Run("Shutdown", suite.bind(testShutdown)) t.Run("TrustedClusters", suite.bind(testTrustedClusters)) t.Run("TrustedClustersWithLabels", suite.bind(testTrustedClustersWithLabels)) @@ -199,7 +201,6 @@ func TestIntegrations(t *testing.T) { t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel)) t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy)) t.Run("WindowChange", suite.bind(testWindowChange)) - t.Run("SessionStreaming", suite.bind(testSessionStreaming)) } // testAuditOn creates a live session, records a bunch of data through it @@ -1206,10 +1207,11 @@ func runDisconnectTest(t *testing.T, suite *integrationTestSuite, tc disconnectT return default: } + if tc.assertExpected != nil { tc.assertExpected(t, err) - } else if err != nil && !trace.IsEOF(err) { - require.FailNowf(t, "Missing EOF", "expected EOF or nil, got %v instead", err) + } else if err != nil && !trace.IsEOF(err) && !isSSHError(err) { + require.FailNowf(t, "Missing EOF", "expected EOF, ExitError, or nil, got %v instead", err) } } @@ -1233,6 +1235,15 @@ func runDisconnectTest(t *testing.T, suite *integrationTestSuite, tc disconnectT } } +func isSSHError(err error) bool { + switch trace.Unwrap(err).(type) { + case *ssh.ExitError, *ssh.ExitMissingError: + return true + default: + return false + } +} + func timeNow() string { return time.Now().Format(time.StampMilli) } @@ -3375,7 +3386,9 @@ func testPAM(t *testing.T, suite *integrationTestSuite) { termSession.Type("\aecho hi\n\r\aexit\n\r\a") err = cl.SSH(context.TODO(), []string{}, false) - require.NoError(t, err) + if !isSSHError(err) { + require.NoError(t, err) + } cancel() }() @@ -4172,7 +4185,9 @@ func testWindowChange(t *testing.T, suite *integrationTestSuite) { cl.Stdin = personA err = cl.SSH(context.TODO(), []string{}, false) - require.NoError(t, err) + if !isSSHError(err) { + require.NoError(t, err) + } } // joinSession will join the existing session on a server. @@ -4206,10 +4221,12 @@ func testWindowChange(t *testing.T, suite *integrationTestSuite) { for i := 0; i < 10; i++ { err = cl.Join(context.TODO(), apidefaults.Namespace, session.ID(sessionID), personB) - if err == nil { + if err == nil || isSSHError(err) { + err = nil break } } + require.NoError(t, err) } @@ -4766,6 +4783,132 @@ func testBPFExec(t *testing.T, suite *integrationTestSuite) { } } +func testSSHExitCode(t *testing.T, suite *integrationTestSuite) { + lsPath, err := exec.LookPath("ls") + require.NoError(t, err) + + var tests = []struct { + desc string + command []string + input string + interactive bool + errorAssertion require.ErrorAssertionFunc + statusCode int + }{ + // A successful noninteractive session should have a zero status code + { + desc: "Run Command and Exit Successfully", + command: []string{lsPath}, + interactive: false, + errorAssertion: require.NoError, + }, + // A failed noninteractive session should have a non-zero status code + { + desc: "Run Command and Fail With Code 2", + command: []string{"exit 2"}, + interactive: false, + errorAssertion: require.Error, + statusCode: 2, + }, + // A failed interactive session should have a non-zero status code + { + desc: "Run Command Interactively and Fail With Code 2", + command: []string{"exit 2"}, + interactive: true, + errorAssertion: require.Error, + statusCode: 2, + }, + // A failed interactive session should have a non-zero status code + { + desc: "Interactively Fail With Code 3", + input: "exit 3\n\r", + interactive: true, + errorAssertion: require.Error, + statusCode: 3, + }, + // A failed interactive session should have a non-zero status code + { + desc: "Interactively Fail With Code 3", + input: fmt.Sprintf("%v\n\rexit 3\n\r", lsPath), + interactive: true, + errorAssertion: require.Error, + statusCode: 3, + }, + // A successful interactive session should have a zero status code + { + desc: "Interactively Run Command and Exit Successfully", + input: fmt.Sprintf("%v\n\rexit\n\r", lsPath), + interactive: true, + errorAssertion: require.NoError, + }, + // A successful interactive session should have a zero status code + { + desc: "Interactively Exit", + input: "exit\n\r", + interactive: true, + errorAssertion: require.NoError, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + // Create and start a Teleport cluster. + makeConfig := func() (*testing.T, []string, []*InstanceSecrets, *service.Config) { + // Create default config. + tconf := suite.defaultServiceConfig() + + // Configure Auth. + tconf.Auth.Preference.SetSecondFactor("off") + tconf.Auth.Enabled = true + tconf.Auth.NoAudit = true + + // Configure Proxy. + tconf.Proxy.Enabled = true + tconf.Proxy.DisableWebService = false + tconf.Proxy.DisableWebInterface = true + + // Configure Node. + tconf.SSH.Enabled = true + return t, nil, nil, tconf + } + main := suite.newTeleportWithConfig(makeConfig()) + t.Cleanup(func() { main.StopAll() }) + + // context to signal when the client is done with the terminal. + doneContext, doneCancel := context.WithTimeout(context.Background(), time.Second*10) + defer doneCancel() + + cli, err := main.NewClient(t, ClientConfig{ + Login: suite.me.Username, + Cluster: Site, + Host: Host, + Port: main.GetPortSSHInt(), + Interactive: tt.interactive, + }) + require.NoError(t, err) + + if tt.interactive { + // Create a new terminal and connect it to std{in,out} of client. + term := NewTerminal(250) + cli.Stdout = term + cli.Stdin = term + term.Type(tt.input) + } + + // run the ssh command + err = cli.SSH(doneContext, tt.command, false) + tt.errorAssertion(t, err) + + // check that the exit code of the session matches the expected one + if err != nil { + var exitError *ssh.ExitError + require.ErrorAs(t, trace.Unwrap(err), &exitError) + require.Equal(t, tt.statusCode, exitError.ExitStatus()) + } + }) + } +} + // testBPFSessionDifferentiation verifies that the bpf package can // differentiate events from two different sessions. This test in turn also // verifies the cgroup package. diff --git a/lib/client/api.go b/lib/client/api.go index 04a6f227cec..439544850db 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1953,6 +1953,13 @@ func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.S return trace.Wrap(err) } if err = nodeSession.runShell(tc.OnShellCreated); err != nil { + switch e := trace.Unwrap(err).(type) { + case *ssh.ExitError: + tc.ExitStatus = e.ExitStatus() + case *ssh.ExitMissingError: + tc.ExitStatus = 1 + } + return trace.Wrap(err) } if nodeSession.ExitMsg == "" { diff --git a/lib/client/session.go b/lib/client/session.go index e75ade29947..d5407a7be64 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -265,7 +265,7 @@ func (ns *NodeSession) interactiveSession(callback interactiveCallback) error { } // wait for the session to end <-ns.closer.C - return nil + return sess.Wait() } // allocateTerminal creates (allocates) a server-side terminal for this session.