Guard server session with a mutex to prevent races when the session is (#5365)

used from multiple goroutines.
This commit is contained in:
a-palchikov 2021-02-03 13:47:49 +01:00 committed by GitHub
parent cc35ce0912
commit aa5c5223a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 156 additions and 120 deletions

View file

@ -175,7 +175,7 @@ type ServerContext struct {
*sshutils.ConnectionContext
*log.Entry
sync.RWMutex
mu sync.RWMutex
// env is a list of environment variables passed to the session.
env map[string]string
@ -387,6 +387,11 @@ func (c *ServerContext) ID() int {
// SessionID returns the ID of the session in the context.
func (c *ServerContext) SessionID() rsession.ID {
c.mu.RLock()
defer c.mu.RUnlock()
if c.session == nil {
return ""
}
return c.session.id
}
@ -398,9 +403,11 @@ func (c *ServerContext) GetServer() Server {
// CreateOrJoinSession will look in the SessionRegistry for the session ID. If
// no session is found, a new one is created. If one is found, it is returned.
func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
c.mu.Lock()
defer c.mu.Unlock()
// As SSH conversation progresses, at some point a session will be created and
// its ID will be added to the environment
ssid, found := c.GetEnv(sshutils.SessionEnvVar)
ssid, found := c.getEnvLocked(sshutils.SessionEnvVar)
if !found {
return nil
}
@ -411,9 +418,9 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
}
findSession := func() (*session, bool) {
reg.Lock()
defer reg.Unlock()
return reg.findSession(rsession.ID(ssid))
reg.mu.Lock()
defer reg.mu.Unlock()
return reg.findSessionLocked(rsession.ID(ssid))
}
// update ctx with a session ID
@ -421,7 +428,7 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
if c.session == nil {
log.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
} else {
log.Debugf("Will join session %v for SSH connection %v.", c.session, c.ServerConn.RemoteAddr())
log.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr())
}
return nil
@ -435,45 +442,47 @@ func (c *ServerContext) TrackActivity(ch ssh.Channel) ssh.Channel {
// GetClientLastActive returns time when client was last active
func (c *ServerContext) GetClientLastActive() time.Time {
c.RLock()
defer c.RUnlock()
c.mu.RLock()
defer c.mu.RUnlock()
return c.clientLastActive
}
// UpdateClientActivity sets last recorded client activity associated with this context
// either channel or session
func (c *ServerContext) UpdateClientActivity() {
c.Lock()
defer c.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
c.clientLastActive = c.srv.GetClock().Now().UTC()
}
// AddCloser adds any closer in ctx that will be called
// whenever server closes session channel
func (c *ServerContext) AddCloser(closer io.Closer) {
c.Lock()
defer c.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
c.closers = append(c.closers, closer)
}
// GetTerm returns a Terminal.
func (c *ServerContext) GetTerm() Terminal {
c.RLock()
defer c.RUnlock()
c.mu.RLock()
defer c.mu.RUnlock()
return c.term
}
// SetTerm set a Terminal.
func (c *ServerContext) SetTerm(t Terminal) {
c.Lock()
defer c.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
c.term = t
}
// VisitEnv grants visitor-style access to env variables.
func (c *ServerContext) VisitEnv(visit func(key, val string)) {
c.mu.RLock()
defer c.mu.RUnlock()
// visit the parent env first since locally defined variables
// effectively "override" parent defined variables.
c.Parent().VisitEnv(visit)
@ -484,11 +493,19 @@ func (c *ServerContext) VisitEnv(visit func(key, val string)) {
// SetEnv sets a environment variable within this context.
func (c *ServerContext) SetEnv(key, val string) {
c.mu.Lock()
c.env[key] = val
c.mu.Unlock()
}
// GetEnv returns a environment variable within this context.
func (c *ServerContext) GetEnv(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.getEnvLocked(key)
}
func (c *ServerContext) getEnvLocked(key string) (string, bool) {
val, ok := c.env[key]
if ok {
return val, true
@ -496,12 +513,26 @@ func (c *ServerContext) GetEnv(key string) (string, bool) {
return c.Parent().GetEnv(key)
}
// setSession sets the context's session
func (c *ServerContext) setSession(sess *session) {
c.mu.Lock()
defer c.mu.Unlock()
c.session = sess
}
// getSession returns the context's session
func (c *ServerContext) getSession() *session {
c.mu.RLock()
defer c.mu.RUnlock()
return c.session
}
// takeClosers returns all resources that should be closed and sets the properties to null
// we do this to avoid calling Close() under lock to avoid potential deadlocks
func (c *ServerContext) takeClosers() []io.Closer {
// this is done to avoid any operation holding the lock for too long
c.Lock()
defer c.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
closers := []io.Closer{}
if c.term != nil {
@ -557,8 +588,9 @@ func (c *ServerContext) reportStats(conn utils.Stater) {
if !c.srv.UseTunnel() {
sessionDataEvent.ConnectionMetadata.LocalAddr = c.ServerConn.LocalAddr().String()
}
if c.session != nil {
sessionDataEvent.SessionMetadata.SessionID = string(c.session.id)
sessionID := string(c.SessionID())
if sessionID != "" {
sessionDataEvent.SessionMetadata.SessionID = sessionID
}
if err := c.GetServer().EmitAuditEvent(c.GetServer().Context(), sessionDataEvent); err != nil {
c.WithError(err).Warn("Failed to emit session data event.")
@ -735,13 +767,14 @@ func buildEnvironment(ctx *ServerContext) []string {
}
// If a session has been created try and set TERM, SSH_TTY, and SSH_SESSION_ID.
if ctx.session != nil {
if ctx.session.term != nil {
env = append(env, fmt.Sprintf("TERM=%v", ctx.session.term.GetTermType()))
env = append(env, fmt.Sprintf("SSH_TTY=%s", ctx.session.term.TTY().Name()))
session := ctx.getSession()
if session != nil {
if session.term != nil {
env = append(env, fmt.Sprintf("TERM=%v", session.term.GetTermType()))
env = append(env, fmt.Sprintf("SSH_TTY=%s", session.term.TTY().Name()))
}
if ctx.session.id != "" {
env = append(env, fmt.Sprintf("%s=%s", teleport.SSHSessionID, ctx.session.id))
if session.id != "" {
env = append(env, fmt.Sprintf("%s=%s", teleport.SSHSessionID, session.id))
}
}

View file

@ -366,8 +366,9 @@ func emitExecAuditEvent(ctx *ServerContext, cmd string, execErr error) {
}
var sessionMeta events.SessionMetadata
if ctx.session != nil {
sessionMeta.SessionID = string(ctx.session.id)
sessionID := string(ctx.SessionID())
if sessionID != "" {
sessionMeta.SessionID = sessionID
}
userMeta := events.UserMetadata{

View file

@ -215,9 +215,7 @@ func (t *proxySubsys) Start(sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Requ
)
// did the client pass us a true client IP ahead of time via an environment variable?
// (usually the web client would do that)
ctx.Lock()
trueClientIP, ok := ctx.GetEnv(sshutils.TrueClientAddrVar)
ctx.Unlock()
if ok {
a, err := utils.ParseAddr(trueClientIP)
if err == nil {

View file

@ -65,7 +65,7 @@ func init() {
// SessionRegistry holds a map of all active sessions on a given
// SSH server
type SessionRegistry struct {
sync.Mutex
mu sync.Mutex
// log holds the structured logger
log *logrus.Entry
@ -94,31 +94,31 @@ func NewSessionRegistry(srv Server) (*SessionRegistry, error) {
}
func (s *SessionRegistry) addSession(sess *session) {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sess.id] = sess
}
func (s *SessionRegistry) removeSession(sess *session) {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sess.id)
}
func (s *SessionRegistry) findSession(id rsession.ID) (*session, bool) {
func (s *SessionRegistry) findSessionLocked(id rsession.ID) (*session, bool) {
sess, found := s.sessions[id]
return sess, found
}
func (s *SessionRegistry) Close() {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
for _, se := range s.sessions {
se.Close()
}
s.log.Debugf("Closing Session Registry.")
s.log.Debug("Closing Session Registry.")
}
// emitSessionJoinEvent emits a session join event to both the Audit Log as
@ -138,7 +138,7 @@ func (s *SessionRegistry) emitSessionJoinEvent(ctx *ServerContext) {
ServerAddr: ctx.ServerConn.LocalAddr().String(),
},
SessionMetadata: events.SessionMetadata{
SessionID: string(ctx.session.id),
SessionID: string(ctx.SessionID()),
},
UserMetadata: events.UserMetadata{
User: ctx.Identity.TeleportUser,
@ -154,13 +154,14 @@ func (s *SessionRegistry) emitSessionJoinEvent(ctx *ServerContext) {
}
// Emit session join event to Audit Log.
if err := ctx.session.recorder.EmitAuditEvent(ctx.srv.Context(), sessionJoinEvent); err != nil {
session := ctx.getSession()
if err := session.recorder.EmitAuditEvent(ctx.srv.Context(), sessionJoinEvent); err != nil {
s.log.WithError(err).Warn("Failed to emit session join event.")
}
// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.getParties(ctx.session) {
for _, p := range session.getParties() {
eventPayload, err := json.Marshal(sessionJoinEvent)
if err != nil {
s.log.Warnf("Unable to marshal %v for %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
@ -177,11 +178,12 @@ func (s *SessionRegistry) emitSessionJoinEvent(ctx *ServerContext) {
// OpenSession either joins an existing session or starts a new session.
func (s *SessionRegistry) OpenSession(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) error {
if ctx.session != nil {
ctx.Infof("Joining existing session %v.", ctx.session.id)
session := ctx.getSession()
if session != nil {
ctx.Infof("Joining existing session %v.", session.id)
// Update the in-memory data structure that a party member has joined.
_, err := ctx.session.join(ch, req, ctx)
_, err := session.join(ch, req, ctx)
if err != nil {
return trace.Wrap(err)
}
@ -204,7 +206,7 @@ func (s *SessionRegistry) OpenSession(ch ssh.Channel, req *ssh.Request, ctx *Ser
if err != nil {
return trace.Wrap(err)
}
ctx.session = sess
ctx.setSession(sess)
s.addSession(sess)
ctx.Infof("Creating (interactive) session %v.", sid)
@ -233,7 +235,7 @@ func (s *SessionRegistry) OpenExecSession(channel ssh.Channel, req *ssh.Request,
// Start a non-interactive session (TTY attached). Close the session if an error
// occurs, otherwise it will be closed by the callee.
ctx.session = sess
ctx.setSession(sess)
err = sess.startExec(channel, ctx)
defer sess.Close()
if err != nil {
@ -274,7 +276,7 @@ func (s *SessionRegistry) emitSessionLeaveEvent(party *party) {
// Notify all members of the party that a new member has left over the
// "x-teleport-event" channel.
for _, p := range s.getParties(party.s) {
for _, p := range party.s.getParties() {
eventPayload, err := utils.FastMarshal(sessionLeaveEvent)
if err != nil {
s.log.Warnf("Unable to marshal %v for %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
@ -292,8 +294,8 @@ func (s *SessionRegistry) emitSessionLeaveEvent(party *party) {
// leaveSession removes the given party from this session.
func (s *SessionRegistry) leaveSession(party *party) error {
sess := party.s
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
// Emit session leave event to both the Audit Log as well as over the
// "x-teleport-event" channel in the SSH connection.
@ -378,34 +380,16 @@ func (s *SessionRegistry) leaveSession(party *party) error {
return nil
}
// getParties allows to safely return a list of parties connected to this
// session (as determined by ctx)
func (s *SessionRegistry) getParties(sess *session) []*party {
var parties []*party
if sess == nil {
return parties
}
sess.Lock()
defer sess.Unlock()
for _, p := range sess.parties {
parties = append(parties, p)
}
return parties
}
// NotifyWinChange is called to notify all members in the party that the PTY
// size has changed. The notification is sent as a global SSH request and it
// is the responsibility of the client to update it's window size upon receipt.
func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *ServerContext) error {
if ctx.session == nil {
s.log.Debugf("Unable to update window size, no session found in context.")
session := ctx.getSession()
if session == nil {
s.log.Debug("Unable to update window size, no session found in context.")
return nil
}
sid := ctx.session.id
sid := session.id
// Build the resize event.
resizeEvent := &events.Resize{
@ -433,12 +417,12 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S
// Report the updated window size to the event log (this is so the sessions
// can be replayed correctly).
if err := ctx.session.recorder.EmitAuditEvent(s.srv.Context(), resizeEvent); err != nil {
if err := session.recorder.EmitAuditEvent(s.srv.Context(), resizeEvent); err != nil {
s.log.WithError(err).Warn("Failed to emit resize audit event.")
}
// Update the size of the server side PTY.
err := ctx.session.term.SetWinSize(params)
err := session.term.SetWinSize(params)
if err != nil {
return trace.Wrap(err)
}
@ -453,7 +437,7 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S
// Notify all members of the party (except originator) that the size of the
// window has changed so the client can update it's own local PTY. Note that
// OpenSSH clients will ignore this and not update their own local PTY.
for _, p := range s.getParties(ctx.session) {
for _, p := range session.getParties() {
// Don't send the window change notification back to the originator.
if p.ctx.ID() == ctx.ID() {
continue
@ -478,10 +462,10 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S
}
func (s *SessionRegistry) broadcastResult(sid rsession.ID, r ExecResult) error {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
sess, found := s.findSession(sid)
sess, found := s.findSessionLocked(sid)
if !found {
return trace.NotFound("session %v not found", sid)
}
@ -492,7 +476,7 @@ func (s *SessionRegistry) broadcastResult(sid rsession.ID, r ExecResult) error {
// session struct describes an active (in progress) SSH session. These sessions
// are managed by 'SessionRegistry' containers which are attached to SSH servers.
type session struct {
sync.Mutex
mu sync.RWMutex
// log holds the structured logger
log *logrus.Entry
@ -617,17 +601,23 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
// ID returns a string representation of the session ID.
func (s *session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id.String()
}
// PID returns the PID of the Teleport process under which the shell is running.
func (s *session) PID() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.term.PID()
}
// Recorder returns a events.SessionRecorder which can be used to emit events
// to a session as well as the audit log.
func (s *session) Recorder() events.StreamWriter {
s.mu.RLock()
defer s.mu.RUnlock()
return s.recorder
}
@ -646,15 +636,7 @@ func (s *session) Close() error {
close(s.closeC)
// close all writers in our multi-writer
s.writer.Lock()
defer s.writer.Unlock()
for writerName, writer := range s.writer.writers {
s.log.Debugf("Closing session writer: %v.", writerName)
closer, ok := io.Writer(writer).(io.WriteCloser)
if ok {
closer.Close()
}
}
s.writer.Close()
}()
})
return nil
@ -663,8 +645,8 @@ func (s *session) Close() error {
// isLingering returns true if every party has left this session. Occurs
// under a lock.
func (s *session) isLingering() bool {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
return len(s.parties) == 0
}
@ -1085,8 +1067,8 @@ func (s *session) String() string {
// removePartyMember removes participant from in-memory representation of
// party members. Occurs under a lock.
func (s *session) removePartyMember(party *party) {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
delete(s.parties, party.id)
}
@ -1105,14 +1087,14 @@ func (s *session) removeParty(p *party) error {
}
func (s *session) GetLingerTTL() time.Duration {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
return s.lingerTTL
}
func (s *session) SetLingerTTL(ttl time.Duration) {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
s.lingerTTL = ttl
}
@ -1123,8 +1105,8 @@ func (s *session) getNamespace() string {
// exportPartyMembers exports participants in the in-memory map of party
// members. Occurs under a lock.
func (s *session) exportPartyMembers() []rsession.Party {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
var partyList []rsession.Party
for _, p := range s.parties {
@ -1142,8 +1124,8 @@ func (s *session) exportPartyMembers() []rsession.Party {
// exportParticipants returns a list of all members that joined the party.
func (s *session) exportParticipants() []string {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
var participants []string
for _, p := range s.participants {
@ -1201,8 +1183,8 @@ func (s *session) heartbeat(ctx *ServerContext) {
// addPartyMember adds participant to in-memory map of party members. Occurs
// under a lock.
func (s *session) addPartyMember(p *party) {
s.Lock()
defer s.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
s.parties[p.id] = p
s.participants[p.id] = p
@ -1221,16 +1203,7 @@ func (s *session) addParty(p *party) error {
// Write last chunk (so the newly joined parties won't stare at a blank
// screen).
getRecentWrite := func() []byte {
s.writer.Lock()
defer s.writer.Unlock()
data := make([]byte, 0, 1024)
for i := range s.writer.recentWrites {
data = append(data, s.writer.recentWrites[i]...)
}
return data
}
if _, err := p.Write(getRecentWrite()); err != nil {
if _, err := p.Write(s.writer.getRecentWrites()); err != nil {
return trace.Wrap(err)
}
@ -1261,12 +1234,21 @@ func (s *session) join(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) (*p
return p, nil
}
func (s *session) getParties() (parties []*party) {
s.mu.Lock()
defer s.mu.Unlock()
for _, p := range s.parties {
parties = append(parties, p)
}
return parties
}
func newMultiWriter() *multiWriter {
return &multiWriter{writers: make(map[string]writerWrapper)}
}
type multiWriter struct {
sync.RWMutex
mu sync.RWMutex
writers map[string]writerWrapper
recentWrites [][]byte
}
@ -1277,14 +1259,14 @@ type writerWrapper struct {
}
func (m *multiWriter) addWriter(id string, w io.WriteCloser, closeOnError bool) {
m.Lock()
defer m.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
m.writers[id] = writerWrapper{WriteCloser: w, closeOnError: closeOnError}
}
func (m *multiWriter) deleteWriter(id string) {
m.Lock()
defer m.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
delete(m.writers, id)
}
@ -1304,8 +1286,8 @@ func (m *multiWriter) lockedAddRecentWrite(p []byte) {
func (m *multiWriter) Write(p []byte) (n int, err error) {
// lock and make a local copy of available writers:
getWriters := func() (writers []writerWrapper) {
m.RLock()
defer m.RUnlock()
m.mu.RLock()
defer m.mu.RUnlock()
writers = make([]writerWrapper, 0, len(m.writers))
for _, w := range m.writers {
writers = append(writers, w)
@ -1334,6 +1316,28 @@ func (m *multiWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (m *multiWriter) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
for writerName, writer := range m.writers {
logrus.Debugf("Closing session writer: %v.", writerName)
if closer, ok := writer.WriteCloser.(io.Closer); ok {
closer.Close()
}
}
return nil
}
func (m *multiWriter) getRecentWrites() []byte {
m.mu.Lock()
defer m.mu.Unlock()
data := make([]byte, 0, 1024)
for i := range m.recentWrites {
data = append(data, m.recentWrites[i]...)
}
return data
}
type party struct {
sync.Mutex

View file

@ -634,7 +634,7 @@ func (t *remoteTerminal) prepareRemoteSession(session *ssh.Session, ctx *ServerC
teleport.SSHSessionWebproxyAddr: ctx.ProxyPublicAddress(),
teleport.SSHTeleportHostUUID: ctx.srv.ID(),
teleport.SSHTeleportClusterName: ctx.ClusterName,
teleport.SSHSessionID: string(ctx.session.id),
teleport.SSHSessionID: string(ctx.SessionID()),
}
for k, v := range envs {