diff --git a/constants.go b/constants.go index 57161eca496..b5e8eb35080 100644 --- a/constants.go +++ b/constants.go @@ -57,10 +57,15 @@ const ( // ComponentFields stores component-specific fields ComponentFields = "fields" - // ComponentReverseTunnel is reverse tunnel agent and server - // that together establish a bi-directional SSH revers tunnel + // ComponentReverseTunnelServer is reverse tunnel server + // that together with agent establish a bi-directional SSH revers tunnel // to bypass firewall restrictions - ComponentReverseTunnel = "reversetunnel" + ComponentReverseTunnelServer = "proxy:server" + + // ComponentReverseTunnel is reverse tunnel agent + // that together with server establish a bi-directional SSH revers tunnel + // to bypass firewall restrictions + ComponentReverseTunnelAgent = "proxy:agent" // ComponentAuth is the cluster CA node (auth server API) ComponentAuth = "auth" diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 5ada5844b49..18e47441fab 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -41,60 +41,84 @@ import ( "golang.org/x/crypto/ssh" ) +// AgentConfig holds configuration for agent +type AgentConfig struct { + // Addr is target address to dial + Addr utils.NetAddr + // RemoteCluster is a remote cluster name to connect to + RemoteCluster string + // Signers contains authentication signers + Signers []ssh.Signer + // Client is a client to the local auth servers + Client *auth.TunClient + // AccessPoint is a caching access point to the local auth servers + AccessPoint auth.AccessPoint + // Context is a parent context + Context context.Context + // DiscoveryC is a channel that receives discovery requests + // from reverse tunnel server + DiscoveryC chan *discoveryRequest + // Username is the name of this client used to authenticate on SSH + Username string +} + +// CheckAndSetDefaults checks parameters and sets default values +func (a *AgentConfig) CheckAndSetDefaults() error { + if a.Addr.IsEmpty() { + return trace.BadParameter("missing parameter Addr") + } + if a.DiscoveryC == nil { + return trace.BadParameter("missing parameter DiscoveryC") + } + if a.Context == nil { + return trace.BadParameter("missing parameter Context") + } + if a.Client == nil { + return trace.BadParameter("missing parameter Client") + } + if a.AccessPoint == nil { + return trace.BadParameter("missing parameter AccessPoint") + } + if len(a.Signers) == 0 { + return trace.BadParameter("missing parameter Signers") + } + if len(a.Username) == 0 { + return trace.BadParameter("missing parameter Username") + } + return nil +} + // Agent is a reverse tunnel agent running as a part of teleport Proxies // to establish outbound reverse tunnels to remote proxies type Agent struct { - log *log.Entry - addr utils.NetAddr - clt *auth.TunClient - // domain name of the tunnel server, used only for debugging & logging - remoteDomainName string - // clientName format is "hostid.domain" (where 'domain' is local domain name) - clientName string + *log.Entry + AgentConfig ctx context.Context cancel context.CancelFunc hostKeyCallback utils.HostKeyCallback authMethods []ssh.AuthMethod - accessPoint auth.AccessPoint } -// AgentOption specifies parameter that could be passed to Agents -type AgentOption func(a *Agent) error - // NewAgent returns a new reverse tunnel agent // Parameters: // addr points to the remote reverse tunnel server // remoteDomainName is the domain name of the runnel server, used only for logging // clientName is hostid.domain (where 'domain' is local domain name) -func NewAgent( - addr utils.NetAddr, - remoteDomainName string, - clientName string, - signers []ssh.Signer, - clt *auth.TunClient, - accessPoint auth.AccessPoint, parentContext context.Context) (*Agent, error) { - - log.Debugf("reversetunnel.NewAgent %s -> %s", clientName, remoteDomainName) - - ctx, cancel := context.WithCancel(parentContext) +func NewAgent(cfg AgentConfig) (*Agent, error) { + ctx, cancel := context.WithCancel(cfg.Context) a := &Agent{ - log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + AgentConfig: cfg, + Entry: log.WithFields(log.Fields{ + teleport.Component: teleport.ComponentReverseTunnelAgent, teleport.ComponentFields: map[string]interface{}{ - "side": "agent", - "remote": addr.String(), - "mode": "agent", + "remote": cfg.Addr.String(), + "client": cfg.Username, }, }), - clt: clt, - addr: addr, - remoteDomainName: remoteDomainName, - clientName: clientName, - authMethods: []ssh.AuthMethod{ssh.PublicKeys(signers...)}, - accessPoint: accessPoint, - ctx: ctx, - cancel: cancel, + ctx: ctx, + cancel: cancel, + authMethods: []ssh.AuthMethod{ssh.PublicKeys(cfg.Signers...)}, } a.hostKeyCallback = a.checkHostSignature return a, nil @@ -110,8 +134,7 @@ func (a *Agent) Close() error { func (a *Agent) Start() error { conn, err := a.connect() if err != nil { - log.Errorf("Failed to create remote tunnel for %v on %s(%s): %v", - a.clientName, a.remoteDomainName, a.addr.FullAddress(), err) + a.Warningf("Failed to create remote tunnel: %v", err) } // start heartbeat even if error happend, it will reconnect go a.runHeartbeat(conn) @@ -125,7 +148,7 @@ func (a *Agent) Wait() error { // String returns debug-friendly func (a *Agent) String() string { - return fmt.Sprintf("tunagent(remote=%s)", a.addr.String()) + return fmt.Sprintf("tunagent(remote=%s)", a.Addr.String()) } func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.PublicKey) error { @@ -133,7 +156,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub if !ok { return trace.BadParameter("expected certificate") } - cas, err := a.accessPoint.GetCertAuthorities(services.HostCA, false) + cas, err := a.AccessPoint.GetCertAuthorities(services.HostCA, false) if err != nil { return trace.Wrap(err, "failed to fetch remote certs") } @@ -144,7 +167,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub } for _, checker := range checkers { if sshutils.KeysEqual(checker, cert.SignatureKey) { - a.log.Debugf("matched key %v for %v", ca.GetName(), hostport) + a.Debugf("matched key %v for %v", ca.GetName(), hostport) return nil } } @@ -154,14 +177,11 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub } func (a *Agent) connect() (conn *ssh.Client, err error) { - if a.addr.IsEmpty() { - return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty") - } for _, authMethod := range a.authMethods { // if http_proxy is set, dial through the proxy dialer := proxy.DialerFromEnvironment() - conn, err = dialer.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{ - User: a.clientName, + conn, err = dialer.Dial(a.Addr.AddrNetwork, a.Addr.Addr, &ssh.ClientConfig{ + User: a.Username, Auth: []ssh.AuthMethod{authMethod}, HostKeyCallback: a.hostKeyCallback, Timeout: defaults.DefaultDialTimeout, @@ -174,12 +194,12 @@ func (a *Agent) connect() (conn *ssh.Client, err error) { } func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { - log.Debugf("[HA Agent] proxyAccessPoint") + a.Debugf("proxyAccessPoint") defer ch.Close() - conn, err := a.clt.GetDialer()() + conn, err := a.Client.GetDialer()() if err != nil { - a.log.Errorf("error dialing: %v", err) + a.Warningf("error dialing: %v", err) return } @@ -215,7 +235,7 @@ func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { // ch : SSH channel which received "teleport-transport" out-of-band request // reqC : request payload func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { - log.Debugf("[HA Agent] proxyTransport") + a.Debugf("proxyTransport") defer ch.Close() // always push space into stderr to make sure the caller can always @@ -226,15 +246,15 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { var req *ssh.Request select { case <-a.ctx.Done(): - a.log.Infof("is closed, returning") + a.Infof("is closed, returning") return case req = <-reqC: if req == nil { - a.log.Infof("connection closed, returning") + a.Infof("connection closed, returning") return } case <-time.After(defaults.DefaultDialTimeout): - a.log.Errorf("timeout waiting for dial") + a.Warningf("timeout waiting for dial") return } @@ -245,9 +265,9 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { // list of auth servers and return that. otherwise try and connect to the // passed in server. if server == RemoteAuthServer { - authServers, err := a.clt.GetAuthServers() + authServers, err := a.Client.GetAuthServers() if err != nil { - a.log.Errorf("unable to find auth servers: %v", err) + a.Warningf("unable to find auth servers: %v", err) return } for _, as := range authServers { @@ -257,7 +277,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { servers = append(servers, server) } - log.Debugf("got out of band request %v", servers) + a.Debugf("got out of band request %v", servers) var conn net.Conn var err error @@ -284,7 +304,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { // successfully dialed req.Reply(true, []byte("connected")) - a.log.Infof("successfully dialed to %v, start proxying", server) + a.Debugf("successfully dialed to %v, start proxying", server) wg := sync.WaitGroup{} wg.Add(2) @@ -317,7 +337,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { if conn == nil { return trace.Errorf("heartbeat cannot ping: need to reconnect") } - log.Infof("[TUNNEL CLIENT] connected to %s", conn.RemoteAddr()) + a.Infof("connected to %s", conn.RemoteAddr()) defer conn.Close() hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) if err != nil { @@ -336,26 +356,26 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { return nil // time to ping: case <-ticker.C: - log.Debugf("[TUNNEL CLIENT] pings \"%s\" at %s", a.remoteDomainName, conn.RemoteAddr()) _, err := hb.SendRequest("ping", false, nil) if err != nil { log.Error(err) return trace.Wrap(err) } + a.Debugf("ping -> %v", conn.RemoteAddr()) // ssh channel closed: case req := <-reqC: if req == nil { - return trace.Errorf("heartbeat: connection closed") + return trace.ConnectionProblem(nil, "heartbeat: connection closed") } // new access point request: case nch := <-newAccesspointC: if nch == nil { continue } - a.log.Infof("[TUNNEL CLIENT] access point request: %v", nch.ChannelType()) + a.Debugf("access point request: %v", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { - a.log.Errorf("failed to accept request: %v", err) + a.Warningf("failed to accept request: %v", err) continue } go a.proxyAccessPoint(ch, req) @@ -364,10 +384,22 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { if nch == nil { continue } - a.log.Infof("[TUNNEL CLIENT] transport request: %v", nch.ChannelType()) + a.Debugf("transport request: %v", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { - a.log.Errorf("failed to accept request: %v", err) + a.Warningf("failed to accept request: %v", err) + continue + } + go a.proxyTransport(ch, req) + // new discovery request + case nch := <-newTransportC: + if nch == nil { + continue + } + a.Debugf("transport request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.Warningf("failed to accept request: %v", err) continue } go a.proxyTransport(ch, req) @@ -398,11 +430,50 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { } } +// handleDisovery receives discovery requests from the reverse tunnel +// server, that informs agent about proxies registered in the remote +// cluster and the reverse tunnels already established +// +// ch : SSH channel which received "teleport-transport" out-of-band request +// reqC : request payload +func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { + a.Debugf("handleDiscovery") + defer ch.Close() + + for { + var req *ssh.Request + select { + case <-a.ctx.Done(): + a.Infof("is closed, returning") + return + case req = <-reqC: + if req == nil { + a.Infof("connection closed, returning") + return + } + r, err := unmarshalDiscoveryRequest(req.Payload) + if err != nil { + a.Warningf("bad payload: %v", err) + return + } + select { + case a.DiscoveryC <- r: + case <-a.ctx.Done(): + a.Infof("is closed, returning") + return + default: + } + req.Reply(true, []byte("thanks")) + } + } +} + const ( chanHeartbeat = "teleport-heartbeat" chanAccessPoint = "teleport-access-point" chanTransport = "teleport-transport" chanTransportDialReq = "teleport-transport-dial" + chanDiscovery = "teleport-discovery" ) const ( diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 069440b7001..4ccd2f74a37 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -24,10 +24,11 @@ import ( type AgentPool struct { sync.Mutex *log.Entry - cfg AgentPoolConfig - agents map[agentKey]*Agent - ctx context.Context - cancel context.CancelFunc + cfg AgentPoolConfig + agents map[agentKey]*Agent + ctx context.Context + cancel context.CancelFunc + discoveryC chan *discoveryRequest } // AgentPoolConfig holds configuration parameters for the agent pool @@ -68,19 +69,19 @@ func (cfg *AgentPoolConfig) CheckAndSetDefaults() error { // NewAgentPool returns new isntance of the agent pool func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } ctx, cancel := context.WithCancel(cfg.Context) pool := &AgentPool{ - agents: make(map[agentKey]*Agent), - cfg: cfg, - ctx: ctx, - cancel: cancel, + agents: make(map[agentKey]*Agent), + cfg: cfg, + ctx: ctx, + cancel: cancel, + discoveryC: make(chan *discoveryRequest), } pool.Entry = log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, - teleport.ComponentFields: map[string]interface{}{ - "side": "agent", - "mode": "agentpool", - }, + teleport.Component: teleport.ComponentReverseTunnelAgent, }) return pool, nil } @@ -105,6 +106,22 @@ func (m *AgentPool) Wait() error { return nil } +func (m *AgentPool) processDiscoveryRequests() { + for { + select { + case <-m.ctx.Done(): + m.Debugf("closing") + return + case req := <-m.discoveryC: + if req == nil { + m.Debugf("channel closed") + return + } + m.Debugf("got discovery request, following proxies are not connected %v", req.Proxies) + } + } +} + // FetchAndSyncAgents executes one time fetch and sync request // (used in tests instead of polling) func (m *AgentPool) FetchAndSyncAgents() error { @@ -160,7 +177,16 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { for _, key := range agentsToAdd { m.Debugf("adding %v", &key) - agent, err := NewAgent(key.addr, key.domainName, m.cfg.HostUUID, m.cfg.HostSigners, m.cfg.Client, m.cfg.AccessPoint, m.ctx) + agent, err := NewAgent(AgentConfig{ + Addr: key.addr, + RemoteCluster: key.domainName, + Username: m.cfg.HostUUID, + Signers: m.cfg.HostSigners, + Client: m.cfg.Client, + AccessPoint: m.cfg.AccessPoint, + Context: m.ctx, + DiscoveryC: m.discoveryC, + }) if err != nil { return trace.Wrap(err) } diff --git a/lib/reversetunnel/discovery.go b/lib/reversetunnel/discovery.go new file mode 100644 index 00000000000..1f3584ffecd --- /dev/null +++ b/lib/reversetunnel/discovery.go @@ -0,0 +1,52 @@ +package reversetunnel + +import ( + "encoding/json" + + "github.com/gravitational/teleport/lib/services" + + "github.com/gravitational/trace" +) + +type discoveryRequest struct { + Proxies []services.Server `json:"proxies"` +} + +type discoveryRequestRaw struct { + Proxies []json.RawMessage `json:"proxies"` +} + +func marshalDiscoveryRequest(req discoveryRequest) ([]byte, error) { + var out discoveryRequestRaw + m := services.GetServerMarshaler() + for _, p := range req.Proxies { + data, err := m.MarshalServer(p) + if err != nil { + return nil, trace.Wrap(err) + } + out.Proxies = append(out.Proxies, data) + } + + return json.Marshal(out) +} + +func unmarshalDiscoveryRequest(data []byte) (*discoveryRequest, error) { + if len(data) == 0 { + return nil, trace.BadParameter("missing payload") + } + var raw discoveryRequestRaw + err := json.Unmarshal(data, &raw) + if err != nil { + return nil, trace.Wrap(err) + } + m := services.GetServerMarshaler() + var out discoveryRequest + for _, bytes := range raw.Proxies { + proxy, err := m.UnmarshalServer([]byte(bytes), services.KindProxy) + if err != nil { + return nil, trace.Wrap(err) + } + out.Proxies = append(out.Proxies, proxy) + } + return &out, nil +} diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 9f059684778..5a38d4871f5 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -40,11 +40,9 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi accessPoint: accessPoint, domainName: domainName, log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + teleport.Component: teleport.ComponentReverseTunnelServer, teleport.ComponentFields: map[string]string{ - "domainName": domainName, - "side": "server", - "type": "localSite", + "cluster": domainName, }, }), }, nil @@ -77,7 +75,7 @@ func (s *localSite) GetClient() (auth.ClientI, error) { } func (s *localSite) String() string { - return fmt.Sprintf("localSite(%v)", s.domainName) + return fmt.Sprintf("local(%v)", s.domainName) } func (s *localSite) GetStatus() string { @@ -94,7 +92,7 @@ func (s *localSite) GetLastConnected() time.Time { // Dial dials a given host in this site (cluster). func (s *localSite) Dial(from net.Addr, to net.Addr) (net.Conn, error) { - s.log.Debugf("[PROXY] localSite.Dial(from=%v, to=%v)", from, to) + s.log.Debugf("local.Dial(from=%v, to=%v)", from, to) return net.Dial(to.Network(), to.String()) } diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 4e6b26ed42b..7035fae25a5 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -127,10 +127,9 @@ func newClusterPeer(srv *server, connInfo services.TunnelConnection) (*clusterPe srv: srv, connInfo: connInfo, log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + teleport.Component: teleport.ComponentReverseTunnelServer, teleport.ComponentFields: map[string]string{ "cluster": connInfo.GetClusterName(), - "side": "server", }, }), } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 46be0a0e9e5..a299509abb2 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -14,9 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ + package reversetunnel import ( + "context" "fmt" "io/ioutil" "net" @@ -44,17 +46,17 @@ import ( type remoteSite struct { sync.Mutex - log *log.Entry + *log.Entry domainName string connections []*remoteConn lastUsed int lastActive time.Time srv *server - transport *http.Transport clt *auth.Client accessPoint auth.AccessPoint connInfo services.TunnelConnection + ctx context.Context } func (s *remoteSite) CachingAccessPoint() (auth.AccessPoint, error) { @@ -100,7 +102,7 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro rc := &remoteConn{ sshConn: sshConn, conn: conn, - log: s.log, + log: s.Entry, } s.Lock() @@ -114,7 +116,7 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, error) { conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName) if err != nil { - s.log.Warningf("[TUNNEL] failed to fetch tunnel statuses: %v", err) + s.Warningf("failed to fetch tunnel statuses: %v", err) return nil, trace.Wrap(err) } var lastConn services.TunnelConnection @@ -125,7 +127,7 @@ func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, err } } if lastConn == nil { - return nil, trace.NotFound("no connections from %v found in the cluster", s.domainName) + return nil, trace.NotFound("no connections found") } return lastConn, nil } @@ -146,24 +148,27 @@ func (s *remoteSite) registerHeartbeat(t time.Time) { s.connInfo.SetLastHeartbeat(t) err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo) if err != nil { - log.Warningf("[TUNNEL] failed to register heartbeat: %v", err) + log.Warningf("failed to register heartbeat: %v", err) } } func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { defer func() { - s.log.Infof("[TUNNEL] cluster connection closed: %v", s.domainName) + s.Infof("cluster connection closed") conn.Close() }() for { select { + case <-s.ctx.Done(): + s.Infof("closing") + return case req := <-reqC: if req == nil { - s.log.Infof("[TUNNEL] cluster disconnected: %v", s.domainName) + s.Infof("cluster disconnected") conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) return } - log.Debugf("[TUNNEL] ping from \"%s\" %s", s.domainName, conn.conn.RemoteAddr()) + s.Debugf("ping <- %v", conn.conn.RemoteAddr()) go s.registerHeartbeat(time.Now()) case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod): conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats")) @@ -183,10 +188,96 @@ func (s *remoteSite) GetLastConnected() time.Time { return connInfo.GetLastHeartbeat() } +func (s *remoteSite) periodicSendDiscoveryRequests() { + ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) + defer ticker.Stop() + if err := s.sendDiscoveryRequest(); err != nil { + s.Warningf("failed to fetch cluster peers: %v", err) + } + for { + select { + case <-s.ctx.Done(): + s.Debugf("closing") + return + case <-ticker.C: + err := s.sendDiscoveryRequest() + if err != nil { + s.Warningf("could not send discovery request: %v", err) + } + } + } +} + +// findDisconnectedProxies +func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { + conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName) + if err != nil { + return nil, trace.Wrap(err) + } + connected := make(map[string]bool) + for _, conn := range conns { + connected[conn.GetProxyName()] = true + } + proxies, err := s.srv.AccessPoint.GetProxies() + if err != nil { + return nil, trace.Wrap(err) + } + var missing []services.Server + for i := range proxies { + proxy := proxies[i] + if !connected[proxy.GetName()] { + missing = append(missing, proxy) + } + } + return missing, nil +} + +func (s *remoteSite) sendDiscoveryRequest() error { + disconnectedProxies, err := s.findDisconnectedProxies() + if err != nil { + return trace.Wrap(err) + } + if len(disconnectedProxies) == 0 { + return nil + } + s.Infof("detected disconnected proxies: %v", disconnectedProxies) + req := discoveryRequest{ + Proxies: disconnectedProxies, + } + payload, err := marshalDiscoveryRequest(req) + if err != nil { + return trace.Wrap(err) + } + send := func() error { + remoteConn, err := s.nextConn() + if err != nil { + return trace.Wrap(err) + } + discoveryC, err := remoteConn.openDiscoveryChannel() + if err != nil { + return trace.Wrap(err) + } + _, err = discoveryC.SendRequest("ping", false, payload) + remoteConn.markInvalid(err) + s.Errorf("disconnecting cluster on %v, err: %v", + remoteConn.conn.RemoteAddr(), + err) + return trace.Wrap(err) + } + + for i := 0; i < s.connectionCount(); i++ { + err := send() + if err != nil { + s.Warningf("%v") + } + } + return nil +} + // dialAccessPoint establishes a connection from the proxy (reverse tunnel server) // back into the client using previously established tunnel. func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { - s.log.Infof("[TUNNEL] dial to site '%s'", s.GetName()) + s.Debugf("dialAccessPoint") try := func() (net.Conn, error) { remoteConn, err := s.nextConn() @@ -196,13 +287,12 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { ch, _, err := remoteConn.sshConn.OpenChannel(chanAccessPoint, nil) if err != nil { remoteConn.markInvalid(err) - s.log.Errorf("[TUNNEL] disconnecting site '%s' on %v. Err: %v", - s.GetName(), + s.Errorf("disconnecting cluster on %v, err: %v", remoteConn.conn.RemoteAddr(), err) return nil, trace.Wrap(err) } - s.log.Infof("[TUNNEL] success dialing to site '%s'", s.GetName()) + s.Infof("success dialing to cluster") return utils.NewChConn(remoteConn.sshConn, ch), nil } @@ -222,7 +312,7 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { // located in a remote connected site, the connection goes through the // reverse proxy tunnel. func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { - s.log.Infof("[TUNNEL] dialing %v@%v through the tunnel", to, s.domainName) + s.Debugf("dialing %v through the tunnel", to) stop := false _, addr := to.Network(), to.String() @@ -268,7 +358,7 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { if err == nil { return conn, nil } - s.log.Errorf("[TUNNEL] Dial(addr=%v) failed: %v", addr, err) + s.Warningf("Dial(addr=%v) failed: %v", addr, err) } // didn't connect and no error? this means we didn't have any connected // tunnels to try @@ -279,9 +369,9 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { } func (s *remoteSite) handleAuthProxy(w http.ResponseWriter, r *http.Request) { - s.log.Infof("[TUNNEL] handleAuthProxy()") + s.Debugf("handleAuthProxy()") - fwd, err := forward.New(forward.RoundTripper(s.transport), forward.Logger(s.log)) + fwd, err := forward.New(forward.RoundTripper(s.transport), forward.Logger(s.Entry)) if err != nil { roundtrip.ReplyJSON(w, http.StatusInternalServerError, err.Error()) return diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 7ff735f6486..a9c274d99e8 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -158,7 +158,7 @@ func NewServer(cfg Config) (Server, error) { var err error s, err := sshutils.NewServer( - teleport.ComponentReverseTunnel, + teleport.ComponentReverseTunnelServer, cfg.ListenAddr, srv, cfg.HostSigners, @@ -604,11 +604,29 @@ func (s *server) RemoveSite(domainName string) error { } type remoteConn struct { - sshConn ssh.Conn - conn net.Conn - invalid int32 - log *log.Entry - counter int32 + sshConn ssh.Conn + conn net.Conn + invalid int32 + log *log.Entry + counter int32 + discoveryC ssh.Channel + discoveryErr error +} + +func (rc *remoteConn) openDiscoveryChannel() (ssh.Channel, error) { + if rc.discoveryC != nil { + return rc.discoveryC, nil + } + if rc.discoveryErr != nil { + return nil, trace.Wrap(rc.discoveryErr) + } + discoveryC, _, err := rc.sshConn.OpenChannel(chanDiscovery, nil) + if err != nil { + rc.discoveryErr = err + return nil, trace.Wrap(err) + } + rc.discoveryC = discoveryC + return rc.discoveryC, nil } func (rc *remoteConn) String() string { @@ -616,6 +634,10 @@ func (rc *remoteConn) String() string { } func (rc *remoteConn) Close() error { + if rc.discoveryC != nil { + rc.discoveryC.Close() + rc.discoveryC = nil + } return rc.sshConn.Close() } @@ -645,13 +667,13 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { srv: srv, domainName: domainName, connInfo: connInfo, - log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + Entry: log.WithFields(log.Fields{ + teleport.Component: teleport.ComponentReverseTunnelServer, teleport.ComponentFields: map[string]string{ - "domainName": domainName, - "side": "server", + "cluster": domainName, }, }), + ctx: srv.ctx, } // transport uses connection do dial out to the remote address remoteSite.transport = &http.Transport{ diff --git a/lib/utils/cli.go b/lib/utils/cli.go index 53316b326a3..248eee05b14 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -44,8 +44,7 @@ const ( // InitLogger configures the global logger for a given purpose / verbosity level func InitLogger(purpose LoggingPurpose, level log.Level) { log.StandardLogger().Hooks = make(log.LevelHooks) - formatter := &trace.TextFormatter{} - formatter.DisableTimestamp = true + formatter := &trace.TextFormatter{DisableTimestamp: true} log.SetFormatter(formatter) log.SetLevel(level) diff --git a/vendor/github.com/gravitational/trace/log.go b/vendor/github.com/gravitational/trace/log.go index 58f7858e494..310a90af486 100644 --- a/vendor/github.com/gravitational/trace/log.go +++ b/vendor/github.com/gravitational/trace/log.go @@ -18,7 +18,12 @@ limitations under the License. package trace import ( + "bytes" + "fmt" "regexp" + "sort" + "strings" + "time" log "github.com/sirupsen/logrus" @@ -40,20 +45,55 @@ const ( // TextFormatter is logrus-compatible formatter and adds // file and line details to every logged entry. type TextFormatter struct { - log.TextFormatter + DisableTimestamp bool } // Format implements logrus.Formatter interface and adds file and line func (tf *TextFormatter) Format(e *log.Entry) ([]byte, error) { + var file string if frameNo := findFrame(); frameNo != -1 { t := newTrace(frameNo, nil) - new := e.WithFields(log.Fields{FileField: t.Loc(), FunctionField: t.FuncName()}) - new.Time = e.Time - new.Level = e.Level - new.Message = e.Message - e = new + file = t.Loc() } - return (&tf.TextFormatter).Format(e) + + w := &writer{bytes.Buffer{}} + + // time + if !tf.DisableTimestamp { + w.writeField(e.Time.Format(time.RFC3339)) + } + + // level + w.writeField(strings.ToUpper(padMax(e.Level.String(), 4))) + + // component if present, highly visible + component, ok := e.Data[Component] + if ok { + if w.Len() > 0 { + w.WriteByte(' ') + } + w.WriteByte('[') + w.WriteString(strings.ToUpper(padMax(fmt.Sprintf("%v", component), 11))) + w.WriteByte(']') + } + + // message + if e.Message != "" { + w.writeField(e.Message) + } + + // file, if present + if file != "" { + w.writeField(file) + } + + // rest of the fields + if len(e.Data) > 0 { + w.WriteByte(' ') + w.writeMap(e.Data) + } + w.WriteByte('\n') + return w.Bytes(), nil } // JSONFormatter implements logrus.Formatter interface and adds file and line @@ -91,3 +131,78 @@ func findFrame() int { } return -1 } + +type writer struct { + bytes.Buffer +} + +func (w *writer) writeField(value interface{}) { + if w.Len() > 0 { + w.WriteByte(' ') + } + w.writeValue(value) +} + +func (w *writer) writeValue(value interface{}) { + stringVal, ok := value.(string) + if !ok { + stringVal = fmt.Sprint(value) + } + if !needsQuoting(stringVal) { + w.WriteString(stringVal) + } else { + w.WriteString(fmt.Sprintf("%q", stringVal)) + } +} + +func (w *writer) writeKeyValue(key string, value interface{}) { + if w.Len() > 0 { + w.WriteByte(' ') + } + w.WriteString(key) + w.WriteByte(':') + w.writeValue(value) +} + +func (w *writer) writeMap(m map[string]interface{}) { + if len(m) == 0 { + return + } + keys := make([]string, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + if key == Component { + continue + } + switch val := m[key].(type) { + case map[string]interface{}: + w.WriteString(key) + w.WriteString(":{") + w.writeMap(val) + w.WriteString(" }") + default: + w.writeKeyValue(key, val) + } + } +} + +func needsQuoting(text string) bool { + for _, ch := range text { + if ch < 32 { + return true + } + } + return false +} + +func padMax(in string, chars int) string { + switch { + case len(in) < chars: + return in + strings.Repeat(" ", chars-len(in)) + default: + return in[:chars] + } +}