Ensure invalid tunnel agent connections get closed (#17899)

* Ensure invalid tunnel agent connections get closed

Connections from reverse tunnel agents were being marked
as invalid by the proxy under certain conditions but would
ultimately never be closed. This could lead to scenarios where
the agent thought things were fine but the proxy considered
that agent unhealthy and unroutable.

Pruning of invalid connections used to occur when a proxy
tried to retrieve a connection for that tunnel. This also
further muddied the point in time at which the proxy could
close a connection as it never explicitly stopped tracking
the connection and closed it at the same time.

To remedy this, connections are explicitly closed by the proxy
and removed from the mapping to stop tracking immediately. In order
to prevent a connection that is servicing an active connection
from being closed the proxy now tracks which connections have
sessions. Closing does not occur when there are any active
sessions to prevent them from being force terminated.

When the proxy receives a heartbeat from an agent it now restores
the connection to a valid state. In the event that too many heart
beats have been missed for an agent, the proxy will now terminate
the connection, again only if it is not serving any sessions.

Fixes #15911
This commit is contained in:
rosstimothy 2022-11-04 14:05:13 -04:00 committed by GitHub
parent 2aad238a12
commit 514bfc7ac6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 342 additions and 83 deletions

View file

@ -31,7 +31,11 @@ import (
// ConnectProxyTransport opens a channel over the remote tunnel and connects
// to the requested host.
func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChConn, bool, error) {
//
// Returns the net.Conn wrapper over an SSH channel, whether the provided ssh.Conn
// should be considered invalid due to errors opening or sending a request to the
// channel while setting up the ChConn, and any error that occurs.
func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (conn *ChConn, invalid bool, err error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, false, trace.Wrap(err)
}
@ -43,7 +47,7 @@ func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChCon
channel, discard, err := sconn.OpenChannel(constants.ChanTransport, nil)
if err != nil {
return nil, false, trace.Wrap(err)
return nil, true, trace.Wrap(err)
}
// DiscardRequests will return when the channel or underlying connection is closed.
@ -55,7 +59,7 @@ func ConnectProxyTransport(sconn ssh.Conn, req *DialReq, exclusive bool) (*ChCon
// this SSH channel.
ok, err := channel.SendRequest(constants.ChanTransportDialReq, true, payload)
if err != nil {
return nil, true, trace.Wrap(err)
return nil, true, trace.NewAggregate(trace.Wrap(err), channel.Close())
}
if !ok {
defer channel.Close()

View file

@ -1122,7 +1122,7 @@ func (i *TeleInstance) AddUser(username string, mappings []string) *User {
func (i *TeleInstance) Start() error {
// Build a list of expected events to wait for before unblocking based off
// the configuration passed in.
expectedEvents := []string{}
var expectedEvents []string
if i.Config.Auth.Enabled {
expectedEvents = append(expectedEvents, service.AuthTLSReady)
}
@ -1147,6 +1147,8 @@ func (i *TeleInstance) Start() error {
expectedEvents = append(expectedEvents, service.KubernetesReady)
}
expectedEvents = append(expectedEvents, service.InstanceReady)
// Start the process and block until the expected events have arrived.
receivedEvents, err := StartAndWait(i.Process, expectedEvents)
if err != nil {

View file

@ -1131,6 +1131,7 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
rcConf.Proxy.DisableWebInterface = true
rcConf.SSH.Enabled = false
rcConf.CircuitBreakerConfig = breaker.NoopBreakerConfig()
rcConf.Log = log
log.Infof("Root cluster config: %#v", rcConf)
@ -1162,7 +1163,9 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
rcProxyAddr := net.JoinHostPort(rcAddr, helpers.PortStr(t, rc.Web))
require.Zero(t, ph.Count())
_, err = rc.StartNode(makeNodeConfig("node1", rcProxyAddr))
nodeCfg := makeNodeConfig("node1", rcProxyAddr)
nodeCfg.Log = log
_, err = rc.StartNode(nodeCfg)
require.Error(t, err)
timeout := time.Second * 60

View file

@ -112,7 +112,7 @@ func testProxyTunnelStrategyAgentMesh(t *testing.T) {
testResource: func(t *testing.T, p *proxyTunnelStrategy) {
p.makeDatabase(t)
// wait for the node to be connected to both proxies
// wait for the database to be connected to both proxies
helpers.WaitForActiveTunnelConnections(t, p.proxies[0].Tunnel, p.cluster, 1)
helpers.WaitForActiveTunnelConnections(t, p.proxies[1].Tunnel, p.cluster, 1)
@ -303,6 +303,8 @@ func (p *proxyTunnelStrategy) makeAuth(t *testing.T) {
conf := service.MakeDefaultConfig()
conf.DataDir = t.TempDir()
conf.Log = auth.Log
conf.Auth.Enabled = true
conf.Auth.NetworkingConfig.SetTunnelStrategy(p.strategy)
conf.Auth.SessionRecordingConfig.SetMode(types.RecordAtNodeSync)
@ -331,6 +333,7 @@ func (p *proxyTunnelStrategy) makeProxy(t *testing.T) {
conf.SetAuthServerAddress(*authAddr)
conf.SetToken("token")
conf.DataDir = t.TempDir()
conf.Log = proxy.Log
conf.Auth.Enabled = false
conf.SSH.Enabled = false
@ -371,13 +374,15 @@ func (p *proxyTunnelStrategy) makeNode(t *testing.T) {
})
conf := service.MakeDefaultConfig()
conf.SetAuthServerAddress(utils.FromAddr(p.lb.Addr()))
conf.Version = types.V3
conf.SetToken("token")
conf.DataDir = t.TempDir()
conf.Log = node.Log
conf.Auth.Enabled = false
conf.Proxy.Enabled = false
conf.SSH.Enabled = true
conf.ProxyServer = utils.FromAddr(p.lb.Addr())
process, err := service.NewTeleport(conf)
require.NoError(t, err)
@ -413,14 +418,16 @@ func (p *proxyTunnelStrategy) makeDatabase(t *testing.T) {
})
conf := service.MakeDefaultConfig()
conf.SetAuthServerAddress(utils.FromAddr(p.lb.Addr()))
conf.Version = types.V3
conf.SetToken("token")
conf.DataDir = t.TempDir()
conf.Log = db.Log
conf.Auth.Enabled = false
conf.Proxy.Enabled = false
conf.SSH.Enabled = false
conf.Databases.Enabled = true
conf.ProxyServer = utils.FromAddr(p.lb.Addr())
conf.Databases.Databases = []service.Database{
{
Name: p.cluster + "-postgres",

View file

@ -641,7 +641,7 @@ func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
var r discoveryRequest
if err := json.Unmarshal(req.Payload, &r); err != nil {
a.log.WithError(err).Warningf("Bad payload")
a.log.WithError(err).Warn("Bad payload")
return
}

View file

@ -48,7 +48,7 @@ type connKey struct {
type remoteConn struct {
// lastHeartbeat is the last time a heartbeat was received.
// intentionally placed first to ensure 64-bit alignment
lastHeartbeat int64
lastHeartbeat atomic.Int64
*connConfig
mu sync.Mutex
@ -62,17 +62,20 @@ type remoteConn struct {
// invalid indicates the connection is invalid and connections can no longer
// be made on it.
invalid int32
invalid atomic.Bool
// lastError is the last error that occurred before this connection became
// invalid.
lastError error
// Used to make sure calling Close on the connection multiple times is safe.
closed int32
closed atomic.Bool
// clock is used to control time in tests.
clock clockwork.Clock
// sessions counts the number of active sessions being serviced by this connection
sessions atomic.Int64
}
// connConfig is the configuration for the remoteConn.
@ -120,23 +123,24 @@ func (c *remoteConn) String() string {
func (c *remoteConn) Close() error {
// If the connection has already been closed, return right away.
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
if c.closed.Swap(true) {
return nil
}
var errs []error
// Close the discovery channel.
if c.discoveryCh != nil {
c.discoveryCh.Close()
errs = append(errs, c.discoveryCh.Close())
c.discoveryCh = nil
}
// Close the SSH connection which will close the underlying net.Conn as well.
err := c.sconn.Close()
if err != nil {
return trace.Wrap(err)
errs = append(errs, err)
}
return nil
return trace.NewAggregate(errs...)
}
@ -155,27 +159,56 @@ func (c *remoteConn) ChannelConn(channel ssh.Channel) net.Conn {
return sshutils.NewChConn(c.sconn, channel)
}
func (c *remoteConn) incrementActiveSessions() {
c.sessions.Add(1)
}
func (c *remoteConn) decrementActiveSessions() {
c.sessions.Add(-1)
}
func (c *remoteConn) activeSessions() int64 {
return c.sessions.Load()
}
func (c *remoteConn) markInvalid(err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.lastError = err
atomic.StoreInt32(&c.invalid, 1)
c.log.Debugf("Disconnecting connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err)
c.invalid.Store(true)
c.log.Warnf("Unhealthy connection to %v %v: %v.", c.clusterName, c.conn.RemoteAddr(), err)
}
func (c *remoteConn) markValid() {
c.mu.Lock()
defer c.mu.Unlock()
c.lastError = nil
c.invalid.Store(false)
}
func (c *remoteConn) isInvalid() bool {
return atomic.LoadInt32(&c.invalid) == 1
return c.invalid.Load()
}
func (c *remoteConn) setLastHeartbeat(tm time.Time) {
atomic.StoreInt64(&c.lastHeartbeat, tm.UnixNano())
c.lastHeartbeat.Store(tm.UnixNano())
}
func (c *remoteConn) getLastHeartbeat() time.Time {
hb := c.lastHeartbeat.Load()
if hb == 0 {
return time.Time{}
}
return time.Unix(0, hb)
}
// isReady returns true when connection is ready to be tried,
// it returns true when connection has received the first heartbeat
func (c *remoteConn) isReady() bool {
return atomic.LoadInt64(&c.lastHeartbeat) != 0
return c.lastHeartbeat.Load() != 0
}
func (c *remoteConn) openDiscoveryChannel() (ssh.Channel, error) {

View file

@ -52,6 +52,9 @@ const (
// connected agents via a discovery request. It is a function of track.DefaultProxyExpiry
// to ensure that the proxies are always synced before the tracker expiry.
proxySyncInterval = track.DefaultProxyExpiry * 2 / 3
// missedHeartBeatThreshold is the number of missed heart beats needed to terminate a connection.
missedHeartBeatThreshold = 3
)
// withPeriodicFunctionInterval adjusts the periodic function interval
@ -261,7 +264,9 @@ func (s *localSite) adviseReconnect(ctx context.Context) {
wg.Add(1)
go func(conn *remoteConn) {
conn.adviseReconnect()
if err := conn.adviseReconnect(); err != nil {
s.log.WithError(err).Warn("Failed sending reconnect advisory")
}
wg.Done()
}(conn)
}
@ -296,15 +301,13 @@ func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
// return a connection to that node. Otherwise net.Dial to the target host.
targetConn, useTunnel, err := s.getConn(params)
if err != nil {
userAgent.Close()
return nil, trace.Wrap(err)
return nil, trace.NewAggregate(trace.Wrap(err), userAgent.Close())
}
// Get a host certificate for the forwarding node from the cache.
hostCertificate, err := s.certificateCache.getHostCertificate(params.Address, params.Principals)
if err != nil {
userAgent.Close()
return nil, trace.Wrap(err)
return nil, trace.NewAggregate(trace.Wrap(err), userAgent.Close())
}
// Create a forwarding server that serves a single SSH connection on it. This
@ -534,15 +537,25 @@ func (s *localSite) fanOutProxies(proxies []types.Server) {
// if the agent has missed several heartbeats in a row, Proxy marks
// the connection as invalid.
func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval)
defer func() {
s.log.Debugf("Cluster connection closed.")
rconn.Close()
proxyResyncTicker.Stop()
}()
logger := s.log.WithFields(log.Fields{
"serverID": rconn.nodeID,
"addr": rconn.conn.RemoteAddr().String(),
})
firstHeartbeat := true
proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval)
defer func() {
proxyResyncTicker.Stop()
logger.Warn("Closing remote connection to agent.")
s.removeRemoteConn(rconn)
if err := rconn.Close(); err != nil && !utils.IsOKNetworkError(err) {
logger.WithError(err).Warn("Failed to close remote connection")
}
if !firstHeartbeat {
reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Dec()
}
}()
for {
select {
case <-s.srv.ctx.Done():
@ -554,7 +567,7 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
}
if err := rconn.sendDiscoveryRequest(req); err != nil {
s.log.WithError(err).Debugf("Marking connection invalid on error")
s.log.WithError(err).Debug("Marking connection invalid on error")
rconn.markInvalid(err)
return
}
@ -564,13 +577,13 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
}
if err := rconn.sendDiscoveryRequest(req); err != nil {
s.log.WithError(err).Debugf("Marking connection invalid on error")
logger.WithError(err).Debug("Failed to send discovery request to agent")
rconn.markInvalid(err)
return
}
case req := <-reqC:
if req == nil {
s.log.Debugf("Cluster agent disconnected.")
logger.Debug("Agent disconnected.")
rconn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
return
}
@ -582,7 +595,6 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
rconn.updateProxies(current)
}
reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc()
defer reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Dec()
firstHeartbeat = false
}
var timeSent time.Time
@ -592,16 +604,52 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
roundtrip = s.srv.Clock.Now().Sub(timeSent)
}
}
log := logger
if roundtrip != 0 {
s.log.WithFields(log.Fields{"latency": roundtrip, "nodeID": rconn.nodeID}).Debugf("Ping <- %v", rconn.conn.RemoteAddr())
} else {
s.log.WithFields(log.Fields{"nodeID": rconn.nodeID}).Debugf("Ping <- %v", rconn.conn.RemoteAddr())
log = logger.WithField("latency", roundtrip)
}
tm := s.clock.Now().UTC()
rconn.setLastHeartbeat(tm)
log.Debugf("Ping <- %v", rconn.conn.RemoteAddr())
rconn.setLastHeartbeat(s.clock.Now().UTC())
rconn.markValid()
// Note that time.After is re-created everytime a request is processed.
case <-s.clock.After(s.offlineThreshold):
case t := <-s.clock.After(s.offlineThreshold):
rconn.markInvalid(trace.ConnectionProblem(nil, "no heartbeats for %v", s.offlineThreshold))
// terminate and remove the connection after missing more than missedHeartBeatThreshold heartbeats if
// the connection isn't still servicing any sessions
hb := rconn.getLastHeartbeat()
if t.After(hb.Add(s.offlineThreshold * missedHeartBeatThreshold)) {
count := rconn.activeSessions()
if count == 0 {
logger.Errorf("Closing unhealthy and idle connection. Heartbeat last received at %s", hb)
return
}
logger.Warnf("Deferring closure of unhealthy connection due to %d active connections", count)
}
}
}
}
func (s *localSite) removeRemoteConn(rconn *remoteConn) {
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()
key := connKey{
uuid: rconn.nodeID,
connType: types.TunnelType(rconn.tunnelType),
}
conns := s.remoteConns[key]
for i, conn := range conns {
if conn == rconn {
s.remoteConns[key] = append(conns[:i], conns[i+1:]...)
if len(s.remoteConns[key]) == 0 {
delete(s.remoteConns, key)
}
return
}
}
}
@ -610,36 +658,40 @@ func (s *localSite) getRemoteConn(dreq *sshutils.DialReq) (*remoteConn, error) {
s.remoteConnsMtx.Lock()
defer s.remoteConnsMtx.Unlock()
// Loop over all connections and remove and invalid connections from the
// connection map.
for key, conns := range s.remoteConns {
validConns := conns[:0]
for _, conn := range conns {
if !conn.isInvalid() {
validConns = append(validConns, conn)
}
}
if len(validConns) == 0 {
delete(s.remoteConns, key)
} else {
s.remoteConns[key] = validConns
}
}
key := connKey{
uuid: dreq.ServerID,
connType: dreq.ConnType,
}
if len(s.remoteConns[key]) == 0 {
conns := s.remoteConns[key]
if len(conns) == 0 {
return nil, trace.NotFound("no %v reverse tunnel for %v found", dreq.ConnType, dreq.ServerID)
}
conns := s.remoteConns[key]
// Check the remoteConns from newest to oldest for one
// that has heartbeated and is valid. If none are valid, try
// the newest ready but invalid connection.
var newestInvalidConn *remoteConn
for i := len(conns) - 1; i >= 0; i-- {
if conns[i].isReady() {
switch {
case !conns[i].isReady(): // skip remoteConn that haven't heartbeated yet
continue
case !conns[i].isInvalid(): // return the first valid remoteConn that has heartbeated
return conns[i], nil
case newestInvalidConn == nil && conns[i].isInvalid(): // cache the first invalid remoteConn in case none are valid
newestInvalidConn = conns[i]
}
}
// This indicates that there were no ready and valid connections, but at least
// one ready and invalid connection. We can at least attempt to connect on the
// invalid connection instead of giving up entirely. If anything the error might
// be more informative than the default offline message returned below.
if newestInvalidConn != nil {
return newestInvalidConn, nil
}
// The agent is having issues and there is no way to connect
return nil, trace.NotFound("%v is offline: no active %v tunnels found", dreq.ConnType, dreq.ServerID)
}
@ -650,11 +702,43 @@ func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq)
if err != nil {
if markInvalid {
rconn.markInvalid(err)
// If not serving any connections close and remove this connection immediately.
// Otherwise, let the heartbeat handler detect this connection is down.
if rconn.activeSessions() == 0 {
s.removeRemoteConn(rconn)
return nil, trace.NewAggregate(trace.Wrap(err), rconn.Close())
}
}
return nil, trace.Wrap(err)
}
return conn, nil
return newSessionTrackingConn(rconn, conn), nil
}
// sessionTrackingConn wraps a net.Conn in order
// to maintain the number of active sessions for
// a remoteConn.
type sessionTrackingConn struct {
net.Conn
rc *remoteConn
}
// newSessionTrackingConn wraps the provided net.Conn to alert the remoteConn
// when it is no longer active. Prior to returning the remoteConn active sessions
// are incremented. Close must be called to decrement the count.
func newSessionTrackingConn(rconn *remoteConn, conn net.Conn) *sessionTrackingConn {
rconn.incrementActiveSessions()
return &sessionTrackingConn{
rc: rconn,
Conn: conn,
}
}
// Close decrements the remoteConn active session count and then
// closes the underlying net.Conn
func (c *sessionTrackingConn) Close() error {
c.rc.decrementActiveSessions()
return c.Conn.Close()
}
// periodicFunctions runs functions periodic functions for the local cluster.
@ -677,7 +761,7 @@ func (s *localSite) periodicFunctions() {
// sshTunnelStats reports SSH tunnel statistics for the cluster.
func (s *localSite) sshTunnelStats() error {
missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool {
// Skip over any servers that that have a TTL larger than announce TTL (10
// Skip over any servers that have a TTL larger than announce TTL (10
// minutes) and are non-IoT SSH servers (they won't have tunnels).
//
// Servers with a TTL larger than the announce TTL skipped over to work around

View file

@ -37,15 +37,105 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func TestRemoteConnCleanup(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
clock := clockwork.NewFakeClock()
watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
Log: utils.NewLoggerForTests(),
Clock: clock,
Client: &mockLocalSiteClient{},
},
ProxiesC: make(chan []types.Server, 2),
})
require.NoError(t, err)
require.NoError(t, watcher.WaitInitialization())
// setup the site
srv := &server{
ctx: ctx,
Config: Config{Clock: clock},
localAuthClient: &mockLocalSiteClient{},
log: utils.NewLoggerForTests(),
offlineThreshold: time.Second,
proxyWatcher: watcher,
}
site, err := newlocalSite(srv, "clustername", nil, withPeriodicFunctionInterval(time.Hour), withProxySyncInterval(time.Hour))
require.NoError(t, err)
// add a connection
rconn := &mockRemoteConnConn{}
sconn := &mockedSSHConn{}
conn1, err := site.addConn(uuid.NewString(), types.NodeTunnel, rconn, sconn)
require.NoError(t, err)
reqs := make(chan *ssh.Request)
// terminated by too many missed heartbeats
go func() {
site.handleHeartbeat(conn1, nil, reqs)
cancel()
}()
// send an initial heartbeat
reqs <- &ssh.Request{Type: "heartbeat"}
// create a fake session
fakeSession := newSessionTrackingConn(conn1, &mockRemoteConnConn{})
// advance the clock to trigger missing a heartbeat, the last advance
// should not force the connection to close since there is still an active session
for i := 0; i <= missedHeartBeatThreshold+1; i++ {
// wait until the heartbeat loop has created the timer
clock.BlockUntil(3) // periodic ticker + heart beat timer + resync ticker = 3
clock.Advance(srv.offlineThreshold)
}
// the fake session should have prevented anything from closing
require.False(t, conn1.closed.Load())
require.False(t, sconn.closed.Load())
// send another heartbeat to reset exceeding the threshold
reqs <- &ssh.Request{Type: "heartbeat"}
// close the fake session
clock.BlockUntil(3) // periodic ticker + heart beat timer + resync ticker = 3
require.NoError(t, fakeSession.Close())
// advance the clock to trigger missing a heartbeat, the last advance
// should force the connection to close since there are no active sessions
for i := 0; i <= missedHeartBeatThreshold; i++ {
// wait until the heartbeat loop has created the timer
clock.BlockUntil(3) // periodic ticker + heart beat timer + resync ticker = 3
clock.Advance(srv.offlineThreshold)
}
// wait for handleHeartbeat to finish
select {
case <-ctx.Done():
case <-time.After(30 * time.Second): // artificially high to prevent flakiness
t.Fatal("LocalSite heart beat handler never terminated")
}
// assert the connections were closed
require.True(t, conn1.closed.Load())
require.True(t, sconn.closed.Load())
}
func TestLocalSiteOverlap(t *testing.T) {
t.Parallel()
ctx := context.Background()
srv := &server{
ctx: ctx,
localAuthClient: &mockLocalSiteClient{},
Config: Config{Clock: clockwork.NewFakeClock()},
ctx: context.Background(),
localAuthClient: &mockLocalSiteClient{},
}
site, err := newlocalSite(srv, "clustername", nil, withPeriodicFunctionInterval(time.Hour))
@ -58,35 +148,76 @@ func TestLocalSiteOverlap(t *testing.T) {
ConnType: connType,
}
conn1, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil)
// add a few connections for the same node id
conn1, err := site.addConn(nodeID, connType, &mockRemoteConnConn{}, nil)
require.NoError(t, err)
conn2, err := site.addConn(nodeID, connType, mockRemoteConnConn{}, nil)
conn2, err := site.addConn(nodeID, connType, &mockRemoteConnConn{}, nil)
require.NoError(t, err)
conn3, err := site.addConn(nodeID, connType, &mockRemoteConnConn{}, nil)
require.NoError(t, err)
// no heartbeats from any of them shouldn't return a connection
c, err := site.getRemoteConn(dreq)
require.True(t, trace.IsNotFound(err))
require.Nil(t, c)
// ensure conn1 is ready
conn1.setLastHeartbeat(time.Now())
// getRemoteConn returns the only healthy connection
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn1, c)
// ensure conn2 is ready
conn2.setLastHeartbeat(time.Now())
// getRemoteConn returns the newest healthy connection
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn2, c)
// mark conn2 invalid
conn2.markInvalid(nil)
// getRemoteConn returns the only healthy connection
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn1, c)
// mark conn1 invalid
conn1.markInvalid(nil)
// getRemoteConn returns the only healthy connection
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn2, c)
// remove conn2
site.removeRemoteConn(conn2)
// getRemoteConn returns the only invalid connection
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn1, c)
// remove conn1
site.removeRemoteConn(conn1)
// no ready connections exist
c, err = site.getRemoteConn(dreq)
require.True(t, trace.IsNotFound(err))
require.Nil(t, c)
// mark conn3 as ready
conn3.setLastHeartbeat(time.Now())
// getRemoteConn returns the only healthy connection
c, err = site.getRemoteConn(dreq)
require.NoError(t, err)
require.Equal(t, conn3, c)
}
func TestProxyResync(t *testing.T) {
@ -239,10 +370,16 @@ func (mockLocalSiteClient) NewWatcher(context.Context, types.Watch) (types.Watch
type mockRemoteConnConn struct {
net.Conn
closed atomic.Bool
}
func (c *mockRemoteConnConn) Close() error {
c.closed.Store(true)
return nil
}
// called for logging by (*remoteConn).markInvalid()
func (mockRemoteConnConn) RemoteAddr() net.Addr {
func (*mockRemoteConnConn) RemoteAddr() net.Addr {
return &utils.NetAddr{
Addr: "localhost",
AddrNetwork: "tcp",

View file

@ -391,10 +391,7 @@ func (p *proxyCollector) getResourcesAndUpdateCurrent(ctx context.Context) error
if err != nil {
return trace.Wrap(err)
}
if len(proxies) == 0 {
// At least one proxy ought to exist.
return trace.NotFound("empty proxy list")
}
newCurrent := make(map[string]types.Server, len(proxies))
for _, proxy := range proxies {
newCurrent[proxy.GetName()] = proxy

View file

@ -133,14 +133,6 @@ func TestProxyWatcher(t *testing.T) {
require.NoError(t, err)
t.Cleanup(w.Close)
// Since no proxy is yet present, the ProxyWatcher should immediately
// yield back to its retry loop.
select {
case <-w.ResetC:
case <-time.After(time.Second):
t.Fatal("Timeout waiting for ProxyWatcher reset.")
}
// Add a proxy server.
proxy := newProxyServer(t, "proxy1", "127.0.0.1:2023")
require.NoError(t, presence.UpsertProxy(proxy))