Merge pull request #707 from gravitational/ev/669

Code refactoring / cleanup
This commit is contained in:
Ev Kontsevoy 2017-01-12 14:47:42 -08:00 committed by GitHub
commit 36a2d488d7
5 changed files with 110 additions and 87 deletions

View file

@ -940,27 +940,40 @@ func (m *Handler) getSiteNodes(w http.ResponseWriter, r *http.Request, p httprou
//
// Sucessful response is a websocket stream that allows read write to the server
//
func (m *Handler) siteNodeConnect(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) {
func (m *Handler) siteNodeConnect(
w http.ResponseWriter,
r *http.Request,
p httprouter.Params,
ctx *SessionContext,
site reversetunnel.RemoteSite) (interface{}, error) {
q := r.URL.Query()
params := q.Get("params")
if params == "" {
return nil, trace.BadParameter("missing params")
}
var req *connectReq
var req *terminalRequest
if err := json.Unmarshal([]byte(params), &req); err != nil {
return nil, trace.Wrap(err)
}
log.Debugf("[WEB] new terminal request for ns=%s, server=%s, login=%s",
req.Namespace, req.ServerID, req.Login)
req.Namespace = p.ByName("namespace")
log.Infof("web client connected to node %#v", req)
connect, err := newConnectHandler(*req, ctx, site)
term, err := newTerminal(*req, ctx, site)
if err != nil {
return nil, trace.Wrap(err)
}
// this is to make sure we close web socket connections once
// sessionContext that owns them expires
ctx.AddClosers(connect)
defer connect.Close()
connect.Handler().ServeHTTP(w, r)
ctx.AddClosers(term)
defer term.Close()
// start the websocket session with a web-based terminal:
log.Infof("[WEB] getting terminal to '%#v'", req)
term.Run(w, r)
return nil, nil
}

View file

@ -48,29 +48,36 @@ type SessionContext struct {
closers []io.Closer
}
func (c *SessionContext) getConnectHandler(sessionID session.ID) (*connectHandler, error) {
// getTerminal finds and returns an active web terminal for a given session:
func (c *SessionContext) getTerminal(sessionID session.ID) (*terminalHandler, error) {
c.Lock()
defer c.Unlock()
for _, closer := range c.closers {
handler, ok := closer.(*connectHandler)
if ok && handler.req.SessionID == sessionID {
return handler, nil
term, ok := closer.(*terminalHandler)
if ok && term.params.SessionID == sessionID {
return term, nil
}
}
return nil, trace.NotFound("no connected streams")
}
func (c *SessionContext) UpdateSessionTerminal(namespace string, sessionID session.ID, params session.TerminalParams) error {
err := c.clt.UpdateSession(session.UpdateRequest{ID: sessionID, TerminalParams: &params, Namespace: namespace})
func (c *SessionContext) UpdateSessionTerminal(
namespace string, sessionID session.ID, params session.TerminalParams) error {
err := c.clt.UpdateSession(session.UpdateRequest{
ID: sessionID,
TerminalParams: &params,
Namespace: namespace,
})
if err != nil {
return trace.Wrap(err)
}
handler, err := c.getConnectHandler(sessionID)
term, err := c.getTerminal(sessionID)
if err != nil {
return trace.Wrap(err)
}
return trace.Wrap(handler.resizePTYWindow(params))
return trace.Wrap(term.resizePTYWindow(params))
}
func (c *SessionContext) AddClosers(closers ...io.Closer) {

View file

@ -51,7 +51,7 @@ type sealData struct {
Nonce []byte `json:"nonce"`
}
// SSHAgentOIDCLogin is used by SSH Agent to login using OpenID connect
// SSHAgentOIDCLogin is used by SSH Agent (tsh) to login using OpenID connect
func SSHAgentOIDCLogin(proxyAddr, connectorID string, pubKey []byte, ttl time.Duration, insecure bool, pool *x509.CertPool) (*SSHLoginResponse, error) {
clt, proxyURL, err := initClient(proxyAddr, insecure, pool)
if err != nil {

View file

@ -17,7 +17,6 @@ limitations under the License.
package web
import (
"fmt"
"net/http"
"github.com/gravitational/teleport/lib/reversetunnel"
@ -32,9 +31,9 @@ import (
"golang.org/x/net/websocket"
)
// connectReq is a request to open interactive SSH
// connection to remote server
type connectReq struct {
// terminalRequest describes a request to crate a web-based terminal
// to a remote SSH server
type terminalRequest struct {
// ServerID is a server id to connect to
ServerID string `json:"server_id"`
// User is linux username to connect as
@ -47,7 +46,12 @@ type connectReq struct {
Namespace string `json:"namespace"`
}
func newConnectHandler(req connectReq, ctx *SessionContext, site reversetunnel.RemoteSite) (*connectHandler, error) {
// newTerminal creates a web-based terminal based on WebSockets and returns a new
// terminalHandler
func newTerminal(req terminalRequest,
ctx *SessionContext,
site reversetunnel.RemoteSite) (*terminalHandler, error) {
clt, err := site.GetClient()
if err != nil {
return nil, trace.Wrap(err)
@ -72,29 +76,33 @@ func newConnectHandler(req connectReq, ctx *SessionContext, site reversetunnel.R
if req.Term.W <= 0 || req.Term.H <= 0 {
return nil, trace.BadParameter("term: bad term dimensions")
}
return &connectHandler{
req: req,
return &terminalHandler{
params: req,
ctx: ctx,
site: site,
server: *server,
}, nil
}
// connectHandler is a websocket to SSH proxy handler
type connectHandler struct {
ctx *SessionContext
site reversetunnel.RemoteSite
up *sshutils.Upstream
req connectReq
ws *websocket.Conn
// terminalHandler connects together an SSH session with a web-based
// terminal via a web socket.
type terminalHandler struct {
// params describe the terminal configuration
params terminalRequest
// ctx is a web session context for the currently logged in user
ctx *SessionContext
ws *websocket.Conn
up *sshutils.Upstream
// site/cluster we're connected to
site reversetunnel.RemoteSite
// server we're connected to
server services.Server
}
func (w *connectHandler) String() string {
return fmt.Sprintf("connectHandler(%#v)", w.req)
}
func (w *connectHandler) Close() error {
func (w *terminalHandler) Close() error {
if w.ws != nil {
w.ws.Close()
}
@ -104,32 +112,9 @@ func (w *connectHandler) Close() error {
return nil
}
// connect is called when a web browser wants to start piping an active terminal session
// io/out via the provided websocket
func (w *connectHandler) connect(ws *websocket.Conn) {
// connectUpstream establishes an SSH connection to a requested node
up, err := w.connectUpstream()
if err != nil {
log.Error(err)
return
}
w.up = up
w.ws = ws
// PipeShell will be piping inputs/output to/from SSH connection (to the node)
// and the websocket (to a browser)
err = w.up.PipeShell(utils.NewWebSockWrapper(ws, utils.WebSocketTextMode),
&sshutils.PTYReqParams{
W: uint32(w.req.Term.W),
H: uint32(w.req.Term.H),
})
log.Infof("pipe shell finished with: %v", err)
}
// resizePTYWindow is called when a brower resizes its window. Now the node
// needs to be notified via SSH
func (w *connectHandler) resizePTYWindow(params session.TerminalParams) error {
func (w *terminalHandler) resizePTYWindow(params session.TerminalParams) error {
_, err := w.up.GetSession().SendRequest(
// send SSH "window resized" SSH request:
sshutils.WindowChangeReq, false,
@ -141,8 +126,8 @@ func (w *connectHandler) resizePTYWindow(params session.TerminalParams) error {
}
// connectUpstream establishes the SSH connection to a requested SSH server (node)
func (w *connectHandler) connectUpstream() (*sshutils.Upstream, error) {
agent, err := w.ctx.GetAgent()
func (t *terminalHandler) connectUpstream() (*sshutils.Upstream, error) {
agent, err := t.ctx.GetAgent()
if err != nil {
return nil, trace.Wrap(err)
}
@ -151,8 +136,8 @@ func (w *connectHandler) connectUpstream() (*sshutils.Upstream, error) {
if err != nil {
return nil, trace.Wrap(err)
}
client, err := w.site.ConnectToServer(
w.server.GetAddr(), w.req.Login, []ssh.AuthMethod{ssh.PublicKeys(signers...)})
client, err := t.site.ConnectToServer(
t.server.GetAddr(), t.params.Login, []ssh.AuthMethod{ssh.PublicKeys(signers...)})
if err != nil {
return nil, trace.Wrap(err)
}
@ -183,21 +168,39 @@ func (w *connectHandler) connectUpstream() (*sshutils.Upstream, error) {
sshutils.SetEnvReq, false,
ssh.Marshal(sshutils.EnvReqParams{
Name: sshutils.SessionEnvVar,
Value: string(w.req.SessionID),
Value: string(t.params.SessionID),
}))
return up, nil
}
func (w *connectHandler) Handler() http.Handler {
// Run creates a new websocket connection to the SSH server and runs
// the "loop" piping the input/output of the SSH session into the
// js-based terminal.
func (t *terminalHandler) Run(w http.ResponseWriter, r *http.Request) {
webSocketLoop := func(ws *websocket.Conn) {
up, err := t.connectUpstream()
if err != nil {
log.Error(err)
return
}
t.up = up
t.ws = ws
// PipeShell will be piping inputs/output to/from SSH connection (to the node)
// and the websocket (to a browser)
err = t.up.PipeShell(utils.NewWebSockWrapper(ws, utils.WebSocketTextMode),
&sshutils.PTYReqParams{
W: uint32(t.params.Term.W),
H: uint32(t.params.Term.H),
})
log.Infof("pipe shell finished with: %v", err)
}
// TODO(klizhentas)
// we instantiate a server explicitly here instead of using
// websocket.HandlerFunc to set empty origin checker
// make sure we check origin when in prod mode
return &websocket.Server{
Handler: w.connect,
}
}
func newWSHandler(host string, auth []string) *connectHandler {
return &connectHandler{}
ws := &websocket.Server{Handler: webSocketLoop}
ws.ServeHTTP(w, r)
}

View file

@ -547,7 +547,7 @@ func (s *WebSuite) TestGetSiteNodes(c *C) {
c.Assert(nodes2, DeepEquals, nodes)
}
func (s *WebSuite) connect(c *C, pack *authPack, opts ...session.ID) *websocket.Conn {
func (s *WebSuite) makeTerminal(c *C, pack *authPack, opts ...session.ID) *websocket.Conn {
var sessionID session.ID
if len(opts) == 0 {
sessionID = session.NewID()
@ -555,7 +555,7 @@ func (s *WebSuite) connect(c *C, pack *authPack, opts ...session.ID) *websocket.
sessionID = opts[0]
}
u := url.URL{Host: s.url().Host, Scheme: WSS, Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut)}
data, err := json.Marshal(connectReq{
data, err := json.Marshal(terminalRequest{
ServerID: s.srvID,
Login: s.user,
Term: session.TerminalParams{W: 100, H: 100},
@ -608,11 +608,11 @@ func (s *WebSuite) sessionStream(c *C, pack *authPack, sessionID session.ID, opt
return clt
}
func (s *WebSuite) TestConnect(c *C) {
clt := s.connect(c, s.authPack(c))
defer clt.Close()
func (s *WebSuite) TestTerminal(c *C) {
term := s.makeTerminal(c, s.authPack(c))
defer term.Close()
_, err := io.WriteString(clt, "echo vinsong\r\n")
_, err := io.WriteString(term, "echo vinsong\r\n")
c.Assert(err, IsNil)
resultC := make(chan struct{})
@ -620,7 +620,7 @@ func (s *WebSuite) TestConnect(c *C) {
go func() {
out := make([]byte, 100)
for {
n, err := clt.Read(out)
n, err := term.Read(out)
c.Assert(err, IsNil)
c.Assert(n > 0, Equals, true)
if strings.Contains(removeSpace(string(out)), "vinsong") {
@ -642,7 +642,7 @@ func (s *WebSuite) TestConnect(c *C) {
func (s *WebSuite) TestNodesWithSessions(c *C) {
sid := session.NewID()
pack := s.authPack(c)
clt := s.connect(c, pack, sid)
clt := s.makeTerminal(c, pack, sid)
defer clt.Close()
// to make sure we have a session
@ -697,7 +697,7 @@ func (s *WebSuite) TestNodesWithSessions(c *C) {
func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) {
sid := session.NewID()
pack := s.authPack(c)
clt := s.connect(c, pack, sid)
clt := s.makeTerminal(c, pack, sid)
defer clt.Close()
// to make sure we have a session
@ -754,16 +754,16 @@ func (s *WebSuite) TestCreateSession(c *C) {
func (s *WebSuite) TestResizeTerminal(c *C) {
sid := session.NewID()
pack := s.authPack(c)
clt := s.connect(c, pack, sid)
defer clt.Close()
term := s.makeTerminal(c, pack, sid)
defer term.Close()
// to make sure we have a session
_, err := io.WriteString(clt, "expr 137 + 39\r\n")
_, err := io.WriteString(term, "expr 137 + 39\r\n")
c.Assert(err, IsNil)
// make sure server has replied
out := make([]byte, 100)
clt.Read(out)
term.Read(out)
params := session.TerminalParams{W: 300, H: 120}
_, err = pack.clt.PutJSON(
@ -784,8 +784,8 @@ func (s *WebSuite) TestResizeTerminal(c *C) {
func (s *WebSuite) TestPlayback(c *C) {
pack := s.authPack(c)
sid := session.NewID()
clt := s.connect(c, pack, sid)
defer clt.Close()
term := s.makeTerminal(c, pack, sid)
defer term.Close()
}
func removeSpace(in string) string {