Create single server context for forwarding server.

This commit is contained in:
Russell Jones 2017-12-27 13:08:54 -08:00
parent 4996825586
commit 61b2873b33
3 changed files with 58 additions and 71 deletions

View file

@ -749,7 +749,7 @@ func (s *IntSuite) TestTwoClusters(c *check.C) {
return nil
}
case <-stopCh:
return trace.BadParameter("unable to find %v events after 5s: %v", count)
return trace.BadParameter("unable to find %v events after 5s", count)
}
}
}

View file

@ -201,6 +201,10 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
return ctx
}
func (c *ServerContext) ID() int {
return c.id
}
func (c *ServerContext) GetServer() Server {
return c.srv
}

View file

@ -79,11 +79,6 @@ type Server struct {
// to the client.
hostCertificate ssh.Signer
// remoteClient represents a *ssh.Client connected to the target node.
remoteClient *ssh.Client
// remoteSession represents a *ssh.Session on the target node.
remoteSession *ssh.Session
// authHandlers are common authorization and authentication handlers shared
// by the regular and forwarding server.
authHandlers *srv.AuthHandlers
@ -305,7 +300,7 @@ func (s *Server) Serve() {
// build a remote session to the remote node
s.log.Debugf("Creating remote connection to %v@%v", sconn.User(), s.clientConn.RemoteAddr().String())
s.remoteClient, s.remoteSession, err = s.newRemoteSession(sconn.User())
remoteClient, remoteSession, err := s.newRemoteSession(sconn.User())
if err != nil {
// reject the connection with an error so the client doesn't hang then
// close the connection
@ -320,13 +315,30 @@ func (s *Server) Serve() {
return
}
// create server context for this connection, it's closed when the
// connection is closed
ctx := srv.NewServerContext(s, sconn, identityContext)
ctx.RemoteClient = remoteClient
ctx.RemoteSession = remoteSession
ctx.SetAgent(s.userAgent, s.userAgentChannel)
ctx.AddCloser(sconn)
ctx.AddCloser(s.targetConn)
ctx.AddCloser(s.serverConn)
ctx.AddCloser(s.clientConn)
ctx.AddCloser(remoteSession)
ctx.AddCloser(remoteClient)
s.log.Debugf("Created connection context %v", ctx.ID())
// create a cancelable context and pass it to a keep alive loop. the keep
// alive loop will keep pinging the remote server and after it has missed a
// certain number of keep alive requests it will cancel the context which
// will close any listening goroutines.
ctx, cancel := context.WithCancel(context.Background())
go s.keepAliveLoop(cancel, sconn)
go s.handleConnection(ctx, sconn, identityContext, chans, reqs)
heartbeatContext, cancel := context.WithCancel(context.Background())
go s.keepAliveLoop(ctx, sconn, cancel)
go s.handleConnection(ctx, heartbeatContext, sconn, chans, reqs)
}
// newRemoteSession will create and return a *ssh.Client and *ssh.Session
@ -372,37 +384,32 @@ func (s *Server) newRemoteSession(systemLogin string) (*ssh.Client, *ssh.Session
return client, session, nil
}
func (s *Server) handleConnection(ctx context.Context, sconn *ssh.ServerConn, identityContext srv.IdentityContext, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
func (s *Server) handleConnection(ctx *srv.ServerContext, heartbeatContext context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
defer s.log.Debugf("Closing connection context %v and releasing resources.", ctx.ID())
defer ctx.Close()
for {
select {
// global out-of-band requests
case newRequest := <-reqs:
if newRequest == nil {
s.log.Debugf("Closing connection to %v", sconn.RemoteAddr())
return
}
go s.handleGlobalRequest(newRequest)
go s.handleGlobalRequest(ctx, newRequest)
// channel requests
case newChannel := <-chans:
if newChannel == nil {
s.log.Debugf("Closing connection to %v", sconn.RemoteAddr())
return
}
go s.handleChannel(sconn, identityContext, newChannel)
go s.handleChannel(ctx, sconn, newChannel)
// if the heartbeats failed, we close everything and cleanup
case <-ctx.Done():
sconn.Close()
s.clientConn.Close()
s.serverConn.Close()
s.log.Debugf("Context closed, cleaning up and closing forwarding server")
case <-heartbeatContext.Done():
return
}
}
}
func (s *Server) keepAliveLoop(cancel context.CancelFunc, sconn *ssh.ServerConn) {
func (s *Server) keepAliveLoop(ctx *srv.ServerContext, sconn *ssh.ServerConn, cancel context.CancelFunc) {
var missed int
// tick at 1/3 of the idle timeout duration
@ -413,7 +420,7 @@ func (s *Server) keepAliveLoop(cancel context.CancelFunc, sconn *ssh.ServerConn)
select {
case <-keepAliveTick.C:
// send a keep alive to the target node and the client to ensure both are alive.
proxyToNodeOk := s.sendKeepAliveWithTimeout(s.remoteClient, defaults.ReadHeadersTimeout)
proxyToNodeOk := s.sendKeepAliveWithTimeout(ctx.RemoteClient, defaults.ReadHeadersTimeout)
proxyToClientOk := s.sendKeepAliveWithTimeout(sconn, defaults.ReadHeadersTimeout)
if proxyToNodeOk && proxyToClientOk {
missed = 0
@ -441,8 +448,8 @@ func (s *Server) rejectChannel(chans <-chan ssh.NewChannel, err error) {
}
}
func (s *Server) handleGlobalRequest(req *ssh.Request) {
ok, err := s.remoteSession.SendRequest(req.Type, req.WantReply, req.Payload)
func (s *Server) handleGlobalRequest(ctx *srv.ServerContext, req *ssh.Request) {
ok, err := ctx.RemoteSession.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
s.log.Warnf("Failed to forward global request %v: %v", req.Type, err)
return
@ -455,7 +462,7 @@ func (s *Server) handleGlobalRequest(req *ssh.Request) {
}
}
func (s *Server) handleChannel(sconn *ssh.ServerConn, identityContext srv.IdentityContext, nch ssh.NewChannel) {
func (s *Server) handleChannel(ctx *srv.ServerContext, sconn *ssh.ServerConn, nch ssh.NewChannel) {
channelType := nch.ChannelType()
switch channelType {
@ -470,7 +477,7 @@ func (s *Server) handleChannel(sconn *ssh.ServerConn, identityContext srv.Identi
if err != nil {
s.log.Infof("Unable to accept channel: %v", err)
}
go s.handleSessionRequests(sconn, identityContext, ch, requests)
go s.handleSessionRequests(ctx, sconn, ch, requests)
// port forwarding
case "direct-tcpip":
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
@ -482,34 +489,19 @@ func (s *Server) handleChannel(sconn *ssh.ServerConn, identityContext srv.Identi
if err != nil {
s.log.Infof("Unable to accept channel: %v", err)
}
go s.handleDirectTCPIPRequest(sconn, identityContext, ch, req)
go s.handleDirectTCPIPRequest(ctx, sconn, ch, req)
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
}
// handleDirectTCPIPRequest handles port forwarding requests.
func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
ctx := srv.NewServerContext(s, sconn, identityContext)
ctx.RemoteClient = s.remoteClient
ctx.RemoteSession = s.remoteSession
ctx.SetAgent(s.userAgent, s.userAgentChannel)
ctx.AddCloser(ch)
ctx.AddCloser(sconn)
ctx.AddCloser(s.targetConn)
ctx.AddCloser(s.serverConn)
ctx.AddCloser(s.clientConn)
ctx.AddCloser(s.remoteSession)
ctx.AddCloser(s.remoteClient)
defer ctx.Debugf("Closed direct-tcp context")
defer ctx.Close()
func (s *Server) handleDirectTCPIPRequest(ctx *srv.ServerContext, sconn *ssh.ServerConn, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
srcAddr := fmt.Sprintf("%v:%d", req.Orig, req.OrigPort)
dstAddr := fmt.Sprintf("%v:%d", req.Host, req.Port)
defer s.log.Debugf("Completing direct-tcpip request from %v to %v.", srcAddr, dstAddr)
// check if the role allows port forwarding for this user
err := s.authHandlers.CheckPortForward(dstAddr, ctx)
if err != nil {
@ -517,9 +509,9 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
return
}
ctx.Debugf("Opening direct-tcpip channel from %v to %v", srcAddr, dstAddr)
s.log.Debugf("Opening direct-tcpip channel from %v to %v.", srcAddr, dstAddr)
conn, err := s.remoteClient.Dial("tcp", dstAddr)
conn, err := ctx.RemoteClient.Dial("tcp", dstAddr)
if err != nil {
ctx.Infof("Failed to connect to: %v: %v", dstAddr, err)
return
@ -566,25 +558,14 @@ func (s *Server) handleTerminalResize(sconn *ssh.ServerConn, ch ssh.Channel) {
}
}
// handleSessionRequests handles out of band session requests once the session channel has been created
// this function's loop handles all the "exec", "subsystem" and "shell" requests.
func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
ctx := srv.NewServerContext(s, sconn, identityContext)
// handleSessionRequests handles out of band session requests once the session
// channel has been created this function's loop handles all the "exec",
// "subsystem" and "shell" requests.
func (s *Server) handleSessionRequests(ctx *srv.ServerContext, sconn *ssh.ServerConn, ch ssh.Channel, in <-chan *ssh.Request) {
defer s.log.Debugf("Closing session request to %v.", sconn.RemoteAddr())
defer ch.Close()
ctx.RemoteClient = s.remoteClient
ctx.RemoteSession = s.remoteSession
ctx.SetAgent(s.userAgent, s.userAgentChannel)
ctx.AddCloser(ch)
ctx.AddCloser(sconn)
ctx.AddCloser(s.targetConn)
ctx.AddCloser(s.serverConn)
ctx.AddCloser(s.clientConn)
ctx.AddCloser(s.remoteSession)
ctx.AddCloser(s.remoteClient)
defer s.log.Debugf("Closed session context")
defer ctx.Close()
s.log.Debugf("Opening session request to %v.", sconn.RemoteAddr())
for {
// update ctx with the session ID:
@ -607,6 +588,7 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext sr
// this means that subsystem has finished executing and
// want us to close session and the channel
ctx.Debugf("Subsystem execution result: %v", result.Err)
return
case req := <-in:
if req == nil {
@ -630,13 +612,14 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext sr
if err != nil {
ctx.Infof("Failed to send exit status for %v: %v", result.Command, err)
}
return
}
}
}
func (s *Server) dispatch(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error {
ctx.Debugf("Handling request %v (WantReply=%v)", req.Type, req.WantReply)
ctx.Debugf("Handling request %v, want reply %v.", req.Type, req.WantReply)
switch req.Type {
case sshutils.ExecRequest:
@ -674,13 +657,13 @@ func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.S
}
// route authentication requests to the agent that was forwarded to the proxy
err = agent.ForwardToAgent(s.remoteClient, ctx.GetAgent())
err = agent.ForwardToAgent(ctx.RemoteClient, ctx.GetAgent())
if err != nil {
return trace.Wrap(err)
}
// make an "auth-agent-req@openssh.com" request on the target node
err = agent.RequestAgentForwarding(s.remoteSession)
err = agent.RequestAgentForwarding(ctx.RemoteSession)
if err != nil {
return trace.Wrap(err)
}
@ -723,7 +706,7 @@ func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerCont
return trace.Wrap(err, "failed to parse env request")
}
err := s.remoteSession.Setenv(e.Name, e.Value)
err := ctx.RemoteSession.Setenv(e.Name, e.Value)
if err != nil {
s.log.Debugf("Unable to set environment variable: %v: %v", e.Name, e.Value)
}