Client refactorings

- Broke apart huge runShell() and runCommand() functions
- Introduced "client session" object which now holds context for a
  session on the client.
This commit is contained in:
Ev Kontsevoy 2016-09-07 19:15:08 -07:00
parent 50fbc5855f
commit db4558ddda
2 changed files with 258 additions and 250 deletions

View file

@ -690,13 +690,7 @@ func (tc *TeleportClient) runCommand(siteName string, nodeAddresses []string, pr
// sessionID : when empty, creates a new shell. otherwise it tries to join the existing session.
// stdin : standard input to use. if nil, uses os.Stdin
func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.Session, stdin io.Reader) error {
var (
state *term.State
err error
)
defer nodeClient.Close()
address := tc.NodeHostPort()
if stdin == nil {
stdin = os.Stdin
}
@ -707,128 +701,59 @@ func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.S
tc.Stderr = os.Stderr
}
attachedTerm := (stdin == os.Stdin && term.IsTerminal(0))
sess, err := newSession(nodeClient, sessToJoin, tc.Config.Env, attachedTerm)
winSize := &term.Winsize{Width: 80, Height: 25}
if attachedTerm {
winSize, err = term.GetWinsize(0)
if err != nil {
log.Error(err)
}
}
var sessionID session.ID
if sessToJoin != nil {
// if we're joining an existing session, we need to assume that session's
// existing/current terminal size:
sessionID = sessToJoin.ID
winSize = sessToJoin.TerminalParams.Winsize()
if attachedTerm {
err = term.SetWinsize(0, winSize)
if err != nil {
log.Error(err)
}
os.Stdout.Write([]byte(fmt.Sprintf("\x1b[8;%d;%dt", winSize.Height, winSize.Width)))
}
}
shell, err := nodeClient.Shell(
int(winSize.Width),
int(winSize.Height),
sessionID,
tc.Config.Env,
attachedTerm)
// request a new shell session on the SSH server:
shell, err := nodeClient.Shell(sess)
if err != nil {
return trace.Wrap(err)
}
defer shell.Close()
// user-supplied callback
// call the client-supplied callback
if tc.OnShellCreated != nil {
exit, err := tc.OnShellCreated(shell)
if exit {
return trace.Wrap(err)
}
}
// terminal must be in raw mode
if attachedTerm {
state, err = term.SetRawTerminal(0)
if err != nil {
return trace.Wrap(err)
}
log.Infof("[CLIENT] connecting to remote shell using stdin")
} else {
log.Infof("[CLIENT] connecting to remote shell NOT using stdin")
}
defer func() {
if state != nil {
term.RestoreTerminal(0, state)
}
if tc.ExitMsg != "" {
fmt.Println(tc.ExitMsg)
}
}()
broadcastClose := utils.NewCloseBroadcaster()
// Catch term signals, but only if we're attached to a real terminal
closer := utils.NewCloseBroadcaster()
if attachedTerm {
exitSignals := make(chan os.Signal, 1)
signal.Notify(exitSignals, syscall.SIGTERM)
go func() {
defer broadcastClose.Close()
<-exitSignals
if tc.ExitMsg == "" {
tc.ExitMsg = fmt.Sprintf("Connection to %s closed\n", address)
}
}()
// Catch Ctrl-C signal
ctrlCSignal := make(chan os.Signal, 1)
signal.Notify(ctrlCSignal, syscall.SIGINT)
go func() {
for {
<-ctrlCSignal
_, err := shell.Write([]byte{3})
if err != nil {
log.Errorf(err.Error())
}
}
}()
// Catch Ctrl-Z signal
ctrlZSignal := make(chan os.Signal, 1)
signal.Notify(ctrlZSignal, syscall.SIGTSTP)
go func() {
for {
<-ctrlZSignal
_, err := shell.Write([]byte{26})
if err != nil {
log.Errorf(err.Error())
}
}
}()
tc.watchSignals(shell, closer)
}
// start piping input into the remote shell and pipe the output from
// the remote shell into stdout:
tc.pipeInOut(shell, stdin, closer)
// wait for the session to end
<-closer.C
return nil
}
// copy from the remote shell to the local
// pipeInOut launches two goroutines: one to pipe the local input into the remote shell,
// and another to pipe the output of the remote shell into the local output
func (tc *TeleportClient) pipeInOut(shell io.ReadWriteCloser, localInput io.Reader, closer *utils.CloseBroadcaster) {
// copy from the remote shell to the local output
go func() {
defer broadcastClose.Close()
defer closer.Close()
_, err := io.Copy(tc.Stdout, shell)
if err != nil {
log.Errorf(err.Error())
}
if tc.ExitMsg == "" {
tc.ExitMsg = fmt.Sprintf("Connection to %s closed from the remote side", address)
tc.ExitMsg = fmt.Sprintf("Connection to %s closed from the remote side", tc.NodeHostPort())
}
}()
// copy from the local shell to the remote
// copy from the local input to the remote shell:
go func() {
defer broadcastClose.Close()
defer closer.Close()
buf := make([]byte, 128)
for {
n, err := stdin.Read(buf)
n, err := localInput.Read(buf)
if err != nil {
fmt.Println(trace.Wrap(err))
return
@ -843,8 +768,45 @@ func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.S
}
}()
<-broadcastClose.C
return nil
}
// watchSignals register UNIX signal handlers and properly terminates a remote shell session
// must be called as a goroutine right after a remote shell is created
func (tc *TeleportClient) watchSignals(shell io.Writer, closer *utils.CloseBroadcaster) {
exitSignals := make(chan os.Signal, 1)
// catch SIGTERM
signal.Notify(exitSignals, syscall.SIGTERM)
go func() {
defer closer.Close()
<-exitSignals
if tc.ExitMsg == "" {
tc.ExitMsg = fmt.Sprintf("Connection to %s closed\n", tc.NodeHostPort())
}
}()
// Catch Ctrl-C signal
ctrlCSignal := make(chan os.Signal, 1)
signal.Notify(ctrlCSignal, syscall.SIGINT)
go func() {
for {
<-ctrlCSignal
_, err := shell.Write([]byte{3})
if err != nil {
log.Errorf(err.Error())
}
}
}()
// Catch Ctrl-Z signal
ctrlZSignal := make(chan os.Signal, 1)
signal.Notify(ctrlZSignal, syscall.SIGTSTP)
go func() {
for {
<-ctrlZSignal
_, err := shell.Write([]byte{26})
if err != nil {
log.Errorf(err.Error())
}
}
}()
}
// getProxyLogin determines which SSH login to use when connecting to proxy.

View file

@ -262,6 +262,104 @@ func (proxy *ProxyClient) Close() error {
return proxy.Client.Close()
}
type NodeSession struct {
// id is the Teleport session ID
id session.ID
// env is the environment variables that need to be created
// on the server for this session
env map[string]string
// attachedTerm is set to true when this session is be controlled by
// a real terminal.
// This will be set to False for sessions initiated by the Web client or
// for non-interactive sessions (commands)
attachedTerm bool
// terminalSize is the inital size of the terminal. It only has meaning
// when the session is interactive
terminalSize *term.Winsize
// serverSession is the server-side SSH session
serverSession *ssh.Session
// nodeClient is the parent of this session: the client connected to an
// SSH node
nodeClient *NodeClient
}
func newSession(client *NodeClient,
joinSession *session.Session,
env map[string]string,
attachedTerm bool) (*NodeSession, error) {
var err error
ns := &NodeSession{
attachedTerm: attachedTerm,
env: env,
nodeClient: client,
terminalSize: &term.Winsize{Width: 80, Height: 25},
}
// read the size of the terminal window:
if attachedTerm {
ns.terminalSize, err = term.GetWinsize(0)
if err != nil {
log.Error(err)
}
state, err := term.SetRawTerminal(0)
if err != nil {
return nil, trace.Wrap(err)
}
defer term.RestoreTerminal(0, state)
}
// if we're joining an existing session, we need to assume that session's
// existing/current terminal size:
if joinSession != nil {
ns.id = joinSession.ID
ns.terminalSize = joinSession.TerminalParams.Winsize()
if attachedTerm {
err = term.SetWinsize(0, ns.terminalSize)
if err != nil {
log.Error(err)
}
os.Stdout.Write([]byte(fmt.Sprintf("\x1b[8;%d;%dt", ns.terminalSize.Height, ns.terminalSize.Width)))
}
// new session!
} else {
ns.id = session.NewID()
}
if ns.env == nil {
ns.env = make(map[string]string)
}
ns.env[sshutils.SessionEnvVar] = string(ns.id)
// create the server-side session:
ns.serverSession, err = client.Client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}
// pass language info into the remote session.
evarsToPass := []string{"LANG", "LANGUAGE"}
for _, evar := range evarsToPass {
if value := os.Getenv(evar); value != "" {
err = ns.serverSession.Setenv(evar, value)
if err != nil {
log.Warn(err)
}
}
}
// pass environment variables set by client
for key, val := range env {
err = ns.serverSession.Setenv(key, val)
if err != nil {
log.Warn(err)
}
}
return ns, nil
}
// Shell returns a configured remote shell (for a window of a requested size)
// as io.ReadWriterCloser object
//
@ -271,183 +369,131 @@ func (proxy *ProxyClient) Close() error {
// a new session
// - env : list of environment variables to set for a new session
// - attachedTerm : boolean indicating if this client is attached to a real terminal
func (client *NodeClient) Shell(
width, height int,
sessionID session.ID,
env map[string]string,
attachedTerm bool) (io.ReadWriteCloser, error) {
if sessionID == "" {
// initiate a new session if not passed
sessionID = session.NewID()
func (client *NodeClient) Shell(nc *NodeSession) (io.ReadWriteCloser, error) {
pipe, err := nc.allocateTerminal()
// start the shell on the server:
if err := nc.serverSession.Shell(); err != nil {
return nil, trace.Wrap(err)
}
return pipe, err
}
// allocateTerminal creates (allocates) a server-side terminal for a given session.
func (ns *NodeSession) allocateTerminal() (io.ReadWriteCloser, error) {
err := ns.serverSession.RequestPty("xterm",
int(ns.terminalSize.Height),
int(ns.terminalSize.Width),
ssh.TerminalModes{})
siteClient, err := client.Proxy.ConnectToSite()
if err != nil {
return nil, trace.Wrap(err)
}
clientSession, err := client.Client.NewSession()
writer, err := ns.serverSession.StdinPipe()
if err != nil {
return nil, trace.Wrap(err)
}
// ask the server to drop us into the existing session:
if len(sessionID) > 0 {
err = clientSession.Setenv(sshutils.SessionEnvVar, string(sessionID))
if err != nil {
log.Warn(err)
}
}
// pass language info into the remote session.
evarsToPass := []string{"LANG", "LANGUAGE"}
for _, evar := range evarsToPass {
if value := os.Getenv(evar); value != "" {
err = clientSession.Setenv(evar, value)
if err != nil {
log.Warn(err)
}
}
}
// pass environment variables set by client
for key, val := range env {
err = clientSession.Setenv(key, val)
if err != nil {
log.Warn(err)
}
}
terminalModes := ssh.TerminalModes{}
err = clientSession.RequestPty("xterm", height, width, terminalModes)
reader, err := ns.serverSession.StdoutPipe()
if err != nil {
return nil, trace.Wrap(err)
}
writer, err := clientSession.StdinPipe()
stderr, err := ns.serverSession.StderrPipe()
if err != nil {
return nil, trace.Wrap(err)
}
reader, err := clientSession.StdoutPipe()
if err != nil {
return nil, trace.Wrap(err)
closer := utils.NewCloseBroadcaster()
if ns.attachedTerm {
go ns.updateTerminalSize(closer)
}
stderr, err := clientSession.StderrPipe()
if err != nil {
return nil, trace.Wrap(err)
}
broadcastClose := utils.NewCloseBroadcaster()
// this goroutine sleeps until a terminal size changes (it receives an OS signal)
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGWINCH)
currentSize, _ := term.GetWinsize(0)
broadcastTerminalSize := func() {
for {
select {
case sig := <-sigC:
if sig == nil {
return
}
// get the size:
winSize, err := term.GetWinsize(0)
if err != nil {
log.Warnf("[CLIENT] Error getting size: %s", err)
break
}
// it's the result of our own size change (see below)
if winSize.Height == currentSize.Height && winSize.Width == currentSize.Width {
continue
}
// send the new window size to the server
_, err = clientSession.SendRequest(
sshutils.WindowChangeReq, false,
ssh.Marshal(sshutils.WinChangeReqParams{
W: uint32(winSize.Width),
H: uint32(winSize.Height),
}))
if err != nil {
log.Warnf("[CLIENT] failed to send window change reqest: %v", err)
}
case <-broadcastClose.C:
return
}
}
}
// detect changes of the session's terminal
updateTerminalSize := func() {
tick := time.NewTicker(defaults.SessionRefreshPeriod)
defer tick.Stop()
var prevSess *session.Session
for {
select {
case <-tick.C:
sess, err := siteClient.GetSession(sessionID)
if err != nil {
log.Error(err)
continue
}
log.Infof("[CLIENT] updating the session %v with %d parties", sess.ID, len(sess.Parties))
// no previous session
if prevSess == nil {
prevSess = sess
continue
}
// nothing changed
if prevSess.TerminalParams.W == sess.TerminalParams.W && prevSess.TerminalParams.H == sess.TerminalParams.H {
continue
}
newSize := sess.TerminalParams.Winsize()
currentSize, err = term.GetWinsize(0)
if err != nil {
log.Error(err)
}
if currentSize.Width != newSize.Width || currentSize.Height != newSize.Height {
// ok, something have changed, let's resize to the new parameters
err = term.SetWinsize(0, newSize)
if err != nil {
log.Error(err)
}
os.Stdout.Write([]byte(fmt.Sprintf("\x1b[8;%d;%dt", newSize.Height, newSize.Width)))
}
prevSess = sess
case <-broadcastClose.C:
return
}
}
}
if attachedTerm {
go broadcastTerminalSize()
go updateTerminalSize()
}
go func() {
io.Copy(os.Stderr, stderr)
}()
err = clientSession.Shell()
if err != nil {
return nil, trace.Wrap(err)
}
return utils.NewPipeNetConn(
reader,
writer,
utils.MultiCloser(writer, clientSession, broadcastClose),
utils.MultiCloser(writer, ns.serverSession, closer),
&net.IPAddr{},
&net.IPAddr{},
), nil
}
func (ns *NodeSession) updateTerminalSize(closer *utils.CloseBroadcaster) {
// sibscribe for "terminal resized" signal:
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGWINCH)
currentSize, _ := term.GetWinsize(0)
// start the timer which asks for server-side window size changes:
siteClient, err := ns.nodeClient.Proxy.ConnectToSite()
if err != nil {
log.Error(err)
}
tick := time.NewTicker(defaults.SessionRefreshPeriod)
defer tick.Stop()
var prevSess *session.Session
for {
select {
case sig := <-sigC:
if sig == nil {
return
}
// get the size:
winSize, err := term.GetWinsize(0)
if err != nil {
log.Warnf("[CLIENT] Error getting size: %s", err)
break
}
// it's the result of our own size change (see below)
if winSize.Height == currentSize.Height && winSize.Width == currentSize.Width {
continue
}
// send the new window size to the server
_, err = ns.serverSession.SendRequest(
sshutils.WindowChangeReq, false,
ssh.Marshal(sshutils.WinChangeReqParams{
W: uint32(winSize.Width),
H: uint32(winSize.Height),
}))
if err != nil {
log.Warnf("[CLIENT] failed to send window change reqest: %v", err)
}
case <-tick.C:
sess, err := siteClient.GetSession(ns.id)
if err != nil {
log.Error(err)
continue
}
// no previous session
if prevSess == nil || sess == nil {
prevSess = sess
continue
}
log.Infof("[CLIENT] updating the session %v with %d parties", sess.ID, len(sess.Parties))
// nothing changed
if prevSess.TerminalParams.W == sess.TerminalParams.W && prevSess.TerminalParams.H == sess.TerminalParams.H {
continue
}
newSize := sess.TerminalParams.Winsize()
currentSize, err = term.GetWinsize(0)
if err != nil {
log.Error(err)
}
if currentSize.Width != newSize.Width || currentSize.Height != newSize.Height {
// ok, something have changed, let's resize to the new parameters
err = term.SetWinsize(0, newSize)
if err != nil {
log.Error(err)
}
os.Stdout.Write([]byte(fmt.Sprintf("\x1b[8;%d;%dt", newSize.Height, newSize.Width)))
}
prevSess = sess
case <-closer.C:
return
}
}
}
// Run executes command on the remote server and writes its stdout to
// the 'output' argument
func (client *NodeClient) Run(cmd []string, stdin io.Reader, stdout, stderr io.Writer, env map[string]string) error {