fix tests

This commit is contained in:
Sasha Klizhentas 2017-10-12 16:51:18 -07:00
parent 0290cccb57
commit e461b4e6bd
20 changed files with 348 additions and 75 deletions

View file

@ -418,7 +418,7 @@ func (i *TeleInstance) StartProxy(cfg ProxyConfig) error {
tconf.Token = "token"
tconf.Proxy.Enabled = true
tconf.Proxy.SSHAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort))
tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort))
tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.ReverseTunnelPort))
tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.WebPort))
tconf.Proxy.DisableReverseTunnel = false
tconf.Proxy.DisableWebService = true

View file

@ -862,6 +862,7 @@ func (s *IntSuite) TestDiscovery(c *check.C) {
lb, err := utils.NewLoadBalancer(context.TODO(), frontend)
c.Assert(err, check.IsNil)
c.Assert(lb.Listen(), check.IsNil)
go lb.Serve()
defer lb.Close()
remote := NewInstance("cluster-remote", HostID, Host, s.getPorts(5), s.priv, s.pub)
@ -873,7 +874,8 @@ func (s *IntSuite) TestDiscovery(c *check.C) {
c.Assert(main.Create(remote.Secrets.AsSlice(), false, nil), check.IsNil)
mainSecrets := main.Secrets
// switch listen address of the main cluster to load balancer
lb.AddBackend(*utils.MustParseAddr(mainSecrets.ListenAddr))
mainProxyAddr := *utils.MustParseAddr(mainSecrets.ListenAddr)
lb.AddBackend(mainProxyAddr)
mainSecrets.ListenAddr = frontend.String()
c.Assert(remote.Create(mainSecrets.AsSlice(), true, nil), check.IsNil)
@ -892,33 +894,79 @@ func (s *IntSuite) TestDiscovery(c *check.C) {
// start second proxy
nodePorts := s.getPorts(3)
proxyReverseTunnelPort, proxyWebPort, proxySSHPort := nodePorts[0], nodePorts[1], nodePorts[2]
err = main.StartProxy(ProxyConfig{
proxyConfig := ProxyConfig{
Name: "cluster-main-proxy",
SSHPort: proxySSHPort,
WebPort: proxyWebPort,
ReverseTunnelPort: proxyReverseTunnelPort,
})
}
err = main.StartProxy(proxyConfig)
c.Assert(err, check.IsNil)
// add second proxy as a backend to the load balancer
lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.0:%v", proxyReverseTunnelPort)))
lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.1:%v", proxyReverseTunnelPort)))
// execute the connection via first proxy
cmd := []string{"echo", "hello world"}
tc, err := main.NewClient(ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()})
cfg := ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()}
output, err := runCommand(main, []string{"echo", "hello world"}, cfg, 1)
c.Assert(err, check.IsNil)
output := &bytes.Buffer{}
tc.Stdout = output
c.Assert(output, check.Equals, "hello world\n")
// execute the connection via second proxy, should work
cfgProxy := ClientConfig{
Login: username,
Cluster: "cluster-remote",
Host: "127.0.0.1",
Port: remote.GetPortSSHInt(),
Proxy: &proxyConfig,
}
output, err = runCommand(main, []string{"echo", "hello world"}, cfgProxy, 10)
c.Assert(err, check.IsNil)
err = tc.SSH(context.TODO(), cmd, false)
c.Assert(output, check.Equals, "hello world\n")
// now disconnect the main proxy and make sure it will reconnect eventually
lb.RemoveBackend(mainProxyAddr)
// requests going via main proxy will fail
output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 1)
c.Assert(err, check.NotNil)
// requests going via second proxy will succeed
output, err = runCommand(main, []string{"echo", "hello world"}, cfgProxy, 1)
c.Assert(err, check.IsNil)
c.Assert(output.String(), check.Equals, "hello world\n")
c.Assert(output, check.Equals, "hello world\n")
// connect the main proxy back and make sure agents have reconnected over time
lb.AddBackend(mainProxyAddr)
output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 10)
c.Assert(err, check.IsNil)
c.Assert(output, check.Equals, "hello world\n")
// stop cluster and remaining nodes
c.Assert(remote.Stop(true), check.IsNil)
c.Assert(main.Stop(true), check.IsNil)
}
// runCommand is a shortcut for running SSH command, it creates
// a client connected to proxy hosted by instance
// and returns the result
func runCommand(instance *TeleInstance, cmd []string, cfg ClientConfig, attempts int) (string, error) {
tc, err := instance.NewClient(cfg)
if err != nil {
return "", trace.Wrap(err)
}
output := &bytes.Buffer{}
tc.Stdout = output
for i := 0; i < attempts; i++ {
err = tc.SSH(context.TODO(), cmd, false)
if err == nil {
break
}
time.Sleep(time.Millisecond * 50)
}
return output.String(), trace.Wrap(err)
}
// getPorts helper returns a range of unallocated ports available for litening on
func (s *IntSuite) getPorts(num int) []int {
if len(s.ports) < num {

View file

@ -64,6 +64,9 @@ type AccessPoint interface {
// UpsertTunnelConnection upserts tunnel connection
UpsertTunnelConnection(conn services.TunnelConnection) error
// DeleteTunnelConnection deletes tunnel connection
DeleteTunnelConnection(clusterName, connName string) error
// GetTunnelConnections returns tunnel connections for a given cluster
GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error)

View file

@ -105,6 +105,7 @@ func NewAPIServer(config *APIConfig) http.Handler {
srv.POST("/:version/tunnelconnections", srv.withAuth(srv.upsertTunnelConnection))
srv.GET("/:version/tunnelconnections/:cluster", srv.withAuth(srv.getTunnelConnections))
srv.GET("/:version/tunnelconnections", srv.withAuth(srv.getAllTunnelConnections))
srv.DELETE("/:version/tunnelconnections/:cluster/:conn", srv.withAuth(srv.deleteTunnelConnection))
srv.DELETE("/:version/tunnelconnections/:cluster", srv.withAuth(srv.deleteTunnelConnections))
srv.DELETE("/:version/tunnelconnections", srv.withAuth(srv.deleteAllTunnelConnections))
@ -1793,6 +1794,15 @@ func (s *APIServer) getAllTunnelConnections(auth ClientI, w http.ResponseWriter,
return items, nil
}
// deleteTunnelConnection deletes tunnel connection by name
func (s *APIServer) deleteTunnelConnection(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
err := auth.DeleteTunnelConnection(p.ByName("cluster"), p.ByName("conn"))
if err != nil {
return nil, trace.Wrap(err)
}
return message("ok"), nil
}
// deleteTunnelConnections deletes all tunnel connections for cluster
func (s *APIServer) deleteTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
err := auth.DeleteTunnelConnections(p.ByName("cluster"))

View file

@ -942,6 +942,13 @@ func (a *AuthWithRoles) GetAllTunnelConnections() ([]services.TunnelConnection,
return a.authServer.GetAllTunnelConnections()
}
func (a *AuthWithRoles) DeleteTunnelConnection(clusterName string, connName string) error {
if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbDelete); err != nil {
return trace.Wrap(err)
}
return a.authServer.DeleteTunnelConnection(clusterName, connName)
}
func (a *AuthWithRoles) DeleteTunnelConnections(clusterName string) error {
if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil {
return trace.Wrap(err)

View file

@ -529,6 +529,18 @@ func (c *Client) GetAllTunnelConnections() ([]services.TunnelConnection, error)
return conns, nil
}
// DeleteTunnelConnection deletes tunnel connection by name
func (c *Client) DeleteTunnelConnection(clusterName string, connName string) error {
if clusterName == "" {
return trace.BadParameter("missing parameter cluster name")
}
if connName == "" {
return trace.BadParameter("missing parameter connection name")
}
_, err := c.Delete(c.Endpoint("tunnelconnections", clusterName, connName))
return trace.Wrap(err)
}
// DeleteTunnelConnections deletes all tunnel connections for cluster
func (c *Client) DeleteTunnelConnections(clusterName string) error {
if clusterName == "" {

View file

@ -83,7 +83,7 @@ func (s *AuthInitSuite) TestReadIdentity(c *C) {
id, err := ReadIdentityFromKeyPair(priv, cert)
c.Assert(err, IsNil)
c.Assert(id.AuthorityDomain, Equals, "example.com")
c.Assert(id.ID, DeepEquals, IdentityID{HostUUID: "id1", Role: teleport.RoleNode})
c.Assert(id.ID, DeepEquals, IdentityID{HostUUID: "id1.example.com", Role: teleport.RoleNode})
c.Assert(id.CertBytes, DeepEquals, cert)
c.Assert(id.KeyBytes, DeepEquals, priv)

View file

@ -858,7 +858,7 @@ func (c *TunClient) GetDialer() AccessPointDialer {
}
time.Sleep(4 * time.Duration(attempt) * dialRetryInterval)
}
c.Error("%v", err)
c.Errorf("%v", err)
return nil, trace.Wrap(err)
}
}

View file

@ -179,7 +179,7 @@ func (s *Suite) TestTTL(c *check.C) {
v, err = s.bk.GetVal(bucket, "key")
c.Assert(trace.IsNotFound(err), check.Equals, true)
c.Assert(err.Error(), check.Equals, `key 'key' is not found`)
c.Assert(err.Error(), check.Equals, `key "key" is not found`)
c.Assert(v, check.IsNil)
}

View file

@ -80,6 +80,8 @@ type AgentConfig struct {
// Clock is a clock passed in tests, if not set wall clock
// will be used
Clock clockwork.Clock
// EventsC is an optional events channel, used for testing purposes
EventsC chan string
}
// CheckAndSetDefaults checks parameters and sets default values
@ -87,9 +89,6 @@ func (a *AgentConfig) CheckAndSetDefaults() error {
if a.Addr.IsEmpty() {
return trace.BadParameter("missing parameter Addr")
}
if a.DiscoveryC == nil {
return trace.BadParameter("missing parameter DiscoveryC")
}
if a.Context == nil {
return trace.BadParameter("missing parameter Context")
}
@ -430,7 +429,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
func (a *Agent) run() {
ticker, err := utils.NewSwitchTicker(defaults.FastAttempts, defaults.NetworkRetryDuration, defaults.NetworkBackoffDuration)
if err != nil {
log.Error("failed to run: %v", err)
log.Errorf("failed to run: %v", err)
return
}
defer ticker.Stop()
@ -473,6 +472,15 @@ func (a *Agent) run() {
} else {
a.setState(agentStateConnected)
}
if a.EventsC != nil {
select {
case a.EventsC <- ConnectedEvent:
case <-a.ctx.Done():
a.Debugf("context is closing")
return
default:
}
}
// start heartbeat even if error happend, it will reconnect
// when this happens, this is #1 issue we have right now with Teleport. So we are making
// it EASY to see in the logs. This condition should never be permanent (repeates
@ -484,6 +492,9 @@ func (a *Agent) run() {
}
}
// ConnectedEvent is used to indicate that reverse tunnel has connected
const ConnectedEvent = "connected"
// processRequests is a blocking function which runs in a loop sending heartbeats
// to the given SSH connection and processes inbound requests from the
// remote proxy

View file

@ -78,6 +78,18 @@ func (s *remoteSite) connectionCount() int {
return len(s.connections)
}
func (s *remoteSite) hasValidConnections() bool {
s.Lock()
defer s.Unlock()
for _, conn := range s.connections {
if !conn.isInvalid() {
return true
}
}
return false
}
func (s *remoteSite) nextConn() (*remoteConn, error) {
s.Lock()
defer s.Unlock()
@ -153,6 +165,12 @@ func (s *remoteSite) registerHeartbeat(t time.Time) {
}
}
// deleteConnectionRecord deletes connection record to let know peer proxies
// that this node lost the connection and needs to be discovered
func (s *remoteSite) deleteConnectionRecord() {
s.srv.AccessPoint.DeleteTunnelConnection(s.connInfo.GetClusterName(), s.connInfo.GetName())
}
func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
defer func() {
s.Infof("cluster connection closed")
@ -165,8 +183,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
return
case req := <-reqC:
if req == nil {
s.Infof("cluster disconnected")
s.Infof("cluster agent disconnected")
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
if !s.hasValidConnections() {
s.Debugf("deleting connection record")
s.deleteConnectionRecord()
}
return
}
var timeSent time.Time

View file

@ -420,6 +420,11 @@ func (s *PresenceService) GetAllTunnelConnections() ([]services.TunnelConnection
return conns, nil
}
// DeleteTunnelConnection deletes tunnel connection by name
func (s *PresenceService) DeleteTunnelConnection(clusterName, connectionName string) error {
return s.DeleteKey([]string{tunnelConnectionsPrefix, clusterName}, connectionName)
}
// DeleteTunnelConnections deletes all tunnel connections for cluster
func (s *PresenceService) DeleteTunnelConnections(clusterName string) error {
err := s.DeleteBucket([]string{tunnelConnectionsPrefix}, clusterName)

View file

@ -105,6 +105,9 @@ type Presence interface {
// GetAllTunnelConnections returns all tunnel connections
GetAllTunnelConnections() ([]TunnelConnection, error)
// DeleteTunnelConnection deletes tunnel connection by name
DeleteTunnelConnection(clusterName string, connName string) error
// DeleteTunnelConnections deletes all tunnel connections for cluster
DeleteTunnelConnections(clusterName string) error

View file

@ -600,4 +600,20 @@ func (s *ServicesTestSuite) TunnelConnectionsCRUD(c *C) {
err = s.PresenceS.DeleteAllTunnelConnections()
c.Assert(err, IsNil)
// test delete individual connection
err = s.PresenceS.UpsertTunnelConnection(conn)
c.Assert(err, IsNil)
out, err = s.PresenceS.GetTunnelConnections(clusterName)
c.Assert(err, IsNil)
c.Assert(len(out), Equals, 1)
fixtures.DeepCompare(c, out[0], conn)
err = s.PresenceS.DeleteTunnelConnection(clusterName, conn.GetName())
c.Assert(err, IsNil)
out, err = s.PresenceS.GetTunnelConnections(clusterName)
c.Assert(err, IsNil)
c.Assert(len(out), Equals, 0)
}

View file

@ -87,6 +87,8 @@ func (s *SrvSuite) SetUpSuite(c *C) {
utils.InitLoggerForTests()
}
const hostID = "00000000-0000-0000-0000-000000000000"
func (s *SrvSuite) SetUpTest(c *C) {
var err error
s.dir = c.MkDir()
@ -155,7 +157,7 @@ func (s *SrvSuite) SetUpTest(c *C) {
// set up host private key and certificate
hpriv, hpub, err := s.a.GenerateKeyPair("")
c.Assert(err, IsNil)
hcert, err := s.a.GenerateHostCert(hpub, "00000000-0000-0000-0000-000000000000", s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
hcert, err := s.a.GenerateHostCert(hpub, hostID, s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
c.Assert(err, IsNil)
// set up user CA and set up a user that has access to the server
@ -469,12 +471,13 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
reverseTunnelPort := s.freePorts[len(s.freePorts)-1]
s.freePorts = s.freePorts[:len(s.freePorts)-1]
reverseTunnelAddress := utils.NetAddr{AddrNetwork: "tcp", Addr: fmt.Sprintf("%v:%v", s.domainName, reverseTunnelPort)}
reverseTunnelServer, err := reversetunnel.NewServer(
reverseTunnelAddress,
[]ssh.Signer{s.signer},
s.roleAuth,
state.NoCache,
)
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ID: s.domainName,
ListenAddr: reverseTunnelAddress,
HostSigners: []ssh.Signer{s.signer},
AccessPoint: s.roleAuth,
NewCachingAccessPoint: state.NoCache,
})
c.Assert(err, IsNil)
c.Assert(reverseTunnelServer.Start(), IsNil)
@ -500,14 +503,14 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
c.Assert(tsrv.Start(), IsNil)
tunClt, err := auth.NewTunClient("test",
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer tunClt.Close()
agentPool, err := reversetunnel.NewAgentPool(reversetunnel.AgentPoolConfig{
Client: tunClt,
HostSigners: []ssh.Signer{s.signer},
HostUUID: s.domainName,
HostUUID: hostID,
AccessPoint: tunClt,
})
c.Assert(err, IsNil)
@ -519,13 +522,27 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
err = agentPool.FetchAndSyncAgents()
c.Assert(err, IsNil)
rsAgent, err := reversetunnel.NewAgent(
reverseTunnelAddress,
"remote",
"localhost",
[]ssh.Signer{s.signer}, tunClt, tunClt)
eventsC := make(chan string, 1)
rsAgent, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{
Context: context.TODO(),
Addr: reverseTunnelAddress,
RemoteCluster: "remote",
Username: hostID,
Signers: []ssh.Signer{s.signer},
Client: tunClt,
AccessPoint: tunClt,
EventsC: eventsC,
})
c.Assert(err, IsNil)
c.Assert(rsAgent.Start(), IsNil)
rsAgent.Start()
timeout := time.After(time.Second)
select {
case event := <-eventsC:
c.Assert(event, Equals, reversetunnel.ConnectedEvent)
case <-timeout:
c.Fatalf("timeout waiting for clusters to connect")
}
sshConfig := &ssh.ClientConfig{
User: s.user,
@ -620,12 +637,13 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
AddrNetwork: "tcp",
Addr: fmt.Sprintf("%v:%v", s.domainName, reverseTunnelPort),
}
reverseTunnelServer, err := reversetunnel.NewServer(
reverseTunnelAddress,
[]ssh.Signer{s.signer},
s.roleAuth,
state.NoCache,
)
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ID: s.domainName,
ListenAddr: reverseTunnelAddress,
HostSigners: []ssh.Signer{s.signer},
AccessPoint: s.roleAuth,
NewCachingAccessPoint: state.NoCache,
})
c.Assert(err, IsNil)
c.Assert(reverseTunnelServer.Start(), IsNil)
@ -652,28 +670,49 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
c.Assert(tsrv.Start(), IsNil)
tunClt, err := auth.NewTunClient("test",
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer tunClt.Close()
// start agent and load balance requests
rsAgent, err := reversetunnel.NewAgent(
reverseTunnelAddress,
"remote",
"localhost",
[]ssh.Signer{s.signer}, tunClt, tunClt)
eventsC := make(chan string, 2)
rsAgent, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{
Context: context.TODO(),
Addr: reverseTunnelAddress,
RemoteCluster: "remote",
Username: hostID,
Signers: []ssh.Signer{s.signer},
Client: tunClt,
AccessPoint: tunClt,
EventsC: eventsC,
})
c.Assert(err, IsNil)
c.Assert(rsAgent.Start(), IsNil)
rsAgent.Start()
rsAgent2, err := reversetunnel.NewAgent(
reverseTunnelAddress,
"remote",
"localhost",
[]ssh.Signer{s.signer}, tunClt, tunClt)
rsAgent2, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{
Context: context.TODO(),
Addr: reverseTunnelAddress,
RemoteCluster: "remote",
Username: hostID,
Signers: []ssh.Signer{s.signer},
Client: tunClt,
AccessPoint: tunClt,
EventsC: eventsC,
})
c.Assert(err, IsNil)
c.Assert(rsAgent2.Start(), IsNil)
rsAgent2.Start()
defer rsAgent2.Close()
timeout := time.After(time.Second)
for i := 0; i < 2; i++ {
select {
case event := <-eventsC:
c.Assert(event, Equals, reversetunnel.ConnectedEvent)
case <-timeout:
c.Fatalf("timeout waiting for clusters to connect")
}
}
sshConfig := &ssh.ClientConfig{
User: s.user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(up.certSigner)},
@ -700,13 +739,14 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
AddrNetwork: "tcp",
Addr: fmt.Sprintf("%v:0", s.domainName),
}
reverseTunnelServer, err := reversetunnel.NewServer(
reverseTunnelAddress,
[]ssh.Signer{s.signer},
s.roleAuth,
state.NoCache,
reversetunnel.DirectSite(s.domainName, s.roleAuth),
)
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ID: s.domainName,
ListenAddr: reverseTunnelAddress,
HostSigners: []ssh.Signer{s.signer},
AccessPoint: s.roleAuth,
NewCachingAccessPoint: state.NoCache,
DirectClusters: []reversetunnel.DirectCluster{{Name: s.domainName, Client: s.roleAuth}},
})
c.Assert(err, IsNil)
proxy, err := New(
@ -731,7 +771,7 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
c.Assert(tsrv.Start(), IsNil)
tunClt, err := auth.NewTunClient("test",
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer tunClt.Close()

View file

@ -499,11 +499,16 @@ func (cs *CachingAuthClient) UpsertTunnelConnection(conn services.TunnelConnecti
return cs.ap.UpsertTunnelConnection(conn)
}
// DeleteTunnelConnection is a part of auth.AccessPoint implementation
func (cs *CachingAuthClient) DeleteTunnelConnection(clusterName, connName string) error {
return cs.ap.DeleteTunnelConnection(clusterName, connName)
}
// try calls a given function f and checks for errors. If f() fails, the current
// time is recorded. Future calls to f will be ingored until sufficient time passes
// since th last error
func (cs *CachingAuthClient) try(f func() error) error {
tooSoon := cs.lastErrorTime.Add(defaults.NetworkBackoffDuration).After(time.Now())
tooSoon := cs.lastErrorTime.Add(defaults.NetworkRetryDuration).After(time.Now())
if tooSoon {
cs.Warnf("backoff: using cached value due to recent errors")
return trace.ConnectionProblem(fmt.Errorf("backoff"), "backing off due to recent errors")

View file

@ -91,7 +91,6 @@ var (
TunnelConnections = []services.TunnelConnection{
services.MustCreateTunnelConnection("conn1", services.TunnelConnectionSpecV2{
ClusterName: "example.com",
ProxyAddr: "localhost:3025",
ProxyName: "p1",
LastHeartbeat: time.Date(2015, 6, 5, 4, 3, 2, 1, time.UTC).UTC(),
}),
@ -234,7 +233,7 @@ func (s *ClusterSnapshotSuite) TestTry(c *check.C) {
c.Assert(failedCalls, check.Equals, 1)
// "wait" for backoff duration and try again:
ap.lastErrorTime = time.Now().Add(-backoffDuration)
ap.lastErrorTime = time.Now().Add(-defaults.NetworkBackoffDuration)
ap.try(success)
ap.try(failure)

View file

@ -44,6 +44,7 @@ func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr)
"listen": frontend.String(),
},
}),
connections: make(map[NetAddr]map[int64]net.Conn),
}, nil
}
@ -51,6 +52,7 @@ func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr)
// balancer used in tests.
type LoadBalancer struct {
sync.RWMutex
connID int64
*log.Entry
frontend NetAddr
backends []NetAddr
@ -58,6 +60,41 @@ type LoadBalancer struct {
currentIndex int
listener net.Listener
listenerClosed bool
connections map[NetAddr]map[int64]net.Conn
}
// trackeConnection adds connection to the connection tracker
func (l *LoadBalancer) trackConnection(backend NetAddr, conn net.Conn) int64 {
l.Lock()
defer l.Unlock()
l.connID += 1
tracker, ok := l.connections[backend]
if !ok {
tracker = make(map[int64]net.Conn)
l.connections[backend] = tracker
}
tracker[l.connID] = conn
return l.connID
}
// untrackConnection removes connection from connection tracker
func (l *LoadBalancer) untrackConnection(backend NetAddr, id int64) {
l.Lock()
defer l.Unlock()
tracker, ok := l.connections[backend]
if !ok {
return
}
delete(tracker, id)
}
// dropConnections drops connections associated with backend
func (l *LoadBalancer) dropConnections(backend NetAddr) {
tracker := l.connections[backend]
for _, conn := range tracker {
conn.Close()
}
delete(l.connections, backend)
}
// AddBackend adds backend
@ -76,10 +113,10 @@ func (l *LoadBalancer) RemoveBackend(b NetAddr) {
for i := range l.backends {
if l.backends[i].Equals(b) {
l.backends = append(l.backends[:i], l.backends[i+1:]...)
l.dropConnections(b)
return
}
}
}
func (l *LoadBalancer) nextBackend() (*NetAddr, error) {
@ -171,12 +208,18 @@ func (l *LoadBalancer) forward(conn net.Conn) error {
return trace.Wrap(err)
}
connID := l.trackConnection(*backend, conn)
defer l.untrackConnection(*backend, connID)
backendConn, err := net.Dial(backend.AddrNetwork, backend.Addr)
if err != nil {
return trace.ConvertSystemError(err)
}
defer backendConn.Close()
backendConnID := l.trackConnection(*backend, backendConn)
defer l.untrackConnection(*backend, backendConnID)
logger := l.WithFields(log.Fields{
"source": conn.RemoteAddr(),
"dest": backendConn.RemoteAddr(),

View file

@ -172,6 +172,44 @@ func (s *LBSuite) TestClose(c *check.C) {
c.Assert(err, check.NotNil)
}
func (s *LBSuite) TestDropConnections(c *check.C) {
backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "backend 1")
}))
defer backend1.Close()
ports, err := GetFreeTCPPorts(1)
c.Assert(err, check.IsNil)
frontend := localAddr(ports[0])
backendAddr := urlToNetAddr(backend1.URL)
lb, err := NewLoadBalancer(context.TODO(), frontend, backendAddr)
c.Assert(err, check.IsNil)
err = lb.Listen()
c.Assert(err, check.IsNil)
go lb.Serve()
defer lb.Close()
conn, err := net.Dial("tcp", frontend.String())
c.Assert(err, check.IsNil)
defer conn.Close()
out, err := roundtripWithConn(conn)
c.Assert(err, check.IsNil)
c.Assert(out, check.Equals, "backend 1")
// to make sure multiple requests work on the same wire
out, err = roundtripWithConn(conn)
c.Assert(err, check.IsNil)
c.Assert(out, check.Equals, "backend 1")
// removing backend results in dropped connection to this backend
lb.RemoveBackend(backendAddr)
out, err = roundtripWithConn(conn)
c.Assert(err, check.NotNil)
}
func urlToNetAddr(u string) NetAddr {
parsed, err := url.Parse(u)
if err != nil {
@ -196,7 +234,15 @@ func roundtrip(addr string) (string, error) {
return "", err
}
defer conn.Close()
fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
return roundtripWithConn(conn)
}
// roundtripWithConn uses HTTP get on the existing connection
func roundtripWithConn(conn net.Conn) (string, error) {
_, err := fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
if err != nil {
return "", err
}
re, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {

View file

@ -73,6 +73,8 @@ import (
kyaml "k8s.io/apimachinery/pkg/util/yaml"
)
const hostID = "00000000-0000-0000-0000-000000000000"
func TestWeb(t *testing.T) {
TestingT(t)
}
@ -241,7 +243,7 @@ func (s *WebSuite) SetUpTest(c *C) {
hpriv, hpub, err := s.authServer.GenerateKeyPair("")
c.Assert(err, IsNil)
hcert, err := s.authServer.GenerateHostCert(
hpub, "00000000-0000-0000-0000-000000000000", s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
hpub, hostID, s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
c.Assert(err, IsNil)
// set up user CA and set up a user that has access to the server
@ -274,16 +276,17 @@ func (s *WebSuite) SetUpTest(c *C) {
c.Assert(s.node.Start(), IsNil)
// create reverse tunnel service:
revTunServer, err := reversetunnel.NewServer(
utils.NetAddr{
revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{
ID: node.ID(),
ListenAddr: utils.NetAddr{
AddrNetwork: "tcp",
Addr: fmt.Sprintf("%v:0", s.domainName),
},
[]ssh.Signer{s.signer},
s.roleAuth,
state.NoCache,
reversetunnel.DirectSite(s.domainName, s.roleAuth),
)
HostSigners: []ssh.Signer{s.signer},
AccessPoint: s.roleAuth,
NewCachingAccessPoint: state.NoCache,
DirectClusters: []reversetunnel.DirectCluster{{Name: s.domainName, Client: s.roleAuth}},
})
c.Assert(err, IsNil)
apiPort := s.freePorts[len(s.freePorts)-1]
@ -307,7 +310,7 @@ func (s *WebSuite) SetUpTest(c *C) {
// create a tun client
tunClient, err := auth.NewTunClient("test", []utils.NetAddr{tunAddr},
s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
// proxy server: