Fix Moderated session on leave pause action. (#21366)

* Fix Moderated session on leave pause action.

* Unpause when moderator rejoins session with onLeave=pause

This commite fixes a case when the moderator rejoins a paused session
but Teleport didn't resume the operations and the session was still
locked to everyone.

* Fix panic when re-joining SSH session.

* Update lib/kube/proxy/sess.go

Co-authored-by: Noah Stride <noah.stride@goteleport.com>

* Update lib/kube/proxy/sess.go

Co-authored-by: Noah Stride <noah.stride@goteleport.com>

* Fix typo

Co-authored-by: Tiago Silva <tiago.silva@goteleport.com>

---------

Co-authored-by: Tiago Silva <tiago.silva@goteleport.com>
Co-authored-by: Noah Stride <noah.stride@goteleport.com>
This commit is contained in:
Jakub Nyckowski 2023-02-17 09:11:59 -05:00 committed by GitHub
parent 20d4c79812
commit 0db64bbabe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 47 deletions

View file

@ -32,15 +32,17 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
)
type OnSessionLeaveAction string
const (
// OnSessionLeaveTerminate is a moderated sessions policy constant that terminates
// a session once the require policy is no longer fulfilled.
OnSessionLeaveTerminate = "terminate"
OnSessionLeaveTerminate OnSessionLeaveAction = "terminate"
// OnSessionLeaveTerminate is a moderated sessions policy constant that pauses
// a session once the require policies is no longer fulfilled. It is resumed
// once the requirements are fulfilled again.
OnSessionLeavePause = "pause"
OnSessionLeavePause OnSessionLeaveAction = "pause"
)
// Role contains a set of permissions or settings

View file

@ -235,12 +235,12 @@ func SliceContainsMode(s []types.SessionParticipantMode, e types.SessionParticip
return false
}
// PolicyOptions is a set of settings for the session determined by the matched require policy.
// PolicyOptions is a set of settings for the session determined by the matched required policy.
type PolicyOptions struct {
TerminateOnLeave bool
OnLeaveAction types.OnSessionLeaveAction
}
// Generate a pretty-printed string of precise requirements for session start suitable for user display.
// PrettyRequirementsList generates a pretty-printed string of precise requirements for session start suitable for user display.
func (e *SessionAccessEvaluator) PrettyRequirementsList() string {
s := new(strings.Builder)
s.WriteString("require all:")
@ -275,7 +275,7 @@ func (e *SessionAccessEvaluator) extractApplicablePolicies(set *types.SessionTra
// FulfilledFor checks if a given session may run with a list of participants.
func (e *SessionAccessEvaluator) FulfilledFor(participants []SessionAccessContext) (bool, PolicyOptions, error) {
options := PolicyOptions{TerminateOnLeave: true}
var options PolicyOptions
// Check every policy set to check if it's fulfilled.
// We need every policy set to match to allow the session.
@ -286,6 +286,17 @@ policySetLoop:
continue
}
if options.OnLeaveAction != types.OnSessionLeaveTerminate {
terminateOnLeave := types.OnSessionLeavePause
for _, p := range policies {
if p.OnLeave != string(types.OnSessionLeavePause) {
terminateOnLeave = types.OnSessionLeaveTerminate
break
}
}
options = PolicyOptions{OnLeaveAction: terminateOnLeave}
}
// Check every require policy to see if it's fulfilled.
// Only one needs to be checked to pass the policyset.
for _, requirePolicy := range policies {
@ -309,10 +320,10 @@ policySetLoop:
// Evaluate the filter in the require policy against the participant and allow policy.
matchesPredicate, err := e.matchesPredicate(&participant, requirePolicy, allowPolicy)
if err != nil {
return false, PolicyOptions{}, trace.Wrap(err)
return false, options, trace.Wrap(err)
}
// If the the filter matches the participant and the allow policy matches the session
// If the filter matches the participant and the allow policy matches the session
// we conclude that the participant matches against the require policy.
if matchesPredicate && e.matchesJoin(allowPolicy) {
left--
@ -322,13 +333,6 @@ policySetLoop:
// If we've matched enough participants against the require policy, we can allow the session.
if left <= 0 {
switch requirePolicy.OnLeave {
case types.OnSessionLeaveTerminate:
case types.OnSessionLeavePause:
options.TerminateOnLeave = false
default:
}
// We matched at least one require policy within the set. Continue ahead.
continue policySetLoop
}

View file

@ -31,6 +31,7 @@ type startTestCase struct {
participants []SessionAccessContext
owner string
expected []bool
terminate types.OnSessionLeaveAction
}
func successStartTestCase(t *testing.T) startTestCase {
@ -43,14 +44,14 @@ func successStartTestCase(t *testing.T) startTestCase {
Filter: "contains(user.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 2,
OnLeave: types.OnSessionLeaveTerminate,
OnLeave: string(types.OnSessionLeaveTerminate),
Modes: []string{"peer"},
}})
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{string("*")},
Modes: []string{"*"},
}})
return startTestCase{
@ -69,7 +70,99 @@ func successStartTestCase(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: []bool{true, true},
expected: []bool{true, true},
terminate: types.OnSessionLeaveTerminate,
}
}
func successStartTestCasePause(t *testing.T) startTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV6{})
require.NoError(t, err)
participantRole, err := types.NewRole("participant", types.RoleSpecV6{})
require.NoError(t, err)
hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{
Filter: "contains(user.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 2,
OnLeave: string(types.OnSessionLeavePause),
Modes: []string{"peer"},
}})
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{"*"},
}})
return startTestCase{
name: "successStartTestCasePause",
host: []types.Role{hostRole},
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participants: []SessionAccessContext{
{
Username: "participant",
Roles: []types.Role{participantRole},
Mode: "peer",
},
{
Username: "participant2",
Roles: []types.Role{participantRole},
Mode: "peer",
},
},
expected: []bool{true, true},
terminate: types.OnSessionLeavePause,
}
}
func pauseCanBeOverwritten(t *testing.T) startTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV6{})
require.NoError(t, err)
participantRole, err := types.NewRole("participant", types.RoleSpecV6{})
require.NoError(t, err)
hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{
{
Filter: "contains(user.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 2,
OnLeave: string(types.OnSessionLeavePause),
Modes: []string{"peer"},
},
{
Filter: "contains(user.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 2,
OnLeave: string(types.OnSessionLeaveTerminate),
Modes: []string{"peer"},
},
})
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{"*"},
}})
return startTestCase{
name: "pauseCanBeOverwritten",
host: []types.Role{hostRole},
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participants: []SessionAccessContext{
{
Username: "participant",
Roles: []types.Role{participantRole},
Mode: "peer",
},
{
Username: "participant2",
Roles: []types.Role{participantRole},
Mode: "peer",
},
},
expected: []bool{true, true},
terminate: types.OnSessionLeaveTerminate,
}
}
@ -83,7 +176,7 @@ func successStartTestCaseSpec(t *testing.T) startTestCase {
Filter: "contains(user.spec.roles, \"participant\")",
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Count: 2,
OnLeave: types.OnSessionLeaveTerminate,
OnLeave: string(types.OnSessionLeaveTerminate),
Modes: []string{"peer"},
}})
@ -109,7 +202,8 @@ func successStartTestCaseSpec(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: []bool{true, true},
expected: []bool{true, true},
terminate: types.OnSessionLeaveTerminate,
}
}
@ -129,7 +223,7 @@ func failCountStartTestCase(t *testing.T) startTestCase {
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{string("*")},
Modes: []string{"*"},
}})
return startTestCase{
@ -148,7 +242,8 @@ func failCountStartTestCase(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: []bool{false, false},
expected: []bool{false, false},
terminate: types.OnSessionLeaveTerminate,
}
}
@ -187,7 +282,7 @@ func failFilterStartTestCase(t *testing.T) startTestCase {
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{hostRole.GetName()},
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string("*")},
Modes: []string{"*"},
}})
return startTestCase{
@ -206,21 +301,29 @@ func failFilterStartTestCase(t *testing.T) startTestCase {
Mode: "peer",
},
},
expected: []bool{false},
expected: []bool{false},
terminate: types.OnSessionLeaveTerminate,
}
}
func TestSessionAccessStart(t *testing.T) {
t.Parallel()
testCases := []startTestCase{
successStartTestCase(t),
successStartTestCasePause(t),
successStartTestCaseSpec(t),
failCountStartTestCase(t),
failFilterStartTestCase(t),
succeedDiscardPolicySetStartTestCase(t),
pauseCanBeOverwritten(t),
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
var policies []*types.SessionTrackerPolicySet
for _, role := range testCase.host {
policySet := role.GetSessionPolicySet()
@ -229,9 +332,10 @@ func TestSessionAccessStart(t *testing.T) {
for i, kind := range testCase.sessionKinds {
evaluator := NewSessionAccessEvaluator(policies, kind, testCase.owner)
result, _, err := evaluator.FulfilledFor(testCase.participants)
result, policyOptions, err := evaluator.FulfilledFor(testCase.participants)
require.NoError(t, err)
require.Equal(t, testCase.expected[i], result)
require.Equal(t, testCase.terminate, policyOptions.OnLeaveAction)
}
})
}
@ -383,6 +487,8 @@ func versionDefaultJoinTestCase(t *testing.T) joinTestCase {
}
func TestSessionAccessJoin(t *testing.T) {
t.Parallel()
testCases := []joinTestCase{
successJoinTestCase(t),
successGlobJoinTestCase(t),

View file

@ -927,12 +927,12 @@ func (s *session) join(p *party) error {
}()
}
if !s.started {
canStart, _, err := s.canStart()
if err != nil {
return trace.Wrap(err)
}
canStart, _, err := s.canStart()
if err != nil {
return trace.Wrap(err)
}
if !s.started {
if canStart {
go func() {
if err := s.launch(); err != nil {
@ -948,6 +948,18 @@ func (s *session) join(p *party) error {
s.BroadcastMessage(base)
}
}
} else if canStart && s.tracker.GetState() == types.SessionState_SessionStatePending {
// If the session is already running, but the party is a moderator that left
// a session with onLeave=pause and then rejoined, we need to unpause the session.
// When the moderator left the session, the session was paused, and we spawn
// a goroutine to wait for the moderator to rejoin. If the moderator rejoins
// before the session ends, we need to unpause the session by updating its state and
// the goroutine will unblock the s.io terminal.
// types.SessionState_SessionStatePending marks a session that is waiting for
// a moderator to rejoin.
if err := s.tracker.UpdateState(s.forwarder.ctx, types.SessionState_SessionStateRunning); err != nil {
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning)
}
}
return nil
@ -1049,7 +1061,7 @@ func (s *session) unlockedLeave(id uuid.UUID) (bool, error) {
}
if !canStart {
if options.TerminateOnLeave {
if options.OnLeaveAction == types.OnSessionLeaveTerminate {
go func() {
if err := s.Close(); err != nil {
s.log.WithError(err).Errorf("Failed to close session")

View file

@ -24,6 +24,7 @@ import (
"os/user"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@ -498,6 +499,9 @@ type session struct {
// serverMeta contains metadata about the target node of this session.
serverMeta apievents.ServerMetadata
// started is true after the session start.
started atomic.Bool
}
// newSession creates a new session with a given ID within a given context.
@ -876,7 +880,13 @@ func (s *session) setHasEnhancedRecording(val bool) {
// launch launches the session.
// Must be called under session Lock.
func (s *session) launch(ctx *ServerContext) error {
func (s *session) launch() {
// Mark the session as started here, as we want to avoid double initialization.
if s.started.Swap(true) {
s.log.Debugf("Session has already started")
return
}
s.log.Debug("Launching session")
s.BroadcastMessage("Connecting to %v over SSH", s.serverMeta.ServerHostname)
@ -933,8 +943,6 @@ func (s *session) launch(ctx *ServerContext) error {
_, err := io.Copy(s.term.PTY(), s.io)
s.log.Debugf("Copying from reader to PTY completed with error %v.", err)
}()
return nil
}
// startInteractive starts a new interactive process (or a shell) in the
@ -1276,7 +1284,7 @@ func (s *session) removePartyUnderLock(p *party) error {
// Remove party for the term writer
s.io.DeleteWriter(string(p.id))
// Emit session leave event to both the Audit Log as well as over the
// Emit session leave event to both the Audit Log and over the
// "x-teleport-event" channel in the SSH connection.
s.emitSessionLeaveEvent(p.ctx)
@ -1286,7 +1294,7 @@ func (s *session) removePartyUnderLock(p *party) error {
}
if !canRun {
if policyOptions.TerminateOnLeave {
if policyOptions.OnLeaveAction == types.OnSessionLeaveTerminate {
// Force termination in goroutine to avoid deadlock
go s.registry.ForceTerminate(s.scx)
return nil
@ -1474,18 +1482,30 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
return trace.Wrap(err)
}
if canStart {
if err := s.launch(s.scx); err != nil {
s.log.WithError(err).Error("Failed to launch session")
}
return nil
}
switch {
case canStart && !s.started.Load():
s.launch()
base := "Waiting for required participants..."
if s.displayParticipantRequirements {
s.BroadcastMessage(base+"\r\n%v", s.access.PrettyRequirementsList())
} else {
s.BroadcastMessage(base)
return nil
case canStart:
// If the session is already running, but the party is a moderator that leaved
// a session with onLeave=pause and then rejoined, we need to unpause the session.
// When the moderator leaved the session, the session was paused, and we spawn
// a goroutine to wait for the moderator to rejoin. If the moderator rejoins
// before the session ends, we need to unpause the session by updating its state and
// the goroutine will unblock the s.io terminal.
// types.SessionState_SessionStatePending marks a session that is waiting for
// a moderator to rejoin.
if err := s.tracker.UpdateState(s.serverCtx, types.SessionState_SessionStateRunning); err != nil {
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning)
}
default:
const base = "Waiting for required participants..."
if s.displayParticipantRequirements {
s.BroadcastMessage(base+"\r\n%v", s.access.PrettyRequirementsList())
} else {
s.BroadcastMessage(base)
}
}
}