Added better connection handling for reverse tunnels

Two changes:

1. Read/write timeouts are applied on both client & server.
2. Server now always closes its connection if there are any issues
   communicating with the client.
This commit is contained in:
Ev Kontsevoy 2016-10-23 16:56:15 -07:00
parent aa680fc3c9
commit 3411cc31da
2 changed files with 28 additions and 16 deletions

View file

@ -171,6 +171,12 @@ func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) {
return
}
// apply read/write timeouts to this connection that are 10x of what normal
// reverse tunnel ping is supposed to be:
conn = utils.ObeyTimeouts(conn,
defaults.ReverseTunnelAgentHeartbeatPeriod*10,
"reverse tunnel client")
wg := sync.WaitGroup{}
wg.Add(2)

View file

@ -173,6 +173,11 @@ func (s *server) Close() error {
}
func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
// apply read/write timeouts to the server connection
conn = utils.ObeyTimeouts(conn,
defaults.ReverseTunnelAgentHeartbeatPeriod*10,
"reverse tunnel server connection")
ct := nch.ChannelType()
if ct != chanHeartbeat {
msg := fmt.Sprintf("reversetunnel received unknown channel request %v from %v",
@ -557,23 +562,24 @@ func (s *tunnelSite) setLastActive(t time.Time) {
}
func (s *tunnelSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
go func() {
for {
select {
case req := <-reqC:
if req == nil {
s.log.Infof("[TUNNEL] site disconnected: %v", s.domainName)
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
return
}
log.Debugf("[TUNNEL] ping from \"%s\" %s", s.domainName, conn.conn.RemoteAddr())
s.setLastActive(time.Now())
case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod):
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats"))
conn.sshConn.Close()
}
}
defer func() {
s.log.Infof("[TUNNEL] site connection closed: %v", s.domainName)
conn.Close()
}()
for {
select {
case req := <-reqC:
if req == nil {
s.log.Infof("[TUNNEL] site disconnected: %v", s.domainName)
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
return
}
log.Debugf("[TUNNEL] ping from \"%s\" %s", s.domainName, conn.conn.RemoteAddr())
s.setLastActive(time.Now())
case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod):
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats"))
}
}
}
func (s *tunnelSite) GetName() string {