mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 16:53:57 +00:00
Azure Cache for Redis engine support (#16551)
This commit is contained in:
parent
af7cd0239d
commit
aabced42dc
|
@ -18,13 +18,13 @@ package types
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/teleport/api/utils"
|
||||
awsutils "github.com/gravitational/teleport/api/utils/aws"
|
||||
azureutils "github.com/gravitational/teleport/api/utils/azure"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -511,11 +511,27 @@ func (d *DatabaseV3) CheckAndSetDefaults() error {
|
|||
}
|
||||
d.Spec.AWS.MemoryDB.TLSEnabled = endpointInfo.TransitEncryptionEnabled
|
||||
d.Spec.AWS.MemoryDB.EndpointType = endpointInfo.EndpointType
|
||||
case strings.Contains(d.Spec.URI, AzureEndpointSuffix):
|
||||
name, err := parseAzureEndpoint(d.Spec.URI)
|
||||
|
||||
case azureutils.IsDatabaseEndpoint(d.Spec.URI):
|
||||
// For Azure MySQL and PostgresSQL.
|
||||
name, err := azureutils.ParseDatabaseEndpoint(d.Spec.URI)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if d.Spec.Azure.Name == "" {
|
||||
d.Spec.Azure.Name = name
|
||||
}
|
||||
case azureutils.IsCacheForRedisEndpoint(d.Spec.URI):
|
||||
// ResourceID is required for fetching Redis tokens.
|
||||
if d.Spec.Azure.ResourceID == "" {
|
||||
return trace.BadParameter("missing ResourceID for Azure Cache %v", d.Metadata.Name)
|
||||
}
|
||||
|
||||
name, err := azureutils.ParseCacheForRedisEndpoint(d.Spec.URI)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
if d.Spec.Azure.Name == "" {
|
||||
d.Spec.Azure.Name = name
|
||||
}
|
||||
|
@ -523,21 +539,6 @@ func (d *DatabaseV3) CheckAndSetDefaults() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// parseAzureEndpoint extracts database server name from Azure endpoint.
|
||||
func parseAzureEndpoint(endpoint string) (name string, err error) {
|
||||
host, _, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
// Azure endpoint looks like this:
|
||||
// name.mysql.database.azure.com
|
||||
parts := strings.Split(host, ".")
|
||||
if !strings.HasSuffix(host, AzureEndpointSuffix) || len(parts) != 5 {
|
||||
return "", trace.BadParameter("failed to parse %v as Azure endpoint", endpoint)
|
||||
}
|
||||
return parts[0], nil
|
||||
}
|
||||
|
||||
// GetIAMPolicy returns AWS IAM policy for this database.
|
||||
func (d *DatabaseV3) GetIAMPolicy() (string, error) {
|
||||
if d.IsRDS() {
|
||||
|
@ -721,11 +722,6 @@ func (d Databases) Less(i, j int) bool { return d[i].GetName() < d[j].GetName()
|
|||
// Swap swaps two databases.
|
||||
func (d Databases) Swap(i, j int) { d[i], d[j] = d[j], d[i] }
|
||||
|
||||
const (
|
||||
// AzureEndpointSuffix is the Azure database endpoint suffix.
|
||||
AzureEndpointSuffix = ".database.azure.com"
|
||||
)
|
||||
|
||||
type arnTemplateInput struct {
|
||||
Partition, Region, AccountID, ResourceID string
|
||||
}
|
||||
|
|
|
@ -175,6 +175,108 @@ func TestDatabaseMemoryDBEndpoint(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDatabaseAzureEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
spec DatabaseSpecV3
|
||||
expectError bool
|
||||
expectAzure Azure
|
||||
}{
|
||||
{
|
||||
name: "valid MySQL",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "mysql",
|
||||
URI: "example-mysql.mysql.database.azure.com:3306",
|
||||
},
|
||||
expectAzure: Azure{
|
||||
Name: "example-mysql",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid PostgresSQL",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "postgres",
|
||||
URI: "example-postgres.postgres.database.azure.com:5432",
|
||||
},
|
||||
expectAzure: Azure{
|
||||
Name: "example-postgres",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid database endpoint",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "postgres",
|
||||
URI: "invalid.database.azure.com:5432",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid Redis",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "redis",
|
||||
URI: "example-redis.redis.cache.windows.net:6380",
|
||||
Azure: Azure{
|
||||
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-redis",
|
||||
},
|
||||
},
|
||||
expectAzure: Azure{
|
||||
Name: "example-redis",
|
||||
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-redis",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid Redis Enterprise",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "redis",
|
||||
URI: "rediss://example-redis-enterprise.region.redisenterprise.cache.azure.net?mode=cluster",
|
||||
Azure: Azure{
|
||||
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-redis-enterprise",
|
||||
},
|
||||
},
|
||||
expectAzure: Azure{
|
||||
Name: "example-redis-enterprise",
|
||||
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-redis-enterprise",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid Redis (missing resource ID)",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "redis",
|
||||
URI: "rediss://example-redis-enterprise.region.redisenterprise.cache.azure.net?mode=cluster",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid Redis (unknown format)",
|
||||
spec: DatabaseSpecV3{
|
||||
Protocol: "redis",
|
||||
URI: "rediss://bad-format.redisenterprise.cache.azure.net?mode=cluster",
|
||||
Azure: Azure{
|
||||
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/bad-format",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
database, err := NewDatabaseV3(Metadata{
|
||||
Name: "test",
|
||||
}, test.spec)
|
||||
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectAzure, database.GetAzure())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLVersionValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
122
api/utils/azure/endpoints.go
Normal file
122
api/utils/azure/endpoints.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// IsDatabaseEndpoint returns true if provided endpoint is a valid database
|
||||
// endpoint.
|
||||
func IsDatabaseEndpoint(endpoint string) bool {
|
||||
return strings.Contains(endpoint, DatabaseEndpointSuffix)
|
||||
}
|
||||
|
||||
// IsCacheForRedisEndpoint returns true if provided endpoint is a valid Azure
|
||||
// Cache for Redis endpoint.
|
||||
func IsCacheForRedisEndpoint(endpoint string) bool {
|
||||
return IsRedisEndpoint(endpoint) || IsRedisEnterpriseEndpoint(endpoint)
|
||||
}
|
||||
|
||||
// IsRedisEndpoint returns true if provided endpoint is a valid Redis
|
||||
// (non-Enterprise tier) endpoint.
|
||||
func IsRedisEndpoint(endpoint string) bool {
|
||||
return strings.Contains(endpoint, RedisEndpointSuffix)
|
||||
}
|
||||
|
||||
// IsRedisEnterpriseEndpoint returns true if provided endpoint is a valid Redis
|
||||
// Enterprise endpoint.
|
||||
func IsRedisEnterpriseEndpoint(endpoint string) bool {
|
||||
return strings.Contains(endpoint, RedisEnterpriseEndpointSuffix)
|
||||
}
|
||||
|
||||
// ParseDatabaseEndpoint extracts database server name from Azure endpoint.
|
||||
func ParseDatabaseEndpoint(endpoint string) (name string, err error) {
|
||||
host, _, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
// Azure endpoint looks like this:
|
||||
// name.mysql.database.azure.com
|
||||
parts := strings.Split(host, ".")
|
||||
if !strings.HasSuffix(host, DatabaseEndpointSuffix) || len(parts) != 5 {
|
||||
return "", trace.BadParameter("failed to parse %v as Azure endpoint", endpoint)
|
||||
}
|
||||
return parts[0], nil
|
||||
}
|
||||
|
||||
// ParseCacheForRedisEndpoint extracts database server name from Azure Cache
|
||||
// for Redis endpoint.
|
||||
func ParseCacheForRedisEndpoint(endpoint string) (name string, err error) {
|
||||
// Note that the Redis URI may contain schema and parameters.
|
||||
host, err := GetHostFromRedisURI(endpoint)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
|
||||
switch {
|
||||
// Redis (non-Enterprise) endpoint looks like this:
|
||||
// name.redis.cache.windows.net
|
||||
case strings.HasSuffix(host, RedisEndpointSuffix):
|
||||
return strings.TrimSuffix(host, RedisEndpointSuffix), nil
|
||||
|
||||
// Redis Enterprise endpoint looks like this:
|
||||
// name.region.redisenterprise.cache.azure.net
|
||||
case strings.HasSuffix(host, RedisEnterpriseEndpointSuffix):
|
||||
name, _, ok := strings.Cut(strings.TrimSuffix(host, RedisEnterpriseEndpointSuffix), ".")
|
||||
if !ok {
|
||||
return "", trace.BadParameter("failed to parse %v as Azure Cache endpoint", endpoint)
|
||||
}
|
||||
return name, nil
|
||||
|
||||
default:
|
||||
return "", trace.BadParameter("failed to parse %v as Azure Cache endpoint", endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHostFromRedisURI extracts host name from a Redis URI. The URI may start
|
||||
// with "redis://", "rediss://", or without. The URI may also have parameters
|
||||
// like "?mode=cluster".
|
||||
func GetHostFromRedisURI(uri string) (string, error) {
|
||||
// Add a temporary schema to make a valid URL for url.Parse if schema is
|
||||
// not found.
|
||||
if !strings.Contains(uri, "://") {
|
||||
uri = "schema://" + uri
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
return parsed.Hostname(), nil
|
||||
}
|
||||
|
||||
const (
|
||||
// DatabaseEndpointSuffix is the Azure database endpoint suffix. Used for
|
||||
// MySQL, PostgresSQL, etc.
|
||||
DatabaseEndpointSuffix = ".database.azure.com"
|
||||
|
||||
// RedisEndpointSuffix is the endpoint suffix for Redis.
|
||||
RedisEndpointSuffix = ".redis.cache.windows.net"
|
||||
|
||||
// RedisEnterpriseEndpointSuffix is the endpoint suffix for Redis Enterprise.
|
||||
RedisEnterpriseEndpointSuffix = ".redisenterprise.cache.azure.net"
|
||||
)
|
39
api/utils/azure/fuzz_test.go
Normal file
39
api/utils/azure/fuzz_test.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func FuzzParseDatabaseEndpoint(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, endpoint string) {
|
||||
require.NotPanics(t, func() {
|
||||
ParseDatabaseEndpoint(endpoint)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseCacheForRedisEndpoint(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, endpoint string) {
|
||||
require.NotPanics(t, func() {
|
||||
ParseCacheForRedisEndpoint(endpoint)
|
||||
})
|
||||
})
|
||||
}
|
2
go.mod
2
go.mod
|
@ -10,6 +10,8 @@ require (
|
|||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2 v2.0.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise v1.0.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.4.1
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1
|
||||
|
|
4
go.sum
4
go.sum
|
@ -98,6 +98,10 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0 h1:3
|
|||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0/go.mod h1:yFGqqJ4W/nOViqHDfuwmjyJtZXLmmMoHN0DNPCigKUE=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0 h1:A6qn+g+bsKoBhFzDFXLhNAup//D+Q7+MuofypSUtNfY=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0/go.mod h1:GgxvszemyuFZyiw4vPxGib+Cp6z7Q3rYQb4DsKPOAAw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2 v2.0.0 h1:gQhbudzftueZd/xWdFHDibO9kNghIn4B7DH1yrtoGwg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2 v2.0.0/go.mod h1:ydU7PuJLTH175OGSecHKlCi9d6wTEh8aX5JsNL4PPSs=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise v1.0.0 h1:ZyZb7h9t0hw6pNdCJ1bEwoRsVJ+gHngUlexPfSfj3ZA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise v1.0.0/go.mod h1:KOdfFo3kIa8BlSMMNTslBV7MTVyVnrKgxhn0rjlMNxM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0 h1:vsovXlTyKHZXnqzQyt7QMVkwpJBDkHchQL53qXaGBRY=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0/go.mod h1:UZy1vHcRdEymNP1d6fTrvYHpSdkXoUdowfrvffcQOOU=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.4.1 h1:QSdcrd/UFJv6Bp/CfoVf2SrENpFn9P6Yh8yb+xNhYMM=
|
||||
|
|
75
lib/cloud/azure/client_map.go
Normal file
75
lib/cloud/azure/client_map.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// ClientMap is a generic map that caches a collection of Azure clients by
|
||||
// subscriptions.
|
||||
type ClientMap[ClientType any] struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]ClientType
|
||||
newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)
|
||||
}
|
||||
|
||||
// NewClientMap creates a new ClientMap.
|
||||
func NewClientMap[ClientType any](newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)) ClientMap[ClientType] {
|
||||
return ClientMap[ClientType]{
|
||||
clients: make(map[string]ClientType),
|
||||
newClient: newClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns an Azure client by subscription. A new client is created if the
|
||||
// subscription is not found in the map.
|
||||
func (m *ClientMap[ClientType]) Get(subscription string, getCredentials func() (azcore.TokenCredential, error)) (client ClientType, err error) {
|
||||
m.mu.RLock()
|
||||
if client, ok := m.clients[subscription]; ok {
|
||||
m.mu.RUnlock()
|
||||
return client, nil
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// If some other thread already got here first.
|
||||
if client, ok := m.clients[subscription]; ok {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
cred, err := getCredentials()
|
||||
if err != nil {
|
||||
return client, trace.Wrap(err)
|
||||
}
|
||||
|
||||
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
|
||||
options := &arm.ClientOptions{}
|
||||
client, err = m.newClient(subscription, cred, options)
|
||||
if err != nil {
|
||||
return client, trace.Wrap(err)
|
||||
}
|
||||
|
||||
m.clients[subscription] = client
|
||||
return client, nil
|
||||
}
|
76
lib/cloud/azure/client_map_test.go
Normal file
76
lib/cloud/azure/client_map_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockNewClientFunc := func(subscription string, cred azcore.TokenCredential, opts *arm.ClientOptions) (CacheForRedisClient, error) {
|
||||
if subscription == "good-sub" {
|
||||
return NewRedisClientByAPI(nil), nil
|
||||
}
|
||||
return nil, trace.BadParameter("failed to create")
|
||||
}
|
||||
clientMap := NewClientMap(mockNewClientFunc)
|
||||
|
||||
// Note that some test cases (e.g. "get from cache") depend on previous
|
||||
// test cases. Thus running in sequence.
|
||||
t.Run("get credentials failed", func(t *testing.T) {
|
||||
client, err := clientMap.Get("some-sub", func() (azcore.TokenCredential, error) {
|
||||
return nil, trace.AccessDenied("failed to get credentials")
|
||||
})
|
||||
require.ErrorIs(t, err, trace.AccessDenied("failed to get credentials"))
|
||||
require.Nil(t, client)
|
||||
})
|
||||
|
||||
t.Run("create client failed", func(t *testing.T) {
|
||||
client, err := clientMap.Get("bad-sub", func() (azcore.TokenCredential, error) {
|
||||
return nil, nil
|
||||
})
|
||||
require.ErrorIs(t, err, trace.BadParameter("failed to create"))
|
||||
require.Nil(t, client)
|
||||
})
|
||||
|
||||
t.Run("create client succeed", func(t *testing.T) {
|
||||
client, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) {
|
||||
return nil, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
require.IsType(t, NewRedisClientByAPI(nil), client)
|
||||
})
|
||||
|
||||
t.Run("get from cache", func(t *testing.T) {
|
||||
// Return an error for getCredentials but it shouldn't even be called
|
||||
// as the client is returned from existing cache.
|
||||
client, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) {
|
||||
return nil, trace.AccessDenied("failed to get credentials")
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
require.IsType(t, NewRedisClientByAPI(nil), client)
|
||||
})
|
||||
}
|
|
@ -21,6 +21,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
@ -29,20 +30,21 @@ import (
|
|||
// to trace error. If the provided error is not a `ResponseError` it returns.
|
||||
// the error without modifying it.
|
||||
func ConvertResponseError(err error) error {
|
||||
responseErr, ok := err.(*azcore.ResponseError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
switch v := err.(type) {
|
||||
case *azcore.ResponseError:
|
||||
switch v.StatusCode {
|
||||
case http.StatusForbidden:
|
||||
return trace.AccessDenied(v.Error())
|
||||
case http.StatusConflict:
|
||||
return trace.AlreadyExists(v.Error())
|
||||
case http.StatusNotFound:
|
||||
return trace.NotFound(v.Error())
|
||||
}
|
||||
|
||||
switch responseErr.StatusCode {
|
||||
case http.StatusForbidden:
|
||||
return trace.AccessDenied(responseErr.Error())
|
||||
case http.StatusConflict:
|
||||
return trace.AlreadyExists(responseErr.Error())
|
||||
case http.StatusNotFound:
|
||||
return trace.NotFound(responseErr.Error())
|
||||
}
|
||||
case *azidentity.AuthenticationFailedError:
|
||||
return trace.AccessDenied(v.Error())
|
||||
|
||||
}
|
||||
return err // Return unmodified.
|
||||
}
|
||||
|
||||
|
|
|
@ -59,3 +59,9 @@ type ARMPostgres interface {
|
|||
}
|
||||
|
||||
var _ ARMPostgres = (*armpostgresql.ServersClient)(nil)
|
||||
|
||||
// CacheForRedisClient provides an interface for an Azure Redis For Cache client.
|
||||
type CacheForRedisClient interface {
|
||||
// GetToken retrieves the auth token for provided resource ID.
|
||||
GetToken(ctx context.Context, resourceID string) (string, error)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@ import (
|
|||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
|
@ -202,3 +204,47 @@ func (m *ARMPostgresMock) NewListByResourceGroupPager(group string, _ *armpostgr
|
|||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ARMRedisMock mocks armRedisClient.
|
||||
type ARMRedisMock struct {
|
||||
Token string
|
||||
NoAuth bool
|
||||
}
|
||||
|
||||
func (m *ARMRedisMock) ListKeys(ctx context.Context, resourceGroupName string, name string, options *armredis.ClientListKeysOptions) (armredis.ClientListKeysResponse, error) {
|
||||
if m.NoAuth {
|
||||
return armredis.ClientListKeysResponse{}, trace.AccessDenied("unauthorized")
|
||||
}
|
||||
return armredis.ClientListKeysResponse{
|
||||
AccessKeys: armredis.AccessKeys{
|
||||
PrimaryKey: &m.Token,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ARMRedisEnterpriseDatabaseMock mocks armRedisEnterpriseDatabaseClient.
|
||||
type ARMRedisEnterpriseDatabaseMock struct {
|
||||
Token string
|
||||
TokensByDatabaseName map[string]string
|
||||
NoAuth bool
|
||||
}
|
||||
|
||||
func (m *ARMRedisEnterpriseDatabaseMock) ListKeys(ctx context.Context, resourceGroupName string, clusterName string, databaseName string, options *armredisenterprise.DatabasesClientListKeysOptions) (armredisenterprise.DatabasesClientListKeysResponse, error) {
|
||||
if m.NoAuth {
|
||||
return armredisenterprise.DatabasesClientListKeysResponse{}, trace.AccessDenied("unauthorized")
|
||||
}
|
||||
if len(m.TokensByDatabaseName) != 0 {
|
||||
if token, found := m.TokensByDatabaseName[databaseName]; found {
|
||||
return armredisenterprise.DatabasesClientListKeysResponse{
|
||||
AccessKeys: armredisenterprise.AccessKeys{
|
||||
PrimaryKey: &token,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return armredisenterprise.DatabasesClientListKeysResponse{
|
||||
AccessKeys: armredisenterprise.AccessKeys{
|
||||
PrimaryKey: &m.Token,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
79
lib/cloud/azure/redis.go
Normal file
79
lib/cloud/azure/redis.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// armRedisClient is an interface defines a subset of functions of armredis.Client.
|
||||
type armRedisClient interface {
|
||||
ListKeys(ctx context.Context, resourceGroupName string, name string, options *armredis.ClientListKeysOptions) (armredis.ClientListKeysResponse, error)
|
||||
}
|
||||
|
||||
// redisClient is an Azure Redis client.
|
||||
type redisClient struct {
|
||||
api armRedisClient
|
||||
}
|
||||
|
||||
// NewRedisClient creates a new Azure Redis client by subscription and credentials.
|
||||
func NewRedisClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (CacheForRedisClient, error) {
|
||||
logrus.Debug("Initializing Azure Redis client.")
|
||||
api, err := armredis.NewClient(subscription, cred, options)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return NewRedisClientByAPI(api), nil
|
||||
}
|
||||
|
||||
// NewRedisClientByAPI creates a new Azure Redis client by ARM API client.
|
||||
func NewRedisClientByAPI(api armRedisClient) CacheForRedisClient {
|
||||
return &redisClient{
|
||||
api: api,
|
||||
}
|
||||
}
|
||||
|
||||
// GetToken retrieves the auth token for provided resource group and resource
|
||||
// name.
|
||||
func (c *redisClient) GetToken(ctx context.Context, resourceID string) (string, error) {
|
||||
id, err := arm.ParseResourceID(resourceID)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
|
||||
resp, err := c.api.ListKeys(ctx, id.ResourceGroupName, id.Name, &armredis.ClientListKeysOptions{})
|
||||
if err != nil {
|
||||
return "", trace.Wrap(ConvertResponseError(err))
|
||||
}
|
||||
|
||||
// There are two keys. Pick first one available.
|
||||
if resp.PrimaryKey != nil {
|
||||
return *resp.PrimaryKey, nil
|
||||
}
|
||||
if resp.SecondaryKey != nil {
|
||||
return *resp.SecondaryKey, nil
|
||||
}
|
||||
return "", trace.NotFound("missing keys")
|
||||
}
|
107
lib/cloud/azure/redis_enterprise.go
Normal file
107
lib/cloud/azure/redis_enterprise.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// armRedisEnterpriseDatabaseClient is an interface defines a subset of
|
||||
// functions of armredisenterprise.DatabaseClient
|
||||
type armRedisEnterpriseDatabaseClient interface {
|
||||
ListKeys(ctx context.Context, resourceGroupName string, clusterName string, databaseName string, options *armredisenterprise.DatabasesClientListKeysOptions) (armredisenterprise.DatabasesClientListKeysResponse, error)
|
||||
}
|
||||
|
||||
// redisEnterpriseClient is an Azure Redis Enterprise client.
|
||||
type redisEnterpriseClient struct {
|
||||
databaseAPI armRedisEnterpriseDatabaseClient
|
||||
}
|
||||
|
||||
// NewRedisEnterpriseClient creates a new Azure Redis Enterprise client by
|
||||
// subscription and credentials.
|
||||
func NewRedisEnterpriseClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (CacheForRedisClient, error) {
|
||||
logrus.Debug("Initializing Azure Redis Enterprise client.")
|
||||
databaseAPI, err := armredisenterprise.NewDatabasesClient(subscription, cred, options)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
// TODO(greedy52) Redis Enterprise requires a different API client
|
||||
// (armredisenterprise.Client) for auto-discovery.
|
||||
return NewRedisEnterpriseClientByAPI(databaseAPI), nil
|
||||
}
|
||||
|
||||
// NewRedisEnterpriseClientByAPI creates a new Azure Redis Enterprise client by
|
||||
// ARM API client(s).
|
||||
func NewRedisEnterpriseClientByAPI(databaseAPI armRedisEnterpriseDatabaseClient) CacheForRedisClient {
|
||||
return &redisEnterpriseClient{
|
||||
databaseAPI: databaseAPI,
|
||||
}
|
||||
}
|
||||
|
||||
// GetToken retrieves the auth token for provided resource group and resource
|
||||
// name.
|
||||
func (c *redisEnterpriseClient) GetToken(ctx context.Context, resourceID string) (string, error) {
|
||||
id, err := arm.ParseResourceID(resourceID)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
clusterName, databaseName, err := c.getClusterAndDatabaseName(id)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
|
||||
resp, err := c.databaseAPI.ListKeys(ctx, id.ResourceGroupName, clusterName, databaseName, &armredisenterprise.DatabasesClientListKeysOptions{})
|
||||
if err != nil {
|
||||
return "", trace.Wrap(ConvertResponseError(err))
|
||||
}
|
||||
|
||||
// There are two keys. Pick first one available.
|
||||
if resp.PrimaryKey != nil {
|
||||
return *resp.PrimaryKey, nil
|
||||
}
|
||||
if resp.SecondaryKey != nil {
|
||||
return *resp.SecondaryKey, nil
|
||||
}
|
||||
return "", trace.NotFound("missing keys")
|
||||
}
|
||||
|
||||
// getClusterAndDatabaseName returns the cluster name and the database name
|
||||
// based on the resource ID. Both armredisenterprise.Cluster.ID and
|
||||
// armredisenterprise.Database.ID are supported.
|
||||
func (c *redisEnterpriseClient) getClusterAndDatabaseName(id *arm.ResourceID) (string, string, error) {
|
||||
switch id.ResourceType.String() {
|
||||
case "Microsoft.Cache/redisEnterprise":
|
||||
// It appears an Enterprise cluster always has only one "database", and
|
||||
// the database name is always "default".
|
||||
return id.Name, RedisEnterpriseClusterDefaultDatabase, nil
|
||||
case "Microsoft.Cache/redisEnterprise/databases":
|
||||
return id.Parent.Name, id.Name, nil
|
||||
default:
|
||||
return "", "", trace.BadParameter("unknown Azure Cache for Redis resource type: %v", id.ResourceType)
|
||||
}
|
||||
}
|
||||
|
||||
// RedisEnterpriseClusterDefaultDatabase is the default database name for a
|
||||
// Redis Enterprise cluster.
|
||||
const RedisEnterpriseClusterDefaultDatabase = "default"
|
81
lib/cloud/azure/redis_enterprise_test.go
Normal file
81
lib/cloud/azure/redis_enterprise_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisEnterpriseClient(t *testing.T) {
|
||||
t.Run("GetToken", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockDatabaseAPI armRedisEnterpriseDatabaseClient
|
||||
resourceID string
|
||||
expectError bool
|
||||
expectToken string
|
||||
}{
|
||||
{
|
||||
name: "access denied",
|
||||
resourceID: "cluster-name",
|
||||
mockDatabaseAPI: &ARMRedisEnterpriseDatabaseMock{
|
||||
NoAuth: true,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "succeed (default database name)",
|
||||
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport",
|
||||
mockDatabaseAPI: &ARMRedisEnterpriseDatabaseMock{
|
||||
TokensByDatabaseName: map[string]string{
|
||||
"default": "some-token",
|
||||
},
|
||||
},
|
||||
expectToken: "some-token",
|
||||
},
|
||||
{
|
||||
name: "succeed (specific database name)",
|
||||
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport/databases/some-database",
|
||||
mockDatabaseAPI: &ARMRedisEnterpriseDatabaseMock{
|
||||
TokensByDatabaseName: map[string]string{
|
||||
"some-database": "some-token",
|
||||
},
|
||||
},
|
||||
expectToken: "some-token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := NewRedisEnterpriseClientByAPI(test.mockDatabaseAPI)
|
||||
token, err := c.GetToken(context.TODO(), test.resourceID)
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectToken, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
66
lib/cloud/azure/redis_test.go
Normal file
66
lib/cloud/azure/redis_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisClient(t *testing.T) {
|
||||
t.Run("GetToken", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockAPI armRedisClient
|
||||
expectError bool
|
||||
expectToken string
|
||||
}{
|
||||
{
|
||||
name: "access denied",
|
||||
mockAPI: &ARMRedisMock{
|
||||
NoAuth: true,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "succeed",
|
||||
mockAPI: &ARMRedisMock{
|
||||
Token: "some-token",
|
||||
},
|
||||
expectToken: "some-token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := NewRedisClientByAPI(test.mockAPI)
|
||||
token, err := c.GetToken(context.TODO(), "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-teleport")
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectToken, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
|
@ -100,6 +100,10 @@ type AzureClients interface {
|
|||
GetAzurePostgresClient(subscription string) (azure.DBServersClient, error)
|
||||
// GetAzureSubscriptionClient returns an Azure Subscriptions client
|
||||
GetAzureSubscriptionClient() (*azure.SubscriptionClient, error)
|
||||
// GetAzureRedisClient returns an Azure Redis client for the given subscription.
|
||||
GetAzureRedisClient(subscription string) (azure.CacheForRedisClient, error)
|
||||
// GetAzureRedisEnterpriseClient returns an Azure Redis Enterprise client for the given subscription.
|
||||
GetAzureRedisEnterpriseClient(subscription string) (azure.CacheForRedisClient, error)
|
||||
}
|
||||
|
||||
// NewClients returns a new instance of cloud clients retriever.
|
||||
|
@ -107,8 +111,10 @@ func NewClients() Clients {
|
|||
return &cloudClients{
|
||||
awsSessions: make(map[string]*awssession.Session),
|
||||
azureClients: azureClients{
|
||||
azureMySQLClients: make(map[string]azure.DBServersClient),
|
||||
azurePostgresClients: make(map[string]azure.DBServersClient),
|
||||
azureMySQLClients: make(map[string]azure.DBServersClient),
|
||||
azurePostgresClients: make(map[string]azure.DBServersClient),
|
||||
azureRedisClients: azure.NewClientMap(azure.NewRedisClient),
|
||||
azureRedisEnterpriseClients: azure.NewClientMap(azure.NewRedisEnterpriseClient),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -139,6 +145,10 @@ type azureClients struct {
|
|||
azurePostgresClients map[string]azure.DBServersClient
|
||||
// azureSubscriptionsClient is the cached Azure Subscriptions client.
|
||||
azureSubscriptionsClient *azure.SubscriptionClient
|
||||
// azureRedisClients contains the cached Azure Redis clients.
|
||||
azureRedisClients azure.ClientMap[azure.CacheForRedisClient]
|
||||
// azureRedisEnterpriseClients contains the cached Azure Redis Enterprise clients.
|
||||
azureRedisEnterpriseClients azure.ClientMap[azure.CacheForRedisClient]
|
||||
}
|
||||
|
||||
// GetAWSSession returns AWS session for the specified region.
|
||||
|
@ -299,6 +309,16 @@ func (c *cloudClients) GetAzureSubscriptionClient() (*azure.SubscriptionClient,
|
|||
return c.initAzureSubscriptionsClient()
|
||||
}
|
||||
|
||||
// GetAzureRedisClient returns an Azure Redis client for the given subscription.
|
||||
func (c *cloudClients) GetAzureRedisClient(subscription string) (azure.CacheForRedisClient, error) {
|
||||
return c.azureRedisClients.Get(subscription, c.GetAzureCredential)
|
||||
}
|
||||
|
||||
// GetAzureRedisEnterpriseClient returns an Azure Redis Enterprise client for the given subscription.
|
||||
func (c *cloudClients) GetAzureRedisEnterpriseClient(subscription string) (azure.CacheForRedisClient, error) {
|
||||
return c.azureRedisEnterpriseClients.Get(subscription, c.GetAzureCredential)
|
||||
}
|
||||
|
||||
// Close closes all initialized clients.
|
||||
func (c *cloudClients) Close() (err error) {
|
||||
c.mtx.Lock()
|
||||
|
@ -469,6 +489,8 @@ type TestCloudClients struct {
|
|||
AzurePostgres azure.DBServersClient
|
||||
AzurePostgresPerSub map[string]azure.DBServersClient
|
||||
AzureSubscriptionClient *azure.SubscriptionClient
|
||||
AzureRedis azure.CacheForRedisClient
|
||||
AzureRedisEnterprise azure.CacheForRedisClient
|
||||
}
|
||||
|
||||
// GetAWSSession returns AWS session for the specified region.
|
||||
|
@ -562,6 +584,16 @@ func (c *TestCloudClients) GetAWSSSMClient(region string) (ssmiface.SSMAPI, erro
|
|||
return c.SSM, nil
|
||||
}
|
||||
|
||||
// GetAzureRedisClient returns an Azure Redis client for the given subscription.
|
||||
func (c *TestCloudClients) GetAzureRedisClient(subscription string) (azure.CacheForRedisClient, error) {
|
||||
return c.AzureRedis, nil
|
||||
}
|
||||
|
||||
// GetAzureRedisEnterpriseClient returns an Azure Redis Enterprise client for the given subscription.
|
||||
func (c *TestCloudClients) GetAzureRedisEnterpriseClient(subscription string) (azure.CacheForRedisClient, error) {
|
||||
return c.AzureRedisEnterprise, nil
|
||||
}
|
||||
|
||||
// Close closes all initialized clients.
|
||||
func (c *TestCloudClients) Close() error {
|
||||
return nil
|
||||
|
|
|
@ -1286,6 +1286,9 @@ func applyDatabasesConfig(fc *FileConfig, cfg *service.Config) error {
|
|||
Domain: database.AD.Domain,
|
||||
SPN: database.AD.SPN,
|
||||
},
|
||||
Azure: service.DatabaseAzure{
|
||||
ResourceID: database.Azure.ResourceID,
|
||||
},
|
||||
}
|
||||
if err := db.CheckAndSetDefaults(); err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
|
|
@ -1341,6 +1341,8 @@ type Database struct {
|
|||
GCP DatabaseGCP `yaml:"gcp"`
|
||||
// AD contains Active Directory database configuration.
|
||||
AD DatabaseAD `yaml:"ad"`
|
||||
// Azure contains Azure database configuration.
|
||||
Azure DatabaseAzure `yaml:"azure"`
|
||||
}
|
||||
|
||||
// DatabaseAD contains database Active Directory configuration.
|
||||
|
@ -1431,6 +1433,12 @@ type DatabaseGCP struct {
|
|||
InstanceID string `yaml:"instance_id,omitempty"`
|
||||
}
|
||||
|
||||
// DatabaseAzure contains Azure database configuration.
|
||||
type DatabaseAzure struct {
|
||||
// ResourceID is the Azure fully qualified ID for the resource.
|
||||
ResourceID string `yaml:"resource_id,omitempty"`
|
||||
}
|
||||
|
||||
// Apps represents the configuration for the collection of applications this
|
||||
// service will start. In file configuration this would be the "app_service"
|
||||
// section.
|
||||
|
|
|
@ -777,6 +777,8 @@ type Database struct {
|
|||
GCP DatabaseGCP
|
||||
// AD contains Active Directory configuration for database.
|
||||
AD DatabaseAD
|
||||
// Azure contains Azure database configuration.
|
||||
Azure DatabaseAzure
|
||||
}
|
||||
|
||||
// TLSMode defines all possible database verification modes.
|
||||
|
@ -908,6 +910,12 @@ type DatabaseAD struct {
|
|||
SPN string
|
||||
}
|
||||
|
||||
// DatabaseAzure contains Azure database configuration.
|
||||
type DatabaseAzure struct {
|
||||
// ResourceID is the Azure fully qualified ID for the resource.
|
||||
ResourceID string `yaml:"resource_id,omitempty"`
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults validates database Active Directory configuration.
|
||||
func (d *DatabaseAD) CheckAndSetDefaults(name string) error {
|
||||
if d.KeytabFile == "" {
|
||||
|
|
|
@ -125,6 +125,9 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) {
|
|||
Domain: db.AD.Domain,
|
||||
SPN: db.AD.SPN,
|
||||
},
|
||||
Azure: types.Azure{
|
||||
ResourceID: db.Azure.ResourceID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
awsutils "github.com/gravitational/teleport/api/utils/aws"
|
||||
azureutils "github.com/gravitational/teleport/api/utils/azure"
|
||||
"github.com/gravitational/teleport/lib/cloud/azure"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/fixtures"
|
||||
|
@ -121,7 +122,7 @@ func TestDatabaseFromAzureDBServer(t *testing.T) {
|
|||
Name: name,
|
||||
Port: "3306",
|
||||
Properties: azure.ServerProperties{
|
||||
FullyQualifiedDomainName: name + ".mysql" + types.AzureEndpointSuffix,
|
||||
FullyQualifiedDomainName: name + ".mysql" + azureutils.DatabaseEndpointSuffix,
|
||||
UserVisibleState: string(armmysql.ServerStateReady),
|
||||
Version: string(armmysql.ServerVersionFive7),
|
||||
},
|
||||
|
|
|
@ -2011,7 +2011,9 @@ func withRDSPostgres(name, authToken string) withDatabaseOption {
|
|||
Region: testAWSRegion,
|
||||
},
|
||||
// Set CA cert, otherwise we will attempt to download RDS roots.
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testCtx.postgres[name] = testPostgres{
|
||||
|
@ -2043,7 +2045,9 @@ func withRedshiftPostgres(name, authToken string) withDatabaseOption {
|
|||
Redshift: types.Redshift{ClusterID: "redshift-cluster-1"},
|
||||
},
|
||||
// Set CA cert, otherwise we will attempt to download Redshift roots.
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testCtx.postgres[name] = testPostgres{
|
||||
|
@ -2078,7 +2082,9 @@ func withCloudSQLPostgres(name, authToken string) withDatabaseOption {
|
|||
InstanceID: "instance-1",
|
||||
},
|
||||
// Set CA cert to pass cert validation.
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testCtx.postgres[name] = testPostgres{
|
||||
|
@ -2109,7 +2115,9 @@ func withAzurePostgres(name, authToken string) withDatabaseOption {
|
|||
Name: name,
|
||||
},
|
||||
// Set CA cert, otherwise we will attempt to download RDS roots.
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testCtx.postgres[name] = testPostgres{
|
||||
|
@ -2363,6 +2371,38 @@ func withSQLServer(name string) withDatabaseOption {
|
|||
}
|
||||
}
|
||||
|
||||
func withAzureRedis(name string, token string) withDatabaseOption {
|
||||
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
|
||||
redisServer, err := redis.NewTestServer(t, common.TestServerConfig{
|
||||
Name: name,
|
||||
AuthClient: testCtx.authClient,
|
||||
}, redis.TestServerPassword(token))
|
||||
require.NoError(t, err)
|
||||
|
||||
database, err := types.NewDatabaseV3(types.Metadata{
|
||||
Name: name,
|
||||
}, types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolRedis,
|
||||
URI: fmt.Sprintf("rediss://%s", net.JoinHostPort("localhost", redisServer.Port())),
|
||||
DynamicLabels: dynamicLabels,
|
||||
Azure: types.Azure{
|
||||
Name: name,
|
||||
ResourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-teleport",
|
||||
},
|
||||
// Set CA cert to pass cert validation.
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: string(testCtx.databaseCA.GetActiveKeys().TLS[0].Cert),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testCtx.redis[name] = testRedis{
|
||||
db: redisServer,
|
||||
resource: database,
|
||||
}
|
||||
return database
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicLabels = types.LabelsToV2(map[string]types.CommandLabel{
|
||||
"echo": &types.CommandLabelV2{
|
||||
Period: types.NewDuration(time.Second),
|
||||
|
|
|
@ -55,7 +55,10 @@ func TestAuthTokens(t *testing.T) {
|
|||
withCloudSQLMySQL("mysql-cloudsql-correct-token", "root", cloudSQLPassword),
|
||||
withCloudSQLMySQL("mysql-cloudsql-incorrect-token", "root", "qwe123"),
|
||||
withAzureMySQL("mysql-azure-correct-token", "root", azureAccessToken),
|
||||
withAzureMySQL("mysql-azure-incorrect-token", "root", "qwe123"))
|
||||
withAzureMySQL("mysql-azure-incorrect-token", "root", "qwe123"),
|
||||
withAzureRedis("redis-azure-correct-token", azureRedisToken),
|
||||
withAzureRedis("redis-azure-incorrect-token", "qwe123"),
|
||||
)
|
||||
go testCtx.startHandlingConnections()
|
||||
|
||||
testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard})
|
||||
|
@ -123,6 +126,17 @@ func TestAuthTokens(t *testing.T) {
|
|||
protocol: defaults.ProtocolMySQL,
|
||||
err: "Access denied for user",
|
||||
},
|
||||
{
|
||||
desc: "correct Azure Redis auth token",
|
||||
service: "redis-azure-correct-token",
|
||||
protocol: defaults.ProtocolRedis,
|
||||
},
|
||||
{
|
||||
desc: "incorrect Azure Redis auth token",
|
||||
service: "redis-azure-incorrect-token",
|
||||
protocol: defaults.ProtocolRedis,
|
||||
err: "WRONGPASS invalid username-password pair",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -146,6 +160,15 @@ func TestAuthTokens(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
case defaults.ProtocolRedis:
|
||||
conn, err := testCtx.redisClient(ctx, "alice", test.service, "default")
|
||||
if test.err != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), test.err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unrecognized database protocol in test: %q", test.protocol)
|
||||
}
|
||||
|
@ -186,6 +209,8 @@ const (
|
|||
cloudSQLPassword = "cloudsql-password"
|
||||
// azureAccessToken is a mock Azure access token.
|
||||
azureAccessToken = "azure-access-token"
|
||||
// azureRedisToken is a mock Azure Redis token.
|
||||
azureRedisToken = "azure-redis-token"
|
||||
)
|
||||
|
||||
// GetRDSAuthToken generates RDS/Aurora auth token.
|
||||
|
@ -218,6 +243,12 @@ func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.S
|
|||
return azureAccessToken, nil
|
||||
}
|
||||
|
||||
// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
|
||||
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
|
||||
a.Infof("Generating Azure Redis token for %v.", sessionCtx)
|
||||
return azureRedisToken, nil
|
||||
}
|
||||
|
||||
func TestDBCertSigning(t *testing.T) {
|
||||
authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
|
||||
Clock: clockwork.NewFakeClockAt(time.Now()),
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
awsutils "github.com/gravitational/teleport/api/utils/aws"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
|
@ -47,8 +48,14 @@ func (s *Server) initCACert(ctx context.Context, database types.Database) error
|
|||
types.DatabaseTypeRedshift,
|
||||
types.DatabaseTypeElastiCache,
|
||||
types.DatabaseTypeMemoryDB,
|
||||
types.DatabaseTypeCloudSQL,
|
||||
types.DatabaseTypeAzure:
|
||||
types.DatabaseTypeCloudSQL:
|
||||
|
||||
case types.DatabaseTypeAzure:
|
||||
// Azure Cache for Redis uses system cert poool
|
||||
if database.GetProtocol() == defaults.ProtocolRedis {
|
||||
return nil
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
azureutils "github.com/gravitational/teleport/api/utils/azure"
|
||||
clients "github.com/gravitational/teleport/lib/cloud"
|
||||
"github.com/gravitational/teleport/lib/cloud/azure"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
|
@ -683,7 +684,7 @@ func makeAzureMySQLServer(t *testing.T, name, subscription, group, region string
|
|||
name,
|
||||
)
|
||||
|
||||
fqdn := name + ".mysql" + types.AzureEndpointSuffix
|
||||
fqdn := name + ".mysql" + azureutils.DatabaseEndpointSuffix
|
||||
state := armmysql.ServerStateReady
|
||||
version := armmysql.ServerVersionFive7
|
||||
server := &armmysql.Server{
|
||||
|
@ -724,7 +725,7 @@ func makeAzurePostgresServer(t *testing.T, name, subscription, group, region str
|
|||
name,
|
||||
)
|
||||
|
||||
fqdn := name + ".postgres" + types.AzureEndpointSuffix
|
||||
fqdn := name + ".postgres" + azureutils.DatabaseEndpointSuffix
|
||||
state := armpostgresql.ServerStateReady
|
||||
version := armpostgresql.ServerVersionEleven
|
||||
server := &armpostgresql.Server{
|
||||
|
|
|
@ -27,14 +27,17 @@ import (
|
|||
|
||||
"github.com/gravitational/teleport/api/client/proto"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
azureutils "github.com/gravitational/teleport/api/utils/azure"
|
||||
"github.com/gravitational/teleport/api/utils/retryutils"
|
||||
libauth "github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/auth/native"
|
||||
"github.com/gravitational/teleport/lib/cloud"
|
||||
libazure "github.com/gravitational/teleport/lib/cloud/azure"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
|
@ -61,6 +64,8 @@ type Auth interface {
|
|||
GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error)
|
||||
// GetAzureAccessToken generates Azure database access token.
|
||||
GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error)
|
||||
// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
|
||||
GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error)
|
||||
// GetTLSConfig builds the client TLS configuration for the session.
|
||||
GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error)
|
||||
// GetAuthPreference returns the cluster authentication config.
|
||||
|
@ -69,10 +74,20 @@ type Auth interface {
|
|||
io.Closer
|
||||
}
|
||||
|
||||
// AuthClient is an interface that defines a subset of libauth.Client's
|
||||
// functions that are required for database auth.
|
||||
type AuthClient interface {
|
||||
// GenerateDatabaseCert generates client certificate used by a database
|
||||
// service to authenticate with the database instance.
|
||||
GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
|
||||
// GetAuthPreference returns the cluster authentication config.
|
||||
GetAuthPreference(ctx context.Context) (types.AuthPreference, error)
|
||||
}
|
||||
|
||||
// AuthConfig is the database access authenticator configuration.
|
||||
type AuthConfig struct {
|
||||
// AuthClient is the cluster auth client.
|
||||
AuthClient *libauth.Client
|
||||
AuthClient AuthClient
|
||||
// Clients provides interface for obtaining cloud provider clients.
|
||||
Clients cloud.Clients
|
||||
// Clock is the clock implementation.
|
||||
|
@ -294,6 +309,51 @@ func (a *dbAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (
|
|||
return token.Token, nil
|
||||
}
|
||||
|
||||
// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
|
||||
func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) {
|
||||
resourceID, err := arm.ParseResourceID(sessionCtx.Database.GetAzure().ResourceID)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
|
||||
var client libazure.CacheForRedisClient
|
||||
switch resourceID.ResourceType.String() {
|
||||
case "Microsoft.Cache/Redis":
|
||||
client, err = a.cfg.Clients.GetAzureRedisClient(resourceID.SubscriptionID)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
case "Microsoft.Cache/redisEnterprise", "Microsoft.Cache/redisEnterprise/databases":
|
||||
client, err = a.cfg.Clients.GetAzureRedisEnterpriseClient(resourceID.SubscriptionID)
|
||||
if err != nil {
|
||||
return "", trace.Wrap(err)
|
||||
}
|
||||
default:
|
||||
return "", trace.BadParameter("unknown Azure Cache for Redis resource type: %v", resourceID.ResourceType)
|
||||
}
|
||||
token, err := client.GetToken(ctx, sessionCtx.Database.GetAzure().ResourceID)
|
||||
if err != nil {
|
||||
// Some Azure error messages are long, multi-lined, and may even
|
||||
// contain divider lines like "------". It's unreadable in redis-cli as
|
||||
// the message has to be merged to a single line string. Thus logging
|
||||
// the original error as debug and returning a more user friendly
|
||||
// message.
|
||||
a.cfg.Log.WithError(err).Debugf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName())
|
||||
switch {
|
||||
case trace.IsAccessDenied(err):
|
||||
return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure the database agent has the \"listKeys\" permission to the database.", sessionCtx.Database.GetName())
|
||||
case trace.IsNotFound(err):
|
||||
// Note that Azure Cache for Redis should always have both keys
|
||||
// generated at all time. Here just checking in case something
|
||||
// wrong with the API.
|
||||
return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure either the primary key or the secondary key is generated.", sessionCtx.Database.GetName())
|
||||
default:
|
||||
return "", trace.Errorf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName())
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetTLSConfig builds the client TLS configuration for the session.
|
||||
//
|
||||
// For RDS/Aurora, the config must contain RDS root certificate as a trusted
|
||||
|
@ -316,29 +376,13 @@ func (a *dbAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Co
|
|||
// getTLSConfigVerifyFull returns tls.Config with full verification enabled ('verify-full' mode).
|
||||
// Config also includes database specific adjustment.
|
||||
func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
switch sessionCtx.Database.GetProtocol() {
|
||||
case defaults.ProtocolMongoDB, defaults.ProtocolRedis:
|
||||
// Mongo and Redis are using custom URI schema.
|
||||
default:
|
||||
// Don't set the ServerName when connecting to a MongoDB cluster - in case
|
||||
// of replica set the driver may dial multiple servers and will set
|
||||
// ServerName itself. For Postgres/MySQL we're always connecting to the
|
||||
// server specified in URI so set ServerName ourselves.
|
||||
addr, err := utils.ParseAddr(sessionCtx.Database.GetURI())
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
tlsConfig.ServerName = addr.Host()
|
||||
}
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Add CA certificate to the trusted pool if it's present, e.g. when
|
||||
// connecting to RDS/Aurora which require AWS CA or when was provided in config file.
|
||||
tlsConfig, err := appendCAToRoot(tlsConfig, sessionCtx)
|
||||
if err != nil {
|
||||
//
|
||||
// Some databases may also require the system cert pool, e.g Azure Redis.
|
||||
if err := setupTLSConfigRootCAs(tlsConfig, sessionCtx); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
|
@ -371,10 +415,9 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session
|
|||
tlsConfig.VerifyConnection = getVerifyCloudSQLCertificate(tlsConfig.RootCAs)
|
||||
}
|
||||
|
||||
dbTLSConfig := sessionCtx.Database.GetTLS()
|
||||
// Use user provided server name if set. Override the current value if needed.
|
||||
if dbTLSConfig.ServerName != "" {
|
||||
tlsConfig.ServerName = dbTLSConfig.ServerName
|
||||
// Setup server name for verification.
|
||||
if err := setupTLSConfigServerName(tlsConfig, sessionCtx); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
// RDS/Aurora/Redshift/ElastiCache and Cloud SQL auth is done with an auth
|
||||
|
@ -441,15 +484,91 @@ func (a *dbAuth) appendClientCert(ctx context.Context, sessionCtx *Session, tlsC
|
|||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// appendCAToRoot appends CA certificate from session context to provided tlsConfig.
|
||||
func appendCAToRoot(tlsConfig *tls.Config, sessionCtx *Session) (*tls.Config, error) {
|
||||
// setupTLSConfigRootCAs initializes the root CA cert pool for the provided
|
||||
// tlsConfig based on session context.
|
||||
func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error {
|
||||
// Start with an empty pool.
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
|
||||
// If CA is provided by the database object, always use it.
|
||||
if len(sessionCtx.Database.GetCA()) != 0 {
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(sessionCtx.Database.GetCA())) {
|
||||
return nil, trace.BadParameter("invalid server CA certificate")
|
||||
return trace.BadParameter("invalid server CA certificate")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
// Overwrite with the system cert pool, if required.
|
||||
if shouldUseSystemCertPool(sessionCtx) {
|
||||
systemCertPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
tlsConfig.RootCAs = systemCertPool
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use the empty pool. Client cert will be added later.
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldUseSystemCertPool returns true for database servers presenting
|
||||
// certificates signed by publicly trusted CAs so a system cert pool can be
|
||||
// used.
|
||||
func shouldUseSystemCertPool(sessionCtx *Session) bool {
|
||||
// Azure Cache for Redis certificates are signed by DigiCert Global Root G2.
|
||||
return sessionCtx.Database.IsAzure() && sessionCtx.Database.GetProtocol() == defaults.ProtocolRedis
|
||||
}
|
||||
|
||||
// setupTLSConfigServerName initializes the server name for the provided
|
||||
// tlsConfig based on session context.
|
||||
func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error {
|
||||
// Use user provided server name if set. Override the current value if needed.
|
||||
if dbTLSConfig := sessionCtx.Database.GetTLS(); dbTLSConfig.ServerName != "" {
|
||||
tlsConfig.ServerName = dbTLSConfig.ServerName
|
||||
return nil
|
||||
}
|
||||
|
||||
// If server name is set prior to this function, use that.
|
||||
if tlsConfig.ServerName != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch sessionCtx.Database.GetProtocol() {
|
||||
case defaults.ProtocolMongoDB:
|
||||
// Don't set the ServerName when connecting to a MongoDB cluster - in case
|
||||
// of replica set the driver may dial multiple servers and will set
|
||||
// ServerName itself.
|
||||
return nil
|
||||
|
||||
case defaults.ProtocolRedis:
|
||||
// Azure Redis servers always serve the certificates with the proper
|
||||
// hostnames. However, OSS cluster mode may redirect to an IP address,
|
||||
// and without correct ServerName the handshake will fail as the IPs
|
||||
// are not in SANs.
|
||||
if sessionCtx.Database.IsAzure() {
|
||||
serverName, err := azureutils.GetHostFromRedisURI(sessionCtx.Database.GetURI())
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
tlsConfig.ServerName = serverName
|
||||
return nil
|
||||
}
|
||||
|
||||
// Redis is using custom URI schema.
|
||||
return nil
|
||||
|
||||
default:
|
||||
// For other databases we're always connecting to the server specified
|
||||
// in URI so set ServerName ourselves.
|
||||
addr, err := utils.ParseAddr(sessionCtx.Database.GetURI())
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
tlsConfig.ServerName = addr.Host()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// verifyConnectionFunc returns a certificate validation function. serverName if empty will skip the hostname validation.
|
||||
|
|
300
lib/srv/db/common/auth_test.go
Normal file
300
lib/srv/db/common/auth_test.go
Normal file
|
@ -0,0 +1,300 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport/api/client/proto"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/lib/cloud"
|
||||
libcloudazure "github.com/gravitational/teleport/lib/cloud/azure"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/fixtures"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
func TestAuthGetAzureCacheForRedisToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auth, err := NewAuth(AuthConfig{
|
||||
AuthClient: new(authClientMock),
|
||||
Clients: &cloud.TestCloudClients{
|
||||
AzureRedis: libcloudazure.NewRedisClientByAPI(&libcloudazure.ARMRedisMock{
|
||||
Token: "azure-redis-token",
|
||||
}),
|
||||
AzureRedisEnterprise: libcloudazure.NewRedisEnterpriseClientByAPI(&libcloudazure.ARMRedisEnterpriseDatabaseMock{
|
||||
Token: "azure-redis-enterprise-token",
|
||||
}),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
resourceID string
|
||||
expectError bool
|
||||
expectToken string
|
||||
}{
|
||||
{
|
||||
name: "invalid resource ID",
|
||||
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/some-unknown-service/example-teleport",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Redis (non-Enterprise)",
|
||||
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/Redis/example-teleport",
|
||||
expectToken: "azure-redis-token",
|
||||
},
|
||||
{
|
||||
name: "Redis Enterprise",
|
||||
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport",
|
||||
expectToken: "azure-redis-enterprise-token",
|
||||
},
|
||||
{
|
||||
name: "Redis Enterprise (database resource ID)",
|
||||
resourceID: "/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.Cache/redisEnterprise/example-teleport/databases/default",
|
||||
expectToken: "azure-redis-enterprise-token",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
token, err := auth.GetAzureCacheForRedisToken(context.TODO(), &Session{
|
||||
Database: newAzureRedisDatabase(t, test.resourceID),
|
||||
})
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectToken, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthGetTLSConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auth, err := NewAuth(AuthConfig{
|
||||
AuthClient: new(authClientMock),
|
||||
Clients: &cloud.TestCloudClients{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
systemCertPool, err := x509.SystemCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The authClientMock uses fixtures.TLSCACertPEM as the root signing CA.
|
||||
defaultCertPool := x509.NewCertPool()
|
||||
require.True(t, defaultCertPool.AppendCertsFromPEM([]byte(fixtures.TLSCACertPEM)))
|
||||
|
||||
// Use a different CA to pretend to be CAs for AWS hosted databases.
|
||||
awsCertPool := x509.NewCertPool()
|
||||
require.True(t, awsCertPool.AppendCertsFromPEM([]byte(fixtures.SAMLOktaCertPEM)))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionDatabase types.Database
|
||||
expectServerName string
|
||||
expectRootCAs *x509.CertPool
|
||||
expectClientCertificates bool
|
||||
expectVerifyConnection bool
|
||||
expectInsecureSkipVerify bool
|
||||
}{
|
||||
{
|
||||
name: "self-hosted",
|
||||
sessionDatabase: newSelfHostedDatabase(t, "localhost:8888"),
|
||||
expectServerName: "localhost",
|
||||
expectRootCAs: defaultCertPool,
|
||||
expectClientCertificates: true,
|
||||
},
|
||||
{
|
||||
name: "AWS ElastiCache Redis",
|
||||
sessionDatabase: newElastiCacheRedisDatabase(t, fixtures.SAMLOktaCertPEM),
|
||||
expectRootCAs: awsCertPool,
|
||||
},
|
||||
{
|
||||
name: "AWS Redishift",
|
||||
sessionDatabase: newRedshiftDatabase(t, fixtures.SAMLOktaCertPEM),
|
||||
expectServerName: "redshift-cluster-1.abcdefghijklmnop.us-east-1.redshift.amazonaws.com",
|
||||
expectRootCAs: awsCertPool,
|
||||
},
|
||||
{
|
||||
name: "Azure Redis",
|
||||
sessionDatabase: newAzureRedisDatabase(t, "resource-id"),
|
||||
expectServerName: "test-database.redis.cache.windows.net",
|
||||
expectRootCAs: systemCertPool,
|
||||
},
|
||||
{
|
||||
name: "GCP Cloud SQL",
|
||||
sessionDatabase: newCloudSQLDatabase(t, "project-id", "instance-id"),
|
||||
// RootCAs is empty, and custom VerifyConnection function is provided.
|
||||
expectServerName: "project-id:instance-id",
|
||||
expectRootCAs: x509.NewCertPool(),
|
||||
expectInsecureSkipVerify: true,
|
||||
expectVerifyConnection: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
tlsConfig, err := auth.GetTLSConfig(context.TODO(), &Session{
|
||||
Identity: tlsca.Identity{},
|
||||
DatabaseUser: "default",
|
||||
Database: test.sessionDatabase,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, test.expectServerName, tlsConfig.ServerName)
|
||||
require.Equal(t, test.expectInsecureSkipVerify, tlsConfig.InsecureSkipVerify)
|
||||
|
||||
// nolint:staticcheck
|
||||
// TODO x509.CertPool.Subjects() is deprecated. use
|
||||
// x509.CertPool.Equal introduced in 1.19 for comparison.
|
||||
require.Equal(t, test.expectRootCAs.Subjects(), tlsConfig.RootCAs.Subjects())
|
||||
|
||||
if test.expectClientCertificates {
|
||||
require.Len(t, tlsConfig.Certificates, 1)
|
||||
} else {
|
||||
require.Empty(t, tlsConfig.Certificates)
|
||||
}
|
||||
|
||||
if test.expectVerifyConnection {
|
||||
require.NotNil(t, tlsConfig.VerifyConnection)
|
||||
} else {
|
||||
require.Nil(t, tlsConfig.VerifyConnection)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newAzureRedisDatabase(t *testing.T, resourceID string) types.Database {
|
||||
database, err := types.NewDatabaseV3(types.Metadata{
|
||||
Name: "test-database",
|
||||
}, types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolRedis,
|
||||
URI: "rediss://test-database.redis.cache.windows.net:8888",
|
||||
Azure: types.Azure{
|
||||
ResourceID: resourceID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return database
|
||||
}
|
||||
|
||||
func newSelfHostedDatabase(t *testing.T, uri string) types.Database {
|
||||
database, err := types.NewDatabaseV3(types.Metadata{
|
||||
Name: "test-database",
|
||||
}, types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolMySQL,
|
||||
URI: uri,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return database
|
||||
}
|
||||
|
||||
func newCloudSQLDatabase(t *testing.T, projectID, instanceID string) types.Database {
|
||||
database, err := types.NewDatabaseV3(types.Metadata{
|
||||
Name: "test-database",
|
||||
}, types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolMySQL,
|
||||
URI: "cloudsql:8888",
|
||||
GCP: types.GCPCloudSQL{
|
||||
ProjectID: projectID,
|
||||
InstanceID: instanceID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return database
|
||||
}
|
||||
|
||||
func newElastiCacheRedisDatabase(t *testing.T, ca string) types.Database {
|
||||
database, err := types.NewDatabaseV3(types.Metadata{
|
||||
Name: "test-database",
|
||||
}, types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolRedis,
|
||||
URI: "master.example-cluster.xxxxxx.cac1.cache.amazonaws.com:6379",
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: ca,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return database
|
||||
}
|
||||
|
||||
func newRedshiftDatabase(t *testing.T, ca string) types.Database {
|
||||
database, err := types.NewDatabaseV3(types.Metadata{
|
||||
Name: "test-database",
|
||||
}, types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolPostgres,
|
||||
URI: "redshift-cluster-1.abcdefghijklmnop.us-east-1.redshift.amazonaws.com:5432",
|
||||
TLS: types.DatabaseTLS{
|
||||
CACert: ca,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return database
|
||||
}
|
||||
|
||||
// authClientMock is a mock that implements AuthClient interface.
|
||||
type authClientMock struct {
|
||||
}
|
||||
|
||||
// GenerateDatabaseCert generates a cert using fixtures TLS CA.
|
||||
func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) {
|
||||
csr, err := tlsca.ParseCertificateRequestPEM(req.CSR)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
tlsCACert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
tlsCA, err := tlsca.FromTLSCertificate(tlsCACert)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
certReq := tlsca.CertificateRequest{
|
||||
PublicKey: csr.PublicKey,
|
||||
Subject: csr.Subject,
|
||||
NotAfter: time.Now().Add(req.TTL.Get()),
|
||||
DNSNames: []string{"localhost", "127.0.0.1"},
|
||||
}
|
||||
cert, err := tlsCA.GenerateCertificate(certReq)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return &proto.DatabaseCertResponse{
|
||||
Cert: cert,
|
||||
CACerts: [][]byte{
|
||||
[]byte(fixtures.TLSCACertPEM),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAuthPreference always returns types.DefaultAuthPreference().
|
||||
func (m *authClientMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
|
||||
return types.DefaultAuthPreference(), nil
|
||||
}
|
|
@ -31,6 +31,9 @@ import (
|
|||
|
||||
// TestRegisterEngine verifies database engine registration.
|
||||
func TestRegisterEngine(t *testing.T) {
|
||||
// Cleanup "test" engine in case this test is run in a loop.
|
||||
RegisterEngine(nil, "test")
|
||||
|
||||
ec := EngineConfig{
|
||||
Context: context.Background(),
|
||||
Clock: clockwork.NewFakeClock(),
|
||||
|
|
|
@ -146,14 +146,20 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// If fail to get the initial username or password, return an error right
|
||||
// away without making a connection to the Redis server.
|
||||
username, password, err := e.getInitialUsernameAndPassowrd(ctx, sessionCtx)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// Initialize newClient factory function with current connection state.
|
||||
e.newClient, err = e.getNewClientFn(ctx, sessionCtx)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// Create new client without username or password. Those will be added when we receive AUTH command.
|
||||
e.redisClient, err = e.newClient("", "")
|
||||
e.redisClient, err = e.newClient(username, password)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
@ -174,6 +180,23 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
|
|||
return nil
|
||||
}
|
||||
|
||||
// getInitialUsernameAndPassowrd returns the username and password used for
|
||||
// the initial connection.
|
||||
func (e *Engine) getInitialUsernameAndPassowrd(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
|
||||
switch {
|
||||
case sessionCtx.Database.IsAzure():
|
||||
// Retrieve the auth token for Azure Cache for Redis. Use default user.
|
||||
password, err := e.Auth.GetAzureCacheForRedisToken(ctx, sessionCtx)
|
||||
return "", password, trace.Wrap(err)
|
||||
|
||||
default:
|
||||
// Create new client without username or password. Those will be added
|
||||
// when we receive AUTH command (e.g. self-hosted), or they can be
|
||||
// fetched by the OnConnect callback (e.g. ElastiCache managed users).
|
||||
return "", "", nil
|
||||
}
|
||||
}
|
||||
|
||||
// getNewClientFn returns a partial Redis client factory function.
|
||||
func (e *Engine) getNewClientFn(ctx context.Context, sessionCtx *common.Session) (redisClientFactoryFn, error) {
|
||||
tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx)
|
||||
|
|
|
@ -20,8 +20,10 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gravitational/trace"
|
||||
|
@ -129,7 +131,29 @@ func writeError(wr *redis.Writer, prefix string, val error) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
if _, err := wr.WriteString(val.Error()); err != nil {
|
||||
// If the error message contains "\r" or "\n", redis-cli will have trouble
|
||||
// parsing the message and show "Bad simple string value" instead. So if
|
||||
// newlines are detected in the original error message, merge them to one
|
||||
// line.
|
||||
errString := val.Error()
|
||||
if strings.ContainsAny(errString, "\r\n") {
|
||||
scanner := bufio.NewScanner(strings.NewReader(errString))
|
||||
errString = ""
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if errString != "" {
|
||||
errString += " "
|
||||
}
|
||||
errString += line
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := wr.WriteString(errString); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -76,6 +76,11 @@ func TestWriteCmd(t *testing.T) {
|
|||
val: errors.New("something bad"),
|
||||
expected: []byte("-ERR something bad\r\n"),
|
||||
},
|
||||
{
|
||||
name: "multi-line error",
|
||||
val: errors.New("something bad.\r\n \n and another line"),
|
||||
expected: []byte("-ERR something bad. and another line\r\n"),
|
||||
},
|
||||
{
|
||||
name: "Teleport error",
|
||||
val: trace.Errorf("something bad"),
|
||||
|
|
Loading…
Reference in a new issue