Fix mongo access with mfa and add tests (#8799)

This commit is contained in:
Roman Tkachenko 2021-11-02 12:06:58 -07:00 committed by GitHub
parent 6cf111b241
commit d87ee8f640
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 227 additions and 24 deletions

View file

@ -755,10 +755,13 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
accessPoint, err := site.CachingAccessPoint()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for {
select {
case <-time.Tick(500 * time.Millisecond):
servers, err := accessPoint.GetDatabaseServers(context.Background(), apidefaults.Namespace)
servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace)
if err != nil {
logrus.WithError(err).Debugf("Leaf cluster access point is unavailable.")
continue
@ -772,7 +775,7 @@ func (p *databasePack) waitForLeaf(t *testing.T) {
continue
}
return
case <-time.After(10 * time.Second):
case <-ctx.Done():
t.Fatal("Leaf cluster access point is unavailable.")
}
}

View file

@ -279,13 +279,8 @@ func (proxy *ProxyClient) reissueUserCerts(ctx context.Context, cachePolicy Cert
// Database certs have to be requested with CertUsage All because
// pre-7.0 servers do not accept usage-restricted certificates.
if params.RouteToDatabase.ServiceName != "" {
switch params.RouteToDatabase.Protocol {
case defaults.ProtocolMongoDB:
// MongoDB expects certificate and key pair in the same pem file.
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = append(certs.TLS, key.Priv...)
default:
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = certs.TLS
}
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = makeDatabaseClientPEM(
params.RouteToDatabase.Protocol, certs.TLS, key.Priv)
}
case proto.UserCertsRequest_SSH:
@ -293,19 +288,25 @@ func (proxy *ProxyClient) reissueUserCerts(ctx context.Context, cachePolicy Cert
case proto.UserCertsRequest_App:
key.AppTLSCerts[params.RouteToApp.Name] = certs.TLS
case proto.UserCertsRequest_Database:
switch params.RouteToDatabase.Protocol {
case defaults.ProtocolMongoDB:
// MongoDB expects certificate and key pair in the same pem file.
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = append(certs.TLS, key.Priv...)
default:
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = certs.TLS
}
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = makeDatabaseClientPEM(
params.RouteToDatabase.Protocol, certs.TLS, key.Priv)
case proto.UserCertsRequest_Kubernetes:
key.KubeTLSCerts[params.KubernetesCluster] = certs.TLS
}
return key, nil
}
// makeDatabaseClientPEM returns appropriate client PEM file contents for the
// specified database type. Some databases only need certificate in the PEM
// file, others both certificate and key.
func makeDatabaseClientPEM(proto string, cert, key []byte) []byte {
// MongoDB expects certificate and key pair in the same pem file.
if proto == defaults.ProtocolMongoDB {
return append(cert, key...)
}
return cert
}
// PromptMFAChallengeHandler is a handler for MFA challenges.
//
// The challenge c from proxyAddr should be presented to the user, asking to
@ -434,7 +435,8 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis
case proto.UserCertsRequest_Kubernetes:
key.KubeTLSCerts[initReq.KubernetesCluster] = crt.TLS
case proto.UserCertsRequest_Database:
key.DBTLSCerts[initReq.RouteToDatabase.ServiceName] = crt.TLS
key.DBTLSCerts[params.RouteToDatabase.ServiceName] = makeDatabaseClientPEM(
params.RouteToDatabase.Protocol, crt.TLS, key.Priv)
default:
return nil, trace.BadParameter("server returned a TLS certificate but cert request usage was %s", initReq.Usage)
}

190
tool/tsh/db_test.go Normal file
View file

@ -0,0 +1,190 @@
/*
Copyright 2015-2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"context"
"encoding/pem"
"os"
"testing"
"time"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
// TestDatabaseLogin verifies "tsh db login" command.
func TestDatabaseLogin(t *testing.T) {
os.RemoveAll(profile.FullProfilePath(""))
t.Cleanup(func() {
os.RemoveAll(profile.FullProfilePath(""))
})
connector := mockConnector(t)
alice, err := types.NewUser("alice@example.com")
require.NoError(t, err)
alice.SetRoles([]string{"access"})
authProcess, proxyProcess := makeTestServers(t, connector, alice)
makeTestDatabaseServer(t, authProcess, proxyProcess, service.Database{
Name: "postgres",
Protocol: defaults.ProtocolPostgres,
URI: "localhost:5432",
}, service.Database{
Name: "mongo",
Protocol: defaults.ProtocolMongoDB,
URI: "localhost:27017",
})
authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)
proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)
// Log into Teleport cluster.
err = Run([]string{
"login", "--insecure", "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(),
}, cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
return nil
}))
require.NoError(t, err)
// Fetch the active profile.
profile, err := client.StatusFor("", proxyAddr.Host(), alice.GetName())
require.NoError(t, err)
// Log into test Postgres database.
err = Run([]string{
"db", "login", "--debug", "postgres",
})
require.NoError(t, err)
// Verify Postgres identity file contains certificate.
certs, keys, err := decodePEM(profile.DatabaseCertPath("postgres"))
require.NoError(t, err)
require.Len(t, certs, 1)
require.Len(t, keys, 0)
// Log into test Mongo database.
err = Run([]string{
"db", "login", "--debug", "--db-user", "admin", "mongo",
})
require.NoError(t, err)
// Verify Mongo identity file contains both certificate and key.
certs, keys, err = decodePEM(profile.DatabaseCertPath("mongo"))
require.NoError(t, err)
require.Len(t, certs, 1)
require.Len(t, keys, 1)
}
func makeTestDatabaseServer(t *testing.T, auth *service.TeleportProcess, proxy *service.TeleportProcess, dbs ...service.Database) (db *service.TeleportProcess) {
// Proxy uses self-signed certificates in tests.
lib.SetInsecureDevMode(true)
cfg := service.MakeDefaultConfig()
cfg.Hostname = "localhost"
cfg.DataDir = t.TempDir()
proxyAddr, err := proxy.ProxyWebAddr()
require.NoError(t, err)
cfg.AuthServers = []utils.NetAddr{*proxyAddr}
cfg.Token = proxy.Config.Token
cfg.SSH.Enabled = false
cfg.Auth.Enabled = false
cfg.Databases.Enabled = true
cfg.Databases.Databases = dbs
cfg.Log = utils.NewLoggerForTests()
db, err = service.NewTeleport(cfg)
require.NoError(t, err)
require.NoError(t, db.Start())
t.Cleanup(func() {
db.Close()
})
// Wait for database agent to start.
eventCh := make(chan service.Event, 1)
db.WaitForEvent(db.ExitContext(), service.DatabasesReady, eventCh)
select {
case <-eventCh:
case <-time.After(10 * time.Second):
t.Fatal("database server didn't start after 10s")
}
// Wait for all databases to register to avoid races.
for _, database := range dbs {
waitForDatabase(t, auth, database)
}
return db
}
func waitForDatabase(t *testing.T, auth *service.TeleportProcess, db service.Database) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for {
select {
case <-time.After(500 * time.Millisecond):
all, err := auth.GetAuthServer().GetDatabaseServers(ctx, apidefaults.Namespace)
require.NoError(t, err)
for _, a := range all {
if a.GetName() == db.Name {
return
}
}
case <-ctx.Done():
t.Fatal("database not registered after 10s")
}
}
}
// decodePEM sorts out specified PEM file into certificates and private keys.
func decodePEM(pemPath string) (certs []pem.Block, keys []pem.Block, err error) {
bytes, err := os.ReadFile(pemPath)
if err != nil {
return nil, nil, trace.Wrap(err)
}
var block *pem.Block
for {
block, bytes = pem.Decode(bytes)
if block == nil {
break
}
switch block.Type {
case "CERTIFICATE":
certs = append(certs, *block)
case "RSA PRIVATE KEY":
keys = append(keys, *block)
}
}
return certs, keys, nil
}

View file

@ -50,7 +50,15 @@ import (
const staticToken = "test-static-token"
var randomLocalAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}
var ports utils.PortList
func init() {
var err error
ports, err = utils.GetFreeTCPPorts(5000, utils.PortStartingNumber)
if err != nil {
panic(fmt.Sprintf("failed to allocate tcp ports for tests: %v", err))
}
}
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
@ -913,12 +921,12 @@ func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.T
cfg.Hostname = "localhost"
cfg.DataDir = t.TempDir()
cfg.AuthServers = []utils.NetAddr{randomLocalAddr}
cfg.AuthServers = []utils.NetAddr{{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}}
cfg.Auth.Resources = bootstrap
cfg.Auth.StorageConfig.Params = backend.Params{defaults.BackendPath: filepath.Join(cfg.DataDir, defaults.BackendDir)}
cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{
StaticTokens: []types.ProvisionTokenV1{{
Roles: []types.SystemRole{types.RoleProxy},
Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase},
Expires: time.Now().Add(time.Minute),
Token: staticToken,
}},
@ -926,7 +934,7 @@ func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.T
require.NoError(t, err)
cfg.SSH.Enabled = false
cfg.Auth.Enabled = true
cfg.Auth.SSHAddr = randomLocalAddr
cfg.Auth.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
cfg.Proxy.Enabled = false
cfg.Log = utils.NewLoggerForTests()
@ -962,9 +970,9 @@ func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.T
cfg.SSH.Enabled = false
cfg.Auth.Enabled = false
cfg.Proxy.Enabled = true
cfg.Proxy.WebAddr = randomLocalAddr
cfg.Proxy.SSHAddr = randomLocalAddr
cfg.Proxy.ReverseTunnelListenAddr = randomLocalAddr
cfg.Proxy.WebAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
cfg.Proxy.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
cfg.Proxy.ReverseTunnelListenAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
cfg.Proxy.DisableWebInterface = true
cfg.Log = utils.NewLoggerForTests()