mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 08:43:58 +00:00
Fix mongo access with mfa and add tests (#8799)
This commit is contained in:
parent
6cf111b241
commit
d87ee8f640
|
@ -755,10 +755,13 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
|
|||
accessPoint, err := site.CachingAccessPoint()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-time.Tick(500 * time.Millisecond):
|
||||
servers, err := accessPoint.GetDatabaseServers(context.Background(), apidefaults.Namespace)
|
||||
servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Debugf("Leaf cluster access point is unavailable.")
|
||||
continue
|
||||
|
@ -772,7 +775,7 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
return
|
||||
case <-time.After(10 * time.Second):
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Leaf cluster access point is unavailable.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -279,13 +279,8 @@ func (proxy *ProxyClient) reissueUserCerts(ctx context.Context, cachePolicy Cert
|
|||
// Database certs have to be requested with CertUsage All because
|
||||
// pre-7.0 servers do not accept usage-restricted certificates.
|
||||
if params.RouteToDatabase.ServiceName != "" {
|
||||
switch params.RouteToDatabase.Protocol {
|
||||
case defaults.ProtocolMongoDB:
|
||||
// MongoDB expects certificate and key pair in the same pem file.
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = append(certs.TLS, key.Priv...)
|
||||
default:
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = certs.TLS
|
||||
}
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = makeDatabaseClientPEM(
|
||||
params.RouteToDatabase.Protocol, certs.TLS, key.Priv)
|
||||
}
|
||||
|
||||
case proto.UserCertsRequest_SSH:
|
||||
|
@ -293,19 +288,25 @@ func (proxy *ProxyClient) reissueUserCerts(ctx context.Context, cachePolicy Cert
|
|||
case proto.UserCertsRequest_App:
|
||||
key.AppTLSCerts[params.RouteToApp.Name] = certs.TLS
|
||||
case proto.UserCertsRequest_Database:
|
||||
switch params.RouteToDatabase.Protocol {
|
||||
case defaults.ProtocolMongoDB:
|
||||
// MongoDB expects certificate and key pair in the same pem file.
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = append(certs.TLS, key.Priv...)
|
||||
default:
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = certs.TLS
|
||||
}
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = makeDatabaseClientPEM(
|
||||
params.RouteToDatabase.Protocol, certs.TLS, key.Priv)
|
||||
case proto.UserCertsRequest_Kubernetes:
|
||||
key.KubeTLSCerts[params.KubernetesCluster] = certs.TLS
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// makeDatabaseClientPEM returns appropriate client PEM file contents for the
|
||||
// specified database type. Some databases only need certificate in the PEM
|
||||
// file, others both certificate and key.
|
||||
func makeDatabaseClientPEM(proto string, cert, key []byte) []byte {
|
||||
// MongoDB expects certificate and key pair in the same pem file.
|
||||
if proto == defaults.ProtocolMongoDB {
|
||||
return append(cert, key...)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
// PromptMFAChallengeHandler is a handler for MFA challenges.
|
||||
//
|
||||
// The challenge c from proxyAddr should be presented to the user, asking to
|
||||
|
@ -434,7 +435,8 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis
|
|||
case proto.UserCertsRequest_Kubernetes:
|
||||
key.KubeTLSCerts[initReq.KubernetesCluster] = crt.TLS
|
||||
case proto.UserCertsRequest_Database:
|
||||
key.DBTLSCerts[initReq.RouteToDatabase.ServiceName] = crt.TLS
|
||||
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = makeDatabaseClientPEM(
|
||||
params.RouteToDatabase.Protocol, crt.TLS, key.Priv)
|
||||
default:
|
||||
return nil, trace.BadParameter("server returned a TLS certificate but cert request usage was %s", initReq.Usage)
|
||||
}
|
||||
|
|
190
tool/tsh/db_test.go
Normal file
190
tool/tsh/db_test.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
/*
|
||||
Copyright 2015-2017 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
apidefaults "github.com/gravitational/teleport/api/defaults"
|
||||
"github.com/gravitational/teleport/api/profile"
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/lib"
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDatabaseLogin verifies "tsh db login" command.
|
||||
func TestDatabaseLogin(t *testing.T) {
|
||||
os.RemoveAll(profile.FullProfilePath(""))
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(profile.FullProfilePath(""))
|
||||
})
|
||||
|
||||
connector := mockConnector(t)
|
||||
|
||||
alice, err := types.NewUser("alice@example.com")
|
||||
require.NoError(t, err)
|
||||
alice.SetRoles([]string{"access"})
|
||||
|
||||
authProcess, proxyProcess := makeTestServers(t, connector, alice)
|
||||
makeTestDatabaseServer(t, authProcess, proxyProcess, service.Database{
|
||||
Name: "postgres",
|
||||
Protocol: defaults.ProtocolPostgres,
|
||||
URI: "localhost:5432",
|
||||
}, service.Database{
|
||||
Name: "mongo",
|
||||
Protocol: defaults.ProtocolMongoDB,
|
||||
URI: "localhost:27017",
|
||||
})
|
||||
|
||||
authServer := authProcess.GetAuthServer()
|
||||
require.NotNil(t, authServer)
|
||||
|
||||
proxyAddr, err := proxyProcess.ProxyWebAddr()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Log into Teleport cluster.
|
||||
err = Run([]string{
|
||||
"login", "--insecure", "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(),
|
||||
}, cliOption(func(cf *CLIConf) error {
|
||||
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
|
||||
return nil
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch the active profile.
|
||||
profile, err := client.StatusFor("", proxyAddr.Host(), alice.GetName())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Log into test Postgres database.
|
||||
err = Run([]string{
|
||||
"db", "login", "--debug", "postgres",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify Postgres identity file contains certificate.
|
||||
certs, keys, err := decodePEM(profile.DatabaseCertPath("postgres"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, certs, 1)
|
||||
require.Len(t, keys, 0)
|
||||
|
||||
// Log into test Mongo database.
|
||||
err = Run([]string{
|
||||
"db", "login", "--debug", "--db-user", "admin", "mongo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify Mongo identity file contains both certificate and key.
|
||||
certs, keys, err = decodePEM(profile.DatabaseCertPath("mongo"))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, certs, 1)
|
||||
require.Len(t, keys, 1)
|
||||
}
|
||||
|
||||
func makeTestDatabaseServer(t *testing.T, auth *service.TeleportProcess, proxy *service.TeleportProcess, dbs ...service.Database) (db *service.TeleportProcess) {
|
||||
// Proxy uses self-signed certificates in tests.
|
||||
lib.SetInsecureDevMode(true)
|
||||
|
||||
cfg := service.MakeDefaultConfig()
|
||||
cfg.Hostname = "localhost"
|
||||
cfg.DataDir = t.TempDir()
|
||||
|
||||
proxyAddr, err := proxy.ProxyWebAddr()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg.AuthServers = []utils.NetAddr{*proxyAddr}
|
||||
cfg.Token = proxy.Config.Token
|
||||
cfg.SSH.Enabled = false
|
||||
cfg.Auth.Enabled = false
|
||||
cfg.Databases.Enabled = true
|
||||
cfg.Databases.Databases = dbs
|
||||
cfg.Log = utils.NewLoggerForTests()
|
||||
|
||||
db, err = service.NewTeleport(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Start())
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
// Wait for database agent to start.
|
||||
eventCh := make(chan service.Event, 1)
|
||||
db.WaitForEvent(db.ExitContext(), service.DatabasesReady, eventCh)
|
||||
select {
|
||||
case <-eventCh:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("database server didn't start after 10s")
|
||||
}
|
||||
|
||||
// Wait for all databases to register to avoid races.
|
||||
for _, database := range dbs {
|
||||
waitForDatabase(t, auth, database)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func waitForDatabase(t *testing.T, auth *service.TeleportProcess, db service.Database) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
all, err := auth.GetAuthServer().GetDatabaseServers(ctx, apidefaults.Namespace)
|
||||
require.NoError(t, err)
|
||||
for _, a := range all {
|
||||
if a.GetName() == db.Name {
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("database not registered after 10s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decodePEM sorts out specified PEM file into certificates and private keys.
|
||||
func decodePEM(pemPath string) (certs []pem.Block, keys []pem.Block, err error) {
|
||||
bytes, err := os.ReadFile(pemPath)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
var block *pem.Block
|
||||
for {
|
||||
block, bytes = pem.Decode(bytes)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
switch block.Type {
|
||||
case "CERTIFICATE":
|
||||
certs = append(certs, *block)
|
||||
case "RSA PRIVATE KEY":
|
||||
keys = append(keys, *block)
|
||||
}
|
||||
}
|
||||
return certs, keys, nil
|
||||
}
|
|
@ -50,7 +50,15 @@ import (
|
|||
|
||||
const staticToken = "test-static-token"
|
||||
|
||||
var randomLocalAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}
|
||||
var ports utils.PortList
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
ports, err = utils.GetFreeTCPPorts(5000, utils.PortStartingNumber)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to allocate tcp ports for tests: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
utils.InitLoggerForTests()
|
||||
|
@ -913,12 +921,12 @@ func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.T
|
|||
cfg.Hostname = "localhost"
|
||||
cfg.DataDir = t.TempDir()
|
||||
|
||||
cfg.AuthServers = []utils.NetAddr{randomLocalAddr}
|
||||
cfg.AuthServers = []utils.NetAddr{{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}}
|
||||
cfg.Auth.Resources = bootstrap
|
||||
cfg.Auth.StorageConfig.Params = backend.Params{defaults.BackendPath: filepath.Join(cfg.DataDir, defaults.BackendDir)}
|
||||
cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{
|
||||
StaticTokens: []types.ProvisionTokenV1{{
|
||||
Roles: []types.SystemRole{types.RoleProxy},
|
||||
Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase},
|
||||
Expires: time.Now().Add(time.Minute),
|
||||
Token: staticToken,
|
||||
}},
|
||||
|
@ -926,7 +934,7 @@ func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.T
|
|||
require.NoError(t, err)
|
||||
cfg.SSH.Enabled = false
|
||||
cfg.Auth.Enabled = true
|
||||
cfg.Auth.SSHAddr = randomLocalAddr
|
||||
cfg.Auth.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
|
||||
cfg.Proxy.Enabled = false
|
||||
cfg.Log = utils.NewLoggerForTests()
|
||||
|
||||
|
@ -962,9 +970,9 @@ func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.T
|
|||
cfg.SSH.Enabled = false
|
||||
cfg.Auth.Enabled = false
|
||||
cfg.Proxy.Enabled = true
|
||||
cfg.Proxy.WebAddr = randomLocalAddr
|
||||
cfg.Proxy.SSHAddr = randomLocalAddr
|
||||
cfg.Proxy.ReverseTunnelListenAddr = randomLocalAddr
|
||||
cfg.Proxy.WebAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
|
||||
cfg.Proxy.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
|
||||
cfg.Proxy.ReverseTunnelListenAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
|
||||
cfg.Proxy.DisableWebInterface = true
|
||||
cfg.Log = utils.NewLoggerForTests()
|
||||
|
||||
|
|
Loading…
Reference in a new issue