mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 09:13:39 +00:00
Connection Diagnostics: Postgres Database tester (#18558)
* Connection Diagnostics: Postgres Database tester When adding a new resource using the Web UI we want to allow users to test connecting to it. We have two connection testers already: - for SSH Nodes - for Kube clusters This PR adds a third tester: Postgres Database. Most of the required changes for any Database are already present but we're focusing on Postgres for now. Other databases will be added as future PRs. Testing a Database is similar to the other tests: - generate certs for the logged in user - connect to the resource using those certs When generating the certificate, we inject an ID so that the Database Service can add Connection Diagnostic traces.
This commit is contained in:
parent
0433a9d5cb
commit
6fd0fa8ef3
|
@ -4389,6 +4389,14 @@ message ConnectionDiagnosticTrace {
|
|||
RBAC_KUBE = 6;
|
||||
// KUBE_PRINCIPAL is used when checking if the Kube Cluster has at least one user principals.
|
||||
KUBE_PRINCIPAL = 7;
|
||||
// RBAC_DATABASE is for RBAC checks to database access (db_labels).
|
||||
RBAC_DATABASE = 8;
|
||||
// RBAC_DATABASE_LOGIN is for RBAC checks to database login (db_name and db_user).
|
||||
RBAC_DATABASE_LOGIN = 9;
|
||||
// DATABASE_DB_USER is used when checking whether the Database has the requested Database User.
|
||||
DATABASE_DB_USER = 10;
|
||||
// DATABASE_DB_NAME is used when checking whether the Database has the requested Database Name.
|
||||
DATABASE_DB_NAME = 11;
|
||||
}
|
||||
TraceType Type = 1 [(gogoproto.jsontag) = "type"];
|
||||
// StatusType describes whether this was a success or a failure.
|
||||
|
|
File diff suppressed because it is too large
Load diff
288
integration/conntest/database_test.go
Normal file
288
integration/conntest/database_test.go
Normal file
|
@ -0,0 +1,288 @@
|
|||
// Copyright 2022 Gravitational, Inc
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package conntest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apidefaults "github.com/gravitational/teleport/api/defaults"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/integration/helpers"
|
||||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/client/conntest"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/teleport/lib/srv/db/common"
|
||||
"github.com/gravitational/teleport/lib/srv/db/postgres"
|
||||
"github.com/gravitational/teleport/lib/web/ui"
|
||||
)
|
||||
|
||||
func startPostgresTestServer(t *testing.T, authServer *auth.Server) *postgres.TestServer {
|
||||
postgresTestServer, err := postgres.NewTestServer(common.TestServerConfig{
|
||||
AuthClient: authServer,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
t.Logf("Postgres Fake server running at %s port", postgresTestServer.Port())
|
||||
assert.NoError(t, postgresTestServer.Serve())
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
postgresTestServer.Close()
|
||||
})
|
||||
|
||||
return postgresTestServer
|
||||
}
|
||||
|
||||
func TestDiagnoseConnectionForPostgresDatabases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start Teleport Auth and Proxy services
|
||||
authProcess, proxyProcess, provisionToken := helpers.MakeTestServers(t)
|
||||
authServer := authProcess.GetAuthServer()
|
||||
proxyAddr, err := proxyProcess.ProxyWebAddr()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start Fake Postgres Database
|
||||
postgresTestServer := startPostgresTestServer(t, authServer)
|
||||
|
||||
// Start Teleport Database Service
|
||||
databaseResourceName := "mypsqldb"
|
||||
databaseDBName := "dbname"
|
||||
databaseDBUser := "dbuser"
|
||||
helpers.MakeTestDatabaseServer(t, *proxyAddr, provisionToken, service.Database{
|
||||
Name: databaseResourceName,
|
||||
Protocol: defaults.ProtocolPostgres,
|
||||
URI: net.JoinHostPort("localhost", postgresTestServer.Port()),
|
||||
})
|
||||
// Wait for the Database Server to be registered
|
||||
waitForDatabases(t, authServer, []string{databaseResourceName})
|
||||
|
||||
roleWithFullAccess, err := types.NewRole("fullaccess", types.RoleSpecV5{
|
||||
Allow: types.RoleConditions{
|
||||
Namespaces: []string{apidefaults.Namespace},
|
||||
DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}},
|
||||
Rules: []types.Rule{
|
||||
types.NewRule(types.KindConnectionDiagnostic, services.RW()),
|
||||
},
|
||||
DatabaseUsers: []string{databaseDBUser},
|
||||
DatabaseNames: []string{databaseDBName},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, authServer.UpsertRole(ctx, roleWithFullAccess))
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
teleportUser string
|
||||
|
||||
reqResourceName string
|
||||
reqDBUser string
|
||||
reqDBName string
|
||||
|
||||
expectedSuccess bool
|
||||
expectedMessage string
|
||||
expectedTraces []types.ConnectionDiagnosticTrace
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
teleportUser: "success",
|
||||
|
||||
reqResourceName: databaseResourceName,
|
||||
reqDBUser: databaseDBUser,
|
||||
reqDBName: databaseDBName,
|
||||
|
||||
expectedSuccess: true,
|
||||
expectedMessage: "success",
|
||||
expectedTraces: []types.ConnectionDiagnosticTrace{
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE,
|
||||
Status: types.ConnectionDiagnosticTrace_SUCCESS,
|
||||
Details: "A Database Agent is available to proxy the connection to the Database.",
|
||||
},
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_CONNECTIVITY,
|
||||
Status: types.ConnectionDiagnosticTrace_SUCCESS,
|
||||
Details: "Database is accessible from the Database Agent.",
|
||||
},
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
|
||||
Status: types.ConnectionDiagnosticTrace_SUCCESS,
|
||||
Details: "Access to Database User and Database Name granted.",
|
||||
},
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_DATABASE_DB_USER,
|
||||
Status: types.ConnectionDiagnosticTrace_SUCCESS,
|
||||
Details: "Database User exists in the Database.",
|
||||
},
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_DATABASE_DB_NAME,
|
||||
Status: types.ConnectionDiagnosticTrace_SUCCESS,
|
||||
Details: "Database Name exists in the Database.",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "databse not found",
|
||||
teleportUser: "dbnotfound",
|
||||
|
||||
reqResourceName: "dbnotfound",
|
||||
reqDBUser: databaseDBUser,
|
||||
reqDBName: databaseDBName,
|
||||
|
||||
expectedSuccess: false,
|
||||
expectedMessage: "failed",
|
||||
expectedTraces: []types.ConnectionDiagnosticTrace{
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE,
|
||||
Status: types.ConnectionDiagnosticTrace_FAILED,
|
||||
Details: "Database not found. " +
|
||||
"Ensure your role grants access by adding it to the 'db_labels' property. " +
|
||||
"This can also happen when you don't have a Database Agent proxying the database - " +
|
||||
"you can fix that by adding the database labels to the 'db_service.resources.labels' in 'teleport.yaml' file of the database agent.",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no access to db user/name",
|
||||
teleportUser: "deniedlogin",
|
||||
|
||||
reqResourceName: databaseResourceName,
|
||||
reqDBUser: "root",
|
||||
reqDBName: "system",
|
||||
|
||||
expectedSuccess: false,
|
||||
expectedMessage: "failed",
|
||||
expectedTraces: []types.ConnectionDiagnosticTrace{
|
||||
{
|
||||
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
|
||||
Status: types.ConnectionDiagnosticTrace_FAILED,
|
||||
Details: "Access denied when accessing Database. Please check the Error message for more information.",
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt := tt
|
||||
|
||||
// Set up User
|
||||
user, err := types.NewUser(tt.teleportUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AddRole(roleWithFullAccess.GetName())
|
||||
require.NoError(t, authServer.UpsertUser(user))
|
||||
|
||||
userPassword := uuid.NewString()
|
||||
require.NoError(t, authServer.UpsertPassword(tt.teleportUser, []byte(userPassword)))
|
||||
|
||||
webPack := helpers.LoginWebClient(t, proxyAddr.String(), tt.teleportUser, userPassword)
|
||||
|
||||
diagnoseReq := conntest.TestConnectionRequest{
|
||||
ResourceKind: types.KindDatabase,
|
||||
ResourceName: tt.reqResourceName,
|
||||
DatabaseUser: tt.reqDBUser,
|
||||
DatabaseName: tt.reqDBName,
|
||||
// Default is 30 seconds but since tests run locally, we can reduce this value to also improve test responsiveness
|
||||
DialTimeout: time.Second,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
diagnoseConnectionEndpoint := strings.Join([]string{"sites", "$site", "diagnostics", "connections"}, "/")
|
||||
resp, err := webPack.DoRequest(http.MethodPost, diagnoseConnectionEndpoint, diagnoseReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode, string(respBody))
|
||||
|
||||
var connectionDiagnostic ui.ConnectionDiagnostic
|
||||
require.NoError(t, json.Unmarshal(respBody, &connectionDiagnostic))
|
||||
|
||||
gotFailedTraces := 0
|
||||
expectedFailedTraces := 0
|
||||
|
||||
for i, trace := range connectionDiagnostic.Traces {
|
||||
if trace.Status == types.ConnectionDiagnosticTrace_FAILED.String() {
|
||||
gotFailedTraces++
|
||||
}
|
||||
|
||||
t.Logf("%d status='%s' type='%s' details='%s' error='%s'\n", i, trace.Status, trace.TraceType, trace.Details, trace.Error)
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedSuccess, connectionDiagnostic.Success)
|
||||
require.Equal(t, tt.expectedMessage, connectionDiagnostic.Message)
|
||||
for _, expectedTrace := range tt.expectedTraces {
|
||||
if expectedTrace.Status == types.ConnectionDiagnosticTrace_FAILED {
|
||||
expectedFailedTraces++
|
||||
}
|
||||
|
||||
foundTrace := false
|
||||
for _, returnedTrace := range connectionDiagnostic.Traces {
|
||||
if expectedTrace.Type.String() != returnedTrace.TraceType {
|
||||
continue
|
||||
}
|
||||
|
||||
foundTrace = true
|
||||
require.Equal(t, expectedTrace.Status.String(), returnedTrace.Status)
|
||||
require.Equal(t, expectedTrace.Details, returnedTrace.Details)
|
||||
require.Contains(t, returnedTrace.Error, expectedTrace.Error)
|
||||
}
|
||||
|
||||
require.True(t, foundTrace, expectedTrace)
|
||||
}
|
||||
|
||||
require.Equal(t, expectedFailedTraces, gotFailedTraces)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func waitForDatabases(t *testing.T, authServer *auth.Server, dbNames []string) {
|
||||
ctx := context.Background()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
all, err := authServer.GetDatabaseServers(ctx, apidefaults.Namespace)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(dbNames) > len(all) {
|
||||
return false
|
||||
}
|
||||
|
||||
registered := 0
|
||||
for _, db := range dbNames {
|
||||
for _, a := range all {
|
||||
if a.GetName() == db {
|
||||
registered++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return registered == len(dbNames)
|
||||
|
||||
}, 10*time.Second, 100*time.Millisecond)
|
||||
}
|
|
@ -28,21 +28,28 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
|
||||
"github.com/gravitational/teleport/api/breaker"
|
||||
"github.com/gravitational/teleport/api/constants"
|
||||
apidefaults "github.com/gravitational/teleport/api/defaults"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
apievents "github.com/gravitational/teleport/api/types/events"
|
||||
"github.com/gravitational/teleport/api/utils/retryutils"
|
||||
"github.com/gravitational/teleport/lib"
|
||||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/backend"
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
libclient "github.com/gravitational/teleport/lib/client"
|
||||
"github.com/gravitational/teleport/lib/client/identityfile"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/teleagent"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
)
|
||||
|
||||
// CommandOptions controls how the SSH command is built.
|
||||
|
@ -301,3 +308,119 @@ func WaitForDatabaseServers(t *testing.T, authServer *auth.Server, dbs []service
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MakeTestServers starts an Auth and a Proxy Service.
|
||||
// Besides those processes, it also returns a provision token which can be used to add other services.
|
||||
func MakeTestServers(t *testing.T) (auth *service.TeleportProcess, proxy *service.TeleportProcess, provisionToken string) {
|
||||
provisionToken = uuid.NewString()
|
||||
var err error
|
||||
// Set up a test auth server.
|
||||
//
|
||||
// We need this to get a random port assigned to it and allow parallel
|
||||
// execution of this test.
|
||||
cfg := service.MakeDefaultConfig()
|
||||
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()
|
||||
cfg.Hostname = "localhost"
|
||||
cfg.DataDir = t.TempDir()
|
||||
cfg.SetAuthServerAddress(cfg.Auth.ListenAddr)
|
||||
cfg.Auth.ListenAddr.Addr = NewListener(t, service.ListenerAuth, &cfg.FileDescriptors)
|
||||
cfg.Auth.Preference.SetSecondFactor(constants.SecondFactorOff)
|
||||
cfg.Auth.StorageConfig.Params = backend.Params{defaults.BackendPath: filepath.Join(cfg.DataDir, defaults.BackendDir)}
|
||||
cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{
|
||||
StaticTokens: []types.ProvisionTokenV1{{
|
||||
Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster, types.RoleNode, types.RoleApp},
|
||||
Expires: time.Now().Add(time.Minute),
|
||||
Token: provisionToken,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
cfg.SSH.Enabled = false
|
||||
cfg.Auth.Enabled = true
|
||||
cfg.Proxy.Enabled = false
|
||||
cfg.Log = utils.NewLoggerForTests()
|
||||
|
||||
auth, err = service.NewTeleport(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, auth.Start())
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, auth.Close())
|
||||
require.NoError(t, auth.Wait())
|
||||
})
|
||||
|
||||
// Wait for proxy to become ready.
|
||||
_, err = auth.WaitForEventTimeout(30*time.Second, service.AuthTLSReady)
|
||||
// in reality, the auth server should start *much* sooner than this. we use a very large
|
||||
// timeout here because this isn't the kind of problem that this test is meant to catch.
|
||||
require.NoError(t, err, "auth server didn't start after 30s")
|
||||
|
||||
authAddr, err := auth.AuthAddr()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up a test proxy service.
|
||||
cfg = service.MakeDefaultConfig()
|
||||
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()
|
||||
cfg.Hostname = "localhost"
|
||||
cfg.DataDir = t.TempDir()
|
||||
|
||||
cfg.SetAuthServerAddress(*authAddr)
|
||||
cfg.SetToken(provisionToken)
|
||||
cfg.SSH.Enabled = false
|
||||
cfg.Auth.Enabled = false
|
||||
cfg.Proxy.Enabled = true
|
||||
cfg.Proxy.ReverseTunnelListenAddr.Addr = NewListener(t, service.ListenerProxyTunnel, &cfg.FileDescriptors)
|
||||
cfg.Proxy.WebAddr.Addr = NewListener(t, service.ListenerProxyWeb, &cfg.FileDescriptors)
|
||||
cfg.Proxy.PublicAddrs = []utils.NetAddr{
|
||||
cfg.Proxy.WebAddr,
|
||||
}
|
||||
cfg.Proxy.DisableWebInterface = true
|
||||
cfg.Log = utils.NewLoggerForTests()
|
||||
|
||||
proxy, err = service.NewTeleport(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, proxy.Start())
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, proxy.Close())
|
||||
require.NoError(t, proxy.Wait())
|
||||
})
|
||||
|
||||
// Wait for proxy to become ready.
|
||||
_, err = proxy.WaitForEventTimeout(10*time.Second, service.ProxyWebServerReady)
|
||||
require.NoError(t, err, "proxy web server didn't start after 10s")
|
||||
|
||||
return auth, proxy, provisionToken
|
||||
}
|
||||
|
||||
// MakeTestDatabaseServer creates a Database Service
|
||||
// It receives the Proxy Address, a Token (to join the cluster) and a list of Datbases
|
||||
func MakeTestDatabaseServer(t *testing.T, proxyAddr utils.NetAddr, token string, dbs ...service.Database) (db *service.TeleportProcess) {
|
||||
// Proxy uses self-signed certificates in tests.
|
||||
lib.SetInsecureDevMode(true)
|
||||
|
||||
cfg := service.MakeDefaultConfig()
|
||||
cfg.Hostname = "localhost"
|
||||
cfg.DataDir = t.TempDir()
|
||||
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()
|
||||
cfg.SetAuthServerAddress(proxyAddr)
|
||||
cfg.SetToken(token)
|
||||
cfg.SSH.Enabled = false
|
||||
cfg.Auth.Enabled = false
|
||||
cfg.Databases.Enabled = true
|
||||
cfg.Databases.Databases = dbs
|
||||
cfg.Log = utils.NewLoggerForTests()
|
||||
|
||||
db, err := service.NewTeleport(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Start())
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, db.Close())
|
||||
})
|
||||
|
||||
// Wait for database agent to start.
|
||||
_, err = db.WaitForEventTimeout(10*time.Second, service.DatabasesReady)
|
||||
require.NoError(t, err, "database server didn't start after 10s")
|
||||
|
||||
return db
|
||||
}
|
||||
|
|
145
integration/helpers/web.go
Normal file
145
integration/helpers/web.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
// Copyright 2022 Gravitational, Inc
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport/lib/httplib/csrf"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
"github.com/gravitational/teleport/lib/web"
|
||||
"github.com/gravitational/teleport/lib/web/ui"
|
||||
)
|
||||
|
||||
// WebClientPack is an authenticated HTTP Client for Teleport.
|
||||
type WebClientPack struct {
|
||||
clt *http.Client
|
||||
host string
|
||||
webCookie string
|
||||
bearerToken string
|
||||
clusterName string
|
||||
}
|
||||
|
||||
// LoginWebClient receives the host url, the username and a password.
|
||||
// It will login into that host and return a WebClientPack.
|
||||
func LoginWebClient(t *testing.T, host, username, password string) *WebClientPack {
|
||||
csReq, err := json.Marshal(web.CreateSessionReq{
|
||||
User: username,
|
||||
Pass: password,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create POST request to create session.
|
||||
u := url.URL{
|
||||
Scheme: "https",
|
||||
Host: host,
|
||||
Path: "/v1/webapi/sessions/web",
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewBuffer(csReq))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attach CSRF token in cookie and header.
|
||||
csrfToken, err := utils.CryptoRandomHex(32)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: csrf.CookieName,
|
||||
Value: csrfToken,
|
||||
})
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set(csrf.HeaderName, csrfToken)
|
||||
|
||||
// Issue request.
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Read in response.
|
||||
var csResp *web.CreateSessionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&csResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract session cookie and bearer token.
|
||||
require.Len(t, resp.Cookies(), 1)
|
||||
cookie := resp.Cookies()[0]
|
||||
require.Equal(t, cookie.Name, web.CookieName)
|
||||
|
||||
webClient := &WebClientPack{
|
||||
clt: client,
|
||||
host: host,
|
||||
webCookie: cookie.Value,
|
||||
bearerToken: csResp.Token,
|
||||
}
|
||||
|
||||
resp, err = webClient.DoRequest(http.MethodGet, "sites", nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
var clusters []ui.Cluster
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&clusters))
|
||||
require.NotEmpty(t, clusters)
|
||||
|
||||
webClient.clusterName = clusters[0].Name
|
||||
return webClient
|
||||
}
|
||||
|
||||
// DoRequest receives a method, endpoint and payload and sends an HTTP Request to the Teleport API.
|
||||
// The endpoint must not contain the host neither the base path ('/v1/webapi/').
|
||||
// Returns the http.Response.
|
||||
func (w *WebClientPack) DoRequest(method, endpoint string, payload any) (*http.Response, error) {
|
||||
endpoint = strings.ReplaceAll(endpoint, "$site", w.clusterName)
|
||||
u := url.URL{
|
||||
Scheme: "https",
|
||||
Host: w.host,
|
||||
Path: fmt.Sprintf("/v1/webapi/%s", endpoint),
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(bs))
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: web.CookieName,
|
||||
Value: w.webCookie,
|
||||
})
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", w.bearerToken))
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
resp, err := w.clt.Do(req)
|
||||
return resp, trace.Wrap(err)
|
||||
}
|
|
@ -1488,6 +1488,7 @@ func (a *Server) generateUserCert(req certRequest) (*proto.Certs, error) {
|
|||
Generation: req.generation,
|
||||
AllowedResourceIDs: req.checker.GetAllowedResourceIDs(),
|
||||
PrivateKeyPolicy: attestedKeyPolicy,
|
||||
ConnectionDiagnosticID: req.connectionDiagnosticID,
|
||||
}
|
||||
subject, err := identity.Subject()
|
||||
if err != nil {
|
||||
|
|
|
@ -597,6 +597,7 @@ func definitionForBuiltinRole(clusterName string, recConfig types.SessionRecordi
|
|||
types.NewRule(types.KindDatabase, services.RW()),
|
||||
types.NewRule(types.KindSemaphore, services.RW()),
|
||||
types.NewRule(types.KindLock, services.RO()),
|
||||
types.NewRule(types.KindConnectionDiagnostic, services.RW()),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
|
168
lib/client/alpn.go
Normal file
168
lib/client/alpn.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
"github.com/gravitational/teleport/api/client/proto"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/api/utils/keys"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy"
|
||||
alpn "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
)
|
||||
|
||||
// ALPNAuthClient contains the required auth.ClientI methods to create a local ALPN proxy.
|
||||
type ALPNAuthClient interface {
|
||||
// GetClusterCACert returns the PEM-encoded TLS certs for the local cluster.
|
||||
// If the cluster has multiple TLS certs, they will all be concatenated.
|
||||
GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error)
|
||||
|
||||
// GetCurrentUser returns current user as seen by the server.
|
||||
// Useful especially in the context of remote clusters which perform role and trait mapping.
|
||||
GetCurrentUser(ctx context.Context) (types.User, error)
|
||||
|
||||
// GenerateUserCerts takes the public key in the OpenSSH `authorized_keys` plain
|
||||
// text format, signs it using User Certificate Authority signing key and
|
||||
// returns the resulting certificates.
|
||||
GenerateUserCerts(ctx context.Context, req proto.UserCertsRequest) (*proto.Certs, error)
|
||||
}
|
||||
|
||||
// ALPNAuthTunnelConfig contains the required fields used to create an authed ALPN Proxy
|
||||
type ALPNAuthTunnelConfig struct {
|
||||
// AuthClient is the client that's used to interact with the cluster and obtain Certificates.
|
||||
AuthClient ALPNAuthClient
|
||||
|
||||
// Listener to be used to accept connections that will go trough the tunnel.
|
||||
Listener net.Listener
|
||||
|
||||
// InsecureSkipTLSVerify turns off verification for x509 upstream ALPN proxy service certificate.
|
||||
InsecureSkipVerify bool
|
||||
|
||||
// Expires is a desired time of the expiry of the certificate.
|
||||
Expires time.Time
|
||||
|
||||
// Protocol name.
|
||||
Protocol alpn.Protocol
|
||||
|
||||
// PublicProxyAddr is public address of the proxy
|
||||
PublicProxyAddr string
|
||||
|
||||
// ConnectionDiagnosticID contains the ID to be used to store Connection Diagnostic checks.
|
||||
// Can be empty.
|
||||
ConnectionDiagnosticID string
|
||||
|
||||
// RouteToDatabase contains the destination server that must receive the connection.
|
||||
// Specific for database proxying.
|
||||
RouteToDatabase proto.RouteToDatabase
|
||||
}
|
||||
|
||||
// RunALPNAuthTunnel runs a local authenticated ALPN proxy to another service.
|
||||
// At least one Route (which defines the service) must be defined
|
||||
func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error {
|
||||
protocols := []alpn.Protocol{cfg.Protocol}
|
||||
if alpn.HasPingSupport(cfg.Protocol) {
|
||||
protocols = append(alpn.ProtocolsWithPing(cfg.Protocol), protocols...)
|
||||
}
|
||||
|
||||
var pool *x509.CertPool
|
||||
|
||||
alpnUpgradeRequired := alpnproxy.IsALPNConnUpgradeRequired(cfg.PublicProxyAddr, cfg.InsecureSkipVerify)
|
||||
|
||||
if alpnUpgradeRequired {
|
||||
caCert, err := cfg.AuthClient.GetClusterCACert(ctx)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
pool = x509.NewCertPool()
|
||||
if ok := pool.AppendCertsFromPEM(caCert.GetTLSCA()); !ok {
|
||||
return trace.BadParameter("failed to append cert from cluster's TLS CA Cert")
|
||||
}
|
||||
}
|
||||
|
||||
address, err := utils.ParseAddr(cfg.PublicProxyAddr)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
tlsCert, err := getUserCerts(ctx, cfg.AuthClient, cfg.Expires, cfg.RouteToDatabase, cfg.ConnectionDiagnosticID)
|
||||
if err != nil {
|
||||
return trace.BadParameter("failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
RemoteProxyAddr: cfg.PublicProxyAddr,
|
||||
Protocols: protocols,
|
||||
Listener: cfg.Listener,
|
||||
ParentContext: ctx,
|
||||
SNI: address.Host(),
|
||||
Certs: []tls.Certificate{*tlsCert},
|
||||
RootCAs: pool,
|
||||
ALPNConnUpgradeRequired: alpnUpgradeRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer cfg.Listener.Close()
|
||||
if err := lp.Start(ctx); err != nil {
|
||||
log.WithError(err).Info("ALPN proxy stopped.")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserCerts(ctx context.Context, client ALPNAuthClient, expires time.Time, routeToDatabase proto.RouteToDatabase, connectionDiagnosticID string) (*tls.Certificate, error) {
|
||||
key, err := GenerateRSAKey()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
currentUser, err := client.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
certs, err := client.GenerateUserCerts(ctx, proto.UserCertsRequest{
|
||||
PublicKey: key.MarshalSSHPublicKey(),
|
||||
Username: currentUser.GetName(),
|
||||
Expires: expires,
|
||||
ConnectionDiagnosticID: connectionDiagnosticID,
|
||||
RouteToDatabase: routeToDatabase,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
tlsCert, err := keys.X509KeyPair(certs.TLS, key.PrivateKeyPEM())
|
||||
if err != nil {
|
||||
return nil, trace.BadParameter("failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
return &tlsCert, nil
|
||||
}
|
|
@ -37,6 +37,12 @@ type TestConnectionRequest struct {
|
|||
// ResourceName is the identification of the resource's instance to test.
|
||||
ResourceName string `json:"resource_name"`
|
||||
|
||||
// DialTimeout when trying to connect to the destination host
|
||||
DialTimeout time.Duration `json:"dial_timeout,omitempty"`
|
||||
|
||||
// InsecureSkipTLSVerify turns off verification for x509 upstream ALPN proxy service certificate.
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
|
||||
|
||||
// SSHPrincipal is the Linux username to use in a connection test.
|
||||
// Specific to SSHTester.
|
||||
SSHPrincipal string `json:"ssh_principal,omitempty"`
|
||||
|
@ -50,8 +56,13 @@ type TestConnectionRequest struct {
|
|||
// Specific to KubernetesTester.
|
||||
KubernetesImpersonation KubernetesImpersonation `json:"kubernetes_impersonation,omitempty"`
|
||||
|
||||
// DialTimeout when trying to connect to the destination host
|
||||
DialTimeout time.Duration `json:"dial_timeout,omitempty"`
|
||||
// DatabaseUser is the database User to be tested
|
||||
// Specific to DatabaseTester.
|
||||
DatabaseUser string `json:"database_user,omitempty"`
|
||||
|
||||
// DatabaseName is the database user of the Database to be tested
|
||||
// Specific to DatabaseTester.
|
||||
DatabaseName string `json:"database_name,omitempty"`
|
||||
}
|
||||
|
||||
// KubernetesImpersonation allows to configure a subset of `kubernetes_users` and
|
||||
|
@ -117,6 +128,9 @@ type ConnectionTesterConfig struct {
|
|||
// ProxyHostPort is the proxy to use in the `--proxy` format (host:webPort,sshPort)
|
||||
ProxyHostPort string
|
||||
|
||||
// PublicProxyAddr is public address of the proxy.
|
||||
PublicProxyAddr string
|
||||
|
||||
// KubernetesPublicProxyAddr is the kubernetes proxy.
|
||||
KubernetesPublicProxyAddr string
|
||||
|
||||
|
@ -148,6 +162,15 @@ func ConnectionTesterForKind(cfg ConnectionTesterConfig) (ConnectionTester, erro
|
|||
},
|
||||
)
|
||||
return tester, trace.Wrap(err)
|
||||
case types.KindDatabase:
|
||||
tester, err := NewDatabaseConnectionTester(
|
||||
DatabaseConnectionTesterConfig{
|
||||
UserClient: cfg.UserClient,
|
||||
PublicProxyAddr: cfg.PublicProxyAddr,
|
||||
TLSRoutingEnabled: cfg.TLSRoutingEnabled,
|
||||
},
|
||||
)
|
||||
return tester, trace.Wrap(err)
|
||||
default:
|
||||
return nil, trace.NotImplemented("resource %q does not have a connection tester", cfg.ResourceKind)
|
||||
}
|
||||
|
|
430
lib/client/conntest/database.go
Normal file
430
lib/client/conntest/database.go
Normal file
|
@ -0,0 +1,430 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package conntest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
apiclient "github.com/gravitational/teleport/api/client"
|
||||
"github.com/gravitational/teleport/api/client/proto"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
"github.com/gravitational/teleport/lib/client/conntest/database"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
alpn "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/srv/db/common/role"
|
||||
)
|
||||
|
||||
// databasePinger describes the required methods to test a Database Connection.
|
||||
type databasePinger interface {
|
||||
// Ping tests the connection to the Database with a simple request.
|
||||
Ping(ctx context.Context, req database.PingParams) error
|
||||
|
||||
// IsConnectionRefusedError returns whether the error is referring to a connection refused.
|
||||
IsConnectionRefusedError(error) bool
|
||||
|
||||
// IsInvalidDatabaseUserError returns whether the error is referring to an invalid (non-existent) user.
|
||||
IsInvalidDatabaseUserError(error) bool
|
||||
|
||||
// IsInvalidDatabaseNameError returns whether the error is referring to an invalid (non-existent) database name.
|
||||
IsInvalidDatabaseNameError(error) bool
|
||||
}
|
||||
|
||||
// ClientDatabaseConnectionTester contains the required auth.ClientI methods to test a Database Connection
|
||||
type ClientDatabaseConnectionTester interface {
|
||||
client.ALPNAuthClient
|
||||
|
||||
services.ConnectionsDiagnostic
|
||||
apiclient.ListResourcesClient
|
||||
}
|
||||
|
||||
// DatabaseConnectionTesterConfig defines the config fields for DatabaseConnectionTester.
|
||||
type DatabaseConnectionTesterConfig struct {
|
||||
// UserClient is an auth client that has a User's identity.
|
||||
UserClient ClientDatabaseConnectionTester
|
||||
|
||||
// PublicProxyAddr is public address of the proxy
|
||||
PublicProxyAddr string
|
||||
|
||||
// TLSRoutingEnabled indicates that proxy supports ALPN SNI server where
|
||||
// all proxy services are exposed on a single TLS listener (Proxy Web Listener).
|
||||
TLSRoutingEnabled bool
|
||||
}
|
||||
|
||||
// DatabaseConnectionTester implements the ConnectionTester interface for Testing Database access.
|
||||
type DatabaseConnectionTester struct {
|
||||
cfg DatabaseConnectionTesterConfig
|
||||
}
|
||||
|
||||
// NewDatabaseConnectionTester returns a new DatabaseConnectionTester
|
||||
func NewDatabaseConnectionTester(cfg DatabaseConnectionTesterConfig) (*DatabaseConnectionTester, error) {
|
||||
return &DatabaseConnectionTester{
|
||||
cfg: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestConnection tests the access to a database using:
|
||||
// - auth Client using the User access
|
||||
// - the resource name
|
||||
// - database user and database name to connect to
|
||||
//
|
||||
// A new ConnectionDiagnostic is created and used to store the traces as it goes through the checkpoints
|
||||
// To connect to the Database, we will create a cert-key pair and setup a Database client back to Teleport Proxy.
|
||||
// The following checkpoints are reported:
|
||||
// - database server for the requested database exists / the user's roles can access it
|
||||
// - the user can use the requested database user and database name (per their roles)
|
||||
// - the database is acessible and accepting connections from the database server
|
||||
// - the database has the database user and database name that was requested
|
||||
func (s *DatabaseConnectionTester) TestConnection(ctx context.Context, req TestConnectionRequest) (types.ConnectionDiagnostic, error) {
|
||||
if req.ResourceKind != types.KindDatabase {
|
||||
return nil, trace.BadParameter("invalid value for ResourceKind, expected %q got %q", types.KindDatabase, req.ResourceKind)
|
||||
}
|
||||
|
||||
connectionDiagnosticID := uuid.NewString()
|
||||
connectionDiagnostic, err := types.NewConnectionDiagnosticV1(
|
||||
connectionDiagnosticID,
|
||||
map[string]string{},
|
||||
types.ConnectionDiagnosticSpecV1{
|
||||
// We start with a failed state so that we don't need to set it to each return statement once an error is returned.
|
||||
// if the test reaches the end, we force the test to be a success by calling
|
||||
// connectionDiagnostic.SetMessage(types.DiagnosticMessageSuccess)
|
||||
// connectionDiagnostic.SetSuccess(true)
|
||||
Message: types.DiagnosticMessageFailed,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if err := s.cfg.UserClient.CreateConnectionDiagnostic(ctx, connectionDiagnostic); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
databaseServers, err := s.getDatabaseServers(ctx, req.ResourceName)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if len(databaseServers) == 0 {
|
||||
connDiag, err := s.appendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_RBAC_DATABASE,
|
||||
"Database not found. "+
|
||||
"Ensure your role grants access by adding it to the 'db_labels' property. "+
|
||||
"This can also happen when you don't have a Database Agent proxying the database - "+
|
||||
"you can fix that by adding the database labels to the 'db_service.resources.labels' in 'teleport.yaml' file of the database agent.",
|
||||
trace.NotFound("%s not found", req.ResourceName),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
databaseServer := databaseServers[0]
|
||||
routeToDatabase := proto.RouteToDatabase{
|
||||
ServiceName: databaseServer.GetName(),
|
||||
Protocol: databaseServer.GetDatabase().GetProtocol(),
|
||||
Username: req.DatabaseUser,
|
||||
Database: req.DatabaseName,
|
||||
}
|
||||
|
||||
databasePinger, err := getDatabaseConnTester(routeToDatabase.Protocol)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if err := checkDatabaseLogin(routeToDatabase.Protocol, req.DatabaseUser, req.DatabaseName); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if _, err := s.appendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_RBAC_DATABASE,
|
||||
"A Database Agent is available to proxy the connection to the Database.",
|
||||
nil,
|
||||
); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
listener, err := s.runALPNTunnel(ctx, req, routeToDatabase, connectionDiagnosticID)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
ping, err := newPing(listener.Addr().String(), req.DatabaseUser, req.DatabaseName)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if pingErr := databasePinger.Ping(ctx, ping); pingErr != nil {
|
||||
connDiag, err := s.handlePingError(ctx, connectionDiagnosticID, pingErr, databasePinger)
|
||||
return connDiag, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return s.handlePingSuccess(ctx, connectionDiagnosticID)
|
||||
}
|
||||
|
||||
func (s *DatabaseConnectionTester) runALPNTunnel(ctx context.Context, req TestConnectionRequest, routeToDatabase proto.RouteToDatabase, connectionDiagnosticID string) (net.Listener, error) {
|
||||
list, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
alpnProtocol, err := alpn.ToALPNProtocol(routeToDatabase.Protocol)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = client.RunALPNAuthTunnel(ctx, client.ALPNAuthTunnelConfig{
|
||||
AuthClient: s.cfg.UserClient,
|
||||
Listener: list,
|
||||
Protocol: alpnProtocol,
|
||||
Expires: time.Now().Add(time.Minute).UTC(),
|
||||
PublicProxyAddr: s.cfg.PublicProxyAddr,
|
||||
ConnectionDiagnosticID: connectionDiagnosticID,
|
||||
RouteToDatabase: routeToDatabase,
|
||||
InsecureSkipVerify: req.InsecureSkipVerify,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseConnectionTester) getDatabaseServers(ctx context.Context, databaseName string) ([]types.DatabaseServer, error) {
|
||||
// Lookup the Database Server that's proxying the requested Database.
|
||||
listResourcesResponse, err := s.cfg.UserClient.ListResources(ctx, proto.ListResourcesRequest{
|
||||
PredicateExpression: fmt.Sprintf(`name == "%s"`, databaseName),
|
||||
ResourceType: types.KindDatabaseServer,
|
||||
Limit: defaults.MaxIterationLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
databaseServers, err := types.ResourcesWithLabels(listResourcesResponse.Resources).AsDatabaseServers()
|
||||
return databaseServers, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func checkDatabaseLogin(protocol, databaseUser, databaseName string) error {
|
||||
matchers := role.DatabaseRoleMatchers(protocol, databaseUser, databaseName)
|
||||
needUser := false
|
||||
needDatabase := false
|
||||
|
||||
for _, matcher := range matchers {
|
||||
_, userMatcher := matcher.(*services.DatabaseUserMatcher)
|
||||
needUser = needUser || userMatcher
|
||||
|
||||
_, nameMatcher := matcher.(*services.DatabaseNameMatcher)
|
||||
needDatabase = needDatabase || nameMatcher
|
||||
}
|
||||
|
||||
if needUser && databaseUser == "" {
|
||||
return trace.BadParameter("missing required parameter Database User")
|
||||
}
|
||||
|
||||
if needDatabase && databaseName == "" {
|
||||
return trace.BadParameter("missing required parameter Database Name")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newPing(alpnProxyAddr, databaseUser, databaseName string) (database.PingParams, error) {
|
||||
proxyHost, proxyPortStr, err := net.SplitHostPort(alpnProxyAddr)
|
||||
if err != nil {
|
||||
return database.PingParams{}, trace.Wrap(err)
|
||||
}
|
||||
|
||||
proxyPort, err := strconv.Atoi(proxyPortStr)
|
||||
if err != nil {
|
||||
return database.PingParams{}, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return database.PingParams{
|
||||
Host: proxyHost,
|
||||
Port: proxyPort,
|
||||
Username: databaseUser,
|
||||
Database: databaseName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s DatabaseConnectionTester) handlePingSuccess(ctx context.Context, connectionDiagnosticID string) (types.ConnectionDiagnostic, error) {
|
||||
if _, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_CONNECTIVITY,
|
||||
"Database is accessible from the Database Agent.",
|
||||
nil,
|
||||
); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if _, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
|
||||
"Access to Database User and Database Name granted.",
|
||||
nil,
|
||||
); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if _, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_DATABASE_DB_USER,
|
||||
"Database User exists in the Database.",
|
||||
nil,
|
||||
); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
connDiag, err := s.appendDiagnosticTrace(ctx, connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_DATABASE_DB_NAME,
|
||||
"Database Name exists in the Database.",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
connDiag.SetMessage(types.DiagnosticMessageSuccess)
|
||||
connDiag.SetSuccess(true)
|
||||
|
||||
if err := s.cfg.UserClient.UpdateConnectionDiagnostic(ctx, connDiag); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
func errorFromDatabaseService(pingErr error) bool {
|
||||
// If the requested DB User/Name can't be used per RBAC checks, the Database Agent returns an error which gets here.
|
||||
if strings.Contains(pingErr.Error(), "access to db denied. User does not have permissions. Confirm database user and name.") {
|
||||
return true
|
||||
}
|
||||
|
||||
// When there's an error when trying to use RDS IAM auth.
|
||||
if strings.Contains(pingErr.Error(), "FATAL: PAM authentication failed for user") &&
|
||||
strings.Contains(pingErr.Error(), "IAM policy") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s DatabaseConnectionTester) handlePingError(ctx context.Context, connectionDiagnosticID string, pingErr error, databasePinger databasePinger) (types.ConnectionDiagnostic, error) {
|
||||
// The Database Agent (lib/srv/db/server.go) might add an trace in some cases.
|
||||
// Here, it must be ignored to prevent multiple failed traces.
|
||||
if errorFromDatabaseService(pingErr) {
|
||||
connDiag, err := s.cfg.UserClient.GetConnectionDiagnostic(ctx, connectionDiagnosticID)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
if databasePinger.IsConnectionRefusedError(pingErr) {
|
||||
connDiag, err := s.appendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_CONNECTIVITY,
|
||||
"There was a connection problem between the Database Agent and the Database. "+
|
||||
"Ensure the Database is running and accessible from the Database Agent.",
|
||||
pingErr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
// Requested DB User is allowed per RBAC rules, but those entities don't exist in the Database itself.
|
||||
if databasePinger.IsInvalidDatabaseUserError(pingErr) {
|
||||
connDiag, err := s.appendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_DATABASE_DB_USER,
|
||||
"The Database rejected the provided Database User. Ensure that the database user exists.",
|
||||
pingErr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
// Requested DB Name is allowed per RBAC rules, but those entities don't exist in the Database itself.
|
||||
if databasePinger.IsInvalidDatabaseNameError(pingErr) {
|
||||
connDiag, err := s.appendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_DATABASE_DB_NAME,
|
||||
"The Database rejected the provided Database Name. Ensure that the database name exists.",
|
||||
pingErr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
connDiag, err := s.appendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
types.ConnectionDiagnosticTrace_UNKNOWN_ERROR,
|
||||
fmt.Sprintf("Unknown error. %v", pingErr),
|
||||
pingErr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
func (s DatabaseConnectionTester) appendDiagnosticTrace(ctx context.Context, connectionDiagnosticID string, traceType types.ConnectionDiagnosticTrace_TraceType, message string, err error) (types.ConnectionDiagnostic, error) {
|
||||
connDiag, err := s.cfg.UserClient.AppendDiagnosticTrace(
|
||||
ctx,
|
||||
connectionDiagnosticID,
|
||||
types.NewTraceDiagnosticConnection(
|
||||
traceType,
|
||||
message,
|
||||
err,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return connDiag, nil
|
||||
}
|
||||
|
||||
func getDatabaseConnTester(protocol string) (databasePinger, error) {
|
||||
switch protocol {
|
||||
case defaults.ProtocolPostgres:
|
||||
return &database.PostgresPinger{}, nil
|
||||
}
|
||||
return nil, trace.NotImplemented("database protocol %q is not supported yet for testing connection", protocol)
|
||||
}
|
54
lib/client/conntest/database/database.go
Normal file
54
lib/client/conntest/database/database.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// PingParams contains the required fields necessary to test a Database Connection.
|
||||
type PingParams struct {
|
||||
// Host is the hostname of the Database (does not include port).
|
||||
Host string
|
||||
// Port is the port where the Database is accepting connections.
|
||||
Port int
|
||||
// Username is the user to be used to login into the database.
|
||||
Username string
|
||||
// Database is the database name to be used to login into the database.
|
||||
Database string
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults validates and set the default values for the Ping.
|
||||
func (req *PingParams) CheckAndSetDefaults() error {
|
||||
if req.Database == "" {
|
||||
return trace.BadParameter("missing required parameter Database")
|
||||
}
|
||||
|
||||
if req.Username == "" {
|
||||
return trace.BadParameter("missing required parameter Username")
|
||||
}
|
||||
|
||||
if req.Port == 0 {
|
||||
return trace.BadParameter("missing required parameter Port")
|
||||
}
|
||||
|
||||
if req.Host == "" {
|
||||
req.Host = "localhost"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
114
lib/client/conntest/database/postgres.go
Normal file
114
lib/client/conntest/database/postgres.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// A simple query to execute when running the Ping request
|
||||
selectOneQuery = "select 1;"
|
||||
)
|
||||
|
||||
// PostgresPinger implements the DatabasePinger interface for the Postgres protocol
|
||||
type PostgresPinger struct{}
|
||||
|
||||
// Ping connects to the database and issues a basic select statement to validate the connection.
|
||||
func (p *PostgresPinger) Ping(ctx context.Context, ping PingParams) error {
|
||||
if err := ping.CheckAndSetDefaults(); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
pgconnConfig, err := pgconn.ParseConfig(
|
||||
fmt.Sprintf("postgres://%s@%s:%d/%s",
|
||||
ping.Username,
|
||||
ping.Host,
|
||||
ping.Port,
|
||||
ping.Database,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
conn, err := pgconn.ConnectConfig(ctx, pgconnConfig)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := conn.Close(ctx); err != nil {
|
||||
logrus.WithError(err).Info("failed to close connection in PostgresPinger.Ping")
|
||||
}
|
||||
}()
|
||||
|
||||
result, err := conn.Exec(ctx, selectOneQuery).ReadAll()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
return trace.BadParameter("unexpected length for result: %+v", result)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnectionRefusedError checks whether the error is of type invalid database user.
|
||||
// This can happen when the user doesn't exist.
|
||||
func (p *PostgresPinger) IsConnectionRefusedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.Contains(err.Error(), "connection refused (SQLSTATE )")
|
||||
}
|
||||
|
||||
// IsInvalidDatabaseUserError checks whether the error is of type invalid database user.
|
||||
// This can happen when the user doesn't exist.
|
||||
func (p *PostgresPinger) IsInvalidDatabaseUserError(err error) bool {
|
||||
var pge *pgconn.PgError
|
||||
if errors.As(err, &pge) {
|
||||
if pge.SQLState() == pgerrcode.InvalidAuthorizationSpecification {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInvalidDatabaseNameError checks whether the error is of type invalid database name.
|
||||
// This can happen when the database doesn't exist.
|
||||
func (p *PostgresPinger) IsInvalidDatabaseNameError(err error) bool {
|
||||
var pge *pgconn.PgError
|
||||
if errors.As(err, &pge) {
|
||||
if pge.SQLState() == pgerrcode.InvalidCatalogName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
175
lib/client/conntest/database/postgres_test.go
Normal file
175
lib/client/conntest/database/postgres_test.go
Normal file
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport/api/client/proto"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/lib/fixtures"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/teleport/lib/srv/db/common"
|
||||
"github.com/gravitational/teleport/lib/srv/db/postgres"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
)
|
||||
|
||||
func TestPostgresErrors(t *testing.T) {
|
||||
p := PostgresPinger{}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
pingErr error
|
||||
errCheck require.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "connection refused error",
|
||||
pingErr: errors.New("failed to connect to `host=127.0.0.1 user=postgres database=postgres`: server error (: connection refused (SQLSTATE ))"),
|
||||
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
|
||||
require.True(tt, p.IsConnectionRefusedError(err))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid database error",
|
||||
pingErr: &pgconn.PgError{
|
||||
Code: pgerrcode.InvalidCatalogName,
|
||||
},
|
||||
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
|
||||
require.True(tt, p.IsInvalidDatabaseNameError(err))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid user error",
|
||||
pingErr: &pgconn.PgError{
|
||||
Code: pgerrcode.InvalidAuthorizationSpecification,
|
||||
},
|
||||
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
|
||||
require.True(tt, p.IsInvalidDatabaseUserError(err))
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.errCheck(t, tt.pingErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockClient is a mock that implements AuthClient interface.
|
||||
|
||||
type mockClient struct {
|
||||
common.AuthClientCA
|
||||
|
||||
ca types.CertAuthority
|
||||
}
|
||||
|
||||
func setupMockClient(t *testing.T) *mockClient {
|
||||
t.Helper()
|
||||
|
||||
_, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "example.com"}, nil, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
|
||||
Type: types.HostCA,
|
||||
ClusterName: "example.com",
|
||||
ActiveKeys: types.CAKeySet{
|
||||
SSH: []*types.SSHKeyPair{{PublicKey: []byte("SSH CA cert")}},
|
||||
TLS: []*types.TLSKeyPair{{Cert: cert}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return &mockClient{
|
||||
ca: ca,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockClient) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) {
|
||||
csr, err := tlsca.ParseCertificateRequestPEM(req.CSR)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
tlsCACert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
tlsCA, err := tlsca.FromTLSCertificate(tlsCACert)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
certReq := tlsca.CertificateRequest{
|
||||
PublicKey: csr.PublicKey,
|
||||
Subject: csr.Subject,
|
||||
NotAfter: time.Now().Add(req.TTL.Get()),
|
||||
DNSNames: req.ServerNames,
|
||||
}
|
||||
cert, err := tlsCA.GenerateCertificate(certReq)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return &proto.DatabaseCertResponse{
|
||||
Cert: cert,
|
||||
CACerts: [][]byte{
|
||||
[]byte(fixtures.TLSCACertPEM),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *mockClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) {
|
||||
return c.ca, nil
|
||||
}
|
||||
|
||||
func TestPostgresPing(t *testing.T) {
|
||||
mockClt := setupMockClient(t)
|
||||
|
||||
postgresTestServer, err := postgres.NewTestServer(common.TestServerConfig{
|
||||
AuthClient: mockClt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
t.Logf("Postgres Fake server running at %s port", postgresTestServer.Port())
|
||||
require.NoError(t, postgresTestServer.Serve())
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
postgresTestServer.Close()
|
||||
})
|
||||
|
||||
port, err := strconv.Atoi(postgresTestServer.Port())
|
||||
require.NoError(t, err)
|
||||
|
||||
p := PostgresPinger{}
|
||||
err = p.Ping(context.Background(), PingParams{
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "someuser",
|
||||
Database: "somedb",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
|
@ -38,7 +38,7 @@ import (
|
|||
// TestServerConfig combines parameters for a test Postgres/MySQL server.
|
||||
type TestServerConfig struct {
|
||||
// AuthClient will be used to retrieve trusted CA.
|
||||
AuthClient auth.ClientI
|
||||
AuthClient AuthClientCA
|
||||
// Name is the server name for identification purposes.
|
||||
Name string
|
||||
// AuthUser is used in tests simulating IAM token authentication.
|
||||
|
@ -92,6 +92,17 @@ func (cfg *TestServerConfig) Port() (string, error) {
|
|||
return port, nil
|
||||
}
|
||||
|
||||
// AuthClientCA contains the required methods to Generate mTLS certificate to be used
|
||||
// by the postgres TestServer.
|
||||
type AuthClientCA interface {
|
||||
// GenerateDatabaseCert generates client certificate used by a database
|
||||
// service to authenticate with the database instance.
|
||||
GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
|
||||
|
||||
// GetCertAuthority returns cert authority by id
|
||||
GetCertAuthority(context.Context, types.CertAuthID, bool, ...services.MarshalOption) (types.CertAuthority, error)
|
||||
}
|
||||
|
||||
// MakeTestServerTLSConfig returns TLS config suitable for configuring test
|
||||
// database Postgres/MySQL servers.
|
||||
func MakeTestServerTLSConfig(config TestServerConfig) (*tls.Config, error) {
|
||||
|
|
|
@ -813,6 +813,23 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
|
|||
|
||||
err = engine.HandleConnection(ctx, sessionCtx)
|
||||
if err != nil {
|
||||
connectionDiagnosticID := sessionCtx.Identity.ConnectionDiagnosticID
|
||||
if connectionDiagnosticID != "" && trace.IsAccessDenied(err) {
|
||||
_, diagErr := s.cfg.AuthClient.AppendDiagnosticTrace(ctx,
|
||||
connectionDiagnosticID,
|
||||
&types.ConnectionDiagnosticTrace{
|
||||
Type: types.ConnectionDiagnosticTrace_RBAC_DATABASE_LOGIN,
|
||||
Status: types.ConnectionDiagnosticTrace_FAILED,
|
||||
Details: "Access denied when accessing Database. Please check the Error message for more information.",
|
||||
Error: err.Error(),
|
||||
},
|
||||
)
|
||||
|
||||
if diagErr != nil {
|
||||
return trace.Wrap(diagErr)
|
||||
}
|
||||
}
|
||||
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -180,6 +180,9 @@ type Identity struct {
|
|||
AllowedResourceIDs []types.ResourceID
|
||||
// PrivateKeyPolicy is the private key policy supported by this identity.
|
||||
PrivateKeyPolicy keys.PrivateKeyPolicy
|
||||
|
||||
// ConnectionDiagnosticID is used to add connection diagnostic messages when Testing a Connection.
|
||||
ConnectionDiagnosticID string
|
||||
}
|
||||
|
||||
// RouteToApp holds routing information for applications.
|
||||
|
@ -433,6 +436,11 @@ var (
|
|||
// deadline of the session on a certificates issued after an MFA check.
|
||||
// See https://github.com/gravitational/teleport/issues/18544.
|
||||
PreviousIdentityExpiresASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 12}
|
||||
|
||||
// ConnectionDiagnosticIDASN1ExtensionOID is an extension OID used to indicate the Connection Diagnostic ID.
|
||||
// When using the Test Connection feature, there's propagation of the ConnectionDiagnosticID.
|
||||
// Each service (ex DB Agent) uses that to add checkpoints describing if it was a success or a failure.
|
||||
ConnectionDiagnosticIDASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 13}
|
||||
)
|
||||
|
||||
// Subject converts identity to X.509 subject name
|
||||
|
@ -685,6 +693,15 @@ func (id *Identity) Subject() (pkix.Name, error) {
|
|||
)
|
||||
}
|
||||
|
||||
if id.ConnectionDiagnosticID != "" {
|
||||
subject.ExtraNames = append(subject.ExtraNames,
|
||||
pkix.AttributeTypeAndValue{
|
||||
Type: ConnectionDiagnosticIDASN1ExtensionOID,
|
||||
Value: id.ConnectionDiagnosticID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return subject, nil
|
||||
}
|
||||
|
||||
|
@ -868,6 +885,11 @@ func FromSubject(subject pkix.Name, expires time.Time) (*Identity, error) {
|
|||
if ok {
|
||||
id.PrivateKeyPolicy = keys.PrivateKeyPolicy(val)
|
||||
}
|
||||
case attr.Type.Equal(ConnectionDiagnosticIDASN1ExtensionOID):
|
||||
val, ok := attr.Value.(string)
|
||||
if ok {
|
||||
id.ConnectionDiagnosticID = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ func (h *Handler) diagnoseConnection(w http.ResponseWriter, r *http.Request, p h
|
|||
ResourceKind: req.ResourceKind,
|
||||
UserClient: userClt,
|
||||
ProxyHostPort: h.ProxyHostPort(),
|
||||
PublicProxyAddr: h.cfg.PublicProxyAddr,
|
||||
KubernetesPublicProxyAddr: h.kubeProxyHostPort(),
|
||||
TLSRoutingEnabled: proxySettings.TLSRoutingEnabled,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue