Connection Diagnostics: Postgres Database tester (#18558)

* Connection Diagnostics: Postgres Database tester

When adding a new resource using the Web UI we want to allow users to
test connecting to it.

We have two connection testers already:
- for SSH Nodes
- for Kube clusters

This PR adds a third tester: Postgres Database.

Most of the required changes for any Database are already present but
we're focusing on Postgres for now. Other databases will be added as
future PRs.

Testing a Database is similar to the other tests:
- generate certs for the logged in user
- connect to the resource using those certs

When generating the certificate, we inject an ID so that the Database
Service can add Connection Diagnostic traces.
This commit is contained in:
Marco André Dinis 2022-12-13 18:55:40 +01:00 committed by GitHub
parent 0433a9d5cb
commit 6fd0fa8ef3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 2619 additions and 1019 deletions

View file

@ -4389,6 +4389,14 @@ message ConnectionDiagnosticTrace {
RBAC_KUBE = 6;
// KUBE_PRINCIPAL is used when checking if the Kube Cluster has at least one user principals.
KUBE_PRINCIPAL = 7;
// RBAC_DATABASE is for RBAC checks to database access (db_labels).
RBAC_DATABASE = 8;
// RBAC_DATABASE_LOGIN is for RBAC checks to database login (db_name and db_user).
RBAC_DATABASE_LOGIN = 9;
// DATABASE_DB_USER is used when checking whether the Database has the requested Database User.
DATABASE_DB_USER = 10;
// DATABASE_DB_NAME is used when checking whether the Database has the requested Database Name.
DATABASE_DB_NAME = 11;
}
TraceType Type = 1 [(gogoproto.jsontag) = "type"];
// StatusType describes whether this was a success or a failure.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,288 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package conntest
import (
"context"
"encoding/json"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/client/conntest"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/web/ui"
)
func startPostgresTestServer(t *testing.T, authServer *auth.Server) *postgres.TestServer {
postgresTestServer, err := postgres.NewTestServer(common.TestServerConfig{
AuthClient: authServer,
})
require.NoError(t, err)
go func() {
t.Logf("Postgres Fake server running at %s port", postgresTestServer.Port())
assert.NoError(t, postgresTestServer.Serve())
}()
t.Cleanup(func() {
postgresTestServer.Close()
})
return postgresTestServer
}
func TestDiagnoseConnectionForPostgresDatabases(t *testing.T) {
ctx := context.Background()
// Start Teleport Auth and Proxy services
authProcess, proxyProcess, provisionToken := helpers.MakeTestServers(t)
authServer := authProcess.GetAuthServer()
proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)
// Start Fake Postgres Database
postgresTestServer := startPostgresTestServer(t, authServer)
// Start Teleport Database Service
databaseResourceName := "mypsqldb"
databaseDBName := "dbname"
databaseDBUser := "dbuser"
helpers.MakeTestDatabaseServer(t, *proxyAddr, provisionToken, service.Database{
Name: databaseResourceName,
Protocol: defaults.ProtocolPostgres,
URI: net.JoinHostPort("localhost", postgresTestServer.Port()),
})
// Wait for the Database Server to be registered
waitForDatabases(t, authServer, []string{databaseResourceName})
roleWithFullAccess, err := types.NewRole("fullaccess", types.RoleSpecV5{
Allow: types.RoleConditions{
Namespaces: []string{apidefaults.Namespace},
DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}},
Rules: []types.Rule{
types.NewRule(types.KindConnectionDiagnostic, services.RW()),
},
DatabaseUsers: []string{databaseDBUser},
DatabaseNames: []string{databaseDBName},
},
})
require.NoError(t, err)
require.NoError(t, authServer.UpsertRole(ctx, roleWithFullAccess))
for _, tt := range []struct {
name string
teleportUser string
reqResourceName string
reqDBUser string
reqDBName string
expectedSuccess bool
expectedMessage string
expectedTraces []types.ConnectionDiagnosticTrace
}{
{
name: "success",
teleportUser: "success",
reqResourceName: databaseResourceName,
reqDBUser: databaseDBUser,
reqDBName: databaseDBName,
expectedSuccess: true,
expectedMessage: "success",
expectedTraces: []types.ConnectionDiagnosticTrace{
{
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE,
Status: types.ConnectionDiagnosticTrace_SUCCESS,
Details: "A Database Agent is available to proxy the connection to the Database.",
},
{
Type: types.ConnectionDiagnosticTrace_CONNECTIVITY,
Status: types.ConnectionDiagnosticTrace_SUCCESS,
Details: "Database is accessible from the Database Agent.",
},
{
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
Status: types.ConnectionDiagnosticTrace_SUCCESS,
Details: "Access to Database User and Database Name granted.",
},
{
Type: types.ConnectionDiagnosticTrace_DATABASE_DB_USER,
Status: types.ConnectionDiagnosticTrace_SUCCESS,
Details: "Database User exists in the Database.",
},
{
Type: types.ConnectionDiagnosticTrace_DATABASE_DB_NAME,
Status: types.ConnectionDiagnosticTrace_SUCCESS,
Details: "Database Name exists in the Database.",
},
},
},
{
name: "databse not found",
teleportUser: "dbnotfound",
reqResourceName: "dbnotfound",
reqDBUser: databaseDBUser,
reqDBName: databaseDBName,
expectedSuccess: false,
expectedMessage: "failed",
expectedTraces: []types.ConnectionDiagnosticTrace{
{
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE,
Status: types.ConnectionDiagnosticTrace_FAILED,
Details: "Database not found. " +
"Ensure your role grants access by adding it to the 'db_labels' property. " +
"This can also happen when you don't have a Database Agent proxying the database - " +
"you can fix that by adding the database labels to the 'db_service.resources.labels' in 'teleport.yaml' file of the database agent.",
},
},
},
{
name: "no access to db user/name",
teleportUser: "deniedlogin",
reqResourceName: databaseResourceName,
reqDBUser: "root",
reqDBName: "system",
expectedSuccess: false,
expectedMessage: "failed",
expectedTraces: []types.ConnectionDiagnosticTrace{
{
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
Status: types.ConnectionDiagnosticTrace_FAILED,
Details: "Access denied when accessing Database. Please check the Error message for more information.",
},
},
},
} {
t.Run(tt.name, func(t *testing.T) {
tt := tt
// Set up User
user, err := types.NewUser(tt.teleportUser)
require.NoError(t, err)
user.AddRole(roleWithFullAccess.GetName())
require.NoError(t, authServer.UpsertUser(user))
userPassword := uuid.NewString()
require.NoError(t, authServer.UpsertPassword(tt.teleportUser, []byte(userPassword)))
webPack := helpers.LoginWebClient(t, proxyAddr.String(), tt.teleportUser, userPassword)
diagnoseReq := conntest.TestConnectionRequest{
ResourceKind: types.KindDatabase,
ResourceName: tt.reqResourceName,
DatabaseUser: tt.reqDBUser,
DatabaseName: tt.reqDBName,
// Default is 30 seconds but since tests run locally, we can reduce this value to also improve test responsiveness
DialTimeout: time.Second,
InsecureSkipVerify: true,
}
diagnoseConnectionEndpoint := strings.Join([]string{"sites", "$site", "diagnostics", "connections"}, "/")
resp, err := webPack.DoRequest(http.MethodPost, diagnoseConnectionEndpoint, diagnoseReq)
require.NoError(t, err)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode, string(respBody))
var connectionDiagnostic ui.ConnectionDiagnostic
require.NoError(t, json.Unmarshal(respBody, &connectionDiagnostic))
gotFailedTraces := 0
expectedFailedTraces := 0
for i, trace := range connectionDiagnostic.Traces {
if trace.Status == types.ConnectionDiagnosticTrace_FAILED.String() {
gotFailedTraces++
}
t.Logf("%d status='%s' type='%s' details='%s' error='%s'\n", i, trace.Status, trace.TraceType, trace.Details, trace.Error)
}
require.Equal(t, tt.expectedSuccess, connectionDiagnostic.Success)
require.Equal(t, tt.expectedMessage, connectionDiagnostic.Message)
for _, expectedTrace := range tt.expectedTraces {
if expectedTrace.Status == types.ConnectionDiagnosticTrace_FAILED {
expectedFailedTraces++
}
foundTrace := false
for _, returnedTrace := range connectionDiagnostic.Traces {
if expectedTrace.Type.String() != returnedTrace.TraceType {
continue
}
foundTrace = true
require.Equal(t, expectedTrace.Status.String(), returnedTrace.Status)
require.Equal(t, expectedTrace.Details, returnedTrace.Details)
require.Contains(t, returnedTrace.Error, expectedTrace.Error)
}
require.True(t, foundTrace, expectedTrace)
}
require.Equal(t, expectedFailedTraces, gotFailedTraces)
})
}
}
func waitForDatabases(t *testing.T, authServer *auth.Server, dbNames []string) {
ctx := context.Background()
require.Eventually(t, func() bool {
all, err := authServer.GetDatabaseServers(ctx, apidefaults.Namespace)
assert.NoError(t, err)
if len(dbNames) > len(all) {
return false
}
registered := 0
for _, db := range dbNames {
for _, a := range all {
if a.GetName() == db {
registered++
break
}
}
}
return registered == len(dbNames)
}, 10*time.Second, 100*time.Millisecond)
}

View file

@ -28,21 +28,28 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh/agent"
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/client"
libclient "github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/identityfile"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
)
// CommandOptions controls how the SSH command is built.
@ -301,3 +308,119 @@ func WaitForDatabaseServers(t *testing.T, authServer *auth.Server, dbs []service
}
}
}
// MakeTestServers starts an Auth and a Proxy Service.
// Besides those processes, it also returns a provision token which can be used to add other services.
func MakeTestServers(t *testing.T) (auth *service.TeleportProcess, proxy *service.TeleportProcess, provisionToken string) {
provisionToken = uuid.NewString()
var err error
// Set up a test auth server.
//
// We need this to get a random port assigned to it and allow parallel
// execution of this test.
cfg := service.MakeDefaultConfig()
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()
cfg.Hostname = "localhost"
cfg.DataDir = t.TempDir()
cfg.SetAuthServerAddress(cfg.Auth.ListenAddr)
cfg.Auth.ListenAddr.Addr = NewListener(t, service.ListenerAuth, &cfg.FileDescriptors)
cfg.Auth.Preference.SetSecondFactor(constants.SecondFactorOff)
cfg.Auth.StorageConfig.Params = backend.Params{defaults.BackendPath: filepath.Join(cfg.DataDir, defaults.BackendDir)}
cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{
StaticTokens: []types.ProvisionTokenV1{{
Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp},
Expires: time.Now().Add(time.Minute),
Token: provisionToken,
}},
})
require.NoError(t, err)
cfg.SSH.Enabled = false
cfg.Auth.Enabled = true
cfg.Proxy.Enabled = false
cfg.Log = utils.NewLoggerForTests()
auth, err = service.NewTeleport(cfg)
require.NoError(t, err)
require.NoError(t, auth.Start())
t.Cleanup(func() {
require.NoError(t, auth.Close())
require.NoError(t, auth.Wait())
})
// Wait for proxy to become ready.
_, err = auth.WaitForEventTimeout(30*time.Second, service.AuthTLSReady)
// in reality, the auth server should start *much* sooner than this. we use a very large
// timeout here because this isn't the kind of problem that this test is meant to catch.
require.NoError(t, err, "auth server didn't start after 30s")
authAddr, err := auth.AuthAddr()
require.NoError(t, err)
// Set up a test proxy service.
cfg = service.MakeDefaultConfig()
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()
cfg.Hostname = "localhost"
cfg.DataDir = t.TempDir()
cfg.SetAuthServerAddress(*authAddr)
cfg.SetToken(provisionToken)
cfg.SSH.Enabled = false
cfg.Auth.Enabled = false
cfg.Proxy.Enabled = true
cfg.Proxy.ReverseTunnelListenAddr.Addr = NewListener(t, service.ListenerProxyTunnel, &cfg.FileDescriptors)
cfg.Proxy.WebAddr.Addr = NewListener(t, service.ListenerProxyWeb, &cfg.FileDescriptors)
cfg.Proxy.PublicAddrs = []utils.NetAddr{
cfg.Proxy.WebAddr,
}
cfg.Proxy.DisableWebInterface = true
cfg.Log = utils.NewLoggerForTests()
proxy, err = service.NewTeleport(cfg)
require.NoError(t, err)
require.NoError(t, proxy.Start())
t.Cleanup(func() {
require.NoError(t, proxy.Close())
require.NoError(t, proxy.Wait())
})
// Wait for proxy to become ready.
_, err = proxy.WaitForEventTimeout(10*time.Second, service.ProxyWebServerReady)
require.NoError(t, err, "proxy web server didn't start after 10s")
return auth, proxy, provisionToken
}
// MakeTestDatabaseServer creates a Database Service
// It receives the Proxy Address, a Token (to join the cluster) and a list of Datbases
func MakeTestDatabaseServer(t *testing.T, proxyAddr utils.NetAddr, token string, dbs ...service.Database) (db *service.TeleportProcess) {
// Proxy uses self-signed certificates in tests.
lib.SetInsecureDevMode(true)
cfg := service.MakeDefaultConfig()
cfg.Hostname = "localhost"
cfg.DataDir = t.TempDir()
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()
cfg.SetAuthServerAddress(proxyAddr)
cfg.SetToken(token)
cfg.SSH.Enabled = false
cfg.Auth.Enabled = false
cfg.Databases.Enabled = true
cfg.Databases.Databases = dbs
cfg.Log = utils.NewLoggerForTests()
db, err := service.NewTeleport(cfg)
require.NoError(t, err)
require.NoError(t, db.Start())
t.Cleanup(func() {
assert.NoError(t, db.Close())
})
// Wait for database agent to start.
_, err = db.WaitForEventTimeout(10*time.Second, service.DatabasesReady)
require.NoError(t, err, "database server didn't start after 10s")
return db
}

145
integration/helpers/web.go Normal file
View file

@ -0,0 +1,145 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package helpers
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/lib/httplib/csrf"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/web"
"github.com/gravitational/teleport/lib/web/ui"
)
// WebClientPack is an authenticated HTTP Client for Teleport.
type WebClientPack struct {
clt *http.Client
host string
webCookie string
bearerToken string
clusterName string
}
// LoginWebClient receives the host url, the username and a password.
// It will login into that host and return a WebClientPack.
func LoginWebClient(t *testing.T, host, username, password string) *WebClientPack {
csReq, err := json.Marshal(web.CreateSessionReq{
User: username,
Pass: password,
})
require.NoError(t, err)
// Create POST request to create session.
u := url.URL{
Scheme: "https",
Host: host,
Path: "/v1/webapi/sessions/web",
}
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewBuffer(csReq))
require.NoError(t, err)
// Attach CSRF token in cookie and header.
csrfToken, err := utils.CryptoRandomHex(32)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: csrf.CookieName,
Value: csrfToken,
})
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set(csrf.HeaderName, csrfToken)
// Issue request.
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Read in response.
var csResp *web.CreateSessionResponse
err = json.NewDecoder(resp.Body).Decode(&csResp)
require.NoError(t, err)
// Extract session cookie and bearer token.
require.Len(t, resp.Cookies(), 1)
cookie := resp.Cookies()[0]
require.Equal(t, cookie.Name, web.CookieName)
webClient := &WebClientPack{
clt: client,
host: host,
webCookie: cookie.Value,
bearerToken: csResp.Token,
}
resp, err = webClient.DoRequest(http.MethodGet, "sites", nil)
require.NoError(t, err)
defer resp.Body.Close()
var clusters []ui.Cluster
require.NoError(t, json.NewDecoder(resp.Body).Decode(&clusters))
require.NotEmpty(t, clusters)
webClient.clusterName = clusters[0].Name
return webClient
}
// DoRequest receives a method, endpoint and payload and sends an HTTP Request to the Teleport API.
// The endpoint must not contain the host neither the base path ('/v1/webapi/').
// Returns the http.Response.
func (w *WebClientPack) DoRequest(method, endpoint string, payload any) (*http.Response, error) {
endpoint = strings.ReplaceAll(endpoint, "$site", w.clusterName)
u := url.URL{
Scheme: "https",
Host: w.host,
Path: fmt.Sprintf("/v1/webapi/%s", endpoint),
}
bs, err := json.Marshal(payload)
if err != nil {
return nil, trace.Wrap(err)
}
req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(bs))
if err != nil {
return nil, trace.Wrap(err)
}
req.AddCookie(&http.Cookie{
Name: web.CookieName,
Value: w.webCookie,
})
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", w.bearerToken))
req.Header.Add("Content-Type", "application/json")
resp, err := w.clt.Do(req)
return resp, trace.Wrap(err)
}

View file

@ -1488,6 +1488,7 @@ func (a *Server) generateUserCert(req certRequest) (*proto.Certs, error) {
Generation: req.generation,
AllowedResourceIDs: req.checker.GetAllowedResourceIDs(),
PrivateKeyPolicy: attestedKeyPolicy,
ConnectionDiagnosticID: req.connectionDiagnosticID,
}
subject, err := identity.Subject()
if err != nil {

View file

@ -597,6 +597,7 @@ func definitionForBuiltinRole(clusterName string, recConfig types.SessionRecordi
types.NewRule(types.KindDatabase, services.RW()),
types.NewRule(types.KindSemaphore, services.RW()),
types.NewRule(types.KindLock, services.RO()),
types.NewRule(types.KindConnectionDiagnostic, services.RW()),
},
},
})

168
lib/client/alpn.go Normal file
View file

@ -0,0 +1,168 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"time"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
alpn "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/utils"
)
// ALPNAuthClient contains the required auth.ClientI methods to create a local ALPN proxy.
type ALPNAuthClient interface {
// GetClusterCACert returns the PEM-encoded TLS certs for the local cluster.
// If the cluster has multiple TLS certs, they will all be concatenated.
GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error)
// GetCurrentUser returns current user as seen by the server.
// Useful especially in the context of remote clusters which perform role and trait mapping.
GetCurrentUser(ctx context.Context) (types.User, error)
// GenerateUserCerts takes the public key in the OpenSSH `authorized_keys` plain
// text format, signs it using User Certificate Authority signing key and
// returns the resulting certificates.
GenerateUserCerts(ctx context.Context, req proto.UserCertsRequest) (*proto.Certs, error)
}
// ALPNAuthTunnelConfig contains the required fields used to create an authed ALPN Proxy
type ALPNAuthTunnelConfig struct {
// AuthClient is the client that's used to interact with the cluster and obtain Certificates.
AuthClient ALPNAuthClient
// Listener to be used to accept connections that will go trough the tunnel.
Listener net.Listener
// InsecureSkipTLSVerify turns off verification for x509 upstream ALPN proxy service certificate.
InsecureSkipVerify bool
// Expires is a desired time of the expiry of the certificate.
Expires time.Time
// Protocol name.
Protocol alpn.Protocol
// PublicProxyAddr is public address of the proxy
PublicProxyAddr string
// ConnectionDiagnosticID contains the ID to be used to store Connection Diagnostic checks.
// Can be empty.
ConnectionDiagnosticID string
// RouteToDatabase contains the destination server that must receive the connection.
// Specific for database proxying.
RouteToDatabase proto.RouteToDatabase
}
// RunALPNAuthTunnel runs a local authenticated ALPN proxy to another service.
// At least one Route (which defines the service) must be defined
func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error {
protocols := []alpn.Protocol{cfg.Protocol}
if alpn.HasPingSupport(cfg.Protocol) {
protocols = append(alpn.ProtocolsWithPing(cfg.Protocol), protocols...)
}
var pool *x509.CertPool
alpnUpgradeRequired := alpnproxy.IsALPNConnUpgradeRequired(cfg.PublicProxyAddr, cfg.InsecureSkipVerify)
if alpnUpgradeRequired {
caCert, err := cfg.AuthClient.GetClusterCACert(ctx)
if err != nil {
return trace.Wrap(err)
}
pool = x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(caCert.GetTLSCA()); !ok {
return trace.BadParameter("failed to append cert from cluster's TLS CA Cert")
}
}
address, err := utils.ParseAddr(cfg.PublicProxyAddr)
if err != nil {
return trace.Wrap(err)
}
tlsCert, err := getUserCerts(ctx, cfg.AuthClient, cfg.Expires, cfg.RouteToDatabase, cfg.ConnectionDiagnosticID)
if err != nil {
return trace.BadParameter("failed to parse private key: %v", err)
}
lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
InsecureSkipVerify: cfg.InsecureSkipVerify,
RemoteProxyAddr: cfg.PublicProxyAddr,
Protocols: protocols,
Listener: cfg.Listener,
ParentContext: ctx,
SNI: address.Host(),
Certs: []tls.Certificate{*tlsCert},
RootCAs: pool,
ALPNConnUpgradeRequired: alpnUpgradeRequired,
})
if err != nil {
return trace.Wrap(err)
}
go func() {
defer cfg.Listener.Close()
if err := lp.Start(ctx); err != nil {
log.WithError(err).Info("ALPN proxy stopped.")
}
}()
return nil
}
func getUserCerts(ctx context.Context, client ALPNAuthClient, expires time.Time, routeToDatabase proto.RouteToDatabase, connectionDiagnosticID string) (*tls.Certificate, error) {
key, err := GenerateRSAKey()
if err != nil {
return nil, trace.Wrap(err)
}
currentUser, err := client.GetCurrentUser(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
certs, err := client.GenerateUserCerts(ctx, proto.UserCertsRequest{
PublicKey: key.MarshalSSHPublicKey(),
Username: currentUser.GetName(),
Expires: expires,
ConnectionDiagnosticID: connectionDiagnosticID,
RouteToDatabase: routeToDatabase,
})
if err != nil {
return nil, trace.Wrap(err)
}
tlsCert, err := keys.X509KeyPair(certs.TLS, key.PrivateKeyPEM())
if err != nil {
return nil, trace.BadParameter("failed to parse private key: %v", err)
}
return &tlsCert, nil
}

View file

@ -37,6 +37,12 @@ type TestConnectionRequest struct {
// ResourceName is the identification of the resource's instance to test.
ResourceName string `json:"resource_name"`
// DialTimeout when trying to connect to the destination host
DialTimeout time.Duration `json:"dial_timeout,omitempty"`
// InsecureSkipTLSVerify turns off verification for x509 upstream ALPN proxy service certificate.
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
// SSHPrincipal is the Linux username to use in a connection test.
// Specific to SSHTester.
SSHPrincipal string `json:"ssh_principal,omitempty"`
@ -50,8 +56,13 @@ type TestConnectionRequest struct {
// Specific to KubernetesTester.
KubernetesImpersonation KubernetesImpersonation `json:"kubernetes_impersonation,omitempty"`
// DialTimeout when trying to connect to the destination host
DialTimeout time.Duration `json:"dial_timeout,omitempty"`
// DatabaseUser is the database User to be tested
// Specific to DatabaseTester.
DatabaseUser string `json:"database_user,omitempty"`
// DatabaseName is the database user of the Database to be tested
// Specific to DatabaseTester.
DatabaseName string `json:"database_name,omitempty"`
}
// KubernetesImpersonation allows to configure a subset of `kubernetes_users` and
@ -117,6 +128,9 @@ type ConnectionTesterConfig struct {
// ProxyHostPort is the proxy to use in the `--proxy` format (host:webPort,sshPort)
ProxyHostPort string
// PublicProxyAddr is public address of the proxy.
PublicProxyAddr string
// KubernetesPublicProxyAddr is the kubernetes proxy.
KubernetesPublicProxyAddr string
@ -148,6 +162,15 @@ func ConnectionTesterForKind(cfg ConnectionTesterConfig) (ConnectionTester, erro
},
)
return tester, trace.Wrap(err)
case types.KindDatabase:
tester, err := NewDatabaseConnectionTester(
DatabaseConnectionTesterConfig{
UserClient: cfg.UserClient,
PublicProxyAddr: cfg.PublicProxyAddr,
TLSRoutingEnabled: cfg.TLSRoutingEnabled,
},
)
return tester, trace.Wrap(err)
default:
return nil, trace.NotImplemented("resource %q does not have a connection tester", cfg.ResourceKind)
}

View file

@ -0,0 +1,430 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package conntest
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/gravitational/trace"
apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/conntest/database"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
alpn "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/srv/db/common/role"
)
// databasePinger describes the required methods to test a Database Connection.
type databasePinger interface {
// Ping tests the connection to the Database with a simple request.
Ping(ctx context.Context, req database.PingParams) error
// IsConnectionRefusedError returns whether the error is referring to a connection refused.
IsConnectionRefusedError(error) bool
// IsInvalidDatabaseUserError returns whether the error is referring to an invalid (non-existent) user.
IsInvalidDatabaseUserError(error) bool
// IsInvalidDatabaseNameError returns whether the error is referring to an invalid (non-existent) database name.
IsInvalidDatabaseNameError(error) bool
}
// ClientDatabaseConnectionTester contains the required auth.ClientI methods to test a Database Connection
type ClientDatabaseConnectionTester interface {
client.ALPNAuthClient
services.ConnectionsDiagnostic
apiclient.ListResourcesClient
}
// DatabaseConnectionTesterConfig defines the config fields for DatabaseConnectionTester.
type DatabaseConnectionTesterConfig struct {
// UserClient is an auth client that has a User's identity.
UserClient ClientDatabaseConnectionTester
// PublicProxyAddr is public address of the proxy
PublicProxyAddr string
// TLSRoutingEnabled indicates that proxy supports ALPN SNI server where
// all proxy services are exposed on a single TLS listener (Proxy Web Listener).
TLSRoutingEnabled bool
}
// DatabaseConnectionTester implements the ConnectionTester interface for Testing Database access.
type DatabaseConnectionTester struct {
cfg DatabaseConnectionTesterConfig
}
// NewDatabaseConnectionTester returns a new DatabaseConnectionTester
func NewDatabaseConnectionTester(cfg DatabaseConnectionTesterConfig) (*DatabaseConnectionTester, error) {
return &DatabaseConnectionTester{
cfg: cfg,
}, nil
}
// TestConnection tests the access to a database using:
// - auth Client using the User access
// - the resource name
// - database user and database name to connect to
//
// A new ConnectionDiagnostic is created and used to store the traces as it goes through the checkpoints
// To connect to the Database, we will create a cert-key pair and setup a Database client back to Teleport Proxy.
// The following checkpoints are reported:
// - database server for the requested database exists / the user's roles can access it
// - the user can use the requested database user and database name (per their roles)
// - the database is acessible and accepting connections from the database server
// - the database has the database user and database name that was requested
func (s *DatabaseConnectionTester) TestConnection(ctx context.Context, req TestConnectionRequest) (types.ConnectionDiagnostic, error) {
if req.ResourceKind != types.KindDatabase {
return nil, trace.BadParameter("invalid value for ResourceKind, expected %q got %q", types.KindDatabase, req.ResourceKind)
}
connectionDiagnosticID := uuid.NewString()
connectionDiagnostic, err := types.NewConnectionDiagnosticV1(
connectionDiagnosticID,
map[string]string{},
types.ConnectionDiagnosticSpecV1{
// We start with a failed state so that we don't need to set it to each return statement once an error is returned.
// if the test reaches the end, we force the test to be a success by calling
// connectionDiagnostic.SetMessage(types.DiagnosticMessageSuccess)
// connectionDiagnostic.SetSuccess(true)
Message: types.DiagnosticMessageFailed,
},
)
if err != nil {
return nil, trace.Wrap(err)
}
if err := s.cfg.UserClient.CreateConnectionDiagnostic(ctx, connectionDiagnostic); err != nil {
return nil, trace.Wrap(err)
}
databaseServers, err := s.getDatabaseServers(ctx, req.ResourceName)
if err != nil {
return nil, trace.Wrap(err)
}
if len(databaseServers) == 0 {
connDiag, err := s.appendDiagnosticTrace(ctx,
connectionDiagnosticID,
types.ConnectionDiagnosticTrace_RBAC_DATABASE,
"Database not found. "+
"Ensure your role grants access by adding it to the 'db_labels' property. "+
"This can also happen when you don't have a Database Agent proxying the database - "+
"you can fix that by adding the database labels to the 'db_service.resources.labels' in 'teleport.yaml' file of the database agent.",
trace.NotFound("%s not found", req.ResourceName),
)
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
databaseServer := databaseServers[0]
routeToDatabase := proto.RouteToDatabase{
ServiceName: databaseServer.GetName(),
Protocol: databaseServer.GetDatabase().GetProtocol(),
Username: req.DatabaseUser,
Database: req.DatabaseName,
}
databasePinger, err := getDatabaseConnTester(routeToDatabase.Protocol)
if err != nil {
return nil, trace.Wrap(err)
}
if err := checkDatabaseLogin(routeToDatabase.Protocol, req.DatabaseUser, req.DatabaseName); err != nil {
return nil, trace.Wrap(err)
}
if _, err := s.appendDiagnosticTrace(ctx,
connectionDiagnosticID,
types.ConnectionDiagnosticTrace_RBAC_DATABASE,
"A Database Agent is available to proxy the connection to the Database.",
nil,
); err != nil {
return nil, trace.Wrap(err)
}
listener, err := s.runALPNTunnel(ctx, req, routeToDatabase, connectionDiagnosticID)
if err != nil {
return nil, trace.Wrap(err)
}
defer listener.Close()
ping, err := newPing(listener.Addr().String(), req.DatabaseUser, req.DatabaseName)
if err != nil {
return nil, trace.Wrap(err)
}
if pingErr := databasePinger.Ping(ctx, ping); pingErr != nil {
connDiag, err := s.handlePingError(ctx, connectionDiagnosticID, pingErr, databasePinger)
return connDiag, trace.Wrap(err)
}
return s.handlePingSuccess(ctx, connectionDiagnosticID)
}
func (s *DatabaseConnectionTester) runALPNTunnel(ctx context.Context, req TestConnectionRequest, routeToDatabase proto.RouteToDatabase, connectionDiagnosticID string) (net.Listener, error) {
list, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, trace.Wrap(err)
}
alpnProtocol, err := alpn.ToALPNProtocol(routeToDatabase.Protocol)
if err != nil {
return nil, trace.Wrap(err)
}
err = client.RunALPNAuthTunnel(ctx, client.ALPNAuthTunnelConfig{
AuthClient: s.cfg.UserClient,
Listener: list,
Protocol: alpnProtocol,
Expires: time.Now().Add(time.Minute).UTC(),
PublicProxyAddr: s.cfg.PublicProxyAddr,
ConnectionDiagnosticID: connectionDiagnosticID,
RouteToDatabase: routeToDatabase,
InsecureSkipVerify: req.InsecureSkipVerify,
})
if err != nil {
return nil, trace.Wrap(err)
}
return list, nil
}
func (s *DatabaseConnectionTester) getDatabaseServers(ctx context.Context, databaseName string) ([]types.DatabaseServer, error) {
// Lookup the Database Server that's proxying the requested Database.
listResourcesResponse, err := s.cfg.UserClient.ListResources(ctx, proto.ListResourcesRequest{
PredicateExpression: fmt.Sprintf(`name == "%s"`, databaseName),
ResourceType: types.KindDatabaseServer,
Limit: defaults.MaxIterationLimit,
})
if err != nil {
return nil, trace.Wrap(err)
}
databaseServers, err := types.ResourcesWithLabels(listResourcesResponse.Resources).AsDatabaseServers()
return databaseServers, trace.Wrap(err)
}
func checkDatabaseLogin(protocol, databaseUser, databaseName string) error {
matchers := role.DatabaseRoleMatchers(protocol, databaseUser, databaseName)
needUser := false
needDatabase := false
for _, matcher := range matchers {
_, userMatcher := matcher.(*services.DatabaseUserMatcher)
needUser = needUser || userMatcher
_, nameMatcher := matcher.(*services.DatabaseNameMatcher)
needDatabase = needDatabase || nameMatcher
}
if needUser && databaseUser == "" {
return trace.BadParameter("missing required parameter Database User")
}
if needDatabase && databaseName == "" {
return trace.BadParameter("missing required parameter Database Name")
}
return nil
}
func newPing(alpnProxyAddr, databaseUser, databaseName string) (database.PingParams, error) {
proxyHost, proxyPortStr, err := net.SplitHostPort(alpnProxyAddr)
if err != nil {
return database.PingParams{}, trace.Wrap(err)
}
proxyPort, err := strconv.Atoi(proxyPortStr)
if err != nil {
return database.PingParams{}, trace.Wrap(err)
}
return database.PingParams{
Host: proxyHost,
Port: proxyPort,
Username: databaseUser,
Database: databaseName,
}, nil
}
func (s DatabaseConnectionTester) handlePingSuccess(ctx context.Context, connectionDiagnosticID string) (types.ConnectionDiagnostic, error) {
if _, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
types.ConnectionDiagnosticTrace_CONNECTIVITY,
"Database is accessible from the Database Agent.",
nil,
); err != nil {
return nil, trace.Wrap(err)
}
if _, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
"Access to Database User and Database Name granted.",
nil,
); err != nil {
return nil, trace.Wrap(err)
}
if _, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
types.ConnectionDiagnosticTrace_DATABASE_DB_USER,
"Database User exists in the Database.",
nil,
); err != nil {
return nil, trace.Wrap(err)
}
connDiag, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
types.ConnectionDiagnosticTrace_DATABASE_DB_NAME,
"Database Name exists in the Database.",
nil,
)
if err != nil {
return nil, trace.Wrap(err)
}
connDiag.SetMessage(types.DiagnosticMessageSuccess)
connDiag.SetSuccess(true)
if err := s.cfg.UserClient.UpdateConnectionDiagnostic(ctx, connDiag); err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
func errorFromDatabaseService(pingErr error) bool {
// If the requested DB User/Name can't be used per RBAC checks, the Database Agent returns an error which gets here.
if strings.Contains(pingErr.Error(), "access to db denied. User does not have permissions. Confirm database user and name.") {
return true
}
// When there's an error when trying to use RDS IAM auth.
if strings.Contains(pingErr.Error(), "FATAL: PAM authentication failed for user") &&
strings.Contains(pingErr.Error(), "IAM policy") {
return true
}
return false
}
func (s DatabaseConnectionTester) handlePingError(ctx context.Context, connectionDiagnosticID string, pingErr error, databasePinger databasePinger) (types.ConnectionDiagnostic, error) {
// The Database Agent (lib/srv/db/server.go) might add an trace in some cases.
// Here, it must be ignored to prevent multiple failed traces.
if errorFromDatabaseService(pingErr) {
connDiag, err := s.cfg.UserClient.GetConnectionDiagnostic(ctx, connectionDiagnosticID)
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
if databasePinger.IsConnectionRefusedError(pingErr) {
connDiag, err := s.appendDiagnosticTrace(ctx,
connectionDiagnosticID,
types.ConnectionDiagnosticTrace_CONNECTIVITY,
"There was a connection problem between the Database Agent and the Database. "+
"Ensure the Database is running and accessible from the Database Agent.",
pingErr,
)
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
// Requested DB User is allowed per RBAC rules, but those entities don't exist in the Database itself.
if databasePinger.IsInvalidDatabaseUserError(pingErr) {
connDiag, err := s.appendDiagnosticTrace(ctx,
connectionDiagnosticID,
types.ConnectionDiagnosticTrace_DATABASE_DB_USER,
"The Database rejected the provided Database User. Ensure that the database user exists.",
pingErr,
)
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
// Requested DB Name is allowed per RBAC rules, but those entities don't exist in the Database itself.
if databasePinger.IsInvalidDatabaseNameError(pingErr) {
connDiag, err := s.appendDiagnosticTrace(ctx,
connectionDiagnosticID,
types.ConnectionDiagnosticTrace_DATABASE_DB_NAME,
"The Database rejected the provided Database Name. Ensure that the database name exists.",
pingErr,
)
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
connDiag, err := s.appendDiagnosticTrace(ctx,
connectionDiagnosticID,
types.ConnectionDiagnosticTrace_UNKNOWN_ERROR,
fmt.Sprintf("Unknown error. %v", pingErr),
pingErr,
)
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
func (s DatabaseConnectionTester) appendDiagnosticTrace(ctx context.Context, connectionDiagnosticID string, traceType types.ConnectionDiagnosticTrace_TraceType, message string, err error) (types.ConnectionDiagnostic, error) {
connDiag, err := s.cfg.UserClient.AppendDiagnosticTrace(
ctx,
connectionDiagnosticID,
types.NewTraceDiagnosticConnection(
traceType,
message,
err,
))
if err != nil {
return nil, trace.Wrap(err)
}
return connDiag, nil
}
func getDatabaseConnTester(protocol string) (databasePinger, error) {
switch protocol {
case defaults.ProtocolPostgres:
return &database.PostgresPinger{}, nil
}
return nil, trace.NotImplemented("database protocol %q is not supported yet for testing connection", protocol)
}

View file

@ -0,0 +1,54 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package database
import (
"github.com/gravitational/trace"
)
// PingParams contains the required fields necessary to test a Database Connection.
type PingParams struct {
// Host is the hostname of the Database (does not include port).
Host string
// Port is the port where the Database is accepting connections.
Port int
// Username is the user to be used to login into the database.
Username string
// Database is the database name to be used to login into the database.
Database string
}
// CheckAndSetDefaults validates and set the default values for the Ping.
func (req *PingParams) CheckAndSetDefaults() error {
if req.Database == "" {
return trace.BadParameter("missing required parameter Database")
}
if req.Username == "" {
return trace.BadParameter("missing required parameter Username")
}
if req.Port == 0 {
return trace.BadParameter("missing required parameter Port")
}
if req.Host == "" {
req.Host = "localhost"
}
return nil
}

View file

@ -0,0 +1,114 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package database
import (
"context"
"errors"
"fmt"
"strings"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/sirupsen/logrus"
)
const (
// A simple query to execute when running the Ping request
selectOneQuery = "select 1;"
)
// PostgresPinger implements the DatabasePinger interface for the Postgres protocol
type PostgresPinger struct{}
// Ping connects to the database and issues a basic select statement to validate the connection.
func (p *PostgresPinger) Ping(ctx context.Context, ping PingParams) error {
if err := ping.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
pgconnConfig, err := pgconn.ParseConfig(
fmt.Sprintf("postgres://%s@%s:%d/%s",
ping.Username,
ping.Host,
ping.Port,
ping.Database,
),
)
if err != nil {
return trace.Wrap(err)
}
conn, err := pgconn.ConnectConfig(ctx, pgconnConfig)
if err != nil {
return trace.Wrap(err)
}
defer func() {
if err := conn.Close(ctx); err != nil {
logrus.WithError(err).Info("failed to close connection in PostgresPinger.Ping")
}
}()
result, err := conn.Exec(ctx, selectOneQuery).ReadAll()
if err != nil {
return trace.Wrap(err)
}
if len(result) != 1 {
return trace.BadParameter("unexpected length for result: %+v", result)
}
return nil
}
// IsConnectionRefusedError checks whether the error is of type invalid database user.
// This can happen when the user doesn't exist.
func (p *PostgresPinger) IsConnectionRefusedError(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "connection refused (SQLSTATE )")
}
// IsInvalidDatabaseUserError checks whether the error is of type invalid database user.
// This can happen when the user doesn't exist.
func (p *PostgresPinger) IsInvalidDatabaseUserError(err error) bool {
var pge *pgconn.PgError
if errors.As(err, &pge) {
if pge.SQLState() == pgerrcode.InvalidAuthorizationSpecification {
return true
}
}
return false
}
// IsInvalidDatabaseNameError checks whether the error is of type invalid database name.
// This can happen when the database doesn't exist.
func (p *PostgresPinger) IsInvalidDatabaseNameError(err error) bool {
var pge *pgconn.PgError
if errors.As(err, &pge) {
if pge.SQLState() == pgerrcode.InvalidCatalogName {
return true
}
}
return false
}

View file

@ -0,0 +1,175 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package database
import (
"context"
"crypto/tls"
"crypto/x509/pkix"
"errors"
"strconv"
"testing"
"time"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
)
func TestPostgresErrors(t *testing.T) {
p := PostgresPinger{}
for _, tt := range []struct {
name string
pingErr error
errCheck require.ErrorAssertionFunc
}{
{
name: "connection refused error",
pingErr: errors.New("failed to connect to `host=127.0.0.1 user=postgres database=postgres`: server error (: connection refused (SQLSTATE ))"),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(tt, p.IsConnectionRefusedError(err))
},
},
{
name: "invalid database error",
pingErr: &pgconn.PgError{
Code: pgerrcode.InvalidCatalogName,
},
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(tt, p.IsInvalidDatabaseNameError(err))
},
},
{
name: "invalid user error",
pingErr: &pgconn.PgError{
Code: pgerrcode.InvalidAuthorizationSpecification,
},
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(tt, p.IsInvalidDatabaseUserError(err))
},
},
} {
t.Run(tt.name, func(t *testing.T) {
tt.errCheck(t, tt.pingErr)
})
}
}
// mockClient is a mock that implements AuthClient interface.
type mockClient struct {
common.AuthClientCA
ca types.CertAuthority
}
func setupMockClient(t *testing.T) *mockClient {
t.Helper()
_, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "example.com"}, nil, time.Minute)
require.NoError(t, err)
ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
Type: types.HostCA,
ClusterName: "example.com",
ActiveKeys: types.CAKeySet{
SSH: []*types.SSHKeyPair{{PublicKey: []byte("SSH CA cert")}},
TLS: []*types.TLSKeyPair{{Cert: cert}},
},
})
require.NoError(t, err)
return &mockClient{
ca: ca,
}
}
func (c *mockClient) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) {
csr, err := tlsca.ParseCertificateRequestPEM(req.CSR)
if err != nil {
return nil, trace.Wrap(err)
}
tlsCACert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
if err != nil {
return nil, trace.Wrap(err)
}
tlsCA, err := tlsca.FromTLSCertificate(tlsCACert)
if err != nil {
return nil, trace.Wrap(err)
}
certReq := tlsca.CertificateRequest{
PublicKey: csr.PublicKey,
Subject: csr.Subject,
NotAfter: time.Now().Add(req.TTL.Get()),
DNSNames: req.ServerNames,
}
cert, err := tlsCA.GenerateCertificate(certReq)
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.DatabaseCertResponse{
Cert: cert,
CACerts: [][]byte{
[]byte(fixtures.TLSCACertPEM),
},
}, nil
}
func (c *mockClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) {
return c.ca, nil
}
func TestPostgresPing(t *testing.T) {
mockClt := setupMockClient(t)
postgresTestServer, err := postgres.NewTestServer(common.TestServerConfig{
AuthClient: mockClt,
})
require.NoError(t, err)
go func() {
t.Logf("Postgres Fake server running at %s port", postgresTestServer.Port())
require.NoError(t, postgresTestServer.Serve())
}()
t.Cleanup(func() {
postgresTestServer.Close()
})
port, err := strconv.Atoi(postgresTestServer.Port())
require.NoError(t, err)
p := PostgresPinger{}
err = p.Ping(context.Background(), PingParams{
Host: "localhost",
Port: port,
Username: "someuser",
Database: "somedb",
})
require.NoError(t, err)
}

View file

@ -38,7 +38,7 @@ import (
// TestServerConfig combines parameters for a test Postgres/MySQL server.
type TestServerConfig struct {
// AuthClient will be used to retrieve trusted CA.
AuthClient auth.ClientI
AuthClient AuthClientCA
// Name is the server name for identification purposes.
Name string
// AuthUser is used in tests simulating IAM token authentication.
@ -92,6 +92,17 @@ func (cfg *TestServerConfig) Port() (string, error) {
return port, nil
}
// AuthClientCA contains the required methods to Generate mTLS certificate to be used
// by the postgres TestServer.
type AuthClientCA interface {
// GenerateDatabaseCert generates client certificate used by a database
// service to authenticate with the database instance.
GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
// GetCertAuthority returns cert authority by id
GetCertAuthority(context.Context, types.CertAuthID, bool, ...services.MarshalOption) (types.CertAuthority, error)
}
// MakeTestServerTLSConfig returns TLS config suitable for configuring test
// database Postgres/MySQL servers.
func MakeTestServerTLSConfig(config TestServerConfig) (*tls.Config, error) {

View file

@ -813,6 +813,23 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
err = engine.HandleConnection(ctx, sessionCtx)
if err != nil {
connectionDiagnosticID := sessionCtx.Identity.ConnectionDiagnosticID
if connectionDiagnosticID != "" && trace.IsAccessDenied(err) {
_, diagErr := s.cfg.AuthClient.AppendDiagnosticTrace(ctx,
connectionDiagnosticID,
&types.ConnectionDiagnosticTrace{
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
Status: types.ConnectionDiagnosticTrace_FAILED,
Details: "Access denied when accessing Database. Please check the Error message for more information.",
Error: err.Error(),
},
)
if diagErr != nil {
return trace.Wrap(diagErr)
}
}
return trace.Wrap(err)
}
return nil

View file

@ -180,6 +180,9 @@ type Identity struct {
AllowedResourceIDs []types.ResourceID
// PrivateKeyPolicy is the private key policy supported by this identity.
PrivateKeyPolicy keys.PrivateKeyPolicy
// ConnectionDiagnosticID is used to add connection diagnostic messages when Testing a Connection.
ConnectionDiagnosticID string
}
// RouteToApp holds routing information for applications.
@ -433,6 +436,11 @@ var (
// deadline of the session on a certificates issued after an MFA check.
// See https://github.com/gravitational/teleport/issues/18544.
PreviousIdentityExpiresASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 12}
// ConnectionDiagnosticIDASN1ExtensionOID is an extension OID used to indicate the Connection Diagnostic ID.
// When using the Test Connection feature, there's propagation of the ConnectionDiagnosticID.
// Each service (ex DB Agent) uses that to add checkpoints describing if it was a success or a failure.
ConnectionDiagnosticIDASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 13}
)
// Subject converts identity to X.509 subject name
@ -685,6 +693,15 @@ func (id *Identity) Subject() (pkix.Name, error) {
)
}
if id.ConnectionDiagnosticID != "" {
subject.ExtraNames = append(subject.ExtraNames,
pkix.AttributeTypeAndValue{
Type: ConnectionDiagnosticIDASN1ExtensionOID,
Value: id.ConnectionDiagnosticID,
},
)
}
return subject, nil
}
@ -868,6 +885,11 @@ func FromSubject(subject pkix.Name, expires time.Time) (*Identity, error) {
if ok {
id.PrivateKeyPolicy = keys.PrivateKeyPolicy(val)
}
case attr.Type.Equal(ConnectionDiagnosticIDASN1ExtensionOID):
val, ok := attr.Value.(string)
if ok {
id.ConnectionDiagnosticID = val
}
}
}

View file

@ -74,6 +74,7 @@ func (h *Handler) diagnoseConnection(w http.ResponseWriter, r *http.Request, p h
ResourceKind: req.ResourceKind,
UserClient: userClt,
ProxyHostPort: h.ProxyHostPort(),
PublicProxyAddr: h.cfg.PublicProxyAddr,
KubernetesPublicProxyAddr: h.kubeProxyHostPort(),
TLSRoutingEnabled: proxySettings.TLSRoutingEnabled,
}