* Push window size changes to clients instead of polling.

* Cache services.ClusterConfig within srv.ServerContext for the duration
  of a connection.
* Create a single websocket between the browser and the proxy for all
* terminal bytes and events.
This commit is contained in:
Russell Jones 2018-05-04 00:36:08 +00:00
parent a8190188cd
commit 876e04af07
20 changed files with 1623 additions and 1225 deletions

View file

@ -345,3 +345,9 @@ const (
// SharedDirMode is a mode for a directory shared with group
SharedDirMode = 0750
)
const (
// SessionEvent is sent by servers to clients when an audit event occurs on
// the session.
SessionEvent = "x-teleport-event"
)

View file

@ -2525,6 +2525,120 @@ func runAndMatch(tc *client.TeleportClient, attempts int, command []string, patt
return err
}
// TestWindowChange checks if custom Teleport window change requests are sent
// when the server side PTY changes its size.
func (s *IntSuite) TestWindowChange(c *check.C) {
t := s.newTeleport(c, nil, true)
defer t.Stop(true)
site := t.GetSiteAPI(Site)
c.Assert(site, check.NotNil)
personA := NewTerminal(250)
personB := NewTerminal(250)
// openSession will open a new session on a server.
openSession := func() {
cl, err := t.NewClient(ClientConfig{
Login: s.me.Username,
Cluster: Site,
Host: Host,
Port: t.GetPortSSHInt(),
})
c.Assert(err, check.IsNil)
cl.Stdout = &personA
cl.Stdin = &personA
err = cl.SSH(context.TODO(), []string{}, false)
c.Assert(err, check.IsNil)
}
// joinSession will join the existing session on a server.
joinSession := func() {
// Find the existing session in the backend.
var sessionID string
for {
time.Sleep(time.Millisecond)
sessions, _ := site.GetSessions(defaults.Namespace)
if len(sessions) == 0 {
continue
}
sessionID = string(sessions[0].ID)
break
}
cl, err := t.NewClient(ClientConfig{
Login: s.me.Username,
Cluster: Site,
Host: Host,
Port: t.GetPortSSHInt(),
})
c.Assert(err, check.IsNil)
cl.Stdout = &personB
cl.Stdin = &personB
// Change the size of the window immediately after it is created.
cl.OnShellCreated = func(s *ssh.Session, c *ssh.Client, terminal io.ReadWriteCloser) (exit bool, err error) {
err = s.WindowChange(48, 160)
if err != nil {
return true, trace.Wrap(err)
}
return false, nil
}
for i := 0; i < 10; i++ {
err = cl.Join(context.TODO(), defaults.Namespace, session.ID(sessionID), &personB)
if err == nil {
break
}
}
c.Assert(err, check.IsNil)
}
// waitForOutput checks the output of the passed in terminal of a string until
// some timeout has occured.
waitForOutput := func(t Terminal, s string) error {
tickerCh := time.Tick(500 * time.Millisecond)
timeoutCh := time.After(30 * time.Second)
for {
select {
case <-tickerCh:
if strings.Contains(t.Output(500), s) {
return nil
}
case <-timeoutCh:
return trace.BadParameter("timed out waiting for output")
}
}
}
// Open session, the initial size will be 80x24.
go openSession()
// Use the "printf" command to print the terminal size on the screen and
// make sure it is 80x25.
personA.Type("\aprintf '%s %s\n' $(tput cols) $(tput lines)\n\r\a")
err := waitForOutput(personA, "80 25")
c.Assert(err, check.IsNil)
// As soon as person B joins the session, the terminal is resized to 160x48.
// Have another user join the session. As soon as the second shell is
// created, the window is resized to 160x48 (see joinSession implementation).
go joinSession()
// Use the "printf" command to print the window size again and make sure it's
// 160x48.
personA.Type("\aprintf '%s %s\n' $(tput cols) $(tput lines)\n\r\a")
err = waitForOutput(personA, "160 48")
c.Assert(err, check.IsNil)
// Close the session.
personA.Type("\aexit\r\n\a")
}
// runCommand is a shortcut for running SSH command, it creates a client
// connected to proxy of the passed in instance, runs the command, and returns
// the result. If multiple attempts are requested, a 250 millisecond delay is

View file

@ -518,8 +518,12 @@ type TeleportClient struct {
localAgent *LocalKeyAgent
// OnShellCreated gets called when the shell is created. It's
// safe to keep it nil
// safe to keep it nil.
OnShellCreated ShellCreatedCallback
// eventsCh is a channel used to inform clients about events have that
// occured during the session.
eventsCh chan events.EventFields
}
// ShellCreatedCallback can be supplied for every teleport client. It will
@ -568,6 +572,12 @@ func NewClient(c *Config) (tc *TeleportClient, err error) {
tc.Stdin = os.Stdin
}
// Create a buffered channel to hold events that occured during this session.
// This channel must be buffered because the SSH connection directly feeds
// into it. Delays in pulling messages off the global SSH request channel
// could lead to the connection hanging.
tc.eventsCh = make(chan events.EventFields, 1024)
// sometimes we need to use external auth without using local auth
// methods, e.g. in automation daemons
if c.SkipLocalAuth {
@ -1500,6 +1510,24 @@ func (tc *TeleportClient) u2fLogin(pub []byte) (*auth.SSHLoginResponse, error) {
return response, trace.Wrap(err)
}
// SendEvent adds a events.EventFields to the channel.
func (tc *TeleportClient) SendEvent(ctx context.Context, e events.EventFields) error {
// Try and send the event to the eventsCh. If blocking, keep blocking until
// the passed in context in canceled.
select {
case tc.eventsCh <- e:
return nil
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}
// EventsChannel returns a channel that can be used to listen for events that
// occur for this session.
func (tc *TeleportClient) EventsChannel() <-chan events.EventFields {
return tc.eventsCh
}
// loopbackPool reads trusted CAs if it finds it in a predefined location
// and will work only if target proxy address is loopback
func loopbackPool(proxyAddr string) *x509.CertPool {

View file

@ -34,6 +34,7 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/scp"
@ -63,6 +64,7 @@ type NodeClient struct {
Namespace string
Client *ssh.Client
Proxy *ProxyClient
TC *TeleportClient
}
// GetSites returns list of the "sites" (AKA teleport clusters) connected to the proxy
@ -420,9 +422,62 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress string,
return nil, trace.Wrap(err)
}
client := ssh.NewClient(conn, chans, reqs)
// We pass an empty channel which we close right away to ssh.NewClient
// because the client need to handle requests itself.
emptyCh := make(chan *ssh.Request)
close(emptyCh)
return &NodeClient{Client: client, Proxy: proxy, Namespace: defaults.Namespace}, nil
client := ssh.NewClient(conn, chans, emptyCh)
nc := &NodeClient{
Client: client,
Proxy: proxy,
Namespace: defaults.Namespace,
TC: proxy.teleportClient,
}
// Start a goroutine that will run for the duration of the client to process
// global requests from the client. Teleport clients will use this to update
// terminal sizes when the remote PTY size has changed.
go nc.handleGlobalRequests(ctx, reqs)
return nc, nil
}
func (c *NodeClient) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.Request) {
for {
select {
case r := <-requestCh:
// When the channel is closing, nil is returned.
if r == nil {
return
}
switch r.Type {
case teleport.SessionEvent:
// Parse event and create events.EventFields that can be consumed directly
// by caller.
var e events.EventFields
err := json.Unmarshal(r.Payload, &e)
if err != nil {
log.Warnf("Unable to parse event: %v: %v.", string(r.Payload), err)
continue
}
// Send event to event channel.
err = c.TC.SendEvent(ctx, e)
if err != nil {
log.Warnf("Unable to send event %v: %v.", string(r.Payload), err)
continue
}
default:
// This handles keepalive messages and matches the behaviour of OpenSSH.
r.Reply(false, nil)
}
case <-ctx.Done():
return
}
}
}
// newClientConn is a wrapper around ssh.NewClientConn
@ -504,18 +559,18 @@ func (client *NodeClient) Download(remoteSourcePath, localDestinationPath string
// scp runs remote scp command(shellCmd) on the remote server and
// runs local scp handler using scpConf
func (client *NodeClient) scp(scpCommand scp.Command, shellCmd string, errWriter io.Writer) error {
session, err := client.Client.NewSession()
s, err := client.Client.NewSession()
if err != nil {
return trace.Wrap(err)
}
defer session.Close()
defer s.Close()
stdin, err := session.StdinPipe()
stdin, err := s.StdinPipe()
if err != nil {
return trace.Wrap(err)
}
stdout, err := session.StdoutPipe()
stdout, err := s.StdoutPipe()
if err != nil {
return trace.Wrap(err)
}
@ -537,7 +592,7 @@ func (client *NodeClient) scp(scpCommand scp.Command, shellCmd string, errWriter
close(closeC)
}()
runErr := session.Run(shellCmd)
runErr := s.Run(shellCmd)
if runErr != nil && err == nil {
err = runErr
}

View file

@ -32,6 +32,7 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"
@ -44,6 +45,7 @@ import (
type NodeSession struct {
// namespace is a session this namespace belongs to
namespace string
// id is the Teleport session ID
id session.ID
@ -281,81 +283,102 @@ func (ns *NodeSession) allocateTerminal(termType string, s *ssh.Session) (io.Rea
}
func (ns *NodeSession) updateTerminalSize(s *ssh.Session) {
// sibscribe for "terminal resized" signal:
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGWINCH)
currentSize, _ := term.GetWinsize(0)
// SIGWINCH is sent to the process when the window size of the terminal has
// changed.
sigwinchCh := make(chan os.Signal, 1)
signal.Notify(sigwinchCh, syscall.SIGWINCH)
// start the timer which asks for server-side window size changes:
siteClient, err := ns.nodeClient.Proxy.ConnectToSite(context.TODO(), true)
lastSize, err := term.GetWinsize(0)
if err != nil {
log.Error(err)
log.Errorf("Unable to get window size: %v", err)
return
}
tick := time.NewTicker(defaults.SessionRefreshPeriod)
defer tick.Stop()
var prevSess *session.Session
// Sync the local terminal with size received from the remote server every
// two seconds. If we try and do it live, synchronization jitters occur.
tickerCh := time.NewTicker(defaults.TerminalResizePeriod)
defer tickerCh.Stop()
for {
select {
// our own terminal window got resized:
case sig := <-sigC:
if sig == nil {
// The client updated the size of the local PTY. This change needs to occur
// on the server side PTY as well.
case sigwinch := <-sigwinchCh:
if sigwinch == nil {
return
}
// get the size:
winSize, err := term.GetWinsize(0)
currSize, err := term.GetWinsize(0)
if err != nil {
log.Warnf("[CLIENT] Error getting size: %s", err)
break
}
// it's the result of our own size change (see below)
if winSize.Height == currentSize.Height && winSize.Width == currentSize.Width {
log.Warnf("Unable to get window size: %v.", err)
continue
}
// send the new window size to the server
// Terminal size has not changed, don't do anything.
if currSize.Height == lastSize.Height && currSize.Width == lastSize.Width {
continue
}
// Send the "window-change" request over the channel.
_, err = s.SendRequest(
sshutils.WindowChangeRequest, false,
sshutils.WindowChangeRequest,
false,
ssh.Marshal(sshutils.WinChangeReqParams{
W: uint32(winSize.Width),
H: uint32(winSize.Height),
W: uint32(currSize.Width),
H: uint32(currSize.Height),
}))
if err != nil {
log.Warnf("[CLIENT] failed to send window change reqest: %v", err)
}
case <-tick.C:
sess, err := siteClient.GetSession(ns.namespace, ns.id)
if err != nil {
if !trace.IsNotFound(err) {
log.Error(trace.DebugReport(err))
}
log.Warnf("Unable to send %v reqest: %v.", sshutils.WindowChangeRequest, err)
continue
}
// no previous session
if prevSess == nil || sess == nil {
prevSess = sess
continue
}
// nothing changed
if prevSess.TerminalParams.W == sess.TerminalParams.W && prevSess.TerminalParams.H == sess.TerminalParams.H {
continue
}
log.Infof("[CLIENT] updating the session %v with %d parties", sess.ID, len(sess.Parties))
newSize := sess.TerminalParams.Winsize()
currentSize, err = term.GetWinsize(0)
log.Debugf("Updated window size from %v to %v due to SIGWINCH.", lastSize, currSize)
lastSize = currSize
// Extract "resize" events in the stream and store the last window size.
case event := <-ns.nodeClient.TC.EventsChannel():
// Only "resize" events are important to tsh, all others can be ignored.
if event.GetType() != events.ResizeEvent {
continue
}
terminalParams, err := session.UnmarshalTerminalParams(event.GetString(events.TerminalSize))
if err != nil {
log.Error(err)
log.Warnf("Unable to unmarshal terminal parameters: %v.", err)
continue
}
if currentSize.Width != newSize.Width || currentSize.Height != newSize.Height {
// ok, something have changed, let's resize to the new parameters
err = term.SetWinsize(0, newSize)
if err != nil {
log.Error(err)
}
os.Stdout.Write([]byte(fmt.Sprintf("\x1b[8;%d;%dt", newSize.Height, newSize.Width)))
lastSize = terminalParams.Winsize()
log.Debugf("Recevied window size %v from node in session.\n", lastSize, event.GetString(events.SessionEventID))
// Update size of local terminal with the last size received from remote server.
case <-tickerCh.C:
// Get the current size of the terminal and the last size report that was
// received.
currSize, err := term.GetWinsize(0)
if err != nil {
log.Warnf("Unable to get current terminal size: %v.", err)
continue
}
prevSess = sess
// Terminal size has not changed, don't do anything.
if currSize.Width == lastSize.Width && currSize.Height == lastSize.Height {
continue
}
// This changes the size of the local PTY. This will re-draw what's within
// the window.
err = term.SetWinsize(0, lastSize)
if err != nil {
log.Warnf("Unable to update terminal size: %v.\n", err)
continue
}
// This is what we use to resize the physical terminal window itself.
os.Stdout.Write([]byte(fmt.Sprintf("\x1b[8;%d;%dt", lastSize.Height, lastSize.Width)))
log.Debugf("Updated window size from to %v due to remote window change.", currSize, lastSize)
case <-ns.closer.C:
return
}

View file

@ -221,19 +221,21 @@ var (
// their stored list of auth servers
AuthServersRefreshPeriod = 30 * time.Second
// SessionRefreshPeriod is how often tsh polls information about session
// TODO(klizhentas) all polling periods should go away once backend
// releases events
// TerminalResizePeriod is how long tsh waits before updating the size of the
// terminal window.
TerminalResizePeriod = 2 * time.Second
// SessionRefreshPeriod is how often session data is updated on the backend.
// The web client polls this information about session to update the UI.
//
// TODO(klizhentas): All polling periods should go away once backend supports
// events.
SessionRefreshPeriod = 2 * time.Second
// SessionIdlePeriod is the period of inactivity after which the
// session will be considered idle
SessionIdlePeriod = SessionRefreshPeriod * 10
// TerminalSizeRefreshPeriod is how frequently clients who share sessions sync up
// their terminal sizes
TerminalSizeRefreshPeriod = 2 * time.Second
// NewtworkBackoffDuration is a standard backoff on network requests
// usually is slow, e.g. once in 30 seconds
NetworkBackoffDuration = time.Second * 30
@ -399,3 +401,15 @@ const (
// CATTL is a default lifetime of a CA certificate
CATTL = time.Hour * 24 * 365 * 10
)
const (
// AuditEnvelopeType is sending a audit event over the websocket to the web client.
AuditEnvelopeType = "audit"
// RawEnvelopeType is sending raw terminal bytes over the websocket to the web
// client.
RawEnvelopeType = "raw"
// ResizeRequestEnvelopeType is receiving a resize request.
ResizeRequestEnvelopeType = "resize.request"
)

View file

@ -1617,8 +1617,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
// Register web proxy server
var webServer *http.Server
var webHandler *web.RewritingHandler
if !process.Config.Proxy.DisableWebService {
webHandler, err := web.NewHandler(
webHandler, err = web.NewHandler(
web.Config{
Proxy: tsrv,
AuthServers: cfg.AuthServers[0],
@ -1718,6 +1719,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if webServer != nil {
warnOnErr(webServer.Close())
}
if webHandler != nil {
warnOnErr(webHandler.Close())
}
warnOnErr(sshProxy.Close())
} else {
log.Infof("Shutting down gracefully.")
@ -1729,6 +1733,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if webServer != nil {
warnOnErr(webServer.Shutdown(ctx))
}
if webHandler != nil {
warnOnErr(webHandler.Close())
}
}
log.Infof("Exited.")
})

View file

@ -21,6 +21,8 @@ package session
import (
"fmt"
"sort"
"strconv"
"strings"
"time"
"github.com/gravitational/teleport/lib/backend"
@ -154,12 +156,35 @@ func (p *Party) String() string {
)
}
// TerminalParams holds parameters of the terminal used in session
// TerminalParams holds the terminal size in a session.
type TerminalParams struct {
W int `json:"w"`
H int `json:"h"`
}
// UnmarshalTerminalParams takes a serialized string that contains the
// terminal parameters and returns a *TerminalParams.
func UnmarshalTerminalParams(s string) (*TerminalParams, error) {
parts := strings.Split(s, ":")
if len(parts) != 2 {
return nil, trace.BadParameter("failed to unmarshal: too many parts")
}
w, err := strconv.Atoi(parts[0])
if err != nil {
return nil, trace.Wrap(err)
}
h, err := strconv.Atoi(parts[1])
if err != nil {
return nil, trace.Wrap(err)
}
return &TerminalParams{
W: w,
H: h,
}, nil
}
// Serialize is a more strict version of String(): it returns a string
// representation of terminal size, this is used in our APIs.
// Format : "W:H"

View file

@ -172,6 +172,10 @@ type ServerContext struct {
// ClusterName is the name of the cluster current user is authenticated with.
ClusterName string
// ClusterConfig holds the cluster configuration at the time this context was
// created.
ClusterConfig services.ClusterConfig
// RemoteClient holds a SSH client to a remote server. Only used by the
// recording proxy.
RemoteClient *ssh.Client
@ -183,7 +187,12 @@ type ServerContext struct {
// NewServerContext creates a new *ServerContext which is used to pass and
// manage resources.
func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext) *ServerContext {
func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext) (*ServerContext, error) {
clusterConfig, err := srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
}
ctx := &ServerContext{
id: int(atomic.AddInt32(&ctxID, int32(1))),
env: make(map[string]string),
@ -192,6 +201,7 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
ExecResultCh: make(chan ExecResult, 10),
SubsystemResultCh: make(chan SubsystemResult, 10),
ClusterName: conn.Permissions.Extensions[utils.CertTeleportClusterName],
ClusterConfig: clusterConfig,
Identity: identityContext,
}
@ -205,7 +215,8 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
"id": ctx.id,
},
})
return ctx
return ctx, nil
}
func (c *ServerContext) ID() int {

View file

@ -76,8 +76,8 @@ type Exec interface {
// NewExecRequest creates a new local or remote Exec.
func NewExecRequest(ctx *ServerContext, command string) (Exec, error) {
// doesn't matter what mode the cluster is in, if this is a teleport node
// return a local *localExec
// It doesn't matter what mode the cluster is in, if this is a Teleport node
// return a local *localExec.
if ctx.srv.Component() == teleport.ComponentNode {
return &localExec{
Ctx: ctx,
@ -85,14 +85,9 @@ func NewExecRequest(ctx *ServerContext, command string) (Exec, error) {
}, nil
}
clusterConfig, err := ctx.srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
}
// when in recording mode, return an *remoteExec which will execute the
// command on a remote host. used by forwarding nodes.
if clusterConfig.GetSessionRecording() == services.RecordAtProxy {
// When in recording mode, return an *remoteExec which will execute the
// command on a remote host. This is used by in-memory forwarding nodes.
if ctx.ClusterConfig.GetSessionRecording() == services.RecordAtProxy {
return &remoteExec{
ctx: ctx,
command: command,
@ -100,8 +95,8 @@ func NewExecRequest(ctx *ServerContext, command string) (Exec, error) {
}, nil
}
// otherwise return a *localExec which will execute locally on the server.
// used by the regular teleport nodes.
// Otherwise return a *localExec which will execute locally on the server.
// used by the regular Teleport nodes.
return &localExec{
Ctx: ctx,
Command: command,

View file

@ -507,13 +507,8 @@ func (s *Server) handleChannel(nch ssh.NewChannel) {
channelType := nch.ChannelType()
switch channelType {
// A client requested the terminal size to be sent along with every
// session message (Teleport-specific SSH channel for web-based terminals).
case "x-teleport-request-resize-events":
ch, _, _ := nch.Accept()
go s.handleTerminalResize(ch)
// Channels of type "session" handle requests that are invovled in running
// commands on a server.
// commands on a server, subsystem requests, and agent forwarding.
case "session":
ch, requests, err := nch.Accept()
if err != nil {
@ -549,12 +544,17 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.RemoteClient = s.remoteClient
defer ctx.Close()
// Check if the role allows port forwarding for this user.
err := s.authHandlers.CheckPortForward(dstAddr, ctx)
err = s.authHandlers.CheckPortForward(dstAddr, ctx)
if err != nil {
ch.Stderr().Write([]byte(err.Error()))
return
@ -597,27 +597,18 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC
wg.Wait()
}
// handleTerminalResize is called by the web proxy via its SSH connection.
// when a web browser connects to the web API, the web proxy asks us,
// by creating this new SSH channel, to start injecting the terminal size
// into every SSH write back to it.
//
// This is the only way to make web-based terminal UI not break apart
// when window changes its size.
func (s *Server) handleTerminalResize(channel ssh.Channel) {
err := s.sessionRegistry.PushTermSizeToParty(s.sconn, channel)
if err != nil {
s.log.Warnf("Unable to push terminal size to party: %v", err)
}
}
// handleSessionRequests handles out of band session requests once the session
// channel has been created this function's loop handles all the "exec",
// "subsystem" and "shell" requests.
func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) {
// Create context for this channel. This context will be closed when the
// session request is complete.
ctx := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.RemoteClient = s.remoteClient
ctx.AddCloser(ch)
defer ctx.Close()

View file

@ -699,10 +699,15 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
channelType := nch.ChannelType()
if s.proxyMode {
if channelType == "session" { // interactive sessions
// Channels of type "session" handle requests that are invovled in running
// commands on a server. In the case of proxy mode subsystem and agent
// forwarding requests occur over the "session" channel.
if channelType == "session" {
ch, requests, err := nch.Accept()
if err != nil {
log.Infof("could not accept channel (%s)", err)
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(sconn, identityContext, ch, requests)
} else {
@ -712,26 +717,29 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
}
switch channelType {
// a client requested the terminal size to be sent along with every
// session message (Teleport-specific SSH channel for web-based terminals)
case "x-teleport-request-resize-events":
ch, _, _ := nch.Accept()
go s.handleTerminalResize(sconn, ch)
case "session": // interactive sessions
// Channels of type "session" handle requests that are invovled in running
// commands on a server, subsystem requests, and agent forwarding.
case "session":
ch, requests, err := nch.Accept()
if err != nil {
log.Infof("could not accept channel (%s)", err)
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(sconn, identityContext, ch, requests)
case "direct-tcpip": //port forwarding
// Channels of type "direct-tcpip" handles request for port forwarding.
case "direct-tcpip":
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
if err != nil {
log.Errorf("failed to parse request data: %v, err: %v", string(nch.ExtraData()), err)
log.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err)
nch.Reject(ssh.UnknownChannelType, "failed to parse direct-tcpip request")
return
}
ch, _, err := nch.Accept()
if err != nil {
log.Infof("could not accept channel (%s)", err)
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleDirectTCPIPRequest(sconn, identityContext, ch, req)
default:
@ -739,10 +747,17 @@ func (s *Server) HandleNewChan(nc net.Conn, sconn *ssh.ServerConn, nch ssh.NewCh
}
}
// handleDirectTCPIPRequest does the port forwarding
// handleDirectTCPIPRequest handles port forwarding requests.
func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// ctx holds the connection context and keeps track of the associated resources
ctx := srv.NewServerContext(s, sconn, identityContext)
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.IsTestStub = s.isTestStub
ctx.AddCloser(ch)
defer ctx.Debugf("direct-tcp closed")
@ -752,7 +767,7 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
dstAddr := fmt.Sprintf("%v:%d", req.Host, req.Port)
// check if the role allows port forwarding for this user
err := s.authHandlers.CheckPortForward(dstAddr, ctx)
err = s.authHandlers.CheckPortForward(dstAddr, ctx)
if err != nil {
ch.Stderr().Write([]byte(err.Error()))
return
@ -822,25 +837,18 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext
}
}
// handleTerminalResize is called by the web proxy via its SSH connection.
// when a web browser connects to the web API, the web proxy asks us,
// by creating this new SSH channel, to start injecting the terminal size
// into every SSH write back to it.
//
// this is the only way to make web-based terminal UI not break apart
// when window changes its size
func (s *Server) handleTerminalResize(sconn *ssh.ServerConn, ch ssh.Channel) {
err := s.reg.PushTermSizeToParty(sconn, ch)
if err != nil {
log.Warnf("Unable to push terminal size to party: %v", err)
}
}
// handleSessionRequests handles out of band session requests once the session channel has been created
// this function's loop handles all the "exec", "subsystem" and "shell" requests.
// handleSessionRequests handles out of band session requests once the session
// channel has been created this function's loop handles all the "exec",
// "subsystem" and "shell" requests.
func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
// ctx holds the connection context and keeps track of the associated resources
ctx := srv.NewServerContext(s, sconn, identityContext)
// Create context for this channel. This context will be closed when the
// session request is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
return
}
ctx.IsTestStub = s.isTestStub
ctx.AddCloser(ch)
defer ctx.Close()
@ -991,29 +999,27 @@ func (s *Server) handleAgentForwardNode(req *ssh.Request, ctx *srv.ServerContext
// requests should never fail, all errors should be logged and we should
// continue processing requests.
func (s *Server) handleAgentForwardProxy(req *ssh.Request, ctx *srv.ServerContext) error {
// we only support agent forwarding at the proxy when the proxy is in recording mode
clusterConfig, err := s.GetAccessPoint().GetClusterConfig()
if err != nil {
return trace.Wrap(err)
}
if clusterConfig.GetSessionRecording() != services.RecordAtProxy {
// Forwarding an agent to the proxy is only supported when the proxy is in
// recording mode.
if ctx.ClusterConfig.GetSessionRecording() != services.RecordAtProxy {
return trace.BadParameter("agent forwarding to proxy only supported in recording mode")
}
// check if the user's RBAC role allows agent forwarding
err = s.authHandlers.CheckAgentForward(ctx)
// Check if the user's RBAC role allows agent forwarding.
err := s.authHandlers.CheckAgentForward(ctx)
if err != nil {
return trace.Wrap(err)
}
// open a channel to the client where the client will serve an agent
// Open a channel to the client where the client will serve an agent.
authChannel, _, err := ctx.Conn.OpenChannel(sshutils.AuthAgentRequest, nil)
if err != nil {
return trace.Wrap(err)
}
// we save the agent so it can be used when we make a proxy subsystem request
// later and use it to build a remote connection to the target node.
// Save the agent so it can be used when making a proxy subsystem request
// later. It will also be used when building a remote connection to the
// target node.
ctx.SetAgent(agent.NewClient(authChannel), authChannel)
return nil

View file

@ -18,6 +18,7 @@ package srv
import (
"context"
"encoding/json"
"fmt"
"io"
"path/filepath"
@ -110,22 +111,56 @@ func (s *SessionRegistry) Close() {
s.log.Debugf("Closing Session Registry.")
}
// joinShell either joins an existing session or starts a new shell
// emitSessionJoinEvent emits a session join event to both the Audit Log as
// well as sending a "x-teleport-event" global request on the SSH connection.
func (s *SessionRegistry) emitSessionJoinEvent(ctx *ServerContext) {
sessionJoinEvent := events.EventFields{
events.EventType: events.SessionJoinEvent,
events.SessionEventID: string(ctx.session.id),
events.EventNamespace: s.srv.GetNamespace(),
events.EventLogin: ctx.Identity.Login,
events.EventUser: ctx.Identity.TeleportUser,
events.LocalAddr: ctx.Conn.LocalAddr().String(),
events.RemoteAddr: ctx.Conn.RemoteAddr().String(),
events.SessionServerID: ctx.srv.ID(),
}
// Emit session join event to Audit Log.
ctx.session.recorder.alog.EmitAuditEvent(events.SessionJoinEvent, sessionJoinEvent)
// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.getParties(ctx.session) {
eventPayload, err := json.Marshal(sessionJoinEvent)
if err != nil {
s.log.Warnf("Unable to marshal %v for %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
continue
}
_, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, eventPayload)
if err != nil {
s.log.Warnf("Unable to send %v to %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
continue
}
s.log.Debugf("Sent %v to %v.", events.SessionJoinEvent, p.sconn.RemoteAddr())
}
}
// OpenSession either joins an existing session or starts a new session.
func (s *SessionRegistry) OpenSession(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) error {
if ctx.session != nil {
// emit "joined session" event:
ctx.session.recorder.alog.EmitAuditEvent(events.SessionJoinEvent, events.EventFields{
events.SessionEventID: string(ctx.session.id),
events.EventNamespace: s.srv.GetNamespace(),
events.EventLogin: ctx.Identity.Login,
events.EventUser: ctx.Identity.TeleportUser,
events.LocalAddr: ctx.Conn.LocalAddr().String(),
events.RemoteAddr: ctx.Conn.RemoteAddr().String(),
events.SessionServerID: ctx.srv.ID(),
})
ctx.Infof("Joining existing session %v.", ctx.session.id)
// Update the in-memory data structure that a party member has joined.
_, err := ctx.session.join(ch, req, ctx)
return trace.Wrap(err)
if err != nil {
return trace.Wrap(err)
}
// Emit session join event to both the Audit Log as well as over the
// "x-teleport-event" channel in the SSH connection.
s.emitSessionJoinEvent(ctx)
return nil
}
// session not found? need to create one. start by getting/generating an ID for it
sid, found := ctx.GetEnv(sshutils.SessionEnvVar)
@ -150,25 +185,52 @@ func (s *SessionRegistry) OpenSession(ch ssh.Channel, req *ssh.Request, ctx *Ser
return nil
}
// leaveSession removes the given party from this session
// emitSessionLeaveEvent emits a session leave event to both the Audit Log as
// well as sending a "x-teleport-event" global request on the SSH connection.
func (s *SessionRegistry) emitSessionLeaveEvent(party *party) {
sessionLeaveEvent := events.EventFields{
events.EventType: events.SessionLeaveEvent,
events.SessionEventID: party.id.String(),
events.EventUser: party.user,
events.SessionServerID: party.serverID,
events.EventNamespace: s.srv.GetNamespace(),
}
// Emit session leave event to Audit Log.
party.s.recorder.alog.EmitAuditEvent(events.SessionLeaveEvent, sessionLeaveEvent)
// Notify all members of the party that a new member has left over the
// "x-teleport-event" channel.
for _, p := range s.getParties(party.s) {
eventPayload, err := json.Marshal(sessionLeaveEvent)
if err != nil {
s.log.Warnf("Unable to marshal %v for %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
continue
}
_, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, eventPayload)
if err != nil {
s.log.Warnf("Unable to send %v to %v: %v.", events.SessionJoinEvent, p.sconn.RemoteAddr(), err)
continue
}
s.log.Debugf("Sent %v to %v.", events.SessionJoinEvent, p.sconn.RemoteAddr())
}
}
// leaveSession removes the given party from this session.
func (s *SessionRegistry) leaveSession(party *party) error {
sess := party.s
s.Lock()
defer s.Unlock()
// remove from in-memory representation of the session:
// Emit session leave event to both the Audit Log as well as over the
// "x-teleport-event" channel in the SSH connection.
s.emitSessionLeaveEvent(party)
// Remove member from in-members representation of party.
if err := sess.removeParty(party); err != nil {
return trace.Wrap(err)
}
// emit "session leave" event (party left the session)
sess.recorder.alog.EmitAuditEvent(events.SessionLeaveEvent, events.EventFields{
events.SessionEventID: string(sess.id),
events.EventUser: party.user,
events.SessionServerID: party.serverID,
events.EventNamespace: s.srv.GetNamespace(),
})
// this goroutine runs for a short amount of time only after a session
// becomes empty (no parties). It allows session to "linger" for a bit
// allowing parties to reconnect if they lost connection momentarily
@ -221,54 +283,84 @@ func (s *SessionRegistry) leaveSession(party *party) error {
// getParties allows to safely return a list of parties connected to this
// session (as determined by ctx)
func (s *SessionRegistry) getParties(ctx *ServerContext) (parties []*party) {
sess := ctx.session
if sess != nil {
sess.Lock()
defer sess.Unlock()
func (s *SessionRegistry) getParties(sess *session) []*party {
var parties []*party
parties = make([]*party, 0, len(sess.parties))
for _, p := range sess.parties {
parties = append(parties, p)
}
if sess == nil {
return parties
}
sess.Lock()
defer sess.Unlock()
for _, p := range sess.parties {
parties = append(parties, p)
}
return parties
}
// notifyWinChange is called when an SSH server receives a command notifying
// us that the terminal size has changed
// NotifyWinChange is called to notify all members in the party that the PTY
// size has changed. The notification is sent as a global SSH request and it
// is the responsibility of the client to update it's window size upon receipt.
func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *ServerContext) error {
if ctx.session == nil {
s.log.Debugf("Unable to update window size, no session found in context.")
return nil
}
sid := ctx.session.id
// report this to the event/audit log:
ctx.session.recorder.alog.EmitAuditEvent(events.ResizeEvent, events.EventFields{
// Build the resize event.
resizeEvent := events.EventFields{
events.EventType: events.ResizeEvent,
events.EventNamespace: s.srv.GetNamespace(),
events.SessionEventID: sid,
events.EventLogin: ctx.Identity.Login,
events.EventUser: ctx.Identity.TeleportUser,
events.TerminalSize: params.Serialize(),
})
}
// Report the updated window size to the event log (this is so the sessions
// can be replayed correctly).
ctx.session.recorder.alog.EmitAuditEvent(events.ResizeEvent, resizeEvent)
// Update the size of the server side PTY.
err := ctx.session.term.SetWinSize(params)
if err != nil {
return trace.Wrap(err)
}
// notify all connected parties about the change in real time
// (if they're capable)
for _, p := range s.getParties(ctx) {
p.onWindowChanged(&params)
// If sessions are being recorded at the proxy, sessions can not be shared.
// In that situation, PTY size information does not need to be propagated
// back to all clients and we can return right away.
if ctx.ClusterConfig.GetSessionRecording() == services.RecordAtProxy {
return nil
}
go func() {
err := s.srv.GetSessionServer().UpdateSession(
rsession.UpdateRequest{ID: sid, TerminalParams: &params, Namespace: s.srv.GetNamespace()})
if err != nil {
s.log.Errorf("Unable to update session %v: %v", sid, err)
// Notify all members of the party (except originator) that the size of the
// window has changed so the client can update it's own local PTY. Note that
// OpenSSH clients will ignore this and not update their own local PTY.
for _, p := range s.getParties(ctx.session) {
// Don't send the window change notification back to the originator.
if p.ctx.ID() == ctx.ID() {
continue
}
}()
eventPayload, err := json.Marshal(resizeEvent)
if err != nil {
s.log.Warnf("Unable to marshal resize event for %v: %v.", p.sconn.RemoteAddr(), err)
continue
}
// Send the message as a global request.
_, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, eventPayload)
if err != nil {
s.log.Warnf("Unable to resize event to %v: %v.", p.sconn.RemoteAddr(), err)
continue
}
s.log.Debugf("Sent resize event %v to %v.", params, p.sconn.RemoteAddr())
}
return nil
}
@ -289,43 +381,6 @@ func (s *SessionRegistry) findSession(id rsession.ID) (*session, bool) {
return sess, found
}
func (r *SessionRegistry) PushTermSizeToParty(sconn *ssh.ServerConn, ch ssh.Channel) error {
// the party may not be immediately available for this connection,
// keep asking for a full second:
for i := 0; i < 10; i++ {
party := r.partyForConnection(sconn)
if party == nil {
time.Sleep(time.Millisecond * 100)
continue
}
// this starts a loop which will keep updating the terminal
// size for every SSH write back to this connection
party.termSizePusher(ch)
return nil
}
return trace.Errorf("unable to push term size to party")
}
// partyForConnection finds an existing party which owns the given connection
func (r *SessionRegistry) partyForConnection(sconn *ssh.ServerConn) *party {
r.Lock()
defer r.Unlock()
for _, session := range r.sessions {
session.Lock()
defer session.Unlock()
parties := session.parties
for _, party := range parties {
if party.sconn == sconn {
return party
}
}
}
return nil
}
// sessionRecorder implements io.Writer to be plugged into the multi-writer
// associated with every session. It forwards session stream to the audit log
type sessionRecorder struct {
@ -343,21 +398,18 @@ type sessionRecorder struct {
}
func newSessionRecorder(alog events.IAuditLog, ctx *ServerContext, sid rsession.ID) (*sessionRecorder, error) {
var err error
var auditLog events.IAuditLog
if alog == nil {
auditLog = &events.DiscardAuditLog{}
} else {
clusterConfig, err := ctx.srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
}
// always write sessions to local disk first
// forward them to auth server later
// Always write sessions to local disk first, then forward them to the Auth
// Server later.
auditLog, err = events.NewForwarder(events.ForwarderConfig{
SessionID: sid,
ServerID: "upload",
DataDir: filepath.Join(ctx.srv.GetDataDir(), teleport.LogsDir),
RecordSessions: clusterConfig.GetSessionRecording() != services.RecordOff,
RecordSessions: ctx.ClusterConfig.GetSessionRecording() != services.RecordOff,
Namespace: ctx.srv.GetNamespace(),
ForwardTo: alog,
})
@ -533,10 +585,12 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
return sess, nil
}
// isLingering returns 'true' if every party has left this session
// isLingering returns true if every party has left this session. Occurs
// under a lock.
func (s *session) isLingering() bool {
s.Lock()
defer s.Unlock()
return len(s.parties) == 0
}
@ -644,9 +698,9 @@ func (s *session) start(ch ssh.Channel, ctx *ServerContext) error {
events.TerminalSize: params.Serialize(),
})
// start asynchronous loop of synchronizing session state with
// the session server (terminal size and activity)
go s.pollAndSync()
// Start a heartbeat that marks this session as active with current members
// of party in the backend.
go s.heartbeat(ctx)
doneCh := make(chan bool, 1)
@ -732,41 +786,25 @@ func (s *session) String() string {
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties))
}
// removeParty removes the party from two places:
// 1. from in-memory dictionary inside of this session
// 2. from sessin server's storage
// removePartyMember removes participant from in-memory representation of
// party members. Occurs under a lock.
func (s *session) removePartyMember(party *party) {
s.Lock()
defer s.Unlock()
delete(s.parties, party.id)
}
// removeParty removes the party from the in-memory map that holds all party
// members.
func (s *session) removeParty(p *party) error {
p.ctx.Infof("Removing party %v from session %v", p, s.id)
ns := s.getNamespace()
// Removes participant from in-memory map of party members.
s.removePartyMember(p)
// in-memory locked remove:
lockedRemove := func() {
s.Lock()
defer s.Unlock()
delete(s.parties, p.id)
s.writer.deleteWriter(string(p.id))
}
lockedRemove()
s.writer.deleteWriter(string(p.id))
// remove from the session server (asynchronously)
storageRemove := func(db rsession.Service) {
dbSession, err := db.GetSession(ns, s.id)
if err != nil {
s.log.Errorf("Unable to get session %v: %v", s.id, err)
return
}
if dbSession != nil && dbSession.RemoveParty(p.id) {
db.UpdateSession(rsession.UpdateRequest{
ID: dbSession.ID,
Parties: &dbSession.Parties,
Namespace: ns,
})
}
}
if s.registry.srv.GetSessionServer() != nil {
go storageRemove(s.registry.srv.GetSessionServer())
}
return nil
}
@ -786,81 +824,82 @@ func (s *session) getNamespace() string {
return s.registry.srv.GetNamespace()
}
// pollAndSync is a loops forever trying to sync terminal size to what's in
// the session (so all connected parties have the same terminal size) and
// update the "active" field of the session. If the session are recorded at
// the proxy, then this function does nothing as it's counterpart in the proxy
// will do this work.
func (s *session) pollAndSync() {
// exportPartyMembers exports participants in the in-memory map of party
// members. Occurs under a lock.
func (s *session) exportPartyMembers() []rsession.Party {
s.Lock()
defer s.Unlock()
var partyList []rsession.Party
for _, p := range s.parties {
partyList = append(partyList, rsession.Party{
ID: p.id,
User: p.user,
ServerID: p.serverID,
RemoteAddr: p.site,
LastActive: p.getLastActive(),
})
}
return partyList
}
// heartbeat will loop as long as the session is not closed and mark it as
// active and update the list of party members. If the session are recorded at
// the proxy, then this function does nothing as it's counterpart
// in the proxy will do this work.
func (s *session) heartbeat(ctx *ServerContext) {
// If sessions are being recorded at the proxy, an identical version of this
// goroutine is running in the proxy, which means it does not need to run here.
clusterConfig, err := s.registry.srv.GetAccessPoint().GetClusterConfig()
if err != nil {
s.log.Errorf("Unable to sync terminal size: %v.", err)
if ctx.ClusterConfig.GetSessionRecording() == services.RecordAtProxy &&
s.registry.srv.Component() == teleport.ComponentNode {
return
}
if clusterConfig.GetSessionRecording() == services.RecordAtProxy &&
s.registry.srv.Component() == teleport.ComponentNode {
// If no session server (endpoint interface for active sessions) is passed in
// (for example Teleconsole does this) then nothing to sync.
sessionServer := s.registry.srv.GetSessionServer()
if sessionServer == nil {
return
}
s.log.Debugf("Starting poll and sync of terminal size to all parties.")
defer s.log.Debugf("Stopping poll and sync of terminal size to all parties.")
ns := s.getNamespace()
tickerCh := time.NewTicker(defaults.SessionRefreshPeriod)
defer tickerCh.Stop()
sessionServer := s.registry.srv.GetSessionServer()
if sessionServer == nil {
return
}
errCount := 0
sync := func() error {
sess, err := sessionServer.GetSession(ns, s.id)
if err != nil || sess == nil {
return trace.Wrap(err)
}
var active = true
sessionServer.UpdateSession(rsession.UpdateRequest{
Namespace: ns,
ID: sess.ID,
Active: &active,
Parties: nil,
})
winSize, err := s.term.GetWinSize()
if err != nil {
return err
}
termSizeChanged := (int(winSize.Width) != sess.TerminalParams.W ||
int(winSize.Height) != sess.TerminalParams.H)
if termSizeChanged {
s.log.Debugf("Terminal changed from: %v to %v", sess.TerminalParams, winSize)
err = s.term.SetWinSize(sess.TerminalParams)
}
return err
}
tick := time.NewTicker(defaults.TerminalSizeRefreshPeriod)
defer tick.Stop()
// Loop as long as the session is active, updating the session in the backend.
for {
if err := sync(); err != nil {
s.log.Infof("Unable to sync terminal: %v", err)
errCount++
// if the error count keeps going up, this means we're stuck in
// a bad state: end this goroutine to avoid leaks
if errCount > maxTermSyncErrorCount {
return
}
} else {
errCount = 0
}
select {
case <-tickerCh.C:
partyList := s.exportPartyMembers()
var active = true
err := sessionServer.UpdateSession(rsession.UpdateRequest{
Namespace: s.getNamespace(),
ID: s.id,
Active: &active,
Parties: &partyList,
})
if err != nil {
s.log.Warnf("Unable to update session %v as active: %v", s.id, err)
}
case <-s.closeC:
return
case <-tick.C:
}
}
}
// addPartyMember adds participant to in-memory map of party members. Occurs
// under a lock.
func (s *session) addPartyMember(p *party) {
s.Lock()
defer s.Unlock()
s.parties[p.id] = p
}
// addParty is called when a new party joins the session.
func (s *session) addParty(p *party) error {
if s.login != p.login {
@ -869,9 +908,11 @@ func (s *session) addParty(p *party) error {
s.login, p.login, s.id)
}
s.parties[p.id] = p
// write last chunk (so the newly joined parties won't stare
// at a blank screen)
// Adds participant to in-memory map of party members.
s.addPartyMember(p)
// Write last chunk (so the newly joined parties won't stare at a blank
// screen).
getRecentWrite := func() []byte {
s.writer.Lock()
defer s.writer.Unlock()
@ -883,39 +924,14 @@ func (s *session) addParty(p *party) error {
}
p.Write(getRecentWrite())
// register this party as one of the session writers
// (output will go to it)
// Register this party as one of the session writers (output will go to it).
s.writer.addWriter(string(p.id), p, true)
p.ctx.AddCloser(p)
s.term.AddParty(1)
// update session on the session server
storageUpdate := func(db rsession.Service) {
dbSession, err := db.GetSession(s.getNamespace(), s.id)
if err != nil {
s.log.Errorf("Unable to get session %v: %v", s.id, err)
return
}
dbSession.Parties = append(dbSession.Parties, rsession.Party{
ID: p.id,
User: p.user,
ServerID: p.serverID,
RemoteAddr: p.site,
LastActive: p.getLastActive(),
})
db.UpdateSession(rsession.UpdateRequest{
ID: dbSession.ID,
Parties: &dbSession.Parties,
Namespace: s.getNamespace(),
})
}
if s.registry.srv.GetSessionServer() != nil {
go storageUpdate(s.registry.srv.GetSessionServer())
}
s.log.Infof("New party %v joined session: %v", p.String(), s.id)
// this goroutine keeps pumping party's input into the session
// This goroutine keeps pumping party's input into the session.
go func() {
defer s.term.AddParty(-1)
_, err := io.Copy(s.term.PTY(), p)
@ -1046,48 +1062,6 @@ func newParty(s *session, ch ssh.Channel, ctx *ServerContext) *party {
}
}
func (p *party) onWindowChanged(params *rsession.TerminalParams) {
p.log.Debugf("Window size changed to %v in party: %v", params.Serialize(), p.id)
p.Lock()
defer p.Unlock()
// this prefix will be appended to the end of every socker write going
// to this party:
prefix := []byte("\x00" + params.Serialize())
if p.termSizeC != nil && len(p.termSizeC) == 0 {
p.termSizeC <- prefix
}
}
// This goroutine pushes terminal resize events directly into a connected web client
func (p *party) termSizePusher(ch ssh.Channel) {
var (
err error
n int
)
defer func() {
if err != nil {
p.log.Error(err)
}
}()
for err == nil {
select {
case newSize := <-p.termSizeC:
n, err = ch.Write(newSize)
if err == io.EOF {
continue
}
if err != nil || n == 0 {
return
}
case <-p.closeC:
return
}
}
}
func (p *party) updateActivity() {
p.Lock()
defer p.Unlock()

View file

@ -85,19 +85,15 @@ type Terminal interface {
// NewTerminal returns a new terminal. Terminal can be local or remote
// depending on cluster configuration.
func NewTerminal(ctx *ServerContext) (Terminal, error) {
// doesn't matter what mode the cluster is in, if this is a teleport node
// return a local terminal
// It doesn't matter what mode the cluster is in, if this is a Teleport node
// return a local terminal.
if ctx.srv.Component() == teleport.ComponentNode {
return newLocalTerminal(ctx)
}
// otherwise find out what mode the cluster is in and return the
// correct terminal
clusterConfig, err := ctx.srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
}
if clusterConfig.GetSessionRecording() == services.RecordAtProxy {
// If this is not a Teleport node, find out what mode the cluster is in and
// return the correct terminal.
if ctx.ClusterConfig.GetSessionRecording() == services.RecordAtProxy {
return newRemoteTerminal(ctx)
}
return newLocalTerminal(ctx)

View file

@ -180,22 +180,16 @@ func (t *TermHandlers) HandleShell(ch ssh.Channel, req *ssh.Request, ctx *Server
}
// HandleWinChange handles requests of type "window-change" which update the
// size of the TTY running on the server.
// size of the PTY running on the server and update any other members in the
// party.
func (t *TermHandlers) HandleWinChange(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) error {
params, err := parseWinChange(req)
if err != nil {
ctx.Error(err)
return trace.Wrap(err)
}
term := ctx.GetTerm()
if term != nil {
err = term.SetWinSize(*params)
if err != nil {
ctx.Errorf("Unable to set window size: %v", err)
}
}
// Update any other members in the party that the window size has changed
// and to update their terminal windows accordingly.
err = t.SessionRegistry.NotifyWinChange(*params, ctx)
if err != nil {
return trace.Wrap(err)

View file

@ -135,10 +135,6 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
}
}
if h.sessionStreamPollPeriod == 0 {
h.sessionStreamPollPeriod = sessionStreamPollPeriod
}
if h.clock == nil {
h.clock = clockwork.NewRealClock()
}
@ -175,12 +171,10 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
h.GET("/webapi/sites/:site/namespaces/:namespace/nodes", h.WithClusterAuth(h.siteNodesGet))
// active sessions handlers
h.GET("/webapi/sites/:site/namespaces/:namespace/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket)
h.GET("/webapi/sites/:site/namespaces/:namespace/sessions", h.WithClusterAuth(h.siteSessionsGet)) // get active list of sessions
h.POST("/webapi/sites/:site/namespaces/:namespace/sessions", h.WithClusterAuth(h.siteSessionGenerate)) // create active session metadata
h.GET("/webapi/sites/:site/namespaces/:namespace/sessions/:sid", h.WithClusterAuth(h.siteSessionGet)) // get active session metadata
h.PUT("/webapi/sites/:site/namespaces/:namespace/sessions/:sid", h.WithClusterAuth(h.siteSessionUpdate)) // update active session metadata (parameters)
h.GET("/webapi/sites/:site/namespaces/:namespace/sessions/:sid/events/stream", h.WithClusterAuth(h.siteSessionStream)) // get active session's byte stream (from events)
h.GET("/webapi/sites/:site/namespaces/:namespace/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket)
h.GET("/webapi/sites/:site/namespaces/:namespace/sessions", h.WithClusterAuth(h.siteSessionsGet)) // get active list of sessions
h.POST("/webapi/sites/:site/namespaces/:namespace/sessions", h.WithClusterAuth(h.siteSessionGenerate)) // create active session metadata
h.GET("/webapi/sites/:site/namespaces/:namespace/sessions/:sid", h.WithClusterAuth(h.siteSessionGet)) // get active session metadata
// recorded sessions handlers
h.GET("/webapi/sites/:site/events", h.WithClusterAuth(h.siteEventsGet)) // get recorded list of sessions (from events)
@ -1406,54 +1400,11 @@ func (h *Handler) siteNodeConnect(
// start the websocket session with a web-based terminal:
log.Infof("[WEB] getting terminal to '%#v'", req)
term.Run(w, r)
term.Serve(w, r)
return nil, nil
}
// sessionStreamEvent is sent over the session stream socket, it contains
// last events that occurred (only new events are sent)
type sessionStreamEvent struct {
Events []events.EventFields `json:"events"`
Session *session.Session `json:"session"`
Servers []services.ServerV1 `json:"servers"`
}
// siteSessionStream returns a stream of events related to the session
//
// GET /v1/webapi/sites/:site/namespaces/:namespace/sessions/:sid/events/stream?access_token=bearer_token
//
// Successful response is a websocket stream that allows read write to the server and returns
// json events
//
func (h *Handler) siteSessionStream(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) {
sessionID, err := session.ParseID(p.ByName("sid"))
if err != nil {
return nil, trace.Wrap(err)
}
namespace := p.ByName("namespace")
if !services.IsValidNamespace(namespace) {
return nil, trace.BadParameter("invalid namespace %q", namespace)
}
connect, err := newSessionStreamHandler(namespace,
*sessionID, ctx, site, h.sessionStreamPollPeriod)
if err != nil {
return nil, trace.Wrap(err)
}
// this is to make sure we close web socket connections once
// sessionContext that owns them expires
ctx.AddClosers(connect)
defer func() {
connect.Close()
ctx.RemoveCloser(connect)
}()
connect.Handler().ServeHTTP(w, r)
return nil, nil
}
type siteSessionGenerateReq struct {
Session session.Session `json:"session"`
}
@ -1496,48 +1447,6 @@ type siteSessionUpdateReq struct {
TerminalParams session.TerminalParams `json:"terminal_params"`
}
// siteSessionUpdate udpdates the site session
//
// PUT /v1/webapi/sites/:site/sessions/:sid
//
// Request body:
//
// {"terminal_params": {"w": 100, "h": 100}}
//
// Response body:
//
// {"message": "ok"}
//
func (h *Handler) siteSessionUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) {
sessionID, err := session.ParseID(p.ByName("sid"))
if err != nil {
return nil, trace.Wrap(err)
}
var req *siteSessionUpdateReq
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
siteAPI, err := site.GetClient()
if err != nil {
log.Error(err)
return nil, trace.Wrap(err)
}
namespace := p.ByName("namespace")
if !services.IsValidNamespace(namespace) {
return nil, trace.BadParameter("invalid namespace %q", namespace)
}
err = ctx.UpdateSessionTerminal(siteAPI, namespace, *sessionID, req.TerminalParams)
if err != nil {
log.Error(err)
return nil, trace.Wrap(err)
}
return ok(), nil
}
type siteSessionsGetResponse struct {
Sessions []session.Session `json:"sessions"`
}

File diff suppressed because it is too large Load diff

View file

@ -26,22 +26,23 @@ import (
"sync"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/gravitational/ttlmap"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"
"github.com/tstranex/u2f"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// SessionContext is a context associated with users'
@ -57,43 +58,7 @@ type SessionContext struct {
remoteClt map[string]auth.ClientI
parent *sessionCache
closers []io.Closer
}
// getTerminal finds and returns an active web terminal for a given session:
func (c *SessionContext) getTerminal(sessionID session.ID) (*TerminalHandler, error) {
c.Lock()
defer c.Unlock()
for _, closer := range c.closers {
term, ok := closer.(*TerminalHandler)
if ok && term.params.SessionID == sessionID {
return term, nil
}
}
return nil, trace.NotFound("no connected streams")
}
// UpdateSessionTerminal is called when a browser window is resized and
// we need to update PTY on the server side
func (c *SessionContext) UpdateSessionTerminal(
siteAPI auth.ClientI, namespace string, sessionID session.ID, params session.TerminalParams) error {
// update the session size on the auth server's side
err := siteAPI.UpdateSession(session.UpdateRequest{
ID: sessionID,
TerminalParams: &params,
Namespace: namespace,
})
if err != nil {
log.Error(err)
}
// update the server-side PTY to match the browser window size
term, err := c.getTerminal(sessionID)
if err != nil {
log.Error(err)
return trace.Wrap(err)
}
return trace.Wrap(term.resizePTYWindow(params))
tc *client.TeleportClient
}
func (c *SessionContext) AddClosers(closers ...io.Closer) {

View file

@ -1,171 +0,0 @@
/*
Copyright 2015 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package web
import (
"io"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"golang.org/x/net/websocket"
)
func newSessionStreamHandler(namespace string, sessionID session.ID, ctx *SessionContext, site reversetunnel.RemoteSite, pollPeriod time.Duration) (*sessionStreamHandler, error) {
return &sessionStreamHandler{
pollPeriod: pollPeriod,
sessionID: sessionID,
ctx: ctx,
site: site,
closeC: make(chan bool),
namespace: namespace,
}, nil
}
// sessionStreamHandler streams events related to some particular session
// as a stream of JSON encoded event packets
type sessionStreamHandler struct {
closeOnce sync.Once
pollPeriod time.Duration
ctx *SessionContext
site reversetunnel.RemoteSite
namespace string
sessionID session.ID
closeC chan bool
ws *websocket.Conn
}
func (w *sessionStreamHandler) Close() error {
if w.ws != nil {
w.ws.Close()
}
w.closeOnce.Do(func() {
close(w.closeC)
})
return nil
}
// sessionStreamPollPeriod defines how frequently web sessions are
// sent new events
var sessionStreamPollPeriod = time.Second
// stream runs in a loop generating "something changed" events for a
// given active WebSession
//
// The events are fed to a web client via the websocket
func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
w.ws = ws
clt, err := w.site.GetClient()
if err != nil {
return trace.Wrap(err)
}
// spin up a goroutine to detect closed socket by reading
// from it
go func() {
defer w.Close()
io.Copy(ioutil.Discard, ws)
}()
eventsCursor := -1
emptyEventList := make([]events.EventFields, 0)
pollEvents := func() []events.EventFields {
// ask for any events than happened since the last call:
re, err := clt.GetSessionEvents(w.namespace, w.sessionID, eventsCursor+1, false)
if err != nil {
if !trace.IsNotFound(err) {
log.Error(err)
}
return emptyEventList
}
batchLen := len(re)
if batchLen == 0 {
return emptyEventList
}
// advance the cursor, so next time we'll ask for the latest:
eventsCursor = re[batchLen-1].GetInt(events.EventCursor)
return re
}
ticker := time.NewTicker(w.pollPeriod)
defer ticker.Stop()
defer w.Close()
// keep polling in a loop:
for {
// wait for next timer tick or a signal to abort:
select {
case <-ticker.C:
case <-w.closeC:
log.Infof("[web] session.stream() exited")
return nil
}
newEvents := pollEvents()
sess, err := clt.GetSession(w.namespace, w.sessionID)
if err != nil {
if trace.IsNotFound(err) {
continue
}
log.Error(err)
}
if sess == nil {
log.Warningf("invalid session ID: %v", w.sessionID)
continue
}
servers, err := clt.GetNodes(w.namespace)
if err != nil {
log.Error(err)
}
if len(newEvents) > 0 {
log.Infof("[WEB] streaming for %v. Events: %v, Nodes: %v, Parties: %v",
w.sessionID, len(newEvents), len(servers), len(sess.Parties))
}
// push events to the web client
event := &sessionStreamEvent{
Events: newEvents,
Session: sess,
Servers: services.ServersToV1(servers),
}
if err := websocket.JSON.Send(ws, event); err != nil {
log.Error(err)
}
}
}
func (w *sessionStreamHandler) Handler() http.Handler {
// TODO(klizhentas)
// we instantiate a server explicitly here instead of using
// websocket.HandlerFunc to set empty origin checker
// make sure we check origin when in prod mode
return &websocket.Server{
Handler: func(ws *websocket.Conn) {
if err := w.stream(ws); err != nil {
log.WithFields(log.Fields{"sid": w.sessionID}).Infof("handler returned: %#v", err)
}
},
}
}

View file

@ -18,55 +18,78 @@ package web
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/net/websocket"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/unicode"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/net/websocket"
)
// TerminalRequest describes a request to crate a web-based terminal
// to a remote SSH server
// TerminalRequest describes a request to create a web-based terminal
// to a remote SSH server.
type TerminalRequest struct {
// Server describes a server to connect to (serverId|hostname[:port])
// Server describes a server to connect to (serverId|hostname[:port]).
Server string `json:"server_id"`
// User is linux username to connect as
// Login is Linux username to connect as.
Login string `json:"login"`
// Term sets PTY params like width and height
// Term is the initial PTY size.
Term session.TerminalParams `json:"term"`
// SessionID is a teleport session ID to join as
// SessionID is a Teleport session ID to join as.
SessionID session.ID `json:"sid"`
// Namespace is node namespace
// Namespace is node namespace.
Namespace string `json:"namespace"`
// Proxy server address
// ProxyHostPort is the address of the server to connect to.
ProxyHostPort string `json:"-"`
// Remote cluster name
// Cluster is the name of the remote cluster to connect to.
Cluster string `json:"-"`
// InteractiveCommand is a command to execute
// InteractiveCommand is a command to execut.e
InteractiveCommand []string `json:"-"`
// SessionTimeout is how long to wait for the session end event to arrive.
SessionTimeout time.Duration
}
// NodeProvider is a provider of nodes for namespace
type NodeProvider interface {
// AuthProvider is a subset of the full Auth API.
type AuthProvider interface {
GetNodes(namespace string) ([]services.Server, error)
GetSessionEvents(namespace string, sid session.ID, after int, includePrintEvents bool) ([]events.EventFields, error)
}
// newTerminal creates a web-based terminal based on WebSockets and returns a new
// TerminalHandler
func NewTerminal(req TerminalRequest, provider NodeProvider, ctx *SessionContext) (*TerminalHandler, error) {
// make sure whatever session is requested is a valid session
// newTerminal creates a web-based terminal based on WebSockets and returns a
// new TerminalHandler.
func NewTerminal(req TerminalRequest, authProvider AuthProvider, ctx *SessionContext) (*TerminalHandler, error) {
if req.SessionTimeout == 0 {
req.SessionTimeout = defaults.HTTPIdleTimeout
}
// Make sure whatever session is requested is a valid session.
_, err := session.ParseID(string(req.SessionID))
if err != nil {
return nil, trace.BadParameter("sid: invalid session id")
@ -79,7 +102,7 @@ func NewTerminal(req TerminalRequest, provider NodeProvider, ctx *SessionContext
return nil, trace.BadParameter("term: bad term dimensions")
}
servers, err := provider.GetNodes(req.Namespace)
servers, err := authProvider.GetNodes(req.Namespace)
if err != nil {
return nil, trace.Wrap(err)
}
@ -90,50 +113,336 @@ func NewTerminal(req TerminalRequest, provider NodeProvider, ctx *SessionContext
}
return &TerminalHandler{
params: req,
ctx: ctx,
hostName: hostName,
hostPort: hostPort,
namespace: req.Namespace,
sessionID: req.SessionID,
params: req,
ctx: ctx,
hostName: hostName,
hostPort: hostPort,
authProvider: authProvider,
sessionTimeout: req.SessionTimeout,
}, nil
}
// TerminalHandler connects together an SSH session with a web-based
// terminal via a web socket.
type TerminalHandler struct {
// params describe the terminal configuration
// namespace is node namespace.
namespace string
// sessionID is a Teleport session ID to join as.
sessionID session.ID
// params is the initial PTY size.
params TerminalRequest
// ctx is a web session context for the currently logged in user
// ctx is a web session context for the currently logged in user.
ctx *SessionContext
// ws is the websocket which is connected to stdin/out/err of the terminal shell
// ws is the websocket which is connected to stdin/out/err of the terminal shell.
ws *websocket.Conn
// hostName we're connected to
// hostName is the hostname of the server.
hostName string
// hostPort we're connected to
// hostPort is the port of the server.
hostPort int
// sshClient is initialized after an SSH connection to a node is established
// sshSession holds the "shell" SSH channel to the node.
sshSession *ssh.Session
// teleportClient is the client used to form the connection.
teleportClient *client.TeleportClient
// terminalContext is used to signal when the terminal sesson is closing.
terminalContext context.Context
// terminalCancel is used to signal when the terminal session is closing.
terminalCancel context.CancelFunc
// eventContext is used to signal when the event stream is closing.
eventContext context.Context
// eventCancel is used to signal when the event is closing.
eventCancel context.CancelFunc
// request is the HTTP request that initiated the websocket connection.
request *http.Request
// authProvider is used to fetch nodes and sessions from the backend.
authProvider AuthProvider
// sessionTimeout is how long to wait for the session end event to arrive.
sessionTimeout time.Duration
}
// Serve builds a connect to the remote node and then pumps back two types of
// events: raw input/output events for what's happening on the terminal itself
// and audit log events relevant to this session.
func (t *TerminalHandler) Serve(w http.ResponseWriter, r *http.Request) {
t.request = r
// This allows closing of the websocket if the user logs out before exiting
// the session.
t.ctx.AddClosers(t)
defer t.ctx.RemoveCloser(t)
// We initial a server explicitly here instead of using websocket.HandlerFunc
// to set an empty origin checker (this is to make our lives easier in tests).
// The main use of the origin checker is to enforce the browsers same-origin
// policy. That does not matter here because even if malicious Javascript
// would try and open a websocket the request to this endpoint requires the
// bearer token to be in the URL so it would not be sent along by default
// like cookies are.
ws := &websocket.Server{Handler: t.handler}
ws.ServeHTTP(w, r)
}
// Close the websocket stream.
func (t *TerminalHandler) Close() error {
// Close the websocket connection to the client web browser.
if t.ws != nil {
t.ws.Close()
}
// Close the SSH connection to the remote node.
if t.sshSession != nil {
t.sshSession.Close()
}
// If the terminal handler was closed (most likely due to the *SessionContext
// closing) then the stream should be closed as well.
t.terminalCancel()
return nil
}
// resizePTYWindow is called when a brower resizes its window. Now the node
// needs to be notified via SSH
func (t *TerminalHandler) resizePTYWindow(params session.TerminalParams) error {
// handler is the main websocket loop. It creates a Teleport client and then
// pumps raw events and audit events back to the client until the SSH session
// is complete.
func (t *TerminalHandler) handler(ws *websocket.Conn) {
// Create a Teleport client, if not able to, show the reason to the user in
// the terminal.
tc, err := t.makeClient(ws)
if err != nil {
errToTerm(err, ws)
return
}
// Create two contexts for signaling. The first
t.terminalContext, t.terminalCancel = context.WithCancel(context.Background())
t.eventContext, t.eventCancel = context.WithCancel(context.Background())
// Pump raw terminal in/out and audit events into the websocket.
go t.streamTerminal(ws, tc)
go t.streamEvents(ws, tc)
// Block until the terminal session is complete.
<-t.terminalContext.Done()
// Block until the session end event is sent or a timeout occurs.
timeoutCh := time.After(t.sessionTimeout)
for {
select {
case <-timeoutCh:
t.eventCancel()
case <-t.eventContext.Done():
}
log.Debugf("Closing websocket stream to web client.")
return
}
}
// makeClient builds a *client.TeleportClient for the connection.
func (t *TerminalHandler) makeClient(ws *websocket.Conn) (*client.TeleportClient, error) {
agent, cert, err := t.ctx.GetAgent()
if err != nil {
return nil, trace.BadParameter("failed to get user credentials: %v", err)
}
signers, err := agent.Signers()
if err != nil {
return nil, trace.BadParameter("failed to get user credentials: %v", err)
}
tlsConfig, err := t.ctx.ClientTLSConfig()
if err != nil {
return nil, trace.BadParameter("failed to get client TLS config: %v", err)
}
// Create a wrapped websocket to wrap/unwrap the envelope used to
// communicate over the websocket.
wrappedSock := newWrappedSocket(ws, t)
clientConfig := &client.Config{
SkipLocalAuth: true,
ForwardAgent: true,
Agent: agent,
TLS: tlsConfig,
AuthMethods: []ssh.AuthMethod{ssh.PublicKeys(signers...)},
DefaultPrincipal: cert.ValidPrincipals[0],
HostLogin: t.params.Login,
Username: t.ctx.user,
Namespace: t.params.Namespace,
Stdout: wrappedSock,
Stderr: wrappedSock,
Stdin: wrappedSock,
SiteName: t.params.Cluster,
ProxyHostPort: t.params.ProxyHostPort,
Host: t.hostName,
HostPort: t.hostPort,
Env: map[string]string{sshutils.SessionEnvVar: string(t.params.SessionID)},
HostKeyCallback: func(string, net.Addr, ssh.PublicKey) error { return nil },
ClientAddr: t.request.RemoteAddr,
}
if len(t.params.InteractiveCommand) > 0 {
clientConfig.Interactive = true
}
tc, err := client.NewClient(clientConfig)
if err != nil {
return nil, trace.BadParameter("failed to create client: %v", err)
}
// Save the *ssh.Session after the shell has been created. The session is
// used to update all other parties window size to that of the web client and
// to allow future window changes.
tc.OnShellCreated = func(s *ssh.Session, c *ssh.Client, _ io.ReadWriteCloser) (bool, error) {
t.sshSession = s
t.windowChange(&t.params.Term)
return false, nil
}
return tc, nil
}
// streamTerminal opens a SSH connection to the remote host and streams
// events back to the web client.
func (t *TerminalHandler) streamTerminal(ws *websocket.Conn, tc *client.TeleportClient) {
defer t.terminalCancel()
// Establish SSH connection to the server. This function will block until
// either an error occurs or it completes successfully.
err := tc.SSH(t.terminalContext, t.params.InteractiveCommand, false)
if err != nil {
log.Warningf("failed to SSH: %v", err)
errToTerm(err, ws)
}
}
// streamEvents receives events over the SSH connection (as well as periodic
// polling) to update the client with relevant audit events.
func (t *TerminalHandler) streamEvents(ws *websocket.Conn, tc *client.TeleportClient) {
// A cursor are used to keep track of where we are in the event stream. This
// is to find "session.end" events.
var cursor int = -1
tickerCh := time.NewTicker(defaults.SessionRefreshPeriod)
defer tickerCh.Stop()
for {
select {
// Send push events that come over the events channel to the web client.
case event := <-tc.EventsChannel():
e := eventEnvelope{
Type: defaults.AuditEnvelopeType,
Payload: event,
}
log.Debugf("Sending audit event %v to web client.", event.GetType())
err := websocket.JSON.Send(ws, e)
if err != nil {
log.Errorf("Unable to %v event to web client: %v.", event.GetType(), err)
continue
}
// Poll for events to send to the web client. This is for events that can
// not be sent over the events channel (like "session.end" which lingers for
// a while after all party members have left).
case <-tickerCh.C:
// Fetch all session events from the backend.
sessionEvents, cur, err := t.pollEvents(cursor)
if err != nil {
if !trace.IsNotFound(err) {
log.Errorf("Unable to poll for events: %v.", err)
continue
}
continue
}
// Update the cursor location.
cursor = cur
// Send all events to the web client.
for _, sessionEvent := range sessionEvents {
ee := eventEnvelope{
Type: defaults.AuditEnvelopeType,
Payload: sessionEvent,
}
err = websocket.JSON.Send(ws, ee)
if err != nil {
log.Warnf("Unable to send %v events to web client: %v.", len(sessionEvents), err)
continue
}
// The session end event was sent over the websocket, we can now close the
// websocket.
if sessionEvent.GetType() == events.SessionEndEvent {
t.eventCancel()
return
}
}
case <-t.eventContext.Done():
return
}
}
}
// pollEvents polls the backend for events that don't get pushed over the
// SSH events channel. Eventually this function will be removed completely.
func (t *TerminalHandler) pollEvents(cursor int) ([]events.EventFields, int, error) {
// Poll for events since the last call (cursor location).
sessionEvents, err := t.authProvider.GetSessionEvents(t.namespace, t.sessionID, cursor+1, false)
if err != nil {
if !trace.IsNotFound(err) {
return nil, 0, trace.Wrap(err)
}
return nil, 0, trace.NotFound("no events from cursor: %v", cursor)
}
// Get the batch size to see if any events were returned.
batchLen := len(sessionEvents)
if batchLen == 0 {
return nil, 0, trace.NotFound("no events from cursor: %v", cursor)
}
// Advance the cursor.
newCursor := sessionEvents[batchLen-1].GetInt(events.EventCursor)
// Filter out any resize events as we get them over push notifications.
var filteredEvents []events.EventFields
for _, event := range sessionEvents {
if event.GetType() == events.ResizeEvent ||
event.GetType() == events.SessionJoinEvent ||
event.GetType() == events.SessionLeaveEvent ||
event.GetType() == events.SessionPrintEvent {
continue
}
filteredEvents = append(filteredEvents, event)
}
return filteredEvents, newCursor, nil
}
// windowChange is called when the browser window is resized. It sends a
// "window-change" channel request to the server.
func (t *TerminalHandler) windowChange(params *session.TerminalParams) error {
if t.sshSession == nil {
return nil
}
_, err := t.sshSession.SendRequest(
// send SSH "window resized" SSH request:
sshutils.WindowChangeRequest,
// no response needed
false,
ssh.Marshal(sshutils.WinChangeReqParams{
W: uint32(params.W),
@ -142,107 +451,26 @@ func (t *TerminalHandler) resizePTYWindow(params session.TerminalParams) error {
if err != nil {
log.Error(err)
}
return trace.Wrap(err)
}
// Run creates a new websocket connection to the SSH server and runs
// the "loop" piping the input/output of the SSH session into the
// js-based terminal.
func (t *TerminalHandler) Run(w http.ResponseWriter, r *http.Request) {
errToTerm := func(err error, w io.Writer) {
fmt.Fprintf(w, "%s\n\r", err.Error())
log.Error(err)
}
webSocketLoop := func(ws *websocket.Conn) {
agent, cert, err := t.ctx.GetAgent()
if err != nil {
log.Warningf("failed to get user credentials: %v", err)
errToTerm(err, ws)
return
}
signers, err := agent.Signers()
if err != nil {
log.Warningf("failed to get user credentials: %v", err)
errToTerm(err, ws)
return
}
tlsConfig, err := t.ctx.ClientTLSConfig()
if err != nil {
log.Warningf("failed to get client TLS config: %v", err)
errToTerm(err, ws)
return
}
// create teleport client:
output := utils.NewWebSockWrapper(ws, utils.WebSocketTextMode)
clientConfig := &client.Config{
SkipLocalAuth: true,
ForwardAgent: true,
Agent: agent,
TLS: tlsConfig,
AuthMethods: []ssh.AuthMethod{ssh.PublicKeys(signers...)},
DefaultPrincipal: cert.ValidPrincipals[0],
HostLogin: t.params.Login,
Username: t.ctx.user,
Namespace: t.params.Namespace,
Stdout: output,
Stderr: output,
Stdin: ws,
SiteName: t.params.Cluster,
ProxyHostPort: t.params.ProxyHostPort,
Host: t.hostName,
HostPort: t.hostPort,
Env: map[string]string{sshutils.SessionEnvVar: string(t.params.SessionID)},
HostKeyCallback: func(string, net.Addr, ssh.PublicKey) error { return nil },
ClientAddr: r.RemoteAddr,
}
if len(t.params.InteractiveCommand) > 0 {
clientConfig.Interactive = true
}
tc, err := client.NewClient(clientConfig)
if err != nil {
log.Warningf("failed to create client: %v", err)
errToTerm(err, ws)
return
}
// this callback will execute when a shell is created, it will give
// us a reference to ssh.Client object
tc.OnShellCreated = func(s *ssh.Session, c *ssh.Client, _ io.ReadWriteCloser) (bool, error) {
t.sshSession = s
t.resizePTYWindow(t.params.Term)
return false, nil
}
if err = tc.SSH(context.TODO(), t.params.InteractiveCommand, false); err != nil {
log.Warningf("failed to SSH: %v", err)
errToTerm(err, ws)
return
}
}
// this is to make sure we close web socket connections once
// sessionContext that owns them expires
t.ctx.AddClosers(t)
defer t.ctx.RemoveCloser(t)
// TODO(klizhentas)
// we instantiate a server explicitly here instead of using
// websocket.HandlerFunc to set empty origin checker
// make sure we check origin when in prod mode
ws := &websocket.Server{Handler: webSocketLoop}
ws.ServeHTTP(w, r)
// errToTerm displays an error in the terminal window.
func errToTerm(err error, w io.Writer) {
fmt.Fprintf(w, "%s\r\n", err.Error())
}
// resolveServerHostPort parses server name and attempts to resolve hostname and port
// resolveServerHostPort parses server name and attempts to resolve hostname
// and port.
func resolveServerHostPort(servername string, existingServers []services.Server) (string, int, error) {
// if port is 0, it means the client wants us to figure out which port to use
// If port is 0, client wants us to figure out which port to use.
var defaultPort = 0
if servername == "" {
return "", defaultPort, trace.BadParameter("empty server name")
}
// check if servername is UUID
// Check if servername is UUID.
for i := range existingServers {
node := existingServers[i]
if node.GetName() == servername {
@ -254,7 +482,7 @@ func resolveServerHostPort(servername string, existingServers []services.Server)
return servername, defaultPort, nil
}
// check for explicitly specified port
// Check for explicitly specified port.
host, portString, err := utils.SplitHostPort(servername)
if err != nil {
return "", defaultPort, trace.Wrap(err)
@ -267,3 +495,145 @@ func resolveServerHostPort(servername string, existingServers []services.Server)
return host, port, nil
}
// wrappedSocket wraps and unwraps the envelope that is used to send events
// over the websocket.
type wrappedSocket struct {
ws *websocket.Conn
terminal *TerminalHandler
encoder *encoding.Encoder
decoder *encoding.Decoder
}
func newWrappedSocket(ws *websocket.Conn, terminal *TerminalHandler) *wrappedSocket {
if ws == nil {
return nil
}
return &wrappedSocket{
ws: ws,
terminal: terminal,
encoder: unicode.UTF8.NewEncoder(),
decoder: unicode.UTF8.NewDecoder(),
}
}
// Write wraps the data bytes in a raw envelope and sends.
func (w *wrappedSocket) Write(data []byte) (n int, err error) {
encodedBytes, err := w.encoder.Bytes(data)
if err != nil {
return 0, trace.Wrap(err)
}
e := rawEnvelope{
Type: defaults.RawEnvelopeType,
Payload: encodedBytes,
}
err = websocket.JSON.Send(w.ws, e)
if err != nil {
return 0, trace.Wrap(err)
}
return len(data), nil
}
// Read unwraps the envelope and either fills out the passed in bytes or
// performs an action on the connection (sending window-change request).
func (w *wrappedSocket) Read(out []byte) (n int, err error) {
var ue unknownEnvelope
err = websocket.JSON.Receive(w.ws, &ue)
if err != nil {
if err == io.EOF {
return 0, io.EOF
}
return 0, trace.Wrap(err)
}
switch ue.Type {
case defaults.RawEnvelopeType:
var re rawEnvelope
err := json.Unmarshal(ue.Raw, &re)
if err != nil {
return 0, trace.Wrap(err)
}
var data []byte
data, err = w.decoder.Bytes(re.Payload)
if err != nil {
return 0, trace.Wrap(err)
}
if len(out) < len(data) {
log.Warningf("websocket failed to receive everything: %d vs %d", len(out), len(data))
}
return copy(out, data), nil
case defaults.ResizeRequestEnvelopeType:
if w.terminal == nil {
return 0, nil
}
var ee eventEnvelope
err := json.Unmarshal(ue.Raw, &ee)
if err != nil {
return 0, trace.Wrap(err)
}
params, err := session.UnmarshalTerminalParams(ee.Payload.GetString("size"))
if err != nil {
return 0, trace.Wrap(err)
}
// Send the window change request in a goroutine so reads are not blocked
// by network connectivity issues.
go w.terminal.windowChange(params)
return 0, nil
default:
return 0, trace.BadParameter("unknown envelope type")
}
}
// SetReadDeadline sets the network read deadline on the underlying websocket.
func (w *wrappedSocket) SetReadDeadline(t time.Time) error {
return w.ws.SetReadDeadline(t)
}
// Close the websocket.
func (w *wrappedSocket) Close() error {
return w.ws.Close()
}
// eventEnvelope is used to send/receive audit events.
type eventEnvelope struct {
Type string `json:"type"`
Payload events.EventFields `json:"payload"`
}
// rawEnvelope is used to send/receive terminal bytes.
type rawEnvelope struct {
Type string `json:"type"`
Payload []byte `json:"payload"`
}
// unknownEnvelope is used to figure out the type of data being unmarshaled.
type unknownEnvelope struct {
envelopeHeader
Raw []byte
}
type envelopeHeader struct {
Type string `json:"type"`
}
func (u *unknownEnvelope) UnmarshalJSON(raw []byte) error {
var eh envelopeHeader
if err := json.Unmarshal(raw, &eh); err != nil {
return err
}
u.Type = eh.Type
u.Raw = make([]byte, len(raw))
copy(u.Raw, raw)
return nil
}