Node session race (#6195)

* Attempt to isolate and improve state handling of a NodeSession.

* Add terminal close for kube terminal tests

* Address review comments

* Small tweaks

Co-authored-by: Andrew Lytvynov <andrew@goteleport.com>
This commit is contained in:
a-palchikov 2021-04-23 02:16:28 +02:00 committed by GitHub
parent 0dbc11d838
commit 4acf50902c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 60 deletions

View file

@ -423,6 +423,7 @@ func (s *IntSuite) TestAuditOn(c *check.C) {
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
myTerm.Type("\aecho hi\n\r\aexit\n\r\a")
myTerm.closeSend()
// wait for session to end:
select {
@ -867,6 +868,7 @@ func (s *IntSuite) verifySessionJoin(c *check.C, t *TeleInstance) {
personA := NewTerminal(250)
personB := NewTerminal(250)
personB.closeSend()
// PersonA: SSH into the server, wait one second, then type some commands on stdin:
openSession := func() {
@ -904,6 +906,7 @@ func (s *IntSuite) verifySessionJoin(c *check.C, t *TeleInstance) {
}
}
c.Assert(err, check.IsNil)
personA.closeSend()
}
go openSession()
@ -936,9 +939,6 @@ func (s *IntSuite) TestShutdown(c *check.C) {
person := NewTerminal(250)
// commandsC receive commands
commandsC := make(chan string)
// PersonA: SSH into the server, wait one second, then type some commands on stdin:
openSession := func() {
cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()})
@ -946,12 +946,6 @@ func (s *IntSuite) TestShutdown(c *check.C) {
cl.Stdout = person
cl.Stdin = person
go func() {
for command := range commandsC {
person.Type(command)
}
}()
err = cl.SSH(context.TODO(), []string{}, false)
c.Assert(err, check.IsNil)
}
@ -991,6 +985,7 @@ func (s *IntSuite) TestShutdown(c *check.C) {
// now type exit and wait for shutdown to complete
person.Type("exit\n\r")
person.closeSend()
select {
case <-shutdownContext.Done():
@ -1020,6 +1015,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
testCases := []disconnectTestCase{
{
comment: "recording at node",
recordingMode: services.RecordAtNode,
options: services.RoleOptions{
ClientIdleTimeout: services.NewDuration(500 * time.Millisecond),
@ -1027,6 +1023,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
disconnectTimeout: time.Second,
},
{
comment: "recording at proxy",
recordingMode: services.RecordAtProxy,
options: services.RoleOptions{
ForwardAgent: services.NewBool(true),
@ -1035,6 +1032,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
disconnectTimeout: time.Second,
},
{
comment: "recording at node: expired certificate is disconnected",
recordingMode: services.RecordAtNode,
options: services.RoleOptions{
DisconnectExpiredCert: services.NewBool(true),
@ -1043,6 +1041,7 @@ func (s *IntSuite) TestDisconnectScenarios(c *check.C) {
disconnectTimeout: 4 * time.Second,
},
{
comment: "recording at proxy: expired certificate is disconnected and forwarding agent",
recordingMode: services.RecordAtProxy,
options: services.RoleOptions{
ForwardAgent: services.NewBool(true),
@ -1215,7 +1214,7 @@ func (s *IntSuite) runDisconnectTest(c *check.C, tc disconnectTestCase) {
select {
case <-time.After(tc.disconnectTimeout + time.Second):
dumpGoroutineProfile()
c.Fatalf("%s: timeout waiting for session to exit: %+v", timeNow(), tc)
c.Fatalf("%s (%s): timeout waiting for session to exit: %+v", timeNow(), tc.comment, tc)
case <-ctx.Done():
// session closed. a test case is successful if the first
// session to close encountered the expected error variant.
@ -1228,6 +1227,7 @@ func timeNow() string {
func enterInput(ctx context.Context, c *check.C, person *Terminal, command, pattern string) {
person.Type(command)
defer person.closeSend()
abortTime := time.Now().Add(10 * time.Second)
var matched bool
var output string
@ -1245,7 +1245,7 @@ func enterInput(ctx context.Context, c *check.C, person *Terminal, command, patt
return
}
if time.Now().After(abortTime) {
c.Fatalf("failed to capture pattern %q in %q", pattern, output)
c.Fatalf("Failed to capture pattern %q in %q", pattern, output)
}
}
}
@ -3239,6 +3239,7 @@ func (s *IntSuite) TestAuditOff(c *check.C) {
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
myTerm.Type("\aecho hi\n\r\aexit\n\r\a")
myTerm.closeSend()
// wait for session to end
select {
@ -3248,9 +3249,7 @@ func (s *IntSuite) TestAuditOff(c *check.C) {
}
// audit log should have the fact that the session occurred recorded in it
sessions, err = site.GetSessions(defaults.Namespace)
c.Assert(err, check.IsNil)
c.Assert(len(sessions), check.Equals, 1)
// but the session could have been garbage collected at this point.
// however, attempts to read the actual sessions should fail because it was
// not actually recorded
@ -3390,6 +3389,7 @@ func (s *IntSuite) TestPAM(c *check.C) {
cl.Stdin = termSession
termSession.Type("\aecho hi\n\r\aexit\n\r\a")
termSession.closeSend()
err = cl.SSH(context.TODO(), []string{}, false)
c.Assert(err, check.IsNil)
@ -4188,6 +4188,7 @@ func (s *IntSuite) TestWindowChange(c *check.C) {
personA := NewTerminal(250)
personB := NewTerminal(250)
defer personB.closeSend()
// openSession will open a new session on a server.
openSession := func() {
@ -4293,6 +4294,7 @@ func (s *IntSuite) TestWindowChange(c *check.C) {
// Close the session.
personA.Type("\aexit\r\n\a")
personA.closeSend()
}
// TestList checks that the list of servers returned is identity aware.
@ -4672,6 +4674,7 @@ func (s *IntSuite) TestBPFInteractive(c *check.C) {
// "Type" a command into the terminal.
term.Type(fmt.Sprintf("\a%v\n\r\aexit\n\r\a", lsPath))
term.closeSend()
err = client.SSH(context.TODO(), []string{}, false)
c.Assert(err, check.IsNil)
@ -4890,6 +4893,8 @@ func (s *IntSuite) TestBPFSessionDifferentiation(c *check.C) {
}
writeTerm(termA)
writeTerm(termB)
termA.closeSend()
termB.closeSend()
// Wait 10 seconds for both events to arrive, otherwise timeout.
timeout := time.After(10 * time.Second)
@ -5082,6 +5087,8 @@ func runCommand(instance *TeleInstance, cmd []string, cfg ClientConfig, attempts
close(doneC)
}()
tc.Stdout = write
var buf bytes.Buffer
tc.Stdin = &buf
for i := 0; i < attempts; i++ {
err = tc.SSH(context.TODO(), cmd, false)
if err == nil {
@ -5184,9 +5191,14 @@ func (t *Terminal) Write(data []byte) (n int, err error) {
return t.written.Write(data)
}
// closeSend closes the input channel thus signalling the reads to exit
func (t *Terminal) closeSend() {
close(t.typed)
}
func (t *Terminal) Read(p []byte) (n int, err error) {
for n = 0; n < len(p); n++ {
p[n] = <-t.typed
for ch := range t.typed {
p[n] = ch
if p[n] == '\r' {
break
}
@ -5195,6 +5207,10 @@ func (t *Terminal) Read(p []byte) (n int, err error) {
n--
}
time.Sleep(time.Millisecond * 10)
n++
if n == len(p) {
return n, nil
}
}
return n, nil
}

View file

@ -266,6 +266,7 @@ func (s *KubeSuite) TestKubeExec(c *check.C) {
term := NewTerminal(250)
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
term.Type("\aecho hi\n\r\aexit\n\r\a")
term.closeSend()
out = &bytes.Buffer{}
err = kubeExec(proxyClientConfig, kubeExecArgs{
@ -309,6 +310,7 @@ loop:
// interactive command, allocate pty
term = NewTerminal(250)
term.Type("\aecho hi\n\r\aexit\n\r\a")
term.closeSend()
out = &bytes.Buffer{}
err = kubeExec(impersonatingProxyClientConfig, kubeExecArgs{
podName: pod.Name,
@ -326,6 +328,7 @@ loop:
// are allowed by the role
term = NewTerminal(250)
term.Type("\aecho hi\n\r\aexit\n\r\a")
term.closeSend()
out = &bytes.Buffer{}
err = kubeExec(scopedProxyClientConfig, kubeExecArgs{
podName: pod.Name,

View file

@ -26,6 +26,7 @@ import (
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
@ -34,6 +35,8 @@ import (
"github.com/moby/term"
"github.com/gravitational/trace"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/client/escape"
"github.com/gravitational/teleport/lib/defaults"
@ -41,7 +44,6 @@ import (
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
)
type NodeSession struct {
@ -218,25 +220,19 @@ func (ns *NodeSession) interactiveSession(callback interactiveCallback) error {
if err != nil {
return trace.Wrap(err)
}
defer remoteTerm.Close()
// call the passed callback and give them the established
// ssh session:
if err := callback(sess, remoteTerm); err != nil {
remoteTerm.Close()
return trace.Wrap(err)
}
// Catch term signals, but only if we're attached to a real terminal
if ns.isTerminalAttached() {
ns.watchSignals(remoteTerm)
}
// start piping input into the remote shell and pipe the output from
// the remote shell into stdout:
ns.pipeInOut(remoteTerm)
// switch the terminal to raw mode (and switch back on exit!)
if ns.isTerminalAttached() {
// switch the terminal to raw mode (and switch back on exit!)
ts, err := term.SetRawTerminal(0)
if err != nil {
log.Warn(err)
@ -244,8 +240,13 @@ func (ns *NodeSession) interactiveSession(callback interactiveCallback) error {
defer term.RestoreTerminal(0, ts)
}
}
// wait for the session to end
<-ns.closer.C
// Pipe input into the remote shell and pipe the output from the remote
// shell into stdout.
//
// Note, pipeInOut takes ownership of remoteTerm and will close it upon
// completion.
ns.pipeInOut(remoteTerm)
return nil
}
@ -537,23 +538,47 @@ func (ns *NodeSession) watchSignals(shell io.Writer) {
}()
}
// pipeInOut launches two goroutines: one to pipe the local input into the remote shell,
// and another to pipe the output of the remote shell into the local output
// pipeInOut pipes the local input into the remote shell, and the output of the
// remote shell into the local output.
func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser) {
// copy from the remote shell to the local output
var wg sync.WaitGroup
defer wg.Wait()
// Create a pipe to use in front of ns.stdin.
//
// This allows us to close the goroutine sending stdin to the remote shell
// while blocked reading from stdin.
stdin, stdinSink := io.Pipe()
go func() {
// Forward actual stdin to the pipe.
//
// Note: this is not registered with the WaitGroup on purpse. This
// goroutine will dangle after the session was terminated, until the
// last read from ns.stdin unblocks.
_, err := io.Copy(stdinSink, ns.stdin)
stdinSink.CloseWithError(err)
}()
// copy from the remote shell to the local output
wg.Add(1)
go func() {
defer wg.Done()
defer ns.closer.Close()
defer stdinSink.Close()
_, err := io.Copy(ns.stdout, shell)
if err != nil {
log.Errorf(err.Error())
log.Error("Error copying from shell:", err.Error())
}
}()
// copy from the local input to the remote shell:
wg.Add(1)
go func() {
defer wg.Done()
defer ns.closer.Close()
defer shell.Close()
defer stdin.Close()
buf := make([]byte, 128)
stdin := ns.stdin
stdin := io.Reader(stdin)
if ns.isTerminalAttached() && ns.enableEscapeSequences {
stdin = escape.NewReader(stdin, ns.stderr, func(err error) {
switch err {
@ -568,16 +593,21 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser) {
})
}
for {
n, err := stdin.Read(buf)
if err != nil {
fmt.Fprintf(ns.stderr, "\r\n%v\r\n", trace.Wrap(err))
select {
case <-ns.closer.C:
return
}
if n > 0 {
_, err = shell.Write(buf[:n])
default:
n, err := stdin.Read(buf)
if n > 0 {
if _, err := shell.Write(buf[:n]); err != nil {
ns.ExitMsg = err.Error()
return
}
}
if err != nil {
ns.ExitMsg = err.Error()
if err != io.EOF {
fmt.Fprintf(ns.stderr, "\r\n%v\r\n", trace.Wrap(err))
}
return
}
}
@ -586,8 +616,5 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser) {
}
func (ns *NodeSession) Close() error {
if ns.closer != nil {
ns.closer.Close()
}
return nil
return ns.closer.Close()
}

View file

@ -17,6 +17,7 @@ limitations under the License.
package srv
import (
"context"
"encoding/json"
"fmt"
"io"
@ -526,6 +527,8 @@ type session struct {
// hasEnhancedRecording returns true if this session has enhanced session
// recording events associated.
hasEnhancedRecording bool
serverCtx context.Context
}
// newSession creates a new session with a given ID within a given context.
@ -597,6 +600,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
closeC: make(chan bool),
lingerTTL: defaults.SessionIdlePeriod,
startTime: startTime,
serverCtx: ctx.srv.Context(),
}
return sess, nil
}
@ -636,9 +640,9 @@ func (s *session) Close() error {
s.term.Close()
}
close(s.closeC)
// close all writers in our multi-writer
s.writer.Close()
if s.recorder != nil {
s.recorder.Close(s.serverCtx)
}
}()
})
return nil
@ -1327,18 +1331,6 @@ func (m *multiWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (m *multiWriter) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
for writerName, writer := range m.writers {
logrus.Debugf("Closing session writer: %v.", writerName)
if closer, ok := writer.WriteCloser.(io.Closer); ok {
closer.Close()
}
}
return nil
}
func (m *multiWriter) getRecentWrites() []byte {
m.mu.Lock()
defer m.mu.Unlock()

View file

@ -29,7 +29,7 @@ func NewCloseBroadcaster() *CloseBroadcaster {
// CloseBroadcaster is a helper struct
// that implements io.Closer and uses channel
// to broadcast it's closed state once called
// to broadcast its closed state once called
type CloseBroadcaster struct {
sync.Once
C chan struct{}