mirror of
https://github.com/gravitational/teleport
synced 2024-10-22 02:03:24 +00:00
close outstanding connections when invalidating the session
This commit is contained in:
parent
2320adb534
commit
447e839f39
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
package web
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
|
@ -71,9 +72,15 @@ type connectHandler struct {
|
|||
site reversetunnel.RemoteSite
|
||||
up *sshutils.Upstream
|
||||
req connectReq
|
||||
ws *websocket.Conn
|
||||
}
|
||||
|
||||
func (w *connectHandler) String() string {
|
||||
return fmt.Sprintf("connectHandler(%#v)", w.req)
|
||||
}
|
||||
|
||||
func (w *connectHandler) Close() error {
|
||||
w.ws.Close()
|
||||
if w.up != nil {
|
||||
return w.up.Close()
|
||||
}
|
||||
|
@ -87,6 +94,7 @@ func (w *connectHandler) connect(ws *websocket.Conn) {
|
|||
return
|
||||
}
|
||||
w.up = up
|
||||
w.ws = ws
|
||||
err = w.up.PipeShell(ws)
|
||||
log.Infof("pipe shell finished with: %v", err)
|
||||
ws.Write([]byte("\n\rdisconnected\n\r"))
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
package web
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -37,11 +38,27 @@ import (
|
|||
// between requests for example to avoid connecting
|
||||
// to the auth server on every page hit
|
||||
type sessionContext struct {
|
||||
sync.Mutex
|
||||
*log.Entry
|
||||
sess *auth.Session
|
||||
user string
|
||||
clt *auth.TunClient
|
||||
parent *sessionCache
|
||||
sess *auth.Session
|
||||
user string
|
||||
clt *auth.TunClient
|
||||
parent *sessionCache
|
||||
closers []io.Closer
|
||||
}
|
||||
|
||||
func (c *sessionContext) AddClosers(closers ...io.Closer) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.closers = append(c.closers, closers...)
|
||||
}
|
||||
|
||||
func (c *sessionContext) TransferClosers() []io.Closer {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
closers := c.closers
|
||||
c.closers = nil
|
||||
return closers
|
||||
}
|
||||
|
||||
func (c *sessionContext) Invalidate() error {
|
||||
|
@ -89,6 +106,11 @@ func (c *sessionContext) GetAuthMethods() ([]ssh.AuthMethod, error) {
|
|||
|
||||
// Close cleans up connections associated with requests
|
||||
func (c *sessionContext) Close() error {
|
||||
closers := c.TransferClosers()
|
||||
for _, closer := range closers {
|
||||
c.Infof("closing %v", closer)
|
||||
closer.Close()
|
||||
}
|
||||
if c.clt != nil {
|
||||
return trace.Wrap(c.clt.Close())
|
||||
}
|
||||
|
|
|
@ -51,9 +51,11 @@ type sessionStreamHandler struct {
|
|||
site reversetunnel.RemoteSite
|
||||
sessionID string
|
||||
closeC chan bool
|
||||
ws *websocket.Conn
|
||||
}
|
||||
|
||||
func (w *sessionStreamHandler) Close() error {
|
||||
w.ws.Close()
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.closeC)
|
||||
})
|
||||
|
@ -61,6 +63,7 @@ func (w *sessionStreamHandler) Close() error {
|
|||
}
|
||||
|
||||
func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
|
||||
w.ws = ws
|
||||
clt, err := w.site.GetClient()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
|
|
@ -236,6 +236,13 @@ func (m *Handler) renewSession(w http.ResponseWriter, r *http.Request, _ httprou
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
// transfer ownership over connections that were opened in the
|
||||
// sessionContext
|
||||
newContext, err := ctx.parent.ValidateSession(newSess.User.Name, newSess.ID)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
newContext.AddClosers(ctx.TransferClosers()...)
|
||||
if err := SetSession(w, newSess.User.Name, newSess.ID); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -442,6 +449,9 @@ func (m *Handler) siteNodeConnect(w http.ResponseWriter, r *http.Request, p http
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
// this is to make sure we close web socket connections once
|
||||
// sessionContext that owns them expires
|
||||
ctx.AddClosers(connect)
|
||||
defer connect.Close()
|
||||
connect.Handler().ServeHTTP(w, r)
|
||||
return nil, nil
|
||||
|
@ -471,6 +481,9 @@ func (m *Handler) siteSessionStream(w http.ResponseWriter, r *http.Request, p ht
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
// this is to make sure we close web socket connections once
|
||||
// sessionContext that owns them expires
|
||||
ctx.AddClosers(connect)
|
||||
defer connect.Close()
|
||||
connect.Handler().ServeHTTP(w, r)
|
||||
return nil, nil
|
||||
|
|
|
@ -76,7 +76,7 @@ type WebSuite struct {
|
|||
var _ = Suite(&WebSuite{})
|
||||
|
||||
func (s *WebSuite) SetUpSuite(c *C) {
|
||||
utils.InitLoggerDebug()
|
||||
utils.InitLoggerCLI()
|
||||
}
|
||||
|
||||
func (s *WebSuite) SetUpTest(c *C) {
|
||||
|
@ -613,6 +613,45 @@ func (s *WebSuite) TestNodesWithSessions(c *C) {
|
|||
c.Assert(len(event.Session.Parties), Equals, 2)
|
||||
}
|
||||
|
||||
func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) {
|
||||
sid := "testsession"
|
||||
pack := s.authPack(c)
|
||||
clt := s.connect(c, pack, sid)
|
||||
defer clt.Close()
|
||||
|
||||
// to make sure we have a session
|
||||
_, err := io.WriteString(clt, "expr 137 + 39\r\n")
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// make sure server has replied
|
||||
out := make([]byte, 100)
|
||||
clt.Read(out)
|
||||
|
||||
_, err = pack.clt.Delete(
|
||||
pack.clt.Endpoint("webapi", "sessions", pack.session.Token))
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// wait until we timeout or detect that connection has been closed
|
||||
after := time.After(time.Second)
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
for {
|
||||
_, err := clt.Read(out)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-after:
|
||||
c.Fatalf("timeout")
|
||||
case err := <-errC:
|
||||
c.Assert(err, Equals, io.EOF)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getEvent(schema string, events []lunk.Entry) *lunk.Entry {
|
||||
for i := range events {
|
||||
e := events[i]
|
||||
|
|
Loading…
Reference in a new issue