close outstanding connections when invalidating the session

This commit is contained in:
klizhentas 2016-02-26 18:10:01 -08:00
parent 2320adb534
commit 447e839f39
5 changed files with 90 additions and 5 deletions

View file

@ -17,6 +17,7 @@ limitations under the License.
package web
import (
"fmt"
"net"
"net/http"
@ -71,9 +72,15 @@ type connectHandler struct {
site reversetunnel.RemoteSite
up *sshutils.Upstream
req connectReq
ws *websocket.Conn
}
func (w *connectHandler) String() string {
return fmt.Sprintf("connectHandler(%#v)", w.req)
}
func (w *connectHandler) Close() error {
w.ws.Close()
if w.up != nil {
return w.up.Close()
}
@ -87,6 +94,7 @@ func (w *connectHandler) connect(ws *websocket.Conn) {
return
}
w.up = up
w.ws = ws
err = w.up.PipeShell(ws)
log.Infof("pipe shell finished with: %v", err)
ws.Write([]byte("\n\rdisconnected\n\r"))

View file

@ -17,6 +17,7 @@ limitations under the License.
package web
import (
"io"
"net/http"
"sync"
"time"
@ -37,11 +38,27 @@ import (
// between requests for example to avoid connecting
// to the auth server on every page hit
type sessionContext struct {
sync.Mutex
*log.Entry
sess *auth.Session
user string
clt *auth.TunClient
parent *sessionCache
sess *auth.Session
user string
clt *auth.TunClient
parent *sessionCache
closers []io.Closer
}
func (c *sessionContext) AddClosers(closers ...io.Closer) {
c.Lock()
defer c.Unlock()
c.closers = append(c.closers, closers...)
}
func (c *sessionContext) TransferClosers() []io.Closer {
c.Lock()
defer c.Unlock()
closers := c.closers
c.closers = nil
return closers
}
func (c *sessionContext) Invalidate() error {
@ -89,6 +106,11 @@ func (c *sessionContext) GetAuthMethods() ([]ssh.AuthMethod, error) {
// Close cleans up connections associated with requests
func (c *sessionContext) Close() error {
closers := c.TransferClosers()
for _, closer := range closers {
c.Infof("closing %v", closer)
closer.Close()
}
if c.clt != nil {
return trace.Wrap(c.clt.Close())
}

View file

@ -51,9 +51,11 @@ type sessionStreamHandler struct {
site reversetunnel.RemoteSite
sessionID string
closeC chan bool
ws *websocket.Conn
}
func (w *sessionStreamHandler) Close() error {
w.ws.Close()
w.closeOnce.Do(func() {
close(w.closeC)
})
@ -61,6 +63,7 @@ func (w *sessionStreamHandler) Close() error {
}
func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
w.ws = ws
clt, err := w.site.GetClient()
if err != nil {
return trace.Wrap(err)

View file

@ -236,6 +236,13 @@ func (m *Handler) renewSession(w http.ResponseWriter, r *http.Request, _ httprou
if err != nil {
return nil, trace.Wrap(err)
}
// transfer ownership over connections that were opened in the
// sessionContext
newContext, err := ctx.parent.ValidateSession(newSess.User.Name, newSess.ID)
if err != nil {
return nil, trace.Wrap(err)
}
newContext.AddClosers(ctx.TransferClosers()...)
if err := SetSession(w, newSess.User.Name, newSess.ID); err != nil {
return nil, trace.Wrap(err)
}
@ -442,6 +449,9 @@ func (m *Handler) siteNodeConnect(w http.ResponseWriter, r *http.Request, p http
if err != nil {
return nil, trace.Wrap(err)
}
// this is to make sure we close web socket connections once
// sessionContext that owns them expires
ctx.AddClosers(connect)
defer connect.Close()
connect.Handler().ServeHTTP(w, r)
return nil, nil
@ -471,6 +481,9 @@ func (m *Handler) siteSessionStream(w http.ResponseWriter, r *http.Request, p ht
if err != nil {
return nil, trace.Wrap(err)
}
// this is to make sure we close web socket connections once
// sessionContext that owns them expires
ctx.AddClosers(connect)
defer connect.Close()
connect.Handler().ServeHTTP(w, r)
return nil, nil

View file

@ -76,7 +76,7 @@ type WebSuite struct {
var _ = Suite(&WebSuite{})
func (s *WebSuite) SetUpSuite(c *C) {
utils.InitLoggerDebug()
utils.InitLoggerCLI()
}
func (s *WebSuite) SetUpTest(c *C) {
@ -613,6 +613,45 @@ func (s *WebSuite) TestNodesWithSessions(c *C) {
c.Assert(len(event.Session.Parties), Equals, 2)
}
func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) {
sid := "testsession"
pack := s.authPack(c)
clt := s.connect(c, pack, sid)
defer clt.Close()
// to make sure we have a session
_, err := io.WriteString(clt, "expr 137 + 39\r\n")
c.Assert(err, IsNil)
// make sure server has replied
out := make([]byte, 100)
clt.Read(out)
_, err = pack.clt.Delete(
pack.clt.Endpoint("webapi", "sessions", pack.session.Token))
c.Assert(err, IsNil)
// wait until we timeout or detect that connection has been closed
after := time.After(time.Second)
errC := make(chan error)
go func() {
for {
_, err := clt.Read(out)
if err != nil {
errC <- err
}
}
}()
select {
case <-after:
c.Fatalf("timeout")
case err := <-errC:
c.Assert(err, Equals, io.EOF)
}
}
func getEvent(schema string, events []lunk.Entry) *lunk.Entry {
for i := range events {
e := events[i]