mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 08:43:58 +00:00
introduce curiosity protocol and fix logs
This commit is contained in:
parent
a2cd00de8f
commit
53f4a0128e
11
constants.go
11
constants.go
|
@ -57,10 +57,15 @@ const (
|
|||
// ComponentFields stores component-specific fields
|
||||
ComponentFields = "fields"
|
||||
|
||||
// ComponentReverseTunnel is reverse tunnel agent and server
|
||||
// that together establish a bi-directional SSH revers tunnel
|
||||
// ComponentReverseTunnelServer is reverse tunnel server
|
||||
// that together with agent establish a bi-directional SSH revers tunnel
|
||||
// to bypass firewall restrictions
|
||||
ComponentReverseTunnel = "reversetunnel"
|
||||
ComponentReverseTunnelServer = "proxy:server"
|
||||
|
||||
// ComponentReverseTunnel is reverse tunnel agent
|
||||
// that together with server establish a bi-directional SSH revers tunnel
|
||||
// to bypass firewall restrictions
|
||||
ComponentReverseTunnelAgent = "proxy:agent"
|
||||
|
||||
// ComponentAuth is the cluster CA node (auth server API)
|
||||
ComponentAuth = "auth"
|
||||
|
|
|
@ -41,60 +41,84 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// AgentConfig holds configuration for agent
|
||||
type AgentConfig struct {
|
||||
// Addr is target address to dial
|
||||
Addr utils.NetAddr
|
||||
// RemoteCluster is a remote cluster name to connect to
|
||||
RemoteCluster string
|
||||
// Signers contains authentication signers
|
||||
Signers []ssh.Signer
|
||||
// Client is a client to the local auth servers
|
||||
Client *auth.TunClient
|
||||
// AccessPoint is a caching access point to the local auth servers
|
||||
AccessPoint auth.AccessPoint
|
||||
// Context is a parent context
|
||||
Context context.Context
|
||||
// DiscoveryC is a channel that receives discovery requests
|
||||
// from reverse tunnel server
|
||||
DiscoveryC chan *discoveryRequest
|
||||
// Username is the name of this client used to authenticate on SSH
|
||||
Username string
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults checks parameters and sets default values
|
||||
func (a *AgentConfig) CheckAndSetDefaults() error {
|
||||
if a.Addr.IsEmpty() {
|
||||
return trace.BadParameter("missing parameter Addr")
|
||||
}
|
||||
if a.DiscoveryC == nil {
|
||||
return trace.BadParameter("missing parameter DiscoveryC")
|
||||
}
|
||||
if a.Context == nil {
|
||||
return trace.BadParameter("missing parameter Context")
|
||||
}
|
||||
if a.Client == nil {
|
||||
return trace.BadParameter("missing parameter Client")
|
||||
}
|
||||
if a.AccessPoint == nil {
|
||||
return trace.BadParameter("missing parameter AccessPoint")
|
||||
}
|
||||
if len(a.Signers) == 0 {
|
||||
return trace.BadParameter("missing parameter Signers")
|
||||
}
|
||||
if len(a.Username) == 0 {
|
||||
return trace.BadParameter("missing parameter Username")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Agent is a reverse tunnel agent running as a part of teleport Proxies
|
||||
// to establish outbound reverse tunnels to remote proxies
|
||||
type Agent struct {
|
||||
log *log.Entry
|
||||
addr utils.NetAddr
|
||||
clt *auth.TunClient
|
||||
// domain name of the tunnel server, used only for debugging & logging
|
||||
remoteDomainName string
|
||||
// clientName format is "hostid.domain" (where 'domain' is local domain name)
|
||||
clientName string
|
||||
*log.Entry
|
||||
AgentConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
hostKeyCallback utils.HostKeyCallback
|
||||
authMethods []ssh.AuthMethod
|
||||
accessPoint auth.AccessPoint
|
||||
}
|
||||
|
||||
// AgentOption specifies parameter that could be passed to Agents
|
||||
type AgentOption func(a *Agent) error
|
||||
|
||||
// NewAgent returns a new reverse tunnel agent
|
||||
// Parameters:
|
||||
// addr points to the remote reverse tunnel server
|
||||
// remoteDomainName is the domain name of the runnel server, used only for logging
|
||||
// clientName is hostid.domain (where 'domain' is local domain name)
|
||||
func NewAgent(
|
||||
addr utils.NetAddr,
|
||||
remoteDomainName string,
|
||||
clientName string,
|
||||
signers []ssh.Signer,
|
||||
clt *auth.TunClient,
|
||||
accessPoint auth.AccessPoint, parentContext context.Context) (*Agent, error) {
|
||||
|
||||
log.Debugf("reversetunnel.NewAgent %s -> %s", clientName, remoteDomainName)
|
||||
|
||||
ctx, cancel := context.WithCancel(parentContext)
|
||||
func NewAgent(cfg AgentConfig) (*Agent, error) {
|
||||
ctx, cancel := context.WithCancel(cfg.Context)
|
||||
|
||||
a := &Agent{
|
||||
log: log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnel,
|
||||
AgentConfig: cfg,
|
||||
Entry: log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnelAgent,
|
||||
teleport.ComponentFields: map[string]interface{}{
|
||||
"side": "agent",
|
||||
"remote": addr.String(),
|
||||
"mode": "agent",
|
||||
"remote": cfg.Addr.String(),
|
||||
"client": cfg.Username,
|
||||
},
|
||||
}),
|
||||
clt: clt,
|
||||
addr: addr,
|
||||
remoteDomainName: remoteDomainName,
|
||||
clientName: clientName,
|
||||
authMethods: []ssh.AuthMethod{ssh.PublicKeys(signers...)},
|
||||
accessPoint: accessPoint,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
authMethods: []ssh.AuthMethod{ssh.PublicKeys(cfg.Signers...)},
|
||||
}
|
||||
a.hostKeyCallback = a.checkHostSignature
|
||||
return a, nil
|
||||
|
@ -110,8 +134,7 @@ func (a *Agent) Close() error {
|
|||
func (a *Agent) Start() error {
|
||||
conn, err := a.connect()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create remote tunnel for %v on %s(%s): %v",
|
||||
a.clientName, a.remoteDomainName, a.addr.FullAddress(), err)
|
||||
a.Warningf("Failed to create remote tunnel: %v", err)
|
||||
}
|
||||
// start heartbeat even if error happend, it will reconnect
|
||||
go a.runHeartbeat(conn)
|
||||
|
@ -125,7 +148,7 @@ func (a *Agent) Wait() error {
|
|||
|
||||
// String returns debug-friendly
|
||||
func (a *Agent) String() string {
|
||||
return fmt.Sprintf("tunagent(remote=%s)", a.addr.String())
|
||||
return fmt.Sprintf("tunagent(remote=%s)", a.Addr.String())
|
||||
}
|
||||
|
||||
func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.PublicKey) error {
|
||||
|
@ -133,7 +156,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub
|
|||
if !ok {
|
||||
return trace.BadParameter("expected certificate")
|
||||
}
|
||||
cas, err := a.accessPoint.GetCertAuthorities(services.HostCA, false)
|
||||
cas, err := a.AccessPoint.GetCertAuthorities(services.HostCA, false)
|
||||
if err != nil {
|
||||
return trace.Wrap(err, "failed to fetch remote certs")
|
||||
}
|
||||
|
@ -144,7 +167,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub
|
|||
}
|
||||
for _, checker := range checkers {
|
||||
if sshutils.KeysEqual(checker, cert.SignatureKey) {
|
||||
a.log.Debugf("matched key %v for %v", ca.GetName(), hostport)
|
||||
a.Debugf("matched key %v for %v", ca.GetName(), hostport)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -154,14 +177,11 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub
|
|||
}
|
||||
|
||||
func (a *Agent) connect() (conn *ssh.Client, err error) {
|
||||
if a.addr.IsEmpty() {
|
||||
return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty")
|
||||
}
|
||||
for _, authMethod := range a.authMethods {
|
||||
// if http_proxy is set, dial through the proxy
|
||||
dialer := proxy.DialerFromEnvironment()
|
||||
conn, err = dialer.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{
|
||||
User: a.clientName,
|
||||
conn, err = dialer.Dial(a.Addr.AddrNetwork, a.Addr.Addr, &ssh.ClientConfig{
|
||||
User: a.Username,
|
||||
Auth: []ssh.AuthMethod{authMethod},
|
||||
HostKeyCallback: a.hostKeyCallback,
|
||||
Timeout: defaults.DefaultDialTimeout,
|
||||
|
@ -174,12 +194,12 @@ func (a *Agent) connect() (conn *ssh.Client, err error) {
|
|||
}
|
||||
|
||||
func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) {
|
||||
log.Debugf("[HA Agent] proxyAccessPoint")
|
||||
a.Debugf("proxyAccessPoint")
|
||||
defer ch.Close()
|
||||
|
||||
conn, err := a.clt.GetDialer()()
|
||||
conn, err := a.Client.GetDialer()()
|
||||
if err != nil {
|
||||
a.log.Errorf("error dialing: %v", err)
|
||||
a.Warningf("error dialing: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -215,7 +235,7 @@ func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) {
|
|||
// ch : SSH channel which received "teleport-transport" out-of-band request
|
||||
// reqC : request payload
|
||||
func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
||||
log.Debugf("[HA Agent] proxyTransport")
|
||||
a.Debugf("proxyTransport")
|
||||
defer ch.Close()
|
||||
|
||||
// always push space into stderr to make sure the caller can always
|
||||
|
@ -226,15 +246,15 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
|||
var req *ssh.Request
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
a.log.Infof("is closed, returning")
|
||||
a.Infof("is closed, returning")
|
||||
return
|
||||
case req = <-reqC:
|
||||
if req == nil {
|
||||
a.log.Infof("connection closed, returning")
|
||||
a.Infof("connection closed, returning")
|
||||
return
|
||||
}
|
||||
case <-time.After(defaults.DefaultDialTimeout):
|
||||
a.log.Errorf("timeout waiting for dial")
|
||||
a.Warningf("timeout waiting for dial")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -245,9 +265,9 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
|||
// list of auth servers and return that. otherwise try and connect to the
|
||||
// passed in server.
|
||||
if server == RemoteAuthServer {
|
||||
authServers, err := a.clt.GetAuthServers()
|
||||
authServers, err := a.Client.GetAuthServers()
|
||||
if err != nil {
|
||||
a.log.Errorf("unable to find auth servers: %v", err)
|
||||
a.Warningf("unable to find auth servers: %v", err)
|
||||
return
|
||||
}
|
||||
for _, as := range authServers {
|
||||
|
@ -257,7 +277,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
|||
servers = append(servers, server)
|
||||
}
|
||||
|
||||
log.Debugf("got out of band request %v", servers)
|
||||
a.Debugf("got out of band request %v", servers)
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
@ -284,7 +304,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
|||
|
||||
// successfully dialed
|
||||
req.Reply(true, []byte("connected"))
|
||||
a.log.Infof("successfully dialed to %v, start proxying", server)
|
||||
a.Debugf("successfully dialed to %v, start proxying", server)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
@ -317,7 +337,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) {
|
|||
if conn == nil {
|
||||
return trace.Errorf("heartbeat cannot ping: need to reconnect")
|
||||
}
|
||||
log.Infof("[TUNNEL CLIENT] connected to %s", conn.RemoteAddr())
|
||||
a.Infof("connected to %s", conn.RemoteAddr())
|
||||
defer conn.Close()
|
||||
hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil)
|
||||
if err != nil {
|
||||
|
@ -336,26 +356,26 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) {
|
|||
return nil
|
||||
// time to ping:
|
||||
case <-ticker.C:
|
||||
log.Debugf("[TUNNEL CLIENT] pings \"%s\" at %s", a.remoteDomainName, conn.RemoteAddr())
|
||||
_, err := hb.SendRequest("ping", false, nil)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
a.Debugf("ping -> %v", conn.RemoteAddr())
|
||||
// ssh channel closed:
|
||||
case req := <-reqC:
|
||||
if req == nil {
|
||||
return trace.Errorf("heartbeat: connection closed")
|
||||
return trace.ConnectionProblem(nil, "heartbeat: connection closed")
|
||||
}
|
||||
// new access point request:
|
||||
case nch := <-newAccesspointC:
|
||||
if nch == nil {
|
||||
continue
|
||||
}
|
||||
a.log.Infof("[TUNNEL CLIENT] access point request: %v", nch.ChannelType())
|
||||
a.Debugf("access point request: %v", nch.ChannelType())
|
||||
ch, req, err := nch.Accept()
|
||||
if err != nil {
|
||||
a.log.Errorf("failed to accept request: %v", err)
|
||||
a.Warningf("failed to accept request: %v", err)
|
||||
continue
|
||||
}
|
||||
go a.proxyAccessPoint(ch, req)
|
||||
|
@ -364,10 +384,22 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) {
|
|||
if nch == nil {
|
||||
continue
|
||||
}
|
||||
a.log.Infof("[TUNNEL CLIENT] transport request: %v", nch.ChannelType())
|
||||
a.Debugf("transport request: %v", nch.ChannelType())
|
||||
ch, req, err := nch.Accept()
|
||||
if err != nil {
|
||||
a.log.Errorf("failed to accept request: %v", err)
|
||||
a.Warningf("failed to accept request: %v", err)
|
||||
continue
|
||||
}
|
||||
go a.proxyTransport(ch, req)
|
||||
// new discovery request
|
||||
case nch := <-newTransportC:
|
||||
if nch == nil {
|
||||
continue
|
||||
}
|
||||
a.Debugf("transport request: %v", nch.ChannelType())
|
||||
ch, req, err := nch.Accept()
|
||||
if err != nil {
|
||||
a.Warningf("failed to accept request: %v", err)
|
||||
continue
|
||||
}
|
||||
go a.proxyTransport(ch, req)
|
||||
|
@ -398,11 +430,50 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) {
|
|||
}
|
||||
}
|
||||
|
||||
// handleDisovery receives discovery requests from the reverse tunnel
|
||||
// server, that informs agent about proxies registered in the remote
|
||||
// cluster and the reverse tunnels already established
|
||||
//
|
||||
// ch : SSH channel which received "teleport-transport" out-of-band request
|
||||
// reqC : request payload
|
||||
func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
|
||||
a.Debugf("handleDiscovery")
|
||||
defer ch.Close()
|
||||
|
||||
for {
|
||||
var req *ssh.Request
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
a.Infof("is closed, returning")
|
||||
return
|
||||
case req = <-reqC:
|
||||
if req == nil {
|
||||
a.Infof("connection closed, returning")
|
||||
return
|
||||
}
|
||||
r, err := unmarshalDiscoveryRequest(req.Payload)
|
||||
if err != nil {
|
||||
a.Warningf("bad payload: %v", err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case a.DiscoveryC <- r:
|
||||
case <-a.ctx.Done():
|
||||
a.Infof("is closed, returning")
|
||||
return
|
||||
default:
|
||||
}
|
||||
req.Reply(true, []byte("thanks"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
chanHeartbeat = "teleport-heartbeat"
|
||||
chanAccessPoint = "teleport-access-point"
|
||||
chanTransport = "teleport-transport"
|
||||
chanTransportDialReq = "teleport-transport-dial"
|
||||
chanDiscovery = "teleport-discovery"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -24,10 +24,11 @@ import (
|
|||
type AgentPool struct {
|
||||
sync.Mutex
|
||||
*log.Entry
|
||||
cfg AgentPoolConfig
|
||||
agents map[agentKey]*Agent
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
cfg AgentPoolConfig
|
||||
agents map[agentKey]*Agent
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
discoveryC chan *discoveryRequest
|
||||
}
|
||||
|
||||
// AgentPoolConfig holds configuration parameters for the agent pool
|
||||
|
@ -68,19 +69,19 @@ func (cfg *AgentPoolConfig) CheckAndSetDefaults() error {
|
|||
|
||||
// NewAgentPool returns new isntance of the agent pool
|
||||
func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) {
|
||||
if err := cfg.CheckAndSetDefaults(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(cfg.Context)
|
||||
pool := &AgentPool{
|
||||
agents: make(map[agentKey]*Agent),
|
||||
cfg: cfg,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
agents: make(map[agentKey]*Agent),
|
||||
cfg: cfg,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
discoveryC: make(chan *discoveryRequest),
|
||||
}
|
||||
pool.Entry = log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnel,
|
||||
teleport.ComponentFields: map[string]interface{}{
|
||||
"side": "agent",
|
||||
"mode": "agentpool",
|
||||
},
|
||||
teleport.Component: teleport.ComponentReverseTunnelAgent,
|
||||
})
|
||||
return pool, nil
|
||||
}
|
||||
|
@ -105,6 +106,22 @@ func (m *AgentPool) Wait() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *AgentPool) processDiscoveryRequests() {
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
m.Debugf("closing")
|
||||
return
|
||||
case req := <-m.discoveryC:
|
||||
if req == nil {
|
||||
m.Debugf("channel closed")
|
||||
return
|
||||
}
|
||||
m.Debugf("got discovery request, following proxies are not connected %v", req.Proxies)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAndSyncAgents executes one time fetch and sync request
|
||||
// (used in tests instead of polling)
|
||||
func (m *AgentPool) FetchAndSyncAgents() error {
|
||||
|
@ -160,7 +177,16 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error {
|
|||
|
||||
for _, key := range agentsToAdd {
|
||||
m.Debugf("adding %v", &key)
|
||||
agent, err := NewAgent(key.addr, key.domainName, m.cfg.HostUUID, m.cfg.HostSigners, m.cfg.Client, m.cfg.AccessPoint, m.ctx)
|
||||
agent, err := NewAgent(AgentConfig{
|
||||
Addr: key.addr,
|
||||
RemoteCluster: key.domainName,
|
||||
Username: m.cfg.HostUUID,
|
||||
Signers: m.cfg.HostSigners,
|
||||
Client: m.cfg.Client,
|
||||
AccessPoint: m.cfg.AccessPoint,
|
||||
Context: m.ctx,
|
||||
DiscoveryC: m.discoveryC,
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
|
52
lib/reversetunnel/discovery.go
Normal file
52
lib/reversetunnel/discovery.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package reversetunnel
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
type discoveryRequest struct {
|
||||
Proxies []services.Server `json:"proxies"`
|
||||
}
|
||||
|
||||
type discoveryRequestRaw struct {
|
||||
Proxies []json.RawMessage `json:"proxies"`
|
||||
}
|
||||
|
||||
func marshalDiscoveryRequest(req discoveryRequest) ([]byte, error) {
|
||||
var out discoveryRequestRaw
|
||||
m := services.GetServerMarshaler()
|
||||
for _, p := range req.Proxies {
|
||||
data, err := m.MarshalServer(p)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
out.Proxies = append(out.Proxies, data)
|
||||
}
|
||||
|
||||
return json.Marshal(out)
|
||||
}
|
||||
|
||||
func unmarshalDiscoveryRequest(data []byte) (*discoveryRequest, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, trace.BadParameter("missing payload")
|
||||
}
|
||||
var raw discoveryRequestRaw
|
||||
err := json.Unmarshal(data, &raw)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
m := services.GetServerMarshaler()
|
||||
var out discoveryRequest
|
||||
for _, bytes := range raw.Proxies {
|
||||
proxy, err := m.UnmarshalServer([]byte(bytes), services.KindProxy)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
out.Proxies = append(out.Proxies, proxy)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
|
@ -40,11 +40,9 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
|
|||
accessPoint: accessPoint,
|
||||
domainName: domainName,
|
||||
log: log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnel,
|
||||
teleport.Component: teleport.ComponentReverseTunnelServer,
|
||||
teleport.ComponentFields: map[string]string{
|
||||
"domainName": domainName,
|
||||
"side": "server",
|
||||
"type": "localSite",
|
||||
"cluster": domainName,
|
||||
},
|
||||
}),
|
||||
}, nil
|
||||
|
@ -77,7 +75,7 @@ func (s *localSite) GetClient() (auth.ClientI, error) {
|
|||
}
|
||||
|
||||
func (s *localSite) String() string {
|
||||
return fmt.Sprintf("localSite(%v)", s.domainName)
|
||||
return fmt.Sprintf("local(%v)", s.domainName)
|
||||
}
|
||||
|
||||
func (s *localSite) GetStatus() string {
|
||||
|
@ -94,7 +92,7 @@ func (s *localSite) GetLastConnected() time.Time {
|
|||
|
||||
// Dial dials a given host in this site (cluster).
|
||||
func (s *localSite) Dial(from net.Addr, to net.Addr) (net.Conn, error) {
|
||||
s.log.Debugf("[PROXY] localSite.Dial(from=%v, to=%v)", from, to)
|
||||
s.log.Debugf("local.Dial(from=%v, to=%v)", from, to)
|
||||
return net.Dial(to.Network(), to.String())
|
||||
}
|
||||
|
||||
|
|
|
@ -127,10 +127,9 @@ func newClusterPeer(srv *server, connInfo services.TunnelConnection) (*clusterPe
|
|||
srv: srv,
|
||||
connInfo: connInfo,
|
||||
log: log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnel,
|
||||
teleport.Component: teleport.ComponentReverseTunnelServer,
|
||||
teleport.ComponentFields: map[string]string{
|
||||
"cluster": connInfo.GetClusterName(),
|
||||
"side": "server",
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
|
|
@ -14,9 +14,11 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
|
||||
*/
|
||||
|
||||
package reversetunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -44,17 +46,17 @@ import (
|
|||
type remoteSite struct {
|
||||
sync.Mutex
|
||||
|
||||
log *log.Entry
|
||||
*log.Entry
|
||||
domainName string
|
||||
connections []*remoteConn
|
||||
lastUsed int
|
||||
lastActive time.Time
|
||||
srv *server
|
||||
|
||||
transport *http.Transport
|
||||
clt *auth.Client
|
||||
accessPoint auth.AccessPoint
|
||||
connInfo services.TunnelConnection
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *remoteSite) CachingAccessPoint() (auth.AccessPoint, error) {
|
||||
|
@ -100,7 +102,7 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro
|
|||
rc := &remoteConn{
|
||||
sshConn: sshConn,
|
||||
conn: conn,
|
||||
log: s.log,
|
||||
log: s.Entry,
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
|
@ -114,7 +116,7 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro
|
|||
func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, error) {
|
||||
conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName)
|
||||
if err != nil {
|
||||
s.log.Warningf("[TUNNEL] failed to fetch tunnel statuses: %v", err)
|
||||
s.Warningf("failed to fetch tunnel statuses: %v", err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
var lastConn services.TunnelConnection
|
||||
|
@ -125,7 +127,7 @@ func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, err
|
|||
}
|
||||
}
|
||||
if lastConn == nil {
|
||||
return nil, trace.NotFound("no connections from %v found in the cluster", s.domainName)
|
||||
return nil, trace.NotFound("no connections found")
|
||||
}
|
||||
return lastConn, nil
|
||||
}
|
||||
|
@ -146,24 +148,27 @@ func (s *remoteSite) registerHeartbeat(t time.Time) {
|
|||
s.connInfo.SetLastHeartbeat(t)
|
||||
err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo)
|
||||
if err != nil {
|
||||
log.Warningf("[TUNNEL] failed to register heartbeat: %v", err)
|
||||
log.Warningf("failed to register heartbeat: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
|
||||
defer func() {
|
||||
s.log.Infof("[TUNNEL] cluster connection closed: %v", s.domainName)
|
||||
s.Infof("cluster connection closed")
|
||||
conn.Close()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.Infof("closing")
|
||||
return
|
||||
case req := <-reqC:
|
||||
if req == nil {
|
||||
s.log.Infof("[TUNNEL] cluster disconnected: %v", s.domainName)
|
||||
s.Infof("cluster disconnected")
|
||||
conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected"))
|
||||
return
|
||||
}
|
||||
log.Debugf("[TUNNEL] ping from \"%s\" %s", s.domainName, conn.conn.RemoteAddr())
|
||||
s.Debugf("ping <- %v", conn.conn.RemoteAddr())
|
||||
go s.registerHeartbeat(time.Now())
|
||||
case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod):
|
||||
conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats"))
|
||||
|
@ -183,10 +188,96 @@ func (s *remoteSite) GetLastConnected() time.Time {
|
|||
return connInfo.GetLastHeartbeat()
|
||||
}
|
||||
|
||||
func (s *remoteSite) periodicSendDiscoveryRequests() {
|
||||
ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod)
|
||||
defer ticker.Stop()
|
||||
if err := s.sendDiscoveryRequest(); err != nil {
|
||||
s.Warningf("failed to fetch cluster peers: %v", err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.Debugf("closing")
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := s.sendDiscoveryRequest()
|
||||
if err != nil {
|
||||
s.Warningf("could not send discovery request: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findDisconnectedProxies
|
||||
func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) {
|
||||
conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
connected := make(map[string]bool)
|
||||
for _, conn := range conns {
|
||||
connected[conn.GetProxyName()] = true
|
||||
}
|
||||
proxies, err := s.srv.AccessPoint.GetProxies()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
var missing []services.Server
|
||||
for i := range proxies {
|
||||
proxy := proxies[i]
|
||||
if !connected[proxy.GetName()] {
|
||||
missing = append(missing, proxy)
|
||||
}
|
||||
}
|
||||
return missing, nil
|
||||
}
|
||||
|
||||
func (s *remoteSite) sendDiscoveryRequest() error {
|
||||
disconnectedProxies, err := s.findDisconnectedProxies()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if len(disconnectedProxies) == 0 {
|
||||
return nil
|
||||
}
|
||||
s.Infof("detected disconnected proxies: %v", disconnectedProxies)
|
||||
req := discoveryRequest{
|
||||
Proxies: disconnectedProxies,
|
||||
}
|
||||
payload, err := marshalDiscoveryRequest(req)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
send := func() error {
|
||||
remoteConn, err := s.nextConn()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
discoveryC, err := remoteConn.openDiscoveryChannel()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
_, err = discoveryC.SendRequest("ping", false, payload)
|
||||
remoteConn.markInvalid(err)
|
||||
s.Errorf("disconnecting cluster on %v, err: %v",
|
||||
remoteConn.conn.RemoteAddr(),
|
||||
err)
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
for i := 0; i < s.connectionCount(); i++ {
|
||||
err := send()
|
||||
if err != nil {
|
||||
s.Warningf("%v")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dialAccessPoint establishes a connection from the proxy (reverse tunnel server)
|
||||
// back into the client using previously established tunnel.
|
||||
func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) {
|
||||
s.log.Infof("[TUNNEL] dial to site '%s'", s.GetName())
|
||||
s.Debugf("dialAccessPoint")
|
||||
|
||||
try := func() (net.Conn, error) {
|
||||
remoteConn, err := s.nextConn()
|
||||
|
@ -196,13 +287,12 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) {
|
|||
ch, _, err := remoteConn.sshConn.OpenChannel(chanAccessPoint, nil)
|
||||
if err != nil {
|
||||
remoteConn.markInvalid(err)
|
||||
s.log.Errorf("[TUNNEL] disconnecting site '%s' on %v. Err: %v",
|
||||
s.GetName(),
|
||||
s.Errorf("disconnecting cluster on %v, err: %v",
|
||||
remoteConn.conn.RemoteAddr(),
|
||||
err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
s.log.Infof("[TUNNEL] success dialing to site '%s'", s.GetName())
|
||||
s.Infof("success dialing to cluster")
|
||||
return utils.NewChConn(remoteConn.sshConn, ch), nil
|
||||
}
|
||||
|
||||
|
@ -222,7 +312,7 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) {
|
|||
// located in a remote connected site, the connection goes through the
|
||||
// reverse proxy tunnel.
|
||||
func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) {
|
||||
s.log.Infof("[TUNNEL] dialing %v@%v through the tunnel", to, s.domainName)
|
||||
s.Debugf("dialing %v through the tunnel", to)
|
||||
stop := false
|
||||
|
||||
_, addr := to.Network(), to.String()
|
||||
|
@ -268,7 +358,7 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) {
|
|||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
s.log.Errorf("[TUNNEL] Dial(addr=%v) failed: %v", addr, err)
|
||||
s.Warningf("Dial(addr=%v) failed: %v", addr, err)
|
||||
}
|
||||
// didn't connect and no error? this means we didn't have any connected
|
||||
// tunnels to try
|
||||
|
@ -279,9 +369,9 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) {
|
|||
}
|
||||
|
||||
func (s *remoteSite) handleAuthProxy(w http.ResponseWriter, r *http.Request) {
|
||||
s.log.Infof("[TUNNEL] handleAuthProxy()")
|
||||
s.Debugf("handleAuthProxy()")
|
||||
|
||||
fwd, err := forward.New(forward.RoundTripper(s.transport), forward.Logger(s.log))
|
||||
fwd, err := forward.New(forward.RoundTripper(s.transport), forward.Logger(s.Entry))
|
||||
if err != nil {
|
||||
roundtrip.ReplyJSON(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
|
|
|
@ -158,7 +158,7 @@ func NewServer(cfg Config) (Server, error) {
|
|||
|
||||
var err error
|
||||
s, err := sshutils.NewServer(
|
||||
teleport.ComponentReverseTunnel,
|
||||
teleport.ComponentReverseTunnelServer,
|
||||
cfg.ListenAddr,
|
||||
srv,
|
||||
cfg.HostSigners,
|
||||
|
@ -604,11 +604,29 @@ func (s *server) RemoveSite(domainName string) error {
|
|||
}
|
||||
|
||||
type remoteConn struct {
|
||||
sshConn ssh.Conn
|
||||
conn net.Conn
|
||||
invalid int32
|
||||
log *log.Entry
|
||||
counter int32
|
||||
sshConn ssh.Conn
|
||||
conn net.Conn
|
||||
invalid int32
|
||||
log *log.Entry
|
||||
counter int32
|
||||
discoveryC ssh.Channel
|
||||
discoveryErr error
|
||||
}
|
||||
|
||||
func (rc *remoteConn) openDiscoveryChannel() (ssh.Channel, error) {
|
||||
if rc.discoveryC != nil {
|
||||
return rc.discoveryC, nil
|
||||
}
|
||||
if rc.discoveryErr != nil {
|
||||
return nil, trace.Wrap(rc.discoveryErr)
|
||||
}
|
||||
discoveryC, _, err := rc.sshConn.OpenChannel(chanDiscovery, nil)
|
||||
if err != nil {
|
||||
rc.discoveryErr = err
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
rc.discoveryC = discoveryC
|
||||
return rc.discoveryC, nil
|
||||
}
|
||||
|
||||
func (rc *remoteConn) String() string {
|
||||
|
@ -616,6 +634,10 @@ func (rc *remoteConn) String() string {
|
|||
}
|
||||
|
||||
func (rc *remoteConn) Close() error {
|
||||
if rc.discoveryC != nil {
|
||||
rc.discoveryC.Close()
|
||||
rc.discoveryC = nil
|
||||
}
|
||||
return rc.sshConn.Close()
|
||||
}
|
||||
|
||||
|
@ -645,13 +667,13 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) {
|
|||
srv: srv,
|
||||
domainName: domainName,
|
||||
connInfo: connInfo,
|
||||
log: log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnel,
|
||||
Entry: log.WithFields(log.Fields{
|
||||
teleport.Component: teleport.ComponentReverseTunnelServer,
|
||||
teleport.ComponentFields: map[string]string{
|
||||
"domainName": domainName,
|
||||
"side": "server",
|
||||
"cluster": domainName,
|
||||
},
|
||||
}),
|
||||
ctx: srv.ctx,
|
||||
}
|
||||
// transport uses connection do dial out to the remote address
|
||||
remoteSite.transport = &http.Transport{
|
||||
|
|
|
@ -44,8 +44,7 @@ const (
|
|||
// InitLogger configures the global logger for a given purpose / verbosity level
|
||||
func InitLogger(purpose LoggingPurpose, level log.Level) {
|
||||
log.StandardLogger().Hooks = make(log.LevelHooks)
|
||||
formatter := &trace.TextFormatter{}
|
||||
formatter.DisableTimestamp = true
|
||||
formatter := &trace.TextFormatter{DisableTimestamp: true}
|
||||
log.SetFormatter(formatter)
|
||||
log.SetLevel(level)
|
||||
|
||||
|
|
129
vendor/github.com/gravitational/trace/log.go
generated
vendored
129
vendor/github.com/gravitational/trace/log.go
generated
vendored
|
@ -18,7 +18,12 @@ limitations under the License.
|
|||
package trace
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
|
@ -40,20 +45,55 @@ const (
|
|||
// TextFormatter is logrus-compatible formatter and adds
|
||||
// file and line details to every logged entry.
|
||||
type TextFormatter struct {
|
||||
log.TextFormatter
|
||||
DisableTimestamp bool
|
||||
}
|
||||
|
||||
// Format implements logrus.Formatter interface and adds file and line
|
||||
func (tf *TextFormatter) Format(e *log.Entry) ([]byte, error) {
|
||||
var file string
|
||||
if frameNo := findFrame(); frameNo != -1 {
|
||||
t := newTrace(frameNo, nil)
|
||||
new := e.WithFields(log.Fields{FileField: t.Loc(), FunctionField: t.FuncName()})
|
||||
new.Time = e.Time
|
||||
new.Level = e.Level
|
||||
new.Message = e.Message
|
||||
e = new
|
||||
file = t.Loc()
|
||||
}
|
||||
return (&tf.TextFormatter).Format(e)
|
||||
|
||||
w := &writer{bytes.Buffer{}}
|
||||
|
||||
// time
|
||||
if !tf.DisableTimestamp {
|
||||
w.writeField(e.Time.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// level
|
||||
w.writeField(strings.ToUpper(padMax(e.Level.String(), 4)))
|
||||
|
||||
// component if present, highly visible
|
||||
component, ok := e.Data[Component]
|
||||
if ok {
|
||||
if w.Len() > 0 {
|
||||
w.WriteByte(' ')
|
||||
}
|
||||
w.WriteByte('[')
|
||||
w.WriteString(strings.ToUpper(padMax(fmt.Sprintf("%v", component), 11)))
|
||||
w.WriteByte(']')
|
||||
}
|
||||
|
||||
// message
|
||||
if e.Message != "" {
|
||||
w.writeField(e.Message)
|
||||
}
|
||||
|
||||
// file, if present
|
||||
if file != "" {
|
||||
w.writeField(file)
|
||||
}
|
||||
|
||||
// rest of the fields
|
||||
if len(e.Data) > 0 {
|
||||
w.WriteByte(' ')
|
||||
w.writeMap(e.Data)
|
||||
}
|
||||
w.WriteByte('\n')
|
||||
return w.Bytes(), nil
|
||||
}
|
||||
|
||||
// JSONFormatter implements logrus.Formatter interface and adds file and line
|
||||
|
@ -91,3 +131,78 @@ func findFrame() int {
|
|||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
type writer struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (w *writer) writeField(value interface{}) {
|
||||
if w.Len() > 0 {
|
||||
w.WriteByte(' ')
|
||||
}
|
||||
w.writeValue(value)
|
||||
}
|
||||
|
||||
func (w *writer) writeValue(value interface{}) {
|
||||
stringVal, ok := value.(string)
|
||||
if !ok {
|
||||
stringVal = fmt.Sprint(value)
|
||||
}
|
||||
if !needsQuoting(stringVal) {
|
||||
w.WriteString(stringVal)
|
||||
} else {
|
||||
w.WriteString(fmt.Sprintf("%q", stringVal))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writer) writeKeyValue(key string, value interface{}) {
|
||||
if w.Len() > 0 {
|
||||
w.WriteByte(' ')
|
||||
}
|
||||
w.WriteString(key)
|
||||
w.WriteByte(':')
|
||||
w.writeValue(value)
|
||||
}
|
||||
|
||||
func (w *writer) writeMap(m map[string]interface{}) {
|
||||
if len(m) == 0 {
|
||||
return
|
||||
}
|
||||
keys := make([]string, 0, len(m))
|
||||
for key := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, key := range keys {
|
||||
if key == Component {
|
||||
continue
|
||||
}
|
||||
switch val := m[key].(type) {
|
||||
case map[string]interface{}:
|
||||
w.WriteString(key)
|
||||
w.WriteString(":{")
|
||||
w.writeMap(val)
|
||||
w.WriteString(" }")
|
||||
default:
|
||||
w.writeKeyValue(key, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func needsQuoting(text string) bool {
|
||||
for _, ch := range text {
|
||||
if ch < 32 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func padMax(in string, chars int) string {
|
||||
switch {
|
||||
case len(in) < chars:
|
||||
return in + strings.Repeat(" ", chars-len(in))
|
||||
default:
|
||||
return in[:chars]
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue