Easier bookkeeping of sessin parties

This commit is contained in:
Ev Kontsevoy 2016-05-01 17:31:28 -07:00
parent 68e8e8f001
commit f4dfbf2e50
11 changed files with 132 additions and 197 deletions

View file

@ -49,9 +49,7 @@ type Config struct {
// APIServer implements http API server for AuthServer interface
type APIServer struct {
httprouter.Router
a *AuthWithRoles
s *AuthServer
se session.Service
a *AuthWithRoles
}
// NewAPIServer returns a new instance of APIServer HTTP handler
@ -110,7 +108,6 @@ func NewAPIServer(a *AuthWithRoles) *APIServer {
srv.POST("/v1/tokens/register/auth", httplib.MakeHandler(srv.registerNewAuthServer))
// Sesssions
srv.POST("/v1/sessions/:id/parties", httplib.MakeHandler(srv.upsertSessionParty))
srv.POST("/v1/sessions", httplib.MakeHandler(srv.createSession))
srv.PUT("/v1/sessions/:id", httplib.MakeHandler(srv.updateSession))
srv.GET("/v1/sessions", httplib.MakeHandler(srv.getSessions))
@ -580,27 +577,6 @@ func (s *APIServer) updateSession(w http.ResponseWriter, r *http.Request, p http
return message("ok"), nil
}
type upsertPartyReq struct {
Party session.Party `json:"party"`
TTL time.Duration `json:"ttl"`
}
func (s *APIServer) upsertSessionParty(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
var req *upsertPartyReq
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
sid, err := session.ParseID(p[0].Value)
if err != nil {
return nil, trace.Wrap(err)
}
if err := s.a.UpsertParty(*sid, req.Party, req.TTL); err != nil {
return nil, trace.Wrap(err)
}
return req.Party, nil
}
func (s *APIServer) getSessions(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
sessions, err := s.a.GetSessions()
if err != nil {

View file

@ -394,34 +394,3 @@ func (s *APISuite) TestSharedSessions(c *C) {
c.Assert(history[0].GetString(events.SessionEventID), Equals, string(sess.ID))
c.Assert(history[0].GetString("1"), Equals, "one")
}
func (s *APISuite) TestSharedSessionsParties(c *C) {
out, err := s.clt.GetSessions()
c.Assert(err, IsNil)
c.Assert(out, DeepEquals, []session.Session{})
date := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
sess := session.Session{
Active: true,
ID: session.NewID(),
TerminalParams: session.TerminalParams{W: 100, H: 100},
Created: date,
LastActive: date,
Login: "bob",
}
c.Assert(s.clt.CreateSession(sess), IsNil)
p1 := session.Party{
ID: session.NewID(),
User: "bob",
RemoteAddr: "example.com",
ServerID: "id-1",
LastActive: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
}
c.Assert(s.clt.UpsertParty(sess.ID, p1, 0), IsNil)
sess.Parties = []session.Party{p1}
out, err = s.clt.GetSessions()
c.Assert(err, IsNil)
c.Assert(out, DeepEquals, []session.Session{sess})
}

View file

@ -48,6 +48,7 @@ func (a *AuthWithRoles) GetSession(id session.ID) (*session.Session, error) {
if err := a.permChecker.HasPermission(a.role, ActionGetSession); err != nil {
return nil, trace.Wrap(err)
}
// TODO (ev): need session registry here
return a.sessions.GetSession(id)
}
@ -64,12 +65,6 @@ func (a *AuthWithRoles) UpdateSession(req session.UpdateRequest) error {
return a.sessions.UpdateSession(req)
}
func (a *AuthWithRoles) UpsertParty(id session.ID, p session.Party, ttl time.Duration) error {
if err := a.permChecker.HasPermission(a.role, ActionUpsertParty); err != nil {
return trace.Wrap(err)
}
return a.sessions.UpsertParty(id, p, ttl)
}
func (a *AuthWithRoles) UpsertCertAuthority(ca services.CertAuthority, ttl time.Duration) error {
if err := a.permChecker.HasPermission(a.role, ActionUpsertCertAuthority); err != nil {
return trace.Wrap(err)

View file

@ -161,16 +161,6 @@ func (c *Client) UpdateSession(req session.UpdateRequest) error {
return trace.Wrap(err)
}
// UpsertParty updates existing session party or inserts new party
func (c *Client) UpsertParty(id session.ID, p session.Party, ttl time.Duration) error {
// saving extra round-trip
if err := id.Check(); err != nil {
return trace.Wrap(err)
}
_, err := c.PostJSON(c.Endpoint("sessions", string(id), "parties"), upsertPartyReq{Party: p, TTL: ttl})
return trace.Wrap(err)
}
// GetLocalDomain returns local auth domain of the current auth server
func (c *Client) GetLocalDomain() (string, error) {
out, err := c.Get(c.Endpoint("domain"), url.Values{})
@ -819,7 +809,6 @@ type ClientI interface {
GetSession(id session.ID) (*session.Session, error)
CreateSession(s session.Session) error
UpdateSession(req session.UpdateRequest) error
UpsertParty(id session.ID, p session.Party, ttl time.Duration) error
UpsertCertAuthority(cert services.CertAuthority, ttl time.Duration) error
GetCertAuthorities(caType services.CertAuthType, loadKeys bool) ([]*services.CertAuthority, error)
DeleteCertAuthority(caType services.CertAuthID) error

View file

@ -67,7 +67,6 @@ func NewStandardPermissions() PermissionChecker {
ActionGetUser: true,
ActionGetLocalDomain: true,
ActionGetUserKeys: true,
ActionUpsertParty: true,
ActionUpsertSession: true,
ActionGetSession: true,
ActionGetSessions: true,
@ -164,7 +163,6 @@ const (
ActionViewSession = "ViewSession"
ActionDeleteSession = "DeleteSession"
ActionUpsertSession = "UpsertSession"
ActionUpsertParty = "UpsertParty"
ActionUpsertCertAuthority = "UpsertCertAuthority"
ActionGetCertAuthorities = "GetCertAuthorities"
ActionGetCertAuthoritiesWithSigningKeys = "GetCertAuthoritiesWithSigningKeys"

View file

@ -190,8 +190,6 @@ func (sl *SessionLogger) Write(bytes []byte) (written int, err error) {
logrus.Error(err)
return written, trace.Wrap(err)
}
logrus.Infof("--> sessionLogger %d bytes -> %v\n%v", written, sl.streamFile.Name(), bytes)
// log this as a session event (but not more often than once a sec)
sl.LogEvent(EventFields{
EventType: SessionPrintEvent,

View file

@ -280,7 +280,6 @@ func (process *TeleportProcess) initAuthService(authority auth.Authority) error
// create the audit log, which will be consuming (and recording) all events
// and record sessions
// TODO (ev): update server configuration to include "logdir" setting
auditLog, err := events.NewAuditLog(filepath.Join(cfg.DataDir, "log"), false)
if err != nil {
log.Error(err)

View file

@ -100,7 +100,7 @@ func NewID() ID {
type Session struct {
// ID is a unique session identifier
ID ID `json:"id"`
// Parties is a list of session parties
// Parties is a list of session parties.
Parties []Party `json:"parties"`
// TerminalParams sets terminal properties
TerminalParams TerminalParams `json:"terminal_params"`
@ -116,6 +116,18 @@ type Session struct {
LastActive time.Time `json:"last_active"`
}
// RemoveParty helper allows to remove a party by it's ID from the
// session's list. Returns 'false' if pid couldn't be found
func (s *Session) RemoveParty(pid ID) bool {
for i := range s.Parties {
if s.Parties[i].ID == pid {
s.Parties = append(s.Parties[:i], s.Parties[i+1:]...)
return true
}
}
return false
}
// Party is a participant a user or a script executing some action
// in the context of the session
type Party struct {
@ -177,6 +189,10 @@ type UpdateRequest struct {
ID ID `json:"id"`
Active *bool `json:"active"`
TerminalParams *TerminalParams `json:"terminal_params"`
// Parties allows to update the list of session parties. nil means
// "do not update", empty list means "everybody is gone"
Parties *[]Party `json:"parties"`
}
// Check returns nil if request is valid, error otherwize
@ -212,8 +228,6 @@ type Service interface {
// UpdateSession updates certain session parameters (last_active, terminal parameters)
// other parameters will not be updated
UpdateSession(req UpdateRequest) error
// UpsertParty upserts active session party
UpsertParty(id ID, p Party, ttl time.Duration) error
}
type server struct {
@ -359,6 +373,8 @@ func (s *server) CreateSession(sess Session) error {
// UpdateSession updates session parameters - can mark it as inactive and update it's terminal parameters
func (s *server) UpdateSession(req UpdateRequest) error {
logrus.Infof("sessionServer.UpdateSession(%v). parties: %v", req.ID, req.Parties)
lock := "sessions" + string(req.ID)
s.bk.AcquireLock(lock, time.Second)
defer s.bk.ReleaseLock(lock)
@ -376,6 +392,9 @@ func (s *server) UpdateSession(req UpdateRequest) error {
if req.Active != nil {
sess.Active = *req.Active
}
if req.Parties != nil {
sess.Parties = *req.Parties
}
err = s.bk.UpsertJSONVal(activeBucket(), string(req.ID), sess, s.activeSessionTTL)
if err != nil {
return trace.Wrap(err)
@ -383,27 +402,6 @@ func (s *server) UpdateSession(req UpdateRequest) error {
return nil
}
// UpsertParty updates or inserts active session party
func (s *server) UpsertParty(sid ID, p Party, ttl time.Duration) error {
session, err := s.GetSession(sid)
if err != nil {
return trace.Wrap(err)
}
update := false
for i := range session.Parties {
if session.Parties[i].ID == p.ID {
session.Parties[i] = p
update = true
break
}
}
if !update {
session.Parties = append(session.Parties, p)
}
session.Active = true
return trace.Wrap(s.bk.UpsertJSONVal(activeBucket(), string(session.ID), session, s.activeSessionTTL))
}
// NewTerminalParamsFromUint32 returns new terminal parameters from uint32 width and height
func NewTerminalParamsFromUint32(w uint32, h uint32) (*TerminalParams, error) {
if w > maxSize || w < minSize {

View file

@ -146,52 +146,55 @@ func (s *BoltSuite) TestSessionsInactivity(c *C) {
}
func (s *BoltSuite) TestPartiesCRUD(c *C) {
// create session:
sess := Session{
ID: NewID(),
Active: true,
TerminalParams: TerminalParams{W: 100, H: 100},
Login: "bob",
Login: "vincent",
LastActive: s.clock.UtcNow(),
Created: s.clock.UtcNow(),
}
c.Assert(s.srv.CreateSession(sess), IsNil)
p1 := Party{
ID: NewID(),
User: "bob",
RemoteAddr: "example.com",
ServerID: "id-1",
LastActive: s.clock.UtcNow(),
// add two people:
parties := []Party{
{
ID: NewID(),
RemoteAddr: "1_remote_addr",
User: "first",
ServerID: "luna",
LastActive: s.clock.UtcNow(),
},
{
ID: NewID(),
RemoteAddr: "2_remote_addr",
User: "second",
ServerID: "luna",
LastActive: s.clock.UtcNow(),
},
}
c.Assert(s.srv.UpsertParty(sess.ID, p1, defaults.ActivePartyTTL), IsNil)
out, err := s.srv.GetSession(sess.ID)
s.srv.UpdateSession(UpdateRequest{
ID: sess.ID,
Parties: &parties,
})
// verify they're in the session:
copy, err := s.srv.GetSession(sess.ID)
c.Assert(err, IsNil)
sess.Parties = []Party{p1}
c.Assert(out, DeepEquals, &sess)
c.Assert(len(copy.Parties), Equals, 2)
// add one more party
p2 := Party{
ID: NewID(),
User: "alice",
RemoteAddr: "example.com",
ServerID: "id-2",
LastActive: s.clock.UtcNow(),
}
c.Assert(s.srv.UpsertParty(sess.ID, p2, defaults.ActivePartyTTL), IsNil)
// empty update (list of parties must not change)
s.srv.UpdateSession(UpdateRequest{ID: sess.ID})
copy, _ = s.srv.GetSession(sess.ID)
c.Assert(len(copy.Parties), Equals, 2)
out, err = s.srv.GetSession(sess.ID)
c.Assert(err, IsNil)
sess.Parties = []Party{p1, p2}
c.Assert(out, DeepEquals, &sess)
// remove the 2nd party:
deleted := copy.RemoveParty(parties[1].ID)
c.Assert(deleted, Equals, true)
s.srv.UpdateSession(UpdateRequest{ID: copy.ID,
Parties: &copy.Parties})
copy, _ = s.srv.GetSession(sess.ID)
c.Assert(len(copy.Parties), Equals, 1)
// Update session party
s.clock.Sleep(time.Second)
p1.LastActive = s.clock.UtcNow()
c.Assert(s.srv.UpsertParty(sess.ID, p1, defaults.ActivePartyTTL), IsNil)
out, err = s.srv.GetSession(sess.ID)
c.Assert(err, IsNil)
sess.Parties = []Party{p1, p2}
c.Assert(out, DeepEquals, &sess)
// we still have the 1st party in:
c.Assert(parties[0].ID, Equals, copy.Parties[0].ID)
}

View file

@ -109,17 +109,13 @@ func (s *sessionRegistry) leaveShell(party *party) error {
return trace.Wrap(err)
}
if len(sess.parties) != 0 {
// emit an audit event
s.srv.EmitAuditEvent(events.SessionLeaveEvent, events.EventFields{
events.SessionEventID: string(sess.id),
events.EventUser: party.user,
events.SessionServerID: party.serverID,
})
return nil
}
// TODO remove session party in session server
// emit an audit event
s.srv.EmitAuditEvent(events.SessionLeaveEvent, events.EventFields{
events.SessionEventID: string(sess.id),
events.EventUser: party.user,
events.SessionServerID: party.serverID,
})
return nil
// this goroutine runs for a short amount of time only after a session
// becomes empty (no parties). It allows session to "linger" for a bit
@ -316,23 +312,6 @@ func (s *session) Close() error {
return trace.Wrap(err)
}
// upsertSessionParty updates the persistence layer (session object stored somewhere on disk)
// with a new connected client.
func (s *session) upsertSessionParty(sid rsession.ID, p *party) error {
if s.registry.srv.sessionServer == nil {
return nil
}
// session registry has a "session server" (which is actually a "session serializer")
// and we ask it to update the on-disk copy of this session with a new party
return s.registry.srv.sessionServer.UpsertParty(sid, rsession.Party{
ID: p.id,
User: p.user,
ServerID: p.serverID,
RemoteAddr: p.site,
LastActive: p.getLastActive(),
}, defaults.ActivePartyTTL)
}
// startShell starts a new shell process in the current session
func (s *session) startShell(ch ssh.Channel, ctx *ctx) error {
// create a new "party" (connected client)
@ -426,14 +405,38 @@ func (s *session) String() string {
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties))
}
// removeParty removes the party from two places:
// 1. from in-memory dictionary inside of this session
// 2. from sessin server's storage
func (s *session) removeParty(p *party) error {
s.Lock()
defer s.Unlock()
p.ctx.Infof("session.removeParty(%v)", p)
delete(s.parties, p.id)
s.writer.deleteWriter(string(p.id))
// in-memory locked remove:
lockedRemove := func() {
s.Lock()
defer s.Unlock()
delete(s.parties, p.id)
s.writer.deleteWriter(string(p.id))
}
lockedRemove()
// remove from the session server (asynchronously)
storageRemove := func(db rsession.Service) {
dbSession, err := db.GetSession(s.id)
if err != nil {
log.Error(err)
return
}
if dbSession.RemoveParty(p.id) {
db.UpdateSession(rsession.UpdateRequest{
ID: dbSession.ID,
Parties: &dbSession.Parties,
})
}
}
if s.registry.srv.sessionServer != nil {
go storageRemove(s.registry.srv.sessionServer)
}
return nil
}
@ -446,7 +449,7 @@ func (s *session) pollAndSyncTerm() {
}
syncTerm := func() error {
sess, err := sessionServer.GetSession(s.id)
if err != nil {
if err != nil || sess == nil {
log.Debugf("syncTerm: no session")
return err
}
@ -477,6 +480,7 @@ func (s *session) pollAndSyncTerm() {
}
}
// addParty is called when a new party joins the session.
func (s *session) addParty(p *party) {
s.parties[p.id] = p
// register this party as one of the session writers
@ -497,6 +501,29 @@ func (s *session) addParty(p *party) {
p.Write(recentData)
}
// update session on the session server
storageUpdate := func(db rsession.Service) {
dbSession, err := db.GetSession(s.id)
if err != nil {
log.Error(err)
return
}
dbSession.Parties = append(dbSession.Parties, rsession.Party{
ID: p.id,
User: p.user,
ServerID: p.serverID,
RemoteAddr: p.site,
LastActive: p.getLastActive(),
})
db.UpdateSession(rsession.UpdateRequest{
ID: dbSession.ID,
Parties: &dbSession.Parties,
})
}
if s.registry.srv.sessionServer != nil {
go storageUpdate(s.registry.srv.sessionServer)
}
// this goroutine keeps pumping party's input into the session
go func() {
defer s.term.Add(-1)
@ -506,22 +533,6 @@ func (s *session) addParty(p *party) {
log.Error(err)
}
}()
// this goroutine updates the status of this party in the auth
// server storage (last activity time)
go func() {
for {
if err := s.upsertSessionParty(s.id, p); err != nil {
p.ctx.Warningf("failed to upsert session party: %v", err)
}
select {
case <-p.closeC:
p.ctx.Infof("party heartbeat ended")
return
case <-time.After(1 * time.Second):
}
}
}()
}
func (s *session) join(ch ssh.Channel, req *ssh.Request, ctx *ctx) (*party, error) {

View file

@ -107,6 +107,14 @@ func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
// keep polling in a loop:
for {
// wait for next timer tick or a signal to abort:
select {
case <-ticker.C:
case <-w.closeC:
log.Infof("[web] session.stream() exited")
return nil
}
newEvents := pollEvents()
sess, err := clt.GetSession(w.sessionID)
if err != nil {
@ -120,9 +128,8 @@ func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
if err != nil {
log.Error(err)
}
log.Infof("[WEB] streaming for %v. Events: %v, Nodes: %v, Parties: %v",
w.sessionID, len(newEvents), len(servers), len(sess.Parties))
log.Infof("[WEB] Events: %v", newEvents)
log.Infof("[WEB] streaming for %v. Events: %v, Nodes: %v, Parties: %v, Events: %v",
w.sessionID, len(newEvents), len(servers), len(sess.Parties), newEvents)
// push events to the web client
event := &sessionStreamEvent{
@ -133,14 +140,6 @@ func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
if err := websocket.JSON.Send(ws, event); err != nil {
log.Error(err)
}
// wait for next timer tick or a signal to abort:
select {
case <-ticker.C:
case <-w.closeC:
log.Infof("[web] session.stream() exited")
return nil
}
}
}