mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 17:23:22 +00:00
fix tests
This commit is contained in:
parent
0290cccb57
commit
e461b4e6bd
|
@ -418,7 +418,7 @@ func (i *TeleInstance) StartProxy(cfg ProxyConfig) error {
|
|||
tconf.Token = "token"
|
||||
tconf.Proxy.Enabled = true
|
||||
tconf.Proxy.SSHAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort))
|
||||
tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort))
|
||||
tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.ReverseTunnelPort))
|
||||
tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.WebPort))
|
||||
tconf.Proxy.DisableReverseTunnel = false
|
||||
tconf.Proxy.DisableWebService = true
|
||||
|
|
|
@ -862,6 +862,7 @@ func (s *IntSuite) TestDiscovery(c *check.C) {
|
|||
lb, err := utils.NewLoadBalancer(context.TODO(), frontend)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(lb.Listen(), check.IsNil)
|
||||
go lb.Serve()
|
||||
defer lb.Close()
|
||||
|
||||
remote := NewInstance("cluster-remote", HostID, Host, s.getPorts(5), s.priv, s.pub)
|
||||
|
@ -873,7 +874,8 @@ func (s *IntSuite) TestDiscovery(c *check.C) {
|
|||
c.Assert(main.Create(remote.Secrets.AsSlice(), false, nil), check.IsNil)
|
||||
mainSecrets := main.Secrets
|
||||
// switch listen address of the main cluster to load balancer
|
||||
lb.AddBackend(*utils.MustParseAddr(mainSecrets.ListenAddr))
|
||||
mainProxyAddr := *utils.MustParseAddr(mainSecrets.ListenAddr)
|
||||
lb.AddBackend(mainProxyAddr)
|
||||
mainSecrets.ListenAddr = frontend.String()
|
||||
c.Assert(remote.Create(mainSecrets.AsSlice(), true, nil), check.IsNil)
|
||||
|
||||
|
@ -892,33 +894,79 @@ func (s *IntSuite) TestDiscovery(c *check.C) {
|
|||
// start second proxy
|
||||
nodePorts := s.getPorts(3)
|
||||
proxyReverseTunnelPort, proxyWebPort, proxySSHPort := nodePorts[0], nodePorts[1], nodePorts[2]
|
||||
err = main.StartProxy(ProxyConfig{
|
||||
proxyConfig := ProxyConfig{
|
||||
Name: "cluster-main-proxy",
|
||||
SSHPort: proxySSHPort,
|
||||
WebPort: proxyWebPort,
|
||||
ReverseTunnelPort: proxyReverseTunnelPort,
|
||||
})
|
||||
}
|
||||
err = main.StartProxy(proxyConfig)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// add second proxy as a backend to the load balancer
|
||||
lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.0:%v", proxyReverseTunnelPort)))
|
||||
lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.1:%v", proxyReverseTunnelPort)))
|
||||
|
||||
// execute the connection via first proxy
|
||||
cmd := []string{"echo", "hello world"}
|
||||
tc, err := main.NewClient(ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()})
|
||||
cfg := ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()}
|
||||
output, err := runCommand(main, []string{"echo", "hello world"}, cfg, 1)
|
||||
c.Assert(err, check.IsNil)
|
||||
output := &bytes.Buffer{}
|
||||
tc.Stdout = output
|
||||
c.Assert(output, check.Equals, "hello world\n")
|
||||
|
||||
// execute the connection via second proxy, should work
|
||||
cfgProxy := ClientConfig{
|
||||
Login: username,
|
||||
Cluster: "cluster-remote",
|
||||
Host: "127.0.0.1",
|
||||
Port: remote.GetPortSSHInt(),
|
||||
Proxy: &proxyConfig,
|
||||
}
|
||||
output, err = runCommand(main, []string{"echo", "hello world"}, cfgProxy, 10)
|
||||
c.Assert(err, check.IsNil)
|
||||
err = tc.SSH(context.TODO(), cmd, false)
|
||||
c.Assert(output, check.Equals, "hello world\n")
|
||||
|
||||
// now disconnect the main proxy and make sure it will reconnect eventually
|
||||
lb.RemoveBackend(mainProxyAddr)
|
||||
|
||||
// requests going via main proxy will fail
|
||||
output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 1)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
// requests going via second proxy will succeed
|
||||
output, err = runCommand(main, []string{"echo", "hello world"}, cfgProxy, 1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(output.String(), check.Equals, "hello world\n")
|
||||
c.Assert(output, check.Equals, "hello world\n")
|
||||
|
||||
// connect the main proxy back and make sure agents have reconnected over time
|
||||
lb.AddBackend(mainProxyAddr)
|
||||
output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 10)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(output, check.Equals, "hello world\n")
|
||||
|
||||
// stop cluster and remaining nodes
|
||||
c.Assert(remote.Stop(true), check.IsNil)
|
||||
c.Assert(main.Stop(true), check.IsNil)
|
||||
}
|
||||
|
||||
// runCommand is a shortcut for running SSH command, it creates
|
||||
// a client connected to proxy hosted by instance
|
||||
// and returns the result
|
||||
func runCommand(instance *TeleInstance, cmd []string, cfg ClientConfig, attempts int) (string, error) {
|
||||
tc, err := instance.NewClient(cfg)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
output := &bytes.Buffer{}
|
||||
tc.Stdout = output
|
||||
for i := 0; i < attempts; i++ {
|
||||
err = tc.SSH(context.TODO(), cmd, false)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
}
|
||||
return output.String(), trace.Wrap(err)
|
||||
}
|
||||
|
||||
// getPorts helper returns a range of unallocated ports available for litening on
|
||||
func (s *IntSuite) getPorts(num int) []int {
|
||||
if len(s.ports) < num {
|
||||
|
|
|
@ -64,6 +64,9 @@ type AccessPoint interface {
|
|||
// UpsertTunnelConnection upserts tunnel connection
|
||||
UpsertTunnelConnection(conn services.TunnelConnection) error
|
||||
|
||||
// DeleteTunnelConnection deletes tunnel connection
|
||||
DeleteTunnelConnection(clusterName, connName string) error
|
||||
|
||||
// GetTunnelConnections returns tunnel connections for a given cluster
|
||||
GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error)
|
||||
|
||||
|
|
|
@ -105,6 +105,7 @@ func NewAPIServer(config *APIConfig) http.Handler {
|
|||
srv.POST("/:version/tunnelconnections", srv.withAuth(srv.upsertTunnelConnection))
|
||||
srv.GET("/:version/tunnelconnections/:cluster", srv.withAuth(srv.getTunnelConnections))
|
||||
srv.GET("/:version/tunnelconnections", srv.withAuth(srv.getAllTunnelConnections))
|
||||
srv.DELETE("/:version/tunnelconnections/:cluster/:conn", srv.withAuth(srv.deleteTunnelConnection))
|
||||
srv.DELETE("/:version/tunnelconnections/:cluster", srv.withAuth(srv.deleteTunnelConnections))
|
||||
srv.DELETE("/:version/tunnelconnections", srv.withAuth(srv.deleteAllTunnelConnections))
|
||||
|
||||
|
@ -1793,6 +1794,15 @@ func (s *APIServer) getAllTunnelConnections(auth ClientI, w http.ResponseWriter,
|
|||
return items, nil
|
||||
}
|
||||
|
||||
// deleteTunnelConnection deletes tunnel connection by name
|
||||
func (s *APIServer) deleteTunnelConnection(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
|
||||
err := auth.DeleteTunnelConnection(p.ByName("cluster"), p.ByName("conn"))
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return message("ok"), nil
|
||||
}
|
||||
|
||||
// deleteTunnelConnections deletes all tunnel connections for cluster
|
||||
func (s *APIServer) deleteTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
|
||||
err := auth.DeleteTunnelConnections(p.ByName("cluster"))
|
||||
|
|
|
@ -942,6 +942,13 @@ func (a *AuthWithRoles) GetAllTunnelConnections() ([]services.TunnelConnection,
|
|||
return a.authServer.GetAllTunnelConnections()
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) DeleteTunnelConnection(clusterName string, connName string) error {
|
||||
if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbDelete); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return a.authServer.DeleteTunnelConnection(clusterName, connName)
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) DeleteTunnelConnections(clusterName string) error {
|
||||
if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
|
|
@ -529,6 +529,18 @@ func (c *Client) GetAllTunnelConnections() ([]services.TunnelConnection, error)
|
|||
return conns, nil
|
||||
}
|
||||
|
||||
// DeleteTunnelConnection deletes tunnel connection by name
|
||||
func (c *Client) DeleteTunnelConnection(clusterName string, connName string) error {
|
||||
if clusterName == "" {
|
||||
return trace.BadParameter("missing parameter cluster name")
|
||||
}
|
||||
if connName == "" {
|
||||
return trace.BadParameter("missing parameter connection name")
|
||||
}
|
||||
_, err := c.Delete(c.Endpoint("tunnelconnections", clusterName, connName))
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// DeleteTunnelConnections deletes all tunnel connections for cluster
|
||||
func (c *Client) DeleteTunnelConnections(clusterName string) error {
|
||||
if clusterName == "" {
|
||||
|
|
|
@ -83,7 +83,7 @@ func (s *AuthInitSuite) TestReadIdentity(c *C) {
|
|||
id, err := ReadIdentityFromKeyPair(priv, cert)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(id.AuthorityDomain, Equals, "example.com")
|
||||
c.Assert(id.ID, DeepEquals, IdentityID{HostUUID: "id1", Role: teleport.RoleNode})
|
||||
c.Assert(id.ID, DeepEquals, IdentityID{HostUUID: "id1.example.com", Role: teleport.RoleNode})
|
||||
c.Assert(id.CertBytes, DeepEquals, cert)
|
||||
c.Assert(id.KeyBytes, DeepEquals, priv)
|
||||
|
||||
|
|
|
@ -858,7 +858,7 @@ func (c *TunClient) GetDialer() AccessPointDialer {
|
|||
}
|
||||
time.Sleep(4 * time.Duration(attempt) * dialRetryInterval)
|
||||
}
|
||||
c.Error("%v", err)
|
||||
c.Errorf("%v", err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ func (s *Suite) TestTTL(c *check.C) {
|
|||
|
||||
v, err = s.bk.GetVal(bucket, "key")
|
||||
c.Assert(trace.IsNotFound(err), check.Equals, true)
|
||||
c.Assert(err.Error(), check.Equals, `key 'key' is not found`)
|
||||
c.Assert(err.Error(), check.Equals, `key "key" is not found`)
|
||||
c.Assert(v, check.IsNil)
|
||||
}
|
||||
|
||||
|
|
|
@ -80,6 +80,8 @@ type AgentConfig struct {
|
|||
// Clock is a clock passed in tests, if not set wall clock
|
||||
// will be used
|
||||
Clock clockwork.Clock
|
||||
// EventsC is an optional events channel, used for testing purposes
|
||||
EventsC chan string
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults checks parameters and sets default values
|
||||
|
@ -87,9 +89,6 @@ func (a *AgentConfig) CheckAndSetDefaults() error {
|
|||
if a.Addr.IsEmpty() {
|
||||
return trace.BadParameter("missing parameter Addr")
|
||||
}
|
||||
if a.DiscoveryC == nil {
|
||||
return trace.BadParameter("missing parameter DiscoveryC")
|
||||
}
|
||||
if a.Context == nil {
|
||||
return trace.BadParameter("missing parameter Context")
|
||||
}
|
||||
|
@ -430,7 +429,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
|||
func (a *Agent) run() {
|
||||
ticker, err := utils.NewSwitchTicker(defaults.FastAttempts, defaults.NetworkRetryDuration, defaults.NetworkBackoffDuration)
|
||||
if err != nil {
|
||||
log.Error("failed to run: %v", err)
|
||||
log.Errorf("failed to run: %v", err)
|
||||
return
|
||||
}
|
||||
defer ticker.Stop()
|
||||
|
@ -473,6 +472,15 @@ func (a *Agent) run() {
|
|||
} else {
|
||||
a.setState(agentStateConnected)
|
||||
}
|
||||
if a.EventsC != nil {
|
||||
select {
|
||||
case a.EventsC <- ConnectedEvent:
|
||||
case <-a.ctx.Done():
|
||||
a.Debugf("context is closing")
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
// start heartbeat even if error happend, it will reconnect
|
||||
// when this happens, this is #1 issue we have right now with Teleport. So we are making
|
||||
// it EASY to see in the logs. This condition should never be permanent (repeates
|
||||
|
@ -484,6 +492,9 @@ func (a *Agent) run() {
|
|||
}
|
||||
}
|
||||
|
||||
// ConnectedEvent is used to indicate that reverse tunnel has connected
|
||||
const ConnectedEvent = "connected"
|
||||
|
||||
// processRequests is a blocking function which runs in a loop sending heartbeats
|
||||
// to the given SSH connection and processes inbound requests from the
|
||||
// remote proxy
|
||||
|
|
|
@ -78,6 +78,18 @@ func (s *remoteSite) connectionCount() int {
|
|||
return len(s.connections)
|
||||
}
|
||||
|
||||
func (s *remoteSite) hasValidConnections() bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for _, conn := range s.connections {
|
||||
if !conn.isInvalid() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *remoteSite) nextConn() (*remoteConn, error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
@ -153,6 +165,12 @@ func (s *remoteSite) registerHeartbeat(t time.Time) {
|
|||
}
|
||||
}
|
||||
|
||||
// deleteConnectionRecord deletes connection record to let know peer proxies
|
||||
// that this node lost the connection and needs to be discovered
|
||||
func (s *remoteSite) deleteConnectionRecord() {
|
||||
s.srv.AccessPoint.DeleteTunnelConnection(s.connInfo.GetClusterName(), s.connInfo.GetName())
|
||||
}
|
||||
|
||||
func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
|
||||
defer func() {
|
||||
s.Infof("cluster connection closed")
|
||||
|
@ -165,8 +183,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
|
|||
return
|
||||
case req := <-reqC:
|
||||
if req == nil {
|
||||
s.Infof("cluster disconnected")
|
||||
s.Infof("cluster agent disconnected")
|
||||
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
|
||||
if !s.hasValidConnections() {
|
||||
s.Debugf("deleting connection record")
|
||||
s.deleteConnectionRecord()
|
||||
}
|
||||
return
|
||||
}
|
||||
var timeSent time.Time
|
||||
|
|
|
@ -420,6 +420,11 @@ func (s *PresenceService) GetAllTunnelConnections() ([]services.TunnelConnection
|
|||
return conns, nil
|
||||
}
|
||||
|
||||
// DeleteTunnelConnection deletes tunnel connection by name
|
||||
func (s *PresenceService) DeleteTunnelConnection(clusterName, connectionName string) error {
|
||||
return s.DeleteKey([]string{tunnelConnectionsPrefix, clusterName}, connectionName)
|
||||
}
|
||||
|
||||
// DeleteTunnelConnections deletes all tunnel connections for cluster
|
||||
func (s *PresenceService) DeleteTunnelConnections(clusterName string) error {
|
||||
err := s.DeleteBucket([]string{tunnelConnectionsPrefix}, clusterName)
|
||||
|
|
|
@ -105,6 +105,9 @@ type Presence interface {
|
|||
// GetAllTunnelConnections returns all tunnel connections
|
||||
GetAllTunnelConnections() ([]TunnelConnection, error)
|
||||
|
||||
// DeleteTunnelConnection deletes tunnel connection by name
|
||||
DeleteTunnelConnection(clusterName string, connName string) error
|
||||
|
||||
// DeleteTunnelConnections deletes all tunnel connections for cluster
|
||||
DeleteTunnelConnections(clusterName string) error
|
||||
|
||||
|
|
|
@ -600,4 +600,20 @@ func (s *ServicesTestSuite) TunnelConnectionsCRUD(c *C) {
|
|||
|
||||
err = s.PresenceS.DeleteAllTunnelConnections()
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// test delete individual connection
|
||||
err = s.PresenceS.UpsertTunnelConnection(conn)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
out, err = s.PresenceS.GetTunnelConnections(clusterName)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(len(out), Equals, 1)
|
||||
fixtures.DeepCompare(c, out[0], conn)
|
||||
|
||||
err = s.PresenceS.DeleteTunnelConnection(clusterName, conn.GetName())
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
out, err = s.PresenceS.GetTunnelConnections(clusterName)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(len(out), Equals, 0)
|
||||
}
|
||||
|
|
|
@ -87,6 +87,8 @@ func (s *SrvSuite) SetUpSuite(c *C) {
|
|||
utils.InitLoggerForTests()
|
||||
}
|
||||
|
||||
const hostID = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
func (s *SrvSuite) SetUpTest(c *C) {
|
||||
var err error
|
||||
s.dir = c.MkDir()
|
||||
|
@ -155,7 +157,7 @@ func (s *SrvSuite) SetUpTest(c *C) {
|
|||
// set up host private key and certificate
|
||||
hpriv, hpub, err := s.a.GenerateKeyPair("")
|
||||
c.Assert(err, IsNil)
|
||||
hcert, err := s.a.GenerateHostCert(hpub, "00000000-0000-0000-0000-000000000000", s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
|
||||
hcert, err := s.a.GenerateHostCert(hpub, hostID, s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// set up user CA and set up a user that has access to the server
|
||||
|
@ -469,12 +471,13 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
|
|||
reverseTunnelPort := s.freePorts[len(s.freePorts)-1]
|
||||
s.freePorts = s.freePorts[:len(s.freePorts)-1]
|
||||
reverseTunnelAddress := utils.NetAddr{AddrNetwork: "tcp", Addr: fmt.Sprintf("%v:%v", s.domainName, reverseTunnelPort)}
|
||||
reverseTunnelServer, err := reversetunnel.NewServer(
|
||||
reverseTunnelAddress,
|
||||
[]ssh.Signer{s.signer},
|
||||
s.roleAuth,
|
||||
state.NoCache,
|
||||
)
|
||||
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
|
||||
ID: s.domainName,
|
||||
ListenAddr: reverseTunnelAddress,
|
||||
HostSigners: []ssh.Signer{s.signer},
|
||||
AccessPoint: s.roleAuth,
|
||||
NewCachingAccessPoint: state.NoCache,
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(reverseTunnelServer.Start(), IsNil)
|
||||
|
||||
|
@ -500,14 +503,14 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
|
|||
c.Assert(tsrv.Start(), IsNil)
|
||||
|
||||
tunClt, err := auth.NewTunClient("test",
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer tunClt.Close()
|
||||
|
||||
agentPool, err := reversetunnel.NewAgentPool(reversetunnel.AgentPoolConfig{
|
||||
Client: tunClt,
|
||||
HostSigners: []ssh.Signer{s.signer},
|
||||
HostUUID: s.domainName,
|
||||
HostUUID: hostID,
|
||||
AccessPoint: tunClt,
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
|
@ -519,13 +522,27 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
|
|||
err = agentPool.FetchAndSyncAgents()
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
rsAgent, err := reversetunnel.NewAgent(
|
||||
reverseTunnelAddress,
|
||||
"remote",
|
||||
"localhost",
|
||||
[]ssh.Signer{s.signer}, tunClt, tunClt)
|
||||
eventsC := make(chan string, 1)
|
||||
rsAgent, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{
|
||||
Context: context.TODO(),
|
||||
Addr: reverseTunnelAddress,
|
||||
RemoteCluster: "remote",
|
||||
Username: hostID,
|
||||
Signers: []ssh.Signer{s.signer},
|
||||
Client: tunClt,
|
||||
AccessPoint: tunClt,
|
||||
EventsC: eventsC,
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rsAgent.Start(), IsNil)
|
||||
rsAgent.Start()
|
||||
|
||||
timeout := time.After(time.Second)
|
||||
select {
|
||||
case event := <-eventsC:
|
||||
c.Assert(event, Equals, reversetunnel.ConnectedEvent)
|
||||
case <-timeout:
|
||||
c.Fatalf("timeout waiting for clusters to connect")
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: s.user,
|
||||
|
@ -620,12 +637,13 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
|
|||
AddrNetwork: "tcp",
|
||||
Addr: fmt.Sprintf("%v:%v", s.domainName, reverseTunnelPort),
|
||||
}
|
||||
reverseTunnelServer, err := reversetunnel.NewServer(
|
||||
reverseTunnelAddress,
|
||||
[]ssh.Signer{s.signer},
|
||||
s.roleAuth,
|
||||
state.NoCache,
|
||||
)
|
||||
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
|
||||
ID: s.domainName,
|
||||
ListenAddr: reverseTunnelAddress,
|
||||
HostSigners: []ssh.Signer{s.signer},
|
||||
AccessPoint: s.roleAuth,
|
||||
NewCachingAccessPoint: state.NoCache,
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
c.Assert(reverseTunnelServer.Start(), IsNil)
|
||||
|
@ -652,28 +670,49 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
|
|||
c.Assert(tsrv.Start(), IsNil)
|
||||
|
||||
tunClt, err := auth.NewTunClient("test",
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer tunClt.Close()
|
||||
|
||||
// start agent and load balance requests
|
||||
rsAgent, err := reversetunnel.NewAgent(
|
||||
reverseTunnelAddress,
|
||||
"remote",
|
||||
"localhost",
|
||||
[]ssh.Signer{s.signer}, tunClt, tunClt)
|
||||
eventsC := make(chan string, 2)
|
||||
rsAgent, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{
|
||||
Context: context.TODO(),
|
||||
Addr: reverseTunnelAddress,
|
||||
RemoteCluster: "remote",
|
||||
Username: hostID,
|
||||
Signers: []ssh.Signer{s.signer},
|
||||
Client: tunClt,
|
||||
AccessPoint: tunClt,
|
||||
EventsC: eventsC,
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rsAgent.Start(), IsNil)
|
||||
rsAgent.Start()
|
||||
|
||||
rsAgent2, err := reversetunnel.NewAgent(
|
||||
reverseTunnelAddress,
|
||||
"remote",
|
||||
"localhost",
|
||||
[]ssh.Signer{s.signer}, tunClt, tunClt)
|
||||
rsAgent2, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{
|
||||
Context: context.TODO(),
|
||||
Addr: reverseTunnelAddress,
|
||||
RemoteCluster: "remote",
|
||||
Username: hostID,
|
||||
Signers: []ssh.Signer{s.signer},
|
||||
Client: tunClt,
|
||||
AccessPoint: tunClt,
|
||||
EventsC: eventsC,
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rsAgent2.Start(), IsNil)
|
||||
rsAgent2.Start()
|
||||
defer rsAgent2.Close()
|
||||
|
||||
timeout := time.After(time.Second)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case event := <-eventsC:
|
||||
c.Assert(event, Equals, reversetunnel.ConnectedEvent)
|
||||
case <-timeout:
|
||||
c.Fatalf("timeout waiting for clusters to connect")
|
||||
}
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: s.user,
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(up.certSigner)},
|
||||
|
@ -700,13 +739,14 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
|
|||
AddrNetwork: "tcp",
|
||||
Addr: fmt.Sprintf("%v:0", s.domainName),
|
||||
}
|
||||
reverseTunnelServer, err := reversetunnel.NewServer(
|
||||
reverseTunnelAddress,
|
||||
[]ssh.Signer{s.signer},
|
||||
s.roleAuth,
|
||||
state.NoCache,
|
||||
reversetunnel.DirectSite(s.domainName, s.roleAuth),
|
||||
)
|
||||
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
|
||||
ID: s.domainName,
|
||||
ListenAddr: reverseTunnelAddress,
|
||||
HostSigners: []ssh.Signer{s.signer},
|
||||
AccessPoint: s.roleAuth,
|
||||
NewCachingAccessPoint: state.NoCache,
|
||||
DirectClusters: []reversetunnel.DirectCluster{{Name: s.domainName, Client: s.roleAuth}},
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
proxy, err := New(
|
||||
|
@ -731,7 +771,7 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
|
|||
c.Assert(tsrv.Start(), IsNil)
|
||||
|
||||
tunClt, err := auth.NewTunClient("test",
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer tunClt.Close()
|
||||
|
||||
|
|
|
@ -499,11 +499,16 @@ func (cs *CachingAuthClient) UpsertTunnelConnection(conn services.TunnelConnecti
|
|||
return cs.ap.UpsertTunnelConnection(conn)
|
||||
}
|
||||
|
||||
// DeleteTunnelConnection is a part of auth.AccessPoint implementation
|
||||
func (cs *CachingAuthClient) DeleteTunnelConnection(clusterName, connName string) error {
|
||||
return cs.ap.DeleteTunnelConnection(clusterName, connName)
|
||||
}
|
||||
|
||||
// try calls a given function f and checks for errors. If f() fails, the current
|
||||
// time is recorded. Future calls to f will be ingored until sufficient time passes
|
||||
// since th last error
|
||||
func (cs *CachingAuthClient) try(f func() error) error {
|
||||
tooSoon := cs.lastErrorTime.Add(defaults.NetworkBackoffDuration).After(time.Now())
|
||||
tooSoon := cs.lastErrorTime.Add(defaults.NetworkRetryDuration).After(time.Now())
|
||||
if tooSoon {
|
||||
cs.Warnf("backoff: using cached value due to recent errors")
|
||||
return trace.ConnectionProblem(fmt.Errorf("backoff"), "backing off due to recent errors")
|
||||
|
|
|
@ -91,7 +91,6 @@ var (
|
|||
TunnelConnections = []services.TunnelConnection{
|
||||
services.MustCreateTunnelConnection("conn1", services.TunnelConnectionSpecV2{
|
||||
ClusterName: "example.com",
|
||||
ProxyAddr: "localhost:3025",
|
||||
ProxyName: "p1",
|
||||
LastHeartbeat: time.Date(2015, 6, 5, 4, 3, 2, 1, time.UTC).UTC(),
|
||||
}),
|
||||
|
@ -234,7 +233,7 @@ func (s *ClusterSnapshotSuite) TestTry(c *check.C) {
|
|||
c.Assert(failedCalls, check.Equals, 1)
|
||||
|
||||
// "wait" for backoff duration and try again:
|
||||
ap.lastErrorTime = time.Now().Add(-backoffDuration)
|
||||
ap.lastErrorTime = time.Now().Add(-defaults.NetworkBackoffDuration)
|
||||
|
||||
ap.try(success)
|
||||
ap.try(failure)
|
||||
|
|
|
@ -44,6 +44,7 @@ func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr)
|
|||
"listen": frontend.String(),
|
||||
},
|
||||
}),
|
||||
connections: make(map[NetAddr]map[int64]net.Conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -51,6 +52,7 @@ func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr)
|
|||
// balancer used in tests.
|
||||
type LoadBalancer struct {
|
||||
sync.RWMutex
|
||||
connID int64
|
||||
*log.Entry
|
||||
frontend NetAddr
|
||||
backends []NetAddr
|
||||
|
@ -58,6 +60,41 @@ type LoadBalancer struct {
|
|||
currentIndex int
|
||||
listener net.Listener
|
||||
listenerClosed bool
|
||||
connections map[NetAddr]map[int64]net.Conn
|
||||
}
|
||||
|
||||
// trackeConnection adds connection to the connection tracker
|
||||
func (l *LoadBalancer) trackConnection(backend NetAddr, conn net.Conn) int64 {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.connID += 1
|
||||
tracker, ok := l.connections[backend]
|
||||
if !ok {
|
||||
tracker = make(map[int64]net.Conn)
|
||||
l.connections[backend] = tracker
|
||||
}
|
||||
tracker[l.connID] = conn
|
||||
return l.connID
|
||||
}
|
||||
|
||||
// untrackConnection removes connection from connection tracker
|
||||
func (l *LoadBalancer) untrackConnection(backend NetAddr, id int64) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
tracker, ok := l.connections[backend]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(tracker, id)
|
||||
}
|
||||
|
||||
// dropConnections drops connections associated with backend
|
||||
func (l *LoadBalancer) dropConnections(backend NetAddr) {
|
||||
tracker := l.connections[backend]
|
||||
for _, conn := range tracker {
|
||||
conn.Close()
|
||||
}
|
||||
delete(l.connections, backend)
|
||||
}
|
||||
|
||||
// AddBackend adds backend
|
||||
|
@ -76,10 +113,10 @@ func (l *LoadBalancer) RemoveBackend(b NetAddr) {
|
|||
for i := range l.backends {
|
||||
if l.backends[i].Equals(b) {
|
||||
l.backends = append(l.backends[:i], l.backends[i+1:]...)
|
||||
l.dropConnections(b)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (l *LoadBalancer) nextBackend() (*NetAddr, error) {
|
||||
|
@ -171,12 +208,18 @@ func (l *LoadBalancer) forward(conn net.Conn) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
connID := l.trackConnection(*backend, conn)
|
||||
defer l.untrackConnection(*backend, connID)
|
||||
|
||||
backendConn, err := net.Dial(backend.AddrNetwork, backend.Addr)
|
||||
if err != nil {
|
||||
return trace.ConvertSystemError(err)
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
backendConnID := l.trackConnection(*backend, backendConn)
|
||||
defer l.untrackConnection(*backend, backendConnID)
|
||||
|
||||
logger := l.WithFields(log.Fields{
|
||||
"source": conn.RemoteAddr(),
|
||||
"dest": backendConn.RemoteAddr(),
|
||||
|
|
|
@ -172,6 +172,44 @@ func (s *LBSuite) TestClose(c *check.C) {
|
|||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *LBSuite) TestDropConnections(c *check.C) {
|
||||
backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "backend 1")
|
||||
}))
|
||||
defer backend1.Close()
|
||||
|
||||
ports, err := GetFreeTCPPorts(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
frontend := localAddr(ports[0])
|
||||
|
||||
backendAddr := urlToNetAddr(backend1.URL)
|
||||
lb, err := NewLoadBalancer(context.TODO(), frontend, backendAddr)
|
||||
c.Assert(err, check.IsNil)
|
||||
err = lb.Listen()
|
||||
c.Assert(err, check.IsNil)
|
||||
go lb.Serve()
|
||||
defer lb.Close()
|
||||
|
||||
conn, err := net.Dial("tcp", frontend.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
defer conn.Close()
|
||||
|
||||
out, err := roundtripWithConn(conn)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(out, check.Equals, "backend 1")
|
||||
|
||||
// to make sure multiple requests work on the same wire
|
||||
out, err = roundtripWithConn(conn)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(out, check.Equals, "backend 1")
|
||||
|
||||
// removing backend results in dropped connection to this backend
|
||||
lb.RemoveBackend(backendAddr)
|
||||
out, err = roundtripWithConn(conn)
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func urlToNetAddr(u string) NetAddr {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
|
@ -196,7 +234,15 @@ func roundtrip(addr string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
|
||||
return roundtripWithConn(conn)
|
||||
}
|
||||
|
||||
// roundtripWithConn uses HTTP get on the existing connection
|
||||
func roundtripWithConn(conn net.Conn) (string, error) {
|
||||
_, err := fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
re, err := http.ReadResponse(bufio.NewReader(conn), nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -73,6 +73,8 @@ import (
|
|||
kyaml "k8s.io/apimachinery/pkg/util/yaml"
|
||||
)
|
||||
|
||||
const hostID = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
func TestWeb(t *testing.T) {
|
||||
TestingT(t)
|
||||
}
|
||||
|
@ -241,7 +243,7 @@ func (s *WebSuite) SetUpTest(c *C) {
|
|||
hpriv, hpub, err := s.authServer.GenerateKeyPair("")
|
||||
c.Assert(err, IsNil)
|
||||
hcert, err := s.authServer.GenerateHostCert(
|
||||
hpub, "00000000-0000-0000-0000-000000000000", s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
|
||||
hpub, hostID, s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// set up user CA and set up a user that has access to the server
|
||||
|
@ -274,16 +276,17 @@ func (s *WebSuite) SetUpTest(c *C) {
|
|||
c.Assert(s.node.Start(), IsNil)
|
||||
|
||||
// create reverse tunnel service:
|
||||
revTunServer, err := reversetunnel.NewServer(
|
||||
utils.NetAddr{
|
||||
revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{
|
||||
ID: node.ID(),
|
||||
ListenAddr: utils.NetAddr{
|
||||
AddrNetwork: "tcp",
|
||||
Addr: fmt.Sprintf("%v:0", s.domainName),
|
||||
},
|
||||
[]ssh.Signer{s.signer},
|
||||
s.roleAuth,
|
||||
state.NoCache,
|
||||
reversetunnel.DirectSite(s.domainName, s.roleAuth),
|
||||
)
|
||||
HostSigners: []ssh.Signer{s.signer},
|
||||
AccessPoint: s.roleAuth,
|
||||
NewCachingAccessPoint: state.NoCache,
|
||||
DirectClusters: []reversetunnel.DirectCluster{{Name: s.domainName, Client: s.roleAuth}},
|
||||
})
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
apiPort := s.freePorts[len(s.freePorts)-1]
|
||||
|
@ -307,7 +310,7 @@ func (s *WebSuite) SetUpTest(c *C) {
|
|||
|
||||
// create a tun client
|
||||
tunClient, err := auth.NewTunClient("test", []utils.NetAddr{tunAddr},
|
||||
s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// proxy server:
|
||||
|
|
Loading…
Reference in a new issue