AWS cross-account db discovery (#22866)

This commit is contained in:
Gavin Frazar 2023-04-10 09:36:32 -07:00 committed by GitHub
parent ff15a40654
commit c42ae4e6ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 281 additions and 169 deletions

View file

@ -78,6 +78,10 @@ type Database interface {
GetAWS() AWS
// SetStatusAWS sets the database AWS metadata in the status field.
SetStatusAWS(AWS)
// SetAWSExternalID sets the database AWS external ID in the Spec.AWS field.
SetAWSExternalID(id string)
// SetAWSAssumeRole sets the database AWS assume role arn in the Spec.AWS field.
SetAWSAssumeRole(roleARN string)
// GetGCP returns GCP information for Cloud SQL databases.
GetGCP() GCPCloudSQL
// GetAzure returns Azure database server metadata.
@ -341,6 +345,16 @@ func (d *DatabaseV3) SetStatusAWS(aws AWS) {
d.Status.AWS = aws
}
// SetAWSExternalID sets the database AWS external ID in the Spec.AWS field.
func (d *DatabaseV3) SetAWSExternalID(id string) {
d.Spec.AWS.ExternalID = id
}
// SetAWSAssumeRole sets the database AWS assume role arn in the Spec.AWS field.
func (d *DatabaseV3) SetAWSAssumeRole(roleARN string) {
d.Spec.AWS.AssumeRoleARN = roleARN
}
// GetGCP returns GCP information for Cloud SQL databases.
func (d *DatabaseV3) GetGCP() GCPCloudSQL {
return d.Spec.GCP

View file

@ -253,13 +253,18 @@ type awsAssumeRoleOpts struct {
// when getting an AWS session.
type AWSAssumeRoleOptionFn func(*awsAssumeRoleOpts)
// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) AWSAssumeRoleOptionFn {
return func(options *awsAssumeRoleOpts) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
}
}
// WithAssumeRoleFromAWSMeta extracts options needed from AWS metadata for
// assuming an AWS role.
func WithAssumeRoleFromAWSMeta(meta types.AWS) AWSAssumeRoleOptionFn {
return func(options *awsAssumeRoleOpts) {
options.assumeRoleARN = meta.AssumeRoleARN
options.assumeRoleExternalID = meta.ExternalID
}
return WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID)
}
// WithChainedAssumeRole sets a role to assume with a base session to use

View file

@ -39,11 +39,13 @@ import (
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshift"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
@ -918,9 +920,22 @@ func (m *mockGKEAPI) ListClusters(ctx context.Context, projectID string, locatio
func TestDiscoveryDatabase(t *testing.T) {
awsRedshiftResource, awsRedshiftDB := makeRedshiftCluster(t, "aws-redshift", "us-east-1")
awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1")
azRedisResource, azRedisDB := makeAzureRedisServer(t, "az-redis", "sub1", "group1", "East US")
role := services.AssumeRole{RoleARN: "arn:aws:iam::123456789012:role/test-role", ExternalID: "test123"}
awsRDSDBWithRole := awsRDSDB.Copy()
awsRDSDBWithRole.SetAWSAssumeRole("arn:aws:iam::123456789012:role/test-role")
awsRDSDBWithRole.SetAWSExternalID("test123")
testClients := &cloud.TestCloudClients{
STS: &mocks.STSMock{},
RDS: &mocks.RDSMock{
DBInstances: []*rds.DBInstance{awsRDSInstance},
DBEngineVersions: []*rds.DBEngineVersion{
{Engine: aws.String(services.RDSEnginePostgres)},
},
},
Redshift: &mocks.RedshiftMock{
Clusters: []*redshift.Cluster{awsRedshiftResource},
},
@ -949,6 +964,16 @@ func TestDiscoveryDatabase(t *testing.T) {
}},
expectDatabases: []types.Database{awsRedshiftDB},
},
{
name: "discover AWS database with assumed role",
awsMatchers: []services.AWSMatcher{{
Types: []string{services.AWSMatcherRDS},
Tags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}},
Regions: []string{"us-west-1"},
AssumeRole: role,
}},
expectDatabases: []types.Database{awsRDSDBWithRole},
},
{
name: "discover Azure database",
azureMatchers: []services.AzureMatcher{{
@ -979,6 +1004,26 @@ func TestDiscoveryDatabase(t *testing.T) {
}},
expectDatabases: []types.Database{awsRedshiftDB},
},
{
name: "update existing database with assumed role",
existingDatabases: []types.Database{
mustNewDatabase(t, types.Metadata{
Name: "aws-rds",
Description: "should be updated",
Labels: map[string]string{types.OriginLabel: types.OriginCloud},
}, types.DatabaseSpecV3{
Protocol: "postgres",
URI: "should.be.updated.com:12345",
}),
},
awsMatchers: []services.AWSMatcher{{
Types: []string{services.AWSMatcherRDS},
Tags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}},
Regions: []string{"us-west-1"},
AssumeRole: role,
}},
expectDatabases: []types.Database{awsRDSDBWithRole},
},
{
name: "delete existing database",
existingDatabases: []types.Database{
@ -1091,6 +1136,23 @@ func TestDiscoveryDatabase(t *testing.T) {
}
}
func makeRDSInstance(t *testing.T, name, region string) (*rds.DBInstance, types.Database) {
instance := &rds.DBInstance{
DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)),
DBInstanceIdentifier: aws.String(name),
DbiResourceId: aws.String(uuid.New().String()),
Engine: aws.String(services.RDSEnginePostgres),
DBInstanceStatus: aws.String("available"),
Endpoint: &rds.Endpoint{
Address: aws.String("localhost"),
Port: aws.Int64(5432),
},
}
database, err := services.NewDatabaseFromRDSInstance(instance)
require.NoError(t, err)
return instance, database
}
func makeRedshiftCluster(t *testing.T, name, region string) (*redshift.Cluster, types.Database) {
t.Helper()
cluster := &redshift.Cluster{

View file

@ -39,6 +39,8 @@ type elastiCacheFetcherConfig struct {
ElastiCache elasticacheiface.ElastiCacheAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}
// CheckAndSetDefaults validates the config and sets defaults.
@ -74,6 +76,7 @@ func newElastiCacheFetcher(config elastiCacheFetcherConfig) (common.Fetcher, err
trace.Component: "watch:elasticache",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -168,6 +171,7 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels
}
}
applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

View file

@ -48,12 +48,7 @@ func TestElastiCacheFetcher(t *testing.T) {
aws.StringValue(elasticacheUnsupported.ARN): elasticacheUnsupportedTags,
}
tests := []struct {
name string
inputClients cloud.AWSClients
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
@ -62,7 +57,7 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd, elasticacheDatabaseQA},
},
{
@ -73,7 +68,7 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: envProdLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", envProdLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd},
},
{
@ -84,7 +79,7 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd},
},
{
@ -95,20 +90,11 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherElastiCache, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}
func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func(*elasticache.ReplicationGroup)) (*elasticache.ReplicationGroup, types.Database, []*elasticache.Tag) {

View file

@ -39,6 +39,8 @@ type memoryDBFetcherConfig struct {
MemoryDB memorydbiface.MemoryDBAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}
// CheckAndSetDefaults validates the config and sets defaults.
@ -74,6 +76,7 @@ func newMemoryDBFetcher(config memoryDBFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:memorydb",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -136,6 +139,7 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e
databases = append(databases, database)
}
}
applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

View file

@ -47,12 +47,7 @@ func TestMemoryDBFetcher(t *testing.T) {
aws.StringValue(memorydbUnsupported.ARN): memorydbUnsupportedTags,
}
tests := []struct {
name string
inputClients cloud.AWSClients
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
@ -61,7 +56,7 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{memorydbDatabaseProd, memorydbDatabaseTest},
},
{
@ -72,7 +67,7 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: envProdLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", envProdLabels),
wantDatabases: types.Databases{memorydbDatabaseProd},
},
{
@ -83,7 +78,7 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{memorydbDatabaseProd},
},
{
@ -94,20 +89,11 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{memorydbDatabaseProd},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherMemoryDB, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}
func makeMemoryDBCluster(t *testing.T, name, region, env string, opts ...func(*memorydb.Cluster)) (*memorydb.Cluster, types.Database, []*memorydb.Tag) {

View file

@ -41,6 +41,8 @@ type rdsFetcherConfig struct {
RDS rdsiface.RDSAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}
// CheckAndSetDefaults validates the config and sets defaults.
@ -76,6 +78,7 @@ func newRDSDBInstancesFetcher(config rdsFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:rds",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -87,6 +90,7 @@ func (f *rdsDBInstancesFetcher) Get(ctx context.Context) (types.ResourcesWithLab
return nil, trace.Wrap(err)
}
applyAssumeRoleToDatabases(rdsDatabases, f.cfg.AssumeRole)
return filterDatabasesByLabels(rdsDatabases, f.cfg.Labels, f.log).AsResources(), nil
}
@ -172,6 +176,7 @@ func newRDSAuroraClustersFetcher(config rdsFetcherConfig) (common.Fetcher, error
trace.Component: "watch:aurora",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -183,6 +188,7 @@ func (f *rdsAuroraClustersFetcher) Get(ctx context.Context) (types.ResourcesWith
return nil, trace.Wrap(err)
}
applyAssumeRoleToDatabases(auroraDatabases, f.cfg.AssumeRole)
return filterDatabasesByLabels(auroraDatabases, f.cfg.Labels, f.log).AsResources(), nil
}

View file

@ -48,6 +48,7 @@ func newRDSDBProxyFetcher(config rdsFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:rdsproxy",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -60,6 +61,7 @@ func (f *rdsDBProxyFetcher) Get(ctx context.Context) (types.ResourcesWithLabels,
return nil, trace.Wrap(err)
}
applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

View file

@ -38,40 +38,33 @@ func TestRDSDBProxyFetcher(t *testing.T) {
rdsProxyEndpointVpc1, rdsProxyEndpointDatabaseVpc1 := makeRDSProxyCustomEndpoint(t, rdsProxyVpc1, "endpoint-1", "us-east-1")
rdsProxyEndpointVpc2, rdsProxyEndpointDatabaseVpc2 := makeRDSProxyCustomEndpoint(t, rdsProxyVpc2, "endpoint-2", "us-east-1")
clients := &cloud.TestCloudClients{
RDS: &mocks.RDSMock{
DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2},
DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2},
DBProxyTargetPort: 9999,
},
}
tests := []struct {
name string
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputLabels: wildcardLabels,
name: "fetch all",
inputClients: &cloud.TestCloudClients{
RDS: &mocks.RDSMock{
DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2},
DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2},
DBProxyTargetPort: 9999,
},
},
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRDSProxy, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{rdsProxyDatabaseVpc1, rdsProxyDatabaseVpc2, rdsProxyEndpointDatabaseVpc1, rdsProxyEndpointDatabaseVpc2},
},
{
name: "fetch vpc1",
inputLabels: map[string]string{"vpc-id": "vpc1"},
name: "fetch vpc1",
inputClients: &cloud.TestCloudClients{
RDS: &mocks.RDSMock{
DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2},
DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2},
DBProxyTargetPort: 9999,
},
},
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRDSProxy, "us-east-1", map[string]string{"vpc-id": "vpc1"}),
wantDatabases: types.Databases{rdsProxyDatabaseVpc1, rdsProxyEndpointDatabaseVpc1},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fetchers := mustMakeAWSFetchersForMatcher(t, clients, services.AWSMatcherRDSProxy, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}
func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types.Database) {

View file

@ -55,12 +55,7 @@ func TestRDSFetchers(t *testing.T) {
auroraClusterUnknownStatus, auroraDatabaseUnknownStatus := makeRDSCluster(t, "cluster-5", "us-east-1", nil, withRDSClusterStatus("status-does-not-exist"))
auroraClusterNoWriter, auroraDatabasesNoWriter := makeRDSClusterWithExtraEndpoints(t, "cluster-6", "us-east-1", envDevLabels, false)
tests := []struct {
name string
inputClients cloud.AWSClients
inputMatchers []services.AWSMatcher
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
@ -206,16 +201,7 @@ func TestRDSFetchers(t *testing.T) {
wantDatabases: auroraDatabasesNoWriter,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fetchers := mustMakeAWSFetchers(t, test.inputClients, test.inputMatchers)
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}
func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) (*rds.DBInstance, types.Database) {

View file

@ -40,6 +40,8 @@ type redshiftFetcherConfig struct {
Redshift redshiftiface.RedshiftAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}
// CheckAndSetDefaults validates the config and sets defaults.
@ -75,6 +77,7 @@ func newRedshiftFetcher(config redshiftFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:redshift",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -104,6 +107,7 @@ func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e
databases = append(databases, database)
}
applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

View file

@ -41,6 +41,8 @@ type redshiftServerlessFetcherConfig struct {
Region string
// Client is the Redshift Serverless API client.
Client redshiftserverlessiface.RedshiftServerlessAPI
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}
// CheckAndSetDefaults validates the config and sets defaults.
@ -83,6 +85,7 @@ func newRedshiftServerlessFetcher(config redshiftServerlessFetcherConfig) (commo
trace.Component: "watch:rss<", // (r)ed(s)hift (s)erver(<)less
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
@ -106,6 +109,7 @@ func (f *redshiftServerlessFetcher) Get(ctx context.Context) (types.ResourcesWit
databases = append(databases, vpcEndpointDatabases...)
}
applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

View file

@ -47,12 +47,7 @@ func TestRedshiftServerlessFetcher(t *testing.T) {
endpointNotAvailable := mocks.RedshiftServerlessEndpointAccess(workgroupNotAvailable, "endpoint-creating", "us-east-1")
endpointNotAvailable.EndpointStatus = aws.String("creating")
tests := []struct {
name string
inputClients cloud.AWSClients
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
@ -62,7 +57,7 @@ func TestRedshiftServerlessFetcher(t *testing.T) {
TagsByARN: tagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshiftServerless, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{workgroupProdDB, workgroupDevDB, endpointProdDB, endpointProdDev},
},
{
@ -74,7 +69,7 @@ func TestRedshiftServerlessFetcher(t *testing.T) {
TagsByARN: tagsByARN,
},
},
inputLabels: envProdLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshiftServerless, "us-east-1", envProdLabels),
wantDatabases: types.Databases{workgroupProdDB, endpointProdDB},
},
{
@ -86,20 +81,11 @@ func TestRedshiftServerlessFetcher(t *testing.T) {
TagsByARN: tagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshiftServerless, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{workgroupProdDB},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherRedshiftServerless, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}
func makeRedshiftServerlessWorkgroup(t *testing.T, name, region string, labels map[string]string) (*redshiftserverless.Workgroup, types.Database) {

View file

@ -38,12 +38,7 @@ func TestRedshiftFetcher(t *testing.T) {
redshiftUse1Unavailable, _ := makeRedshiftCluster(t, "us-east-1", "qa", withRedshiftStatus("paused"))
redshiftUse1UnknownStatus, redshiftDatabaseUnknownStatus := makeRedshiftCluster(t, "us-east-1", "test", withRedshiftStatus("status-does-not-exist"))
tests := []struct {
name string
inputClients cloud.AWSClients
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
@ -51,7 +46,7 @@ func TestRedshiftFetcher(t *testing.T) {
Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Dev},
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshift, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUse1Dev},
},
{
@ -61,7 +56,7 @@ func TestRedshiftFetcher(t *testing.T) {
Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Dev},
},
},
inputLabels: envProdLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshift, "us-east-1", envProdLabels),
wantDatabases: types.Databases{redshiftDatabaseUse1Prod},
},
{
@ -71,20 +66,11 @@ func TestRedshiftFetcher(t *testing.T) {
Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Unavailable, redshiftUse1UnknownStatus},
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshift, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUnknownStatus},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherRedshift, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}
func makeRedshiftCluster(t *testing.T, region, env string, opts ...func(*redshift.Cluster)) (*redshift.Cluster, types.Database) {

View file

@ -29,7 +29,7 @@ import (
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
type makeAWSFetcherFunc func(context.Context, cloud.AWSClients, string, types.Labels) (common.Fetcher, error)
type makeAWSFetcherFunc func(context.Context, cloud.AWSClients, string, types.Labels, services.AssumeRole) (common.Fetcher, error)
type makeAzureFetcherFunc func(azureFetcherConfig) (common.Fetcher, error)
var (
@ -71,7 +71,7 @@ func MakeAWSFetchers(ctx context.Context, clients cloud.AWSClients, matchers []s
for _, makeFetcher := range makeFetchers {
for _, region := range matcher.Regions {
fetcher, err := makeFetcher(ctx, clients, region, matcher.Tags)
fetcher, err := makeFetcher(ctx, clients, region, matcher.Tags, matcher.AssumeRole)
if err != nil {
return nil, trace.Wrap(err)
}
@ -116,65 +116,69 @@ func MakeAzureFetchers(clients cloud.AzureClients, matchers []services.AzureMatc
}
// makeRDSInstanceFetcher returns RDS instance fetcher for the provided region and tags.
func makeRDSInstanceFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
rds, err := clients.GetAWSRDSClient(ctx, region)
func makeRDSInstanceFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
fetcher, err := newRDSDBInstancesFetcher(rdsFetcherConfig{
Region: region,
Labels: tags,
RDS: rds,
Region: region,
Labels: tags,
RDS: rds,
AssumeRole: assumeRole,
})
return fetcher, trace.Wrap(err)
}
// makeRDSAuroraFetcher returns RDS Aurora fetcher for the provided region and tags.
func makeRDSAuroraFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
rds, err := clients.GetAWSRDSClient(ctx, region)
func makeRDSAuroraFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
fetcher, err := newRDSAuroraClustersFetcher(rdsFetcherConfig{
Region: region,
Labels: tags,
RDS: rds,
Region: region,
Labels: tags,
RDS: rds,
AssumeRole: assumeRole,
})
return fetcher, trace.Wrap(err)
}
// makeRDSProxyFetcher returns RDS proxy fetcher for the provided region and tags.
func makeRDSProxyFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
rds, err := clients.GetAWSRDSClient(ctx, region)
func makeRDSProxyFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
return newRDSDBProxyFetcher(rdsFetcherConfig{
Region: region,
Labels: tags,
RDS: rds,
Region: region,
Labels: tags,
RDS: rds,
AssumeRole: assumeRole,
})
}
// makeRedshiftFetcher returns Redshift fetcher for the provided region and tags.
func makeRedshiftFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
redshift, err := clients.GetAWSRedshiftClient(ctx, region)
func makeRedshiftFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
redshift, err := clients.GetAWSRedshiftClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
return newRedshiftFetcher(redshiftFetcherConfig{
Region: region,
Labels: tags,
Redshift: redshift,
Region: region,
Labels: tags,
Redshift: redshift,
AssumeRole: assumeRole,
})
}
// makeElastiCacheFetcher returns ElastiCache fetcher for the provided region and tags.
func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
elastiCache, err := clients.GetAWSElastiCacheClient(ctx, region)
func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
elastiCache, err := clients.GetAWSElastiCacheClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
@ -182,33 +186,36 @@ func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, regio
Region: region,
Labels: tags,
ElastiCache: elastiCache,
AssumeRole: assumeRole,
})
}
// makeMemoryDBFetcher returns MemoryDB fetcher for the provided region and tags.
func makeMemoryDBFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
memorydb, err := clients.GetAWSMemoryDBClient(ctx, region)
func makeMemoryDBFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
memorydb, err := clients.GetAWSMemoryDBClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
return newMemoryDBFetcher(memoryDBFetcherConfig{
Region: region,
Labels: tags,
MemoryDB: memorydb,
Region: region,
Labels: tags,
MemoryDB: memorydb,
AssumeRole: assumeRole,
})
}
// makeRedshiftServerlessFetcher returns Redshift Serverless fetcher for the
// provided region and tags.
func makeRedshiftServerlessFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) {
client, err := clients.GetAWSRedshiftServerlessClient(ctx, region)
func makeRedshiftServerlessFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) {
client, err := clients.GetAWSRedshiftServerlessClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID))
if err != nil {
return nil, trace.Wrap(err)
}
return newRedshiftServerlessFetcher(redshiftServerlessFetcherConfig{
Region: region,
Labels: tags,
Client: client,
Region: region,
Labels: tags,
Client: client,
AssumeRole: assumeRole,
})
}
@ -228,6 +235,14 @@ func filterDatabasesByLabels(databases types.Databases, labels types.Labels, log
return matchedDatabases
}
// applyAssumeRoleToDatabases applies assume role settings from fetcher to databases.
func applyAssumeRoleToDatabases(databases types.Databases, assumeRole services.AssumeRole) {
for _, db := range databases {
db.SetAWSAssumeRole(assumeRole.RoleARN)
db.SetAWSExternalID(assumeRole.ExternalID)
}
}
// flatten flattens a nested slice [][]T to []T.
func flatten[T any](s [][]T) (result []T) {
for i := range s {

View file

@ -25,6 +25,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)
@ -43,6 +44,14 @@ func toTypeLabels(labels map[string]string) types.Labels {
return result
}
func makeAWSMatchersForType(matcherType, region string, tags map[string]string) []services.AWSMatcher {
return []services.AWSMatcher{{
Types: []string{matcherType},
Regions: []string{region},
Tags: toTypeLabels(tags),
}}
}
func mustMakeAWSFetchers(t *testing.T, clients cloud.AWSClients, matchers []services.AWSMatcher) []common.Fetcher {
t.Helper()
@ -57,16 +66,6 @@ func mustMakeAWSFetchers(t *testing.T, clients cloud.AWSClients, matchers []serv
return fetchers
}
func mustMakeAWSFetchersForMatcher(t *testing.T, clients cloud.AWSClients, matcherType, region string, tags types.Labels) []common.Fetcher {
t.Helper()
return mustMakeAWSFetchers(t, clients, []services.AWSMatcher{{
Types: []string{matcherType},
Regions: []string{region},
Tags: tags,
}})
}
func mustMakeAzureFetchers(t *testing.T, clients cloud.AzureClients, matchers []services.AzureMatcher) []common.Fetcher {
t.Helper()
@ -96,3 +95,73 @@ func mustGetDatabases(t *testing.T, fetchers []common.Fetcher) types.Databases {
}
return all
}
// testAssumeRole is a fixture for testing fetchers.
// every matcher, stub database, and mock AWS Session created uses this fixture.
// Tests will cover:
// - that fetchers use the configured assume role when using AWS cloud clients.
// - that databases discovered and created by fetchers have the assumed role used to discover them populated.
var testAssumeRole = services.AssumeRole{
RoleARN: "arn:aws:iam::123456789012:role/test-role",
ExternalID: "externalID123",
}
// awsFetcherTest is a common test struct for AWS fetchers.
type awsFetcherTest struct {
name string
inputClients *cloud.TestCloudClients
inputMatchers []services.AWSMatcher
wantDatabases types.Databases
}
// testAWSFetchers is a helper that tests AWS fetchers, since
// all of the AWS fetcher tests are fundamentally the same.
func testAWSFetchers(t *testing.T, tests ...awsFetcherTest) {
t.Helper()
for _, test := range tests {
test := test
require.Nil(t, test.inputClients.STS, "testAWSFetchers injects an STS mock itself, but test input had already configured it. This is a test configuration error.")
stsMock := &mocks.STSMock{}
test.inputClients.STS = stsMock
t.Run(test.name, func(t *testing.T) {
t.Helper()
fetchers := mustMakeAWSFetchers(t, test.inputClients, test.inputMatchers)
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
t.Run(test.name+" with assume role", func(t *testing.T) {
t.Helper()
matchers := copyAWSMatchersWithAssumeRole(testAssumeRole, test.inputMatchers...)
wantDBs := copyDatabasesWithAWSAssumeRole(testAssumeRole, test.wantDatabases...)
fetchers := mustMakeAWSFetchers(t, test.inputClients, matchers)
require.ElementsMatch(t, wantDBs, mustGetDatabases(t, fetchers))
require.Equal(t, []string{testAssumeRole.RoleARN}, stsMock.GetAssumedRoleARNs())
require.Equal(t, []string{testAssumeRole.ExternalID}, stsMock.GetAssumedRoleExternalIDs())
})
}
}
// copyDatabasesWithAWSAssumeRole copies input databases and sets a given AWS assume role for each copy.
func copyDatabasesWithAWSAssumeRole(role services.AssumeRole, databases ...types.Database) types.Databases {
if len(databases) == 0 {
return databases
}
out := make(types.Databases, 0, len(databases))
for _, db := range databases {
out = append(out, db.Copy())
}
applyAssumeRoleToDatabases(out, role)
return out
}
// copyAWSMatchersWithAssumeRole copies input AWS matchers and sets a given AWS assume role for each copy.
func copyAWSMatchersWithAssumeRole(role services.AssumeRole, matchers ...services.AWSMatcher) []services.AWSMatcher {
if len(matchers) == 0 {
return matchers
}
out := make([]services.AWSMatcher, 0, len(matchers))
for _, m := range matchers {
m.AssumeRole = role
out = append(out, m)
}
return out
}