Parse AWS info from RDS/Redshift endpoint (#7385)

This commit is contained in:
Roman Tkachenko 2021-06-23 12:41:54 -07:00 committed by GitHub
parent 792e2432b5
commit b42bec61c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 120 additions and 37 deletions

View file

@ -18,6 +18,8 @@ package types
import (
"fmt"
"net"
"strings"
"time"
"github.com/gogo/protobuf/proto"
@ -295,7 +297,6 @@ func (s *DatabaseServerV3) CheckAndSetDefaults() error {
if err := s.Metadata.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
for key := range s.Spec.DynamicLabels {
if !IsValidLabelKey(key) {
return trace.BadParameter("database server %q invalid label key: %q", s.GetName(), key)
@ -313,9 +314,62 @@ func (s *DatabaseServerV3) CheckAndSetDefaults() error {
if s.Spec.HostID == "" {
return trace.BadParameter("database server %q host ID is empty", s.GetName())
}
// In case of RDS, Aurora or Redshift, AWS information such as region or
// cluster ID can be extracted from the endpoint if not provided.
switch {
case strings.Contains(s.Spec.URI, rdsEndpointSuffix):
region, err := parseRDSEndpoint(s.Spec.URI)
if err != nil {
return trace.Wrap(err)
}
if s.Spec.AWS.Region == "" {
s.Spec.AWS.Region = region
}
case strings.Contains(s.Spec.URI, redshiftEndpointSuffix):
clusterID, region, err := parseRedshiftEndpoint(s.Spec.URI)
if err != nil {
return trace.Wrap(err)
}
if s.Spec.AWS.Redshift.ClusterID == "" {
s.Spec.AWS.Redshift.ClusterID = clusterID
}
if s.Spec.AWS.Region == "" {
s.Spec.AWS.Region = region
}
}
return nil
}
// parseRDSEndpoint extracts region from the provided RDS endpoint.
func parseRDSEndpoint(endpoint string) (region string, err error) {
host, _, err := net.SplitHostPort(endpoint)
if err != nil {
return "", trace.Wrap(err)
}
// RDS/Aurora endpoint looks like this:
// aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com
parts := strings.Split(host, ".")
if !strings.HasSuffix(host, rdsEndpointSuffix) || len(parts) != 6 {
return "", trace.BadParameter("failed to parse %v as RDS endpoint", endpoint)
}
return parts[2], nil
}
// parseRedshiftEndpoint extracts cluster ID and region from the provided Redshift endpoint.
func parseRedshiftEndpoint(endpoint string) (clusterID, region string, err error) {
host, _, err := net.SplitHostPort(endpoint)
if err != nil {
return "", "", trace.Wrap(err)
}
// Redshift endpoint looks like this:
// redshift-cluster-1.abcdefghijklmnop.us-east-1.rds.amazonaws.com
parts := strings.Split(host, ".")
if !strings.HasSuffix(host, redshiftEndpointSuffix) || len(parts) != 6 {
return "", "", trace.BadParameter("failed to parse %v as Redshift endpoint", endpoint)
}
return parts[0], parts[2], nil
}
// Copy returns a copy of this database server object.
func (s *DatabaseServerV3) Copy() DatabaseServer {
return proto.Clone(s).(*DatabaseServerV3)
@ -361,3 +415,10 @@ func DeduplicateDatabaseServers(servers []DatabaseServer) (result []DatabaseServ
}
return result
}
const (
// rdsEndpointSuffix is the RDS/Aurora endpoint suffix.
rdsEndpointSuffix = ".rds.amazonaws.com"
// redshiftEndpointSuffix is the Redshift endpoint suffix.
redshiftEndpointSuffix = ".redshift.amazonaws.com"
)

View file

@ -0,0 +1,56 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package types
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestDatabaseServerRDSEndpoint verifies AWS info is correctly populated
// based on the RDS endpoint.
func TestDatabaseServerRDSEndpoint(t *testing.T) {
server, err := NewDatabaseServerV3("rds", nil, DatabaseServerSpecV3{
Protocol: "postgres",
URI: "aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432",
Hostname: "host-1",
HostID: "host-1",
})
require.NoError(t, err)
require.Equal(t, AWS{
Region: "us-west-1",
}, server.GetAWS())
}
// TestDatabaseServerRedshiftEndpoint verifies AWS info is correctly populated
// based on the Redshift endpoint.
func TestDatabaseServerRedshiftEndpoint(t *testing.T) {
server, err := NewDatabaseServerV3("redshift", nil, DatabaseServerSpecV3{
Protocol: "postgres",
URI: "redshift-cluster-1.abcdefghijklmnop.us-east-1.redshift.amazonaws.com:5438",
Hostname: "host-1",
HostID: "host-1",
})
require.NoError(t, err)
require.Equal(t, AWS{
Region: "us-east-1",
Redshift: Redshift{
ClusterID: "redshift-cluster-1",
},
}, server.GetAWS())
}

View file

@ -1403,21 +1403,6 @@ db_service:
`,
outError: `invalid database "foo" address`,
},
{
desc: "missing Redshift region",
inConfigString: `
db_service:
enabled: true
databases:
- name: foo
protocol: postgres
uri: 192.168.1.1:5438
aws:
redshift:
cluster_id: cluster-1
`,
outError: `missing AWS region for Redshift database "foo"`,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {

View file

@ -667,12 +667,6 @@ func (d *Database) Check() error {
d.Name, err)
}
}
// Validate Redshift specific configuration.
if d.AWS.Redshift.ClusterID != "" {
if d.AWS.Region == "" {
return trace.BadParameter("missing AWS region for Redshift database %q", d.Name)
}
}
// Validate Cloud SQL specific configuration.
switch {
case d.GCP.ProjectID != "" && d.GCP.InstanceID == "":

View file

@ -282,20 +282,6 @@ func TestCheckDatabase(t *testing.T) {
},
outErr: true,
},
{
desc: "Redshift region not set",
inDatabase: Database{
Name: "example",
Protocol: defaults.ProtocolPostgres,
URI: "redshift-cluster-1.aaa.us-east-1.redshift.amazonaws.com:5439",
AWS: DatabaseAWS{
Redshift: DatabaseAWSRedshift{
ClusterID: "redshift-cluster-1",
},
},
},
outErr: true,
},
{
desc: "MongoDB connection string",
inDatabase: Database{

View file

@ -169,7 +169,7 @@ func New(ctx context.Context, config Config) (*Server, error) {
// starting up dynamic labels and loading root certs for RDS dbs.
for _, db := range server.cfg.Servers {
if err := server.initDatabaseServer(ctx, db); err != nil {
return nil, trace.Wrap(err)
return nil, trace.Wrap(err, "failed to initialize %v", server)
}
}
@ -186,6 +186,7 @@ func (s *Server) initDatabaseServer(ctx context.Context, server types.DatabaseSe
if err := s.initCACert(ctx, server); err != nil {
return trace.Wrap(err)
}
s.log.Debugf("Initialized %v.", server)
return nil
}