mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 08:43:58 +00:00
Ensure invalid tunnel agent connections get closed (#17899)
* Ensure invalid tunnel agent connections get closed Connections from reverse tunnel agents were being marked as invalid by the proxy under certain conditions but would ultimately never be closed. This could lead to scenarios where the agent thought things were fine but the proxy considered that agent unhealthy and unroutable. Pruning of invalid connections used to occur when a proxy tried to retrieve a connection for that tunnel. This also further muddied the point in time at which the proxy could close a connection as it never explicitly stopped tracking the connection and closed it at the same time. To remedy this, connections are explicitly closed by the proxy and removed from the mapping to stop tracking immediately. In order to prevent a connection that is servicing an active connection from being closed the proxy now tracks which connections have sessions. Closing does not occur when there are any active sessions to prevent them from being force terminated. When the proxy receives a heartbeat from an agent it now restores the connection to a valid state. In the event that too many heart beats have been missed for an agent, the proxy will now terminate the connection, again only if it is not serving any sessions. Fixes #15911
This commit is contained in:
parent
2aad238a12
commit
514bfc7ac6
|
@ -31,7 +31,11 @@ import (
|
|||
|
||||
// ConnectProxyTransport opens a channel over the remote tunnel and connects
|
||||
// to the requested host.
|
||||
func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChConn, bool, error) {
|
||||
//
|
||||
// Returns the net.Conn wrapper over an SSH channel, whether the provided ssh.Conn
|
||||
// should be considered invalid due to errors opening or sending a request to the
|
||||
// channel while setting up the ChConn, and any error that occurs.
|
||||
func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (conn *ChConn, invalid bool, err error) {
|
||||
if err := req.CheckAndSetDefaults(); err != nil {
|
||||
return nil, false, trace.Wrap(err)
|
||||
}
|
||||
|
@ -43,7 +47,7 @@ func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChCon
|
|||
|
||||
channel, discard, err := sconn.OpenChannel(constants.ChanTransport, nil)
|
||||
if err != nil {
|
||||
return nil, false, trace.Wrap(err)
|
||||
return nil, true, trace.Wrap(err)
|
||||
}
|
||||
|
||||
// DiscardRequests will return when the channel or underlying connection is closed.
|
||||
|
@ -55,7 +59,7 @@ func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChCon
|
|||
// this SSH channel.
|
||||
ok, err := channel.SendRequest(constants.ChanTransportDialReq, true, payload)
|
||||
if err != nil {
|
||||
return nil, true, trace.Wrap(err)
|
||||
return nil, true, trace.NewAggregate(trace.Wrap(err), channel.Close())
|
||||
}
|
||||
if !ok {
|
||||
defer channel.Close()
|
||||
|
|
|
@ -1122,7 +1122,7 @@ func (i *TeleInstance) AddUser(username string, mappings []string) *User {
|
|||
func (i *TeleInstance) Start() error {
|
||||
// Build a list of expected events to wait for before unblocking based off
|
||||
// the configuration passed in.
|
||||
expectedEvents := []string{}
|
||||
var expectedEvents []string
|
||||
if i.Config.Auth.Enabled {
|
||||
expectedEvents = append(expectedEvents, service.AuthTLSReady)
|
||||
}
|
||||
|
@ -1147,6 +1147,8 @@ func (i *TeleInstance) Start() error {
|
|||
expectedEvents = append(expectedEvents, service.KubernetesReady)
|
||||
}
|
||||
|
||||
expectedEvents = append(expectedEvents, service.InstanceReady)
|
||||
|
||||
// Start the process and block until the expected events have arrived.
|
||||
receivedEvents, err := StartAndWait(i.Process, expectedEvents)
|
||||
if err != nil {
|
||||
|
|
|
@ -1131,6 +1131,7 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
|
|||
rcConf.Proxy.DisableWebInterface = true
|
||||
rcConf.SSH.Enabled = false
|
||||
rcConf.CircuitBreakerConfig = breaker.NoopBreakerConfig()
|
||||
rcConf.Log = log
|
||||
|
||||
log.Infof("Root cluster config: %#v", rcConf)
|
||||
|
||||
|
@ -1162,7 +1163,9 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
|
|||
|
||||
rcProxyAddr := net.JoinHostPort(rcAddr, helpers.PortStr(t, rc.Web))
|
||||
require.Zero(t, ph.Count())
|
||||
_, err = rc.StartNode(makeNodeConfig("node1", rcProxyAddr))
|
||||
nodeCfg := makeNodeConfig("node1", rcProxyAddr)
|
||||
nodeCfg.Log = log
|
||||
_, err = rc.StartNode(nodeCfg)
|
||||
require.Error(t, err)
|
||||
|
||||
timeout := time.Second * 60
|
||||
|
|
|
@ -112,7 +112,7 @@ func testProxyTunnelStrategyAgentMesh(t *testing.T) {
|
|||
testResource: func(t *testing.T, p *proxyTunnelStrategy) {
|
||||
p.makeDatabase(t)
|
||||
|
||||
// wait for the node to be connected to both proxies
|
||||
// wait for the database to be connected to both proxies
|
||||
helpers.WaitForActiveTunnelConnections(t, p.proxies[0].Tunnel, p.cluster, 1)
|
||||
helpers.WaitForActiveTunnelConnections(t, p.proxies[1].Tunnel, p.cluster, 1)
|
||||
|
||||
|
@ -303,6 +303,8 @@ func (p *proxyTunnelStrategy) makeAuth(t *testing.T) {
|
|||
|
||||
conf := service.MakeDefaultConfig()
|
||||
conf.DataDir = t.TempDir()
|
||||
conf.Log = auth.Log
|
||||
|
||||
conf.Auth.Enabled = true
|
||||
conf.Auth.NetworkingConfig.SetTunnelStrategy(p.strategy)
|
||||
conf.Auth.SessionRecordingConfig.SetMode(types.RecordAtNodeSync)
|
||||
|
@ -331,6 +333,7 @@ func (p *proxyTunnelStrategy) makeProxy(t *testing.T) {
|
|||
conf.SetAuthServerAddress(*authAddr)
|
||||
conf.SetToken("token")
|
||||
conf.DataDir = t.TempDir()
|
||||
conf.Log = proxy.Log
|
||||
|
||||
conf.Auth.Enabled = false
|
||||
conf.SSH.Enabled = false
|
||||
|
@ -371,13 +374,15 @@ func (p *proxyTunnelStrategy) makeNode(t *testing.T) {
|
|||
})
|
||||
|
||||
conf := service.MakeDefaultConfig()
|
||||
conf.SetAuthServerAddress(utils.FromAddr(p.lb.Addr()))
|
||||
conf.Version = types.V3
|
||||
conf.SetToken("token")
|
||||
conf.DataDir = t.TempDir()
|
||||
conf.Log = node.Log
|
||||
|
||||
conf.Auth.Enabled = false
|
||||
conf.Proxy.Enabled = false
|
||||
conf.SSH.Enabled = true
|
||||
conf.ProxyServer = utils.FromAddr(p.lb.Addr())
|
||||
|
||||
process, err := service.NewTeleport(conf)
|
||||
require.NoError(t, err)
|
||||
|
@ -413,14 +418,16 @@ func (p *proxyTunnelStrategy) makeDatabase(t *testing.T) {
|
|||
})
|
||||
|
||||
conf := service.MakeDefaultConfig()
|
||||
conf.SetAuthServerAddress(utils.FromAddr(p.lb.Addr()))
|
||||
conf.Version = types.V3
|
||||
conf.SetToken("token")
|
||||
conf.DataDir = t.TempDir()
|
||||
conf.Log = db.Log
|
||||
|
||||
conf.Auth.Enabled = false
|
||||
conf.Proxy.Enabled = false
|
||||
conf.SSH.Enabled = false
|
||||
conf.Databases.Enabled = true
|
||||
conf.ProxyServer = utils.FromAddr(p.lb.Addr())
|
||||
conf.Databases.Databases = []service.Database{
|
||||
{
|
||||
Name: p.cluster + "-postgres",
|
||||
|
|
|
@ -641,7 +641,7 @@ func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
|||
|
||||
var r discoveryRequest
|
||||
if err := json.Unmarshal(req.Payload, &r); err != nil {
|
||||
a.log.WithError(err).Warningf("Bad payload")
|
||||
a.log.WithError(err).Warn("Bad payload")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ type connKey struct {
|
|||
type remoteConn struct {
|
||||
// lastHeartbeat is the last time a heartbeat was received.
|
||||
// intentionally placed first to ensure 64-bit alignment
|
||||
lastHeartbeat int64
|
||||
lastHeartbeat atomic.Int64
|
||||
|
||||
*connConfig
|
||||
mu sync.Mutex
|
||||
|
@ -62,17 +62,20 @@ type remoteConn struct {
|
|||
|
||||
// invalid indicates the connection is invalid and connections can no longer
|
||||
// be made on it.
|
||||
invalid int32
|
||||
invalid atomic.Bool
|
||||
|
||||
// lastError is the last error that occurred before this connection became
|
||||
// invalid.
|
||||
lastError error
|
||||
|
||||
// Used to make sure calling Close on the connection multiple times is safe.
|
||||
closed int32
|
||||
closed atomic.Bool
|
||||
|
||||
// clock is used to control time in tests.
|
||||
clock clockwork.Clock
|
||||
|
||||
// sessions counts the number of active sessions being serviced by this connection
|
||||
sessions atomic.Int64
|
||||
}
|
||||
|
||||
// connConfig is the configuration for the remoteConn.
|
||||
|
@ -120,23 +123,24 @@ func (c *remoteConn) String() string {
|
|||
|
||||
func (c *remoteConn) Close() error {
|
||||
// If the connection has already been closed, return right away.
|
||||
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
|
||||
if c.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
// Close the discovery channel.
|
||||
if c.discoveryCh != nil {
|
||||
c.discoveryCh.Close()
|
||||
errs = append(errs, c.discoveryCh.Close())
|
||||
c.discoveryCh = nil
|
||||
}
|
||||
|
||||
// Close the SSH connection which will close the underlying net.Conn as well.
|
||||
err := c.sconn.Close()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return trace.NewAggregate(errs...)
|
||||
|
||||
}
|
||||
|
||||
|
@ -155,27 +159,56 @@ func (c *remoteConn) ChannelConn(channel ssh.Channel) net.Conn {
|
|||
return sshutils.NewChConn(c.sconn, channel)
|
||||
}
|
||||
|
||||
func (c *remoteConn) incrementActiveSessions() {
|
||||
c.sessions.Add(1)
|
||||
}
|
||||
|
||||
func (c *remoteConn) decrementActiveSessions() {
|
||||
c.sessions.Add(-1)
|
||||
}
|
||||
|
||||
func (c *remoteConn) activeSessions() int64 {
|
||||
return c.sessions.Load()
|
||||
}
|
||||
|
||||
func (c *remoteConn) markInvalid(err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lastError = err
|
||||
atomic.StoreInt32(&c.invalid, 1)
|
||||
c.log.Debugf("Disconnecting connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err)
|
||||
c.invalid.Store(true)
|
||||
c.log.Warnf("Unhealthy connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
func (c *remoteConn) markValid() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lastError = nil
|
||||
c.invalid.Store(false)
|
||||
}
|
||||
|
||||
func (c *remoteConn) isInvalid() bool {
|
||||
return atomic.LoadInt32(&c.invalid) == 1
|
||||
return c.invalid.Load()
|
||||
}
|
||||
|
||||
func (c *remoteConn) setLastHeartbeat(tm time.Time) {
|
||||
atomic.StoreInt64(&c.lastHeartbeat, tm.UnixNano())
|
||||
c.lastHeartbeat.Store(tm.UnixNano())
|
||||
}
|
||||
|
||||
func (c *remoteConn) getLastHeartbeat() time.Time {
|
||||
hb := c.lastHeartbeat.Load()
|
||||
if hb == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
return time.Unix(0, hb)
|
||||
}
|
||||
|
||||
// isReady returns true when connection is ready to be tried,
|
||||
// it returns true when connection has received the first heartbeat
|
||||
func (c *remoteConn) isReady() bool {
|
||||
return atomic.LoadInt64(&c.lastHeartbeat) != 0
|
||||
return c.lastHeartbeat.Load() != 0
|
||||
}
|
||||
|
||||
func (c *remoteConn) openDiscoveryChannel() (ssh.Channel, error) {
|
||||
|
|
|
@ -52,6 +52,9 @@ const (
|
|||
// connected agents via a discovery request. It is a function of track.DefaultProxyExpiry
|
||||
// to ensure that the proxies are always synced before the tracker expiry.
|
||||
proxySyncInterval = track.DefaultProxyExpiry * 2 / 3
|
||||
|
||||
// missedHeartBeatThreshold is the number of missed heart beats needed to terminate a connection.
|
||||
missedHeartBeatThreshold = 3
|
||||
)
|
||||
|
||||
// withPeriodicFunctionInterval adjusts the periodic function interval
|
||||
|
@ -261,7 +264,9 @@ func (s *localSite) adviseReconnect(ctx context.Context) {
|
|||
|
||||
wg.Add(1)
|
||||
go func(conn *remoteConn) {
|
||||
conn.adviseReconnect()
|
||||
if err := conn.adviseReconnect(); err != nil {
|
||||
s.log.WithError(err).Warn("Failed sending reconnect advisory")
|
||||
}
|
||||
wg.Done()
|
||||
}(conn)
|
||||
}
|
||||
|
@ -296,15 +301,13 @@ func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
|
|||
// return a connection to that node. Otherwise net.Dial to the target host.
|
||||
targetConn, useTunnel, err := s.getConn(params)
|
||||
if err != nil {
|
||||
userAgent.Close()
|
||||
return nil, trace.Wrap(err)
|
||||
return nil, trace.NewAggregate(trace.Wrap(err), userAgent.Close())
|
||||
}
|
||||
|
||||
// Get a host certificate for the forwarding node from the cache.
|
||||
hostCertificate, err := s.certificateCache.getHostCertificate(params.Address, params.Principals)
|
||||
if err != nil {
|
||||
userAgent.Close()
|
||||
return nil, trace.Wrap(err)
|
||||
return nil, trace.NewAggregate(trace.Wrap(err), userAgent.Close())
|
||||
}
|
||||
|
||||
// Create a forwarding server that serves a single SSH connection on it. This
|
||||
|
@ -534,15 +537,25 @@ func (s *localSite) fanOutProxies(proxies []types.Server) {
|
|||
// if the agent has missed several heartbeats in a row, Proxy marks
|
||||
// the connection as invalid.
|
||||
func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
|
||||
proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval)
|
||||
|
||||
defer func() {
|
||||
s.log.Debugf("Cluster connection closed.")
|
||||
rconn.Close()
|
||||
proxyResyncTicker.Stop()
|
||||
}()
|
||||
logger := s.log.WithFields(log.Fields{
|
||||
"serverID": rconn.nodeID,
|
||||
"addr": rconn.conn.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
firstHeartbeat := true
|
||||
proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval)
|
||||
defer func() {
|
||||
proxyResyncTicker.Stop()
|
||||
logger.Warn("Closing remote connection to agent.")
|
||||
s.removeRemoteConn(rconn)
|
||||
if err := rconn.Close(); err != nil && !utils.IsOKNetworkError(err) {
|
||||
logger.WithError(err).Warn("Failed to close remote connection")
|
||||
}
|
||||
if !firstHeartbeat {
|
||||
reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Dec()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.srv.ctx.Done():
|
||||
|
@ -554,7 +567,7 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
|
|||
}
|
||||
|
||||
if err := rconn.sendDiscoveryRequest(req); err != nil {
|
||||
s.log.WithError(err).Debugf("Marking connection invalid on error")
|
||||
s.log.WithError(err).Debug("Marking connection invalid on error")
|
||||
rconn.markInvalid(err)
|
||||
return
|
||||
}
|
||||
|
@ -564,13 +577,13 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
|
|||
}
|
||||
|
||||
if err := rconn.sendDiscoveryRequest(req); err != nil {
|
||||
s.log.WithError(err).Debugf("Marking connection invalid on error")
|
||||
logger.WithError(err).Debug("Failed to send discovery request to agent")
|
||||
rconn.markInvalid(err)
|
||||
return
|
||||
}
|
||||
case req := <-reqC:
|
||||
if req == nil {
|
||||
s.log.Debugf("Cluster agent disconnected.")
|
||||
logger.Debug("Agent disconnected.")
|
||||
rconn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
|
||||
return
|
||||
}
|
||||
|
@ -582,7 +595,6 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
|
|||
rconn.updateProxies(current)
|
||||
}
|
||||
reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc()
|
||||
defer reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Dec()
|
||||
firstHeartbeat = false
|
||||
}
|
||||
var timeSent time.Time
|
||||
|
@ -592,16 +604,52 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
|
|||
roundtrip = s.srv.Clock.Now().Sub(timeSent)
|
||||
}
|
||||
}
|
||||
|
||||
log := logger
|
||||
if roundtrip != 0 {
|
||||
s.log.WithFields(log.Fields{"latency": roundtrip, "nodeID": rconn.nodeID}).Debugf("Ping <- %v", rconn.conn.RemoteAddr())
|
||||
} else {
|
||||
s.log.WithFields(log.Fields{"nodeID": rconn.nodeID}).Debugf("Ping <- %v", rconn.conn.RemoteAddr())
|
||||
log = logger.WithField("latency", roundtrip)
|
||||
}
|
||||
tm := s.clock.Now().UTC()
|
||||
rconn.setLastHeartbeat(tm)
|
||||
log.Debugf("Ping <- %v", rconn.conn.RemoteAddr())
|
||||
|
||||
rconn.setLastHeartbeat(s.clock.Now().UTC())
|
||||
rconn.markValid()
|
||||
// Note that time.After is re-created everytime a request is processed.
|
||||
case <-s.clock.After(s.offlineThreshold):
|
||||
case t := <-s.clock.After(s.offlineThreshold):
|
||||
rconn.markInvalid(trace.ConnectionProblem(nil, "no heartbeats for %v", s.offlineThreshold))
|
||||
|
||||
// terminate and remove the connection after missing more than missedHeartBeatThreshold heartbeats if
|
||||
// the connection isn't still servicing any sessions
|
||||
hb := rconn.getLastHeartbeat()
|
||||
if t.After(hb.Add(s.offlineThreshold * missedHeartBeatThreshold)) {
|
||||
count := rconn.activeSessions()
|
||||
if count == 0 {
|
||||
logger.Errorf("Closing unhealthy and idle connection. Heartbeat last received at %s", hb)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Warnf("Deferring closure of unhealthy connection due to %d active connections", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *localSite) removeRemoteConn(rconn *remoteConn) {
|
||||
s.remoteConnsMtx.Lock()
|
||||
defer s.remoteConnsMtx.Unlock()
|
||||
|
||||
key := connKey{
|
||||
uuid: rconn.nodeID,
|
||||
connType: types.TunnelType(rconn.tunnelType),
|
||||
}
|
||||
|
||||
conns := s.remoteConns[key]
|
||||
for i, conn := range conns {
|
||||
if conn == rconn {
|
||||
s.remoteConns[key] = append(conns[:i], conns[i+1:]...)
|
||||
if len(s.remoteConns[key]) == 0 {
|
||||
delete(s.remoteConns, key)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -610,36 +658,40 @@ func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) {
|
|||
s.remoteConnsMtx.Lock()
|
||||
defer s.remoteConnsMtx.Unlock()
|
||||
|
||||
// Loop over all connections and remove and invalid connections from the
|
||||
// connection map.
|
||||
for key, conns := range s.remoteConns {
|
||||
validConns := conns[:0]
|
||||
for _, conn := range conns {
|
||||
if !conn.isInvalid() {
|
||||
validConns = append(validConns, conn)
|
||||
}
|
||||
}
|
||||
if len(validConns) == 0 {
|
||||
delete(s.remoteConns, key)
|
||||
} else {
|
||||
s.remoteConns[key] = validConns
|
||||
}
|
||||
}
|
||||
|
||||
key := connKey{
|
||||
uuid: dreq.ServerID,
|
||||
connType: dreq.ConnType,
|
||||
}
|
||||
if len(s.remoteConns[key]) == 0 {
|
||||
|
||||
conns := s.remoteConns[key]
|
||||
if len(conns) == 0 {
|
||||
return nil, trace.NotFound("no %v reverse tunnel for %v found", dreq.ConnType, dreq.ServerID)
|
||||
}
|
||||
|
||||
conns := s.remoteConns[key]
|
||||
// Check the remoteConns from newest to oldest for one
|
||||
// that has heartbeated and is valid. If none are valid, try
|
||||
// the newest ready but invalid connection.
|
||||
var newestInvalidConn *remoteConn
|
||||
for i := len(conns) - 1; i >= 0; i-- {
|
||||
if conns[i].isReady() {
|
||||
switch {
|
||||
case !conns[i].isReady(): // skip remoteConn that haven't heartbeated yet
|
||||
continue
|
||||
case !conns[i].isInvalid(): // return the first valid remoteConn that has heartbeated
|
||||
return conns[i], nil
|
||||
case newestInvalidConn == nil && conns[i].isInvalid(): // cache the first invalid remoteConn in case none are valid
|
||||
newestInvalidConn = conns[i]
|
||||
}
|
||||
}
|
||||
|
||||
// This indicates that there were no ready and valid connections, but at least
|
||||
// one ready and invalid connection. We can at least attempt to connect on the
|
||||
// invalid connection instead of giving up entirely. If anything the error might
|
||||
// be more informative than the default offline message returned below.
|
||||
if newestInvalidConn != nil {
|
||||
return newestInvalidConn, nil
|
||||
}
|
||||
|
||||
// The agent is having issues and there is no way to connect
|
||||
return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID)
|
||||
}
|
||||
|
||||
|
@ -650,11 +702,43 @@ func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq)
|
|||
if err != nil {
|
||||
if markInvalid {
|
||||
rconn.markInvalid(err)
|
||||
// If not serving any connections close and remove this connection immediately.
|
||||
// Otherwise, let the heartbeat handler detect this connection is down.
|
||||
if rconn.activeSessions() == 0 {
|
||||
s.removeRemoteConn(rconn)
|
||||
return nil, trace.NewAggregate(trace.Wrap(err), rconn.Close())
|
||||
}
|
||||
}
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
return newSessionTrackingConn(rconn, conn), nil
|
||||
}
|
||||
|
||||
// sessionTrackingConn wraps a net.Conn in order
|
||||
// to maintain the number of active sessions for
|
||||
// a remoteConn.
|
||||
type sessionTrackingConn struct {
|
||||
net.Conn
|
||||
rc *remoteConn
|
||||
}
|
||||
|
||||
// newSessionTrackingConn wraps the provided net.Conn to alert the remoteConn
|
||||
// when it is no longer active. Prior to returning the remoteConn active sessions
|
||||
// are incremented. Close must be called to decrement the count.
|
||||
func newSessionTrackingConn(rconn *remoteConn, conn net.Conn) *sessionTrackingConn {
|
||||
rconn.incrementActiveSessions()
|
||||
return &sessionTrackingConn{
|
||||
rc: rconn,
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Close decrements the remoteConn active session count and then
|
||||
// closes the underlying net.Conn
|
||||
func (c *sessionTrackingConn) Close() error {
|
||||
c.rc.decrementActiveSessions()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// periodicFunctions runs functions periodic functions for the local cluster.
|
||||
|
@ -677,7 +761,7 @@ func (s *localSite) periodicFunctions() {
|
|||
// sshTunnelStats reports SSH tunnel statistics for the cluster.
|
||||
func (s *localSite) sshTunnelStats() error {
|
||||
missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool {
|
||||
// Skip over any servers that that have a TTL larger than announce TTL (10
|
||||
// Skip over any servers that have a TTL larger than announce TTL (10
|
||||
// minutes) and are non-IoT SSH servers (they won't have tunnels).
|
||||
//
|
||||
// Servers with a TTL larger than the announce TTL skipped over to work around
|
||||
|
|
|
@ -37,15 +37,105 @@ import (
|
|||
"github.com/gravitational/teleport/lib/utils"
|
||||
)
|
||||
|
||||
func TestRemoteConnCleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
clock := clockwork.NewFakeClock()
|
||||
|
||||
watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{
|
||||
ResourceWatcherConfig: services.ResourceWatcherConfig{
|
||||
Component: "test",
|
||||
Log: utils.NewLoggerForTests(),
|
||||
Clock: clock,
|
||||
Client: &mockLocalSiteClient{},
|
||||
},
|
||||
ProxiesC: make(chan []types.Server, 2),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, watcher.WaitInitialization())
|
||||
|
||||
// setup the site
|
||||
srv := &server{
|
||||
ctx: ctx,
|
||||
Config: Config{Clock: clock},
|
||||
localAuthClient: &mockLocalSiteClient{},
|
||||
log: utils.NewLoggerForTests(),
|
||||
offlineThreshold: time.Second,
|
||||
proxyWatcher: watcher,
|
||||
}
|
||||
|
||||
site, err := newlocalSite(srv, "clustername", nil, withPeriodicFunctionInterval(time.Hour), withProxySyncInterval(time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// add a connection
|
||||
rconn := &mockRemoteConnConn{}
|
||||
sconn := &mockedSSHConn{}
|
||||
conn1, err := site.addConn(uuid.NewString(), types.NodeTunnel, rconn, sconn)
|
||||
require.NoError(t, err)
|
||||
|
||||
reqs := make(chan *ssh.Request)
|
||||
|
||||
// terminated by too many missed heartbeats
|
||||
go func() {
|
||||
site.handleHeartbeat(conn1, nil, reqs)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// send an initial heartbeat
|
||||
reqs <- &ssh.Request{Type: "heartbeat"}
|
||||
|
||||
// create a fake session
|
||||
fakeSession := newSessionTrackingConn(conn1, &mockRemoteConnConn{})
|
||||
|
||||
// advance the clock to trigger missing a heartbeat, the last advance
|
||||
// should not force the connection to close since there is still an active session
|
||||
for i := 0; i <= missedHeartBeatThreshold+1; i++ {
|
||||
// wait until the heartbeat loop has created the timer
|
||||
clock.BlockUntil(3) // periodic ticker + heart beat timer + resync ticker = 3
|
||||
clock.Advance(srv.offlineThreshold)
|
||||
}
|
||||
|
||||
// the fake session should have prevented anything from closing
|
||||
require.False(t, conn1.closed.Load())
|
||||
require.False(t, sconn.closed.Load())
|
||||
|
||||
// send another heartbeat to reset exceeding the threshold
|
||||
reqs <- &ssh.Request{Type: "heartbeat"}
|
||||
|
||||
// close the fake session
|
||||
clock.BlockUntil(3) // periodic ticker + heart beat timer + resync ticker = 3
|
||||
require.NoError(t, fakeSession.Close())
|
||||
|
||||
// advance the clock to trigger missing a heartbeat, the last advance
|
||||
// should force the connection to close since there are no active sessions
|
||||
for i := 0; i <= missedHeartBeatThreshold; i++ {
|
||||
// wait until the heartbeat loop has created the timer
|
||||
clock.BlockUntil(3) // periodic ticker + heart beat timer + resync ticker = 3
|
||||
clock.Advance(srv.offlineThreshold)
|
||||
}
|
||||
|
||||
// wait for handleHeartbeat to finish
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(30 * time.Second): // artificially high to prevent flakiness
|
||||
t.Fatal("LocalSite heart beat handler never terminated")
|
||||
}
|
||||
|
||||
// assert the connections were closed
|
||||
require.True(t, conn1.closed.Load())
|
||||
require.True(t, sconn.closed.Load())
|
||||
}
|
||||
|
||||
func TestLocalSiteOverlap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
srv := &server{
|
||||
ctx: ctx,
|
||||
localAuthClient: &mockLocalSiteClient{},
|
||||
Config: Config{Clock: clockwork.NewFakeClock()},
|
||||
ctx: context.Background(),
|
||||
localAuthClient: &mockLocalSiteClient{},
|
||||
}
|
||||
|
||||
site, err := newlocalSite(srv, "clustername", nil, withPeriodicFunctionInterval(time.Hour))
|
||||
|
@ -58,35 +148,76 @@ func TestLocalSiteOverlap(t *testing.T) {
|
|||
ConnType: connType,
|
||||
}
|
||||
|
||||
conn1, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil)
|
||||
// add a few connections for the same node id
|
||||
conn1, err := site.addConn(nodeID, connType, &mockRemoteConnConn{}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn2, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil)
|
||||
conn2, err := site.addConn(nodeID, connType, &mockRemoteConnConn{}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn3, err := site.addConn(nodeID, connType, &mockRemoteConnConn{}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// no heartbeats from any of them shouldn't return a connection
|
||||
c, err := site.getRemoteConn(dreq)
|
||||
require.True(t, trace.IsNotFound(err))
|
||||
require.Nil(t, c)
|
||||
|
||||
// ensure conn1 is ready
|
||||
conn1.setLastHeartbeat(time.Now())
|
||||
|
||||
// getRemoteConn returns the only healthy connection
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn1, c)
|
||||
|
||||
// ensure conn2 is ready
|
||||
conn2.setLastHeartbeat(time.Now())
|
||||
|
||||
// getRemoteConn returns the newest healthy connection
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn2, c)
|
||||
|
||||
// mark conn2 invalid
|
||||
conn2.markInvalid(nil)
|
||||
|
||||
// getRemoteConn returns the only healthy connection
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn1, c)
|
||||
|
||||
// mark conn1 invalid
|
||||
conn1.markInvalid(nil)
|
||||
|
||||
// getRemoteConn returns the only healthy connection
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn2, c)
|
||||
|
||||
// remove conn2
|
||||
site.removeRemoteConn(conn2)
|
||||
|
||||
// getRemoteConn returns the only invalid connection
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn1, c)
|
||||
|
||||
// remove conn1
|
||||
site.removeRemoteConn(conn1)
|
||||
|
||||
// no ready connections exist
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.True(t, trace.IsNotFound(err))
|
||||
require.Nil(t, c)
|
||||
|
||||
// mark conn3 as ready
|
||||
conn3.setLastHeartbeat(time.Now())
|
||||
|
||||
// getRemoteConn returns the only healthy connection
|
||||
c, err = site.getRemoteConn(dreq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conn3, c)
|
||||
}
|
||||
|
||||
func TestProxyResync(t *testing.T) {
|
||||
|
@ -239,10 +370,16 @@ func (mockLocalSiteClient) NewWatcher(context.Context, types.Watch) (types.Watch
|
|||
|
||||
type mockRemoteConnConn struct {
|
||||
net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *mockRemoteConnConn) Close() error {
|
||||
c.closed.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// called for logging by (*remoteConn).markInvalid()
|
||||
func (mockRemoteConnConn) RemoteAddr() net.Addr {
|
||||
func (*mockRemoteConnConn) RemoteAddr() net.Addr {
|
||||
return &utils.NetAddr{
|
||||
Addr: "localhost",
|
||||
AddrNetwork: "tcp",
|
||||
|
|
|
@ -391,10 +391,7 @@ func (p *proxyCollector) getResourcesAndUpdateCurrent(ctx context.Context) error
|
|||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if len(proxies) == 0 {
|
||||
// At least one proxy ought to exist.
|
||||
return trace.NotFound("empty proxy list")
|
||||
}
|
||||
|
||||
newCurrent := make(map[string]types.Server, len(proxies))
|
||||
for _, proxy := range proxies {
|
||||
newCurrent[proxy.GetName()] = proxy
|
||||
|
|
|
@ -133,14 +133,6 @@ func TestProxyWatcher(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
t.Cleanup(w.Close)
|
||||
|
||||
// Since no proxy is yet present, the ProxyWatcher should immediately
|
||||
// yield back to its retry loop.
|
||||
select {
|
||||
case <-w.ResetC:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timeout waiting for ProxyWatcher reset.")
|
||||
}
|
||||
|
||||
// Add a proxy server.
|
||||
proxy := newProxyServer(t, "proxy1", "127.0.0.1:2023")
|
||||
require.NoError(t, presence.UpsertProxy(proxy))
|
||||
|
|
Loading…
Reference in a new issue