teleport/lib/auth/auth.go

3765 lines
121 KiB
Go

/*
Copyright 2015-2019 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package auth implements certificate signing authority and access control server
// Authority server is composed of several parts:
//
// * Authority server itself that implements signing and acl logic
// * HTTP server wrapper for authority server
// * HTTP client wrapper
//
package auth
import (
"bytes"
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"math"
"math/big"
insecurerand "math/rand"
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
saml2 "github.com/russellhaering/gosaml2"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"golang.org/x/crypto/ssh"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth/keystore"
"github.com/gravitational/teleport/lib/auth/native"
wanlib "github.com/gravitational/teleport/lib/auth/webauthn"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/inventory"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/observability/tracing"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/srv/db/common/role"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/interval"
)
const (
ErrFieldKeyUserMaxedAttempts = "maxed-attempts"
// MaxFailedAttemptsErrMsg is a user friendly error message that tells a user that they are locked.
MaxFailedAttemptsErrMsg = "too many incorrect attempts, please try again later"
)
// ServerOption allows setting options as functional arguments to Server
type ServerOption func(*Server) error
// NewServer creates and configures a new Server instance
func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
err := metrics.RegisterPrometheusCollectors(prometheusCollectors...)
if err != nil {
return nil, trace.Wrap(err)
}
if cfg.Trust == nil {
cfg.Trust = local.NewCAService(cfg.Backend)
}
if cfg.Presence == nil {
cfg.Presence = local.NewPresenceService(cfg.Backend)
}
if cfg.Provisioner == nil {
cfg.Provisioner = local.NewProvisioningService(cfg.Backend)
}
if cfg.Identity == nil {
cfg.Identity = local.NewIdentityService(cfg.Backend)
}
if cfg.Access == nil {
cfg.Access = local.NewAccessService(cfg.Backend)
}
if cfg.DynamicAccessExt == nil {
cfg.DynamicAccessExt = local.NewDynamicAccessService(cfg.Backend)
}
if cfg.ClusterConfiguration == nil {
clusterConfig, err := local.NewClusterConfigurationService(cfg.Backend)
if err != nil {
return nil, trace.Wrap(err)
}
cfg.ClusterConfiguration = clusterConfig
}
if cfg.Restrictions == nil {
cfg.Restrictions = local.NewRestrictionsService(cfg.Backend)
}
if cfg.Apps == nil {
cfg.Apps = local.NewAppService(cfg.Backend)
}
if cfg.Databases == nil {
cfg.Databases = local.NewDatabasesService(cfg.Backend)
}
if cfg.Events == nil {
cfg.Events = local.NewEventsService(cfg.Backend)
}
if cfg.AuditLog == nil {
cfg.AuditLog = events.NewDiscardAuditLog()
}
if cfg.Emitter == nil {
cfg.Emitter = events.NewDiscardEmitter()
}
if cfg.Streamer == nil {
cfg.Streamer = events.NewDiscardEmitter()
}
if cfg.WindowsDesktops == nil {
cfg.WindowsDesktops = local.NewWindowsDesktopService(cfg.Backend)
}
if cfg.ConnectionsDiagnostic == nil {
cfg.ConnectionsDiagnostic = local.NewConnectionsDiagnosticService(cfg.Backend)
}
if cfg.SessionTrackerService == nil {
cfg.SessionTrackerService, err = local.NewSessionTrackerService(cfg.Backend)
if err != nil {
return nil, trace.Wrap(err)
}
}
if cfg.Enforcer == nil {
cfg.Enforcer = local.NewNoopEnforcer()
}
if cfg.KeyStoreConfig.RSAKeyPairSource == nil {
native.PrecomputeKeys()
cfg.KeyStoreConfig.RSAKeyPairSource = native.GenerateKeyPair
}
if cfg.KeyStoreConfig.HostUUID == "" {
cfg.KeyStoreConfig.HostUUID = cfg.HostUUID
}
if cfg.TraceClient == nil {
cfg.TraceClient = tracing.NewNoopClient()
}
limiter, err := limiter.NewConnectionsLimiter(limiter.Config{
MaxConnections: defaults.LimiterMaxConcurrentSignatures,
})
if err != nil {
return nil, trace.Wrap(err)
}
keyStore, err := keystore.NewKeyStore(cfg.KeyStoreConfig)
if err != nil {
return nil, trace.Wrap(err)
}
services := &Services{
Trust: cfg.Trust,
Presence: cfg.Presence,
Provisioner: cfg.Provisioner,
Identity: cfg.Identity,
Access: cfg.Access,
DynamicAccessExt: cfg.DynamicAccessExt,
ClusterConfiguration: cfg.ClusterConfiguration,
Restrictions: cfg.Restrictions,
Apps: cfg.Apps,
Databases: cfg.Databases,
IAuditLog: cfg.AuditLog,
Events: cfg.Events,
WindowsDesktops: cfg.WindowsDesktops,
SessionTrackerService: cfg.SessionTrackerService,
Enforcer: cfg.Enforcer,
ConnectionsDiagnostic: cfg.ConnectionsDiagnostic,
}
closeCtx, cancelFunc := context.WithCancel(context.TODO())
as := Server{
bk: cfg.Backend,
limiter: limiter,
Authority: cfg.Authority,
AuthServiceName: cfg.AuthServiceName,
ServerID: cfg.HostUUID,
oidcClients: make(map[string]*oidcClient),
samlProviders: make(map[string]*samlProvider),
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
emitter: cfg.Emitter,
streamer: cfg.Streamer,
unstable: local.NewUnstableService(cfg.Backend),
Services: services,
Cache: services,
keyStore: keyStore,
getClaimsFun: getClaims,
inventory: inventory.NewController(cfg.Presence),
traceClient: cfg.TraceClient,
}
for _, o := range opts {
if err := o(&as); err != nil {
return nil, trace.Wrap(err)
}
}
if as.clock == nil {
as.clock = clockwork.NewRealClock()
}
return &as, nil
}
type Services struct {
services.Trust
services.Presence
services.Provisioner
services.Identity
services.Access
services.DynamicAccessExt
services.ClusterConfiguration
services.Restrictions
services.Apps
services.Databases
services.WindowsDesktops
services.SessionTrackerService
services.Enforcer
services.ConnectionsDiagnostic
types.Events
events.IAuditLog
}
// GetWebSession returns existing web session described by req.
// Implements ReadAccessPoint
func (r *Services) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) {
return r.Identity.WebSessions().Get(ctx, req)
}
// GetWebToken returns existing web token described by req.
// Implements ReadAccessPoint
func (r *Services) GetWebToken(ctx context.Context, req types.GetWebTokenRequest) (types.WebToken, error) {
return r.Identity.WebTokens().Get(ctx, req)
}
var (
generateRequestsCount = prometheus.NewCounter(
prometheus.CounterOpts{
Name: teleport.MetricGenerateRequests,
Help: "Number of requests to generate new server keys",
},
)
generateThrottledRequestsCount = prometheus.NewCounter(
prometheus.CounterOpts{
Name: teleport.MetricGenerateRequestsThrottled,
Help: "Number of throttled requests to generate new server keys",
},
)
generateRequestsCurrent = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: teleport.MetricGenerateRequestsCurrent,
Help: "Number of current generate requests for server keys",
},
)
generateRequestsLatencies = prometheus.NewHistogram(
prometheus.HistogramOpts{
Name: teleport.MetricGenerateRequestsHistogram,
Help: "Latency for generate requests for server keys",
// lowest bucket start of upper bound 0.001 sec (1 ms) with factor 2
// highest bucket start of 0.001 sec * 2^15 == 32.768 sec
Buckets: prometheus.ExponentialBuckets(0.001, 2, 16),
},
)
// UserLoginCount counts user logins
UserLoginCount = prometheus.NewCounter(
prometheus.CounterOpts{
Name: teleport.MetricUserLoginCount,
Help: "Number of times there was a user login",
},
)
heartbeatsMissedByAuth = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: teleport.MetricHeartbeatsMissed,
Help: "Number of hearbeats missed by auth server",
},
)
registeredAgents = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: teleport.MetricNamespace,
Name: teleport.MetricRegisteredServers,
Help: "The number of Teleport servers (a server consists of one or more Teleport services) that have connected to the Teleport cluster, including the Teleport version. " +
"After disconnecting, a Teleport server has a TTL of 10 minutes, so this value will include servers that have recently disconnected but have not reached their TTL.",
},
[]string{teleport.TagVersion},
)
prometheusCollectors = []prometheus.Collector{
generateRequestsCount, generateThrottledRequestsCount,
generateRequestsCurrent, generateRequestsLatencies, UserLoginCount, heartbeatsMissedByAuth,
registeredAgents,
}
)
// Server keeps the cluster together. It acts as a certificate authority (CA) for
// a cluster and:
// - generates the keypair for the node it's running on
// - invites other SSH nodes to a cluster, by issuing invite tokens
// - adds other SSH nodes to a cluster, by checking their token and signing their keys
// - same for users and their sessions
// - checks public keys to see if they're signed by it (can be trusted or not)
type Server struct {
lock sync.RWMutex
// oidcClients is a map from authID & proxyAddr -> oidcClient
oidcClients map[string]*oidcClient
samlProviders map[string]*samlProvider
githubClients map[string]*githubClient
clock clockwork.Clock
bk backend.Backend
closeCtx context.Context
cancelFunc context.CancelFunc
sshca.Authority
// AuthServiceName is a human-readable name of this CA. If several Auth services are running
// (managing multiple teleport clusters) this field is used to tell them apart in UIs
// It usually defaults to the hostname of the machine the Auth service runs on.
AuthServiceName string
// ServerID is the server ID of this auth server.
ServerID string
// unstable implements unstable backend methods not suitable
// for inclusion in Services.
unstable local.UnstableService
// Services encapsulate services - provisioner, trust, etc. used by the auth
// server in a separate structure. Reads through Services hit the backend.
*Services
// Cache should either be the same as Services, or a caching layer over it.
// As it's an interface (and thus directly implementing all of its methods)
// its embedding takes priority over Services (which only indirectly
// implements its methods), thus any implemented GetFoo method on both Cache
// and Services will call the one from Cache. To bypass the cache, call the
// method on Services instead.
Cache
// privateKey is used in tests to use pre-generated private keys
privateKey []byte
// cipherSuites is a list of ciphersuites that the auth server supports.
cipherSuites []uint16
// limiter limits the number of active connections per client IP.
limiter *limiter.ConnectionsLimiter
// Emitter is events emitter, used to submit discrete events
emitter apievents.Emitter
// streamer is events sessionstreamer, used to create continuous
// session related streams
streamer events.Streamer
// keyStore is an interface for interacting with private keys in CAs which
// may be backed by HSMs
keyStore keystore.KeyStore
// lockWatcher is a lock watcher, used to verify cert generation requests.
lockWatcher *services.LockWatcher
// getClaimsFun is used in tests for overriding the implementation of getClaims method used in OIDC.
getClaimsFun func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error)
inventory *inventory.Controller
// traceClient is used to forward spans to the upstream collector for components
// within the cluster that don't have a direct connection to said collector
traceClient otlptrace.Client
}
func (a *Server) CloseContext() context.Context {
return a.closeCtx
}
// SetLockWatcher sets the lock watcher.
func (a *Server) SetLockWatcher(lockWatcher *services.LockWatcher) {
a.lock.Lock()
defer a.lock.Unlock()
a.lockWatcher = lockWatcher
}
func (a *Server) checkLockInForce(mode constants.LockingMode, targets []types.LockTarget) error {
a.lock.RLock()
defer a.lock.RUnlock()
if a.lockWatcher == nil {
return trace.BadParameter("lockWatcher is not set")
}
return a.lockWatcher.CheckLockInForce(mode, targets...)
}
// runPeriodicOperations runs some periodic bookkeeping operations
// performed by auth server
func (a *Server) runPeriodicOperations() {
ctx := context.TODO()
// run periodic functions with a semi-random period
// to avoid contention on the database in case if there are multiple
// auth servers running - so they don't compete trying
// to update the same resources.
r := insecurerand.New(insecurerand.NewSource(a.GetClock().Now().UnixNano()))
period := defaults.HighResPollingPeriod + time.Duration(r.Intn(int(defaults.HighResPollingPeriod/time.Second)))*time.Second
log.Debugf("Ticking with period: %v.", period)
a.lock.RLock()
ticker := a.clock.NewTicker(period)
a.lock.RUnlock()
// Create a ticker with jitter
heartbeatCheckTicker := interval.New(interval.Config{
Duration: apidefaults.ServerKeepAliveTTL() * 2,
Jitter: utils.NewSeventhJitter(),
})
promTicker := time.NewTicker(defaults.PrometheusScrapeInterval)
missedKeepAliveCount := 0
defer ticker.Stop()
defer heartbeatCheckTicker.Stop()
defer promTicker.Stop()
for {
select {
case <-a.closeCtx.Done():
return
case <-ticker.Chan():
err := a.autoRotateCertAuthorities(ctx)
if err != nil {
if trace.IsCompareFailed(err) {
log.Debugf("Cert authority has been updated concurrently: %v.", err)
} else {
log.Errorf("Failed to perform cert rotation check: %v.", err)
}
}
case <-heartbeatCheckTicker.Next():
nodes, err := a.GetNodes(ctx, apidefaults.Namespace)
if err != nil {
log.Errorf("Failed to load nodes for heartbeat metric calculation: %v", err)
}
for _, node := range nodes {
if services.NodeHasMissedKeepAlives(node) {
missedKeepAliveCount++
}
}
// Update prometheus gauge
heartbeatsMissedByAuth.Set(float64(missedKeepAliveCount))
case <-promTicker.C:
a.updateVersionMetrics()
}
}
}
// updateVersionMetrics leverages the cache to report all versions of teleport servers connected to the
// cluster via prometheus metrics
func (a *Server) updateVersionMetrics() {
hostID := make(map[string]struct{})
versionCount := make(map[string]int)
// Nodes, Proxies, Auths, KubeServices, and WindowsDesktopServices use the UUID as the name field where
// DB and App store it in the spec. Check expiry due to DynamoDB taking up to 48hr to expire from backend
// and then store hostID and version count information.
serverCheck := func(server interface{}) {
type serverKubeWindows interface {
Expiry() time.Time
GetName() string
GetTeleportVersion() string
}
type appDB interface {
serverKubeWindows
GetHostID() string
}
// appDB needs to be first as it also matches the serverKubeWindows interface
if a, ok := server.(appDB); ok {
if a.Expiry().Before(time.Now()) {
return
}
if _, present := hostID[a.GetHostID()]; !present {
hostID[a.GetHostID()] = struct{}{}
versionCount[a.GetTeleportVersion()]++
}
} else if s, ok := server.(serverKubeWindows); ok {
if s.Expiry().Before(time.Now()) {
return
}
if _, present := hostID[s.GetName()]; !present {
hostID[s.GetName()] = struct{}{}
versionCount[s.GetTeleportVersion()]++
}
}
}
proxyServers, err := a.GetProxies()
if err != nil {
log.Debugf("Failed to get Proxies for teleport_registered_servers metric: %v", err)
}
for _, proxyServer := range proxyServers {
serverCheck(proxyServer)
}
authServers, err := a.GetAuthServers()
if err != nil {
log.Debugf("Failed to get Auth servers for teleport_registered_servers metric: %v", err)
}
for _, authServer := range authServers {
serverCheck(authServer)
}
servers, err := a.GetNodes(a.closeCtx, apidefaults.Namespace)
if err != nil {
log.Debugf("Failed to get Nodes for teleport_registered_servers metric: %v", err)
}
for _, server := range servers {
serverCheck(server)
}
dbs, err := a.GetDatabaseServers(a.closeCtx, apidefaults.Namespace)
if err != nil {
log.Debugf("Failed to get Database servers for teleport_registered_servers metric: %v", err)
}
for _, db := range dbs {
serverCheck(db)
}
apps, err := a.GetApplicationServers(a.closeCtx, apidefaults.Namespace)
if err != nil {
log.Debugf("Failed to get Application servers for teleport_registered_servers metric: %v", err)
}
for _, app := range apps {
serverCheck(app)
}
kubeServers, err := a.GetKubernetesServers(a.closeCtx)
if err != nil {
log.Debugf("Failed to get Kube servers for teleport_registered_servers metric: %v", err)
}
for _, kubeService := range kubeServers {
serverCheck(kubeService)
}
windowsServices, err := a.GetWindowsDesktopServices(a.closeCtx)
if err != nil {
log.Debugf("Failed to get Window Desktop Services for teleport_registered_servers metric: %v", err)
}
for _, windowsService := range windowsServices {
serverCheck(windowsService)
}
// reset the gauges so that any versions that fall off are removed from exported metrics
registeredAgents.Reset()
for version, count := range versionCount {
registeredAgents.WithLabelValues(version).Set(float64(count))
}
}
func (a *Server) Close() error {
a.cancelFunc()
var errs []error
if err := a.inventory.Close(); err != nil {
errs = append(errs, err)
}
if a.bk != nil {
if err := a.bk.Close(); err != nil {
errs = append(errs, err)
}
}
return trace.NewAggregate(errs...)
}
func (a *Server) GetClock() clockwork.Clock {
a.lock.RLock()
defer a.lock.RUnlock()
return a.clock
}
// SetClock sets clock, used in tests
func (a *Server) SetClock(clock clockwork.Clock) {
a.lock.Lock()
defer a.lock.Unlock()
a.clock = clock
}
// SetAuditLog sets the server's audit log
func (a *Server) SetAuditLog(auditLog events.IAuditLog) {
a.Services.IAuditLog = auditLog
}
// SetEnforcer sets the server's enforce service
func (a *Server) SetEnforcer(enforcer services.Enforcer) {
a.Services.Enforcer = enforcer
}
// GetDomainName returns the domain name that identifies this authority server.
// Also known as "cluster name"
func (a *Server) GetDomainName() (string, error) {
clusterName, err := a.GetClusterName()
if err != nil {
return "", trace.Wrap(err)
}
return clusterName.GetClusterName(), nil
}
// GetClusterCACert returns the PEM-encoded TLS certs for the local cluster. If
// the cluster has multiple TLS certs, they will all be concatenated.
func (a *Server) GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error) {
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
// Extract the TLS CA for this cluster.
hostCA, err := a.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
DomainName: clusterName.GetClusterName(),
}, false)
if err != nil {
return nil, trace.Wrap(err)
}
certs := services.GetTLSCerts(hostCA)
if len(certs) < 1 {
return nil, trace.NotFound("no tls certs found in host CA")
}
allCerts := bytes.Join(certs, []byte("\n"))
return &proto.GetClusterCACertResponse{
TLSCA: allCerts,
}, nil
}
// GenerateHostCert uses the private key of the CA to sign the public key of the host
// (along with meta data like host ID, node name, roles, and ttl) to generate a host certificate.
func (a *Server) GenerateHostCert(hostPublicKey []byte, hostID, nodeName string, principals []string, clusterName string, role types.SystemRole, ttl time.Duration) ([]byte, error) {
domainName, err := a.GetDomainName()
if err != nil {
return nil, trace.Wrap(err)
}
// get the certificate authority that will be signing the public key of the host
ca, err := a.Services.GetCertAuthority(context.TODO(), types.CertAuthID{
Type: types.HostCA,
DomainName: domainName,
}, true)
if err != nil {
return nil, trace.BadParameter("failed to load host CA for %q: %v", domainName, err)
}
caSigner, err := a.keyStore.GetSSHSigner(ca)
if err != nil {
return nil, trace.Wrap(err)
}
// create and sign!
return a.generateHostCert(services.HostCertParams{
CASigner: caSigner,
PublicHostKey: hostPublicKey,
HostID: hostID,
NodeName: nodeName,
Principals: principals,
ClusterName: clusterName,
Role: role,
TTL: ttl,
})
}
func (a *Server) generateHostCert(p services.HostCertParams) ([]byte, error) {
authPref, err := a.GetAuthPreference(context.TODO())
if err != nil {
return nil, trace.Wrap(err)
}
if p.Role == types.RoleNode {
if lockErr := a.checkLockInForce(authPref.GetLockingMode(),
[]types.LockTarget{{Node: p.HostID}, {Node: HostFQDN(p.HostID, p.ClusterName)}},
); lockErr != nil {
return nil, trace.Wrap(lockErr)
}
}
return a.Authority.GenerateHostCert(p)
}
// GetKeyStore returns the KeyStore used by the auth server
func (a *Server) GetKeyStore() keystore.KeyStore {
return a.keyStore
}
type certRequest struct {
// user is a user to generate certificate for
user types.User
// impersonator is a user who generates the certificate,
// is set when different from the user in the certificate
impersonator string
// checker is used to perform RBAC checks.
checker services.AccessChecker
// ttl is Duration of the certificate
ttl time.Duration
// publicKey is RSA public key in authorized_keys format
publicKey []byte
// compatibility is compatibility mode
compatibility string
// overrideRoleTTL is used for requests when the requested TTL should not be
// adjusted based off the role of the user. This is used by tctl to allow
// creating long lived user certs.
overrideRoleTTL bool
// usage is a list of acceptable usages to be encoded in X509 certificate,
// is used to limit ways the certificate can be used, for example
// the cert can be only used against kubernetes endpoint, and not auth endpoint,
// no usage means unrestricted (to keep backwards compatibility)
usage []string
// routeToCluster is an optional teleport cluster name to route the
// certificate requests to, this teleport cluster name will be used to
// route the requests to in case of kubernetes
routeToCluster string
// kubernetesCluster specifies the target kubernetes cluster for TLS
// identities. This can be empty on older Teleport clients.
kubernetesCluster string
// traits hold claim data used to populate a role at runtime.
traits wrappers.Traits
// activeRequests tracks privilege escalation requests applied
// during the construction of the certificate.
activeRequests services.RequestIDs
// appSessionID is the session ID of the application session.
appSessionID string
// appPublicAddr is the public address of the application.
appPublicAddr string
// appClusterName is the name of the cluster this application is in.
appClusterName string
// appName is the name of the application to generate cert for.
appName string
// awsRoleARN is the role ARN to generate certificate for.
awsRoleARN string
// dbService identifies the name of the database service requests will
// be routed to.
dbService string
// dbProtocol specifies the protocol of the database a certificate will
// be issued for.
dbProtocol string
// dbUser is the optional database user which, if provided, will be used
// as a default username.
dbUser string
// dbName is the optional database name which, if provided, will be used
// as a default database.
dbName string
// mfaVerified is the UUID of an MFA device when this certRequest was
// created immediately after an MFA check.
mfaVerified string
// clientIP is an IP of the client requesting the certificate.
clientIP string
// sourceIP is an IP this certificate should be pinned to
sourceIP string
// disallowReissue flags that a cert should not be allowed to issue future
// certificates.
disallowReissue bool
// renewable indicates that the certificate can be renewed,
// having its TTL increased
renewable bool
// includeHostCA indicates that host CA certs should be included in the
// returned certs
includeHostCA bool
// generation indicates the number of times this certificate has been
// renewed.
generation uint64
}
// check verifies the cert request is valid.
func (r *certRequest) check() error {
if r.user == nil {
return trace.BadParameter("missing parameter user")
}
if r.checker == nil {
return trace.BadParameter("missing parameter checker")
}
// When generating certificate for MongoDB access, database username must
// be encoded into it. This is required to be able to tell which database
// user to authenticate the connection as.
if r.dbProtocol == defaults.ProtocolMongoDB {
if r.dbUser == "" {
return trace.BadParameter("must provide database user name to generate certificate for database %q", r.dbService)
}
}
return nil
}
type certRequestOption func(*certRequest)
func certRequestMFAVerified(mfaID string) certRequestOption {
return func(r *certRequest) { r.mfaVerified = mfaID }
}
func certRequestClientIP(ip string) certRequestOption {
return func(r *certRequest) { r.clientIP = ip }
}
// GenerateUserTestCerts is used to generate user certificate, used internally for tests
func (a *Server) GenerateUserTestCerts(key []byte, username string, ttl time.Duration, compatibility, routeToCluster, sourceIP string) ([]byte, []byte, error) {
user, err := a.GetUser(username, false)
if err != nil {
return nil, nil, trace.Wrap(err)
}
accessInfo := services.AccessInfoFromUser(user)
clusterName, err := a.GetClusterName()
if err != nil {
return nil, nil, trace.Wrap(err)
}
checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), a)
if err != nil {
return nil, nil, trace.Wrap(err)
}
certs, err := a.generateUserCert(certRequest{
user: user,
ttl: ttl,
compatibility: compatibility,
publicKey: key,
routeToCluster: routeToCluster,
checker: checker,
traits: user.GetTraits(),
sourceIP: sourceIP,
})
if err != nil {
return nil, nil, trace.Wrap(err)
}
return certs.SSH, certs.TLS, nil
}
// AppTestCertRequest combines parameters for generating a test app access cert.
type AppTestCertRequest struct {
// PublicKey is the public key to sign.
PublicKey []byte
// Username is the Teleport user name to sign certificate for.
Username string
// TTL is the test certificate validity period.
TTL time.Duration
// PublicAddr is the application public address. Used for routing.
PublicAddr string
// ClusterName is the name of the cluster application resides in. Used for routing.
ClusterName string
// SessionID is the optional session ID to encode. Used for routing.
SessionID string
// AWSRoleARN is optional AWS role ARN a user wants to assume to encode.
AWSRoleARN string
}
// GenerateUserAppTestCert generates an application specific certificate, used
// internally for tests.
func (a *Server) GenerateUserAppTestCert(req AppTestCertRequest) ([]byte, error) {
user, err := a.GetUser(req.Username, false)
if err != nil {
return nil, trace.Wrap(err)
}
accessInfo := services.AccessInfoFromUser(user)
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), a)
if err != nil {
return nil, trace.Wrap(err)
}
sessionID := req.SessionID
if sessionID == "" {
sessionID = uuid.New().String()
}
certs, err := a.generateUserCert(certRequest{
user: user,
publicKey: req.PublicKey,
checker: checker,
ttl: req.TTL,
// Set the login to be a random string. Application certificates are never
// used to log into servers but SSH certificate generation code requires a
// principal be in the certificate.
traits: wrappers.Traits(map[string][]string{
constants.TraitLogins: {uuid.New().String()},
}),
// Only allow this certificate to be used for applications.
usage: []string{teleport.UsageAppsOnly},
// Add in the application routing information.
appSessionID: sessionID,
appPublicAddr: req.PublicAddr,
appClusterName: req.ClusterName,
awsRoleARN: req.AWSRoleARN,
})
if err != nil {
return nil, trace.Wrap(err)
}
return certs.TLS, nil
}
// DatabaseTestCertRequest combines parameters for generating test database
// access certificate.
type DatabaseTestCertRequest struct {
// PublicKey is the public key to sign.
PublicKey []byte
// Cluster is the Teleport cluster name.
Cluster string
// Username is the Teleport username.
Username string
// RouteToDatabase contains database routing information.
RouteToDatabase tlsca.RouteToDatabase
}
// GenerateDatabaseTestCert generates a database access certificate for the
// provided parameters. Used only internally in tests.
func (a *Server) GenerateDatabaseTestCert(req DatabaseTestCertRequest) ([]byte, error) {
user, err := a.GetUser(req.Username, false)
if err != nil {
return nil, trace.Wrap(err)
}
accessInfo := services.AccessInfoFromUser(user)
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), a)
if err != nil {
return nil, trace.Wrap(err)
}
certs, err := a.generateUserCert(certRequest{
user: user,
publicKey: req.PublicKey,
checker: checker,
ttl: time.Hour,
traits: map[string][]string{
constants.TraitLogins: {req.Username},
},
routeToCluster: req.Cluster,
dbService: req.RouteToDatabase.ServiceName,
dbProtocol: req.RouteToDatabase.Protocol,
dbUser: req.RouteToDatabase.Username,
dbName: req.RouteToDatabase.Database,
})
if err != nil {
return nil, trace.Wrap(err)
}
return certs.TLS, nil
}
// generateUserCert generates user certificates
func (a *Server) generateUserCert(req certRequest) (*proto.Certs, error) {
ctx := context.TODO()
err := req.check()
if err != nil {
return nil, trace.Wrap(err)
}
if len(req.checker.GetAllowedResourceIDs()) > 0 && !modules.GetModules().Features().ResourceAccessRequests {
return nil, trace.AccessDenied("this Teleport cluster is not licensed for resource access requests, please contact the cluster administrator")
}
// Reject the cert request if there is a matching lock in force.
authPref, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
lockingMode := req.checker.LockingMode(authPref.GetLockingMode())
lockTargets := []types.LockTarget{
{User: req.user.GetName()},
{MFADevice: req.mfaVerified},
}
lockTargets = append(lockTargets,
services.RolesToLockTargets(req.checker.RoleNames())...,
)
lockTargets = append(lockTargets,
services.AccessRequestsToLockTargets(req.activeRequests.AccessRequests)...,
)
if err := a.checkLockInForce(lockingMode, lockTargets); err != nil {
return nil, trace.Wrap(err)
}
// reuse the same RSA keys for SSH and TLS keys
cryptoPubKey, err := sshutils.CryptoPublicKey(req.publicKey)
if err != nil {
return nil, trace.Wrap(err)
}
// extract the passed in certificate format. if nothing was passed in, fetch
// the certificate format from the role.
certificateFormat, err := utils.CheckCertificateFormatFlag(req.compatibility)
if err != nil {
return nil, trace.Wrap(err)
}
if certificateFormat == teleport.CertificateFormatUnspecified {
certificateFormat = req.checker.CertificateFormat()
}
var sessionTTL time.Duration
var allowedLogins []string
// If the role TTL is ignored, do not restrict session TTL and allowed logins.
// The only caller setting this parameter should be "tctl auth sign".
// Otherwise, set the session TTL to the smallest of all roles and
// then only grant access to allowed logins based on that.
if req.overrideRoleTTL {
// Take whatever was passed in. Pass in 0 to CheckLoginDuration so all
// logins are returned for the role set.
sessionTTL = req.ttl
allowedLogins, err = req.checker.CheckLoginDuration(0)
if err != nil {
return nil, trace.Wrap(err)
}
} else {
// Adjust session TTL to the smaller of two values: the session TTL
// requested in tsh or the session TTL for the role.
sessionTTL = req.checker.AdjustSessionTTL(req.ttl)
// Return a list of logins that meet the session TTL limit. This means if
// the requested session TTL is larger than the max session TTL for a login,
// that login will not be included in the list of allowed logins.
allowedLogins, err = req.checker.CheckLoginDuration(sessionTTL)
if err != nil {
return nil, trace.Wrap(err)
}
}
clusterName, err := a.GetDomainName()
if err != nil {
return nil, trace.Wrap(err)
}
if req.routeToCluster == "" {
req.routeToCluster = clusterName
}
if req.routeToCluster != clusterName {
// Authorize access to a remote cluster.
rc, err := a.GetRemoteCluster(req.routeToCluster)
if err != nil {
return nil, trace.Wrap(err)
}
if err := req.checker.CheckAccessToRemoteCluster(rc); err != nil {
if trace.IsAccessDenied(err) {
return nil, trace.NotFound("remote cluster %q not found", req.routeToCluster)
}
return nil, trace.Wrap(err)
}
}
userCA, err := a.GetCertAuthority(ctx, types.CertAuthID{
Type: types.UserCA,
DomainName: clusterName,
}, true)
if err != nil {
return nil, trace.Wrap(err)
}
caSigner, err := a.keyStore.GetSSHSigner(userCA)
if err != nil {
return nil, trace.Wrap(err)
}
// Add the special join-only principal used for joining sessions.
// All users have access to this and join RBAC rules are checked after the connection is established.
allowedLogins = append(allowedLogins, teleport.SSHSessionJoinPrincipal)
requestedResourcesStr, err := types.ResourceIDsToString(req.checker.GetAllowedResourceIDs())
if err != nil {
return nil, trace.Wrap(err)
}
params := services.UserCertParams{
CASigner: caSigner,
PublicUserKey: req.publicKey,
Username: req.user.GetName(),
Impersonator: req.impersonator,
AllowedLogins: allowedLogins,
TTL: sessionTTL,
Roles: req.checker.RoleNames(),
CertificateFormat: certificateFormat,
PermitPortForwarding: req.checker.CanPortForward(),
PermitAgentForwarding: req.checker.CanForwardAgents(),
PermitX11Forwarding: req.checker.PermitX11Forwarding(),
RouteToCluster: req.routeToCluster,
Traits: req.traits,
ActiveRequests: req.activeRequests,
MFAVerified: req.mfaVerified,
ClientIP: req.clientIP,
DisallowReissue: req.disallowReissue,
Renewable: req.renewable,
Generation: req.generation,
CertificateExtensions: req.checker.CertificateExtensions(),
AllowedResourceIDs: requestedResourcesStr,
SourceIP: req.sourceIP,
}
sshCert, err := a.Authority.GenerateUserCert(params)
if err != nil {
return nil, trace.Wrap(err)
}
kubeGroups, kubeUsers, err := req.checker.CheckKubeGroupsAndUsers(sessionTTL, req.overrideRoleTTL)
// NotFound errors are acceptable - this user may have no k8s access
// granted and that shouldn't prevent us from issuing a TLS cert.
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
// Only validate/default kubernetes cluster name for the current teleport
// cluster. If this cert is targeting a trusted teleport cluster, leave all
// the kubernetes cluster validation up to them.
if req.routeToCluster == clusterName {
req.kubernetesCluster, err = kubeutils.CheckOrSetKubeCluster(a.closeCtx, a, req.kubernetesCluster, clusterName)
if err != nil {
if !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
log.Debug("Failed setting default kubernetes cluster for user login (user did not provide a cluster); leaving KubernetesCluster extension in the TLS certificate empty")
}
}
// See which database names and users this user is allowed to use.
dbNames, dbUsers, err := req.checker.CheckDatabaseNamesAndUsers(sessionTTL, req.overrideRoleTTL)
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
// See which AWS role ARNs this user is allowed to assume.
roleARNs, err := req.checker.CheckAWSRoleARNs(sessionTTL, req.overrideRoleTTL)
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
// generate TLS certificate
cert, signer, err := a.keyStore.GetTLSCertAndSigner(userCA)
if err != nil {
return nil, trace.Wrap(err)
}
tlsAuthority, err := tlsca.FromCertAndSigner(cert, signer)
if err != nil {
return nil, trace.Wrap(err)
}
identity := tlsca.Identity{
Username: req.user.GetName(),
Impersonator: req.impersonator,
Groups: req.checker.RoleNames(),
Principals: allowedLogins,
Usage: req.usage,
RouteToCluster: req.routeToCluster,
KubernetesCluster: req.kubernetesCluster,
Traits: req.traits,
KubernetesGroups: kubeGroups,
KubernetesUsers: kubeUsers,
RouteToApp: tlsca.RouteToApp{
SessionID: req.appSessionID,
PublicAddr: req.appPublicAddr,
ClusterName: req.appClusterName,
Name: req.appName,
AWSRoleARN: req.awsRoleARN,
},
TeleportCluster: clusterName,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: req.dbService,
Protocol: req.dbProtocol,
Username: req.dbUser,
Database: req.dbName,
},
DatabaseNames: dbNames,
DatabaseUsers: dbUsers,
MFAVerified: req.mfaVerified,
ClientIP: req.clientIP,
AWSRoleARNs: roleARNs,
ActiveRequests: req.activeRequests.AccessRequests,
DisallowReissue: req.disallowReissue,
Renewable: req.renewable,
Generation: req.generation,
AllowedResourceIDs: req.checker.GetAllowedResourceIDs(),
}
subject, err := identity.Subject()
if err != nil {
return nil, trace.Wrap(err)
}
certRequest := tlsca.CertificateRequest{
Clock: a.clock,
PublicKey: cryptoPubKey,
Subject: subject,
NotAfter: a.clock.Now().UTC().Add(sessionTTL),
}
tlsCert, err := tlsAuthority.GenerateCertificate(certRequest)
if err != nil {
return nil, trace.Wrap(err)
}
eventIdentity := identity.GetEventIdentity()
eventIdentity.Expires = certRequest.NotAfter
if a.emitter.EmitAuditEvent(a.closeCtx, &apievents.CertificateCreate{
Metadata: apievents.Metadata{
Type: events.CertificateCreateEvent,
Code: events.CertificateCreateCode,
},
CertificateType: events.CertificateTypeUser,
Identity: &eventIdentity,
}); err != nil {
log.WithError(err).Warn("Failed to emit certificate create event.")
}
// create certs struct to return to user
certs := &proto.Certs{
SSH: sshCert,
TLS: tlsCert,
}
// always include user CA TLS and SSH certs
cas := []types.CertAuthority{userCA}
// also include host CA certs if requested
if req.includeHostCA {
hostCA, err := a.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
DomainName: clusterName,
}, false)
if err != nil {
return nil, trace.Wrap(err)
}
cas = append(cas, hostCA)
}
for _, ca := range cas {
certs.TLSCACerts = append(certs.TLSCACerts, services.GetTLSCerts(ca)...)
certs.SSHCACerts = append(certs.SSHCACerts, services.GetSSHCheckingKeys(ca)...)
}
return certs, nil
}
// WithUserLock executes function authenticateFn that performs user authentication
// if authenticateFn returns non nil error, the login attempt will be logged in as failed.
// The only exception to this rule is ConnectionProblemError, in case if it occurs
// access will be denied, but login attempt will not be recorded
// this is done to avoid potential user lockouts due to backend failures
// In case if user exceeds defaults.MaxLoginAttempts
// the user account will be locked for defaults.AccountLockInterval
func (a *Server) WithUserLock(username string, authenticateFn func() error) error {
user, err := a.Services.GetUser(username, false)
if err != nil {
if trace.IsNotFound(err) {
// If user is not found, still call authenticateFn. It should
// always return an error. This prevents username oracles and
// timing attacks.
return authenticateFn()
}
return trace.Wrap(err)
}
status := user.GetStatus()
if status.IsLocked {
if status.RecoveryAttemptLockExpires.After(a.clock.Now().UTC()) {
log.Debugf("%v exceeds %v failed account recovery attempts, locked until %v",
user.GetName(), defaults.MaxAccountRecoveryAttempts, apiutils.HumanTimeFormat(status.RecoveryAttemptLockExpires))
err := trace.AccessDenied(MaxFailedAttemptsErrMsg)
err.AddField(ErrFieldKeyUserMaxedAttempts, true)
return err
}
if status.LockExpires.After(a.clock.Now().UTC()) {
log.Debugf("%v exceeds %v failed login attempts, locked until %v",
user.GetName(), defaults.MaxLoginAttempts, apiutils.HumanTimeFormat(status.LockExpires))
err := trace.AccessDenied(MaxFailedAttemptsErrMsg)
err.AddField(ErrFieldKeyUserMaxedAttempts, true)
return err
}
}
fnErr := authenticateFn()
if fnErr == nil {
// upon successful login, reset the failed attempt counter
err = a.DeleteUserLoginAttempts(username)
if !trace.IsNotFound(err) {
return trace.Wrap(err)
}
return nil
}
// do not lock user in case if DB is flaky or down
if trace.IsConnectionProblem(err) {
return trace.Wrap(fnErr)
}
// log failed attempt and possibly lock user
attempt := services.LoginAttempt{Time: a.clock.Now().UTC(), Success: false}
err = a.AddUserLoginAttempt(username, attempt, defaults.AttemptTTL)
if err != nil {
log.Error(trace.DebugReport(err))
return trace.Wrap(fnErr)
}
loginAttempts, err := a.GetUserLoginAttempts(username)
if err != nil {
log.Error(trace.DebugReport(err))
return trace.Wrap(fnErr)
}
if !services.LastFailed(defaults.MaxLoginAttempts, loginAttempts) {
log.Debugf("%v user has less than %v failed login attempts", username, defaults.MaxLoginAttempts)
return trace.Wrap(fnErr)
}
lockUntil := a.clock.Now().UTC().Add(defaults.AccountLockInterval)
log.Debug(fmt.Sprintf("%v exceeds %v failed login attempts, locked until %v",
username, defaults.MaxLoginAttempts, apiutils.HumanTimeFormat(lockUntil)))
user.SetLocked(lockUntil, "user has exceeded maximum failed login attempts")
err = a.UpsertUser(user)
if err != nil {
log.Error(trace.DebugReport(err))
return trace.Wrap(fnErr)
}
retErr := trace.AccessDenied(MaxFailedAttemptsErrMsg)
retErr.AddField(ErrFieldKeyUserMaxedAttempts, true)
return retErr
}
// PreAuthenticatedSignIn is for MFA authentication methods where the password
// is already checked before issuing the second factor challenge
func (a *Server) PreAuthenticatedSignIn(ctx context.Context, user string, identity tlsca.Identity) (types.WebSession, error) {
accessInfo, err := services.AccessInfoFromLocalIdentity(identity, a)
if err != nil {
return nil, trace.Wrap(err)
}
sess, err := a.NewWebSession(ctx, types.NewWebSessionRequest{
User: user,
Roles: accessInfo.Roles,
Traits: accessInfo.Traits,
AccessRequests: identity.ActiveRequests,
RequestedResourceIDs: accessInfo.AllowedResourceIDs,
})
if err != nil {
return nil, trace.Wrap(err)
}
if err := a.upsertWebSession(ctx, user, sess); err != nil {
return nil, trace.Wrap(err)
}
return sess.WithoutSecrets(), nil
}
// CreateAuthenticateChallenge implements AuthService.CreateAuthenticateChallenge.
func (a *Server) CreateAuthenticateChallenge(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
var username string
var passwordless bool
switch req.GetRequest().(type) {
case *proto.CreateAuthenticateChallengeRequest_UserCredentials:
username = req.GetUserCredentials().GetUsername()
if err := a.WithUserLock(username, func() error {
return a.checkPasswordWOToken(username, req.GetUserCredentials().GetPassword())
}); err != nil {
return nil, trace.Wrap(err)
}
case *proto.CreateAuthenticateChallengeRequest_RecoveryStartTokenID:
token, err := a.GetUserToken(ctx, req.GetRecoveryStartTokenID())
if err != nil {
log.Error(trace.DebugReport(err))
return nil, trace.AccessDenied("invalid token")
}
if err := a.verifyUserToken(token, UserTokenTypeRecoveryStart); err != nil {
return nil, trace.Wrap(err)
}
username = token.GetUser()
case *proto.CreateAuthenticateChallengeRequest_Passwordless:
passwordless = true // Allows empty username.
default: // unset or CreateAuthenticateChallengeRequest_ContextUser.
var err error
username, err = GetClientUsername(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
}
challenges, err := a.mfaAuthChallenge(ctx, username, passwordless)
if err != nil {
log.Error(trace.DebugReport(err))
return nil, trace.AccessDenied("unable to create MFA challenges")
}
return challenges, nil
}
// CreateRegisterChallenge implements AuthService.CreateRegisterChallenge.
func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateRegisterChallengeRequest) (*proto.MFARegisterChallenge, error) {
token, err := a.GetUserToken(ctx, req.GetTokenID())
if err != nil {
log.Error(trace.DebugReport(err))
return nil, trace.AccessDenied("invalid token")
}
allowedTokenTypes := []string{
UserTokenTypePrivilege,
UserTokenTypePrivilegeException,
UserTokenTypeResetPassword,
UserTokenTypeResetPasswordInvite,
UserTokenTypeRecoveryApproved,
}
if err := a.verifyUserToken(token, allowedTokenTypes...); err != nil {
return nil, trace.AccessDenied("invalid token")
}
regChal, err := a.createRegisterChallenge(ctx, &newRegisterChallengeRequest{
username: token.GetUser(),
token: token,
deviceType: req.GetDeviceType(),
deviceUsage: req.GetDeviceUsage(),
})
return regChal, trace.Wrap(err)
}
type newRegisterChallengeRequest struct {
username string
deviceType proto.DeviceType
deviceUsage proto.DeviceUsage
// token is a user token resource.
// It is used as following:
// - TOTP:
// - create a UserTokenSecrets resource
// - store by token's ID using Server's IdentityService.
// - MFA:
// - store challenge by the token's ID
// - store by token's ID using Server's IdentityService.
// This field can be empty to use storage overrides.
token types.UserToken
// webIdentityOverride is an optional RegistrationIdentity override to be used
// to store webauthn challenge. A common override is decorating the regular
// Identity with an in-memory SessionData storage.
// Defaults to the Server's IdentityService.
webIdentityOverride wanlib.RegistrationIdentity
}
func (a *Server) createRegisterChallenge(ctx context.Context, req *newRegisterChallengeRequest) (*proto.MFARegisterChallenge, error) {
switch req.deviceType {
case proto.DeviceType_DEVICE_TYPE_TOTP:
otpKey, otpOpts, err := a.newTOTPKey(req.username)
if err != nil {
return nil, trace.Wrap(err)
}
challenge := &proto.TOTPRegisterChallenge{
Secret: otpKey.Secret(),
Issuer: otpKey.Issuer(),
PeriodSeconds: uint32(otpOpts.Period),
Algorithm: otpOpts.Algorithm.String(),
Digits: uint32(otpOpts.Digits.Length()),
Account: otpKey.AccountName(),
}
if req.token != nil {
secrets, err := a.createTOTPUserTokenSecrets(ctx, req.token, otpKey)
if err != nil {
return nil, trace.Wrap(err)
}
challenge.QRCode = secrets.GetQRCode()
}
return &proto.MFARegisterChallenge{Request: &proto.MFARegisterChallenge_TOTP{TOTP: challenge}}, nil
case proto.DeviceType_DEVICE_TYPE_WEBAUTHN:
cap, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
webConfig, err := cap.GetWebauthn()
if err != nil {
return nil, trace.Wrap(err)
}
identity := req.webIdentityOverride
if identity == nil {
identity = a.Services
}
webRegistration := &wanlib.RegistrationFlow{
Webauthn: webConfig,
Identity: identity,
}
passwordless := req.deviceUsage == proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS
credentialCreation, err := webRegistration.Begin(ctx, req.username, passwordless)
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.MFARegisterChallenge{Request: &proto.MFARegisterChallenge_Webauthn{
Webauthn: wanlib.CredentialCreationToProto(credentialCreation),
}}, nil
default:
return nil, trace.BadParameter("MFA device type %q unsupported", req.deviceType.String())
}
}
// GetMFADevices returns all mfa devices for the user defined in the token or the user defined in context.
func (a *Server) GetMFADevices(ctx context.Context, req *proto.GetMFADevicesRequest) (*proto.GetMFADevicesResponse, error) {
var username string
if req.GetTokenID() != "" {
token, err := a.GetUserToken(ctx, req.GetTokenID())
if err != nil {
log.Error(trace.DebugReport(err))
return nil, trace.AccessDenied("invalid token")
}
if err := a.verifyUserToken(token, UserTokenTypeRecoveryApproved); err != nil {
return nil, trace.Wrap(err)
}
username = token.GetUser()
}
if username == "" {
var err error
username, err = GetClientUsername(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
}
devs, err := a.Services.GetMFADevices(ctx, username, false)
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.GetMFADevicesResponse{
Devices: devs,
}, nil
}
// DeleteMFADeviceSync implements AuthService.DeleteMFADeviceSync.
func (a *Server) DeleteMFADeviceSync(ctx context.Context, req *proto.DeleteMFADeviceSyncRequest) error {
token, err := a.GetUserToken(ctx, req.GetTokenID())
if err != nil {
log.Error(trace.DebugReport(err))
return trace.AccessDenied("invalid token")
}
if err := a.verifyUserToken(token, UserTokenTypeRecoveryApproved, UserTokenTypePrivilege); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(a.deleteMFADeviceSafely(ctx, token.GetUser(), req.GetDeviceName()))
}
// deleteMFADeviceSafely deletes the user's mfa device while preventing users from deleting their last device
// for clusters that require second factors, which prevents users from being locked out of their account.
func (a *Server) deleteMFADeviceSafely(ctx context.Context, user, deviceName string) error {
devs, err := a.Services.GetMFADevices(ctx, user, true)
if err != nil {
return trace.Wrap(err)
}
authPref, err := a.GetAuthPreference(ctx)
if err != nil {
return trace.Wrap(err)
}
kindToSF := map[string]constants.SecondFactorType{
fmt.Sprintf("%T", &types.MFADevice_Totp{}): constants.SecondFactorOTP,
fmt.Sprintf("%T", &types.MFADevice_U2F{}): constants.SecondFactorWebauthn,
fmt.Sprintf("%T", &types.MFADevice_Webauthn{}): constants.SecondFactorWebauthn,
}
sfToCount := make(map[constants.SecondFactorType]int)
var knownDevices int
var deviceToDelete *types.MFADevice
// Find the device to delete and count devices.
for _, d := range devs {
// Match device by name or ID.
if d.GetName() == deviceName || d.Id == deviceName {
deviceToDelete = d
}
sf, ok := kindToSF[fmt.Sprintf("%T", d.Device)]
switch {
case !ok && d == deviceToDelete:
return trace.NotImplemented("cannot delete device of type %T", d.Device)
case !ok:
log.Warnf("Ignoring unknown device with type %T in deletion.", d.Device)
continue
}
sfToCount[sf]++
knownDevices++
}
if deviceToDelete == nil {
return trace.NotFound("MFA device %q does not exist", deviceName)
}
// Prevent users from deleting their last device for clusters that require second factors.
const minDevices = 2
switch sf := authPref.GetSecondFactor(); sf {
case constants.SecondFactorOff, constants.SecondFactorOptional: // MFA is not required, allow deletion
case constants.SecondFactorOn:
if knownDevices < minDevices {
return trace.BadParameter(
"cannot delete the last MFA device for this user; add a replacement device first to avoid getting locked out")
}
case constants.SecondFactorOTP, constants.SecondFactorWebauthn:
if sfToCount[sf] < minDevices {
return trace.BadParameter(
"cannot delete the last %s device for this user; add a replacement device first to avoid getting locked out", sf)
}
default:
return trace.BadParameter("unexpected second factor type: %s", sf)
}
if err := a.DeleteMFADevice(ctx, user, deviceToDelete.Id); err != nil {
return trace.Wrap(err)
}
// Emit deleted event.
clusterName, err := a.GetClusterName()
if err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.MFADeviceDelete{
Metadata: apievents.Metadata{
Type: events.MFADeviceDeleteEvent,
Code: events.MFADeviceDeleteEventCode,
ClusterName: clusterName.GetClusterName(),
},
UserMetadata: apievents.UserMetadata{
User: user,
},
MFADeviceMetadata: mfaDeviceEventMetadata(deviceToDelete),
}); err != nil {
return trace.Wrap(err)
}
return nil
}
// AddMFADeviceSync implements AuthService.AddMFADeviceSync.
func (a *Server) AddMFADeviceSync(ctx context.Context, req *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error) {
privilegeToken, err := a.GetUserToken(ctx, req.GetTokenID())
if err != nil {
log.Error(trace.DebugReport(err))
return nil, trace.AccessDenied("invalid token")
}
if err := a.verifyUserToken(privilegeToken, UserTokenTypePrivilege, UserTokenTypePrivilegeException); err != nil {
return nil, trace.Wrap(err)
}
dev, err := a.verifyMFARespAndAddDevice(ctx, &newMFADeviceFields{
username: privilegeToken.GetUser(),
newDeviceName: req.GetNewDeviceName(),
tokenID: privilegeToken.GetName(),
deviceResp: req.GetNewMFAResponse(),
deviceUsage: req.DeviceUsage,
})
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.AddMFADeviceSyncResponse{Device: dev}, nil
}
type newMFADeviceFields struct {
username string
newDeviceName string
// tokenID is the ID of a reset/invite/recovery token.
// It is used as following:
// - TOTP:
// - look up TOTP secret stored by token ID
// - MFA:
// - look up challenge stored by token ID
// This field can be empty to use storage overrides.
tokenID string
// totpSecret is a secret shared by client and server to generate totp codes.
// Field can be empty to get secret by "tokenID".
totpSecret string
// webIdentityOverride is an optional RegistrationIdentity override to be used
// for device registration. A common override is decorating the regular
// Identity with an in-memory SessionData storage.
// Defaults to the Server's IdentityService.
webIdentityOverride wanlib.RegistrationIdentity
// deviceResp is the register response from the new device.
deviceResp *proto.MFARegisterResponse
// deviceUsage describes the intended usage of the new device.
deviceUsage proto.DeviceUsage
}
// verifyMFARespAndAddDevice validates MFA register response and on success adds the new MFA device.
func (a *Server) verifyMFARespAndAddDevice(ctx context.Context, req *newMFADeviceFields) (*types.MFADevice, error) {
cap, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if cap.GetSecondFactor() == constants.SecondFactorOff {
return nil, trace.BadParameter("second factor disabled by cluster configuration")
}
var dev *types.MFADevice
switch req.deviceResp.GetResponse().(type) {
case *proto.MFARegisterResponse_TOTP:
dev, err = a.registerTOTPDevice(ctx, req.deviceResp, req)
if err != nil {
return nil, trace.Wrap(err)
}
case *proto.MFARegisterResponse_Webauthn:
dev, err = a.registerWebauthnDevice(ctx, req.deviceResp, req)
if err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.BadParameter("MFARegisterResponse is an unknown response type %T", req.deviceResp.Response)
}
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.MFADeviceAdd{
Metadata: apievents.Metadata{
Type: events.MFADeviceAddEvent,
Code: events.MFADeviceAddEventCode,
ClusterName: clusterName.GetClusterName(),
},
UserMetadata: apievents.UserMetadata{
User: req.username,
},
MFADeviceMetadata: mfaDeviceEventMetadata(dev),
}); err != nil {
log.WithError(err).Warn("Failed to emit add mfa device event.")
}
return dev, nil
}
func (a *Server) registerTOTPDevice(ctx context.Context, regResp *proto.MFARegisterResponse, req *newMFADeviceFields) (*types.MFADevice, error) {
cap, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if !cap.IsSecondFactorTOTPAllowed() {
return nil, trace.BadParameter("second factor TOTP not allowed by cluster")
}
var secret string
switch {
case req.tokenID != "":
secrets, err := a.GetUserTokenSecrets(ctx, req.tokenID)
if err != nil {
return nil, trace.Wrap(err)
}
secret = secrets.GetOTPKey()
case req.totpSecret != "":
secret = req.totpSecret
default:
return nil, trace.BadParameter("missing TOTP secret")
}
dev, err := services.NewTOTPDevice(req.newDeviceName, secret, a.clock.Now())
if err != nil {
return nil, trace.Wrap(err)
}
if err := a.checkTOTP(ctx, req.username, regResp.GetTOTP().GetCode(), dev); err != nil {
return nil, trace.Wrap(err)
}
if err := a.UpsertMFADevice(ctx, req.username, dev); err != nil {
return nil, trace.Wrap(err)
}
return dev, nil
}
func (a *Server) registerWebauthnDevice(ctx context.Context, regResp *proto.MFARegisterResponse, req *newMFADeviceFields) (*types.MFADevice, error) {
cap, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if !cap.IsSecondFactorWebauthnAllowed() {
return nil, trace.BadParameter("second factor webauthn not allowed by cluster")
}
webConfig, err := cap.GetWebauthn()
if err != nil {
return nil, trace.Wrap(err)
}
identity := req.webIdentityOverride // Override Identity, if supplied.
if identity == nil {
identity = a.Services
}
webRegistration := &wanlib.RegistrationFlow{
Webauthn: webConfig,
Identity: identity,
}
// Finish upserts the device on success.
dev, err := webRegistration.Finish(ctx, wanlib.RegisterResponse{
User: req.username,
DeviceName: req.newDeviceName,
CreationResponse: wanlib.CredentialCreationResponseFromProto(regResp.GetWebauthn()),
Passwordless: req.deviceUsage == proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS,
})
return dev, trace.Wrap(err)
}
// GetWebSession returns existing web session described by req. Explicitly
// delegating to Services as it's directly implemented by Cache as well.
func (a *Server) GetWebSession(ctx context.Context, req types.GetWebSessionRequest) (types.WebSession, error) {
return a.Services.GetWebSession(ctx, req)
}
// GetWebToken returns existing web token described by req. Explicitly
// delegating to Services as it's directly implemented by Cache as well.
func (a *Server) GetWebToken(ctx context.Context, req types.GetWebTokenRequest) (types.WebToken, error) {
return a.Services.GetWebToken(ctx, req)
}
// ExtendWebSession creates a new web session for a user based on a valid previous (current) session.
//
// If there is an approved access request, additional roles are appended to the roles that were
// extracted from identity. The new session expiration time will not exceed the expiration time
// of the previous session.
//
// If there is a switchback request, the roles will switchback to user's default roles and
// the expiration time is derived from users recently logged in time.
func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identity tlsca.Identity) (types.WebSession, error) {
prevSession, err := a.GetWebSession(ctx, types.GetWebSessionRequest{
User: req.User,
SessionID: req.PrevSessionID,
})
if err != nil {
return nil, trace.Wrap(err)
}
// consider absolute expiry time that may be set for this session
// by some external identity service, so we can not renew this session
// anymore without extra logic for renewal with external OIDC provider
expiresAt := prevSession.GetExpiryTime()
if !expiresAt.IsZero() && expiresAt.Before(a.clock.Now().UTC()) {
return nil, trace.NotFound("web session has expired")
}
accessInfo, err := services.AccessInfoFromLocalIdentity(identity, a)
if err != nil {
return nil, trace.Wrap(err)
}
roles := accessInfo.Roles
traits := accessInfo.Traits
allowedResourceIDs := accessInfo.AllowedResourceIDs
accessRequests := identity.ActiveRequests
if req.AccessRequestID != "" {
accessRequest, err := a.getValidatedAccessRequest(ctx, req.User, req.AccessRequestID)
if err != nil {
return nil, trace.Wrap(err)
}
roles = append(roles, accessRequest.GetRoles()...)
roles = apiutils.Deduplicate(roles)
accessRequests = apiutils.Deduplicate(append(accessRequests, req.AccessRequestID))
if len(accessRequest.GetRequestedResourceIDs()) > 0 {
// There's not a consistent way to merge multiple resource access
// requests, a user may be able to request access to different resources
// with different roles which should not overlap.
if len(allowedResourceIDs) > 0 {
return nil, trace.BadParameter("user is already logged in with a resource access request, cannot assume another")
}
allowedResourceIDs = accessRequest.GetRequestedResourceIDs()
}
// Let session expire with the shortest expiry time.
if expiresAt.After(accessRequest.GetAccessExpiry()) {
expiresAt = accessRequest.GetAccessExpiry()
}
}
if req.Switchback {
if prevSession.GetLoginTime().IsZero() {
return nil, trace.BadParameter("Unable to switchback, log in time was not recorded.")
}
// Get default/static roles.
user, err := a.GetUser(req.User, false)
if err != nil {
return nil, trace.Wrap(err, "failed to switchback")
}
// Reset any search-based access requests
allowedResourceIDs = nil
// Calculate expiry time.
roleSet, err := services.FetchRoles(user.GetRoles(), a, user.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}
sessionTTL := roleSet.AdjustSessionTTL(apidefaults.CertDuration)
// Set default roles and expiration.
expiresAt = prevSession.GetLoginTime().UTC().Add(sessionTTL)
roles = user.GetRoles()
accessRequests = nil
}
sessionTTL := utils.ToTTL(a.clock, expiresAt)
sess, err := a.NewWebSession(ctx, types.NewWebSessionRequest{
User: req.User,
Roles: roles,
Traits: traits,
SessionTTL: sessionTTL,
AccessRequests: accessRequests,
RequestedResourceIDs: allowedResourceIDs,
})
if err != nil {
return nil, trace.Wrap(err)
}
// Keep preserving the login time.
sess.SetLoginTime(prevSession.GetLoginTime())
if err := a.upsertWebSession(ctx, req.User, sess); err != nil {
return nil, trace.Wrap(err)
}
return sess, nil
}
func (a *Server) getValidatedAccessRequest(ctx context.Context, user, accessRequestID string) (types.AccessRequest, error) {
reqFilter := types.AccessRequestFilter{
User: user,
ID: accessRequestID,
}
reqs, err := a.GetAccessRequests(ctx, reqFilter)
if err != nil {
return nil, trace.Wrap(err)
}
if len(reqs) < 1 {
return nil, trace.NotFound("access request %q not found", accessRequestID)
}
req := reqs[0]
if !req.GetState().IsApproved() {
if req.GetState().IsDenied() {
return nil, trace.AccessDenied("access request %q has been denied", accessRequestID)
}
return nil, trace.AccessDenied("access request %q is awaiting approval", accessRequestID)
}
if err := services.ValidateAccessRequestForUser(ctx, a, req); err != nil {
return nil, trace.Wrap(err)
}
accessExpiry := req.GetAccessExpiry()
if accessExpiry.Before(a.GetClock().Now()) {
return nil, trace.BadParameter("access request %q has expired", accessRequestID)
}
return req, nil
}
// CreateWebSession creates a new web session for user without any
// checks, is used by admins
func (a *Server) CreateWebSession(ctx context.Context, user string) (types.WebSession, error) {
u, err := a.GetUser(user, false)
if err != nil {
return nil, trace.Wrap(err)
}
sess, err := a.NewWebSession(ctx, types.NewWebSessionRequest{
User: user,
Roles: u.GetRoles(),
Traits: u.GetTraits(),
LoginTime: a.clock.Now().UTC(),
})
if err != nil {
return nil, trace.Wrap(err)
}
if err := a.upsertWebSession(ctx, user, sess); err != nil {
return nil, trace.Wrap(err)
}
return sess, nil
}
// GenerateToken generates multi-purpose authentication token.
func (a *Server) GenerateToken(ctx context.Context, req *proto.GenerateTokenRequest) (string, error) {
ttl := defaults.ProvisioningTokenTTL
if req.TTL != 0 {
ttl = req.TTL.Get()
}
expires := a.clock.Now().UTC().Add(ttl)
if req.Token == "" {
token, err := utils.CryptoRandomHex(TokenLenBytes)
if err != nil {
return "", trace.Wrap(err)
}
req.Token = token
}
token, err := types.NewProvisionToken(req.Token, req.Roles, expires)
if err != nil {
return "", trace.Wrap(err)
}
if len(req.Labels) != 0 {
meta := token.GetMetadata()
meta.Labels = req.Labels
token.SetMetadata(meta)
}
if err := a.UpsertToken(ctx, token); err != nil {
return "", trace.Wrap(err)
}
userMetadata := ClientUserMetadata(ctx)
for _, role := range req.Roles {
if role == types.RoleTrustedCluster {
if err := a.emitter.EmitAuditEvent(ctx, &apievents.TrustedClusterTokenCreate{
Metadata: apievents.Metadata{
Type: events.TrustedClusterTokenCreateEvent,
Code: events.TrustedClusterTokenCreateCode,
},
UserMetadata: userMetadata,
}); err != nil {
log.WithError(err).Warn("Failed to emit trusted cluster token create event.")
}
}
}
return req.Token, nil
}
// ExtractHostID returns host id based on the hostname
func ExtractHostID(hostName string, clusterName string) (string, error) {
suffix := "." + clusterName
if !strings.HasSuffix(hostName, suffix) {
return "", trace.BadParameter("expected suffix %q in %q", suffix, hostName)
}
return strings.TrimSuffix(hostName, suffix), nil
}
// HostFQDN consists of host UUID and cluster name joined via .
func HostFQDN(hostUUID, clusterName string) string {
return fmt.Sprintf("%v.%v", hostUUID, clusterName)
}
// GenerateHostCerts generates new host certificates (signed
// by the host certificate authority) for a node.
func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
if err := req.Role.Check(); err != nil {
return nil, err
}
if err := a.limiter.AcquireConnection(req.Role.String()); err != nil {
generateThrottledRequestsCount.Inc()
log.Debugf("Node %q [%v] is rate limited: %v.", req.NodeName, req.HostID, req.Role)
return nil, trace.Wrap(err)
}
defer a.limiter.ReleaseConnection(req.Role.String())
// only observe latencies for non-throttled requests
start := a.clock.Now()
defer generateRequestsLatencies.Observe(time.Since(start).Seconds())
generateRequestsCount.Inc()
generateRequestsCurrent.Inc()
defer generateRequestsCurrent.Dec()
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
// If the request contains 0.0.0.0, this implies an advertise IP was not
// specified on the node. Try and guess what the address by replacing 0.0.0.0
// with the RemoteAddr as known to the Auth Server.
if apiutils.SliceContainsStr(req.AdditionalPrincipals, defaults.AnyAddress) {
remoteHost, err := utils.Host(req.RemoteAddr)
if err != nil {
return nil, trace.Wrap(err)
}
req.AdditionalPrincipals = utils.ReplaceInSlice(
req.AdditionalPrincipals,
defaults.AnyAddress,
remoteHost)
}
if _, _, _, _, err := ssh.ParseAuthorizedKey(req.PublicSSHKey); err != nil {
return nil, trace.BadParameter("failed to parse SSH public key")
}
cryptoPubKey, err := tlsca.ParsePublicKeyPEM(req.PublicTLSKey)
if err != nil {
return nil, trace.Wrap(err)
}
// get the certificate authority that will be signing the public key of the host,
client := a.Cache
if req.NoCache {
client = a.Services
}
ca, err := client.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
DomainName: clusterName.GetClusterName(),
}, true)
if err != nil {
return nil, trace.BadParameter("failed to load host CA for %q: %v", clusterName.GetClusterName(), err)
}
// could be a couple of scenarios, either client data is out of sync,
// or auth server is out of sync, either way, for now check that
// cache is out of sync, this will result in higher read rate
// to the backend, which is a fine tradeoff
if !req.NoCache && req.Rotation != nil && !req.Rotation.Matches(ca.GetRotation()) {
log.Debugf("Client sent rotation state %v, cache state is %v, using state from the DB.", req.Rotation, ca.GetRotation())
ca, err = a.Services.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
DomainName: clusterName.GetClusterName(),
}, true)
if err != nil {
return nil, trace.BadParameter("failed to load host CA for %q: %v", clusterName.GetClusterName(), err)
}
if !req.Rotation.Matches(ca.GetRotation()) {
return nil, trace.BadParameter(""+
"the client expected state is out of sync, server rotation state: %v, "+
"client rotation state: %v, re-register the client from scratch to fix the issue.",
ca.GetRotation(), req.Rotation)
}
}
isAdminRole := req.Role == types.RoleAdmin
cert, signer, err := a.keyStore.GetTLSCertAndSigner(ca)
if trace.IsNotFound(err) && isAdminRole {
// If there is no local TLS signer found in the host CA ActiveKeys, this
// auth server may have a newly configured HSM and has only populated
// local keys in the AdditionalTrustedKeys until the next CA rotation.
// This is the only case where we should be able to get a signer from
// AdditionalTrustedKeys but not ActiveKeys.
cert, signer, err = a.keyStore.GetAdditionalTrustedTLSCertAndSigner(ca)
}
if err != nil {
return nil, trace.Wrap(err)
}
tlsAuthority, err := tlsca.FromCertAndSigner(cert, signer)
if err != nil {
return nil, trace.Wrap(err)
}
caSigner, err := a.keyStore.GetSSHSigner(ca)
if trace.IsNotFound(err) && isAdminRole {
// If there is no local SSH signer found in the host CA ActiveKeys, this
// auth server may have a newly configured HSM and has only populated
// local keys in the AdditionalTrustedKeys until the next CA rotation.
// This is the only case where we should be able to get a signer from
// AdditionalTrustedKeys but not ActiveKeys.
caSigner, err = a.keyStore.GetAdditionalTrustedSSHSigner(ca)
}
if err != nil {
return nil, trace.Wrap(err)
}
// generate host SSH certificate
hostSSHCert, err := a.generateHostCert(services.HostCertParams{
CASigner: caSigner,
PublicHostKey: req.PublicSSHKey,
HostID: req.HostID,
NodeName: req.NodeName,
ClusterName: clusterName.GetClusterName(),
Role: req.Role,
Principals: req.AdditionalPrincipals,
})
if err != nil {
return nil, trace.Wrap(err)
}
if req.Role == types.RoleInstance && len(req.SystemRoles) == 0 {
return nil, trace.BadParameter("cannot generate instance cert with no system roles")
}
systemRoles := make([]string, 0, len(req.SystemRoles))
for _, r := range req.SystemRoles {
systemRoles = append(systemRoles, string(r))
}
// generate host TLS certificate
identity := tlsca.Identity{
Username: HostFQDN(req.HostID, clusterName.GetClusterName()),
Groups: []string{req.Role.String()},
TeleportCluster: clusterName.GetClusterName(),
SystemRoles: systemRoles,
}
subject, err := identity.Subject()
if err != nil {
return nil, trace.Wrap(err)
}
certRequest := tlsca.CertificateRequest{
Clock: a.clock,
PublicKey: cryptoPubKey,
Subject: subject,
NotAfter: a.clock.Now().UTC().Add(defaults.CATTL),
DNSNames: append([]string{}, req.AdditionalPrincipals...),
}
// API requests need to specify a DNS name, which must be present in the certificate's DNS Names.
// The target DNS is not always known in advance, so we add a default one to all certificates.
certRequest.DNSNames = append(certRequest.DNSNames, DefaultDNSNamesForRole(req.Role)...)
// Unlike additional principals, DNS Names is x509 specific and is limited
// to services with TLS endpoints (e.g. auth, proxies, kubernetes)
if (types.SystemRoles{req.Role}).IncludeAny(types.RoleAuth, types.RoleAdmin, types.RoleProxy, types.RoleKube, types.RoleWindowsDesktop) {
certRequest.DNSNames = append(certRequest.DNSNames, req.DNSNames...)
}
hostTLSCert, err := tlsAuthority.GenerateCertificate(certRequest)
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.Certs{
SSH: hostSSHCert,
TLS: hostTLSCert,
TLSCACerts: services.GetTLSCerts(ca),
SSHCACerts: services.GetSSHCheckingKeys(ca),
}, nil
}
// UnstableAssertSystemRole is not a stable part of the public API. Used by older
// instances to prove that they hold a given system role.
// DELETE IN: 12.0 (deprecated in v11, but required for back-compat with v10 clients)
func (a *Server) UnstableAssertSystemRole(ctx context.Context, req proto.UnstableSystemRoleAssertion) error {
return trace.Wrap(a.unstable.AssertSystemRole(ctx, req))
}
func (a *Server) UnstableGetSystemRoleAssertions(ctx context.Context, serverID string, assertionID string) (proto.UnstableSystemRoleAssertionSet, error) {
set, err := a.unstable.GetSystemRoleAssertions(ctx, serverID, assertionID)
return set, trace.Wrap(err)
}
func (a *Server) RegisterInventoryControlStream(ics client.UpstreamInventoryControlStream, hello proto.UpstreamInventoryHello) error {
// upstream hello is pulled and checked at rbac layer. we wait to send the downstream hello until we get here
// in order to simplify creation of in-memory streams when dealing with local auth (note: in theory we could
// send hellos simultaneously to slightly improve perf, but there is a potential benefit to having the
// downstream hello serve double-duty as an indicator of having successfully transitioned the rbac layer).
downstreamHello := proto.DownstreamInventoryHello{
Version: teleport.Version,
ServerID: a.ServerID,
}
if err := ics.Send(a.CloseContext(), downstreamHello); err != nil {
return trace.Wrap(err)
}
a.inventory.RegisterControlStream(ics, hello)
return nil
}
// MakeLocalInventoryControlStream sets up an in-memory control stream which automatically registers with this auth
// server upon hello exchange.
func (a *Server) MakeLocalInventoryControlStream(opts ...client.ICSPipeOption) client.DownstreamInventoryControlStream {
upstream, downstream := client.InventoryControlStreamPipe(opts...)
go func() {
select {
case msg := <-upstream.Recv():
hello, ok := msg.(proto.UpstreamInventoryHello)
if !ok {
upstream.CloseWithError(trace.BadParameter("expected upstream hello, got: %T", msg))
return
}
if err := a.RegisterInventoryControlStream(upstream, hello); err != nil {
upstream.CloseWithError(err)
return
}
case <-upstream.Done():
case <-a.CloseContext().Done():
upstream.Close()
}
}()
return downstream
}
func (a *Server) GetInventoryStatus(ctx context.Context, req proto.InventoryStatusRequest) proto.InventoryStatusSummary {
var rsp proto.InventoryStatusSummary
if req.Connected {
a.inventory.Iter(func(handle inventory.UpstreamHandle) {
rsp.Connected = append(rsp.Connected, handle.Hello())
})
}
return rsp
}
func (a *Server) PingInventory(ctx context.Context, req proto.InventoryPingRequest) (proto.InventoryPingResponse, error) {
stream, ok := a.inventory.GetControlStream(req.ServerID)
if !ok {
return proto.InventoryPingResponse{}, trace.NotFound("no control stream found for server %q", req.ServerID)
}
d, err := stream.Ping(ctx)
if err != nil {
return proto.InventoryPingResponse{}, trace.Wrap(err)
}
return proto.InventoryPingResponse{
Duration: d,
}, nil
}
// TokenExpiredOrNotFound is a special message returned by the auth server when provisioning
// tokens are either past their TTL, or could not be found.
const TokenExpiredOrNotFound = "token expired or not found"
// ValidateToken takes a provisioning token value and finds if it's valid. Returns
// a list of roles this token allows its owner to assume and token labels, or an error if the token
// cannot be found.
func (a *Server) ValidateToken(ctx context.Context, token string) (types.ProvisionToken, error) {
tkns, err := a.GetStaticTokens()
if err != nil {
return nil, trace.Wrap(err)
}
// First check if the token is a static token. If it is, return right away.
// Static tokens have no expiration.
for _, st := range tkns.GetStaticTokens() {
if subtle.ConstantTimeCompare([]byte(st.GetName()), []byte(token)) == 1 {
return st, nil
}
}
// If it's not a static token, check if it's a ephemeral token in the backend.
// If a ephemeral token is found, make sure it's still valid.
tok, err := a.GetToken(ctx, token)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied(TokenExpiredOrNotFound)
}
return nil, trace.Wrap(err)
}
if !a.checkTokenTTL(tok) {
return nil, trace.AccessDenied(TokenExpiredOrNotFound)
}
return tok, nil
}
// checkTokenTTL checks if the token is still valid. If it is not, the token
// is removed from the backend and returns false. Otherwise returns true.
func (a *Server) checkTokenTTL(tok types.ProvisionToken) bool {
ctx := context.TODO()
now := a.clock.Now().UTC()
if tok.Expiry().Before(now) {
err := a.DeleteToken(ctx, tok.GetName())
if err != nil {
if !trace.IsNotFound(err) {
log.Warnf("Unable to delete token from backend: %v.", err)
}
}
return false
}
return true
}
func (a *Server) DeleteToken(ctx context.Context, token string) (err error) {
tkns, err := a.GetStaticTokens()
if err != nil {
return trace.Wrap(err)
}
// is this a static token?
for _, st := range tkns.GetStaticTokens() {
if subtle.ConstantTimeCompare([]byte(st.GetName()), []byte(token)) == 1 {
return trace.BadParameter("token %s is statically configured and cannot be removed", backend.MaskKeyName(token))
}
}
// Delete a user token.
if err = a.DeleteUserToken(ctx, token); err == nil {
return nil
}
// delete node token:
if err = a.Services.DeleteToken(ctx, token); err == nil {
return nil
}
return trace.Wrap(err)
}
// GetTokens returns all tokens (machine provisioning ones and user tokens). Machine
// tokens usually have "node roles", like auth,proxy,node and user invitation tokens have 'signup' role
func (a *Server) GetTokens(ctx context.Context, opts ...services.MarshalOption) (tokens []types.ProvisionToken, err error) {
// get node tokens:
tokens, err = a.Services.GetTokens(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
// get static tokens:
tkns, err := a.GetStaticTokens()
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
if err == nil {
tokens = append(tokens, tkns.GetStaticTokens()...)
}
// get user tokens:
userTokens, err := a.GetUserTokens(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
// convert user tokens to machine tokens:
for _, t := range userTokens {
roles := types.SystemRoles{types.RoleSignup}
tok, err := types.NewProvisionToken(t.GetName(), roles, t.Expiry())
if err != nil {
return nil, trace.Wrap(err)
}
tokens = append(tokens, tok)
}
return tokens, nil
}
// NewWebSession creates and returns a new web session for the specified request
func (a *Server) NewWebSession(ctx context.Context, req types.NewWebSessionRequest) (types.WebSession, error) {
user, err := a.GetUser(req.User, false)
if err != nil {
return nil, trace.Wrap(err)
}
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
checker, err := services.NewAccessChecker(&services.AccessInfo{
Roles: req.Roles,
Traits: req.Traits,
AllowedResourceIDs: req.RequestedResourceIDs,
}, clusterName.GetClusterName(), a)
if err != nil {
return nil, trace.Wrap(err)
}
netCfg, err := a.GetClusterNetworkingConfig(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
priv, pub, err := native.GenerateKeyPair()
if err != nil {
return nil, trace.Wrap(err)
}
sessionTTL := req.SessionTTL
if sessionTTL == 0 {
sessionTTL = checker.AdjustSessionTTL(apidefaults.CertDuration)
}
certs, err := a.generateUserCert(certRequest{
user: user,
ttl: sessionTTL,
publicKey: pub,
checker: checker,
traits: req.Traits,
activeRequests: services.RequestIDs{AccessRequests: req.AccessRequests},
})
if err != nil {
return nil, trace.Wrap(err)
}
token, err := utils.CryptoRandomHex(SessionTokenBytes)
if err != nil {
return nil, trace.Wrap(err)
}
bearerToken, err := utils.CryptoRandomHex(SessionTokenBytes)
if err != nil {
return nil, trace.Wrap(err)
}
bearerTokenTTL := utils.MinTTL(sessionTTL, BearerTokenTTL)
startTime := a.clock.Now()
if !req.LoginTime.IsZero() {
startTime = req.LoginTime
}
sessionSpec := types.WebSessionSpecV2{
User: req.User,
Priv: priv,
Pub: certs.SSH,
TLSCert: certs.TLS,
Expires: startTime.UTC().Add(sessionTTL),
BearerToken: bearerToken,
BearerTokenExpires: startTime.UTC().Add(bearerTokenTTL),
LoginTime: req.LoginTime,
IdleTimeout: types.Duration(netCfg.GetWebIdleTimeout()),
}
UserLoginCount.Inc()
sess, err := types.NewWebSession(token, types.KindWebSession, sessionSpec)
if err != nil {
return nil, trace.Wrap(err)
}
return sess, nil
}
// GetWebSessionInfo returns the web session specified with sessionID for the given user.
// The session is stripped of any authentication details.
// Implements auth.WebUIService
func (a *Server) GetWebSessionInfo(ctx context.Context, user, sessionID string) (types.WebSession, error) {
sess, err := a.GetWebSession(ctx, types.GetWebSessionRequest{User: user, SessionID: sessionID})
if err != nil {
return nil, trace.Wrap(err)
}
return sess.WithoutSecrets(), nil
}
func (a *Server) DeleteNamespace(namespace string) error {
ctx := context.TODO()
if namespace == apidefaults.Namespace {
return trace.AccessDenied("can't delete default namespace")
}
nodes, err := a.GetNodes(ctx, namespace)
if err != nil {
return trace.Wrap(err)
}
if len(nodes) != 0 {
return trace.BadParameter("can't delete namespace %v that has %v registered nodes", namespace, len(nodes))
}
return a.Services.DeleteNamespace(namespace)
}
func (a *Server) CreateAccessRequest(ctx context.Context, req types.AccessRequest) error {
if err := services.ValidateAccessRequestForUser(ctx, a, req,
// if request is in state pending, variable expansion must be applied
services.ExpandVars(req.GetState().IsPending()),
); err != nil {
return trace.Wrap(err)
}
ttl, err := a.calculateMaxAccessTTL(ctx, req)
if err != nil {
return trace.Wrap(err)
}
now := a.clock.Now().UTC()
req.SetCreationTime(now)
exp := now.Add(ttl)
// Set acccess expiry if an allowable default was not provided.
if req.GetAccessExpiry().Before(now) || req.GetAccessExpiry().After(exp) {
req.SetAccessExpiry(exp)
}
// By default, resource expiry should match access expiry.
req.SetExpiry(req.GetAccessExpiry())
// If the access-request is in a pending state, then the expiry of the underlying resource
// is capped to to PendingAccessDuration in order to limit orphaned access requests.
if req.GetState().IsPending() {
pexp := now.Add(defaults.PendingAccessDuration)
if pexp.Before(req.Expiry()) {
req.SetExpiry(pexp)
}
}
if req.GetDryRun() {
// Made it this far with no errors, return before creating the request
// if this is a dry run.
return nil
}
if err := a.Services.CreateAccessRequest(ctx, req); err != nil {
return trace.Wrap(err)
}
err = a.emitter.EmitAuditEvent(a.closeCtx, &apievents.AccessRequestCreate{
Metadata: apievents.Metadata{
Type: events.AccessRequestCreateEvent,
Code: events.AccessRequestCreateCode,
},
UserMetadata: ClientUserMetadataWithUser(ctx, req.GetUser()),
ResourceMetadata: apievents.ResourceMetadata{
Expires: req.GetAccessExpiry(),
},
Roles: req.GetRoles(),
RequestedResourceIDs: types.EventResourceIDs(req.GetRequestedResourceIDs()),
RequestID: req.GetName(),
RequestState: req.GetState().String(),
Reason: req.GetRequestReason(),
})
if err != nil {
log.WithError(err).Warn("Failed to emit access request create event.")
}
return nil
}
func (a *Server) DeleteAccessRequest(ctx context.Context, name string) error {
if err := a.Services.DeleteAccessRequest(ctx, name); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.AccessRequestDelete{
Metadata: apievents.Metadata{
Type: events.AccessRequestDeleteEvent,
Code: events.AccessRequestDeleteCode,
},
UserMetadata: ClientUserMetadata(ctx),
RequestID: name,
}); err != nil {
log.WithError(err).Warn("Failed to emit access request delete event.")
}
return nil
}
func (a *Server) SetAccessRequestState(ctx context.Context, params types.AccessRequestUpdate) error {
req, err := a.Services.SetAccessRequestState(ctx, params)
if err != nil {
return trace.Wrap(err)
}
event := &apievents.AccessRequestCreate{
Metadata: apievents.Metadata{
Type: events.AccessRequestUpdateEvent,
Code: events.AccessRequestUpdateCode,
},
ResourceMetadata: apievents.ResourceMetadata{
UpdatedBy: ClientUsername(ctx),
Expires: req.GetAccessExpiry(),
},
RequestID: params.RequestID,
RequestState: params.State.String(),
Reason: params.Reason,
Roles: params.Roles,
}
if delegator := apiutils.GetDelegator(ctx); delegator != "" {
event.Delegator = delegator
}
if len(params.Annotations) > 0 {
annotations, err := apievents.EncodeMapStrings(params.Annotations)
if err != nil {
log.WithError(err).Debugf("Failed to encode access request annotations.")
} else {
event.Annotations = annotations
}
}
err = a.emitter.EmitAuditEvent(a.closeCtx, event)
if err != nil {
log.WithError(err).Warn("Failed to emit access request update event.")
}
return trace.Wrap(err)
}
func (a *Server) SubmitAccessReview(ctx context.Context, params types.AccessReviewSubmission) (types.AccessRequest, error) {
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
// set up a checker for the review author
checker, err := services.NewReviewPermissionChecker(ctx, a, params.Review.Author)
if err != nil {
return nil, trace.Wrap(err)
}
// don't bother continuing if the author has no allow directives
if !checker.HasAllowDirectives() {
return nil, trace.AccessDenied("user %q cannot submit reviews", params.Review.Author)
}
// final permission checks and review application must be done by the local backend
// service, as their validity depends upon optimistic locking.
req, err := a.ApplyAccessReview(ctx, params, checker)
if err != nil {
return nil, trace.Wrap(err)
}
event := &apievents.AccessRequestCreate{
Metadata: apievents.Metadata{
Type: events.AccessRequestReviewEvent,
Code: events.AccessRequestReviewCode,
ClusterName: clusterName.GetClusterName(),
},
ResourceMetadata: apievents.ResourceMetadata{
Expires: req.GetAccessExpiry(),
},
RequestID: params.RequestID,
RequestState: req.GetState().String(),
ProposedState: params.Review.ProposedState.String(),
Reason: params.Review.Reason,
Reviewer: params.Review.Author,
}
if len(params.Review.Annotations) > 0 {
annotations, err := apievents.EncodeMapStrings(params.Review.Annotations)
if err != nil {
log.WithError(err).Debugf("Failed to encode access request annotations.")
} else {
event.Annotations = annotations
}
}
if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil {
log.WithError(err).Warn("Failed to emit access request update event.")
}
return req, nil
}
func (a *Server) GetAccessCapabilities(ctx context.Context, req types.AccessCapabilitiesRequest) (*types.AccessCapabilities, error) {
caps, err := services.CalculateAccessCapabilities(ctx, a, req)
if err != nil {
return nil, trace.Wrap(err)
}
return caps, nil
}
// calculateMaxAccessTTL determines the maximum allowable TTL for a given access request
// based on the MaxSessionTTLs of the roles being requested (a access request's life cannot
// exceed the smallest allowable MaxSessionTTL value of the roles that it requests).
func (a *Server) calculateMaxAccessTTL(ctx context.Context, req types.AccessRequest) (time.Duration, error) {
minTTL := defaults.MaxAccessDuration
for _, roleName := range req.GetRoles() {
role, err := a.GetRole(ctx, roleName)
if err != nil {
return 0, trace.Wrap(err)
}
roleTTL := time.Duration(role.GetOptions().MaxSessionTTL)
if roleTTL > 0 && roleTTL < minTTL {
minTTL = roleTTL
}
}
return minTTL, nil
}
// NewKeepAliver returns a new instance of keep aliver
func (a *Server) NewKeepAliver(ctx context.Context) (types.KeepAliver, error) {
cancelCtx, cancel := context.WithCancel(ctx)
k := &authKeepAliver{
a: a,
ctx: cancelCtx,
cancel: cancel,
keepAlivesC: make(chan types.KeepAlive),
}
go k.forwardKeepAlives()
return k, nil
}
// GenerateCertAuthorityCRL generates an empty CRL for the local CA of a given type.
func (a *Server) GenerateCertAuthorityCRL(ctx context.Context, caType types.CertAuthType) ([]byte, error) {
// Generate a CRL for the current cluster CA.
clusterName, err := a.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
ca, err := a.GetCertAuthority(ctx, types.CertAuthID{
Type: caType,
DomainName: clusterName.GetClusterName(),
}, true)
if err != nil {
return nil, trace.Wrap(err)
}
// TODO(awly): this will only create a CRL for an active signer.
// If there are multiple signers (multiple HSMs), we won't have the full CRL coverage.
// Generate a CRL per signer and return all of them separately.
cert, signer, err := a.keyStore.GetTLSCertAndSigner(ca)
if trace.IsNotFound(err) {
// If there is no local TLS signer found in the host CA ActiveKeys, this
// auth server may have a newly configured HSM and has only populated
// local keys in the AdditionalTrustedKeys until the next CA rotation.
// This is the only case where we should be able to get a signer from
// AdditionalTrustedKeys but not ActiveKeys.
cert, signer, err = a.keyStore.GetAdditionalTrustedTLSCertAndSigner(ca)
}
if err != nil {
return nil, trace.Wrap(err)
}
tlsAuthority, err := tlsca.FromCertAndSigner(cert, signer)
if err != nil {
return nil, trace.Wrap(err)
}
// Empty CRL valid for 1yr.
template := &x509.RevocationList{
Number: big.NewInt(1),
ThisUpdate: time.Now().Add(-1 * time.Minute), // 1 min in the past to account for clock skew.
NextUpdate: time.Now().Add(365 * 24 * time.Hour),
}
crl, err := x509.CreateRevocationList(rand.Reader, template, tlsAuthority.Cert, tlsAuthority.Signer)
if err != nil {
return nil, trace.Wrap(err)
}
return crl, nil
}
// ErrDone indicates that resource iteration is complete
var ErrDone = errors.New("done iterating")
// IterateResources loads all resources matching the provided request and passes them one by one to the provided
// callback function. To stop iteration callers may return ErrDone from the callback function, which will result in
// a nil return from IterateResources. Any other errors returned from the callback function cause iteration to stop
// and the error to be returned.
func (a *Server) IterateResources(ctx context.Context, req proto.ListResourcesRequest, f func(resource types.ResourceWithLabels) error) error {
for {
resp, err := a.ListResources(ctx, req)
if err != nil {
return trace.Wrap(err)
}
for _, resource := range resp.Resources {
if err := f(resource); err != nil {
if errors.Is(err, ErrDone) {
return nil
}
return trace.Wrap(err)
}
}
if resp.NextKey == "" {
return nil
}
req.StartKey = resp.NextKey
}
}
// CreateAuditStream creates audit event stream
func (a *Server) CreateAuditStream(ctx context.Context, sid session.ID) (apievents.Stream, error) {
streamer, err := a.modeStreamer(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return streamer.CreateAuditStream(ctx, sid)
}
// ResumeAuditStream resumes the stream that has been created
func (a *Server) ResumeAuditStream(ctx context.Context, sid session.ID, uploadID string) (apievents.Stream, error) {
streamer, err := a.modeStreamer(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return streamer.ResumeAuditStream(ctx, sid, uploadID)
}
// modeStreamer creates streamer based on the event mode
func (a *Server) modeStreamer(ctx context.Context) (events.Streamer, error) {
recConfig, err := a.GetSessionRecordingConfig(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
// In sync mode, auth server forwards session control to the event log
// in addition to sending them and data events to the record storage.
if services.IsRecordSync(recConfig.GetMode()) {
return events.NewTeeStreamer(a.streamer, a.emitter), nil
}
// In async mode, clients submit session control events
// during the session in addition to writing a local
// session recording to be uploaded at the end of the session,
// so forwarding events here will result in duplicate events.
return a.streamer, nil
}
// CreateApp creates a new application resource.
func (a *Server) CreateApp(ctx context.Context, app types.Application) error {
if err := a.Services.CreateApp(ctx, app); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.AppCreate{
Metadata: apievents.Metadata{
Type: events.AppCreateEvent,
Code: events.AppCreateCode,
},
UserMetadata: ClientUserMetadata(ctx),
ResourceMetadata: apievents.ResourceMetadata{
Name: app.GetName(),
Expires: app.Expiry(),
},
AppMetadata: apievents.AppMetadata{
AppURI: app.GetURI(),
AppPublicAddr: app.GetPublicAddr(),
AppLabels: app.GetStaticLabels(),
},
}); err != nil {
log.WithError(err).Warn("Failed to emit app create event.")
}
return nil
}
// UpdateApp updates an existing application resource.
func (a *Server) UpdateApp(ctx context.Context, app types.Application) error {
if err := a.Services.UpdateApp(ctx, app); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.AppUpdate{
Metadata: apievents.Metadata{
Type: events.AppUpdateEvent,
Code: events.AppUpdateCode,
},
UserMetadata: ClientUserMetadata(ctx),
ResourceMetadata: apievents.ResourceMetadata{
Name: app.GetName(),
Expires: app.Expiry(),
},
AppMetadata: apievents.AppMetadata{
AppURI: app.GetURI(),
AppPublicAddr: app.GetPublicAddr(),
AppLabels: app.GetStaticLabels(),
},
}); err != nil {
log.WithError(err).Warn("Failed to emit app update event.")
}
return nil
}
// DeleteApp deletes an application resource.
func (a *Server) DeleteApp(ctx context.Context, name string) error {
if err := a.Services.DeleteApp(ctx, name); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.AppDelete{
Metadata: apievents.Metadata{
Type: events.AppDeleteEvent,
Code: events.AppDeleteCode,
},
UserMetadata: ClientUserMetadata(ctx),
ResourceMetadata: apievents.ResourceMetadata{
Name: name,
},
}); err != nil {
log.WithError(err).Warn("Failed to emit app delete event.")
}
return nil
}
// CreateSessionTracker creates a tracker resource for an active session.
func (a *Server) CreateSessionTracker(ctx context.Context, tracker types.SessionTracker) (types.SessionTracker, error) {
// Don't allow sessions that require moderation without the enterprise feature enabled.
for _, policySet := range tracker.GetHostPolicySets() {
if len(policySet.RequireSessionJoin) != 0 {
if !modules.GetModules().Features().ModeratedSessions {
return nil, trace.AccessDenied("this Teleport cluster is not licensed for moderated sessions, please contact the cluster administrator")
}
}
}
return a.Services.CreateSessionTracker(ctx, tracker)
}
// CreateDatabase creates a new database resource.
func (a *Server) CreateDatabase(ctx context.Context, database types.Database) error {
if err := a.Services.CreateDatabase(ctx, database); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.DatabaseCreate{
Metadata: apievents.Metadata{
Type: events.DatabaseCreateEvent,
Code: events.DatabaseCreateCode,
},
UserMetadata: ClientUserMetadata(ctx),
ResourceMetadata: apievents.ResourceMetadata{
Name: database.GetName(),
Expires: database.Expiry(),
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseProtocol: database.GetProtocol(),
DatabaseURI: database.GetURI(),
DatabaseLabels: database.GetStaticLabels(),
DatabaseAWSRegion: database.GetAWS().Region,
DatabaseAWSRedshiftClusterID: database.GetAWS().Redshift.ClusterID,
DatabaseGCPProjectID: database.GetGCP().ProjectID,
DatabaseGCPInstanceID: database.GetGCP().InstanceID,
},
}); err != nil {
log.WithError(err).Warn("Failed to emit database create event.")
}
return nil
}
// UpdateDatabase updates an existing database resource.
func (a *Server) UpdateDatabase(ctx context.Context, database types.Database) error {
if err := a.Services.UpdateDatabase(ctx, database); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.DatabaseUpdate{
Metadata: apievents.Metadata{
Type: events.DatabaseUpdateEvent,
Code: events.DatabaseUpdateCode,
},
UserMetadata: ClientUserMetadata(ctx),
ResourceMetadata: apievents.ResourceMetadata{
Name: database.GetName(),
Expires: database.Expiry(),
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseProtocol: database.GetProtocol(),
DatabaseURI: database.GetURI(),
DatabaseLabels: database.GetStaticLabels(),
DatabaseAWSRegion: database.GetAWS().Region,
DatabaseAWSRedshiftClusterID: database.GetAWS().Redshift.ClusterID,
DatabaseGCPProjectID: database.GetGCP().ProjectID,
DatabaseGCPInstanceID: database.GetGCP().InstanceID,
},
}); err != nil {
log.WithError(err).Warn("Failed to emit database update event.")
}
return nil
}
// DeleteDatabase deletes a database resource.
func (a *Server) DeleteDatabase(ctx context.Context, name string) error {
if err := a.Services.DeleteDatabase(ctx, name); err != nil {
return trace.Wrap(err)
}
if err := a.emitter.EmitAuditEvent(ctx, &apievents.DatabaseDelete{
Metadata: apievents.Metadata{
Type: events.DatabaseDeleteEvent,
Code: events.DatabaseDeleteCode,
},
UserMetadata: ClientUserMetadata(ctx),
ResourceMetadata: apievents.ResourceMetadata{
Name: name,
},
}); err != nil {
log.WithError(err).Warn("Failed to emit database delete event.")
}
return nil
}
// ListResources returns paginated resources depending on the resource type..
func (a *Server) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) {
// Because WindowsDesktopService does not contain the desktop resources,
// this is not implemented at the cache level and requires the workaround
// here in order to support KindWindowsDesktop for ListResources.
if req.ResourceType == types.KindWindowsDesktop {
wResp, err := a.ListWindowsDesktops(ctx, types.ListWindowsDesktopsRequest{
WindowsDesktopFilter: req.WindowsDesktopFilter,
Limit: int(req.Limit),
StartKey: req.StartKey,
PredicateExpression: req.PredicateExpression,
Labels: req.Labels,
SearchKeywords: req.SearchKeywords,
})
if err != nil {
return nil, trace.Wrap(err)
}
return &types.ListResourcesResponse{
Resources: types.WindowsDesktops(wResp.Desktops).AsResources(),
NextKey: wResp.NextKey,
}, nil
}
return a.Cache.ListResources(ctx, req)
}
func (a *Server) isMFARequired(ctx context.Context, checker services.AccessChecker, req *proto.IsMFARequiredRequest) (*proto.IsMFARequiredResponse, error) {
pref, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if pref.GetRequireSessionMFA() {
// Cluster always requires MFA, regardless of roles.
return &proto.IsMFARequiredResponse{Required: true}, nil
}
var noMFAAccessErr, notFoundErr error
switch t := req.Target.(type) {
case *proto.IsMFARequiredRequest_Node:
if t.Node.Node == "" {
return nil, trace.BadParameter("empty Node field")
}
if t.Node.Login == "" {
return nil, trace.BadParameter("empty Login field")
}
// Find the target node and check whether MFA is required.
nodes, err := a.GetNodes(ctx, apidefaults.Namespace)
if err != nil {
return nil, trace.Wrap(err)
}
var matches []types.Server
for _, n := range nodes {
// Get the server address without port number.
addr, _, err := net.SplitHostPort(n.GetAddr())
if err != nil {
addr = n.GetAddr()
}
// Match NodeName to UUID, hostname or self-reported server address.
if n.GetName() == t.Node.Node || n.GetHostname() == t.Node.Node || addr == t.Node.Node {
matches = append(matches, n)
}
}
if len(matches) == 0 {
// If t.Node.Node is not a known registered node, it may be an
// unregistered host running OpenSSH with a certificate created via
// `tctl auth sign`. In these cases, let the user through without
// extra checks.
//
// If t.Node.Node turns out to be an alias for a real node (e.g.
// private network IP), and MFA check was actually required, the
// Node itself will check the cert extensions and reject the
// connection.
return &proto.IsMFARequiredResponse{Required: false}, nil
}
// Check RBAC against all matching nodes and return the first error.
// If at least one node requires MFA, we'll catch it.
for _, n := range matches {
err := checker.CheckAccess(
n,
services.AccessMFAParams{},
services.NewLoginMatcher(t.Node.Login),
)
// Ignore other errors; they'll be caught on the real access attempt.
if err != nil && errors.Is(err, services.ErrSessionMFARequired) {
noMFAAccessErr = err
break
}
}
case *proto.IsMFARequiredRequest_KubernetesCluster:
notFoundErr = trace.NotFound("kubernetes cluster %q not found", t.KubernetesCluster)
if t.KubernetesCluster == "" {
return nil, trace.BadParameter("missing KubernetesCluster field in a kubernetes-only UserCertsRequest")
}
// Find the target cluster and check whether MFA is required.
svcs, err := a.GetKubernetesServers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
var cluster types.KubeCluster
for _, svc := range svcs {
kubeCluster := svc.GetCluster()
if kubeCluster.GetName() == t.KubernetesCluster {
cluster = kubeCluster
break
}
}
if cluster == nil {
return nil, trace.Wrap(notFoundErr)
}
noMFAAccessErr = checker.CheckAccess(cluster, services.AccessMFAParams{})
case *proto.IsMFARequiredRequest_Database:
notFoundErr = trace.NotFound("database service %q not found", t.Database.ServiceName)
if t.Database.ServiceName == "" {
return nil, trace.BadParameter("missing ServiceName field in a database-only UserCertsRequest")
}
servers, err := a.GetDatabaseServers(ctx, apidefaults.Namespace)
if err != nil {
return nil, trace.Wrap(err)
}
var db types.Database
for _, server := range servers {
if server.GetDatabase().GetName() == t.Database.ServiceName {
db = server.GetDatabase()
break
}
}
if db == nil {
return nil, trace.Wrap(notFoundErr)
}
dbRoleMatchers := role.DatabaseRoleMatchers(
db.GetProtocol(),
t.Database.Username,
t.Database.GetDatabase(),
)
noMFAAccessErr = checker.CheckAccess(
db,
services.AccessMFAParams{},
dbRoleMatchers...,
)
case *proto.IsMFARequiredRequest_WindowsDesktop:
desktops, err := a.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{Name: t.WindowsDesktop.GetWindowsDesktop()})
if err != nil {
return nil, trace.Wrap(err)
}
if len(desktops) == 0 {
return nil, trace.NotFound("windows desktop %q not found", t.WindowsDesktop.GetWindowsDesktop())
}
noMFAAccessErr = checker.CheckAccess(desktops[0],
services.AccessMFAParams{},
services.NewWindowsLoginMatcher(t.WindowsDesktop.GetLogin()))
default:
return nil, trace.BadParameter("unknown Target %T", req.Target)
}
// No error means that MFA is not required for this resource by
// AccessChecker.
if noMFAAccessErr == nil {
return &proto.IsMFARequiredResponse{Required: false}, nil
}
// Errors other than ErrSessionMFARequired mean something else is wrong,
// most likely access denied.
if !errors.Is(noMFAAccessErr, services.ErrSessionMFARequired) {
if !trace.IsAccessDenied(noMFAAccessErr) {
log.WithError(noMFAAccessErr).Warn("Could not determine MFA access")
}
// Mask the access denied errors by returning false to prevent resource
// name oracles. Auth will be denied (and generate an audit log entry)
// when the client attempts to connect.
return &proto.IsMFARequiredResponse{Required: false}, nil
}
// If we reach here, the error from AccessChecker was
// ErrSessionMFARequired.
return &proto.IsMFARequiredResponse{Required: true}, nil
}
// mfaAuthChallenge constructs an MFAAuthenticateChallenge for all MFA devices
// registered by the user.
func (a *Server) mfaAuthChallenge(ctx context.Context, user string, passwordless bool) (*proto.MFAAuthenticateChallenge, error) {
// Check what kind of MFA is enabled.
apref, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
enableTOTP := apref.IsSecondFactorTOTPAllowed()
enableWebauthn := apref.IsSecondFactorWebauthnAllowed()
// Fetch configurations. The IsSecondFactor*Allowed calls above already
// include the necessary checks of config empty, disabled, etc.
var u2fPref *types.U2F
switch val, err := apref.GetU2F(); {
case trace.IsNotFound(err): // OK, may happen.
case err != nil: // NOK, unexpected.
return nil, trace.Wrap(err)
default:
u2fPref = val
}
var webConfig *types.Webauthn
switch val, err := apref.GetWebauthn(); {
case trace.IsNotFound(err): // OK, may happen.
case err != nil: // NOK, unexpected.
return nil, trace.Wrap(err)
default:
webConfig = val
}
// Handle passwordless separately, it works differently from MFA.
if passwordless {
if !enableWebauthn {
return nil, trace.BadParameter("passwordless requires WebAuthn")
}
webLogin := &wanlib.PasswordlessFlow{
Webauthn: webConfig,
Identity: a.Services,
}
assertion, err := webLogin.Begin(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return &proto.MFAAuthenticateChallenge{
WebauthnChallenge: wanlib.CredentialAssertionToProto(assertion),
}, nil
}
// User required for non-passwordless.
if user == "" {
return nil, trace.BadParameter("user required")
}
devs, err := a.Services.GetMFADevices(ctx, user, true /* withSecrets */)
if err != nil {
return nil, trace.Wrap(err)
}
groupedDevs := groupByDeviceType(devs, enableWebauthn)
challenge := &proto.MFAAuthenticateChallenge{}
// TOTP challenge.
if enableTOTP && groupedDevs.TOTP {
challenge.TOTP = &proto.TOTPChallenge{}
}
// WebAuthn challenge.
if len(groupedDevs.Webauthn) > 0 {
webLogin := &wanlib.LoginFlow{
U2F: u2fPref,
Webauthn: webConfig,
Identity: wanlib.WithDevices(a.Services, groupedDevs.Webauthn),
}
assertion, err := webLogin.Begin(ctx, user)
if err != nil {
return nil, trace.Wrap(err)
}
challenge.WebauthnChallenge = wanlib.CredentialAssertionToProto(assertion)
}
return challenge, nil
}
type devicesByType struct {
TOTP bool
Webauthn []*types.MFADevice
}
func groupByDeviceType(devs []*types.MFADevice, groupWebauthn bool) devicesByType {
res := devicesByType{}
for _, dev := range devs {
switch dev.Device.(type) {
case *types.MFADevice_Totp:
res.TOTP = true
case *types.MFADevice_U2F:
if groupWebauthn {
res.Webauthn = append(res.Webauthn, dev)
}
case *types.MFADevice_Webauthn:
if groupWebauthn {
res.Webauthn = append(res.Webauthn, dev)
}
default:
log.Warningf("Skipping MFA device of unknown type %T.", dev.Device)
}
}
return res
}
// validateMFAAuthResponse validates an MFA or passwordless challenge.
// Returns the device used to solve the challenge (if applicable) and the
// username.
func (a *Server) validateMFAAuthResponse(
ctx context.Context,
resp *proto.MFAAuthenticateResponse, user string, passwordless bool,
) (*types.MFADevice, string, error) {
// Sanity check user/passwordless.
if user == "" && !passwordless {
return nil, "", trace.BadParameter("user required")
}
switch res := resp.Response.(type) {
// cases in order of preference
case *proto.MFAAuthenticateResponse_Webauthn:
// Read necessary configurations.
cap, err := a.GetAuthPreference(ctx)
if err != nil {
return nil, "", trace.Wrap(err)
}
u2f, err := cap.GetU2F()
switch {
case trace.IsNotFound(err): // OK, may happen.
case err != nil: // Unexpected.
return nil, "", trace.Wrap(err)
}
webConfig, err := cap.GetWebauthn()
if err != nil {
return nil, "", trace.Wrap(err)
}
assertionResp := wanlib.CredentialAssertionResponseFromProto(res.Webauthn)
var dev *types.MFADevice
if passwordless {
webLogin := &wanlib.PasswordlessFlow{
Webauthn: webConfig,
Identity: a.Services,
}
dev, user, err = webLogin.Finish(ctx, assertionResp)
} else {
webLogin := &wanlib.LoginFlow{
U2F: u2f,
Webauthn: webConfig,
Identity: a.Services,
}
dev, err = webLogin.Finish(ctx, user, wanlib.CredentialAssertionResponseFromProto(res.Webauthn))
}
if err != nil {
return nil, "", trace.AccessDenied("MFA response validation failed: %v", err)
}
return dev, user, nil
case *proto.MFAAuthenticateResponse_TOTP:
dev, err := a.checkOTP(user, res.TOTP.Code)
return dev, user, trace.Wrap(err)
default:
return nil, "", trace.BadParameter("unknown or missing MFAAuthenticateResponse type %T", resp.Response)
}
}
func (a *Server) upsertWebSession(ctx context.Context, user string, session types.WebSession) error {
if err := a.WebSessions().Upsert(ctx, session); err != nil {
return trace.Wrap(err)
}
token, err := types.NewWebToken(session.GetBearerTokenExpiryTime(), types.WebTokenSpecV3{
User: session.GetUser(),
Token: session.GetBearerToken(),
})
if err != nil {
return trace.Wrap(err)
}
if err := a.WebTokens().Upsert(ctx, token); err != nil {
return trace.Wrap(err)
}
return nil
}
func mergeKeySets(a, b types.CAKeySet) types.CAKeySet {
newKeySet := a.Clone()
newKeySet.SSH = append(newKeySet.SSH, b.SSH...)
newKeySet.TLS = append(newKeySet.TLS, b.TLS...)
newKeySet.JWT = append(newKeySet.JWT, b.JWT...)
return newKeySet
}
// addAdditionalTrustedKeysAtomic performs an atomic CompareAndSwap to update
// the given CA with newKeys added to the AdditionalTrustedKeys
func (a *Server) addAddtionalTrustedKeysAtomic(
ctx context.Context,
currentCA types.CertAuthority,
newKeys types.CAKeySet,
needsUpdate func(types.CertAuthority) bool,
) error {
for {
select {
case <-a.closeCtx.Done():
return trace.Wrap(a.closeCtx.Err())
default:
}
if !needsUpdate(currentCA) {
return nil
}
newCA := currentCA.Clone()
currentKeySet := newCA.GetAdditionalTrustedKeys()
mergedKeySet := mergeKeySets(currentKeySet, newKeys)
if err := newCA.SetAdditionalTrustedKeys(mergedKeySet); err != nil {
return trace.Wrap(err)
}
err := a.CompareAndSwapCertAuthority(newCA, currentCA)
if err != nil && !trace.IsCompareFailed(err) {
return trace.Wrap(err)
}
if err == nil {
// success!
return nil
}
// else trace.IsCompareFailed(err) == true (CA was concurrently updated)
currentCA, err = a.Services.GetCertAuthority(ctx, currentCA.GetID(), true)
if err != nil {
return trace.Wrap(err)
}
}
}
// newKeySet generates a new sets of keys for a given CA type.
// Keep this function in sync with lib/service/suite/suite.go:NewTestCAWithConfig().
func newKeySet(keyStore keystore.KeyStore, caID types.CertAuthID) (types.CAKeySet, error) {
var keySet types.CAKeySet
switch caID.Type {
case types.UserCA, types.HostCA:
sshKeyPair, err := keyStore.NewSSHKeyPair()
if err != nil {
return keySet, trace.Wrap(err)
}
tlsKeyPair, err := keyStore.NewTLSKeyPair(caID.DomainName)
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.SSH = append(keySet.SSH, sshKeyPair)
keySet.TLS = append(keySet.TLS, tlsKeyPair)
case types.DatabaseCA:
// Database CA only contains TLS cert.
tlsKeyPair, err := keyStore.NewTLSKeyPair(caID.DomainName)
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.TLS = append(keySet.TLS, tlsKeyPair)
case types.JWTSigner:
jwtKeyPair, err := keyStore.NewJWTKeyPair()
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.JWT = append(keySet.JWT, jwtKeyPair)
default:
return keySet, trace.BadParameter("unknown ca type: %s", caID.Type)
}
return keySet, nil
}
// ensureLocalAdditionalKeys adds additional trusted keys to the CA if they are not
// already present.
func (a *Server) ensureLocalAdditionalKeys(ctx context.Context, ca types.CertAuthority) error {
if a.keyStore.HasLocalAdditionalKeys(ca) {
// nothing to do
return nil
}
newKeySet, err := newKeySet(a.keyStore, ca.GetID())
if err != nil {
return trace.Wrap(err)
}
err = a.addAddtionalTrustedKeysAtomic(ctx, ca, newKeySet, func(ca types.CertAuthority) bool {
return !a.keyStore.HasLocalAdditionalKeys(ca)
})
if err != nil {
return trace.Wrap(err)
}
log.Infof("Successfully added local additional trusted keys to %s CA.", ca.GetType())
return nil
}
// createSelfSignedCA creates a new self-signed CA and writes it to the
// backend, with the type and clusterName given by the argument caID.
func (a *Server) createSelfSignedCA(caID types.CertAuthID) error {
keySet, err := newKeySet(a.keyStore, caID)
if err != nil {
return trace.Wrap(err)
}
ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
Type: caID.Type,
ClusterName: caID.DomainName,
ActiveKeys: keySet,
})
if err != nil {
return trace.Wrap(err)
}
if err := a.CreateCertAuthority(ca); err != nil {
return trace.Wrap(err)
}
return nil
}
// deleteUnusedKeys deletes all teleport keys held in a connected HSM for this
// auth server which are not currently used in any CAs.
func (a *Server) deleteUnusedKeys(ctx context.Context) error {
clusterName, err := a.Services.GetClusterName()
if err != nil {
return trace.Wrap(err)
}
var usedKeys [][]byte
for _, caType := range types.CertAuthTypes {
caID := types.CertAuthID{Type: caType, DomainName: clusterName.GetClusterName()}
ca, err := a.Services.GetCertAuthority(ctx, caID, true)
if err != nil {
return trace.Wrap(err)
}
for _, keySet := range []types.CAKeySet{ca.GetActiveKeys(), ca.GetAdditionalTrustedKeys()} {
for _, sshKeyPair := range keySet.SSH {
usedKeys = append(usedKeys, sshKeyPair.PrivateKey)
}
for _, tlsKeyPair := range keySet.TLS {
usedKeys = append(usedKeys, tlsKeyPair.Key)
}
for _, jwtKeyPair := range keySet.JWT {
usedKeys = append(usedKeys, jwtKeyPair.PrivateKey)
}
}
}
return trace.Wrap(a.keyStore.DeleteUnusedKeys(usedKeys))
}
// authKeepAliver is a keep aliver using auth server directly
type authKeepAliver struct {
sync.RWMutex
a *Server
ctx context.Context
cancel context.CancelFunc
keepAlivesC chan types.KeepAlive
err error
}
// KeepAlives returns a channel accepting keep alive requests
func (k *authKeepAliver) KeepAlives() chan<- types.KeepAlive {
return k.keepAlivesC
}
func (k *authKeepAliver) forwardKeepAlives() {
for {
select {
case <-k.a.closeCtx.Done():
k.Close()
return
case <-k.ctx.Done():
return
case keepAlive := <-k.keepAlivesC:
err := k.a.KeepAliveServer(k.ctx, keepAlive)
if err != nil {
k.closeWithError(err)
return
}
}
}
}
func (k *authKeepAliver) closeWithError(err error) {
k.Close()
k.Lock()
defer k.Unlock()
k.err = err
}
// Error returns the error if keep aliver
// has been closed
func (k *authKeepAliver) Error() error {
k.RLock()
defer k.RUnlock()
return k.err
}
// Done returns channel that is closed whenever
// keep aliver is closed
func (k *authKeepAliver) Done() <-chan struct{} {
return k.ctx.Done()
}
// Close closes keep aliver and cancels all goroutines
func (k *authKeepAliver) Close() error {
k.cancel()
return nil
}
const (
// BearerTokenTTL specifies standard bearer token to exist before
// it has to be renewed by the client
BearerTokenTTL = 10 * time.Minute
// TokenLenBytes is len in bytes of the invite token
TokenLenBytes = 16
// RecoveryTokenLenBytes is len in bytes of a user token for recovery.
RecoveryTokenLenBytes = 32
// SessionTokenBytes is the number of bytes of a web or application session.
SessionTokenBytes = 32
)
// oidcClient is internal structure that stores OIDC client and its config
type oidcClient struct {
client *oidc.Client
connector types.OIDCConnector
// syncCtx controls the provider sync goroutine.
syncCtx context.Context
syncCancel context.CancelFunc
// firstSync will be closed once the first provider sync succeeds
firstSync chan struct{}
}
// samlProvider is internal structure that stores SAML client and its config
type samlProvider struct {
provider *saml2.SAMLServiceProvider
connector types.SAMLConnector
}
// githubClient is internal structure that stores Github OAuth 2client and its config
type githubClient struct {
client *oauth2.Client
config oauth2.Config
}
// oauth2ConfigsEqual returns true if the provided OAuth2 configs are equal
func oauth2ConfigsEqual(a, b oauth2.Config) bool {
if a.Credentials.ID != b.Credentials.ID {
return false
}
if a.Credentials.Secret != b.Credentials.Secret {
return false
}
if a.RedirectURL != b.RedirectURL {
return false
}
if len(a.Scope) != len(b.Scope) {
return false
}
for i := range a.Scope {
if a.Scope[i] != b.Scope[i] {
return false
}
}
if a.AuthURL != b.AuthURL {
return false
}
if a.TokenURL != b.TokenURL {
return false
}
if a.AuthMethod != b.AuthMethod {
return false
}
return true
}
// isHTTPS checks if the scheme for a URL is https or not.
func isHTTPS(u string) error {
earl, err := url.Parse(u)
if err != nil {
return trace.Wrap(err)
}
if earl.Scheme != "https" {
return trace.BadParameter("expected scheme https, got %q", earl.Scheme)
}
return nil
}
// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
if info.ServerName != "" {
// Newer clients will set SNI that encodes the cluster name.
clusterName, err = apiutils.DecodeClusterName(info.ServerName)
if err != nil {
if !trace.IsNotFound(err) {
log.Debugf("Ignoring unsupported cluster name name %q.", info.ServerName)
clusterName = ""
}
}
}
pool, totalSubjectsLen, err := DefaultClientCertPool(ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
// this falls back to the default config
return nil, nil
}
// Per https://tools.ietf.org/html/rfc5246#section-7.4.4 the total size of
// the known CA subjects sent to the client can't exceed 2^16-1 (due to
// 2-byte length encoding). The crypto/tls stack will panic if this
// happens.
//
// This usually happens on the root cluster with a very large (>500) number
// of leaf clusters. In these cases, the client cert will be signed by the
// current (root) cluster.
//
// If the number of CAs turns out too large for the handshake, drop all but
// the current cluster CA. In the unlikely case where it's wrong, the
// client will be rejected.
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")
pool, _, err = DefaultClientCertPool(ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
// this falls back to the default config
return nil, nil
}
}
tlsCopy := tlsConfig.Clone()
tlsCopy.ClientCAs = pool
return tlsCopy, nil
}
}
// DefaultDNSNamesForRole returns default DNS names for the specified role.
func DefaultDNSNamesForRole(role types.SystemRole) []string {
if (types.SystemRoles{role}).IncludeAny(types.RoleAuth, types.RoleAdmin, types.RoleProxy, types.RoleKube, types.RoleApp, types.RoleDatabase, types.RoleWindowsDesktop) {
return []string{
"*." + constants.APIDomain,
constants.APIDomain,
}
}
return nil
}