From 7467e47718c8586b84e96008960f62fd58120ff9 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Tue, 30 Apr 2019 16:17:09 -0700 Subject: [PATCH] Cache auth servers and new find endpoint Whenever many IOT style nodes are connecting back to the web proxy server, they all call /find endpoint to discover the configuration. This new endpoint is designed to be fast and not hit the database. In addition to that every proxy reverse tunnel connection handler was fetching auth servers and this commit adds caching for the auth servers on the proxy side. --- lib/auth/api.go | 3 ++ lib/auth/auth_with_roles.go | 20 ++++++++++ lib/auth/clt.go | 10 +++++ lib/cache/cache.go | 6 +++ lib/cache/cache_test.go | 67 ++++++++++++++++++++++++++++++++ lib/cache/collections.go | 71 ++++++++++++++++++++++++++++++++++ lib/client/weblogin.go | 18 +++++++++ lib/reversetunnel/conn.go | 2 +- lib/reversetunnel/srv.go | 2 +- lib/service/connect.go | 4 +- lib/services/local/events.go | 24 ++++++++++++ lib/services/local/presence.go | 12 ++++++ lib/services/presence.go | 6 +++ lib/web/apiserver.go | 12 ++++++ 14 files changed, 253 insertions(+), 4 deletions(-) diff --git a/lib/auth/api.go b/lib/auth/api.go index a3e413f1c29..1d57e57b602 100644 --- a/lib/auth/api.go +++ b/lib/auth/api.go @@ -63,6 +63,9 @@ type ReadAccessPoint interface { // GetProxies returns a list of proxy servers registered in the cluster GetProxies() ([]services.Server, error) + // GetAuthServers returns a list of auth servers registered in the cluster + GetAuthServers() ([]services.Server, error) + // GetCertAuthority returns cert authority by id GetCertAuthority(id services.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (services.CertAuthority, error) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index e03ba890797..7fda67de319 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -392,6 +392,10 @@ func (a *AuthWithRoles) NewWatcher(ctx context.Context, watch services.Watch) (s if err := a.action(defaults.Namespace, services.KindProxy, services.VerbRead); err != nil { return nil, trace.Wrap(err) } + case services.KindAuthServer: + if err := a.action(defaults.Namespace, services.KindAuthServer, services.VerbRead); err != nil { + return nil, trace.Wrap(err) + } case services.KindTunnelConnection: if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbRead); err != nil { return nil, trace.Wrap(err) @@ -552,6 +556,22 @@ func (a *AuthWithRoles) GetAuthServers() ([]services.Server, error) { return a.authServer.GetAuthServers() } +// DeleteAllAuthServers deletes all auth servers +func (a *AuthWithRoles) DeleteAllAuthServers() error { + if err := a.action(defaults.Namespace, services.KindAuthServer, services.VerbDelete); err != nil { + return trace.Wrap(err) + } + return a.authServer.DeleteAllAuthServers() +} + +// DeleteAuthServer deletes auth server by name +func (a *AuthWithRoles) DeleteAuthServer(name string) error { + if err := a.action(defaults.Namespace, services.KindAuthServer, services.VerbDelete); err != nil { + return trace.Wrap(err) + } + return a.authServer.DeleteAuthServer(name) +} + func (a *AuthWithRoles) UpsertProxy(s services.Server) error { if err := a.action(defaults.Namespace, services.KindProxy, services.VerbCreate); err != nil { return trace.Wrap(err) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 80afdf77bf9..87336f784a0 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -1185,6 +1185,16 @@ func (c *Client) GetAuthServers() ([]services.Server, error) { return re, nil } +// DeleteAllAuthServers deletes all auth servers +func (c *Client) DeleteAllAuthServers() error { + return trace.NotImplemented("not implemented") +} + +// DeleteAuthServer deletes auth server by name +func (c *Client) DeleteAuthServer(name string) error { + return trace.NotImplemented("not implemented") +} + // UpsertProxy is used by proxies to report their presence // to other auth servers in form of hearbeat expiring after ttl period. func (c *Client) UpsertProxy(s services.Server) error { diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 1ec91850d44..86b01f1102b 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -64,6 +64,7 @@ func ForProxy(cfg Config) Config { {Kind: services.KindNamespace}, {Kind: services.KindNode}, {Kind: services.KindProxy}, + {Kind: services.KindAuthServer}, {Kind: services.KindReverseTunnel}, {Kind: services.KindTunnelConnection}, } @@ -612,6 +613,11 @@ func (c *Cache) GetNodes(namespace string, opts ...services.MarshalOption) ([]se return c.presenceCache.GetNodes(namespace, opts...) } +// GetAuthServers returns a list of registered servers +func (c *Cache) GetAuthServers() ([]services.Server, error) { + return c.presenceCache.GetAuthServers() +} + // GetReverseTunnels is a part of auth.AccessPoint implementation func (c *Cache) GetReverseTunnels(opts ...services.MarshalOption) ([]services.ReverseTunnel, error) { return c.presenceCache.GetReverseTunnels(services.AddOptions(opts, services.SkipValidation())...) diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index 80a66cf22be..6581b11574b 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1073,6 +1073,73 @@ func (s *CacheSuite) TestProxies(c *check.C) { c.Assert(out, check.HasLen, 0) } +// TestAuthServers tests auth servers cache +func (s *CacheSuite) TestAuthServers(c *check.C) { + p := s.newPackForProxy(c) + defer p.Close() + + server := suite.NewServer(services.KindAuthServer, "srv1", "127.0.0.1:2022", defaults.Namespace) + err := p.presenceS.UpsertAuthServer(server) + c.Assert(err, check.IsNil) + + out, err := p.presenceS.GetAuthServers() + c.Assert(err, check.IsNil) + c.Assert(out, check.HasLen, 1) + srv := out[0] + + select { + case event := <-p.eventsC: + c.Assert(event.Type, check.Equals, EventProcessed) + case <-time.After(time.Second): + c.Fatalf("timeout waiting for event") + } + + out, err = p.cache.GetAuthServers() + c.Assert(err, check.IsNil) + c.Assert(out, check.HasLen, 1) + + srv.SetResourceID(out[0].GetResourceID()) + fixtures.DeepCompare(c, srv, out[0]) + + // update srv parameters + srv.SetAddr("127.0.0.2:2033") + + err = p.presenceS.UpsertAuthServer(srv) + c.Assert(err, check.IsNil) + + out, err = p.presenceS.GetAuthServers() + c.Assert(err, check.IsNil) + c.Assert(out, check.HasLen, 1) + srv = out[0] + + select { + case event := <-p.eventsC: + c.Assert(event.Type, check.Equals, EventProcessed) + case <-time.After(time.Second): + c.Fatalf("timeout waiting for event") + } + + out, err = p.cache.GetAuthServers() + c.Assert(err, check.IsNil) + c.Assert(out, check.HasLen, 1) + + srv.SetResourceID(out[0].GetResourceID()) + fixtures.DeepCompare(c, srv, out[0]) + + err = p.presenceS.DeleteAllAuthServers() + c.Assert(err, check.IsNil) + + select { + case <-p.eventsC: + case <-time.After(time.Second): + c.Fatalf("timeout waiting for event") + } + + out, err = p.cache.GetAuthServers() + c.Assert(err, check.IsNil) + c.Assert(out, check.HasLen, 0) +} + type proxyEvents struct { sync.Mutex watchers []services.Watcher diff --git a/lib/cache/collections.go b/lib/cache/collections.go index b1fdfdbd2af..2c7746c690a 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -93,6 +93,11 @@ func setupCollections(c *Cache, watches []services.WatchKind) (map[string]collec return nil, trace.BadParameter("missing parameter Presence") } collections[watch.Kind] = &proxy{watch: watch, Cache: c} + case services.KindAuthServer: + if c.Presence == nil { + return nil, trace.BadParameter("missing parameter Presence") + } + collections[watch.Kind] = &authServer{watch: watch, Cache: c} case services.KindReverseTunnel: if c.Presence == nil { return nil, trace.BadParameter("missing parameter Presence") @@ -303,6 +308,72 @@ func (c *proxy) watchKind() services.WatchKind { return c.watch } +type authServer struct { + *Cache + watch services.WatchKind +} + +// erase erases all data in the collection +func (c *authServer) erase() error { + if err := c.presenceCache.DeleteAllAuthServers(); err != nil { + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + } + return nil +} + +func (c *authServer) fetch() error { + resources, err := c.Presence.GetAuthServers() + if err != nil { + return trace.Wrap(err) + } + + if err := c.erase(); err != nil { + return trace.Wrap(err) + } + + for _, resource := range resources { + c.setTTL(resource) + if err := c.presenceCache.UpsertAuthServer(resource); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +func (c *authServer) processEvent(event services.Event) error { + switch event.Type { + case backend.OpDelete: + err := c.presenceCache.DeleteAuthServer(event.Resource.GetName()) + if err != nil { + // resource could be missing in the cache + // expired or not created, if the first consumed + // event is delete + if !trace.IsNotFound(err) { + c.Warningf("Failed to delete resource %v.", err) + return trace.Wrap(err) + } + } + case backend.OpPut: + resource, ok := event.Resource.(services.Server) + if !ok { + return trace.BadParameter("unexpected type %T", event.Resource) + } + c.setTTL(resource) + if err := c.presenceCache.UpsertAuthServer(resource); err != nil { + return trace.Wrap(err) + } + default: + c.Warningf("Skipping unsupported event type %v.", event.Type) + } + return nil +} + +func (c *authServer) watchKind() services.WatchKind { + return c.watch +} + type node struct { *Cache watch services.WatchKind diff --git a/lib/client/weblogin.go b/lib/client/weblogin.go index 748857e21e9..1fa2fc66ca3 100644 --- a/lib/client/weblogin.go +++ b/lib/client/weblogin.go @@ -318,6 +318,24 @@ func (c *CredentialsClient) Ping(ctx context.Context, connectorName string) (*Pi return pr, nil } +// Find is like ping, but used by servers to only fetch discovery data, +// without auth connector data, it is designed for servers in IOT mode +// to fetch proxy public addresses on a large scale. +func (c *CredentialsClient) Find(ctx context.Context) (*PingResponse, error) { + response, err := c.clt.Get(ctx, c.clt.Endpoint("webapi", "find"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + + var pr *PingResponse + err = json.Unmarshal(response.Bytes(), &pr) + if err != nil { + return nil, trace.Wrap(err) + } + + return pr, nil +} + // SSHAgentSSOLogin is used by tsh to fetch user credentials using OpenID Connect (OIDC) or SAML. func (c *CredentialsClient) SSHAgentSSOLogin(login SSHLogin) (*auth.SSHLoginResponse, error) { rd, err := NewRedirector(login) diff --git a/lib/reversetunnel/conn.go b/lib/reversetunnel/conn.go index 8cb70764e7e..4b8509fae24 100644 --- a/lib/reversetunnel/conn.go +++ b/lib/reversetunnel/conn.go @@ -324,7 +324,7 @@ type transportParams struct { component string log *logrus.Entry closeContext context.Context - authClient auth.ClientI + authClient auth.AccessPoint channel ssh.Channel requestCh <-chan *ssh.Request diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index fd5b62fa221..04f9741e7b1 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -530,7 +530,7 @@ func (s *server) handleTransport(sconn *ssh.ServerConn, nch ssh.NewChannel) { go proxyTransport(&transportParams{ log: s.Entry, closeContext: s.ctx, - authClient: s.LocalAuthClient, + authClient: s.LocalAccessPoint, channel: channel, requestCh: requestCh, component: teleport.ComponentReverseTunnelServer, diff --git a/lib/service/connect.go b/lib/service/connect.go index cc5f825118e..112419bbcf9 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -668,7 +668,7 @@ func (process *TeleportProcess) rotate(conn *Connector, localState auth.StateV2, // rollback cycle. case "", services.RotationStateStandby: if principalsOrDNSNamesChanged { - process.Infof("Service %v has updated principals to %q, DNS Names to %q, going to request new principals and update.", id.Role, additionalPrincipals) + process.Infof("Service %v has updated principals to %q, DNS Names to %q, going to request new principals and update.", id.Role, additionalPrincipals, dnsNames) identity, err := process.reRegister(conn, additionalPrincipals, dnsNames, remote) if err != nil { return nil, trace.Wrap(err) @@ -820,7 +820,7 @@ func (process *TeleportProcess) findReverseTunnel(addrs []utils.NetAddr) (string return "", trace.Wrap(err) } - resp, err := clt.Ping(process.ExitContext(), "") + resp, err := clt.Find(process.ExitContext()) if err == nil { // If a tunnel public address is set, return it otherwise return the // tunnel listen address. diff --git a/lib/services/local/events.go b/lib/services/local/events.go index ec0261b53ea..af6288b1b58 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -76,6 +76,8 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch services.Watch) (s parser = newNodeParser() case services.KindProxy: parser = newProxyParser() + case services.KindAuthServer: + parser = newAuthServerParser() case services.KindTunnelConnection: parser = newTunnelConnectionParser() case services.KindReverseTunnel: @@ -567,6 +569,28 @@ func (p *proxyParser) parse(event backend.Event) (services.Resource, error) { return parseServer(event, services.KindProxy) } +func newAuthServerParser() *authServerParser { + return &authServerParser{ + matchPrefix: backend.Key(authServersPrefix), + } +} + +type authServerParser struct { + matchPrefix []byte +} + +func (p *authServerParser) prefix() []byte { + return p.matchPrefix +} + +func (p *authServerParser) match(key []byte) bool { + return bytes.HasPrefix(key, p.matchPrefix) +} + +func (p *authServerParser) parse(event backend.Event) (services.Resource, error) { + return parseServer(event, services.KindAuthServer) +} + func newTunnelConnectionParser() *tunnelConnectionParser { return &tunnelConnectionParser{ matchPrefix: backend.Key(tunnelConnectionsPrefix), diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index 0884b8a525d..0cdbdea0aa7 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -312,6 +312,18 @@ func (s *PresenceService) UpsertAuthServer(server services.Server) error { return s.upsertServer(authServersPrefix, server) } +// DeleteAllAuthServers deletes all auth servers +func (s *PresenceService) DeleteAllAuthServers() error { + startKey := backend.Key(authServersPrefix) + return s.DeleteRange(context.TODO(), startKey, backend.RangeEnd(startKey)) +} + +// DeleteAuthServer deletes auth server by name +func (s *PresenceService) DeleteAuthServer(name string) error { + key := backend.Key(authServersPrefix, name) + return s.Delete(context.TODO(), key) +} + // UpsertProxy registers proxy server presence, permanently if ttl is 0 or // for the specified duration with second resolution if it's >= 1 second func (s *PresenceService) UpsertProxy(server services.Server) error { diff --git a/lib/services/presence.go b/lib/services/presence.go index b535b29af27..097025199a9 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -60,6 +60,12 @@ type Presence interface { // for the specified duration with second resolution if it's >= 1 second UpsertAuthServer(server Server) error + // DeleteAuthServer deletes auth server by name + DeleteAuthServer(name string) error + + // DeleteAllAuthServers deletes all auth servers + DeleteAllAuthServers() error + // UpsertProxy registers proxy server presence, permanently if ttl is 0 or // for the specified duration with second resolution if it's >= 1 second UpsertProxy(server Server) error diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a54eabad572..2a618b0edcc 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -153,6 +153,10 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*RewritingHandler, error) { // query the authentication configuration for a specific connector. h.GET("/webapi/ping", httplib.MakeHandler(h.ping)) h.GET("/webapi/ping/:connector", httplib.MakeHandler(h.pingWithConnector)) + // find is like ping, but is faster because it is optimized for servers + // and does not fetch the data that servers don't need, e.g. + // OIDC connectors and auth preferences + h.GET("/webapi/find", httplib.MakeHandler(h.find)) // Web sessions h.POST("/webapi/sessions", httplib.WithCSRFProtection(h.createSession)) @@ -514,6 +518,14 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para }, nil } +func (h *Handler) find(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + return client.PingResponse{ + Proxy: h.cfg.ProxySettings, + ServerVersion: teleport.Version, + MinClientVersion: teleport.MinClientVersion, + }, nil +} + func (h *Handler) pingWithConnector(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { authClient := h.cfg.ProxyClient connectorName := p.ByName("connector")