Azure Cache for Redis engine support (#16551)

This commit is contained in:
STeve (Xin) Huang 2022-09-29 14:25:53 -04:00 committed by GitHub
parent af7cd0239d
commit aabced42dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 1490 additions and 79 deletions

View file

@ -18,13 +18,13 @@ package types
import (
"fmt"
"net"
"strings"
"text/template"
"time"
"github.com/gravitational/teleport/api/utils"
awsutils "github.com/gravitational/teleport/api/utils/aws"
azureutils "github.com/gravitational/teleport/api/utils/azure"
"github.com/gogo/protobuf/proto"
"github.com/google/go-cmp/cmp"
@ -511,11 +511,27 @@ func (d *DatabaseV3) CheckAndSetDefaults() error {
}
d.Spec.AWS.MemoryDB.TLSEnabled = endpointInfo.TransitEncryptionEnabled
d.Spec.AWS.MemoryDB.EndpointType = endpointInfo.EndpointType
case strings.Contains(d.Spec.URI, AzureEndpointSuffix):
name, err := parseAzureEndpoint(d.Spec.URI)
case azureutils.IsDatabaseEndpoint(d.Spec.URI):
// For Azure MySQL and PostgresSQL.
name, err := azureutils.ParseDatabaseEndpoint(d.Spec.URI)
if err != nil {
return trace.Wrap(err)
}
if d.Spec.Azure.Name == "" {
d.Spec.Azure.Name = name
}
case azureutils.IsCacheForRedisEndpoint(d.Spec.URI):
// ResourceID is required for fetching Redis tokens.
if d.Spec.Azure.ResourceID == "" {
return trace.BadParameter("missing ResourceID for Azure Cache %v", d.Metadata.Name)
}
name, err := azureutils.ParseCacheForRedisEndpoint(d.Spec.URI)
if err != nil {
return trace.Wrap(err)
}
if d.Spec.Azure.Name == "" {
d.Spec.Azure.Name = name
}
@ -523,21 +539,6 @@ func (d *DatabaseV3) CheckAndSetDefaults() error {
return nil
}
// parseAzureEndpoint extracts database server name from Azure endpoint.
func parseAzureEndpoint(endpoint string) (name string, err error) {
host, _, err := net.SplitHostPort(endpoint)
if err != nil {
return "", trace.Wrap(err)
}
// Azure endpoint looks like this:
// name.mysql.database.azure.com
parts := strings.Split(host, ".")
if !strings.HasSuffix(host, AzureEndpointSuffix) || len(parts) != 5 {
return "", trace.BadParameter("failed to parse %v as Azure endpoint", endpoint)
}
return parts[0], nil
}
// GetIAMPolicy returns AWS IAM policy for this database.
func (d *DatabaseV3) GetIAMPolicy() (string, error) {
if d.IsRDS() {
@ -721,11 +722,6 @@ func (d Databases) Less(i, j int) bool { return d[i].GetName() < d[j].GetName()
// Swap swaps two databases.
func (d Databases) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
const (
// AzureEndpointSuffix is the Azure database endpoint suffix.
AzureEndpointSuffix = ".database.azure.com"
)
type arnTemplateInput struct {
Partition, Region, AccountID, ResourceID string
}

View file

@ -175,6 +175,108 @@ func TestDatabaseMemoryDBEndpoint(t *testing.T) {
})
}
func TestDatabaseAzureEndpoints(t *testing.T) {
t.Parallel()
tests := []struct {
name string
spec DatabaseSpecV3
expectError bool
expectAzure Azure
}{
{
name: "valid MySQL",
spec: DatabaseSpecV3{
Protocol: "mysql",
URI: "example-mysql.mysql.database.azure.com:3306",
},
expectAzure: Azure{
Name: "example-mysql",
},
},
{
name: "valid PostgresSQL",
spec: DatabaseSpecV3{
Protocol: "postgres",
URI: "example-postgres.postgres.database.azure.com:5432",
},
expectAzure: Azure{
Name: "example-postgres",
},
},
{
name: "invalid database endpoint",
spec: DatabaseSpecV3{
Protocol: "postgres",
URI: "invalid.database.azure.com:5432",
},
expectError: true,
},
{
name: "valid Redis",
spec: DatabaseSpecV3{
Protocol: "redis",
URI: "example-redis.redis.cache.windows.net:6380",
Azure: Azure{
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-redis",
},
},
expectAzure: Azure{
Name: "example-redis",
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-redis",
},
},
{
name: "valid Redis Enterprise",
spec: DatabaseSpecV3{
Protocol: "redis",
URI: "rediss://example-redis-enterprise.region.redisenterprise.cache.azure.net?mode=cluster",
Azure: Azure{
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-redis-enterprise",
},
},
expectAzure: Azure{
Name: "example-redis-enterprise",
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-redis-enterprise",
},
},
{
name: "invalid Redis (missing resource ID)",
spec: DatabaseSpecV3{
Protocol: "redis",
URI: "rediss://example-redis-enterprise.region.redisenterprise.cache.azure.net?mode=cluster",
},
expectError: true,
},
{
name: "invalid Redis (unknown format)",
spec: DatabaseSpecV3{
Protocol: "redis",
URI: "rediss://bad-format.redisenterprise.cache.azure.net?mode=cluster",
Azure: Azure{
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/bad-format",
},
},
expectError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
database, err := NewDatabaseV3(Metadata{
Name: "test",
}, test.spec)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expectAzure, database.GetAzure())
}
})
}
}
func TestMySQLVersionValidation(t *testing.T) {
t.Parallel()

View file

@ -0,0 +1,122 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"net"
"net/url"
"strings"
"github.com/gravitational/trace"
)
// IsDatabaseEndpoint returns true if provided endpoint is a valid database
// endpoint.
func IsDatabaseEndpoint(endpoint string) bool {
return strings.Contains(endpoint, DatabaseEndpointSuffix)
}
// IsCacheForRedisEndpoint returns true if provided endpoint is a valid Azure
// Cache for Redis endpoint.
func IsCacheForRedisEndpoint(endpoint string) bool {
return IsRedisEndpoint(endpoint) || IsRedisEnterpriseEndpoint(endpoint)
}
// IsRedisEndpoint returns true if provided endpoint is a valid Redis
// (non-Enterprise tier) endpoint.
func IsRedisEndpoint(endpoint string) bool {
return strings.Contains(endpoint, RedisEndpointSuffix)
}
// IsRedisEnterpriseEndpoint returns true if provided endpoint is a valid Redis
// Enterprise endpoint.
func IsRedisEnterpriseEndpoint(endpoint string) bool {
return strings.Contains(endpoint, RedisEnterpriseEndpointSuffix)
}
// ParseDatabaseEndpoint extracts database server name from Azure endpoint.
func ParseDatabaseEndpoint(endpoint string) (name string, err error) {
host, _, err := net.SplitHostPort(endpoint)
if err != nil {
return "", trace.Wrap(err)
}
// Azure endpoint looks like this:
// name.mysql.database.azure.com
parts := strings.Split(host, ".")
if !strings.HasSuffix(host, DatabaseEndpointSuffix) || len(parts) != 5 {
return "", trace.BadParameter("failed to parse %v as Azure endpoint", endpoint)
}
return parts[0], nil
}
// ParseCacheForRedisEndpoint extracts database server name from Azure Cache
// for Redis endpoint.
func ParseCacheForRedisEndpoint(endpoint string) (name string, err error) {
// Note that the Redis URI may contain schema and parameters.
host, err := GetHostFromRedisURI(endpoint)
if err != nil {
return "", trace.Wrap(err)
}
switch {
// Redis (non-Enterprise) endpoint looks like this:
// name.redis.cache.windows.net
case strings.HasSuffix(host, RedisEndpointSuffix):
return strings.TrimSuffix(host, RedisEndpointSuffix), nil
// Redis Enterprise endpoint looks like this:
// name.region.redisenterprise.cache.azure.net
case strings.HasSuffix(host, RedisEnterpriseEndpointSuffix):
name, _, ok := strings.Cut(strings.TrimSuffix(host, RedisEnterpriseEndpointSuffix), ".")
if !ok {
return "", trace.BadParameter("failed to parse %v as Azure Cache endpoint", endpoint)
}
return name, nil
default:
return "", trace.BadParameter("failed to parse %v as Azure Cache endpoint", endpoint)
}
}
// GetHostFromRedisURI extracts host name from a Redis URI. The URI may start
// with "redis://", "rediss://", or without. The URI may also have parameters
// like "?mode=cluster".
func GetHostFromRedisURI(uri string) (string, error) {
// Add a temporary schema to make a valid URL for url.Parse if schema is
// not found.
if !strings.Contains(uri, "://") {
uri = "schema://" + uri
}
parsed, err := url.Parse(uri)
if err != nil {
return "", trace.Wrap(err)
}
return parsed.Hostname(), nil
}
const (
// DatabaseEndpointSuffix is the Azure database endpoint suffix. Used for
// MySQL, PostgresSQL, etc.
DatabaseEndpointSuffix = ".database.azure.com"
// RedisEndpointSuffix is the endpoint suffix for Redis.
RedisEndpointSuffix = ".redis.cache.windows.net"
// RedisEnterpriseEndpointSuffix is the endpoint suffix for Redis Enterprise.
RedisEnterpriseEndpointSuffix = ".redisenterprise.cache.azure.net"
)

View file

@ -0,0 +1,39 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"testing"
"github.com/stretchr/testify/require"
)
func FuzzParseDatabaseEndpoint(f *testing.F) {
f.Fuzz(func(t *testing.T, endpoint string) {
require.NotPanics(t, func() {
ParseDatabaseEndpoint(endpoint)
})
})
}
func FuzzParseCacheForRedisEndpoint(f *testing.F) {
f.Fuzz(func(t *testing.T, endpoint string) {
require.NotPanics(t, func() {
ParseCacheForRedisEndpoint(endpoint)
})
})
}

2
go.mod
View file

@ -10,6 +10,8 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2 v2.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.4.1
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1

4
go.sum
View file

@ -98,6 +98,10 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0 h1:3
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0/go.mod h1:yFGqqJ4W/nOViqHDfuwmjyJtZXLmmMoHN0DNPCigKUE=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0 h1:A6qn+g+bsKoBhFzDFXLhNAup//D+Q7+MuofypSUtNfY=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0/go.mod h1:GgxvszemyuFZyiw4vPxGib+Cp6z7Q3rYQb4DsKPOAAw=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2 v2.0.0 h1:gQhbudzftueZd/xWdFHDibO9kNghIn4B7DH1yrtoGwg=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2 v2.0.0/go.mod h1:ydU7PuJLTH175OGSecHKlCi9d6wTEh8aX5JsNL4PPSs=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise v1.0.0 h1:ZyZb7h9t0hw6pNdCJ1bEwoRsVJ+gHngUlexPfSfj3ZA=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise v1.0.0/go.mod h1:KOdfFo3kIa8BlSMMNTslBV7MTVyVnrKgxhn0rjlMNxM=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0 h1:vsovXlTyKHZXnqzQyt7QMVkwpJBDkHchQL53qXaGBRY=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0/go.mod h1:UZy1vHcRdEymNP1d6fTrvYHpSdkXoUdowfrvffcQOOU=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.4.1 h1:QSdcrd/UFJv6Bp/CfoVf2SrENpFn9P6Yh8yb+xNhYMM=

View file

@ -0,0 +1,75 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"sync"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/gravitational/trace"
)
// ClientMap is a generic map that caches a collection of Azure clients by
// subscriptions.
type ClientMap[ClientType any] struct {
mu sync.RWMutex
clients map[string]ClientType
newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)
}
// NewClientMap creates a new ClientMap.
func NewClientMap[ClientType any](newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)) ClientMap[ClientType] {
return ClientMap[ClientType]{
clients: make(map[string]ClientType),
newClient: newClient,
}
}
// Get returns an Azure client by subscription. A new client is created if the
// subscription is not found in the map.
func (m *ClientMap[ClientType]) Get(subscription string, getCredentials func() (azcore.TokenCredential, error)) (client ClientType, err error) {
m.mu.RLock()
if client, ok := m.clients[subscription]; ok {
m.mu.RUnlock()
return client, nil
}
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
// If some other thread already got here first.
if client, ok := m.clients[subscription]; ok {
return client, nil
}
cred, err := getCredentials()
if err != nil {
return client, trace.Wrap(err)
}
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
options := &arm.ClientOptions{}
client, err = m.newClient(subscription, cred, options)
if err != nil {
return client, trace.Wrap(err)
}
m.clients[subscription] = client
return client, nil
}

View file

@ -0,0 +1,76 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
func TestClientMap(t *testing.T) {
t.Parallel()
mockNewClientFunc := func(subscription string, cred azcore.TokenCredential, opts *arm.ClientOptions) (CacheForRedisClient, error) {
if subscription == "good-sub" {
return NewRedisClientByAPI(nil), nil
}
return nil, trace.BadParameter("failed to create")
}
clientMap := NewClientMap(mockNewClientFunc)
// Note that some test cases (e.g. "get from cache") depend on previous
// test cases. Thus running in sequence.
t.Run("get credentials failed", func(t *testing.T) {
client, err := clientMap.Get("some-sub", func() (azcore.TokenCredential, error) {
return nil, trace.AccessDenied("failed to get credentials")
})
require.ErrorIs(t, err, trace.AccessDenied("failed to get credentials"))
require.Nil(t, client)
})
t.Run("create client failed", func(t *testing.T) {
client, err := clientMap.Get("bad-sub", func() (azcore.TokenCredential, error) {
return nil, nil
})
require.ErrorIs(t, err, trace.BadParameter("failed to create"))
require.Nil(t, client)
})
t.Run("create client succeed", func(t *testing.T) {
client, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) {
return nil, nil
})
require.NoError(t, err)
require.NotNil(t, client)
require.IsType(t, NewRedisClientByAPI(nil), client)
})
t.Run("get from cache", func(t *testing.T) {
// Return an error for getCredentials but it shouldn't even be called
// as the client is returned from existing cache.
client, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) {
return nil, trace.AccessDenied("failed to get credentials")
})
require.NoError(t, err)
require.NotNil(t, client)
require.IsType(t, NewRedisClientByAPI(nil), client)
})
}

View file

@ -21,6 +21,7 @@ import (
"net/http"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/gravitational/trace"
)
@ -29,20 +30,21 @@ import (
// to trace error. If the provided error is not a `ResponseError` it returns.
// the error without modifying it.
func ConvertResponseError(err error) error {
responseErr, ok := err.(*azcore.ResponseError)
if !ok {
return err
}
switch v := err.(type) {
case *azcore.ResponseError:
switch v.StatusCode {
case http.StatusForbidden:
return trace.AccessDenied(v.Error())
case http.StatusConflict:
return trace.AlreadyExists(v.Error())
case http.StatusNotFound:
return trace.NotFound(v.Error())
}
switch responseErr.StatusCode {
case http.StatusForbidden:
return trace.AccessDenied(responseErr.Error())
case http.StatusConflict:
return trace.AlreadyExists(responseErr.Error())
case http.StatusNotFound:
return trace.NotFound(responseErr.Error())
}
case *azidentity.AuthenticationFailedError:
return trace.AccessDenied(v.Error())
}
return err // Return unmodified.
}

View file

@ -59,3 +59,9 @@ type ARMPostgres interface {
}
var _ ARMPostgres = (*armpostgresql.ServersClient)(nil)
// CacheForRedisClient provides an interface for an Azure Redis For Cache client.
type CacheForRedisClient interface {
// GetToken retrieves the auth token for provided resource ID.
GetToken(ctx context.Context, resourceID string) (string, error)
}

View file

@ -23,6 +23,8 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription"
"github.com/gravitational/trace"
@ -202,3 +204,47 @@ func (m *ARMPostgresMock) NewListByResourceGroupPager(group string, _ *armpostgr
},
})
}
// ARMRedisMock mocks armRedisClient.
type ARMRedisMock struct {
Token string
NoAuth bool
}
func (m *ARMRedisMock) ListKeys(ctx context.Context, resourceGroupName string, name string, options *armredis.ClientListKeysOptions) (armredis.ClientListKeysResponse, error) {
if m.NoAuth {
return armredis.ClientListKeysResponse{}, trace.AccessDenied("unauthorized")
}
return armredis.ClientListKeysResponse{
AccessKeys: armredis.AccessKeys{
PrimaryKey: &m.Token,
},
}, nil
}
// ARMRedisEnterpriseDatabaseMock mocks armRedisEnterpriseDatabaseClient.
type ARMRedisEnterpriseDatabaseMock struct {
Token string
TokensByDatabaseName map[string]string
NoAuth bool
}
func (m *ARMRedisEnterpriseDatabaseMock) ListKeys(ctx context.Context, resourceGroupName string, clusterName string, databaseName string, options *armredisenterprise.DatabasesClientListKeysOptions) (armredisenterprise.DatabasesClientListKeysResponse, error) {
if m.NoAuth {
return armredisenterprise.DatabasesClientListKeysResponse{}, trace.AccessDenied("unauthorized")
}
if len(m.TokensByDatabaseName) != 0 {
if token, found := m.TokensByDatabaseName[databaseName]; found {
return armredisenterprise.DatabasesClientListKeysResponse{
AccessKeys: armredisenterprise.AccessKeys{
PrimaryKey: &token,
},
}, nil
}
}
return armredisenterprise.DatabasesClientListKeysResponse{
AccessKeys: armredisenterprise.AccessKeys{
PrimaryKey: &m.Token,
},
}, nil
}

79
lib/cloud/azure/redis.go Normal file
View file

@ -0,0 +1,79 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2"
"github.com/sirupsen/logrus"
"github.com/gravitational/trace"
)
// armRedisClient is an interface defines a subset of functions of armredis.Client.
type armRedisClient interface {
ListKeys(ctx context.Context, resourceGroupName string, name string, options *armredis.ClientListKeysOptions) (armredis.ClientListKeysResponse, error)
}
// redisClient is an Azure Redis client.
type redisClient struct {
api armRedisClient
}
// NewRedisClient creates a new Azure Redis client by subscription and credentials.
func NewRedisClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (CacheForRedisClient, error) {
logrus.Debug("Initializing Azure Redis client.")
api, err := armredis.NewClient(subscription, cred, options)
if err != nil {
return nil, trace.Wrap(err)
}
return NewRedisClientByAPI(api), nil
}
// NewRedisClientByAPI creates a new Azure Redis client by ARM API client.
func NewRedisClientByAPI(api armRedisClient) CacheForRedisClient {
return &redisClient{
api: api,
}
}
// GetToken retrieves the auth token for provided resource group and resource
// name.
func (c *redisClient) GetToken(ctx context.Context, resourceID string) (string, error) {
id, err := arm.ParseResourceID(resourceID)
if err != nil {
return "", trace.Wrap(err)
}
resp, err := c.api.ListKeys(ctx, id.ResourceGroupName, id.Name, &armredis.ClientListKeysOptions{})
if err != nil {
return "", trace.Wrap(ConvertResponseError(err))
}
// There are two keys. Pick first one available.
if resp.PrimaryKey != nil {
return *resp.PrimaryKey, nil
}
if resp.SecondaryKey != nil {
return *resp.SecondaryKey, nil
}
return "", trace.NotFound("missing keys")
}

View file

@ -0,0 +1,107 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise"
"github.com/sirupsen/logrus"
"github.com/gravitational/trace"
)
// armRedisEnterpriseDatabaseClient is an interface defines a subset of
// functions of armredisenterprise.DatabaseClient
type armRedisEnterpriseDatabaseClient interface {
ListKeys(ctx context.Context, resourceGroupName string, clusterName string, databaseName string, options *armredisenterprise.DatabasesClientListKeysOptions) (armredisenterprise.DatabasesClientListKeysResponse, error)
}
// redisEnterpriseClient is an Azure Redis Enterprise client.
type redisEnterpriseClient struct {
databaseAPI armRedisEnterpriseDatabaseClient
}
// NewRedisEnterpriseClient creates a new Azure Redis Enterprise client by
// subscription and credentials.
func NewRedisEnterpriseClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (CacheForRedisClient, error) {
logrus.Debug("Initializing Azure Redis Enterprise client.")
databaseAPI, err := armredisenterprise.NewDatabasesClient(subscription, cred, options)
if err != nil {
return nil, trace.Wrap(err)
}
// TODO(greedy52) Redis Enterprise requires a different API client
// (armredisenterprise.Client) for auto-discovery.
return NewRedisEnterpriseClientByAPI(databaseAPI), nil
}
// NewRedisEnterpriseClientByAPI creates a new Azure Redis Enterprise client by
// ARM API client(s).
func NewRedisEnterpriseClientByAPI(databaseAPI armRedisEnterpriseDatabaseClient) CacheForRedisClient {
return &redisEnterpriseClient{
databaseAPI: databaseAPI,
}
}
// GetToken retrieves the auth token for provided resource group and resource
// name.
func (c *redisEnterpriseClient) GetToken(ctx context.Context, resourceID string) (string, error) {
id, err := arm.ParseResourceID(resourceID)
if err != nil {
return "", trace.Wrap(err)
}
clusterName, databaseName, err := c.getClusterAndDatabaseName(id)
if err != nil {
return "", trace.Wrap(err)
}
resp, err := c.databaseAPI.ListKeys(ctx, id.ResourceGroupName, clusterName, databaseName, &armredisenterprise.DatabasesClientListKeysOptions{})
if err != nil {
return "", trace.Wrap(ConvertResponseError(err))
}
// There are two keys. Pick first one available.
if resp.PrimaryKey != nil {
return *resp.PrimaryKey, nil
}
if resp.SecondaryKey != nil {
return *resp.SecondaryKey, nil
}
return "", trace.NotFound("missing keys")
}
// getClusterAndDatabaseName returns the cluster name and the database name
// based on the resource ID. Both armredisenterprise.Cluster.ID and
// armredisenterprise.Database.ID are supported.
func (c *redisEnterpriseClient) getClusterAndDatabaseName(id *arm.ResourceID) (string, string, error) {
switch id.ResourceType.String() {
case "Microsoft.Cache/redisEnterprise":
// It appears an Enterprise cluster always has only one "database", and
// the database name is always "default".
return id.Name, RedisEnterpriseClusterDefaultDatabase, nil
case "Microsoft.Cache/redisEnterprise/databases":
return id.Parent.Name, id.Name, nil
default:
return "", "", trace.BadParameter("unknown Azure Cache for Redis resource type: %v", id.ResourceType)
}
}
// RedisEnterpriseClusterDefaultDatabase is the default database name for a
// Redis Enterprise cluster.
const RedisEnterpriseClusterDefaultDatabase = "default"

View file

@ -0,0 +1,81 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedisEnterpriseClient(t *testing.T) {
t.Run("GetToken", func(t *testing.T) {
tests := []struct {
name string
mockDatabaseAPI armRedisEnterpriseDatabaseClient
resourceID string
expectError bool
expectToken string
}{
{
name: "access denied",
resourceID: "cluster-name",
mockDatabaseAPI: &ARMRedisEnterpriseDatabaseMock{
NoAuth: true,
},
expectError: true,
},
{
name: "succeed (default database name)",
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport",
mockDatabaseAPI: &ARMRedisEnterpriseDatabaseMock{
TokensByDatabaseName: map[string]string{
"default": "some-token",
},
},
expectToken: "some-token",
},
{
name: "succeed (specific database name)",
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport/databases/some-database",
mockDatabaseAPI: &ARMRedisEnterpriseDatabaseMock{
TokensByDatabaseName: map[string]string{
"some-database": "some-token",
},
},
expectToken: "some-token",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
c := NewRedisEnterpriseClientByAPI(test.mockDatabaseAPI)
token, err := c.GetToken(context.TODO(), test.resourceID)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expectToken, token)
}
})
}
})
}

View file

@ -0,0 +1,66 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedisClient(t *testing.T) {
t.Run("GetToken", func(t *testing.T) {
tests := []struct {
name string
mockAPI armRedisClient
expectError bool
expectToken string
}{
{
name: "access denied",
mockAPI: &ARMRedisMock{
NoAuth: true,
},
expectError: true,
},
{
name: "succeed",
mockAPI: &ARMRedisMock{
Token: "some-token",
},
expectToken: "some-token",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
c := NewRedisClientByAPI(test.mockAPI)
token, err := c.GetToken(context.TODO(), "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-teleport")
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expectToken, token)
}
})
}
})
}

View file

@ -100,6 +100,10 @@ type AzureClients interface {
GetAzurePostgresClient(subscription string) (azure.DBServersClient, error)
// GetAzureSubscriptionClient returns an Azure Subscriptions client
GetAzureSubscriptionClient() (*azure.SubscriptionClient, error)
// GetAzureRedisClient returns an Azure Redis client for the given subscription.
GetAzureRedisClient(subscription string) (azure.CacheForRedisClient, error)
// GetAzureRedisEnterpriseClient returns an Azure Redis Enterprise client for the given subscription.
GetAzureRedisEnterpriseClient(subscription string) (azure.CacheForRedisClient, error)
}
// NewClients returns a new instance of cloud clients retriever.
@ -107,8 +111,10 @@ func NewClients() Clients {
return &cloudClients{
awsSessions: make(map[string]*awssession.Session),
azureClients: azureClients{
azureMySQLClients: make(map[string]azure.DBServersClient),
azurePostgresClients: make(map[string]azure.DBServersClient),
azureMySQLClients: make(map[string]azure.DBServersClient),
azurePostgresClients: make(map[string]azure.DBServersClient),
azureRedisClients: azure.NewClientMap(azure.NewRedisClient),
azureRedisEnterpriseClients: azure.NewClientMap(azure.NewRedisEnterpriseClient),
},
}
}
@ -139,6 +145,10 @@ type azureClients struct {
azurePostgresClients map[string]azure.DBServersClient
// azureSubscriptionsClient is the cached Azure Subscriptions client.
azureSubscriptionsClient *azure.SubscriptionClient
// azureRedisClients contains the cached Azure Redis clients.
azureRedisClients azure.ClientMap[azure.CacheForRedisClient]
// azureRedisEnterpriseClients contains the cached Azure Redis Enterprise clients.
azureRedisEnterpriseClients azure.ClientMap[azure.CacheForRedisClient]
}
// GetAWSSession returns AWS session for the specified region.
@ -299,6 +309,16 @@ func (c *cloudClients) GetAzureSubscriptionClient() (*azure.SubscriptionClient,
return c.initAzureSubscriptionsClient()
}
// GetAzureRedisClient returns an Azure Redis client for the given subscription.
func (c *cloudClients) GetAzureRedisClient(subscription string) (azure.CacheForRedisClient, error) {
return c.azureRedisClients.Get(subscription, c.GetAzureCredential)
}
// GetAzureRedisEnterpriseClient returns an Azure Redis Enterprise client for the given subscription.
func (c *cloudClients) GetAzureRedisEnterpriseClient(subscription string) (azure.CacheForRedisClient, error) {
return c.azureRedisEnterpriseClients.Get(subscription, c.GetAzureCredential)
}
// Close closes all initialized clients.
func (c *cloudClients) Close() (err error) {
c.mtx.Lock()
@ -469,6 +489,8 @@ type TestCloudClients struct {
AzurePostgres azure.DBServersClient
AzurePostgresPerSub map[string]azure.DBServersClient
AzureSubscriptionClient *azure.SubscriptionClient
AzureRedis azure.CacheForRedisClient
AzureRedisEnterprise azure.CacheForRedisClient
}
// GetAWSSession returns AWS session for the specified region.
@ -562,6 +584,16 @@ func (c *TestCloudClients) GetAWSSSMClient(region string) (ssmiface.SSMAPI, erro
return c.SSM, nil
}
// GetAzureRedisClient returns an Azure Redis client for the given subscription.
func (c *TestCloudClients) GetAzureRedisClient(subscription string) (azure.CacheForRedisClient, error) {
return c.AzureRedis, nil
}
// GetAzureRedisEnterpriseClient returns an Azure Redis Enterprise client for the given subscription.
func (c *TestCloudClients) GetAzureRedisEnterpriseClient(subscription string) (azure.CacheForRedisClient, error) {
return c.AzureRedisEnterprise, nil
}
// Close closes all initialized clients.
func (c *TestCloudClients) Close() error {
return nil

View file

@ -1286,6 +1286,9 @@ func applyDatabasesConfig(fc *FileConfig, cfg *service.Config) error {
Domain: database.AD.Domain,
SPN: database.AD.SPN,
},
Azure: service.DatabaseAzure{
ResourceID: database.Azure.ResourceID,
},
}
if err := db.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)

View file

@ -1341,6 +1341,8 @@ type Database struct {
GCP DatabaseGCP `yaml:"gcp"`
// AD contains Active Directory database configuration.
AD DatabaseAD `yaml:"ad"`
// Azure contains Azure database configuration.
Azure DatabaseAzure `yaml:"azure"`
}
// DatabaseAD contains database Active Directory configuration.
@ -1431,6 +1433,12 @@ type DatabaseGCP struct {
InstanceID string `yaml:"instance_id,omitempty"`
}
// DatabaseAzure contains Azure database configuration.
type DatabaseAzure struct {
// ResourceID is the Azure fully qualified ID for the resource.
ResourceID string `yaml:"resource_id,omitempty"`
}
// Apps represents the configuration for the collection of applications this
// service will start. In file configuration this would be the "app_service"
// section.

View file

@ -777,6 +777,8 @@ type Database struct {
GCP DatabaseGCP
// AD contains Active Directory configuration for database.
AD DatabaseAD
// Azure contains Azure database configuration.
Azure DatabaseAzure
}
// TLSMode defines all possible database verification modes.
@ -908,6 +910,12 @@ type DatabaseAD struct {
SPN string
}
// DatabaseAzure contains Azure database configuration.
type DatabaseAzure struct {
// ResourceID is the Azure fully qualified ID for the resource.
ResourceID string `yaml:"resource_id,omitempty"`
}
// CheckAndSetDefaults validates database Active Directory configuration.
func (d *DatabaseAD) CheckAndSetDefaults(name string) error {
if d.KeytabFile == "" {

View file

@ -125,6 +125,9 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) {
Domain: db.AD.Domain,
SPN: db.AD.SPN,
},
Azure: types.Azure{
ResourceID: db.Azure.ResourceID,
},
})
if err != nil {
return trace.Wrap(err)

View file

@ -33,6 +33,7 @@ import (
"github.com/gravitational/teleport/api/types"
awsutils "github.com/gravitational/teleport/api/utils/aws"
azureutils "github.com/gravitational/teleport/api/utils/azure"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
@ -121,7 +122,7 @@ func TestDatabaseFromAzureDBServer(t *testing.T) {
Name: name,
Port: "3306",
Properties: azure.ServerProperties{
FullyQualifiedDomainName: name + ".mysql" + types.AzureEndpointSuffix,
FullyQualifiedDomainName: name + ".mysql" + azureutils.DatabaseEndpointSuffix,
UserVisibleState: string(armmysql.ServerStateReady),
Version: string(armmysql.ServerVersionFive7),
},

View file

@ -2011,7 +2011,9 @@ func withRDSPostgres(name, authToken string) withDatabaseOption {
Region: testAWSRegion,
},
// Set CA cert, otherwise we will attempt to download RDS roots.
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
TLS: types.DatabaseTLS{
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
},
})
require.NoError(t, err)
testCtx.postgres[name] = testPostgres{
@ -2043,7 +2045,9 @@ func withRedshiftPostgres(name, authToken string) withDatabaseOption {
Redshift: types.Redshift{ClusterID: "redshift-cluster-1"},
},
// Set CA cert, otherwise we will attempt to download Redshift roots.
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
TLS: types.DatabaseTLS{
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
},
})
require.NoError(t, err)
testCtx.postgres[name] = testPostgres{
@ -2078,7 +2082,9 @@ func withCloudSQLPostgres(name, authToken string) withDatabaseOption {
InstanceID: "instance-1",
},
// Set CA cert to pass cert validation.
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
TLS: types.DatabaseTLS{
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
},
})
require.NoError(t, err)
testCtx.postgres[name] = testPostgres{
@ -2109,7 +2115,9 @@ func withAzurePostgres(name, authToken string) withDatabaseOption {
Name: name,
},
// Set CA cert, otherwise we will attempt to download RDS roots.
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
TLS: types.DatabaseTLS{
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
},
})
require.NoError(t, err)
testCtx.postgres[name] = testPostgres{
@ -2363,6 +2371,38 @@ func withSQLServer(name string) withDatabaseOption {
}
}
func withAzureRedis(name string, token string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
redisServer, err := redis.NewTestServer(t, common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
}, redis.TestServerPassword(token))
require.NoError(t, err)
database, err := types.NewDatabaseV3(types.Metadata{
Name: name,
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolRedis,
URI: fmt.Sprintf("rediss://%s", net.JoinHostPort("localhost", redisServer.Port())),
DynamicLabels: dynamicLabels,
Azure: types.Azure{
Name: name,
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-teleport",
},
// Set CA cert to pass cert validation.
TLS: types.DatabaseTLS{
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
},
})
require.NoError(t, err)
testCtx.redis[name] = testRedis{
db: redisServer,
resource: database,
}
return database
}
}
var dynamicLabels = types.LabelsToV2(map[string]types.CommandLabel{
"echo": &types.CommandLabelV2{
Period: types.NewDuration(time.Second),

View file

@ -55,7 +55,10 @@ func TestAuthTokens(t *testing.T) {
withCloudSQLMySQL("mysql-cloudsql-correct-token", "root", cloudSQLPassword),
withCloudSQLMySQL("mysql-cloudsql-incorrect-token", "root", "qwe123"),
withAzureMySQL("mysql-azure-correct-token", "root", azureAccessToken),
withAzureMySQL("mysql-azure-incorrect-token", "root", "qwe123"))
withAzureMySQL("mysql-azure-incorrect-token", "root", "qwe123"),
withAzureRedis("redis-azure-correct-token", azureRedisToken),
withAzureRedis("redis-azure-incorrect-token", "qwe123"),
)
go testCtx.startHandlingConnections()
testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard})
@ -123,6 +126,17 @@ func TestAuthTokens(t *testing.T) {
protocol: defaults.ProtocolMySQL,
err: "Access denied for user",
},
{
desc: "correct Azure Redis auth token",
service: "redis-azure-correct-token",
protocol: defaults.ProtocolRedis,
},
{
desc: "incorrect Azure Redis auth token",
service: "redis-azure-incorrect-token",
protocol: defaults.ProtocolRedis,
err: "WRONGPASS invalid username-password pair",
},
}
for _, test := range tests {
@ -146,6 +160,15 @@ func TestAuthTokens(t *testing.T) {
require.NoError(t, err)
require.NoError(t, conn.Close())
}
case defaults.ProtocolRedis:
conn, err := testCtx.redisClient(ctx, "alice", test.service, "default")
if test.err != "" {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
} else {
require.NoError(t, err)
require.NoError(t, conn.Close())
}
default:
t.Fatalf("unrecognized database protocol in test: %q", test.protocol)
}
@ -186,6 +209,8 @@ const (
cloudSQLPassword = "cloudsql-password"
// azureAccessToken is a mock Azure access token.
azureAccessToken = "azure-access-token"
// azureRedisToken is a mock Azure Redis token.
azureRedisToken = "azure-redis-token"
)
// GetRDSAuthToken generates RDS/Aurora auth token.
@ -218,6 +243,12 @@ func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.S
return azureAccessToken, nil
}
// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure Redis token for %v.", sessionCtx)
return azureRedisToken, nil
}
func TestDBCertSigning(t *testing.T) {
authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
Clock: clockwork.NewFakeClockAt(time.Now()),

View file

@ -27,6 +27,7 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
awsutils "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
@ -47,8 +48,14 @@ func (s *Server) initCACert(ctx context.Context, database types.Database) error
types.DatabaseTypeRedshift,
types.DatabaseTypeElastiCache,
types.DatabaseTypeMemoryDB,
types.DatabaseTypeCloudSQL,
types.DatabaseTypeAzure:
types.DatabaseTypeCloudSQL:
case types.DatabaseTypeAzure:
// Azure Cache for Redis uses system cert poool
if database.GetProtocol() == defaults.ProtocolRedis {
return nil
}
default:
return nil
}

View file

@ -35,6 +35,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
azureutils "github.com/gravitational/teleport/api/utils/azure"
clients "github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/services"
@ -683,7 +684,7 @@ func makeAzureMySQLServer(t *testing.T, name, subscription, group, region string
name,
)
fqdn := name + ".mysql" + types.AzureEndpointSuffix
fqdn := name + ".mysql" + azureutils.DatabaseEndpointSuffix
state := armmysql.ServerStateReady
version := armmysql.ServerVersionFive7
server := &armmysql.Server{
@ -724,7 +725,7 @@ func makeAzurePostgresServer(t *testing.T, name, subscription, group, region str
name,
)
fqdn := name + ".postgres" + types.AzureEndpointSuffix
fqdn := name + ".postgres" + azureutils.DatabaseEndpointSuffix
state := armpostgresql.ServerStateReady
version := armpostgresql.ServerVersionEleven
server := &armpostgresql.Server{

View file

@ -27,14 +27,17 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
azureutils "github.com/gravitational/teleport/api/utils/azure"
"github.com/gravitational/teleport/api/utils/retryutils"
libauth "github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/cloud"
libazure "github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/aws/aws-sdk-go/aws"
@ -61,6 +64,8 @@ type Auth interface {
GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error)
// GetAzureAccessToken generates Azure database access token.
GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error)
// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error)
// GetTLSConfig builds the client TLS configuration for the session.
GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error)
// GetAuthPreference returns the cluster authentication config.
@ -69,10 +74,20 @@ type Auth interface {
io.Closer
}
// AuthClient is an interface that defines a subset of libauth.Client's
// functions that are required for database auth.
type AuthClient interface {
// GenerateDatabaseCert generates client certificate used by a database
// service to authenticate with the database instance.
GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
// GetAuthPreference returns the cluster authentication config.
GetAuthPreference(ctx context.Context) (types.AuthPreference, error)
}
// AuthConfig is the database access authenticator configuration.
type AuthConfig struct {
// AuthClient is the cluster auth client.
AuthClient *libauth.Client
AuthClient AuthClient
// Clients provides interface for obtaining cloud provider clients.
Clients cloud.Clients
// Clock is the clock implementation.
@ -294,6 +309,51 @@ func (a *dbAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (
return token.Token, nil
}
// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) {
resourceID, err := arm.ParseResourceID(sessionCtx.Database.GetAzure().ResourceID)
if err != nil {
return "", trace.Wrap(err)
}
var client libazure.CacheForRedisClient
switch resourceID.ResourceType.String() {
case "Microsoft.Cache/Redis":
client, err = a.cfg.Clients.GetAzureRedisClient(resourceID.SubscriptionID)
if err != nil {
return "", trace.Wrap(err)
}
case "Microsoft.Cache/redisEnterprise", "Microsoft.Cache/redisEnterprise/databases":
client, err = a.cfg.Clients.GetAzureRedisEnterpriseClient(resourceID.SubscriptionID)
if err != nil {
return "", trace.Wrap(err)
}
default:
return "", trace.BadParameter("unknown Azure Cache for Redis resource type: %v", resourceID.ResourceType)
}
token, err := client.GetToken(ctx, sessionCtx.Database.GetAzure().ResourceID)
if err != nil {
// Some Azure error messages are long, multi-lined, and may even
// contain divider lines like "------". It's unreadable in redis-cli as
// the message has to be merged to a single line string. Thus logging
// the original error as debug and returning a more user friendly
// message.
a.cfg.Log.WithError(err).Debugf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName())
switch {
case trace.IsAccessDenied(err):
return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure the database agent has the \"listKeys\" permission to the database.", sessionCtx.Database.GetName())
case trace.IsNotFound(err):
// Note that Azure Cache for Redis should always have both keys
// generated at all time. Here just checking in case something
// wrong with the API.
return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure either the primary key or the secondary key is generated.", sessionCtx.Database.GetName())
default:
return "", trace.Errorf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName())
}
}
return token, nil
}
// GetTLSConfig builds the client TLS configuration for the session.
//
// For RDS/Aurora, the config must contain RDS root certificate as a trusted
@ -316,29 +376,13 @@ func (a *dbAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Co
// getTLSConfigVerifyFull returns tls.Config with full verification enabled ('verify-full' mode).
// Config also includes database specific adjustment.
func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session) (*tls.Config, error) {
tlsConfig := &tls.Config{
RootCAs: x509.NewCertPool(),
}
switch sessionCtx.Database.GetProtocol() {
case defaults.ProtocolMongoDB, defaults.ProtocolRedis:
// Mongo and Redis are using custom URI schema.
default:
// Don't set the ServerName when connecting to a MongoDB cluster - in case
// of replica set the driver may dial multiple servers and will set
// ServerName itself. For Postgres/MySQL we're always connecting to the
// server specified in URI so set ServerName ourselves.
addr, err := utils.ParseAddr(sessionCtx.Database.GetURI())
if err != nil {
return nil, trace.Wrap(err)
}
tlsConfig.ServerName = addr.Host()
}
tlsConfig := &tls.Config{}
// Add CA certificate to the trusted pool if it's present, e.g. when
// connecting to RDS/Aurora which require AWS CA or when was provided in config file.
tlsConfig, err := appendCAToRoot(tlsConfig, sessionCtx)
if err != nil {
//
// Some databases may also require the system cert pool, e.g Azure Redis.
if err := setupTLSConfigRootCAs(tlsConfig, sessionCtx); err != nil {
return nil, trace.Wrap(err)
}
@ -371,10 +415,9 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session
tlsConfig.VerifyConnection = getVerifyCloudSQLCertificate(tlsConfig.RootCAs)
}
dbTLSConfig := sessionCtx.Database.GetTLS()
// Use user provided server name if set. Override the current value if needed.
if dbTLSConfig.ServerName != "" {
tlsConfig.ServerName = dbTLSConfig.ServerName
// Setup server name for verification.
if err := setupTLSConfigServerName(tlsConfig, sessionCtx); err != nil {
return nil, trace.Wrap(err)
}
// RDS/Aurora/Redshift/ElastiCache and Cloud SQL auth is done with an auth
@ -441,15 +484,91 @@ func (a *dbAuth) appendClientCert(ctx context.Context, sessionCtx *Session, tlsC
return tlsConfig, nil
}
// appendCAToRoot appends CA certificate from session context to provided tlsConfig.
func appendCAToRoot(tlsConfig *tls.Config, sessionCtx *Session) (*tls.Config, error) {
// setupTLSConfigRootCAs initializes the root CA cert pool for the provided
// tlsConfig based on session context.
func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error {
// Start with an empty pool.
tlsConfig.RootCAs = x509.NewCertPool()
// If CA is provided by the database object, always use it.
if len(sessionCtx.Database.GetCA()) != 0 {
if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(sessionCtx.Database.GetCA())) {
return nil, trace.BadParameter("invalid server CA certificate")
return trace.BadParameter("invalid server CA certificate")
}
return nil
}
return tlsConfig, nil
// Overwrite with the system cert pool, if required.
if shouldUseSystemCertPool(sessionCtx) {
systemCertPool, err := x509.SystemCertPool()
if err != nil {
return trace.Wrap(err)
}
tlsConfig.RootCAs = systemCertPool
return nil
}
// Use the empty pool. Client cert will be added later.
return nil
}
// shouldUseSystemCertPool returns true for database servers presenting
// certificates signed by publicly trusted CAs so a system cert pool can be
// used.
func shouldUseSystemCertPool(sessionCtx *Session) bool {
// Azure Cache for Redis certificates are signed by DigiCert Global Root G2.
return sessionCtx.Database.IsAzure() && sessionCtx.Database.GetProtocol() == defaults.ProtocolRedis
}
// setupTLSConfigServerName initializes the server name for the provided
// tlsConfig based on session context.
func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error {
// Use user provided server name if set. Override the current value if needed.
if dbTLSConfig := sessionCtx.Database.GetTLS(); dbTLSConfig.ServerName != "" {
tlsConfig.ServerName = dbTLSConfig.ServerName
return nil
}
// If server name is set prior to this function, use that.
if tlsConfig.ServerName != "" {
return nil
}
switch sessionCtx.Database.GetProtocol() {
case defaults.ProtocolMongoDB:
// Don't set the ServerName when connecting to a MongoDB cluster - in case
// of replica set the driver may dial multiple servers and will set
// ServerName itself.
return nil
case defaults.ProtocolRedis:
// Azure Redis servers always serve the certificates with the proper
// hostnames. However, OSS cluster mode may redirect to an IP address,
// and without correct ServerName the handshake will fail as the IPs
// are not in SANs.
if sessionCtx.Database.IsAzure() {
serverName, err := azureutils.GetHostFromRedisURI(sessionCtx.Database.GetURI())
if err != nil {
return trace.Wrap(err)
}
tlsConfig.ServerName = serverName
return nil
}
// Redis is using custom URI schema.
return nil
default:
// For other databases we're always connecting to the server specified
// in URI so set ServerName ourselves.
addr, err := utils.ParseAddr(sessionCtx.Database.GetURI())
if err != nil {
return trace.Wrap(err)
}
tlsConfig.ServerName = addr.Host()
return nil
}
}
// verifyConnectionFunc returns a certificate validation function. serverName if empty will skip the hostname validation.

View file

@ -0,0 +1,300 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package common
import (
"context"
"crypto/tls"
"crypto/x509"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud"
libcloudazure "github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/trace"
)
func TestAuthGetAzureCacheForRedisToken(t *testing.T) {
t.Parallel()
auth, err := NewAuth(AuthConfig{
AuthClient: new(authClientMock),
Clients: &cloud.TestCloudClients{
AzureRedis: libcloudazure.NewRedisClientByAPI(&libcloudazure.ARMRedisMock{
Token: "azure-redis-token",
}),
AzureRedisEnterprise: libcloudazure.NewRedisEnterpriseClientByAPI(&libcloudazure.ARMRedisEnterpriseDatabaseMock{
Token: "azure-redis-enterprise-token",
}),
},
})
require.NoError(t, err)
tests := []struct {
name string
resourceID string
expectError bool
expectToken string
}{
{
name: "invalid resource ID",
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/some-unknown-service/example-teleport",
expectError: true,
},
{
name: "Redis (non-Enterprise)",
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-teleport",
expectToken: "azure-redis-token",
},
{
name: "Redis Enterprise",
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport",
expectToken: "azure-redis-enterprise-token",
},
{
name: "Redis Enterprise (database resource ID)",
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport/databases/default",
expectToken: "azure-redis-enterprise-token",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token, err := auth.GetAzureCacheForRedisToken(context.TODO(), &Session{
Database: newAzureRedisDatabase(t, test.resourceID),
})
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expectToken, token)
}
})
}
}
func TestAuthGetTLSConfig(t *testing.T) {
t.Parallel()
auth, err := NewAuth(AuthConfig{
AuthClient: new(authClientMock),
Clients: &cloud.TestCloudClients{},
})
require.NoError(t, err)
systemCertPool, err := x509.SystemCertPool()
require.NoError(t, err)
// The authClientMock uses fixtures.TLSCACertPEM as the root signing CA.
defaultCertPool := x509.NewCertPool()
require.True(t, defaultCertPool.AppendCertsFromPEM([]byte(fixtures.TLSCACertPEM)))
// Use a different CA to pretend to be CAs for AWS hosted databases.
awsCertPool := x509.NewCertPool()
require.True(t, awsCertPool.AppendCertsFromPEM([]byte(fixtures.SAMLOktaCertPEM)))
tests := []struct {
name string
sessionDatabase types.Database
expectServerName string
expectRootCAs *x509.CertPool
expectClientCertificates bool
expectVerifyConnection bool
expectInsecureSkipVerify bool
}{
{
name: "self-hosted",
sessionDatabase: newSelfHostedDatabase(t, "localhost:8888"),
expectServerName: "localhost",
expectRootCAs: defaultCertPool,
expectClientCertificates: true,
},
{
name: "AWS ElastiCache Redis",
sessionDatabase: newElastiCacheRedisDatabase(t, fixtures.SAMLOktaCertPEM),
expectRootCAs: awsCertPool,
},
{
name: "AWS Redishift",
sessionDatabase: newRedshiftDatabase(t, fixtures.SAMLOktaCertPEM),
expectServerName: "redshift-cluster-1.abcdefghijklmnop.us-east-1.redshift.amazonaws.com",
expectRootCAs: awsCertPool,
},
{
name: "Azure Redis",
sessionDatabase: newAzureRedisDatabase(t, "resource-id"),
expectServerName: "test-database.redis.cache.windows.net",
expectRootCAs: systemCertPool,
},
{
name: "GCP Cloud SQL",
sessionDatabase: newCloudSQLDatabase(t, "project-id", "instance-id"),
// RootCAs is empty, and custom VerifyConnection function is provided.
expectServerName: "project-id:instance-id",
expectRootCAs: x509.NewCertPool(),
expectInsecureSkipVerify: true,
expectVerifyConnection: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tlsConfig, err := auth.GetTLSConfig(context.TODO(), &Session{
Identity: tlsca.Identity{},
DatabaseUser: "default",
Database: test.sessionDatabase,
})
require.NoError(t, err)
require.Equal(t, test.expectServerName, tlsConfig.ServerName)
require.Equal(t, test.expectInsecureSkipVerify, tlsConfig.InsecureSkipVerify)
// nolint:staticcheck
// TODO x509.CertPool.Subjects() is deprecated. use
// x509.CertPool.Equal introduced in 1.19 for comparison.
require.Equal(t, test.expectRootCAs.Subjects(), tlsConfig.RootCAs.Subjects())
if test.expectClientCertificates {
require.Len(t, tlsConfig.Certificates, 1)
} else {
require.Empty(t, tlsConfig.Certificates)
}
if test.expectVerifyConnection {
require.NotNil(t, tlsConfig.VerifyConnection)
} else {
require.Nil(t, tlsConfig.VerifyConnection)
}
})
}
}
func newAzureRedisDatabase(t *testing.T, resourceID string) types.Database {
database, err := types.NewDatabaseV3(types.Metadata{
Name: "test-database",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolRedis,
URI: "rediss://test-database.redis.cache.windows.net:8888",
Azure: types.Azure{
ResourceID: resourceID,
},
})
require.NoError(t, err)
return database
}
func newSelfHostedDatabase(t *testing.T, uri string) types.Database {
database, err := types.NewDatabaseV3(types.Metadata{
Name: "test-database",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolMySQL,
URI: uri,
})
require.NoError(t, err)
return database
}
func newCloudSQLDatabase(t *testing.T, projectID, instanceID string) types.Database {
database, err := types.NewDatabaseV3(types.Metadata{
Name: "test-database",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolMySQL,
URI: "cloudsql:8888",
GCP: types.GCPCloudSQL{
ProjectID: projectID,
InstanceID: instanceID,
},
})
require.NoError(t, err)
return database
}
func newElastiCacheRedisDatabase(t *testing.T, ca string) types.Database {
database, err := types.NewDatabaseV3(types.Metadata{
Name: "test-database",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolRedis,
URI: "master.example-cluster.xxxxxx.cac1.cache.amazonaws.com:6379",
TLS: types.DatabaseTLS{
CACert: ca,
},
})
require.NoError(t, err)
return database
}
func newRedshiftDatabase(t *testing.T, ca string) types.Database {
database, err := types.NewDatabaseV3(types.Metadata{
Name: "test-database",
}, types.DatabaseSpecV3{
Protocol: defaults.ProtocolPostgres,
URI: "redshift-cluster-1.abcdefghijklmnop.us-east-1.redshift.amazonaws.com:5432",
TLS: types.DatabaseTLS{
CACert: ca,
},
})
require.NoError(t, err)
return database
}
// authClientMock is a mock that implements AuthClient interface.
type authClientMock struct {
}
// GenerateDatabaseCert generates a cert using fixtures TLS CA.
func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) {
csr, err := tlsca.ParseCertificateRequestPEM(req.CSR)
if err != nil {
return nil, trace.Wrap(err)
}
tlsCACert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
if err != nil {
return nil, trace.Wrap(err)
}
tlsCA, err := tlsca.FromTLSCertificate(tlsCACert)
if err != nil {
return nil, trace.Wrap(err)
}
certReq := tlsca.CertificateRequest{
PublicKey: csr.PublicKey,
Subject: csr.Subject,
NotAfter: time.Now().Add(req.TTL.Get()),
DNSNames: []string{"localhost", "127.0.0.1"},
}
cert, err := tlsCA.GenerateCertificate(certReq)
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.DatabaseCertResponse{
Cert: cert,
CACerts: [][]byte{
[]byte(fixtures.TLSCACertPEM),
},
}, nil
}
// GetAuthPreference always returns types.DefaultAuthPreference().
func (m *authClientMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
return types.DefaultAuthPreference(), nil
}

View file

@ -31,6 +31,9 @@ import (
// TestRegisterEngine verifies database engine registration.
func TestRegisterEngine(t *testing.T) {
// Cleanup "test" engine in case this test is run in a loop.
RegisterEngine(nil, "test")
ec := EngineConfig{
Context: context.Background(),
Clock: clockwork.NewFakeClock(),

View file

@ -146,14 +146,20 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
return trace.Wrap(err)
}
// If fail to get the initial username or password, return an error right
// away without making a connection to the Redis server.
username, password, err := e.getInitialUsernameAndPassowrd(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err)
}
// Initialize newClient factory function with current connection state.
e.newClient, err = e.getNewClientFn(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err)
}
// Create new client without username or password. Those will be added when we receive AUTH command.
e.redisClient, err = e.newClient("", "")
e.redisClient, err = e.newClient(username, password)
if err != nil {
return trace.Wrap(err)
}
@ -174,6 +180,23 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
return nil
}
// getInitialUsernameAndPassowrd returns the username and password used for
// the initial connection.
func (e *Engine) getInitialUsernameAndPassowrd(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
switch {
case sessionCtx.Database.IsAzure():
// Retrieve the auth token for Azure Cache for Redis. Use default user.
password, err := e.Auth.GetAzureCacheForRedisToken(ctx, sessionCtx)
return "", password, trace.Wrap(err)
default:
// Create new client without username or password. Those will be added
// when we receive AUTH command (e.g. self-hosted), or they can be
// fetched by the OnConnect callback (e.g. ElastiCache managed users).
return "", "", nil
}
}
// getNewClientFn returns a partial Redis client factory function.
func (e *Engine) getNewClientFn(ctx context.Context, sessionCtx *common.Session) (redisClientFactoryFn, error) {
tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx)

View file

@ -20,8 +20,10 @@
package protocol
import (
"bufio"
"reflect"
"strconv"
"strings"
"github.com/go-redis/redis/v8"
"github.com/gravitational/trace"
@ -129,7 +131,29 @@ func writeError(wr *redis.Writer, prefix string, val error) error {
return trace.Wrap(err)
}
if _, err := wr.WriteString(val.Error()); err != nil {
// If the error message contains "\r" or "\n", redis-cli will have trouble
// parsing the message and show "Bad simple string value" instead. So if
// newlines are detected in the original error message, merge them to one
// line.
errString := val.Error()
if strings.ContainsAny(errString, "\r\n") {
scanner := bufio.NewScanner(strings.NewReader(errString))
errString = ""
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
if errString != "" {
errString += " "
}
errString += line
}
}
if _, err := wr.WriteString(errString); err != nil {
return trace.Wrap(err)
}

View file

@ -76,6 +76,11 @@ func TestWriteCmd(t *testing.T) {
val: errors.New("something bad"),
expected: []byte("-ERR something bad\r\n"),
},
{
name: "multi-line error",
val: errors.New("something bad.\r\n \n and another line"),
expected: []byte("-ERR something bad. and another line\r\n"),
},
{
name: "Teleport error",
val: trace.Errorf("something bad"),