mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 17:23:22 +00:00
Fix Moderated session on leave pause action. (#21366)
* Fix Moderated session on leave pause action. * Unpause when moderator rejoins session with onLeave=pause This commite fixes a case when the moderator rejoins a paused session but Teleport didn't resume the operations and the session was still locked to everyone. * Fix panic when re-joining SSH session. * Update lib/kube/proxy/sess.go Co-authored-by: Noah Stride <noah.stride@goteleport.com> * Update lib/kube/proxy/sess.go Co-authored-by: Noah Stride <noah.stride@goteleport.com> * Fix typo Co-authored-by: Tiago Silva <tiago.silva@goteleport.com> --------- Co-authored-by: Tiago Silva <tiago.silva@goteleport.com> Co-authored-by: Noah Stride <noah.stride@goteleport.com>
This commit is contained in:
parent
20d4c79812
commit
0db64bbabe
|
@ -32,15 +32,17 @@ import (
|
|||
"github.com/gravitational/teleport/api/utils/keys"
|
||||
)
|
||||
|
||||
type OnSessionLeaveAction string
|
||||
|
||||
const (
|
||||
// OnSessionLeaveTerminate is a moderated sessions policy constant that terminates
|
||||
// a session once the require policy is no longer fulfilled.
|
||||
OnSessionLeaveTerminate = "terminate"
|
||||
OnSessionLeaveTerminate OnSessionLeaveAction = "terminate"
|
||||
|
||||
// OnSessionLeaveTerminate is a moderated sessions policy constant that pauses
|
||||
// a session once the require policies is no longer fulfilled. It is resumed
|
||||
// once the requirements are fulfilled again.
|
||||
OnSessionLeavePause = "pause"
|
||||
OnSessionLeavePause OnSessionLeaveAction = "pause"
|
||||
)
|
||||
|
||||
// Role contains a set of permissions or settings
|
||||
|
|
|
@ -235,12 +235,12 @@ func SliceContainsMode(s []types.SessionParticipantMode, e types.SessionParticip
|
|||
return false
|
||||
}
|
||||
|
||||
// PolicyOptions is a set of settings for the session determined by the matched require policy.
|
||||
// PolicyOptions is a set of settings for the session determined by the matched required policy.
|
||||
type PolicyOptions struct {
|
||||
TerminateOnLeave bool
|
||||
OnLeaveAction types.OnSessionLeaveAction
|
||||
}
|
||||
|
||||
// Generate a pretty-printed string of precise requirements for session start suitable for user display.
|
||||
// PrettyRequirementsList generates a pretty-printed string of precise requirements for session start suitable for user display.
|
||||
func (e *SessionAccessEvaluator) PrettyRequirementsList() string {
|
||||
s := new(strings.Builder)
|
||||
s.WriteString("require all:")
|
||||
|
@ -275,7 +275,7 @@ func (e *SessionAccessEvaluator) extractApplicablePolicies(set *types.SessionTra
|
|||
|
||||
// FulfilledFor checks if a given session may run with a list of participants.
|
||||
func (e *SessionAccessEvaluator) FulfilledFor(participants []SessionAccessContext) (bool, PolicyOptions, error) {
|
||||
options := PolicyOptions{TerminateOnLeave: true}
|
||||
var options PolicyOptions
|
||||
|
||||
// Check every policy set to check if it's fulfilled.
|
||||
// We need every policy set to match to allow the session.
|
||||
|
@ -286,6 +286,17 @@ policySetLoop:
|
|||
continue
|
||||
}
|
||||
|
||||
if options.OnLeaveAction != types.OnSessionLeaveTerminate {
|
||||
terminateOnLeave := types.OnSessionLeavePause
|
||||
for _, p := range policies {
|
||||
if p.OnLeave != string(types.OnSessionLeavePause) {
|
||||
terminateOnLeave = types.OnSessionLeaveTerminate
|
||||
break
|
||||
}
|
||||
}
|
||||
options = PolicyOptions{OnLeaveAction: terminateOnLeave}
|
||||
}
|
||||
|
||||
// Check every require policy to see if it's fulfilled.
|
||||
// Only one needs to be checked to pass the policyset.
|
||||
for _, requirePolicy := range policies {
|
||||
|
@ -309,10 +320,10 @@ policySetLoop:
|
|||
// Evaluate the filter in the require policy against the participant and allow policy.
|
||||
matchesPredicate, err := e.matchesPredicate(&participant, requirePolicy, allowPolicy)
|
||||
if err != nil {
|
||||
return false, PolicyOptions{}, trace.Wrap(err)
|
||||
return false, options, trace.Wrap(err)
|
||||
}
|
||||
|
||||
// If the the filter matches the participant and the allow policy matches the session
|
||||
// If the filter matches the participant and the allow policy matches the session
|
||||
// we conclude that the participant matches against the require policy.
|
||||
if matchesPredicate && e.matchesJoin(allowPolicy) {
|
||||
left--
|
||||
|
@ -322,13 +333,6 @@ policySetLoop:
|
|||
|
||||
// If we've matched enough participants against the require policy, we can allow the session.
|
||||
if left <= 0 {
|
||||
switch requirePolicy.OnLeave {
|
||||
case types.OnSessionLeaveTerminate:
|
||||
case types.OnSessionLeavePause:
|
||||
options.TerminateOnLeave = false
|
||||
default:
|
||||
}
|
||||
|
||||
// We matched at least one require policy within the set. Continue ahead.
|
||||
continue policySetLoop
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ type startTestCase struct {
|
|||
participants []SessionAccessContext
|
||||
owner string
|
||||
expected []bool
|
||||
terminate types.OnSessionLeaveAction
|
||||
}
|
||||
|
||||
func successStartTestCase(t *testing.T) startTestCase {
|
||||
|
@ -43,14 +44,14 @@ func successStartTestCase(t *testing.T) startTestCase {
|
|||
Filter: "contains(user.roles, \"participant\")",
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Count: 2,
|
||||
OnLeave: types.OnSessionLeaveTerminate,
|
||||
OnLeave: string(types.OnSessionLeaveTerminate),
|
||||
Modes: []string{"peer"},
|
||||
}})
|
||||
|
||||
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
|
||||
Roles: []string{hostRole.GetName()},
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Modes: []string{string("*")},
|
||||
Modes: []string{"*"},
|
||||
}})
|
||||
|
||||
return startTestCase{
|
||||
|
@ -69,7 +70,99 @@ func successStartTestCase(t *testing.T) startTestCase {
|
|||
Mode: "peer",
|
||||
},
|
||||
},
|
||||
expected: []bool{true, true},
|
||||
expected: []bool{true, true},
|
||||
terminate: types.OnSessionLeaveTerminate,
|
||||
}
|
||||
}
|
||||
|
||||
func successStartTestCasePause(t *testing.T) startTestCase {
|
||||
hostRole, err := types.NewRole("host", types.RoleSpecV6{})
|
||||
require.NoError(t, err)
|
||||
participantRole, err := types.NewRole("participant", types.RoleSpecV6{})
|
||||
require.NoError(t, err)
|
||||
|
||||
hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{
|
||||
Filter: "contains(user.roles, \"participant\")",
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Count: 2,
|
||||
OnLeave: string(types.OnSessionLeavePause),
|
||||
Modes: []string{"peer"},
|
||||
}})
|
||||
|
||||
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
|
||||
Roles: []string{hostRole.GetName()},
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Modes: []string{"*"},
|
||||
}})
|
||||
|
||||
return startTestCase{
|
||||
name: "successStartTestCasePause",
|
||||
host: []types.Role{hostRole},
|
||||
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
|
||||
participants: []SessionAccessContext{
|
||||
{
|
||||
Username: "participant",
|
||||
Roles: []types.Role{participantRole},
|
||||
Mode: "peer",
|
||||
},
|
||||
{
|
||||
Username: "participant2",
|
||||
Roles: []types.Role{participantRole},
|
||||
Mode: "peer",
|
||||
},
|
||||
},
|
||||
expected: []bool{true, true},
|
||||
terminate: types.OnSessionLeavePause,
|
||||
}
|
||||
}
|
||||
|
||||
func pauseCanBeOverwritten(t *testing.T) startTestCase {
|
||||
hostRole, err := types.NewRole("host", types.RoleSpecV6{})
|
||||
require.NoError(t, err)
|
||||
participantRole, err := types.NewRole("participant", types.RoleSpecV6{})
|
||||
require.NoError(t, err)
|
||||
|
||||
hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{
|
||||
{
|
||||
Filter: "contains(user.roles, \"participant\")",
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Count: 2,
|
||||
OnLeave: string(types.OnSessionLeavePause),
|
||||
Modes: []string{"peer"},
|
||||
},
|
||||
{
|
||||
Filter: "contains(user.roles, \"participant\")",
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Count: 2,
|
||||
OnLeave: string(types.OnSessionLeaveTerminate),
|
||||
Modes: []string{"peer"},
|
||||
},
|
||||
})
|
||||
|
||||
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
|
||||
Roles: []string{hostRole.GetName()},
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Modes: []string{"*"},
|
||||
}})
|
||||
|
||||
return startTestCase{
|
||||
name: "pauseCanBeOverwritten",
|
||||
host: []types.Role{hostRole},
|
||||
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
|
||||
participants: []SessionAccessContext{
|
||||
{
|
||||
Username: "participant",
|
||||
Roles: []types.Role{participantRole},
|
||||
Mode: "peer",
|
||||
},
|
||||
{
|
||||
Username: "participant2",
|
||||
Roles: []types.Role{participantRole},
|
||||
Mode: "peer",
|
||||
},
|
||||
},
|
||||
expected: []bool{true, true},
|
||||
terminate: types.OnSessionLeaveTerminate,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,7 +176,7 @@ func successStartTestCaseSpec(t *testing.T) startTestCase {
|
|||
Filter: "contains(user.spec.roles, \"participant\")",
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Count: 2,
|
||||
OnLeave: types.OnSessionLeaveTerminate,
|
||||
OnLeave: string(types.OnSessionLeaveTerminate),
|
||||
Modes: []string{"peer"},
|
||||
}})
|
||||
|
||||
|
@ -109,7 +202,8 @@ func successStartTestCaseSpec(t *testing.T) startTestCase {
|
|||
Mode: "peer",
|
||||
},
|
||||
},
|
||||
expected: []bool{true, true},
|
||||
expected: []bool{true, true},
|
||||
terminate: types.OnSessionLeaveTerminate,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,7 +223,7 @@ func failCountStartTestCase(t *testing.T) startTestCase {
|
|||
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
|
||||
Roles: []string{hostRole.GetName()},
|
||||
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
|
||||
Modes: []string{string("*")},
|
||||
Modes: []string{"*"},
|
||||
}})
|
||||
|
||||
return startTestCase{
|
||||
|
@ -148,7 +242,8 @@ func failCountStartTestCase(t *testing.T) startTestCase {
|
|||
Mode: "peer",
|
||||
},
|
||||
},
|
||||
expected: []bool{false, false},
|
||||
expected: []bool{false, false},
|
||||
terminate: types.OnSessionLeaveTerminate,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -187,7 +282,7 @@ func failFilterStartTestCase(t *testing.T) startTestCase {
|
|||
participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
|
||||
Roles: []string{hostRole.GetName()},
|
||||
Kinds: []string{string(types.SSHSessionKind)},
|
||||
Modes: []string{string("*")},
|
||||
Modes: []string{"*"},
|
||||
}})
|
||||
|
||||
return startTestCase{
|
||||
|
@ -206,21 +301,29 @@ func failFilterStartTestCase(t *testing.T) startTestCase {
|
|||
Mode: "peer",
|
||||
},
|
||||
},
|
||||
expected: []bool{false},
|
||||
expected: []bool{false},
|
||||
terminate: types.OnSessionLeaveTerminate,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionAccessStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []startTestCase{
|
||||
successStartTestCase(t),
|
||||
successStartTestCasePause(t),
|
||||
successStartTestCaseSpec(t),
|
||||
failCountStartTestCase(t),
|
||||
failFilterStartTestCase(t),
|
||||
succeedDiscardPolicySetStartTestCase(t),
|
||||
pauseCanBeOverwritten(t),
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var policies []*types.SessionTrackerPolicySet
|
||||
for _, role := range testCase.host {
|
||||
policySet := role.GetSessionPolicySet()
|
||||
|
@ -229,9 +332,10 @@ func TestSessionAccessStart(t *testing.T) {
|
|||
|
||||
for i, kind := range testCase.sessionKinds {
|
||||
evaluator := NewSessionAccessEvaluator(policies, kind, testCase.owner)
|
||||
result, _, err := evaluator.FulfilledFor(testCase.participants)
|
||||
result, policyOptions, err := evaluator.FulfilledFor(testCase.participants)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testCase.expected[i], result)
|
||||
require.Equal(t, testCase.terminate, policyOptions.OnLeaveAction)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -383,6 +487,8 @@ func versionDefaultJoinTestCase(t *testing.T) joinTestCase {
|
|||
}
|
||||
|
||||
func TestSessionAccessJoin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []joinTestCase{
|
||||
successJoinTestCase(t),
|
||||
successGlobJoinTestCase(t),
|
||||
|
|
|
@ -927,12 +927,12 @@ func (s *session) join(p *party) error {
|
|||
}()
|
||||
}
|
||||
|
||||
if !s.started {
|
||||
canStart, _, err := s.canStart()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
canStart, _, err := s.canStart()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
if !s.started {
|
||||
if canStart {
|
||||
go func() {
|
||||
if err := s.launch(); err != nil {
|
||||
|
@ -948,6 +948,18 @@ func (s *session) join(p *party) error {
|
|||
s.BroadcastMessage(base)
|
||||
}
|
||||
}
|
||||
} else if canStart && s.tracker.GetState() == types.SessionState_SessionStatePending {
|
||||
// If the session is already running, but the party is a moderator that left
|
||||
// a session with onLeave=pause and then rejoined, we need to unpause the session.
|
||||
// When the moderator left the session, the session was paused, and we spawn
|
||||
// a goroutine to wait for the moderator to rejoin. If the moderator rejoins
|
||||
// before the session ends, we need to unpause the session by updating its state and
|
||||
// the goroutine will unblock the s.io terminal.
|
||||
// types.SessionState_SessionStatePending marks a session that is waiting for
|
||||
// a moderator to rejoin.
|
||||
if err := s.tracker.UpdateState(s.forwarder.ctx, types.SessionState_SessionStateRunning); err != nil {
|
||||
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1049,7 +1061,7 @@ func (s *session) unlockedLeave(id uuid.UUID) (bool, error) {
|
|||
}
|
||||
|
||||
if !canStart {
|
||||
if options.TerminateOnLeave {
|
||||
if options.OnLeaveAction == types.OnSessionLeaveTerminate {
|
||||
go func() {
|
||||
if err := s.Close(); err != nil {
|
||||
s.log.WithError(err).Errorf("Failed to close session")
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"os/user"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -498,6 +499,9 @@ type session struct {
|
|||
|
||||
// serverMeta contains metadata about the target node of this session.
|
||||
serverMeta apievents.ServerMetadata
|
||||
|
||||
// started is true after the session start.
|
||||
started atomic.Bool
|
||||
}
|
||||
|
||||
// newSession creates a new session with a given ID within a given context.
|
||||
|
@ -876,7 +880,13 @@ func (s *session) setHasEnhancedRecording(val bool) {
|
|||
|
||||
// launch launches the session.
|
||||
// Must be called under session Lock.
|
||||
func (s *session) launch(ctx *ServerContext) error {
|
||||
func (s *session) launch() {
|
||||
// Mark the session as started here, as we want to avoid double initialization.
|
||||
if s.started.Swap(true) {
|
||||
s.log.Debugf("Session has already started")
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Debug("Launching session")
|
||||
s.BroadcastMessage("Connecting to %v over SSH", s.serverMeta.ServerHostname)
|
||||
|
||||
|
@ -933,8 +943,6 @@ func (s *session) launch(ctx *ServerContext) error {
|
|||
_, err := io.Copy(s.term.PTY(), s.io)
|
||||
s.log.Debugf("Copying from reader to PTY completed with error %v.", err)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startInteractive starts a new interactive process (or a shell) in the
|
||||
|
@ -1276,7 +1284,7 @@ func (s *session) removePartyUnderLock(p *party) error {
|
|||
// Remove party for the term writer
|
||||
s.io.DeleteWriter(string(p.id))
|
||||
|
||||
// Emit session leave event to both the Audit Log as well as over the
|
||||
// Emit session leave event to both the Audit Log and over the
|
||||
// "x-teleport-event" channel in the SSH connection.
|
||||
s.emitSessionLeaveEvent(p.ctx)
|
||||
|
||||
|
@ -1286,7 +1294,7 @@ func (s *session) removePartyUnderLock(p *party) error {
|
|||
}
|
||||
|
||||
if !canRun {
|
||||
if policyOptions.TerminateOnLeave {
|
||||
if policyOptions.OnLeaveAction == types.OnSessionLeaveTerminate {
|
||||
// Force termination in goroutine to avoid deadlock
|
||||
go s.registry.ForceTerminate(s.scx)
|
||||
return nil
|
||||
|
@ -1474,18 +1482,30 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
if canStart {
|
||||
if err := s.launch(s.scx); err != nil {
|
||||
s.log.WithError(err).Error("Failed to launch session")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case canStart && !s.started.Load():
|
||||
s.launch()
|
||||
|
||||
base := "Waiting for required participants..."
|
||||
if s.displayParticipantRequirements {
|
||||
s.BroadcastMessage(base+"\r\n%v", s.access.PrettyRequirementsList())
|
||||
} else {
|
||||
s.BroadcastMessage(base)
|
||||
return nil
|
||||
case canStart:
|
||||
// If the session is already running, but the party is a moderator that leaved
|
||||
// a session with onLeave=pause and then rejoined, we need to unpause the session.
|
||||
// When the moderator leaved the session, the session was paused, and we spawn
|
||||
// a goroutine to wait for the moderator to rejoin. If the moderator rejoins
|
||||
// before the session ends, we need to unpause the session by updating its state and
|
||||
// the goroutine will unblock the s.io terminal.
|
||||
// types.SessionState_SessionStatePending marks a session that is waiting for
|
||||
// a moderator to rejoin.
|
||||
if err := s.tracker.UpdateState(s.serverCtx, types.SessionState_SessionStateRunning); err != nil {
|
||||
s.log.Warnf("Failed to set tracker state to %v", types.SessionState_SessionStateRunning)
|
||||
}
|
||||
default:
|
||||
const base = "Waiting for required participants..."
|
||||
if s.displayParticipantRequirements {
|
||||
s.BroadcastMessage(base+"\r\n%v", s.access.PrettyRequirementsList())
|
||||
} else {
|
||||
s.BroadcastMessage(base)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue