Add missing error checks in lib/srv and lib/sshutils

There's many more left in lib/srv, but this change is already big.
Some errors are left unhandled, where it makes sense.
This commit is contained in:
Andrew Lytvynov 2020-05-13 09:49:04 -07:00 committed by Andrew Lytvynov
parent 519df4daff
commit d52ca0617d
8 changed files with 88 additions and 48 deletions

View file

@ -817,7 +817,9 @@ func (s *Server) HandleRequest(r *ssh.Request) {
s.handleVersionRequest(r)
default:
if r.WantReply {
r.Reply(false, nil)
if err := r.Reply(false, nil); err != nil {
log.Warnf("Failed to reply to %q request: %v", r.Type, err)
}
}
log.Debugf("Discarding %q global request: %+v", r.Type, r)
}
@ -827,7 +829,7 @@ func (s *Server) HandleRequest(r *ssh.Request) {
func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) {
identityContext, err := s.authHandlers.CreateIdentityContext(ccx.ServerConn)
if err != nil {
nch.Reject(ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err))
rejectChannel(nch, ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err))
return
}
@ -840,13 +842,13 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
if err != nil {
log.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err)
nch.Reject(ssh.UnknownChannelType, "failed to parse direct-tcpip request")
rejectChannel(nch, ssh.UnknownChannelType, "failed to parse direct-tcpip request")
return
}
ch, _, err := nch.Accept()
if err != nil {
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleProxyJump(ccx, identityContext, ch, *req)
@ -858,13 +860,13 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann
ch, requests, err := nch.Accept()
if err != nil {
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(ccx, identityContext, ch, requests)
return
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
rejectChannel(nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
return
}
}
@ -876,7 +878,7 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann
ch, requests, err := nch.Accept()
if err != nil {
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(ccx, identityContext, ch, requests)
@ -885,18 +887,18 @@ func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChann
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
if err != nil {
log.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err)
nch.Reject(ssh.UnknownChannelType, "failed to parse direct-tcpip request")
rejectChannel(nch, ssh.UnknownChannelType, "failed to parse direct-tcpip request")
return
}
ch, _, err := nch.Accept()
if err != nil {
log.Warnf("Unable to accept channel: %v.", err)
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleDirectTCPIPRequest(ccx, identityContext, ch, req)
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
rejectChannel(nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
}
@ -907,7 +909,7 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
ctx, err := srv.NewServerContext(ccx, s, identityContext)
if err != nil {
log.Errorf("Unable to create connection context: %v.", err)
channel.Stderr().Write([]byte("Unable to create connection context."))
writeStderr(channel, "Unable to create connection context.")
return
}
ctx.IsTestStub = s.isTestStub
@ -920,7 +922,7 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
// Check if the role allows port forwarding for this user.
err = s.authHandlers.CheckPortForward(ctx.DstAddr, ctx)
if err != nil {
channel.Stderr().Write([]byte(err.Error()))
writeStderr(channel, err.Error())
return
}
@ -932,7 +934,7 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
// from another process.
cmd, err := srv.ConfigureCommand(ctx)
if err != nil {
channel.Stderr().Write([]byte(err.Error()))
writeStderr(channel, err.Error())
}
// Create a pipe for std{in,out} that will be used to transfer data between
@ -950,7 +952,7 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
// to the target host.
err = cmd.Start()
if err != nil {
channel.Stderr().Write([]byte(err.Error()))
writeStderr(channel, err.Error())
return
}
@ -985,7 +987,7 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
}
err = cmd.Wait()
if err != nil {
channel.Stderr().Write([]byte(err.Error()))
writeStderr(channel, err.Error())
return
}
@ -1009,7 +1011,7 @@ func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identity
ctx, err := srv.NewServerContext(ccx, s, identityContext)
if err != nil {
log.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
writeStderr(ch, "Unable to create connection context.")
return
}
ctx.IsTestStub = s.isTestStub
@ -1025,7 +1027,7 @@ func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identity
clusterConfig, err := s.GetAccessPoint().GetClusterConfig()
if err != nil {
log.Errorf("Unable to fetch cluster config: %v.", err)
ch.Stderr().Write([]byte("Unable to fetch cluster configuration."))
writeStderr(ch, "Unable to fetch cluster configuration.")
return
}
@ -1051,7 +1053,7 @@ func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identity
ctx.Errorf("Unable to update context: %v.", errorMessage)
// write the error to channel and close it
ch.Stderr().Write([]byte(errorMessage))
writeStderr(ch, errorMessage)
_, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: teleport.RemoteCommandFailure}))
if err != nil {
ctx.Errorf("Failed to send exit status %v.", errorMessage)
@ -1076,7 +1078,9 @@ func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, identity
return
}
if req.WantReply {
req.Reply(true, nil)
if err := req.Reply(true, nil); err != nil {
log.Warnf("Failed to reply to %q request: %v", req.Type, err)
}
}
case result := <-ctx.ExecResultCh:
ctx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code)
@ -1316,7 +1320,7 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex
ctx, err := srv.NewServerContext(ccx, s, identityContext)
if err != nil {
log.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
writeStderr(ch, "Unable to create connection context.")
return
}
ctx.IsTestStub = s.isTestStub
@ -1326,7 +1330,7 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex
clusterConfig, err := s.GetAccessPoint().GetClusterConfig()
if err != nil {
log.Errorf("Unable to fetch cluster config: %v.", err)
ch.Stderr().Write([]byte("Unable to fetch cluster configuration."))
writeStderr(ch, "Unable to fetch cluster configuration.")
return
}
@ -1359,7 +1363,7 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex
err = s.handleAgentForwardProxy(&ssh.Request{}, ctx)
if err != nil {
log.Warningf("Failed to request agent in recording mode: %v", err)
ch.Stderr().Write([]byte("Failed to request agent"))
writeStderr(ch, "Failed to request agent")
return
}
}
@ -1392,29 +1396,31 @@ func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, identityContex
})
if err != nil {
log.Errorf("Unable instantiate proxy subsystem: %v.", err)
ch.Stderr().Write([]byte("Unable to instantiate proxy subsystem."))
writeStderr(ch, "Unable to instantiate proxy subsystem.")
return
}
if err := subsys.Start(ctx.Conn, ch, &ssh.Request{}, ctx); err != nil {
log.Errorf("Unable to start proxy subsystem: %v.", err)
ch.Stderr().Write([]byte("Unable to start proxy subsystem."))
writeStderr(ch, "Unable to start proxy subsystem.")
return
}
if err := subsys.Wait(); err != nil {
log.Errorf("Proxy subsystem failed: %v.", err)
ch.Stderr().Write([]byte("Proxy subsystem failed."))
writeStderr(ch, "Proxy subsystem failed.")
return
}
}
func (s *Server) replyError(ch ssh.Channel, req *ssh.Request, err error) {
log.Error(err)
message := []byte(trace.UserMessage(err))
ch.Stderr().Write(message)
message := trace.UserMessage(err)
writeStderr(ch, message)
if req.WantReply {
req.Reply(false, message)
if err := req.Reply(false, []byte(message)); err != nil {
log.Warnf("Failed to reply to %q request: %v", req.Type, err)
}
}
}
@ -1431,3 +1437,15 @@ func (s *Server) parseSubsystemRequest(req *ssh.Request, ctx *srv.ServerContext)
}
return nil, trace.BadParameter("unrecognized subsystem: %v", r.Name)
}
func writeStderr(ch ssh.Channel, msg string) {
if _, err := fmt.Fprint(ch.Stderr(), msg); err != nil {
log.Warnf("Failed writing to ssh.Channel.Stderr(): %v", err)
}
}
func rejectChannel(ch ssh.NewChannel, reason ssh.RejectionReason, msg string) {
if err := ch.Reject(reason, msg); err != nil {
log.Warnf("Failed to reject new ssh.Channel: %v", err)
}
}

View file

@ -767,15 +767,16 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
c.Assert(err, IsNil)
done := make(chan struct{})
go func() {
io.Copy(stdout, reader)
_, err := io.Copy(stdout, reader)
c.Assert(err, IsNil)
close(done)
}()
// to make sure labels have the right output
s.srv.syncUpdateLabels()
srv2.syncUpdateLabels()
s.srv.heartbeat.ForceSend(time.Second)
s.srv.heartbeat.ForceSend(time.Second)
c.Assert(s.srv.heartbeat.ForceSend(time.Second), IsNil)
c.Assert(srv2.heartbeat.ForceSend(time.Second), IsNil)
// request "list of sites":
c.Assert(se3.RequestSubsystem("proxysites"), IsNil)
<-done

View file

@ -137,7 +137,7 @@ func (s *SessionRegistry) emitSessionJoinEvent(ctx *ServerContext) {
}
// Emit session join event to Audit Log.
ctx.session.recorder.GetAuditLog().EmitAuditEvent(events.SessionJoin, sessionJoinEvent)
ctx.session.emitAuditEvent(events.SessionJoin, sessionJoinEvent)
// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
@ -235,7 +235,7 @@ func (s *SessionRegistry) emitSessionLeaveEvent(party *party) {
}
// Emit session leave event to Audit Log.
party.s.recorder.GetAuditLog().EmitAuditEvent(events.SessionLeave, sessionLeaveEvent)
party.s.emitAuditEvent(events.SessionLeave, sessionLeaveEvent)
// Notify all members of the party that a new member has left over the
// "x-teleport-event" channel.
@ -299,7 +299,7 @@ func (s *SessionRegistry) leaveSession(party *party) error {
events.SessionParticipants: sess.exportParticipants(),
events.SessionServerHostname: s.srv.GetInfo().GetHostname(),
}
sess.recorder.GetAuditLog().EmitAuditEvent(events.SessionEnd, eventFields)
sess.emitAuditEvent(events.SessionEnd, eventFields)
// close recorder to free up associated resources
// and flush data
@ -365,7 +365,7 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S
// Report the updated window size to the event log (this is so the sessions
// can be replayed correctly).
ctx.session.recorder.GetAuditLog().EmitAuditEvent(events.TerminalResize, resizeEvent)
ctx.session.emitAuditEvent(events.TerminalResize, resizeEvent)
// Update the size of the server side PTY.
err := ctx.session.term.SetWinSize(params)
@ -681,7 +681,7 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
if !ctx.srv.UseTunnel() {
eventFields[events.LocalAddr] = ctx.Conn.LocalAddr().String()
}
s.recorder.GetAuditLog().EmitAuditEvent(events.SessionStart, eventFields)
s.emitAuditEvent(events.SessionStart, eventFields)
// Start a heartbeat that marks this session as active with current members
// of party in the backend.
@ -734,7 +734,9 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
}
if result != nil {
s.registry.broadcastResult(s.id, *result)
if err := s.registry.broadcastResult(s.id, *result); err != nil {
s.log.Warningf("Failed to broadcast session result: %v", err)
}
}
if err != nil {
s.log.Infof("Shell exited with error: %v", err)
@ -748,7 +750,9 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
// wait for the session to end before the shell, kill the shell
go func() {
<-s.closeC
s.term.Kill()
if err := s.term.Kill(); err != nil {
s.log.Debugf("Failed killing the shell: %v", err)
}
}()
return nil
}
@ -791,7 +795,7 @@ func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
if !ctx.srv.UseTunnel() {
eventFields[events.LocalAddr] = ctx.Conn.LocalAddr().String()
}
s.recorder.GetAuditLog().EmitAuditEvent(events.SessionStart, eventFields)
s.emitAuditEvent(events.SessionStart, eventFields)
// Start execution. If the program failed to start, send that result back.
// Note this is a partial start. Teleport will have re-exec'ed itself and
@ -864,7 +868,7 @@ func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
},
events.SessionServerHostname: ctx.srv.GetInfo().GetHostname(),
}
s.recorder.GetAuditLog().EmitAuditEvent(events.SessionEnd, eventFields)
s.emitAuditEvent(events.SessionEnd, eventFields)
// Close recorder to free up associated resources and flush data.
s.recorder.Close()
@ -1047,7 +1051,9 @@ func (s *session) addParty(p *party) error {
}
return data
}
p.Write(getRecentWrite())
if _, err := p.Write(getRecentWrite()); err != nil {
return trace.Wrap(err)
}
// Register this party as one of the session writers (output will go to it).
s.writer.addWriter(string(p.id), p, true)
@ -1076,6 +1082,12 @@ func (s *session) join(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) (*p
return p, nil
}
func (s *session) emitAuditEvent(e events.Event, f events.EventFields) {
if err := s.recorder.GetAuditLog().EmitAuditEvent(e, f); err != nil {
s.log.Warningf("Failed to emit audit event: %v", err)
}
}
func newMultiWriter() *multiWriter {
return &multiWriter{writers: make(map[string]writerWrapper)}
}

View file

@ -83,7 +83,9 @@ func (t *TermHandlers) HandlePTYReq(ch ssh.Channel, req *ssh.Request, ctx *Serve
ctx.SetTerm(term)
ctx.termAllocated = true
}
term.SetWinSize(*params)
if err := term.SetWinSize(*params); err != nil {
ctx.Errorf("Failed setting window size: %v", err)
}
term.SetTermType(ptyRequest.Env)
term.SetTerminalModes(termModes)

View file

@ -457,7 +457,10 @@ func (s *Server) HandleConnection(conn net.Conn) {
// send keepalive pings to the clients
case <-keepAliveTick.C:
const wantReply = true
sconn.SendRequest(teleport.KeepAliveReqType, wantReply, keepAlivePayload[:])
_, _, err = sconn.SendRequest(teleport.KeepAliveReqType, wantReply, keepAlivePayload[:])
if err != nil {
log.Errorf("Failed sending keepalive request: %v", err)
}
}
}
}

View file

@ -74,8 +74,10 @@ func (s *ServerSuite) TestStartStop(c *check.C) {
c.Assert(err, check.IsNil)
defer clt.Close()
// call new session to initiate opening new channel
clt.NewSession()
// Call new session to initiate opening new channel. This should get
// rejected and fail.
_, err = clt.NewSession()
c.Assert(err, check.NotNil)
c.Assert(srv.Close(), check.IsNil)
wait(c, srv)
@ -114,7 +116,8 @@ func (s *ServerSuite) TestShutdown(c *check.C) {
defer clt.Close()
// call new session to initiate opening new channel
clt.NewSession()
_, err = clt.NewSession()
c.Assert(err, check.IsNil)
// context will timeout because there is a connection around
ctx, ctxc := context.WithTimeout(context.TODO(), 50*time.Millisecond)

View file

@ -205,7 +205,8 @@ func (s *LBSuite) TestDropConnections(c *check.C) {
c.Assert(out, check.Equals, "backend 1")
// removing backend results in dropped connection to this backend
lb.RemoveBackend(backendAddr)
err = lb.RemoveBackend(backendAddr)
c.Assert(err, check.IsNil)
_, err = RoundtripWithConn(conn)
c.Assert(err, check.NotNil)
}

View file

@ -309,7 +309,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) {
}
httplib.SetIndexHTMLHeaders(w.Header())
if err := indexPage.Execute(w, session); err != nil {
log.Errorf("failed to execute index page template: %v", err)
log.Errorf("Failed to execute index page template: %v", err)
}
} else {
http.NotFound(w, r)