mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 08:43:58 +00:00
Node session race (#6195)
* Attempt to isolate and improve state handling of a NodeSession. * Add terminal close for kube terminal tests * Address review comments * Small tweaks Co-authored-by: Andrew Lytvynov <andrew@goteleport.com>
This commit is contained in:
parent
0dbc11d838
commit
4acf50902c
|
@ -423,6 +423,7 @@ func (s *IntSuite) TestAuditOn(c *check.C) {
|
|||
|
||||
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
|
||||
myTerm.Type("\aecho hi\n\r\aexit\n\r\a")
|
||||
myTerm.closeSend()
|
||||
|
||||
// wait for session to end:
|
||||
select {
|
||||
|
@ -867,6 +868,7 @@ func (s *IntSuite) verifySessionJoin(c *check.C, t *TeleInstance) {
|
|||
|
||||
personA := NewTerminal(250)
|
||||
personB := NewTerminal(250)
|
||||
personB.closeSend()
|
||||
|
||||
// PersonA: SSH into the server, wait one second, then type some commands on stdin:
|
||||
openSession := func() {
|
||||
|
@ -904,6 +906,7 @@ func (s *IntSuite) verifySessionJoin(c *check.C, t *TeleInstance) {
|
|||
}
|
||||
}
|
||||
c.Assert(err, check.IsNil)
|
||||
personA.closeSend()
|
||||
}
|
||||
|
||||
go openSession()
|
||||
|
@ -936,9 +939,6 @@ func (s *IntSuite) TestShutdown(c *check.C) {
|
|||
|
||||
person := NewTerminal(250)
|
||||
|
||||
// commandsC receive commands
|
||||
commandsC := make(chan string)
|
||||
|
||||
// PersonA: SSH into the server, wait one second, then type some commands on stdin:
|
||||
openSession := func() {
|
||||
cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()})
|
||||
|
@ -946,12 +946,6 @@ func (s *IntSuite) TestShutdown(c *check.C) {
|
|||
cl.Stdout = person
|
||||
cl.Stdin = person
|
||||
|
||||
go func() {
|
||||
for command := range commandsC {
|
||||
person.Type(command)
|
||||
}
|
||||
}()
|
||||
|
||||
err = cl.SSH(context.TODO(), []string{}, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
@ -991,6 +985,7 @@ func (s *IntSuite) TestShutdown(c *check.C) {
|
|||
|
||||
// now type exit and wait for shutdown to complete
|
||||
person.Type("exit\n\r")
|
||||
person.closeSend()
|
||||
|
||||
select {
|
||||
case <-shutdownContext.Done():
|
||||
|
@ -1020,6 +1015,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
|
|||
|
||||
testCases := []disconnectTestCase{
|
||||
{
|
||||
comment: "recording at node",
|
||||
recordingMode: services.RecordAtNode,
|
||||
options: services.RoleOptions{
|
||||
ClientIdleTimeout: services.NewDuration(500 * time.Millisecond),
|
||||
|
@ -1027,6 +1023,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
|
|||
disconnectTimeout: time.Second,
|
||||
},
|
||||
{
|
||||
comment: "recording at proxy",
|
||||
recordingMode: services.RecordAtProxy,
|
||||
options: services.RoleOptions{
|
||||
ForwardAgent: services.NewBool(true),
|
||||
|
@ -1035,6 +1032,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
|
|||
disconnectTimeout: time.Second,
|
||||
},
|
||||
{
|
||||
comment: "recording at node: expired certificate is disconnected",
|
||||
recordingMode: services.RecordAtNode,
|
||||
options: services.RoleOptions{
|
||||
DisconnectExpiredCert: services.NewBool(true),
|
||||
|
@ -1043,6 +1041,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
|
|||
disconnectTimeout: 4 * time.Second,
|
||||
},
|
||||
{
|
||||
comment: "recording at proxy: expired certificate is disconnected and forwarding agent",
|
||||
recordingMode: services.RecordAtProxy,
|
||||
options: services.RoleOptions{
|
||||
ForwardAgent: services.NewBool(true),
|
||||
|
@ -1215,7 +1214,7 @@ func (s *IntSuite) runDisconnectTest(c *check.C, tc disconnectTestCase) {
|
|||
select {
|
||||
case <-time.After(tc.disconnectTimeout + time.Second):
|
||||
dumpGoroutineProfile()
|
||||
c.Fatalf("%s: timeout waiting for session to exit: %+v", timeNow(), tc)
|
||||
c.Fatalf("%s (%s): timeout waiting for session to exit: %+v", timeNow(), tc.comment, tc)
|
||||
case <-ctx.Done():
|
||||
// session closed. a test case is successful if the first
|
||||
// session to close encountered the expected error variant.
|
||||
|
@ -1228,6 +1227,7 @@ func timeNow() string {
|
|||
|
||||
func enterInput(ctx context.Context, c *check.C, person *Terminal, command, pattern string) {
|
||||
person.Type(command)
|
||||
defer person.closeSend()
|
||||
abortTime := time.Now().Add(10 * time.Second)
|
||||
var matched bool
|
||||
var output string
|
||||
|
@ -1245,7 +1245,7 @@ func enterInput(ctx context.Context, c *check.C, person *Terminal, command, patt
|
|||
return
|
||||
}
|
||||
if time.Now().After(abortTime) {
|
||||
c.Fatalf("failed to capture pattern %q in %q", pattern, output)
|
||||
c.Fatalf("Failed to capture pattern %q in %q", pattern, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3239,6 +3239,7 @@ func (s *IntSuite) TestAuditOff(c *check.C) {
|
|||
|
||||
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
|
||||
myTerm.Type("\aecho hi\n\r\aexit\n\r\a")
|
||||
myTerm.closeSend()
|
||||
|
||||
// wait for session to end
|
||||
select {
|
||||
|
@ -3248,9 +3249,7 @@ func (s *IntSuite) TestAuditOff(c *check.C) {
|
|||
}
|
||||
|
||||
// audit log should have the fact that the session occurred recorded in it
|
||||
sessions, err = site.GetSessions(defaults.Namespace)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(sessions), check.Equals, 1)
|
||||
// but the session could have been garbage collected at this point.
|
||||
|
||||
// however, attempts to read the actual sessions should fail because it was
|
||||
// not actually recorded
|
||||
|
@ -3390,6 +3389,7 @@ func (s *IntSuite) TestPAM(c *check.C) {
|
|||
cl.Stdin = termSession
|
||||
|
||||
termSession.Type("\aecho hi\n\r\aexit\n\r\a")
|
||||
termSession.closeSend()
|
||||
err = cl.SSH(context.TODO(), []string{}, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -4188,6 +4188,7 @@ func (s *IntSuite) TestWindowChange(c *check.C) {
|
|||
|
||||
personA := NewTerminal(250)
|
||||
personB := NewTerminal(250)
|
||||
defer personB.closeSend()
|
||||
|
||||
// openSession will open a new session on a server.
|
||||
openSession := func() {
|
||||
|
@ -4293,6 +4294,7 @@ func (s *IntSuite) TestWindowChange(c *check.C) {
|
|||
|
||||
// Close the session.
|
||||
personA.Type("\aexit\r\n\a")
|
||||
personA.closeSend()
|
||||
}
|
||||
|
||||
// TestList checks that the list of servers returned is identity aware.
|
||||
|
@ -4672,6 +4674,7 @@ func (s *IntSuite) TestBPFInteractive(c *check.C) {
|
|||
|
||||
// "Type" a command into the terminal.
|
||||
term.Type(fmt.Sprintf("\a%v\n\r\aexit\n\r\a", lsPath))
|
||||
term.closeSend()
|
||||
err = client.SSH(context.TODO(), []string{}, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
|
@ -4890,6 +4893,8 @@ func (s *IntSuite) TestBPFSessionDifferentiation(c *check.C) {
|
|||
}
|
||||
writeTerm(termA)
|
||||
writeTerm(termB)
|
||||
termA.closeSend()
|
||||
termB.closeSend()
|
||||
|
||||
// Wait 10 seconds for both events to arrive, otherwise timeout.
|
||||
timeout := time.After(10 * time.Second)
|
||||
|
@ -5082,6 +5087,8 @@ func runCommand(instance *TeleInstance, cmd []string, cfg ClientConfig, attempts
|
|||
close(doneC)
|
||||
}()
|
||||
tc.Stdout = write
|
||||
var buf bytes.Buffer
|
||||
tc.Stdin = &buf
|
||||
for i := 0; i < attempts; i++ {
|
||||
err = tc.SSH(context.TODO(), cmd, false)
|
||||
if err == nil {
|
||||
|
@ -5184,9 +5191,14 @@ func (t *Terminal) Write(data []byte) (n int, err error) {
|
|||
return t.written.Write(data)
|
||||
}
|
||||
|
||||
// closeSend closes the input channel thus signalling the reads to exit
|
||||
func (t *Terminal) closeSend() {
|
||||
close(t.typed)
|
||||
}
|
||||
|
||||
func (t *Terminal) Read(p []byte) (n int, err error) {
|
||||
for n = 0; n < len(p); n++ {
|
||||
p[n] = <-t.typed
|
||||
for ch := range t.typed {
|
||||
p[n] = ch
|
||||
if p[n] == '\r' {
|
||||
break
|
||||
}
|
||||
|
@ -5195,6 +5207,10 @@ func (t *Terminal) Read(p []byte) (n int, err error) {
|
|||
n--
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
n++
|
||||
if n == len(p) {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
|
|
@ -266,6 +266,7 @@ func (s *KubeSuite) TestKubeExec(c *check.C) {
|
|||
term := NewTerminal(250)
|
||||
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
|
||||
term.Type("\aecho hi\n\r\aexit\n\r\a")
|
||||
term.closeSend()
|
||||
|
||||
out = &bytes.Buffer{}
|
||||
err = kubeExec(proxyClientConfig, kubeExecArgs{
|
||||
|
@ -309,6 +310,7 @@ loop:
|
|||
// interactive command, allocate pty
|
||||
term = NewTerminal(250)
|
||||
term.Type("\aecho hi\n\r\aexit\n\r\a")
|
||||
term.closeSend()
|
||||
out = &bytes.Buffer{}
|
||||
err = kubeExec(impersonatingProxyClientConfig, kubeExecArgs{
|
||||
podName: pod.Name,
|
||||
|
@ -326,6 +328,7 @@ loop:
|
|||
// are allowed by the role
|
||||
term = NewTerminal(250)
|
||||
term.Type("\aecho hi\n\r\aexit\n\r\a")
|
||||
term.closeSend()
|
||||
out = &bytes.Buffer{}
|
||||
err = kubeExec(scopedProxyClientConfig, kubeExecArgs{
|
||||
podName: pod.Name,
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
@ -34,6 +35,8 @@ import (
|
|||
|
||||
"github.com/moby/term"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/client/escape"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
|
@ -41,7 +44,6 @@ import (
|
|||
"github.com/gravitational/teleport/lib/session"
|
||||
"github.com/gravitational/teleport/lib/sshutils"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
type NodeSession struct {
|
||||
|
@ -218,25 +220,19 @@ func (ns *NodeSession) interactiveSession(callback interactiveCallback) error {
|
|||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
defer remoteTerm.Close()
|
||||
|
||||
// call the passed callback and give them the established
|
||||
// ssh session:
|
||||
if err := callback(sess, remoteTerm); err != nil {
|
||||
remoteTerm.Close()
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// Catch term signals, but only if we're attached to a real terminal
|
||||
if ns.isTerminalAttached() {
|
||||
ns.watchSignals(remoteTerm)
|
||||
}
|
||||
|
||||
// start piping input into the remote shell and pipe the output from
|
||||
// the remote shell into stdout:
|
||||
ns.pipeInOut(remoteTerm)
|
||||
|
||||
// switch the terminal to raw mode (and switch back on exit!)
|
||||
if ns.isTerminalAttached() {
|
||||
// switch the terminal to raw mode (and switch back on exit!)
|
||||
ts, err := term.SetRawTerminal(0)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
|
@ -244,8 +240,13 @@ func (ns *NodeSession) interactiveSession(callback interactiveCallback) error {
|
|||
defer term.RestoreTerminal(0, ts)
|
||||
}
|
||||
}
|
||||
// wait for the session to end
|
||||
<-ns.closer.C
|
||||
|
||||
// Pipe input into the remote shell and pipe the output from the remote
|
||||
// shell into stdout.
|
||||
//
|
||||
// Note, pipeInOut takes ownership of remoteTerm and will close it upon
|
||||
// completion.
|
||||
ns.pipeInOut(remoteTerm)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -537,23 +538,47 @@ func (ns *NodeSession) watchSignals(shell io.Writer) {
|
|||
}()
|
||||
}
|
||||
|
||||
// pipeInOut launches two goroutines: one to pipe the local input into the remote shell,
|
||||
// and another to pipe the output of the remote shell into the local output
|
||||
// pipeInOut pipes the local input into the remote shell, and the output of the
|
||||
// remote shell into the local output.
|
||||
func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser) {
|
||||
// copy from the remote shell to the local output
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
// Create a pipe to use in front of ns.stdin.
|
||||
//
|
||||
// This allows us to close the goroutine sending stdin to the remote shell
|
||||
// while blocked reading from stdin.
|
||||
stdin, stdinSink := io.Pipe()
|
||||
go func() {
|
||||
// Forward actual stdin to the pipe.
|
||||
//
|
||||
// Note: this is not registered with the WaitGroup on purpse. This
|
||||
// goroutine will dangle after the session was terminated, until the
|
||||
// last read from ns.stdin unblocks.
|
||||
_, err := io.Copy(stdinSink, ns.stdin)
|
||||
stdinSink.CloseWithError(err)
|
||||
}()
|
||||
// copy from the remote shell to the local output
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer ns.closer.Close()
|
||||
defer stdinSink.Close()
|
||||
_, err := io.Copy(ns.stdout, shell)
|
||||
if err != nil {
|
||||
log.Errorf(err.Error())
|
||||
log.Error("Error copying from shell:", err.Error())
|
||||
}
|
||||
}()
|
||||
// copy from the local input to the remote shell:
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer ns.closer.Close()
|
||||
defer shell.Close()
|
||||
defer stdin.Close()
|
||||
buf := make([]byte, 128)
|
||||
|
||||
stdin := ns.stdin
|
||||
stdin := io.Reader(stdin)
|
||||
if ns.isTerminalAttached() && ns.enableEscapeSequences {
|
||||
stdin = escape.NewReader(stdin, ns.stderr, func(err error) {
|
||||
switch err {
|
||||
|
@ -568,16 +593,21 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser) {
|
|||
})
|
||||
}
|
||||
for {
|
||||
n, err := stdin.Read(buf)
|
||||
if err != nil {
|
||||
fmt.Fprintf(ns.stderr, "\r\n%v\r\n", trace.Wrap(err))
|
||||
select {
|
||||
case <-ns.closer.C:
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
_, err = shell.Write(buf[:n])
|
||||
default:
|
||||
n, err := stdin.Read(buf)
|
||||
if n > 0 {
|
||||
if _, err := shell.Write(buf[:n]); err != nil {
|
||||
ns.ExitMsg = err.Error()
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
ns.ExitMsg = err.Error()
|
||||
if err != io.EOF {
|
||||
fmt.Fprintf(ns.stderr, "\r\n%v\r\n", trace.Wrap(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -586,8 +616,5 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser) {
|
|||
}
|
||||
|
||||
func (ns *NodeSession) Close() error {
|
||||
if ns.closer != nil {
|
||||
ns.closer.Close()
|
||||
}
|
||||
return nil
|
||||
return ns.closer.Close()
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
package srv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -526,6 +527,8 @@ type session struct {
|
|||
// hasEnhancedRecording returns true if this session has enhanced session
|
||||
// recording events associated.
|
||||
hasEnhancedRecording bool
|
||||
|
||||
serverCtx context.Context
|
||||
}
|
||||
|
||||
// newSession creates a new session with a given ID within a given context.
|
||||
|
@ -597,6 +600,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
|
|||
closeC: make(chan bool),
|
||||
lingerTTL: defaults.SessionIdlePeriod,
|
||||
startTime: startTime,
|
||||
serverCtx: ctx.srv.Context(),
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
@ -636,9 +640,9 @@ func (s *session) Close() error {
|
|||
s.term.Close()
|
||||
}
|
||||
close(s.closeC)
|
||||
|
||||
// close all writers in our multi-writer
|
||||
s.writer.Close()
|
||||
if s.recorder != nil {
|
||||
s.recorder.Close(s.serverCtx)
|
||||
}
|
||||
}()
|
||||
})
|
||||
return nil
|
||||
|
@ -1327,18 +1331,6 @@ func (m *multiWriter) Write(p []byte) (n int, err error) {
|
|||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *multiWriter) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for writerName, writer := range m.writers {
|
||||
logrus.Debugf("Closing session writer: %v.", writerName)
|
||||
if closer, ok := writer.WriteCloser.(io.Closer); ok {
|
||||
closer.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiWriter) getRecentWrites() []byte {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
|
|
@ -29,7 +29,7 @@ func NewCloseBroadcaster() *CloseBroadcaster {
|
|||
|
||||
// CloseBroadcaster is a helper struct
|
||||
// that implements io.Closer and uses channel
|
||||
// to broadcast it's closed state once called
|
||||
// to broadcast its closed state once called
|
||||
type CloseBroadcaster struct {
|
||||
sync.Once
|
||||
C chan struct{}
|
||||
|
|
Loading…
Reference in a new issue