Improve CertAuthorityWatcher (#10403)

* Improve CertAuthorityWatcher

CertAuthorityWatcher and its usage are refactored to allow for
all the following:
 - eliminate retransmission of the same CAs
 - reduce memory usage by having one local watcher per proxy
 - adds the ability to filter only the CAs that are desired
 - reduce the time required to send the first CAs

watchCertAuthorities now compares all CAs it receives from the
watcher with the previous CA of the same type and only sends to
the remote site if they are not identical. This is to reduce
unnecessary network traffic which can be problematic for a
root cluster with a larger number of leafs.

The CertAuthorityWatcher is refactored to leverage a fanout
to emit events to any number of watchers, each subscription
can be for a subset of the configured CA types. The proxy
now has only one CertAuthorityWatcher that is passed around
similarly to the LockWatcher. This reduces the memory usage
for proxies, which prior to this has one local CAWatcher per
remote site.

updateCertAuthorities no longer waits on the utils.Retry it
is provided with before starting to watch CAs. By doing this
the proxy no longer has to wait ~8 minutes before it even
starts to watch CAs.
This commit is contained in:
rosstimothy 2022-05-17 15:06:41 -04:00 committed by GitHub
parent 414c82a341
commit 1ac0957d0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 423 additions and 254 deletions

View file

@ -269,7 +269,7 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
Clock: p.clock,
Client: p.siteAPI,
},
WatchCertTypes: []types.CertAuthType{p.certType},
Types: []types.CertAuthType{p.certType},
})
if err != nil {
return err
@ -280,16 +280,30 @@ func (p *phaseWatcher) waitForPhase(phase string, fn func() error) error {
return trace.Wrap(err)
}
sub, err := watcher.Subscribe(ctx, services.CertAuthorityTarget{
ClusterName: p.clusterRootName,
Type: p.certType,
})
if err != nil {
return trace.Wrap(err)
}
defer sub.Close()
var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, ctx.Err())
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == p.clusterRootName &&
ca.GetType() == p.certType &&
ca.GetRotation().Phase == phase {
case <-sub.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q certType: %v err: %v", phase, lastPhase, p.certType, sub.Error())
case evt := <-sub.Events():
switch evt.Type {
case types.OpPut:
ca, ok := evt.Resource.(types.CertAuthority)
if !ok {
return trace.BadParameter("expected a ca got type %T", evt.Resource)
}
if ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase

View file

@ -4097,45 +4097,39 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
err = waitForProcessEvent(svc, service.TeleportPhaseChangeEvent, 10*time.Second)
require.NoError(t, err)
// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) error {
ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10)
defer cancel()
watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
WatchCertTypes: []types.CertAuthType{types.HostCA},
Types: []types.CertAuthType{types.HostCA},
})
if err != nil {
return err
}
defer watcher.Close()
var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == clusterMain &&
ca.GetType() == types.HostCA &&
ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
}
}
}
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
}
err = waitForPhase(types.RotationPhaseInit)
require.NoError(t, err)
t.Cleanup(watcher.Close)
// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) {
require.Eventually(t, func() bool {
ca, err := aux.Process.GetAuthServer().GetCertAuthority(
ctx,
types.CertAuthID{
Type: types.HostCA,
DomainName: clusterMain,
}, false)
if err != nil {
return false
}
if ca.GetRotation().Phase == phase {
return true
}
return false
}, tconf.PollingPeriod*10, tconf.PollingPeriod/2, "failed to converge to phase %q", phase)
}
waitForPhase(types.RotationPhaseInit)
// update clients
err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{
@ -4148,8 +4142,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)
err = waitForPhase(types.RotationPhaseUpdateClients)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateClients)
// old client should work as is
err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*")
@ -4168,8 +4161,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = waitForReload(serviceC, svc)
require.NoError(t, err)
err = waitForPhase(types.RotationPhaseUpdateServers)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateServers)
// new credentials will work from this phase to others
newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username})
@ -4197,8 +4189,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)
t.Log("Service reload completed, waiting for phase.")
err = waitForPhase(types.RotationPhaseStandby)
require.NoError(t, err)
waitForPhase(types.RotationPhaseStandby)
t.Log("Phase completed.")
// new client still works

View file

@ -427,20 +427,12 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error {
return trace.CompareFailed("remote certificate authority rotation has been updated")
}
func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersion string) {
s.Debugf("Watching for cert authority changes.")
func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) {
defer remoteWatcher.Close()
cas := make(map[types.CertAuthType]types.CertAuthority)
for {
startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}
err := s.watchCertAuthorities(remoteClusterVersion)
err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas)
if err != nil {
switch {
case trace.IsNotFound(err):
@ -456,70 +448,88 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteClusterVersi
}
}
startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}
}
}
func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error {
localWatchedTypes, err := s.getLocalWatchedCerts(remoteClusterVersion)
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
targets, err := s.getLocalWatchedCerts(remoteVersion)
if err != nil {
return trace.Wrap(err)
}
localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.localAccessPoint,
localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(s.ctx, targets...)
if err != nil {
return trace.Wrap(err)
}
defer func() {
if err := localWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close local ca watcher subscription.")
}
}()
remoteWatch, err := remoteWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
ClusterName: s.domainName,
Type: types.HostCA,
},
WatchCertTypes: localWatchedTypes,
})
)
if err != nil {
return trace.Wrap(err)
}
defer localWatcher.Close()
remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.remoteAccessPoint,
},
WatchCertTypes: []types.CertAuthType{types.HostCA},
})
if err != nil {
return trace.Wrap(err)
defer func() {
if err := remoteWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close remote ca watcher subscription.")
}
defer remoteWatcher.Close()
}()
s.Debugf("Watching for cert authority changes.")
for {
select {
case <-s.ctx.Done():
s.WithError(s.ctx.Err()).Debug("Context is closing.")
return trace.Wrap(s.ctx.Err())
case <-localWatcher.Done():
case <-localWatch.Done():
s.Warn("Local CertAuthority watcher subscription has closed")
return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName)
case <-remoteWatcher.Done():
case <-remoteWatch.Done():
s.Warn("Remote CertAuthority watcher subscription has closed")
return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName)
case cas := <-localWatcher.CertAuthorityC:
for _, localCA := range cas {
if localCA.GetClusterName() != s.srv.ClusterName ||
!localWatcher.IsWatched(localCA.GetType()) {
case evt := <-localWatch.Events():
switch evt.Type {
case types.OpPut:
localCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}
ca, ok := cas[localCA.GetType()]
if ok && services.CertAuthoritiesEquivalent(ca, localCA) {
continue
}
// clone to prevent a race with watcher filtering
localCA = localCA.Clone()
if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil {
s.WithError(err).Warn("Failed to rotate external ca")
log.WithError(err).Warn("Failed to rotate external ca")
return trace.Wrap(err)
}
cas[localCA.GetType()] = localCA
}
case cas := <-remoteWatcher.CertAuthorityC:
for _, remoteCA := range cas {
if remoteCA.GetType() != types.HostCA ||
remoteCA.GetClusterName() != s.domainName {
case evt := <-remoteWatch.Events():
switch evt.Type {
case types.OpPut:
remoteCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}
@ -549,8 +559,17 @@ func (s *remoteSite) watchCertAuthorities(remoteClusterVersion string) error {
}
// getLocalWatchedCerts returns local certificates types that should be watched by the cert authority watcher.
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.CertAuthType, error) {
localWatchedTypes := []types.CertAuthType{types.HostCA, types.UserCA}
func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]services.CertAuthorityTarget, error) {
localWatchedTypes := []services.CertAuthorityTarget{
{
Type: types.HostCA,
ClusterName: s.srv.ClusterName,
},
{
Type: types.UserCA,
ClusterName: s.srv.ClusterName,
},
}
// Delete in 11.0.
ver10orAbove, err := utils.MinVerWithoutPreRelease(remoteClusterVersion, constants.DatabaseCAMinVersion)
@ -559,7 +578,7 @@ func (s *remoteSite) getLocalWatchedCerts(remoteClusterVersion string) ([]types.
}
if ver10orAbove {
localWatchedTypes = append(localWatchedTypes, types.DatabaseCA)
localWatchedTypes = append(localWatchedTypes, services.CertAuthorityTarget{ClusterName: s.srv.ClusterName, Type: types.DatabaseCA})
} else {
s.Debugf("Connected to remote cluster of version %s. Database CA won't be propagated.", remoteClusterVersion)
}

View file

@ -22,6 +22,7 @@ import (
"testing"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
@ -31,34 +32,48 @@ func Test_remoteSite_getLocalWatchedCerts(t *testing.T) {
tests := []struct {
name string
clusterVersion string
want []types.CertAuthType
wantErr bool
want []services.CertAuthorityTarget
errorAssertion require.ErrorAssertionFunc
}{
{
name: "pre Database CA, only Host and User CA",
clusterVersion: "9.0.0",
want: []types.CertAuthType{types.HostCA, types.UserCA},
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
},
errorAssertion: require.NoError,
},
{
name: "all certs should be returned",
clusterVersion: "10.0.0",
want: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
want: []services.CertAuthorityTarget{
{Type: types.HostCA, ClusterName: "test"},
{Type: types.UserCA, ClusterName: "test"},
{Type: types.DatabaseCA, ClusterName: "test"},
},
errorAssertion: require.NoError,
},
{
name: "invalid version",
clusterVersion: "foo",
wantErr: true,
errorAssertion: require.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &remoteSite{
srv: &server{
Config: Config{
ClusterName: "test",
},
},
Entry: log.NewEntry(utils.NewLoggerForTests()),
}
got, err := s.getLocalWatchedCerts(tt.clusterVersion)
if (err != nil) != tt.wantErr {
t.Errorf("getLocalWatchedCerts() error = %v, wantErr %v", err, tt.wantErr)
tt.errorAssertion(t, err)
if err != nil {
return
}

View file

@ -205,6 +205,9 @@ type Config struct {
// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher
// CertAuthorityWatcher is a cert authority watcher.
CertAuthorityWatcher *services.CertAuthorityWatcher
}
// CheckAndSetDefaults checks parameters and sets default values
@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
if cfg.CertAuthorityWatcher == nil {
return trace.BadParameter("missing parameter CertAuthorityWatcher")
}
return nil
}
@ -1040,6 +1046,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
connInfo.SetExpiry(srv.Clock.Now().Add(srv.offlineThreshold))
closeContext, cancel := context.WithCancel(srv.ctx)
defer func() {
if err != nil {
cancel()
}
}()
remoteSite := &remoteSite{
srv: srv,
domainName: domainName,
@ -1063,20 +1074,17 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt
remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint
@ -1088,7 +1096,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
},
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
@ -1098,7 +1105,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
@ -1111,11 +1117,25 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
go remoteSite.updateCertAuthorities(caRetry, remoteVersion)
remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: srv.log,
Clock: srv.Clock,
Client: remoteSite.remoteAccessPoint,
},
Types: []types.CertAuthType{types.HostCA},
})
if err != nil {
return nil, trace.Wrap(err)
}
go func() {
remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion)
}()
lockRetry, err := utils.NewLinear(utils.LinearConfig{
First: utils.HalfJitter(srv.Config.PollingPeriod),
@ -1125,7 +1145,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

View file

@ -2853,6 +2853,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}
caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: process.log.WithField(trace.Component, teleport.ComponentProxy),
Client: conn.Client,
},
AuthorityGetter: accessPoint,
Types: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
})
if err != nil {
return trace.Wrap(err)
}
serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
@ -2893,6 +2906,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
if err != nil {
return trace.Wrap(err)

View file

@ -919,10 +919,8 @@ type CertAuthorityWatcherConfig struct {
ResourceWatcherConfig
// AuthorityGetter is responsible for fetching cert authority resources.
AuthorityGetter
// CertAuthorityC receives up-to-date list of all cert authority resources.
CertAuthorityC chan []types.CertAuthority
// WatchCertTypes stores all certificate types that should be monitored.
WatchCertTypes []types.CertAuthType
// Types restricts which cert authority types are retrieved via the AuthorityGetter.
Types []types.CertAuthType
}
// CheckAndSetDefaults checks parameters and sets default values.
@ -937,15 +935,12 @@ func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error {
}
cfg.AuthorityGetter = getter
}
if cfg.CertAuthorityC == nil {
cfg.CertAuthorityC = make(chan []types.CertAuthority)
}
return nil
}
// IsWatched return true if the given certificate auth type is being observer by the watcher.
func (cfg *CertAuthorityWatcherConfig) IsWatched(certType types.CertAuthType) bool {
for _, observedType := range cfg.WatchCertTypes {
for _, observedType := range cfg.Types {
if observedType == certType {
return true
}
@ -961,6 +956,12 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig
collector := &caCollector{
CertAuthorityWatcherConfig: cfg,
fanout: NewFanout(),
cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)),
}
for _, t := range cfg.Types {
collector.cas[t] = make(map[string]types.CertAuthority)
}
watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig)
@ -968,6 +969,7 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig
return nil, trace.Wrap(err)
}
collector.fanout.SetInit()
return &CertAuthorityWatcher{watcher, collector}, nil
}
@ -980,26 +982,66 @@ type CertAuthorityWatcher struct {
// caCollector accompanies resourceWatcher when monitoring cert authority resources.
type caCollector struct {
CertAuthorityWatcherConfig
fanout *Fanout
collectedCAs CertAuthorityTypeMap
// lock protects concurrent access to cas
lock sync.RWMutex
// cas maps ca type -> cluster -> ca
cas map[types.CertAuthType]map[string]types.CertAuthority
}
// CertAuthorityMap maps clusterName -> types.CertAuthority
type CertAuthorityMap map[string]types.CertAuthority
// CertAuthorityTypeMap maps types.CertAuthType -> map(clusterName -> types.CertAuthority)
type CertAuthorityTypeMap map[types.CertAuthType]CertAuthorityMap
// ToSlice converts CertAuthorityTypeMap to a slice.
func (cat *CertAuthorityTypeMap) ToSlice() []types.CertAuthority {
slice := make([]types.CertAuthority, 0)
for _, cert := range *cat {
for _, ca := range cert {
slice = append(slice, ca)
// CertAuthorityTarget lists the attributes of interactions to be disabled.
type CertAuthorityTarget struct {
// ClusterName specifies the name of the cluster to watch.
ClusterName string
// Type specifies the ca types to watch for.
Type types.CertAuthType
}
// Subscribe is used to subscribe to the lock updates.
func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTarget) (types.Watcher, error) {
watchKinds, err := caTargetToWatchKinds(targets)
if err != nil {
return nil, trace.Wrap(err)
}
return slice
sub, err := c.fanout.NewWatcher(ctx, types.Watch{Kinds: watchKinds})
if err != nil {
return nil, trace.Wrap(err)
}
select {
case event := <-sub.Events():
if event.Type != types.OpInit {
return nil, trace.BadParameter("expected init event, got %v instead", event.Type)
}
case <-sub.Done():
return nil, trace.Wrap(sub.Error())
}
return sub, nil
}
func caTargetToWatchKinds(targets []CertAuthorityTarget) ([]types.WatchKind, error) {
watchKinds := make([]types.WatchKind, 0, len(targets))
for _, target := range targets {
kind := types.WatchKind{
Kind: types.KindCertAuthority,
// Note that watching SubKind doesn't work for types.WatchKind - to do so it would
// require a custom filter, which was recently added but - we can't use yet due to
// older clients not supporting the filter.
SubKind: string(target.Type),
}
if target.ClusterName != "" {
kind.Name = target.ClusterName
}
watchKinds = append(watchKinds, kind)
}
if len(watchKinds) == 0 {
watchKinds = []types.WatchKind{{Kind: types.KindCertAuthority}}
}
return watchKinds, nil
}
// resourceKind specifies the resource kind to watch.
@ -1009,28 +1051,27 @@ func (c *caCollector) resourceKind() string {
// getResourcesAndUpdateCurrent refreshes the list of current resources.
func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error {
updatedCerts := make(CertAuthorityTypeMap)
var cas []types.CertAuthority
for _, caType := range c.WatchCertTypes {
cas, err := c.AuthorityGetter.GetCertAuthorities(ctx, caType, false)
for _, t := range c.Types {
authorities, err := c.AuthorityGetter.GetCertAuthorities(ctx, t, false)
if err != nil {
return trace.Wrap(err)
}
updatedCerts[caType] = make(CertAuthorityMap, len(cas))
for _, ca := range cas {
updatedCerts[caType][ca.GetName()] = ca
}
cas = append(cas, authorities...)
}
c.lock.Lock()
c.collectedCAs = updatedCerts
c.lock.Unlock()
defer c.lock.Unlock()
select {
case <-ctx.Done():
return trace.Wrap(ctx.Err())
case c.CertAuthorityC <- updatedCerts.ToSlice():
for _, ca := range cas {
if !c.watchingType(ca.GetType()) {
continue
}
c.cas[ca.GetType()][ca.GetName()] = ca
c.fanout.Emit(types.Event{Type: types.OpPut, Resource: ca.Clone()})
}
return nil
}
@ -1046,17 +1087,12 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty
switch event.Type {
case types.OpDelete:
caType := types.CertAuthType(event.Resource.GetSubKind())
// Check if the certificate should be processed.
_, found := c.collectedCAs[caType]
if found {
delete(c.collectedCAs[caType], event.Resource.GetName())
if !c.watchingType(caType) {
return
}
select {
case <-ctx.Done():
case c.CertAuthorityC <- c.collectedCAs.ToSlice():
}
delete(c.cas[caType], event.Resource.GetName())
c.fanout.Emit(event)
case types.OpPut:
ca, ok := event.Resource.(types.CertAuthority)
if !ok {
@ -1064,28 +1100,31 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty
return
}
caType := ca.GetType()
_, found := c.collectedCAs[caType]
// Check if the certificate should be processed.
if found {
c.collectedCAs[caType][ca.GetName()] = ca
if !c.watchingType(ca.GetType()) {
return
}
select {
case <-ctx.Done():
case c.CertAuthorityC <- c.collectedCAs.ToSlice():
authority, ok := c.cas[ca.GetType()][ca.GetName()]
if ok && CertAuthoritiesEquivalent(authority, ca) {
return
}
c.cas[ca.GetType()][ca.GetName()] = ca
c.fanout.Emit(event)
default:
c.Log.Warnf("Unsupported event type %s.", event.Type)
return
}
}
// GetCurrent returns the currently stored authorities.
func (c *caCollector) GetCurrent() []types.CertAuthority {
c.lock.RLock()
defer c.lock.RUnlock()
return c.collectedCAs.ToSlice()
func (c *caCollector) watchingType(t types.CertAuthType) bool {
for _, caType := range c.Types {
if caType == t {
return true
}
}
return false
}
func (c *caCollector) notifyStale() {}

View file

@ -37,6 +37,7 @@ import (
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend/lite"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/tlsca"
@ -520,9 +521,10 @@ func resourceDiff(res1, res2 types.Resource) string {
func caDiff(ca1, ca2 types.CertAuthority) string {
return cmp.Diff(ca1, ca2,
cmpopts.IgnoreFields(types.Metadata{}, "ID"),
cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"),
cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs", "JWTKeyPairs"),
cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"),
cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"),
cmpopts.IgnoreFields(types.JWTKeyPair{}, "PrivateKey"),
cmpopts.EquateEmpty(),
)
}
@ -723,10 +725,12 @@ func newApp(t *testing.T, name string) types.Application {
func TestCertAuthorityWatcher(t *testing.T) {
t.Parallel()
ctx := context.Background()
clock := clockwork.NewFakeClock()
bk, err := lite.NewWithConfig(ctx, lite.Config{
Path: t.TempDir(),
PollStreamPeriod: 200 * time.Millisecond,
Clock: clock,
})
require.NoError(t, err)
@ -744,85 +748,88 @@ func TestCertAuthorityWatcher(t *testing.T) {
Trust: caService,
Events: local.NewEventsService(bk),
},
Clock: clock,
},
CertAuthorityC: make(chan []types.CertAuthority, 10),
WatchCertTypes: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
Types: []types.CertAuthType{types.HostCA, types.UserCA, types.DatabaseCA},
})
require.NoError(t, err)
t.Cleanup(w.Close)
nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
MaxRetryPeriod: 200 * time.Millisecond,
Client: &client{
Trust: caService,
Events: local.NewEventsService(bk),
},
},
CertAuthorityC: make(chan []types.CertAuthority, 10),
})
target := services.CertAuthorityTarget{ClusterName: "test"}
sub, err := w.Subscribe(ctx, target)
require.NoError(t, err)
t.Cleanup(nothingWatcher.Close)
t.Cleanup(func() { require.NoError(t, sub.Close()) })
require.Empty(t, w.GetCurrent())
require.Empty(t, nothingWatcher.GetCurrent())
// Initially there are no cas so watcher should send an empty list.
// create a CA for the cluster and a type we are filtering for
// and ensure we receive the event
ca := newCertAuthority(t, "test", types.HostCA)
require.NoError(t, caService.UpsertCertAuthority(ca))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 0)
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the first event.")
case event := <-sub.Events():
caFromEvent, ok := event.Resource.(types.CertAuthority)
require.True(t, ok)
require.Empty(t, caDiff(ca, caFromEvent))
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
// Add an authority.
ca1 := newCertAuthority(t, "ca1", types.HostCA)
require.NoError(t, caService.CreateCertAuthority(ca1))
// The first event is always the current list of apps.
// create a CA with a type we are filtering for another cluster that we are NOT filtering for
// and ensure that we DO NOT receive the event
require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA)))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 1)
require.Empty(t, caDiff(changeset[0], ca1))
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the first event.")
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
// Add a second ca.
ca2 := newCertAuthority(t, "ca2", types.UserCA)
require.NoError(t, caService.CreateCertAuthority(ca2))
// Watcher should detect the ca list change.
// create a CA for the cluster and a type we are filtering for
// and ensure we receive the event
ca2 := newCertAuthority(t, "test", types.UserCA)
require.NoError(t, caService.UpsertCertAuthority(ca2))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 2)
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the update event.")
case event := <-sub.Events():
caFromEvent, ok := event.Resource.(types.CertAuthority)
require.True(t, ok)
require.Empty(t, caDiff(ca2, caFromEvent))
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
// Delete the first ca.
require.NoError(t, caService.DeleteCertAuthority(ca1.GetID()))
// Watcher should detect the ca list change.
// delete a CA with type being watched in the cluster we are filtering for
// and ensure we receive the event
require.NoError(t, caService.DeleteCertAuthority(ca.GetID()))
select {
case changeset := <-w.CertAuthorityC:
require.Len(t, changeset, 1)
require.Empty(t, caDiff(changeset[0], ca2))
require.Empty(t, nothingWatcher.GetCurrent())
case <-w.Done():
t.Fatal("Watcher has unexpectedly exited.")
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for the update event.")
case event := <-sub.Events():
require.Equal(t, types.KindCertAuthority, event.Resource.GetKind())
require.Equal(t, string(types.HostCA), event.Resource.GetSubKind())
require.Equal(t, "test", event.Resource.GetName())
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
// create a CA with a type we are NOT filtering for but for a cluster we are filtering for
// and ensure we DO NOT receive the event
signer := newCertAuthority(t, "test", types.JWTSigner)
require.NoError(t, caService.UpsertCertAuthority(signer))
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
// delete a CA with a name we are filtering for but a type we are NOT filtering for
// and ensure we do NOT receive the event
require.NoError(t, caService.DeleteCertAuthority(signer.GetID()))
select {
case event := <-sub.Events():
t.Fatalf("Unexpected event: %v.", event)
case <-sub.Done():
t.Fatal("CA watcher subscription has unexpectedly exited.")
case <-time.After(time.Second):
}
}
@ -839,15 +846,25 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type
Type: caType,
ClusterName: name,
ActiveKeys: types.CAKeySet{
SSH: []*types.SSHKeyPair{{
SSH: []*types.SSHKeyPair{
{
PrivateKey: priv,
PrivateKeyType: types.PrivateKeyType_RAW,
PublicKey: pub,
}},
TLS: []*types.TLSKeyPair{{
},
},
TLS: []*types.TLSKeyPair{
{
Cert: cert,
Key: key,
}},
},
},
JWT: []*types.JWTKeyPair{
{
PublicKey: []byte(fixtures.JWTSignerPublicKey),
PrivateKey: []byte(fixtures.JWTSignerPrivateKey),
},
},
},
Roles: nil,
SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256,

View file

@ -1126,6 +1126,7 @@ func TestProxyRoundRobin(t *testing.T) {
defer listener.Close()
lockWatcher := newLockWatcher(ctx, t, proxyClient)
nodeWatcher := newNodeWatcher(ctx, t, proxyClient)
caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient)
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ClusterName: f.testSrv.ClusterName(),
@ -1143,6 +1144,7 @@ func TestProxyRoundRobin(t *testing.T) {
Log: logger,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)
logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.")
@ -1252,6 +1254,7 @@ func TestProxyDirectAccess(t *testing.T) {
proxyClient, _ := newProxyClient(t, f.testSrv)
lockWatcher := newLockWatcher(ctx, t, proxyClient)
nodeWatcher := newNodeWatcher(ctx, t, proxyClient)
caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient)
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ClientTLS: proxyClient.TLSConfig(),
@ -1269,6 +1272,7 @@ func TestProxyDirectAccess(t *testing.T) {
Log: logger,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)
@ -1863,6 +1867,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) {
proxyClient, _ := newProxyClient(t, f.testSrv)
lockWatcher := newLockWatcher(ctx, t, proxyClient)
nodeWatcher := newNodeWatcher(ctx, t, proxyClient)
caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient)
reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{
ClientTLS: proxyClient.TLSConfig(),
@ -1880,6 +1885,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) {
Log: logger,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)
@ -2098,6 +2104,19 @@ func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *ser
return nodeWatcher
}
func newCertAuthorityWatcher(ctx context.Context, t *testing.T, client types.Events) *services.CertAuthorityWatcher {
caWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
Client: client,
},
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
require.NoError(t, err)
t.Cleanup(caWatcher.Close)
return caWatcher
}
// maxPipeSize is one larger than the maximum pipe size for most operating
// systems which appears to be 65536 bytes.
//

View file

@ -275,6 +275,16 @@ func newWebSuite(t *testing.T) *WebSuite {
})
require.NoError(t, err)
caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Client: s.proxyClient,
},
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
require.NoError(t, err)
defer caWatcher.Close()
revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{
ID: node.ID(),
Listener: revTunListener,
@ -289,6 +299,7 @@ func newWebSuite(t *testing.T) *WebSuite {
DataDir: t.TempDir(),
LockWatcher: proxyLockWatcher,
NodeWatcher: proxyNodeWatcher,
CertAuthorityWatcher: caWatcher,
})
require.NoError(t, err)
s.proxyTunnel = revTunServer
@ -3746,6 +3757,16 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula
require.NoError(t, err)
t.Cleanup(proxyLockWatcher.Close)
proxyCAWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Client: client,
},
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
require.NoError(t, err)
t.Cleanup(proxyLockWatcher.Close)
proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
@ -3769,6 +3790,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula
DataDir: t.TempDir(),
LockWatcher: proxyLockWatcher,
NodeWatcher: proxyNodeWatcher,
CertAuthorityWatcher: proxyCAWatcher,
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, revTunServer.Close()) })