Implment access-request system (workflow API)

This commit is contained in:
Forrest Marshall 2019-11-06 21:00:32 -08:00 committed by Forrest Marshall
parent a32468033a
commit ec327b6e03
32 changed files with 4725 additions and 425 deletions

View file

@ -373,6 +373,9 @@ const (
CertExtensionTeleportRouteToCluster = "teleport-route-to-cluster"
// CertExtensionTeleportTraits is used to propagate traits about the user.
CertExtensionTeleportTraits = "teleport-traits"
// CertExtensionTeleportActiveRequests is used to track which privilege
// escalation requests were used to construct the certificate.
CertExtensionTeleportActiveRequests = "teleport-active-requests"
)
const (

View file

@ -77,6 +77,9 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro
if cfg.Access == nil {
cfg.Access = local.NewAccessService(cfg.Backend)
}
if cfg.DynamicAccess == nil {
cfg.DynamicAccess = local.NewDynamicAccessService(cfg.Backend)
}
if cfg.ClusterConfiguration == nil {
cfg.ClusterConfiguration = local.NewClusterConfigurationService(cfg.Backend)
}
@ -111,6 +114,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro
Provisioner: cfg.Provisioner,
Identity: cfg.Identity,
Access: cfg.Access,
DynamicAccess: cfg.DynamicAccess,
ClusterConfiguration: cfg.ClusterConfiguration,
IAuditLog: cfg.AuditLog,
Events: cfg.Events,
@ -132,6 +136,7 @@ type AuthServices struct {
services.Provisioner
services.Identity
services.Access
services.DynamicAccess
services.ClusterConfiguration
services.Events
events.IAuditLog
@ -408,6 +413,9 @@ type certRequest struct {
routeToCluster string
// traits hold claim data used to populate a role at runtime.
traits wrappers.Traits
// activeRequests tracks privilege escalation requests applied
// during the construction of the certificate.
activeRequests services.RequestIDs
}
// GenerateUserTestCerts is used to generate user certificate, used internally for tests
@ -509,6 +517,7 @@ func (s *AuthServer) generateUserCert(req certRequest) (*certs, error) {
PermitAgentForwarding: req.checker.CanForwardAgents(),
RouteToCluster: req.routeToCluster,
Traits: req.traits,
ActiveRequests: req.activeRequests,
})
if err != nil {
return nil, trace.Wrap(err)
@ -1350,6 +1359,77 @@ func (a *AuthServer) DeleteRole(name string) error {
return a.Access.DeleteRole(name)
}
func (a *AuthServer) CreateAccessRequest(req services.AccessRequest) error {
if err := services.ValidateAccessRequest(a, req); err != nil {
return trace.Wrap(err)
}
ttl, err := a.calculateMaxAccessTTL(req)
if err != nil {
return trace.Wrap(err)
}
now := a.clock.Now().UTC()
req.SetCreationTime(now)
exp := now.Add(ttl)
// Set acccess expiry if an allowable default was not provided.
if req.GetAccessExpiry().Before(now) || req.GetAccessExpiry().After(exp) {
req.SetAccessExpiry(exp)
}
// By default, resource expiry should match access expiry.
req.SetExpiry(req.GetAccessExpiry())
// If the access-request is in a pending state, then the expiry of the underlying resource
// is capped to to PendingAccessDuration in order to limit orphaned access requests.
if req.GetState().IsPending() {
pexp := now.Add(defaults.PendingAccessDuration)
if pexp.Before(req.Expiry()) {
req.SetExpiry(pexp)
}
}
if err := a.DynamicAccess.CreateAccessRequest(req); err != nil {
return trace.Wrap(err)
}
err = a.EmitAuditEvent(events.AccessRequestCreated, events.EventFields{
events.AccessRequestID: req.GetName(),
events.EventUser: req.GetUser(),
events.UserRoles: req.GetRoles(),
events.AccessRequestState: req.GetState().String(),
})
return trace.Wrap(err)
}
func (a *AuthServer) SetAccessRequestState(reqID string, state services.RequestState, updatedBy ...string) error {
if err := a.DynamicAccess.SetAccessRequestState(reqID, state); err != nil {
return trace.Wrap(err)
}
u := "unknown"
if len(updatedBy) == 1 {
u = updatedBy[0]
}
err := a.EmitAuditEvent(events.AccessRequestUpdated, events.EventFields{
events.AccessRequestID: reqID,
events.AccessRequestState: state.String(),
events.AccessRequestUpdateBy: u,
})
return trace.Wrap(err)
}
// calculateMaxAccessTTL determines the maximum allowable TTL for a given access request
// based on the MaxSessionTTLs of the roles being requested (a access request's life cannot
// exceed the smallest allowable MaxSessionTTL value of the roles that it requests).
func (a *AuthServer) calculateMaxAccessTTL(req services.AccessRequest) (time.Duration, error) {
minTTL := defaults.MaxAccessDuration
for _, roleName := range req.GetRoles() {
role, err := a.GetRole(roleName)
if err != nil {
return 0, trace.Wrap(err)
}
roleTTL := time.Duration(role.GetOptions().MaxSessionTTL)
if roleTTL > 0 && roleTTL < minTTL {
minTTL = roleTTL
}
}
return minTTL, nil
}
// NewKeepAliver returns a new instance of keep aliver
func (a *AuthServer) NewKeepAliver(ctx context.Context) (services.KeepAliver, error) {
cancelCtx, cancel := context.WithCancel(ctx)

View file

@ -452,7 +452,16 @@ func (a *AuthWithRoles) NewWatcher(ctx context.Context, watch services.Watch) (s
return nil, trace.Wrap(err)
}
}
case services.KindAccessRequest:
var filter services.AccessRequestFilter
if err := filter.FromMap(kind.Filter); err != nil {
return nil, trace.Wrap(err)
}
if filter.User == "" || a.currentUserAction(filter.User) != nil {
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbRead); err != nil {
return nil, trace.Wrap(err)
}
}
default:
return nil, trace.AccessDenied("not authorized to watch %v events", kind.Kind)
}
@ -778,6 +787,47 @@ func (a *AuthWithRoles) DeleteWebSession(user string, sid string) error {
return a.authServer.DeleteWebSession(user, sid)
}
func (a *AuthWithRoles) GetAccessRequests(filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
// An exception is made to allow users to get their own access requests.
if filter.User == "" || a.currentUserAction(filter.User) != nil {
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbList); err != nil {
return nil, trace.Wrap(err)
}
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbRead); err != nil {
return nil, trace.Wrap(err)
}
}
return a.authServer.GetAccessRequests(filter)
}
func (a *AuthWithRoles) CreateAccessRequest(req services.AccessRequest) error {
// An exception is made to allow users to create access *pending* requests for themselves.
if !req.GetState().IsPending() || a.currentUserAction(req.GetUser()) != nil {
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbCreate); err != nil {
return trace.Wrap(err)
}
}
// Ensure that an access request cannot outlive the identity that creates it.
if req.GetAccessExpiry().Before(a.authServer.GetClock().Now()) || req.GetAccessExpiry().After(a.identity.Expires) {
req.SetAccessExpiry(a.identity.Expires)
}
return a.authServer.CreateAccessRequest(req)
}
func (a *AuthWithRoles) SetAccessRequestState(reqID string, state services.RequestState) error {
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbUpdate); err != nil {
return trace.Wrap(err)
}
return a.authServer.SetAccessRequestState(reqID, state, a.user.GetName())
}
func (a *AuthWithRoles) DeleteAccessRequest(name string) error {
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbUpdate); err != nil {
return trace.Wrap(err)
}
return a.authServer.DeleteAccessRequest(name)
}
func (a *AuthWithRoles) GetUsers(withSecrets bool) ([]services.User, error) {
if withSecrets {
// TODO(fspmarshall): replace admin requirement with VerbReadWithSecrets once we've
@ -908,6 +958,44 @@ func (a *AuthWithRoles) GenerateUserCerts(ctx context.Context, req proto.UserCer
return nil, trace.AccessDenied("this request can be only executed by an admin")
}
// TODO(fspmarshall): Move this logic to AuthServer.
if len(req.AccessRequests) > 0 {
// add any applicable access request values.
for _, reqID := range req.AccessRequests {
accessReq, err := a.authServer.GetAccessRequest(reqID)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied("invalid access request %q", reqID)
}
return nil, trace.Wrap(err)
}
if accessReq.GetUser() != req.Username {
return nil, trace.AccessDenied("invalid access request %q", reqID)
}
if !accessReq.GetState().IsApproved() {
if accessReq.GetState().IsDenied() {
return nil, trace.AccessDenied("access-request %q has been denied", reqID)
}
return nil, trace.AccessDenied("access-request %q is awaiting approval", reqID)
}
if err := services.ValidateAccessRequest(a.authServer, accessReq); err != nil {
return nil, trace.Wrap(err)
}
aexp := accessReq.GetAccessExpiry()
if aexp.Before(a.authServer.GetClock().Now()) {
return nil, trace.AccessDenied("access-request %q is expired", reqID)
}
if aexp.Before(req.Expires) {
// cannot generate a cert that would outlive the access request
req.Expires = aexp
}
roles = append(roles, accessReq.GetRoles()...)
}
// nothing prevents an access-request from including roles already posessed by the
// user, so we must make sure to trim duplicate roles.
roles = utils.Deduplicate(roles)
}
// Extract the user and role set for whom the certificate will be generated.
user, err := a.GetUser(req.Username, false)
if err != nil {
@ -929,6 +1017,9 @@ func (a *AuthWithRoles) GenerateUserCerts(ctx context.Context, req proto.UserCer
routeToCluster: req.RouteToCluster,
checker: checker,
traits: traits,
activeRequests: services.RequestIDs{
AccessRequests: req.AccessRequests,
},
})
if err != nil {
return nil, trace.Wrap(err)

View file

@ -123,7 +123,7 @@ func DecodeClusterName(serverName string) (string, error) {
}
const suffix = "." + teleport.APIDomain
if !strings.HasSuffix(serverName, suffix) {
return "", trace.BadParameter("unrecognized name, expected suffix %v, got %q", teleport.APIDomain, serverName)
return "", trace.NotFound("no cluster name is encoded")
}
clusterName := strings.TrimSuffix(serverName, suffix)
@ -810,6 +810,7 @@ func (c *Client) NewWatcher(ctx context.Context, watch services.Watch) (services
Name: kind.Name,
Kind: kind.Kind,
LoadSecrets: kind.LoadSecrets,
Filter: kind.Filter,
})
}
stream, err := clt.WatchEvents(cancelCtx, &protoWatch)
@ -2530,6 +2531,67 @@ func (c *Client) DeleteTrustedCluster(name string) error {
return trace.Wrap(err)
}
func (c *Client) GetAccessRequests(filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
clt, err := c.grpc()
if err != nil {
return nil, trace.Wrap(err)
}
rsp, err := clt.GetAccessRequests(context.TODO(), &filter)
if err != nil {
return nil, trail.FromGRPC(err)
}
reqs := make([]services.AccessRequest, 0, len(rsp.AccessRequests))
for _, req := range rsp.AccessRequests {
reqs = append(reqs, req)
}
return reqs, nil
}
func (c *Client) CreateAccessRequest(req services.AccessRequest) error {
r, ok := req.(*services.AccessRequestV3)
if !ok {
return trace.BadParameter("unexpected access request type %T", req)
}
clt, err := c.grpc()
if err != nil {
return trace.Wrap(err)
}
_, err = clt.CreateAccessRequest(context.TODO(), r)
if err != nil {
return trail.FromGRPC(err)
}
return nil
}
func (c *Client) DeleteAccessRequest(reqID string) error {
clt, err := c.grpc()
if err != nil {
return trace.Wrap(err)
}
_, err = clt.DeleteAccessRequest(context.TODO(), &proto.RequestID{
ID: reqID,
})
if err != nil {
return trail.FromGRPC(err)
}
return nil
}
func (c *Client) SetAccessRequestState(reqID string, state services.RequestState) error {
clt, err := c.grpc()
if err != nil {
return trace.Wrap(err)
}
_, err = clt.SetAccessRequestState(context.TODO(), &proto.RequestStateSetter{
ID: reqID,
State: state,
})
if err != nil {
return trail.FromGRPC(err)
}
return nil
}
// WebService implements features used by Web UI clients
type WebService interface {
// GetWebSessionInfo checks if a web sesion is valid, returns session id in case if
@ -2750,4 +2812,12 @@ type ClientI interface {
// ProcessKubeCSR processes CSR request against Kubernetes CA, returns
// signed certificate if sucessful.
ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error)
// GetAccessRequests lists all existing access requests.
GetAccessRequests(services.AccessRequestFilter) ([]services.AccessRequest, error)
// CreateAccessRequest creates a new access request.
CreateAccessRequest(req services.AccessRequest) error
// DeleteAccessRequest deletes an access request.
DeleteAccessRequest(reqID string) error
// SetAccessRequestState updates the state of an existing access request.
SetAccessRequestState(reqID string, state services.RequestState) error
}

View file

@ -85,6 +85,7 @@ func (g *GRPCServer) WatchEvents(watch *proto.Watch, stream proto.AuthService_Wa
Name: kind.Name,
Kind: kind.Kind,
LoadSecrets: kind.LoadSecrets,
Filter: kind.Filter,
})
}
watcher, err := auth.NewWatcher(stream.Context(), servicesWatch)
@ -185,6 +186,66 @@ func (g *GRPCServer) GetUsers(req *proto.GetUsersRequest, stream proto.AuthServi
return nil
}
func (g *GRPCServer) GetAccessRequests(ctx context.Context, f *services.AccessRequestFilter) (*proto.AccessRequests, error) {
auth, err := g.authenticate(ctx)
if err != nil {
return nil, trail.ToGRPC(err)
}
var filter services.AccessRequestFilter
if f != nil {
filter = *f
}
reqs, err := auth.AuthWithRoles.GetAccessRequests(filter)
if err != nil {
return nil, trail.ToGRPC(err)
}
collector := make([]*services.AccessRequestV3, 0, len(reqs))
for _, req := range reqs {
r, ok := req.(*services.AccessRequestV3)
if !ok {
err = trace.BadParameter("unexpected access request type %T", req)
return nil, trail.ToGRPC(err)
}
collector = append(collector, r)
}
return &proto.AccessRequests{
AccessRequests: collector,
}, nil
}
func (g *GRPCServer) CreateAccessRequest(ctx context.Context, req *services.AccessRequestV3) (*empty.Empty, error) {
auth, err := g.authenticate(ctx)
if err != nil {
return nil, trail.ToGRPC(err)
}
if err := auth.AuthWithRoles.CreateAccessRequest(req); err != nil {
return nil, trail.ToGRPC(err)
}
return &empty.Empty{}, nil
}
func (g *GRPCServer) DeleteAccessRequest(ctx context.Context, id *proto.RequestID) (*empty.Empty, error) {
auth, err := g.authenticate(ctx)
if err != nil {
return nil, trail.ToGRPC(err)
}
if err := auth.AuthWithRoles.DeleteAccessRequest(id.ID); err != nil {
return nil, trail.ToGRPC(err)
}
return &empty.Empty{}, nil
}
func (g *GRPCServer) SetAccessRequestState(ctx context.Context, req *proto.RequestStateSetter) (*empty.Empty, error) {
auth, err := g.authenticate(ctx)
if err != nil {
return nil, trail.ToGRPC(err)
}
if err := auth.SetAccessRequestState(req.ID, req.State); err != nil {
return nil, trail.ToGRPC(err)
}
return &empty.Empty{}, nil
}
type grpcContext struct {
*AuthContext
*AuthWithRoles
@ -305,6 +366,10 @@ func eventToGRPC(in services.Event) (*proto.Event, error) {
out.Resource = &proto.Event_TunnelConnection{
TunnelConnection: r,
}
case *services.AccessRequestV3:
out.Resource = &proto.Event_AccessRequest{
AccessRequest: r,
}
default:
return nil, trace.BadParameter("resource type %T is not supported", in.Resource)
}
@ -371,6 +436,9 @@ func eventFromGRPC(in proto.Event) (*services.Event, error) {
} else if r := in.GetTunnelConnection(); r != nil {
out.Resource = r
return &out, nil
} else if r := in.GetAccessRequest(); r != nil {
out.Resource = r
return &out, nil
} else {
return nil, trace.BadParameter("received unsupported resource %T", in.Resource)
}

View file

@ -641,6 +641,37 @@ type clt interface {
UpsertUser(services.User) error
}
// CreateUserRoleAndRequestable creates two roles for a user, one base role with allowed login
// matching username, and another role with a login matching rolename that can be requested.
func CreateUserRoleAndRequestable(clt clt, username string, rolename string) (services.User, error) {
user, err := services.NewUser(username)
if err != nil {
return nil, trace.Wrap(err)
}
baseRole := services.RoleForUser(user)
baseRole.SetLogins(services.Allow, []string{username})
baseRole.SetAccessRequestConditions(services.Allow, services.AccessRequestConditions{
Roles: []string{rolename},
})
err = clt.UpsertRole(baseRole)
if err != nil {
return nil, trace.Wrap(err)
}
user.AddRole(baseRole.GetName())
err = clt.UpsertUser(user)
if err != nil {
return nil, trace.Wrap(err)
}
requestableRole := services.RoleForUser(user)
requestableRole.SetName(rolename)
requestableRole.SetLogins(services.Allow, []string{rolename})
err = clt.UpsertRole(requestableRole)
if err != nil {
return nil, trace.Wrap(err)
}
return user, nil
}
// CreateUserAndRole creates user and role and assignes role to a user, used in tests
func CreateUserAndRole(clt clt, username string, allowedLogins []string) (services.User, services.Role, error) {
user, err := services.NewUser(username)

View file

@ -105,6 +105,9 @@ type InitConfig struct {
// Access is service controlling access to resources
Access services.Access
// DynamicAccess is a service that manages dynamic RBAC.
DynamicAccess services.DynamicAccess
// Events is an event service
Events services.Events

View file

@ -291,6 +291,13 @@ func (k *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) {
if c.RouteToCluster != "" {
cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = c.RouteToCluster
}
if !c.ActiveRequests.IsEmpty() {
requests, err := c.ActiveRequests.Marshal()
if err != nil {
return nil, trace.Wrap(err)
}
cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests)
}
}
signer, err := ssh.ParsePrivateKey(c.PrivateCASigningKey)

File diff suppressed because it is too large Load diff

View file

@ -53,6 +53,8 @@ message Event {
services.ReverseTunnelV2 ReverseTunnel = 12 [(gogoproto.jsontag) = "reverse_tunnel,omitempty"];
// TunnelConnection is a resource for tunnel connnections
services.TunnelConnectionV2 TunnelConnection = 13 [(gogoproto.jsontag) = "tunnel_connection,omitempty"];
// AccessRequest is a resource for access requests
services.AccessRequestV3 AccessRequest = 14 [(gogoproto.jsontag) = "access_request,omitempty"];
}
}
@ -72,6 +74,9 @@ message WatchKind {
// if specified only the events with a specific resource
// name will be sent
string Name = 3 [(gogoproto.jsontag) = "name"];
// Filter is an optional mapping of custom filter parameters.
// Valid values vary by resource kind.
map <string, string> Filter = 4 [(gogoproto.jsontag) = "filter,omitempty"];
}
// Set of certificates corresponding to a single public key.
@ -100,6 +105,9 @@ message UserCertsRequest {
// so that requests originating with this certificate will be redirected
// to this cluster
string RouteToCluster = 5 [(gogoproto.jsontag) = "route_to_cluster,omitempty"];
// AccessRequests is an optional list of request IDs indicating requests whose
// escalated privileges should be added to the certificate.
repeated string AccessRequests = 6 [(gogoproto.jsontag) = "access_requests,omitempty"];
}
// GetUserRequest specifies paramters for the GetUser method.
@ -116,6 +124,23 @@ message GetUsersRequest {
bool WithSecrets = 1 [(gogoproto.jsontag) = "with_secrets"];
}
// AccessRequests is a collection of AccessRequest values.
message AccessRequests {
repeated services.AccessRequestV3 AccessRequests = 1 [(gogoproto.jsontag) = "access_requests"];
}
// RequestStateSetter encodes the paramters necessary to update the
// state of a privilege escalation request.
message RequestStateSetter {
string ID = 1 [(gogoproto.jsontag) = "id"];
services.RequestState State = 2 [(gogoproto.jsontag) = "state"];
}
// RequestID is the unique identifier of an access request.
message RequestID {
string ID = 1 [(gogoproto.jsontag) = "id"];
}
// AuthService is authentication/authorization service implementation
service AuthService {
// SendKeepAlives allows node to send a stream of keep alive requests
@ -130,4 +155,12 @@ service AuthService {
rpc GetUser(GetUserRequest) returns (services.UserV2);
// GetUsers gets all current user resources.
rpc GetUsers(GetUsersRequest) returns (stream services.UserV2);
// GetAccessRequests gets all pending access requests.
rpc GetAccessRequests(services.AccessRequestFilter) returns (AccessRequests);
// CreateAccessRequest creates a new access request.
rpc CreateAccessRequest(services.AccessRequestV3) returns (google.protobuf.Empty);
// DeleteAccessRequest deletes an access request.
rpc DeleteAccessRequest(RequestID) returns (google.protobuf.Empty);
// SetAccessRequestState sets the state of an access request.
rpc SetAccessRequestState(RequestStateSetter) returns (google.protobuf.Empty);
}

View file

@ -1357,6 +1357,99 @@ func (s *TLSSuite) TestGetCertAuthority(c *check.C) {
fixtures.ExpectAccessDenied(c, err)
}
func (s *TLSSuite) TestAccessRequest(c *check.C) {
priv, pub, err := s.server.Auth().GenerateKeyPair("")
c.Assert(err, check.IsNil)
// make sure we can parse the private and public key
privateKey, err := ssh.ParseRawPrivateKey(priv)
c.Assert(err, check.IsNil)
_, err = tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey)
c.Assert(err, check.IsNil)
_, _, _, _, err = ssh.ParseAuthorizedKey(pub)
c.Assert(err, check.IsNil)
user := "user1"
role := "some-role"
_, err = CreateUserRoleAndRequestable(s.server.Auth(), user, role)
c.Assert(err, check.IsNil)
testUser := TestUser(user)
testUser.TTL = time.Hour
userClient, err := s.server.NewClient(testUser)
c.Assert(err, check.IsNil)
req, err := services.NewAccessRequest(user, role)
c.Assert(err, check.IsNil)
c.Assert(userClient.CreateAccessRequest(req), check.IsNil)
// sanity check; ensure that roles for which no `allow` directive
// exists cannot be requested.
badReq, err := services.NewAccessRequest(user, "some-fake-role")
c.Assert(err, check.IsNil)
c.Assert(userClient.CreateAccessRequest(badReq), check.NotNil)
// generateCerts executes a GenerateUserCerts request, optionally applying
// one or more access-requests to the certificate.
generateCerts := func(reqIDs ...string) (*proto.Certs, error) {
return userClient.GenerateUserCerts(context.TODO(), proto.UserCertsRequest{
PublicKey: pub,
Username: user,
Expires: time.Now().Add(time.Hour).UTC(),
Format: teleport.CertificateFormatStandard,
AccessRequests: reqIDs,
})
}
// certContainsRole checks if a PEM encoded TLS cert contains the
// specified role.
certContainsRole := func(cert []byte, role string) bool {
tlsCert, err := tlsca.ParseCertificatePEM(cert)
c.Assert(err, check.IsNil)
identity, err := tlsca.FromSubject(tlsCert.Subject, tlsCert.NotAfter)
c.Assert(err, check.IsNil)
return utils.SliceContainsStr(identity.Groups, role)
}
// sanity check; ensure that role is not held if no request is applied.
userCerts, err := generateCerts()
c.Assert(err, check.IsNil)
if certContainsRole(userCerts.TLS, role) {
c.Errorf("unexpected role %s", role)
}
// attempt to apply request in PENDING state (should fail)
_, err = generateCerts(req.GetName())
c.Assert(err, check.NotNil)
// verify that user does not have the ability to approve their own request (not a special case, this
// user just wasn't created with the necessary roles for request management).
c.Assert(userClient.SetAccessRequestState(req.GetName(), services.RequestState_APPROVED), check.NotNil)
// attempt to apply request in APPROVED state (should succeed)
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_APPROVED), check.IsNil)
userCerts, err = generateCerts(req.GetName())
c.Assert(err, check.IsNil)
// ensure that the requested role was actually applied to the cert
if !certContainsRole(userCerts.TLS, role) {
c.Errorf("missing requested role %s", role)
}
// attempt to apply request in DENIED state (should fail)
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_DENIED), check.IsNil)
_, err = generateCerts(req.GetName())
c.Assert(err, check.NotNil)
// ensure that once in the DENIED state, a request cannot be set back to PENDING state.
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_PENDING), check.NotNil)
// ensure that once in the DENIED state, a request cannot be set back to APPROVED state.
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_APPROVED), check.NotNil)
}
// TestGenerateCerts tests edge cases around authorization of
// certificate generation for servers and users
func (s *TLSSuite) TestGenerateCerts(c *check.C) {

View file

@ -297,6 +297,10 @@ type ProfileStatus struct {
// Traits hold claim data used to populate a role at runtime.
Traits wrappers.Traits
// ActiveRequests tracks the privilege escalation requests applied
// during certificate construction.
ActiveRequests services.RequestIDs
}
// IsExpired returns true if profile is not expired yet
@ -397,13 +401,22 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
}
}
var activeRequests services.RequestIDs
rawRequests, ok := cert.Extensions[teleport.CertExtensionTeleportActiveRequests]
if ok {
if err := activeRequests.Unmarshal([]byte(rawRequests)); err != nil {
return nil, trace.Wrap(err)
}
}
// Extract extensions from certificate. This lists the abilities of the
// certificate (like can the user request a PTY, port forwarding, etc.)
var extensions []string
for ext, _ := range cert.Extensions {
if ext == teleport.CertExtensionTeleportRoles ||
ext == teleport.CertExtensionTeleportTraits ||
ext == teleport.CertExtensionTeleportRouteToCluster {
ext == teleport.CertExtensionTeleportRouteToCluster ||
ext == teleport.CertExtensionTeleportActiveRequests {
continue
}
extensions = append(extensions, ext)
@ -426,13 +439,14 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
Scheme: "https",
Host: profile.WebProxyAddr,
},
Username: profile.Username,
Logins: cert.ValidPrincipals,
ValidUntil: validUntil,
Extensions: extensions,
Roles: roles,
Cluster: clusterName,
Traits: traits,
Username: profile.Username,
Logins: cert.ValidPrincipals,
ValidUntil: validUntil,
Extensions: extensions,
Roles: roles,
Cluster: clusterName,
Traits: traits,
ActiveRequests: activeRequests,
}, nil
}
@ -875,6 +889,38 @@ func (tc *TeleportClient) GenerateCertsForCluster(ctx context.Context, routeToCl
return proxyClient.GenerateCertsForCluster(ctx, routeToCluster)
}
func (tc *TeleportClient) ReissueUserCerts(ctx context.Context, params ReissueParams) error {
proxyClient, err := tc.ConnectToProxy(ctx)
if err != nil {
return trace.Wrap(err)
}
return proxyClient.ReissueUserCerts(ctx, params)
}
func (tc *TeleportClient) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error {
proxyClient, err := tc.ConnectToProxy(ctx)
if err != nil {
return trace.Wrap(err)
}
return proxyClient.CreateAccessRequest(ctx, req)
}
func (tc *TeleportClient) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
proxyClient, err := tc.ConnectToProxy(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return proxyClient.GetAccessRequests(ctx, filter)
}
func (tc *TeleportClient) NewWatcher(ctx context.Context, watch services.Watch) (services.Watcher, error) {
proxyClient, err := tc.ConnectToProxy(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
return proxyClient.NewWatcher(ctx, watch)
}
// SSH connects to a node and, if 'command' is specified, executes the command on it,
// otherwise runs interactive shell
//

View file

@ -166,6 +166,104 @@ func (proxy *ProxyClient) GenerateCertsForCluster(ctx context.Context, routeToCl
return trace.Wrap(err)
}
// ReissueParams encodes optional paramters for
// user certificate reissue.
type ReissueParams struct {
RouteToCluster string
AccessRequests []string
}
// ReissueUserCerts generates certificates for the user
// that have a metadata instructing server to route the requests to the cluster
func (proxy *ProxyClient) ReissueUserCerts(ctx context.Context, params ReissueParams) error {
localAgent := proxy.teleportClient.LocalAgent()
key, err := localAgent.GetKey()
if err != nil {
return trace.Wrap(err)
}
cert, err := key.SSHCert()
if err != nil {
return trace.Wrap(err)
}
tlsCert, err := key.TLSCertificate()
if err != nil {
return trace.Wrap(err)
}
clusterName, err := tlsca.ClusterName(tlsCert.Issuer)
if err != nil {
return trace.Wrap(err)
}
clt, err := proxy.ConnectToCluster(ctx, clusterName, true)
if err != nil {
return trace.Wrap(err)
}
if params.RouteToCluster != "" {
// Before requesting a certificate, check if the requested cluster is valid.
_, err = clt.GetCertAuthority(services.CertAuthID{
Type: services.HostCA,
DomainName: params.RouteToCluster,
}, false)
if err != nil {
return trace.NotFound("cluster %v not found", params.RouteToCluster)
}
}
req := proto.UserCertsRequest{
Username: cert.KeyId,
PublicKey: key.Pub,
Expires: time.Unix(int64(cert.ValidBefore), 0),
RouteToCluster: params.RouteToCluster,
AccessRequests: params.AccessRequests,
}
if _, ok := cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles]; !ok {
req.Format = teleport.CertificateFormatOldSSH
}
certs, err := clt.GenerateUserCerts(ctx, req)
if err != nil {
return trace.Wrap(err)
}
key.Cert = certs.SSH
key.TLSCert = certs.TLS
// save the cert to the local storage (~/.tsh usually):
_, err = localAgent.AddKey(key)
return trace.Wrap(err)
}
// CreateAccessRequest attempts to create a new request for escalated privilege.
func (proxy *ProxyClient) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error {
site, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return trace.Wrap(err)
}
return site.CreateAccessRequest(req)
}
func (proxy *ProxyClient) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
site, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return nil, trace.Wrap(err)
}
reqs, err := site.GetAccessRequests(filter)
if err != nil {
return nil, trace.Wrap(err)
}
return reqs, nil
}
func (proxy *ProxyClient) NewWatcher(ctx context.Context, watch services.Watch) (services.Watcher, error) {
site, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return nil, trace.Wrap(err)
}
watcher, err := site.NewWatcher(ctx, watch)
if err != nil {
return nil, trace.Wrap(err)
}
return watcher, nil
}
// FindServersByLabels returns list of the nodes which have labels exactly matching
// the given label set.
//

View file

@ -17,6 +17,8 @@ limitations under the License.
package client
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"os"
@ -53,6 +55,10 @@ const (
// two different files (in the same directory)
IdentityFormatOpenSSH IdentityFileFormat = "openssh"
// IdentityFormatTLS is a standard TLS format used by common TLS clients (e.g. GRPC) where
// certificate and key are stored in separate files.
IdentityFormatTLS IdentityFileFormat = "tls"
// DefaultIdentityFormat is what Teleport uses by default
DefaultIdentityFormat = IdentityFormatFile
)
@ -130,9 +136,119 @@ func MakeIdentityFile(filePath string, key *Key, format IdentityFileFormat, cert
if err != nil {
return trace.Wrap(err)
}
case IdentityFormatTLS:
keyPath := filePath + ".key"
certPath := filePath + ".crt"
casPath := filePath + ".cas"
err = ioutil.WriteFile(certPath, key.TLSCert, fileMode)
if err != nil {
return trace.Wrap(err)
}
err = ioutil.WriteFile(keyPath, key.Priv, fileMode)
if err != nil {
return trace.Wrap(err)
}
var caCerts []byte
for _, ca := range certAuthorities {
for _, keyPair := range ca.GetTLSKeyPairs() {
caCerts = append(caCerts, keyPair.Cert...)
}
}
err = ioutil.WriteFile(casPath, caCerts, fileMode)
if err != nil {
return trace.Wrap(err)
}
default:
return trace.BadParameter("unsupported identity format: %q, use either %q or %q",
format, IdentityFormatFile, IdentityFormatOpenSSH)
return trace.BadParameter("unsupported identity format: %q, use one of %q, %q, or %q",
format, IdentityFormatFile, IdentityFormatOpenSSH, IdentityFormatTLS)
}
return nil
}
// IdentityFile represents the basic components of an identity file.
type IdentityFile struct {
PrivateKey []byte
Certs struct {
SSH []byte
TLS []byte
}
CACerts struct {
SSH [][]byte
TLS [][]byte
}
}
// DecodeIdentityFile attempts to break up the contents of an identity file
// into its respective components.
func DecodeIdentityFile(r io.Reader) (*IdentityFile, error) {
scanner := bufio.NewScanner(r)
var ident IdentityFile
// Subslice of scanner's buffer pointing to current line
// with leading and trailing whitespace trimmed.
var line []byte
// Attempt to scan to the next line.
scanln := func() bool {
if !scanner.Scan() {
line = nil
return false
}
line = bytes.TrimSpace(scanner.Bytes())
return true
}
// Check if the current line starts with prefix `p`.
peekln := func(p string) bool {
return bytes.HasPrefix(line, []byte(p))
}
// Get an "owned" copy of the current line.
cloneln := func() []byte {
ln := make([]byte, len(line))
copy(ln, line)
return ln
}
// Scan through all lines of identity file. Lines with a known prefix
// are copied out of the scanner's buffer. All others are ignored.
for scanln() {
switch {
case peekln("ssh"):
ident.Certs.SSH = cloneln()
case peekln("@cert-authority"):
ident.CACerts.SSH = append(ident.CACerts.SSH, cloneln())
case peekln("-----BEGIN"):
// Current line marks the beginning of a PEM block. Consume all
// lines until a corresponding END is found.
var pemBlock []byte
for {
pemBlock = append(pemBlock, line...)
pemBlock = append(pemBlock, '\n')
if peekln("-----END") {
break
}
if !scanln() {
// If scanner has terminated in the middle of a PEM block, either
// the reader encountered an error, or the PEM block is a fragment.
if err := scanner.Err(); err != nil {
return nil, trace.Wrap(err)
}
return nil, trace.BadParameter("invalid PEM block (fragment)")
}
}
// Decide where to place the pem block based on
// which pem blocks have already been found.
switch {
case ident.PrivateKey == nil:
ident.PrivateKey = pemBlock
case ident.Certs.TLS == nil:
ident.Certs.TLS = pemBlock
default:
ident.CACerts.TLS = append(ident.CACerts.TLS, pemBlock)
}
}
}
if err := scanner.Err(); err != nil {
return nil, trace.Wrap(err)
}
return &ident, nil
}

View file

@ -368,6 +368,11 @@ const (
// certificate rotations, by default to set to maximum allowed user
// cert duration
RotationGracePeriod = MaxCertDuration
// PendingAccessDuration defines the expiry of a pending access request.
PendingAccessDuration = time.Hour
// MaxAccessDuration defines the maximum time for which an access request
// can be active.
MaxAccessDuration = MaxCertDuration
)
// list of roles teleport service can run as:

View file

@ -139,6 +139,17 @@ const (
// UserConnector is the connector used to create the user.
UserConnector = "connector"
// AccessRequestCreateEvent is emitted when a new access request is created.
AccessRequestCreateEvent = "access_request.create"
// AccessRequestUpdateEvent is emitted when a request's state is updated.
AccessRequestUpdateEvent = "access_request.update"
// AccessRequestUpdateBy indicates the user that updated the request state.
AccessRequestUpdateBy = "updated_by"
// AccessRequestState is the state of a request.
AccessRequestState = "state"
// AccessRequestID is the ID of an access request.
AccessRequestID = "id"
// ExecEvent is an exec command executed by script or user on
// the server side
ExecEvent = "exec"

View file

@ -150,6 +150,15 @@ var (
Name: AuthAttemptEvent,
Code: AuthAttemptFailureCode,
}
// AccessRequestCreated is emitted when an access request is created.
AccessRequestCreated = Event{
Name: AccessRequestCreateEvent,
Code: AccessRequestCreateCode,
}
AccessRequestUpdated = Event{
Name: AccessRequestUpdateEvent,
Code: AccessRequestUpdateCode,
}
)
var (
@ -203,4 +212,8 @@ var (
ClientDisconnectCode = "T3006I"
// AuthAttemptFailureCode is the auth attempt failure event code.
AuthAttemptFailureCode = "T3007W"
// AccessRequestCreateCode is the the access request creation code.
AccessRequestCreateCode = "T5000I"
// AccessRequestUpdateCode is the access request state update code.
AccessRequestUpdateCode = "T5001I"
)

View file

@ -0,0 +1,508 @@
/*
Copyright 2019 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package services
import (
"fmt"
"time"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
)
// RequestIDs is a collection of IDs for privelege escalation requests.
type RequestIDs struct {
AccessRequests []string `json:"access_requests,omitempty"`
}
func (r *RequestIDs) Marshal() ([]byte, error) {
data, err := utils.FastMarshal(r)
if err != nil {
return nil, trace.Wrap(err)
}
return data, nil
}
func (r *RequestIDs) Unmarshal(data []byte) error {
if err := utils.FastUnmarshal(data, r); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(r.Check())
}
func (r *RequestIDs) Check() error {
for _, id := range r.AccessRequests {
if uuid.Parse(id) == nil {
return trace.BadParameter("invalid request id %q", id)
}
}
return nil
}
func (r *RequestIDs) IsEmpty() bool {
return len(r.AccessRequests) < 1
}
// stateVariants allows iteration of the expected variants
// of RequestState.
var stateVariants = [4]RequestState{
RequestState_NONE,
RequestState_PENDING,
RequestState_APPROVED,
RequestState_DENIED,
}
// Parse attempts to interpret a value as a string representation
// of a RequestState.
func (s *RequestState) Parse(val string) error {
for _, state := range stateVariants {
if state.String() == val {
*s = state
return nil
}
}
return trace.BadParameter("unknown request state: %q", val)
}
// key values for map encoding of request filter
const (
keyID = "id"
keyUser = "user"
keyState = "state"
)
func (f *AccessRequestFilter) IntoMap() map[string]string {
m := make(map[string]string)
if f.ID != "" {
m[keyID] = f.ID
}
if f.User != "" {
m[keyUser] = f.User
}
if !f.State.IsNone() {
m[keyState] = f.State.String()
}
return m
}
func (f *AccessRequestFilter) FromMap(m map[string]string) error {
for key, val := range m {
switch key {
case keyID:
f.ID = val
case keyUser:
f.User = val
case keyState:
if err := f.State.Parse(val); err != nil {
return trace.Wrap(err)
}
default:
return trace.BadParameter("unknown filter key %s", key)
}
}
return nil
}
// Match checks if a given access request matches this filter.
func (f *AccessRequestFilter) Match(req AccessRequest) bool {
if f.ID != "" && req.GetName() != f.ID {
return false
}
if f.User != "" && req.GetUser() != f.User {
return false
}
if !f.State.IsNone() && req.GetState() != f.State {
return false
}
return true
}
func (f *AccessRequestFilter) Equals(o AccessRequestFilter) bool {
return f.ID == o.ID && f.User == o.User && f.State == o.State
}
// DynamicAccess is a service which manages dynamic RBAC.
type DynamicAccess interface {
// CreateAccessRequest stores a new access request.
CreateAccessRequest(AccessRequest) error
// SetAccessRequestState updates the state of an existing access request.
SetAccessRequestState(reqID string, state RequestState) error
// GetAccessRequest gets an access request by name (uuid).
GetAccessRequest(string) (AccessRequest, error)
// GetAccessRequests gets all currently active access requests.
GetAccessRequests(AccessRequestFilter) ([]AccessRequest, error)
// DeleteAccessRequest deletes an access request.
DeleteAccessRequest(string) error
}
// AccessRequest is a request for temporarily granted roles
type AccessRequest interface {
Resource
// GetUser gets the name of the requesting user
GetUser() string
// GetRoles gets the roles being requested by the user
GetRoles() []string
// GetState gets the current state of the request
GetState() RequestState
// SetState sets the approval state of the request
SetState(RequestState) error
// GetCreationTime gets the time at which the request was
// originally registered with the auth server.
GetCreationTime() time.Time
// SetCreationTime sets the creation time of the request.
SetCreationTime(time.Time)
// GetAccessExpiry gets the upper limit for which this request
// may be considered active.
GetAccessExpiry() time.Time
// SetAccessExpiry sets the upper limit for which this request
// may be considered active.
SetAccessExpiry(time.Time)
// CheckAndSetDefaults validates the access request and
// supplies default values where appropriate.
CheckAndSetDefaults() error
// Equals checks equality between access request values.
Equals(AccessRequest) bool
}
func (s RequestState) IsNone() bool {
return s == RequestState_NONE
}
func (s RequestState) IsPending() bool {
return s == RequestState_PENDING
}
func (s RequestState) IsApproved() bool {
return s == RequestState_APPROVED
}
func (s RequestState) IsDenied() bool {
return s == RequestState_DENIED
}
// NewAccessRequest assembled an AccessReqeust resource.
func NewAccessRequest(user string, roles ...string) (AccessRequest, error) {
req := AccessRequestV3{
Kind: KindAccessRequest,
Version: V3,
Metadata: Metadata{
Name: uuid.New(),
},
Spec: AccessRequestSpecV3{
User: user,
Roles: roles,
State: RequestState_PENDING,
},
}
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &req, nil
}
type UserAndRoleGetter interface {
UserGetter
RoleGetter
}
func ValidateAccessRequest(getter UserAndRoleGetter, req AccessRequest) error {
user, err := getter.GetUser(req.GetUser(), false)
if err != nil {
return trace.Wrap(err)
}
type rstate struct {
allowed bool
denied bool
}
roleStates := make(map[string]rstate, len(req.GetRoles()))
for _, r := range req.GetRoles() {
roleStates[r] = rstate{false, false}
}
for _, roleName := range user.GetRoles() {
role, err := getter.GetRole(roleName)
if err != nil {
return trace.Wrap(err)
}
Allow:
for _, r := range role.GetAccessRequestConditions(Allow).Roles {
s, ok := roleStates[r]
if !ok {
continue Allow
}
s.allowed = true
roleStates[r] = s
}
Deny:
for _, r := range role.GetAccessRequestConditions(Deny).Roles {
s, ok := roleStates[r]
if !ok {
continue Deny
}
s.denied = true
roleStates[r] = s
}
}
for roleName, roleState := range roleStates {
if roleState.denied || !roleState.allowed {
return trace.BadParameter("user %q cannot request role %q", req.GetUser(), roleName)
}
}
return nil
}
func (r *AccessRequestV3) GetUser() string {
return r.Spec.User
}
func (r *AccessRequestV3) GetRoles() []string {
return r.Spec.Roles
}
func (r *AccessRequestV3) GetState() RequestState {
return r.Spec.State
}
func (r *AccessRequestV3) SetState(state RequestState) error {
if r.Spec.State.IsDenied() {
if state.IsDenied() {
return nil
}
return trace.BadParameter("cannot set request-state %q (already denied)", state.String())
}
r.Spec.State = state
return nil
}
func (r *AccessRequestV3) GetCreationTime() time.Time {
return r.Spec.Created
}
func (r *AccessRequestV3) SetCreationTime(t time.Time) {
r.Spec.Created = t
}
func (r *AccessRequestV3) GetAccessExpiry() time.Time {
return r.Spec.Expires
}
func (r *AccessRequestV3) SetAccessExpiry(expiry time.Time) {
r.Spec.Expires = expiry
}
func (r *AccessRequestV3) CheckAndSetDefaults() error {
if err := r.Metadata.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if r.GetState().IsNone() {
r.SetState(RequestState_PENDING)
}
if err := r.Check(); err != nil {
return trace.Wrap(err)
}
return nil
}
func (r *AccessRequestV3) Check() error {
if r.Kind == "" {
return trace.BadParameter("access request kind not set")
}
if r.Version == "" {
return trace.BadParameter("access request version not set")
}
if r.GetName() == "" {
return trace.BadParameter("access request id not set")
}
if uuid.Parse(r.GetName()) == nil {
return trace.BadParameter("invalid access request id %q", r.GetName())
}
if r.GetUser() == "" {
return trace.BadParameter("access request user name not set")
}
if len(r.GetRoles()) < 1 {
return trace.BadParameter("access request does not specify any roles")
}
return nil
}
func (r *AccessRequestV3) Equals(other AccessRequest) bool {
o, ok := other.(*AccessRequestV3)
if !ok {
return false
}
if r.GetName() != o.GetName() {
return false
}
return r.Spec.Equals(&o.Spec)
}
func (s *AccessRequestSpecV3) Equals(other *AccessRequestSpecV3) bool {
if s.User != other.User {
return false
}
if len(s.Roles) != len(other.Roles) {
return false
}
for i, role := range s.Roles {
if role != other.Roles[i] {
return false
}
}
if s.Created != other.Created {
return false
}
if s.Expires != other.Expires {
return false
}
return s.State == other.State
}
type AccessRequestMarshaler interface {
MarshalAccessRequest(req AccessRequest, opts ...MarshalOption) ([]byte, error)
UnmarshalAccessRequest(bytes []byte, opts ...MarshalOption) (AccessRequest, error)
}
type accessRequestMarshaler struct{}
func (r *accessRequestMarshaler) MarshalAccessRequest(req AccessRequest, opts ...MarshalOption) ([]byte, error) {
cfg, err := collectOptions(opts)
if err != nil {
return nil, trace.Wrap(err)
}
switch r := req.(type) {
case *AccessRequestV3:
if !cfg.PreserveResourceID {
// avoid modifying the original object
// to prevent unexpected data races
cp := *r
cp.SetResourceID(0)
r = &cp
}
return utils.FastMarshal(r)
default:
return nil, trace.BadParameter("unrecognized access request type: %T", req)
}
}
func (r *accessRequestMarshaler) UnmarshalAccessRequest(data []byte, opts ...MarshalOption) (AccessRequest, error) {
cfg, err := collectOptions(opts)
if err != nil {
return nil, trace.Wrap(err)
}
var req AccessRequestV3
if cfg.SkipValidation {
if err := utils.FastUnmarshal(data, &req); err != nil {
return nil, trace.Wrap(err)
}
} else {
if err := utils.UnmarshalWithSchema(GetAccessRequestSchema(), &req, data); err != nil {
return nil, trace.Wrap(err)
}
}
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
if cfg.ID != 0 {
req.SetResourceID(cfg.ID)
}
if !cfg.Expires.IsZero() {
req.SetExpiry(cfg.Expires)
}
return &req, nil
}
var accessRequestMarshalerInstance AccessRequestMarshaler = &accessRequestMarshaler{}
func GetAccessRequestMarshaler() AccessRequestMarshaler {
marshalerMutex.Lock()
defer marshalerMutex.Unlock()
return accessRequestMarshalerInstance
}
const AccessRequestSpecSchema = `{
"type": "object",
"additionalProperties": false,
"properties": {
"user": { "type": "string" },
"roles": {
"type": "array",
"items": { "type": "string" }
},
"state": { "type": "integer" },
"created": { "type": "string" },
"expires": { "type": "string" }
}
}`
func GetAccessRequestSchema() string {
return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, AccessRequestSpecSchema, DefaultDefinitions)
}
func (r *AccessRequestV3) GetKind() string {
return r.Kind
}
func (r *AccessRequestV3) GetSubKind() string {
return r.SubKind
}
func (r *AccessRequestV3) SetSubKind(subKind string) {
r.SubKind = subKind
}
func (r *AccessRequestV3) GetVersion() string {
return r.Version
}
func (r *AccessRequestV3) GetName() string {
return r.Metadata.Name
}
func (r *AccessRequestV3) SetName(name string) {
r.Metadata.Name = name
}
func (r *AccessRequestV3) Expiry() time.Time {
return r.Metadata.Expiry()
}
func (r *AccessRequestV3) SetExpiry(expiry time.Time) {
r.Metadata.SetExpiry(expiry)
}
func (r *AccessRequestV3) SetTTL(clock clockwork.Clock, ttl time.Duration) {
r.Metadata.SetTTL(clock, ttl)
}
func (r *AccessRequestV3) GetMetadata() Metadata {
return r.Metadata
}
func (r *AccessRequestV3) GetResourceID() int64 {
return r.Metadata.GetID()
}
func (r *AccessRequestV3) SetResourceID(id int64) {
r.Metadata.SetID(id)
}
func (r *AccessRequestV3) String() string {
return fmt.Sprintf("AccessRequest(user=%v,roles=%+v)", r.Spec.User, r.Spec.Roles)
}

View file

@ -0,0 +1,129 @@
/*
Copyright 2019 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package services
import (
"fmt"
"github.com/gravitational/teleport/lib/utils"
. "gopkg.in/check.v1"
)
type AccessRequestSuite struct {
}
var _ = Suite(&AccessRequestSuite{})
var _ = fmt.Printf
func (s *AccessRequestSuite) SetUpSuite(c *C) {
utils.InitLoggerForTests()
}
// TestRequestMarshaling verifies that marshaling/unmarshaling access requests
// works as expected (failures likely indicate a problem with json schema).
func (s *AccessRequestSuite) TestRequestMarshaling(c *C) {
req1, err := NewAccessRequest("some-user", "role-1", "role-2")
c.Assert(err, IsNil)
marshaled, err := GetAccessRequestMarshaler().MarshalAccessRequest(req1)
c.Assert(err, IsNil)
req2, err := GetAccessRequestMarshaler().UnmarshalAccessRequest(marshaled)
c.Assert(err, IsNil)
if !req1.Equals(req2) {
c.Errorf("unexpected inequality %+v <---> %+v", req1, req2)
}
}
// TestRequestFilterMatching verifies expected matching behavior for AccessRequestFilter.
func (s *AccessRequestSuite) TestRequestFilterMatching(c *C) {
reqA, err := NewAccessRequest("alice", "role-a")
c.Assert(err, IsNil)
reqB, err := NewAccessRequest("bob", "role-b")
c.Assert(err, IsNil)
testCases := []struct {
user string
id string
matchA bool
matchB bool
}{
{"", "", true, true},
{"alice", "", true, false},
{"", reqA.GetName(), true, false},
{"bob", reqA.GetName(), false, false},
{"carol", "", false, false},
}
for _, tc := range testCases {
m := AccessRequestFilter{
User: tc.user,
ID: tc.id,
}
if m.Match(reqA) != tc.matchA {
c.Errorf("bad filter behavior (a) %+v", tc)
}
if m.Match(reqB) != tc.matchB {
c.Errorf("bad filter behavior (b) %+v", tc)
}
}
}
// TestRequestFilterConversion verifies that filters convert to and from
// maps correctly.
func (s *AccessRequestSuite) TestRequestFilterConversion(c *C) {
testCases := []struct {
f AccessRequestFilter
m map[string]string
}{
{
AccessRequestFilter{User: "alice", ID: "foo", State: RequestState_PENDING},
map[string]string{"user": "alice", "id": "foo", "state": "PENDING"},
},
{
AccessRequestFilter{User: "bob"},
map[string]string{"user": "bob"},
},
{
AccessRequestFilter{},
map[string]string{},
},
}
for _, tc := range testCases {
if m := tc.f.IntoMap(); !utils.StringMapsEqual(m, tc.m) {
c.Errorf("bad map encoding: expected %+v, got %+v", tc.m, m)
}
var f AccessRequestFilter
if err := f.FromMap(tc.m); err != nil {
c.Errorf("failed to parse %+v: %s", tc.m, err)
}
if !f.Equals(tc.f) {
c.Errorf("bad map decoding: expected %+v, got %+v", tc.f, f)
}
}
badMaps := []map[string]string{
{"food": "carrots"},
{"state": "homesick"},
}
for _, m := range badMaps {
var f AccessRequestFilter
c.Assert(f.FromMap(m), NotNil)
}
}

View file

@ -111,6 +111,9 @@ type UserCertParams struct {
RouteToCluster string
// Traits hold claim data used to populate a role at runtime.
Traits wrappers.Traits
// ActiveRequests tracks privilege escalation requests applied during
// certificate construction.
ActiveRequests RequestIDs
}
// CertRoles defines certificate roles

View file

@ -48,6 +48,9 @@ type WatchKind struct {
Name string
// LoadSecrets specifies whether to load secrets
LoadSecrets bool
// Filter supplies custom event filter parameters that differ by
// resource (e.g. "state":"pending" for access requests).
Filter map[string]string
}
// Event represents an event that happened in the backend

View file

@ -0,0 +1,164 @@
/*
Copyright 2019 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package local
import (
"bytes"
"context"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/trace"
)
// DynamicAccessService manages dynamic RBAC
type DynamicAccessService struct {
backend.Backend
}
// NewDynamicAccessService returns new dynamic access service instance
func NewDynamicAccessService(backend backend.Backend) *AccessService {
return &AccessService{Backend: backend}
}
func (s *AccessService) CreateAccessRequest(req services.AccessRequest) error {
if err := req.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
item, err := itemFromAccessRequest(req)
if err != nil {
return trace.Wrap(err)
}
if _, err := s.Create(context.TODO(), item); err != nil {
return trace.Wrap(err)
}
return nil
}
func (s *AccessService) SetAccessRequestState(name string, state services.RequestState) error {
item, err := s.Get(context.TODO(), accessRequestKey(name))
if err != nil {
if trace.IsNotFound(err) {
return trace.NotFound("cannot set state of access request %q (not found)", name)
}
return trace.Wrap(err)
}
req, err := itemToAccessRequest(*item)
if err != nil {
return trace.Wrap(err)
}
if err := req.SetState(state); err != nil {
return trace.Wrap(err)
}
// approved requests should have a resource expiry which matches
// the underlying access expiry.
if state.IsApproved() {
req.SetExpiry(req.GetAccessExpiry())
}
newItem, err := itemFromAccessRequest(req)
if err != nil {
return trace.Wrap(err)
}
if _, err := s.CompareAndSwap(context.TODO(), *item, newItem); err != nil {
return trace.Wrap(err)
}
return nil
}
func (s *AccessService) GetAccessRequest(name string) (services.AccessRequest, error) {
item, err := s.Get(context.TODO(), accessRequestKey(name))
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.NotFound("access request %q not found", name)
}
return nil, trace.Wrap(err)
}
req, err := itemToAccessRequest(*item)
if err != nil {
return nil, trace.Wrap(err)
}
return req, nil
}
func (s *AccessService) GetAccessRequests(filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
result, err := s.GetRange(context.TODO(), backend.Key(accessRequestsPrefix), backend.RangeEnd(backend.Key(accessRequestsPrefix)), backend.NoLimit)
if err != nil {
return nil, trace.Wrap(err)
}
var requests []services.AccessRequest
for _, item := range result.Items {
if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) {
continue
}
req, err := itemToAccessRequest(item)
if err != nil {
return nil, trace.Wrap(err)
}
if !filter.Match(req) {
// TODO(fspmarshall): optimize filtering to
// avoid full query/iteration in some cases.
continue
}
requests = append(requests, req)
}
return requests, nil
}
func (s *AccessService) DeleteAccessRequest(name string) error {
err := s.Delete(context.TODO(), accessRequestKey(name))
if err != nil {
if trace.IsNotFound(err) {
return trace.NotFound("cannot delete access request %q (not found)", name)
}
return trace.Wrap(err)
}
return nil
}
func itemFromAccessRequest(req services.AccessRequest) (backend.Item, error) {
value, err := services.GetAccessRequestMarshaler().MarshalAccessRequest(req)
if err != nil {
return backend.Item{}, trace.Wrap(err)
}
return backend.Item{
Key: accessRequestKey(req.GetName()),
Value: value,
Expires: req.Expiry(),
ID: req.GetResourceID(),
}, nil
}
func itemToAccessRequest(item backend.Item) (services.AccessRequest, error) {
req, err := services.GetAccessRequestMarshaler().UnmarshalAccessRequest(
item.Value,
services.WithResourceID(item.ID),
services.WithExpires(item.Expires),
)
if err != nil {
return nil, trace.Wrap(err)
}
return req, nil
}
func accessRequestKey(name string) []byte {
return backend.Key(accessRequestsPrefix, name, paramsPrefix)
}
const (
accessRequestsPrefix = "access_requests"
)

View file

@ -82,6 +82,12 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch services.Watch) (s
parser = newTunnelConnectionParser()
case services.KindReverseTunnel:
parser = newReverseTunnelParser()
case services.KindAccessRequest:
p, err := newAccessRequestParser(kind.Filter)
if err != nil {
return nil, trace.Wrap(err)
}
parser = p
default:
return nil, trace.BadParameter("watcher on object kind %v is not supported", kind)
}
@ -134,6 +140,10 @@ func (w *watcher) parseEvent(e backend.Event) (*services.Event, error) {
if err != nil {
return nil, trace.Wrap(err)
}
// if resource is nil, then it was well-formed but is being filtered out.
if resource == nil {
return nil, nil
}
return &services.Event{Type: e.Type, Resource: resource}, nil
}
}
@ -157,6 +167,10 @@ func (w *watcher) forwardEvents() {
}
continue
}
// event is being filtered out
if converted == nil {
continue
}
select {
case w.eventsC <- *converted:
case <-w.backendWatcher.Done():
@ -485,6 +499,56 @@ func (p *roleParser) parse(event backend.Event) (services.Resource, error) {
}
}
func newAccessRequestParser(m map[string]string) (*accessRequestParser, error) {
var filter services.AccessRequestFilter
if err := filter.FromMap(m); err != nil {
return nil, trace.Wrap(err)
}
return &accessRequestParser{
filter: filter,
matchPrefix: backend.Key(accessRequestsPrefix),
matchSuffix: backend.Key(paramsPrefix),
}, nil
}
type accessRequestParser struct {
filter services.AccessRequestFilter
matchPrefix []byte
matchSuffix []byte
}
func (p *accessRequestParser) prefix() []byte {
return p.matchPrefix
}
func (p *accessRequestParser) match(key []byte) bool {
if !bytes.HasPrefix(key, p.matchPrefix) {
return false
}
if !bytes.HasSuffix(key, p.matchSuffix) {
return false
}
return true
}
func (p *accessRequestParser) parse(event backend.Event) (services.Resource, error) {
switch event.Type {
case backend.OpDelete:
return resourceHeader(event, services.KindAccessRequest, services.V3, 1)
case backend.OpPut:
req, err := itemToAccessRequest(event.Item)
if err != nil {
return nil, trace.Wrap(err)
}
if !p.filter.Match(req) {
return nil, nil
}
return req, nil
default:
return nil, trace.BadParameter("event %v is not supported", event.Type)
}
}
func newUserParser() *userParser {
return &userParser{
matchPrefix: backend.Key(webPrefix, usersPrefix),

View file

@ -62,6 +62,9 @@ const (
// KindRole is a role resource
KindRole = "role"
// KindAccessRequest is an AccessReqeust resource
KindAccessRequest = "access_request"
// KindOIDC is OIDC connector resource
KindOIDC = "oidc"

View file

@ -86,7 +86,7 @@ func RoleNameForCertAuthority(name string) string {
}
// NewAdminRole is the default admin role for all local users if another role
// is not explicitly assigned (Enterprise only).
// is not explicitly assigned (this role applies to all users in OSS version).
func NewAdminRole() Role {
role := &RoleV3{
Kind: KindRole,
@ -265,6 +265,11 @@ type Role interface {
GetKubeGroups(RoleConditionType) []string
// SetKubeGroups sets kubernetes groups for allow or deny condition.
SetKubeGroups(RoleConditionType, []string)
// GetAccessRequestConditions gets allow/deny conditions for access requests.
GetAccessRequestConditions(RoleConditionType) AccessRequestConditions
// SetAccessRequestConditions sets allow/deny conditions for access requests.
SetAccessRequestConditions(RoleConditionType, AccessRequestConditions)
}
// ApplyTraits applies the passed in traits to any variables within the role
@ -516,6 +521,27 @@ func (r *RoleV3) SetKubeGroups(rct RoleConditionType, groups []string) {
}
}
// GetAccessRequestConditions gets conditions for access requests.
func (r *RoleV3) GetAccessRequestConditions(rct RoleConditionType) AccessRequestConditions {
cond := r.Spec.Deny.Request
if rct == Allow {
cond = r.Spec.Allow.Request
}
if cond == nil {
return AccessRequestConditions{}
}
return *cond
}
// SetAccessRequestConditions sets allow/deny conditions for access requests.
func (r *RoleV3) SetAccessRequestConditions(rct RoleConditionType, cond AccessRequestConditions) {
if rct == Allow {
r.Spec.Allow.Request = &cond
} else {
r.Spec.Deny.Request = &cond
}
}
// GetNamespaces gets a list of namespaces this role is allowed or denied access to.
func (r *RoleV3) GetNamespaces(rct RoleConditionType) []string {
if rct == Allow {
@ -2198,6 +2224,16 @@ const RoleSpecV3SchemaDefinitions = `
"type": "array",
"items": { "type": "string" }
},
"request": {
"type": "object",
"additionalProperties": false,
"properties": {
"roles": {
"type": "array",
"items": { "type": "string" }
}
}
},
"rules": {
"type": "array",
"items": {

File diff suppressed because it is too large Load diff

View file

@ -371,6 +371,55 @@ message Namespace {
message NamespaceSpec {
}
// AccessRequest represents an access request resource specification
message AccessRequestV3 {
option (gogoproto.goproto_stringer) = false;
option (gogoproto.stringer) = false;
// Kind is a resource kind
string Kind = 1 [(gogoproto.jsontag) = "kind"];
// SubKind is an optional resource sub kind, used in some resources
string SubKind = 2 [(gogoproto.jsontag) = "sub_kind,omitempty"];
// Version is version
string Version = 3 [(gogoproto.jsontag) = "version"];
// Metadata is AccessRequest metadata
Metadata Metadata = 4 [(gogoproto.nullable) = false, (gogoproto.jsontag) = "metadata"];
// Spec is an AccessReqeust specification
AccessRequestSpecV3 Spec = 5 [(gogoproto.nullable) = false, (gogoproto.jsontag) = "spec"];
}
// RequestState represents the state of a request for escalated privilege.
enum RequestState {
NONE = 0;
PENDING = 1;
APPROVED = 2;
DENIED = 3;
}
// AccessRequestSpec is the specification for AccessRequest
message AccessRequestSpecV3 {
// User is the name of the user to whom the roles will be applied.
string User = 1 [(gogoproto.jsontag) = "user"];
// Roles is the name of the roles being requested.
repeated string Roles = 2 [(gogoproto.jsontag) = "roles"];
// State is the current state of this access request.
RequestState State = 3 [(gogoproto.jsontag) = "state,omitempty"];
// Created encodes the time at which the request was registered with the auth server.
google.protobuf.Timestamp Created = 4 [(gogoproto.stdtime) = true, (gogoproto.nullable) = false, (gogoproto.jsontag) = "created,omitempty"];
// Expires constrains the maximum lifetime of any login session for which this request is active.
google.protobuf.Timestamp Expires = 5 [(gogoproto.stdtime) = true, (gogoproto.nullable) = false, (gogoproto.jsontag) = "expires,omitempty"];
}
// AccessRequestFilter encodes filter params for access requests.
message AccessRequestFilter {
// ID specifies a request ID if set.
string ID = 1 [(gogoproto.jsontag) = "id"];
// User specifies a username if set.
string User = 2 [(gogoproto.jsontag) = "user"];
// RequestState filters for requests in a specific state.
RequestState State = 3 [(gogoproto.jsontag) = "state"];
}
// RoleV3 represents role resource specification
message RoleV3 {
option (gogoproto.goproto_stringer) = false;
@ -444,6 +493,14 @@ message RoleConditions {
// KubeGroups is a list of kubernetes groups
repeated string KubeGroups = 5 [(gogoproto.jsontag) = "kubernetes_groups,omitempty"];
AccessRequestConditions Request = 6 [(gogoproto.jsontag) = "request,omitempty"];
}
// AccessRequestConditions is a matcher for allow/deny restrictions on access-requests.
message AccessRequestConditions {
// Roles is the name of roles which will match the request rule.
repeated string Roles = 1 [(gogoproto.jsontag) = "roles,omitempty"];
}
// Rule represents allow or deny rule that is executed to check

View file

@ -0,0 +1,171 @@
/*
Copyright 2019 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package common
import (
"encoding/json"
"fmt"
"os"
"strings"
"time"
"github.com/gravitational/kingpin"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/asciitable"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/trace"
)
// AccessRequestCommand implements `tctl users` set of commands
// It implements CLICommand interface
type AccessRequestCommand struct {
config *service.Config
reqIDs string
user string
roles string
// format is the output format, e.g. text or json
format string
requestList *kingpin.CmdClause
requestApprove *kingpin.CmdClause
requestDeny *kingpin.CmdClause
requestCreate *kingpin.CmdClause
requestDelete *kingpin.CmdClause
}
// Initialize allows AccessRequestCommand to plug itself into the CLI parser
func (c *AccessRequestCommand) Initialize(app *kingpin.Application, config *service.Config) {
c.config = config
requests := app.Command("requests", "Manage access requests").Alias("request")
c.requestList = requests.Command("ls", "Show active access requests")
c.requestList.Flag("format", "Output format, 'text' or 'json'").Hidden().Default(teleport.Text).StringVar(&c.format)
c.requestApprove = requests.Command("approve", "Approve pending access request")
c.requestApprove.Arg("request-id", "ID of target request(s)").Required().StringVar(&c.reqIDs)
c.requestDeny = requests.Command("deny", "Deny pending access request")
c.requestDeny.Arg("request-id", "ID of target request(s)").Required().StringVar(&c.reqIDs)
c.requestCreate = requests.Command("create", "Create pending access request")
c.requestCreate.Arg("username", "Name of target user").Required().StringVar(&c.user)
c.requestCreate.Flag("roles", "Roles to be requested").Required().StringVar(&c.roles)
c.requestDelete = requests.Command("rm", "Delete an access request")
c.requestDelete.Arg("request-id", "ID of target request(s)").Required().StringVar(&c.reqIDs)
}
// TryRun takes the CLI command as an argument (like "access-request list") and executes it.
func (c *AccessRequestCommand) TryRun(cmd string, client auth.ClientI) (match bool, err error) {
switch cmd {
case c.requestList.FullCommand():
err = c.List(client)
case c.requestApprove.FullCommand():
err = c.Approve(client)
case c.requestDeny.FullCommand():
err = c.Deny(client)
case c.requestCreate.FullCommand():
err = c.Create(client)
default:
return false, nil
}
return true, trace.Wrap(err)
}
func (c *AccessRequestCommand) List(client auth.ClientI) error {
reqs, err := client.GetAccessRequests(services.AccessRequestFilter{})
if err != nil {
return trace.Wrap(err)
}
if err := c.PrintAccessRequests(client, reqs, c.format); err != nil {
return trace.Wrap(err)
}
return nil
}
func (c *AccessRequestCommand) Approve(client auth.ClientI) error {
for _, reqID := range strings.Split(c.reqIDs, ",") {
if err := client.SetAccessRequestState(reqID, services.RequestState_APPROVED); err != nil {
return trace.Wrap(err)
}
}
return nil
}
func (c *AccessRequestCommand) Deny(client auth.ClientI) error {
for _, reqID := range strings.Split(c.reqIDs, ",") {
if err := client.SetAccessRequestState(reqID, services.RequestState_DENIED); err != nil {
return trace.Wrap(err)
}
}
return nil
}
func (c *AccessRequestCommand) Create(client auth.ClientI) error {
roles := strings.Split(c.roles, ",")
req, err := services.NewAccessRequest(c.user, roles...)
if err != nil {
return trace.Wrap(err)
}
if err := client.CreateAccessRequest(req); err != nil {
return trace.Wrap(err)
}
fmt.Printf("%s\n", req.GetName())
return nil
}
func (c *AccessRequestCommand) Delete(client auth.ClientI) error {
for _, reqID := range strings.Split(c.reqIDs, ",") {
if err := client.DeleteAccessRequest(reqID); err != nil {
return trace.Wrap(err)
}
}
return nil
}
// PrintAccessRequests prints access requests
func (c *AccessRequestCommand) PrintAccessRequests(client auth.ClientI, reqs []services.AccessRequest, format string) error {
if format == teleport.Text {
table := asciitable.MakeTable([]string{"Token", "Requestor", "Metadata", "Created At (UTC)", "Status"})
now := time.Now()
for _, req := range reqs {
if now.After(req.GetAccessExpiry()) {
continue
}
params := fmt.Sprintf("roles=%s", strings.Join(req.GetRoles(), ","))
table.AddRow([]string{
req.GetName(),
req.GetUser(),
params,
req.GetCreationTime().Format(time.RFC822),
req.GetState().String(),
})
}
_, err := table.AsBuffer().WriteTo(os.Stdout)
return trace.Wrap(err)
} else {
out, err := json.MarshalIndent(reqs, "", " ")
if err != nil {
return trace.Wrap(err, "failed to marshal requests")
}
fmt.Printf("%s\n", out)
}
return nil
}

View file

@ -70,7 +70,7 @@ func (a *AuthCommand) Initialize(app *kingpin.Application, config *service.Confi
a.authSign.Flag("user", "Teleport user name").StringVar(&a.genUser)
a.authSign.Flag("host", "Teleport host name").StringVar(&a.genHost)
a.authSign.Flag("out", "identity output").Short('o').StringVar(&a.output)
a.authSign.Flag("format", fmt.Sprintf("identity format: %q (default) or %q", client.IdentityFormatFile, client.IdentityFormatOpenSSH)).Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&a.outputFormat))
a.authSign.Flag("format", fmt.Sprintf("identity format: %q (default), %q, or %q", client.IdentityFormatFile, client.IdentityFormatOpenSSH, client.IdentityFormatTLS)).Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&a.outputFormat))
a.authSign.Flag("ttl", "TTL (time to live) for the generated certificate").Default(fmt.Sprintf("%v", defaults.CertDuration)).DurationVar(&a.genTTL)
a.authSign.Flag("compat", "OpenSSH compatibility flag").StringVar(&a.compatibility)
@ -348,7 +348,7 @@ func (a *AuthCommand) generateUserKeys(clusterApi auth.ClientI) error {
key.TLSCert = certs.TLS
var certAuthorities []services.CertAuthority
if a.outputFormat == client.IdentityFormatFile {
if a.outputFormat != client.IdentityFormatOpenSSH {
certAuthorities, err = clusterApi.GetCertAuthorities(services.HostCA, false)
if err != nil {
return trace.Wrap(err)

View file

@ -186,10 +186,13 @@ func (c *TokenCommand) List(client auth.ClientI) error {
if c.format == teleport.Text {
tokensView := func() string {
table := asciitable.MakeTable([]string{"Token", "Type", "Expiry Time (UTC)"})
now := time.Now()
for _, t := range tokens {
expiry := "never"
if t.Expiry().Unix() > 0 {
expiry = t.Expiry().Format(time.RFC822)
exptime := t.Expiry().Format(time.RFC822)
expdur := t.Expiry().Sub(now).Round(time.Second)
expiry = fmt.Sprintf("%s (%s)", exptime, expdur.String())
}
table.AddRow([]string{t.GetName(), t.GetRoles().String(), expiry})
}

View file

@ -29,6 +29,7 @@ func main() {
&common.ResourceCommand{},
&common.StatusCommand{},
&common.TopCommand{},
&common.AccessRequestCommand{},
}
common.Run(commands)
}

View file

@ -18,8 +18,11 @@ package main
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/signal"
@ -34,6 +37,8 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/asciitable"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
kubeclient "github.com/gravitational/teleport/lib/kube/client"
@ -61,6 +66,8 @@ type CLIConf struct {
UserHost string
// Commands to execute on a remote host
RemoteCommand []string
// DesiredRoles indicates one or more roles which should be requested.
DesiredRoles string
// Username is the Teleport user's username (to login into proxies)
Username string
// Proxy keeps the hostname:port of the SSH proxy to use
@ -262,6 +269,7 @@ func Run(args []string, underTest bool) {
login.Flag("format", fmt.Sprintf("Identity format [%s] or %s (for OpenSSH compatibility)",
client.DefaultIdentityFormat,
client.IdentityFormatOpenSSH)).Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&cf.IdentityFormat))
login.Flag("request-roles", "Request one or more extra roles").StringVar(&cf.DesiredRoles)
login.Arg("cluster", clusterHelp).StringVar(&cf.SiteName)
login.Alias(loginUsageFooter)
@ -411,18 +419,23 @@ func onLogin(cf *CLIConf) {
if profile != nil && !profile.IsExpired(clockwork.NewRealClock()) {
switch {
// in case if nothing is specified, print current status
case cf.Proxy == "" && cf.SiteName == "":
case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "":
printProfiles(cf.Debug, profile, profiles)
return
// in case if parameters match, print current status
case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster:
case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "":
printProfiles(cf.Debug, profile, profiles)
return
// proxy is unspecified or the same as the currently provided proxy,
// but cluster is specified, treat this as selecting a new cluster
// for the same proxy
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "":
if err := tc.GenerateCertsForCluster(cf.Context, cf.SiteName); err != nil {
// trigger reissue, preserving any active requests.
err = tc.ReissueUserCerts(cf.Context, client.ReissueParams{
AccessRequests: profile.ActiveRequests.AccessRequests,
RouteToCluster: cf.SiteName,
})
if err != nil {
utils.FatalError(err)
}
tc.SaveProfile("", "")
@ -431,6 +444,12 @@ func onLogin(cf *CLIConf) {
}
onStatus(cf)
return
// proxy is unspecified or the same as the currently provided proxy,
// but desired roles are specified, treat this as a privilege escalation
// request for the same login session.
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.DesiredRoles != "":
executeAccessRequest(cf)
return
// otherwise just passthrough to standard login
default:
}
@ -477,7 +496,12 @@ func onLogin(cf *CLIConf) {
// advertised settings are picked up.
webProxyHost, _ := tc.WebProxyHostPort()
cf.Proxy = webProxyHost
onStatus(cf)
if cf.DesiredRoles != "" {
fmt.Println("") // visually separate onRequestExecute output
executeAccessRequest(cf)
} else {
onStatus(cf)
}
}
// setupNoninteractiveClient sets up existing client to use
@ -665,6 +689,33 @@ func onListNodes(cf *CLIConf) {
}
}
func executeAccessRequest(cf *CLIConf) {
if cf.DesiredRoles == "" {
utils.FatalError(trace.BadParameter("one or more roles must be specified"))
}
roles := strings.Split(cf.DesiredRoles, ",")
tc, err := makeClient(cf, true)
if err != nil {
utils.FatalError(err)
}
if cf.Username == "" {
cf.Username = tc.Username
}
req, err := services.NewAccessRequest(cf.Username, roles...)
if err != nil {
utils.FatalError(err)
}
fmt.Fprintf(os.Stderr, "Seeking request approval... (id: %s)\n", req.GetName())
if err := getRequestApproval(cf, tc, req); err != nil {
utils.FatalError(err)
}
fmt.Fprintf(os.Stderr, "Approval received, getting updated certificates...\n\n")
if err := reissueWithRequests(cf, tc, req.GetName()); err != nil {
utils.FatalError(err)
}
onStatus(cf)
}
// chunkLabels breaks labels into sized chunks. Used to improve readability
// of "tsh ls".
func chunkLabels(labels map[string]string, chunkSize int) [][]string {
@ -1028,6 +1079,104 @@ func refuseArgs(command string, args []string) {
}
}
// loadIdentity loads the private key + certificate from a file
// Returns:
// - client key: user's private key+cert
// - host auth callback: function to validate the host (may be null)
// - error, if somthing happens when reading the identityf file
//
// If the "host auth callback" is not returned, user will be prompted to
// trust the proxy server.
func loadIdentity(idFn string) (*client.Key, ssh.HostKeyCallback, error) {
log.Infof("Reading identity file: %v", idFn)
f, err := os.Open(idFn)
if err != nil {
return nil, nil, trace.Wrap(err)
}
defer f.Close()
ident, err := client.DecodeIdentityFile(f)
if err != nil {
return nil, nil, trace.Wrap(err, "failed to parse identity file")
}
// did not find the certificate in the file? look in a separate file with
// -cert.pub prefix
if len(ident.Certs.SSH) == 0 {
certFn := idFn + "-cert.pub"
log.Infof("Certificate not found in %s. Looking in %s.", idFn, certFn)
ident.Certs.SSH, err = ioutil.ReadFile(certFn)
if err != nil {
return nil, nil, trace.Wrap(err)
}
}
// validate both by parsing them:
privKey, err := ssh.ParseRawPrivateKey(ident.PrivateKey)
if err != nil {
return nil, nil, trace.BadParameter("invalid identity: %s. %v", idFn, err)
}
signer, err := ssh.NewSignerFromKey(privKey)
if err != nil {
return nil, nil, trace.Wrap(err)
}
// validate TLS Cert (if present):
if len(ident.Certs.TLS) > 0 {
_, err := tls.X509KeyPair(ident.Certs.TLS, ident.PrivateKey)
if err != nil {
return nil, nil, trace.Wrap(err)
}
}
// Validate TLS CA certs (if present).
var trustedCA []auth.TrustedCerts
if len(ident.CACerts.TLS) > 0 {
var trustedCerts auth.TrustedCerts
pool := x509.NewCertPool()
for i, certPEM := range ident.CACerts.TLS {
if !pool.AppendCertsFromPEM(certPEM) {
return nil, nil, trace.BadParameter("identity file contains invalid TLS CA cert (#%v)", i+1)
}
trustedCerts.TLSCertificates = append(trustedCerts.TLSCertificates, certPEM)
}
trustedCA = []auth.TrustedCerts{trustedCerts}
}
var hostAuthFunc ssh.HostKeyCallback = nil
// validate CA (cluster) cert
if len(ident.CACerts.SSH) > 0 {
var trustedKeys []ssh.PublicKey
for _, caCert := range ident.CACerts.SSH {
_, _, publicKey, _, _, err := ssh.ParseKnownHosts(caCert)
if err != nil {
return nil, nil, trace.BadParameter("CA cert parsing error: %v. cert line :%v",
err.Error(), string(caCert))
}
trustedKeys = append(trustedKeys, publicKey)
}
// found CA cert in the indentity file? construct the host key checking function
// and return it:
hostAuthFunc = func(host string, a net.Addr, hostKey ssh.PublicKey) error {
clusterCert, ok := hostKey.(*ssh.Certificate)
if ok {
hostKey = clusterCert.SignatureKey
}
for _, trustedKey := range trustedKeys {
if sshutils.KeysEqual(trustedKey, hostKey) {
return nil
}
}
err = trace.AccessDenied("host %v is untrusted", host)
log.Error(err)
return err
}
}
return &client.Key{
Priv: ident.PrivateKey,
Pub: signer.PublicKey().Marshal(),
Cert: ident.Certs.SSH,
TLSCert: ident.Certs.TLS,
TrustedCA: trustedCA,
}, hostAuthFunc, nil
}
// authFromIdentity returns a standard ssh.Authmethod for a given identity file
func authFromIdentity(k *client.Key) (ssh.AuthMethod, error) {
signer, err := sshutils.NewSigner(k.Priv, k.Cert)
@ -1147,3 +1296,80 @@ func host(in string) string {
}
return out
}
// getRequestApproval registers an access request with the auth server and waits for it to be approved.
func getRequestApproval(cf *CLIConf, tc *client.TeleportClient, req services.AccessRequest) error {
// set up request watcher before submitting the request to the admin server
// in order to avoid potential race.
filter := services.AccessRequestFilter{
User: tc.Username,
}
watcher, err := tc.NewWatcher(cf.Context, services.Watch{
Name: "await-request-approval",
Kinds: []services.WatchKind{
services.WatchKind{
Kind: services.KindAccessRequest,
Filter: filter.IntoMap(),
},
},
})
if err != nil {
return trace.Wrap(err)
}
defer watcher.Close()
if err := tc.CreateAccessRequest(cf.Context, req); err != nil {
utils.FatalError(err)
}
Loop:
for {
select {
case event := <-watcher.Events():
if event.Type != backend.OpPut {
continue Loop
}
r, ok := event.Resource.(*services.AccessRequestV3)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
if r.GetName() != req.GetName() || r.GetState().IsPending() {
continue Loop
}
if !r.GetState().IsApproved() {
return trace.Errorf("request %s has been set to %s", r.GetName(), r.GetState().String())
}
return nil
case <-watcher.Done():
utils.FatalError(watcher.Error())
}
}
}
// reissueWithRequests handles a certificate reissue, applying new requests by ID,
// and saving the updated profile.
func reissueWithRequests(cf *CLIConf, tc *client.TeleportClient, reqIDs ...string) error {
profile, _, err := client.Status("", cf.Proxy)
if err != nil {
return trace.Wrap(err)
}
params := client.ReissueParams{
AccessRequests: reqIDs,
RouteToCluster: cf.SiteName,
}
// if the certificate already had active requests, add them to our inputs parameters.
if len(profile.ActiveRequests.AccessRequests) > 0 {
params.AccessRequests = append(params.AccessRequests, profile.ActiveRequests.AccessRequests...)
}
if params.RouteToCluster == "" {
params.RouteToCluster = profile.Cluster
}
if err := tc.ReissueUserCerts(cf.Context, params); err != nil {
return trace.Wrap(err)
}
if err := tc.SaveProfile("", ""); err != nil {
return trace.Wrap(err)
}
if err := kubeclient.UpdateKubeconfig(tc); err != nil {
return trace.Wrap(err)
}
return nil
}