mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 01:34:01 +00:00
Create single server context for forwarding server.
This commit is contained in:
parent
4996825586
commit
61b2873b33
|
@ -749,7 +749,7 @@ func (s *IntSuite) TestTwoClusters(c *check.C) {
|
|||
return nil
|
||||
}
|
||||
case <-stopCh:
|
||||
return trace.BadParameter("unable to find %v events after 5s: %v", count)
|
||||
return trace.BadParameter("unable to find %v events after 5s", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -201,6 +201,10 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
|
|||
return ctx
|
||||
}
|
||||
|
||||
func (c *ServerContext) ID() int {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *ServerContext) GetServer() Server {
|
||||
return c.srv
|
||||
}
|
||||
|
|
|
@ -79,11 +79,6 @@ type Server struct {
|
|||
// to the client.
|
||||
hostCertificate ssh.Signer
|
||||
|
||||
// remoteClient represents a *ssh.Client connected to the target node.
|
||||
remoteClient *ssh.Client
|
||||
// remoteSession represents a *ssh.Session on the target node.
|
||||
remoteSession *ssh.Session
|
||||
|
||||
// authHandlers are common authorization and authentication handlers shared
|
||||
// by the regular and forwarding server.
|
||||
authHandlers *srv.AuthHandlers
|
||||
|
@ -305,7 +300,7 @@ func (s *Server) Serve() {
|
|||
|
||||
// build a remote session to the remote node
|
||||
s.log.Debugf("Creating remote connection to %v@%v", sconn.User(), s.clientConn.RemoteAddr().String())
|
||||
s.remoteClient, s.remoteSession, err = s.newRemoteSession(sconn.User())
|
||||
remoteClient, remoteSession, err := s.newRemoteSession(sconn.User())
|
||||
if err != nil {
|
||||
// reject the connection with an error so the client doesn't hang then
|
||||
// close the connection
|
||||
|
@ -320,13 +315,30 @@ func (s *Server) Serve() {
|
|||
return
|
||||
}
|
||||
|
||||
// create server context for this connection, it's closed when the
|
||||
// connection is closed
|
||||
ctx := srv.NewServerContext(s, sconn, identityContext)
|
||||
|
||||
ctx.RemoteClient = remoteClient
|
||||
ctx.RemoteSession = remoteSession
|
||||
ctx.SetAgent(s.userAgent, s.userAgentChannel)
|
||||
|
||||
ctx.AddCloser(sconn)
|
||||
ctx.AddCloser(s.targetConn)
|
||||
ctx.AddCloser(s.serverConn)
|
||||
ctx.AddCloser(s.clientConn)
|
||||
ctx.AddCloser(remoteSession)
|
||||
ctx.AddCloser(remoteClient)
|
||||
|
||||
s.log.Debugf("Created connection context %v", ctx.ID())
|
||||
|
||||
// create a cancelable context and pass it to a keep alive loop. the keep
|
||||
// alive loop will keep pinging the remote server and after it has missed a
|
||||
// certain number of keep alive requests it will cancel the context which
|
||||
// will close any listening goroutines.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go s.keepAliveLoop(cancel, sconn)
|
||||
go s.handleConnection(ctx, sconn, identityContext, chans, reqs)
|
||||
heartbeatContext, cancel := context.WithCancel(context.Background())
|
||||
go s.keepAliveLoop(ctx, sconn, cancel)
|
||||
go s.handleConnection(ctx, heartbeatContext, sconn, chans, reqs)
|
||||
}
|
||||
|
||||
// newRemoteSession will create and return a *ssh.Client and *ssh.Session
|
||||
|
@ -372,37 +384,32 @@ func (s *Server) newRemoteSession(systemLogin string) (*ssh.Client, *ssh.Session
|
|||
return client, session, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(ctx context.Context, sconn *ssh.ServerConn, identityContext srv.IdentityContext, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
|
||||
func (s *Server) handleConnection(ctx *srv.ServerContext, heartbeatContext context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
|
||||
defer s.log.Debugf("Closing connection context %v and releasing resources.", ctx.ID())
|
||||
defer ctx.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
// global out-of-band requests
|
||||
case newRequest := <-reqs:
|
||||
if newRequest == nil {
|
||||
s.log.Debugf("Closing connection to %v", sconn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
go s.handleGlobalRequest(newRequest)
|
||||
go s.handleGlobalRequest(ctx, newRequest)
|
||||
// channel requests
|
||||
case newChannel := <-chans:
|
||||
if newChannel == nil {
|
||||
s.log.Debugf("Closing connection to %v", sconn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
go s.handleChannel(sconn, identityContext, newChannel)
|
||||
go s.handleChannel(ctx, sconn, newChannel)
|
||||
// if the heartbeats failed, we close everything and cleanup
|
||||
case <-ctx.Done():
|
||||
sconn.Close()
|
||||
|
||||
s.clientConn.Close()
|
||||
s.serverConn.Close()
|
||||
|
||||
s.log.Debugf("Context closed, cleaning up and closing forwarding server")
|
||||
case <-heartbeatContext.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) keepAliveLoop(cancel context.CancelFunc, sconn *ssh.ServerConn) {
|
||||
func (s *Server) keepAliveLoop(ctx *srv.ServerContext, sconn *ssh.ServerConn, cancel context.CancelFunc) {
|
||||
var missed int
|
||||
|
||||
// tick at 1/3 of the idle timeout duration
|
||||
|
@ -413,7 +420,7 @@ func (s *Server) keepAliveLoop(cancel context.CancelFunc, sconn *ssh.ServerConn)
|
|||
select {
|
||||
case <-keepAliveTick.C:
|
||||
// send a keep alive to the target node and the client to ensure both are alive.
|
||||
proxyToNodeOk := s.sendKeepAliveWithTimeout(s.remoteClient, defaults.ReadHeadersTimeout)
|
||||
proxyToNodeOk := s.sendKeepAliveWithTimeout(ctx.RemoteClient, defaults.ReadHeadersTimeout)
|
||||
proxyToClientOk := s.sendKeepAliveWithTimeout(sconn, defaults.ReadHeadersTimeout)
|
||||
if proxyToNodeOk && proxyToClientOk {
|
||||
missed = 0
|
||||
|
@ -441,8 +448,8 @@ func (s *Server) rejectChannel(chans <-chan ssh.NewChannel, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleGlobalRequest(req *ssh.Request) {
|
||||
ok, err := s.remoteSession.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||
func (s *Server) handleGlobalRequest(ctx *srv.ServerContext, req *ssh.Request) {
|
||||
ok, err := ctx.RemoteSession.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||
if err != nil {
|
||||
s.log.Warnf("Failed to forward global request %v: %v", req.Type, err)
|
||||
return
|
||||
|
@ -455,7 +462,7 @@ func (s *Server) handleGlobalRequest(req *ssh.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleChannel(sconn *ssh.ServerConn, identityContext srv.IdentityContext, nch ssh.NewChannel) {
|
||||
func (s *Server) handleChannel(ctx *srv.ServerContext, sconn *ssh.ServerConn, nch ssh.NewChannel) {
|
||||
channelType := nch.ChannelType()
|
||||
|
||||
switch channelType {
|
||||
|
@ -470,7 +477,7 @@ func (s *Server) handleChannel(sconn *ssh.ServerConn, identityContext srv.Identi
|
|||
if err != nil {
|
||||
s.log.Infof("Unable to accept channel: %v", err)
|
||||
}
|
||||
go s.handleSessionRequests(sconn, identityContext, ch, requests)
|
||||
go s.handleSessionRequests(ctx, sconn, ch, requests)
|
||||
// port forwarding
|
||||
case "direct-tcpip":
|
||||
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
|
||||
|
@ -482,34 +489,19 @@ func (s *Server) handleChannel(sconn *ssh.ServerConn, identityContext srv.Identi
|
|||
if err != nil {
|
||||
s.log.Infof("Unable to accept channel: %v", err)
|
||||
}
|
||||
go s.handleDirectTCPIPRequest(sconn, identityContext, ch, req)
|
||||
go s.handleDirectTCPIPRequest(ctx, sconn, ch, req)
|
||||
default:
|
||||
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDirectTCPIPRequest handles port forwarding requests.
|
||||
func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
|
||||
ctx := srv.NewServerContext(s, sconn, identityContext)
|
||||
|
||||
ctx.RemoteClient = s.remoteClient
|
||||
ctx.RemoteSession = s.remoteSession
|
||||
ctx.SetAgent(s.userAgent, s.userAgentChannel)
|
||||
|
||||
ctx.AddCloser(ch)
|
||||
ctx.AddCloser(sconn)
|
||||
ctx.AddCloser(s.targetConn)
|
||||
ctx.AddCloser(s.serverConn)
|
||||
ctx.AddCloser(s.clientConn)
|
||||
ctx.AddCloser(s.remoteSession)
|
||||
ctx.AddCloser(s.remoteClient)
|
||||
|
||||
defer ctx.Debugf("Closed direct-tcp context")
|
||||
defer ctx.Close()
|
||||
|
||||
func (s *Server) handleDirectTCPIPRequest(ctx *srv.ServerContext, sconn *ssh.ServerConn, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
|
||||
srcAddr := fmt.Sprintf("%v:%d", req.Orig, req.OrigPort)
|
||||
dstAddr := fmt.Sprintf("%v:%d", req.Host, req.Port)
|
||||
|
||||
defer s.log.Debugf("Completing direct-tcpip request from %v to %v.", srcAddr, dstAddr)
|
||||
|
||||
// check if the role allows port forwarding for this user
|
||||
err := s.authHandlers.CheckPortForward(dstAddr, ctx)
|
||||
if err != nil {
|
||||
|
@ -517,9 +509,9 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
|
|||
return
|
||||
}
|
||||
|
||||
ctx.Debugf("Opening direct-tcpip channel from %v to %v", srcAddr, dstAddr)
|
||||
s.log.Debugf("Opening direct-tcpip channel from %v to %v.", srcAddr, dstAddr)
|
||||
|
||||
conn, err := s.remoteClient.Dial("tcp", dstAddr)
|
||||
conn, err := ctx.RemoteClient.Dial("tcp", dstAddr)
|
||||
if err != nil {
|
||||
ctx.Infof("Failed to connect to: %v: %v", dstAddr, err)
|
||||
return
|
||||
|
@ -566,25 +558,14 @@ func (s *Server) handleTerminalResize(sconn *ssh.ServerConn, ch ssh.Channel) {
|
|||
}
|
||||
}
|
||||
|
||||
// handleSessionRequests handles out of band session requests once the session channel has been created
|
||||
// this function's loop handles all the "exec", "subsystem" and "shell" requests.
|
||||
func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
|
||||
ctx := srv.NewServerContext(s, sconn, identityContext)
|
||||
// handleSessionRequests handles out of band session requests once the session
|
||||
// channel has been created this function's loop handles all the "exec",
|
||||
// "subsystem" and "shell" requests.
|
||||
func (s *Server) handleSessionRequests(ctx *srv.ServerContext, sconn *ssh.ServerConn, ch ssh.Channel, in <-chan *ssh.Request) {
|
||||
defer s.log.Debugf("Closing session request to %v.", sconn.RemoteAddr())
|
||||
defer ch.Close()
|
||||
|
||||
ctx.RemoteClient = s.remoteClient
|
||||
ctx.RemoteSession = s.remoteSession
|
||||
ctx.SetAgent(s.userAgent, s.userAgentChannel)
|
||||
|
||||
ctx.AddCloser(ch)
|
||||
ctx.AddCloser(sconn)
|
||||
ctx.AddCloser(s.targetConn)
|
||||
ctx.AddCloser(s.serverConn)
|
||||
ctx.AddCloser(s.clientConn)
|
||||
ctx.AddCloser(s.remoteSession)
|
||||
ctx.AddCloser(s.remoteClient)
|
||||
|
||||
defer s.log.Debugf("Closed session context")
|
||||
defer ctx.Close()
|
||||
s.log.Debugf("Opening session request to %v.", sconn.RemoteAddr())
|
||||
|
||||
for {
|
||||
// update ctx with the session ID:
|
||||
|
@ -607,6 +588,7 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext sr
|
|||
// this means that subsystem has finished executing and
|
||||
// want us to close session and the channel
|
||||
ctx.Debugf("Subsystem execution result: %v", result.Err)
|
||||
|
||||
return
|
||||
case req := <-in:
|
||||
if req == nil {
|
||||
|
@ -630,13 +612,14 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext sr
|
|||
if err != nil {
|
||||
ctx.Infof("Failed to send exit status for %v: %v", result.Command, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) dispatch(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error {
|
||||
ctx.Debugf("Handling request %v (WantReply=%v)", req.Type, req.WantReply)
|
||||
ctx.Debugf("Handling request %v, want reply %v.", req.Type, req.WantReply)
|
||||
|
||||
switch req.Type {
|
||||
case sshutils.ExecRequest:
|
||||
|
@ -674,13 +657,13 @@ func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.S
|
|||
}
|
||||
|
||||
// route authentication requests to the agent that was forwarded to the proxy
|
||||
err = agent.ForwardToAgent(s.remoteClient, ctx.GetAgent())
|
||||
err = agent.ForwardToAgent(ctx.RemoteClient, ctx.GetAgent())
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// make an "auth-agent-req@openssh.com" request on the target node
|
||||
err = agent.RequestAgentForwarding(s.remoteSession)
|
||||
err = agent.RequestAgentForwarding(ctx.RemoteSession)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
@ -723,7 +706,7 @@ func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerCont
|
|||
return trace.Wrap(err, "failed to parse env request")
|
||||
}
|
||||
|
||||
err := s.remoteSession.Setenv(e.Name, e.Value)
|
||||
err := ctx.RemoteSession.Setenv(e.Name, e.Value)
|
||||
if err != nil {
|
||||
s.log.Debugf("Unable to set environment variable: %v: %v", e.Name, e.Value)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue