From 0db64bbabe5056e0115c7f038e10f83b6e3ea2a0 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Fri, 17 Feb 2023 09:11:59 -0500 Subject: [PATCH] Fix Moderated session on leave pause action. (#21366) * Fix Moderated session on leave pause action. * Unpause when moderator rejoins session with onLeave=pause This commite fixes a case when the moderator rejoins a paused session but Teleport didn't resume the operations and the session was still locked to everyone. * Fix panic when re-joining SSH session. * Update lib/kube/proxy/sess.go Co-authored-by: Noah Stride * Update lib/kube/proxy/sess.go Co-authored-by: Noah Stride * Fix typo Co-authored-by: Tiago Silva --------- Co-authored-by: Tiago Silva Co-authored-by: Noah Stride --- api/types/role.go | 6 +- lib/auth/session_access.go | 30 ++++---- lib/auth/session_access_test.go | 126 +++++++++++++++++++++++++++++--- lib/kube/proxy/sess.go | 24 ++++-- lib/srv/sess.go | 52 +++++++++---- 5 files changed, 191 insertions(+), 47 deletions(-) diff --git a/api/types/role.go b/api/types/role.go index 1e0c5709524..75534818ee1 100644 --- a/api/types/role.go +++ b/api/types/role.go @@ -32,15 +32,17 @@ import ( "github.com/gravitational/teleport/api/utils/keys" ) +type OnSessionLeaveAction string + const ( // OnSessionLeaveTerminate is a moderated sessions policy constant that terminates // a session once the require policy is no longer fulfilled. - OnSessionLeaveTerminate = "terminate" + OnSessionLeaveTerminate OnSessionLeaveAction = "terminate" // OnSessionLeaveTerminate is a moderated sessions policy constant that pauses // a session once the require policies is no longer fulfilled. It is resumed // once the requirements are fulfilled again. - OnSessionLeavePause = "pause" + OnSessionLeavePause OnSessionLeaveAction = "pause" ) // Role contains a set of permissions or settings diff --git a/lib/auth/session_access.go b/lib/auth/session_access.go index 56b2bc9e9e5..28a6df0a763 100644 --- a/lib/auth/session_access.go +++ b/lib/auth/session_access.go @@ -235,12 +235,12 @@ func SliceContainsMode(s []types.SessionParticipantMode, e types.SessionParticip return false } -// PolicyOptions is a set of settings for the session determined by the matched require policy. +// PolicyOptions is a set of settings for the session determined by the matched required policy. type PolicyOptions struct { - TerminateOnLeave bool + OnLeaveAction types.OnSessionLeaveAction } -// Generate a pretty-printed string of precise requirements for session start suitable for user display. +// PrettyRequirementsList generates a pretty-printed string of precise requirements for session start suitable for user display. func (e *SessionAccessEvaluator) PrettyRequirementsList() string { s := new(strings.Builder) s.WriteString("require all:") @@ -275,7 +275,7 @@ func (e *SessionAccessEvaluator) extractApplicablePolicies(set *types.SessionTra // FulfilledFor checks if a given session may run with a list of participants. func (e *SessionAccessEvaluator) FulfilledFor(participants []SessionAccessContext) (bool, PolicyOptions, error) { - options := PolicyOptions{TerminateOnLeave: true} + var options PolicyOptions // Check every policy set to check if it's fulfilled. // We need every policy set to match to allow the session. @@ -286,6 +286,17 @@ policySetLoop: continue } + if options.OnLeaveAction != types.OnSessionLeaveTerminate { + terminateOnLeave := types.OnSessionLeavePause + for _, p := range policies { + if p.OnLeave != string(types.OnSessionLeavePause) { + terminateOnLeave = types.OnSessionLeaveTerminate + break + } + } + options = PolicyOptions{OnLeaveAction: terminateOnLeave} + } + // Check every require policy to see if it's fulfilled. // Only one needs to be checked to pass the policyset. for _, requirePolicy := range policies { @@ -309,10 +320,10 @@ policySetLoop: // Evaluate the filter in the require policy against the participant and allow policy. matchesPredicate, err := e.matchesPredicate(&participant, requirePolicy, allowPolicy) if err != nil { - return false, PolicyOptions{}, trace.Wrap(err) + return false, options, trace.Wrap(err) } - // If the the filter matches the participant and the allow policy matches the session + // If the filter matches the participant and the allow policy matches the session // we conclude that the participant matches against the require policy. if matchesPredicate && e.matchesJoin(allowPolicy) { left-- @@ -322,13 +333,6 @@ policySetLoop: // If we've matched enough participants against the require policy, we can allow the session. if left <= 0 { - switch requirePolicy.OnLeave { - case types.OnSessionLeaveTerminate: - case types.OnSessionLeavePause: - options.TerminateOnLeave = false - default: - } - // We matched at least one require policy within the set. Continue ahead. continue policySetLoop } diff --git a/lib/auth/session_access_test.go b/lib/auth/session_access_test.go index 71547a97861..1932c189b6f 100644 --- a/lib/auth/session_access_test.go +++ b/lib/auth/session_access_test.go @@ -31,6 +31,7 @@ type startTestCase struct { participants []SessionAccessContext owner string expected []bool + terminate types.OnSessionLeaveAction } func successStartTestCase(t *testing.T) startTestCase { @@ -43,14 +44,14 @@ func successStartTestCase(t *testing.T) startTestCase { Filter: "contains(user.roles, \"participant\")", Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Count: 2, - OnLeave: types.OnSessionLeaveTerminate, + OnLeave: string(types.OnSessionLeaveTerminate), Modes: []string{"peer"}, }}) participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{hostRole.GetName()}, Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, - Modes: []string{string("*")}, + Modes: []string{"*"}, }}) return startTestCase{ @@ -69,7 +70,99 @@ func successStartTestCase(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: []bool{true, true}, + expected: []bool{true, true}, + terminate: types.OnSessionLeaveTerminate, + } +} + +func successStartTestCasePause(t *testing.T) startTestCase { + hostRole, err := types.NewRole("host", types.RoleSpecV6{}) + require.NoError(t, err) + participantRole, err := types.NewRole("participant", types.RoleSpecV6{}) + require.NoError(t, err) + + hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{ + Filter: "contains(user.roles, \"participant\")", + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, + Count: 2, + OnLeave: string(types.OnSessionLeavePause), + Modes: []string{"peer"}, + }}) + + participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ + Roles: []string{hostRole.GetName()}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, + Modes: []string{"*"}, + }}) + + return startTestCase{ + name: "successStartTestCasePause", + host: []types.Role{hostRole}, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, + participants: []SessionAccessContext{ + { + Username: "participant", + Roles: []types.Role{participantRole}, + Mode: "peer", + }, + { + Username: "participant2", + Roles: []types.Role{participantRole}, + Mode: "peer", + }, + }, + expected: []bool{true, true}, + terminate: types.OnSessionLeavePause, + } +} + +func pauseCanBeOverwritten(t *testing.T) startTestCase { + hostRole, err := types.NewRole("host", types.RoleSpecV6{}) + require.NoError(t, err) + participantRole, err := types.NewRole("participant", types.RoleSpecV6{}) + require.NoError(t, err) + + hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{ + { + Filter: "contains(user.roles, \"participant\")", + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, + Count: 2, + OnLeave: string(types.OnSessionLeavePause), + Modes: []string{"peer"}, + }, + { + Filter: "contains(user.roles, \"participant\")", + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, + Count: 2, + OnLeave: string(types.OnSessionLeaveTerminate), + Modes: []string{"peer"}, + }, + }) + + participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ + Roles: []string{hostRole.GetName()}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, + Modes: []string{"*"}, + }}) + + return startTestCase{ + name: "pauseCanBeOverwritten", + host: []types.Role{hostRole}, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, + participants: []SessionAccessContext{ + { + Username: "participant", + Roles: []types.Role{participantRole}, + Mode: "peer", + }, + { + Username: "participant2", + Roles: []types.Role{participantRole}, + Mode: "peer", + }, + }, + expected: []bool{true, true}, + terminate: types.OnSessionLeaveTerminate, } } @@ -83,7 +176,7 @@ func successStartTestCaseSpec(t *testing.T) startTestCase { Filter: "contains(user.spec.roles, \"participant\")", Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Count: 2, - OnLeave: types.OnSessionLeaveTerminate, + OnLeave: string(types.OnSessionLeaveTerminate), Modes: []string{"peer"}, }}) @@ -109,7 +202,8 @@ func successStartTestCaseSpec(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: []bool{true, true}, + expected: []bool{true, true}, + terminate: types.OnSessionLeaveTerminate, } } @@ -129,7 +223,7 @@ func failCountStartTestCase(t *testing.T) startTestCase { participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{hostRole.GetName()}, Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, - Modes: []string{string("*")}, + Modes: []string{"*"}, }}) return startTestCase{ @@ -148,7 +242,8 @@ func failCountStartTestCase(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: []bool{false, false}, + expected: []bool{false, false}, + terminate: types.OnSessionLeaveTerminate, } } @@ -187,7 +282,7 @@ func failFilterStartTestCase(t *testing.T) startTestCase { participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{hostRole.GetName()}, Kinds: []string{string(types.SSHSessionKind)}, - Modes: []string{string("*")}, + Modes: []string{"*"}, }}) return startTestCase{ @@ -206,21 +301,29 @@ func failFilterStartTestCase(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: []bool{false}, + expected: []bool{false}, + terminate: types.OnSessionLeaveTerminate, } } func TestSessionAccessStart(t *testing.T) { + t.Parallel() + testCases := []startTestCase{ successStartTestCase(t), + successStartTestCasePause(t), successStartTestCaseSpec(t), failCountStartTestCase(t), failFilterStartTestCase(t), succeedDiscardPolicySetStartTestCase(t), + pauseCanBeOverwritten(t), } for _, testCase := range testCases { + testCase := testCase t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + var policies []*types.SessionTrackerPolicySet for _, role := range testCase.host { policySet := role.GetSessionPolicySet() @@ -229,9 +332,10 @@ func TestSessionAccessStart(t *testing.T) { for i, kind := range testCase.sessionKinds { evaluator := NewSessionAccessEvaluator(policies, kind, testCase.owner) - result, _, err := evaluator.FulfilledFor(testCase.participants) + result, policyOptions, err := evaluator.FulfilledFor(testCase.participants) require.NoError(t, err) require.Equal(t, testCase.expected[i], result) + require.Equal(t, testCase.terminate, policyOptions.OnLeaveAction) } }) } @@ -383,6 +487,8 @@ func versionDefaultJoinTestCase(t *testing.T) joinTestCase { } func TestSessionAccessJoin(t *testing.T) { + t.Parallel() + testCases := []joinTestCase{ successJoinTestCase(t), successGlobJoinTestCase(t), diff --git a/lib/kube/proxy/sess.go b/lib/kube/proxy/sess.go index c2f6149c416..c4b0f3feefb 100644 --- a/lib/kube/proxy/sess.go +++ b/lib/kube/proxy/sess.go @@ -927,12 +927,12 @@ func (s *session) join(p *party) error { }() } - if !s.started { - canStart, _, err := s.canStart() - if err != nil { - return trace.Wrap(err) - } + canStart, _, err := s.canStart() + if err != nil { + return trace.Wrap(err) + } + if !s.started { if canStart { go func() { if err := s.launch(); err != nil { @@ -948,6 +948,18 @@ func (s *session) join(p *party) error { s.BroadcastMessage(base) } } + } else if canStart && s.tracker.GetState() == types.SessionState_SessionStatePending { + // If the session is already running, but the party is a moderator that left + // a session with onLeave=pause and then rejoined, we need to unpause the session. + // When the moderator left the session, the session was paused, and we spawn + // a goroutine to wait for the moderator to rejoin. If the moderator rejoins + // before the session ends, we need to unpause the session by updating its state and + // the goroutine will unblock the s.io terminal. + // types.SessionState_SessionStatePending marks a session that is waiting for + // a moderator to rejoin. + if err := s.tracker.UpdateState(s.forwarder.ctx, types.SessionState_SessionStateRunning); err != nil { + s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning) + } } return nil @@ -1049,7 +1061,7 @@ func (s *session) unlockedLeave(id uuid.UUID) (bool, error) { } if !canStart { - if options.TerminateOnLeave { + if options.OnLeaveAction == types.OnSessionLeaveTerminate { go func() { if err := s.Close(); err != nil { s.log.WithError(err).Errorf("Failed to close session") diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 37cd30f4d5b..3d275ce7b9d 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -24,6 +24,7 @@ import ( "os/user" "path/filepath" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -498,6 +499,9 @@ type session struct { // serverMeta contains metadata about the target node of this session. serverMeta apievents.ServerMetadata + + // started is true after the session start. + started atomic.Bool } // newSession creates a new session with a given ID within a given context. @@ -876,7 +880,13 @@ func (s *session) setHasEnhancedRecording(val bool) { // launch launches the session. // Must be called under session Lock. -func (s *session) launch(ctx *ServerContext) error { +func (s *session) launch() { + // Mark the session as started here, as we want to avoid double initialization. + if s.started.Swap(true) { + s.log.Debugf("Session has already started") + return + } + s.log.Debug("Launching session") s.BroadcastMessage("Connecting to %v over SSH", s.serverMeta.ServerHostname) @@ -933,8 +943,6 @@ func (s *session) launch(ctx *ServerContext) error { _, err := io.Copy(s.term.PTY(), s.io) s.log.Debugf("Copying from reader to PTY completed with error %v.", err) }() - - return nil } // startInteractive starts a new interactive process (or a shell) in the @@ -1276,7 +1284,7 @@ func (s *session) removePartyUnderLock(p *party) error { // Remove party for the term writer s.io.DeleteWriter(string(p.id)) - // Emit session leave event to both the Audit Log as well as over the + // Emit session leave event to both the Audit Log and over the // "x-teleport-event" channel in the SSH connection. s.emitSessionLeaveEvent(p.ctx) @@ -1286,7 +1294,7 @@ func (s *session) removePartyUnderLock(p *party) error { } if !canRun { - if policyOptions.TerminateOnLeave { + if policyOptions.OnLeaveAction == types.OnSessionLeaveTerminate { // Force termination in goroutine to avoid deadlock go s.registry.ForceTerminate(s.scx) return nil @@ -1474,18 +1482,30 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { return trace.Wrap(err) } - if canStart { - if err := s.launch(s.scx); err != nil { - s.log.WithError(err).Error("Failed to launch session") - } - return nil - } + switch { + case canStart && !s.started.Load(): + s.launch() - base := "Waiting for required participants..." - if s.displayParticipantRequirements { - s.BroadcastMessage(base+"\r\n%v", s.access.PrettyRequirementsList()) - } else { - s.BroadcastMessage(base) + return nil + case canStart: + // If the session is already running, but the party is a moderator that leaved + // a session with onLeave=pause and then rejoined, we need to unpause the session. + // When the moderator leaved the session, the session was paused, and we spawn + // a goroutine to wait for the moderator to rejoin. If the moderator rejoins + // before the session ends, we need to unpause the session by updating its state and + // the goroutine will unblock the s.io terminal. + // types.SessionState_SessionStatePending marks a session that is waiting for + // a moderator to rejoin. + if err := s.tracker.UpdateState(s.serverCtx, types.SessionState_SessionStateRunning); err != nil { + s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning) + } + default: + const base = "Waiting for required participants..." + if s.displayParticipantRequirements { + s.BroadcastMessage(base+"\r\n%v", s.access.PrettyRequirementsList()) + } else { + s.BroadcastMessage(base) + } } }