mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 17:53:28 +00:00
Generate database access credentials with tctl auth sign command (#10785)
* feat(tctl): sign command to generate database access credentials * feat(tctl): make auth sign parameters app-name and db-name mutually exclusive * feat(tctl): add flag db-user to auth sign command * test(tctl): remove references to deprecated package ioutil * test(tctl): update test to check error type * chore(tctl): add godoc to `getDatabaseServer` function * refactor(tctl): rename database-related flags in auth sign * refactor(tctl): rename flag from `db` to `db-service`
This commit is contained in:
parent
7c7bb75f22
commit
6f971d1fb7
|
@ -67,6 +67,9 @@ type AuthCommand struct {
|
|||
leafCluster string
|
||||
kubeCluster string
|
||||
appName string
|
||||
dbService string
|
||||
dbName string
|
||||
dbUser string
|
||||
signOverwrite bool
|
||||
|
||||
rotateGracePeriod time.Duration
|
||||
|
@ -118,7 +121,10 @@ func (a *AuthCommand) Initialize(app *kingpin.Application, config *service.Confi
|
|||
a.authSign.Flag("kube-cluster", `Leaf cluster to generate identity file for when --format is set to "kubernetes"`).Hidden().StringVar(&a.leafCluster)
|
||||
a.authSign.Flag("leaf-cluster", `Leaf cluster to generate identity file for when --format is set to "kubernetes"`).StringVar(&a.leafCluster)
|
||||
a.authSign.Flag("kube-cluster-name", `Kubernetes cluster to generate identity file for when --format is set to "kubernetes"`).StringVar(&a.kubeCluster)
|
||||
a.authSign.Flag("app-name", `Application to generate identity file for`).StringVar(&a.appName)
|
||||
a.authSign.Flag("app-name", `Application to generate identity file for. Mutually exclusive with "--db-service".`).StringVar(&a.appName)
|
||||
a.authSign.Flag("db-service", `Database to generate identity file for. Mutually exclusive with "--app-name".`).StringVar(&a.dbService)
|
||||
a.authSign.Flag("db-user", `Database user placed on the identity file. Only used when "--db-service" is set.`).StringVar(&a.dbUser)
|
||||
a.authSign.Flag("db-name", `Database name placed on the identity file. Only used when "--db-service" is set.`).StringVar(&a.dbName)
|
||||
|
||||
a.authRotate = auth.Command("rotate", "Rotate certificate authorities in the cluster")
|
||||
a.authRotate.Flag("grace-period", "Grace period keeps previous certificate authorities signatures valid, if set to 0 will force users to relogin and nodes to re-register.").
|
||||
|
@ -594,10 +600,19 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI auth.Clie
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
var routeToApp proto.RouteToApp
|
||||
var certUsage proto.UserCertsRequest_CertUsage
|
||||
var (
|
||||
routeToApp proto.RouteToApp
|
||||
routeToDatabase proto.RouteToDatabase
|
||||
certUsage proto.UserCertsRequest_CertUsage
|
||||
)
|
||||
|
||||
if a.appName != "" {
|
||||
// `appName` and `db` are mutually exclusive.
|
||||
if a.appName != "" && a.dbService != "" {
|
||||
return trace.BadParameter("only --app-name or --db-service can be set, not both")
|
||||
}
|
||||
|
||||
switch {
|
||||
case a.appName != "":
|
||||
server, err := getApplicationServer(ctx, clusterAPI, a.appName)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -619,6 +634,19 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI auth.Clie
|
|||
SessionID: appSession.GetName(),
|
||||
}
|
||||
certUsage = proto.UserCertsRequest_App
|
||||
case a.dbService != "":
|
||||
server, err := getDatabaseServer(context.TODO(), clusterAPI, a.dbService)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
routeToDatabase = proto.RouteToDatabase{
|
||||
ServiceName: a.dbService,
|
||||
Protocol: server.GetDatabase().GetProtocol(),
|
||||
Database: a.dbName,
|
||||
Username: a.dbUser,
|
||||
}
|
||||
certUsage = proto.UserCertsRequest_Database
|
||||
}
|
||||
|
||||
reqExpiry := time.Now().UTC().Add(a.genTTL)
|
||||
|
@ -632,6 +660,7 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI auth.Clie
|
|||
KubernetesCluster: a.kubeCluster,
|
||||
RouteToApp: routeToApp,
|
||||
Usage: certUsage,
|
||||
RouteToDatabase: routeToDatabase,
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -832,3 +861,20 @@ func getApplicationServer(ctx context.Context, clusterAPI auth.ClientI, appName
|
|||
}
|
||||
return nil, trace.NotFound("app %q not found", appName)
|
||||
}
|
||||
|
||||
// getDatabaseServer fetches a single `DatabaseServer` by name using the
|
||||
// provided `auth.ClientI`.
|
||||
func getDatabaseServer(ctx context.Context, clientAPI auth.ClientI, dbName string) (types.DatabaseServer, error) {
|
||||
servers, err := clientAPI.GetDatabaseServers(ctx, apidefaults.Namespace)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
for _, server := range servers {
|
||||
if server.GetName() == dbName {
|
||||
return server, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, trace.NotFound("database %q not found", dbName)
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
"github.com/gravitational/teleport/lib/client/identityfile"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/kube/kubeconfig"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
|
@ -273,6 +274,7 @@ type mockClient struct {
|
|||
remoteClusters []types.RemoteCluster
|
||||
kubeServices []types.Server
|
||||
appServices []types.AppServer
|
||||
dbServices []types.DatabaseServer
|
||||
appSession types.WebSession
|
||||
}
|
||||
|
||||
|
@ -308,6 +310,10 @@ func (c *mockClient) CreateAppSession(ctx context.Context, req types.CreateAppSe
|
|||
return c.appSession, nil
|
||||
}
|
||||
|
||||
func (c *mockClient) GetDatabaseServers(context.Context, string, ...services.MarshalOption) ([]types.DatabaseServer, error) {
|
||||
return c.dbServices, nil
|
||||
}
|
||||
|
||||
func TestCheckKubeCluster(t *testing.T) {
|
||||
const teleportCluster = "local-teleport"
|
||||
clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
|
||||
|
@ -668,3 +674,138 @@ func TestGenerateAppCertificates(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDatabaseUserCertificates(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tests := map[string]struct {
|
||||
clusterName string
|
||||
dbService string
|
||||
dbName string
|
||||
dbUser string
|
||||
expectedDbProtocol string
|
||||
dbServices []types.DatabaseServer
|
||||
expectedErr error
|
||||
}{
|
||||
"DatabaseExists": {
|
||||
clusterName: "example.com",
|
||||
dbService: "db-1",
|
||||
expectedDbProtocol: defaults.ProtocolPostgres,
|
||||
dbServices: []types.DatabaseServer{
|
||||
&types.DatabaseServerV3{
|
||||
Metadata: types.Metadata{
|
||||
Name: "db-1",
|
||||
},
|
||||
Spec: types.DatabaseServerSpecV3{
|
||||
Hostname: "example.com",
|
||||
Database: &types.DatabaseV3{
|
||||
Spec: types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolPostgres,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"DatabaseWithUserExists": {
|
||||
clusterName: "example.com",
|
||||
dbService: "db-user-1",
|
||||
dbUser: "mongo-user",
|
||||
expectedDbProtocol: defaults.ProtocolMongoDB,
|
||||
dbServices: []types.DatabaseServer{
|
||||
&types.DatabaseServerV3{
|
||||
Metadata: types.Metadata{
|
||||
Name: "db-user-1",
|
||||
},
|
||||
Spec: types.DatabaseServerSpecV3{
|
||||
Hostname: "example.com",
|
||||
Database: &types.DatabaseV3{
|
||||
Spec: types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolMongoDB,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"DatabaseWithDatabaseNameExists": {
|
||||
clusterName: "example.com",
|
||||
dbService: "db-user-1",
|
||||
dbName: "root-database",
|
||||
expectedDbProtocol: defaults.ProtocolMongoDB,
|
||||
dbServices: []types.DatabaseServer{
|
||||
&types.DatabaseServerV3{
|
||||
Metadata: types.Metadata{
|
||||
Name: "db-user-1",
|
||||
},
|
||||
Spec: types.DatabaseServerSpecV3{
|
||||
Hostname: "example.com",
|
||||
Database: &types.DatabaseV3{
|
||||
Spec: types.DatabaseSpecV3{
|
||||
Protocol: defaults.ProtocolMongoDB,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"DatabaseNotFound": {
|
||||
clusterName: "example.com",
|
||||
dbService: "db-2",
|
||||
dbServices: []types.DatabaseServer{},
|
||||
expectedErr: trace.NotFound(""),
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
clusterName, err := services.NewClusterNameWithRandomID(
|
||||
types.ClusterNameSpecV2{
|
||||
ClusterName: test.clusterName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authClient := &mockClient{
|
||||
clusterName: clusterName,
|
||||
userCerts: &proto.Certs{
|
||||
SSH: []byte("SSH cert"),
|
||||
TLS: []byte("TLS cert"),
|
||||
},
|
||||
dbServices: test.dbServices,
|
||||
}
|
||||
|
||||
certsDir := t.TempDir()
|
||||
output := filepath.Join(certsDir, test.dbService)
|
||||
ac := AuthCommand{
|
||||
output: output,
|
||||
outputFormat: identityfile.FormatTLS,
|
||||
signOverwrite: true,
|
||||
genTTL: time.Hour,
|
||||
dbService: test.dbService,
|
||||
dbName: test.dbName,
|
||||
dbUser: test.dbUser,
|
||||
}
|
||||
|
||||
err = ac.generateUserKeys(ctx, authClient)
|
||||
if test.expectedErr != nil {
|
||||
require.Error(t, err)
|
||||
require.IsType(t, test.expectedErr, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedRouteToDatabase := proto.RouteToDatabase{
|
||||
ServiceName: test.dbService,
|
||||
Protocol: test.expectedDbProtocol,
|
||||
Database: test.dbName,
|
||||
Username: test.dbUser,
|
||||
}
|
||||
require.Equal(t, proto.UserCertsRequest_Database, authClient.userCertsReq.Usage)
|
||||
require.Equal(t, expectedRouteToDatabase, authClient.userCertsReq.RouteToDatabase)
|
||||
|
||||
certBytes, err := os.ReadFile(filepath.Join(certsDir, test.dbService+".crt"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, authClient.userCerts.TLS, certBytes, "certificates match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue