mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 00:33:50 +00:00
Implment access-request system (workflow API)
This commit is contained in:
parent
a32468033a
commit
ec327b6e03
|
@ -373,6 +373,9 @@ const (
|
|||
CertExtensionTeleportRouteToCluster = "teleport-route-to-cluster"
|
||||
// CertExtensionTeleportTraits is used to propagate traits about the user.
|
||||
CertExtensionTeleportTraits = "teleport-traits"
|
||||
// CertExtensionTeleportActiveRequests is used to track which privilege
|
||||
// escalation requests were used to construct the certificate.
|
||||
CertExtensionTeleportActiveRequests = "teleport-active-requests"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -77,6 +77,9 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro
|
|||
if cfg.Access == nil {
|
||||
cfg.Access = local.NewAccessService(cfg.Backend)
|
||||
}
|
||||
if cfg.DynamicAccess == nil {
|
||||
cfg.DynamicAccess = local.NewDynamicAccessService(cfg.Backend)
|
||||
}
|
||||
if cfg.ClusterConfiguration == nil {
|
||||
cfg.ClusterConfiguration = local.NewClusterConfigurationService(cfg.Backend)
|
||||
}
|
||||
|
@ -111,6 +114,7 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro
|
|||
Provisioner: cfg.Provisioner,
|
||||
Identity: cfg.Identity,
|
||||
Access: cfg.Access,
|
||||
DynamicAccess: cfg.DynamicAccess,
|
||||
ClusterConfiguration: cfg.ClusterConfiguration,
|
||||
IAuditLog: cfg.AuditLog,
|
||||
Events: cfg.Events,
|
||||
|
@ -132,6 +136,7 @@ type AuthServices struct {
|
|||
services.Provisioner
|
||||
services.Identity
|
||||
services.Access
|
||||
services.DynamicAccess
|
||||
services.ClusterConfiguration
|
||||
services.Events
|
||||
events.IAuditLog
|
||||
|
@ -408,6 +413,9 @@ type certRequest struct {
|
|||
routeToCluster string
|
||||
// traits hold claim data used to populate a role at runtime.
|
||||
traits wrappers.Traits
|
||||
// activeRequests tracks privilege escalation requests applied
|
||||
// during the construction of the certificate.
|
||||
activeRequests services.RequestIDs
|
||||
}
|
||||
|
||||
// GenerateUserTestCerts is used to generate user certificate, used internally for tests
|
||||
|
@ -509,6 +517,7 @@ func (s *AuthServer) generateUserCert(req certRequest) (*certs, error) {
|
|||
PermitAgentForwarding: req.checker.CanForwardAgents(),
|
||||
RouteToCluster: req.routeToCluster,
|
||||
Traits: req.traits,
|
||||
ActiveRequests: req.activeRequests,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
|
@ -1350,6 +1359,77 @@ func (a *AuthServer) DeleteRole(name string) error {
|
|||
return a.Access.DeleteRole(name)
|
||||
}
|
||||
|
||||
func (a *AuthServer) CreateAccessRequest(req services.AccessRequest) error {
|
||||
if err := services.ValidateAccessRequest(a, req); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
ttl, err := a.calculateMaxAccessTTL(req)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
now := a.clock.Now().UTC()
|
||||
req.SetCreationTime(now)
|
||||
exp := now.Add(ttl)
|
||||
// Set acccess expiry if an allowable default was not provided.
|
||||
if req.GetAccessExpiry().Before(now) || req.GetAccessExpiry().After(exp) {
|
||||
req.SetAccessExpiry(exp)
|
||||
}
|
||||
// By default, resource expiry should match access expiry.
|
||||
req.SetExpiry(req.GetAccessExpiry())
|
||||
// If the access-request is in a pending state, then the expiry of the underlying resource
|
||||
// is capped to to PendingAccessDuration in order to limit orphaned access requests.
|
||||
if req.GetState().IsPending() {
|
||||
pexp := now.Add(defaults.PendingAccessDuration)
|
||||
if pexp.Before(req.Expiry()) {
|
||||
req.SetExpiry(pexp)
|
||||
}
|
||||
}
|
||||
if err := a.DynamicAccess.CreateAccessRequest(req); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
err = a.EmitAuditEvent(events.AccessRequestCreated, events.EventFields{
|
||||
events.AccessRequestID: req.GetName(),
|
||||
events.EventUser: req.GetUser(),
|
||||
events.UserRoles: req.GetRoles(),
|
||||
events.AccessRequestState: req.GetState().String(),
|
||||
})
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (a *AuthServer) SetAccessRequestState(reqID string, state services.RequestState, updatedBy ...string) error {
|
||||
if err := a.DynamicAccess.SetAccessRequestState(reqID, state); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
u := "unknown"
|
||||
if len(updatedBy) == 1 {
|
||||
u = updatedBy[0]
|
||||
}
|
||||
err := a.EmitAuditEvent(events.AccessRequestUpdated, events.EventFields{
|
||||
events.AccessRequestID: reqID,
|
||||
events.AccessRequestState: state.String(),
|
||||
events.AccessRequestUpdateBy: u,
|
||||
})
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// calculateMaxAccessTTL determines the maximum allowable TTL for a given access request
|
||||
// based on the MaxSessionTTLs of the roles being requested (a access request's life cannot
|
||||
// exceed the smallest allowable MaxSessionTTL value of the roles that it requests).
|
||||
func (a *AuthServer) calculateMaxAccessTTL(req services.AccessRequest) (time.Duration, error) {
|
||||
minTTL := defaults.MaxAccessDuration
|
||||
for _, roleName := range req.GetRoles() {
|
||||
role, err := a.GetRole(roleName)
|
||||
if err != nil {
|
||||
return 0, trace.Wrap(err)
|
||||
}
|
||||
roleTTL := time.Duration(role.GetOptions().MaxSessionTTL)
|
||||
if roleTTL > 0 && roleTTL < minTTL {
|
||||
minTTL = roleTTL
|
||||
}
|
||||
}
|
||||
return minTTL, nil
|
||||
}
|
||||
|
||||
// NewKeepAliver returns a new instance of keep aliver
|
||||
func (a *AuthServer) NewKeepAliver(ctx context.Context) (services.KeepAliver, error) {
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
|
|
|
@ -452,7 +452,16 @@ func (a *AuthWithRoles) NewWatcher(ctx context.Context, watch services.Watch) (s
|
|||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
case services.KindAccessRequest:
|
||||
var filter services.AccessRequestFilter
|
||||
if err := filter.FromMap(kind.Filter); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if filter.User == "" || a.currentUserAction(filter.User) != nil {
|
||||
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbRead); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, trace.AccessDenied("not authorized to watch %v events", kind.Kind)
|
||||
}
|
||||
|
@ -778,6 +787,47 @@ func (a *AuthWithRoles) DeleteWebSession(user string, sid string) error {
|
|||
return a.authServer.DeleteWebSession(user, sid)
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) GetAccessRequests(filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
|
||||
// An exception is made to allow users to get their own access requests.
|
||||
if filter.User == "" || a.currentUserAction(filter.User) != nil {
|
||||
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbList); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbRead); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return a.authServer.GetAccessRequests(filter)
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) CreateAccessRequest(req services.AccessRequest) error {
|
||||
// An exception is made to allow users to create access *pending* requests for themselves.
|
||||
if !req.GetState().IsPending() || a.currentUserAction(req.GetUser()) != nil {
|
||||
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbCreate); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
// Ensure that an access request cannot outlive the identity that creates it.
|
||||
if req.GetAccessExpiry().Before(a.authServer.GetClock().Now()) || req.GetAccessExpiry().After(a.identity.Expires) {
|
||||
req.SetAccessExpiry(a.identity.Expires)
|
||||
}
|
||||
return a.authServer.CreateAccessRequest(req)
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) SetAccessRequestState(reqID string, state services.RequestState) error {
|
||||
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbUpdate); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return a.authServer.SetAccessRequestState(reqID, state, a.user.GetName())
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) DeleteAccessRequest(name string) error {
|
||||
if err := a.action(defaults.Namespace, services.KindAccessRequest, services.VerbUpdate); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return a.authServer.DeleteAccessRequest(name)
|
||||
}
|
||||
|
||||
func (a *AuthWithRoles) GetUsers(withSecrets bool) ([]services.User, error) {
|
||||
if withSecrets {
|
||||
// TODO(fspmarshall): replace admin requirement with VerbReadWithSecrets once we've
|
||||
|
@ -908,6 +958,44 @@ func (a *AuthWithRoles) GenerateUserCerts(ctx context.Context, req proto.UserCer
|
|||
return nil, trace.AccessDenied("this request can be only executed by an admin")
|
||||
}
|
||||
|
||||
// TODO(fspmarshall): Move this logic to AuthServer.
|
||||
if len(req.AccessRequests) > 0 {
|
||||
// add any applicable access request values.
|
||||
for _, reqID := range req.AccessRequests {
|
||||
accessReq, err := a.authServer.GetAccessRequest(reqID)
|
||||
if err != nil {
|
||||
if trace.IsNotFound(err) {
|
||||
return nil, trace.AccessDenied("invalid access request %q", reqID)
|
||||
}
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if accessReq.GetUser() != req.Username {
|
||||
return nil, trace.AccessDenied("invalid access request %q", reqID)
|
||||
}
|
||||
if !accessReq.GetState().IsApproved() {
|
||||
if accessReq.GetState().IsDenied() {
|
||||
return nil, trace.AccessDenied("access-request %q has been denied", reqID)
|
||||
}
|
||||
return nil, trace.AccessDenied("access-request %q is awaiting approval", reqID)
|
||||
}
|
||||
if err := services.ValidateAccessRequest(a.authServer, accessReq); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
aexp := accessReq.GetAccessExpiry()
|
||||
if aexp.Before(a.authServer.GetClock().Now()) {
|
||||
return nil, trace.AccessDenied("access-request %q is expired", reqID)
|
||||
}
|
||||
if aexp.Before(req.Expires) {
|
||||
// cannot generate a cert that would outlive the access request
|
||||
req.Expires = aexp
|
||||
}
|
||||
roles = append(roles, accessReq.GetRoles()...)
|
||||
}
|
||||
// nothing prevents an access-request from including roles already posessed by the
|
||||
// user, so we must make sure to trim duplicate roles.
|
||||
roles = utils.Deduplicate(roles)
|
||||
}
|
||||
|
||||
// Extract the user and role set for whom the certificate will be generated.
|
||||
user, err := a.GetUser(req.Username, false)
|
||||
if err != nil {
|
||||
|
@ -929,6 +1017,9 @@ func (a *AuthWithRoles) GenerateUserCerts(ctx context.Context, req proto.UserCer
|
|||
routeToCluster: req.RouteToCluster,
|
||||
checker: checker,
|
||||
traits: traits,
|
||||
activeRequests: services.RequestIDs{
|
||||
AccessRequests: req.AccessRequests,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
|
|
|
@ -123,7 +123,7 @@ func DecodeClusterName(serverName string) (string, error) {
|
|||
}
|
||||
const suffix = "." + teleport.APIDomain
|
||||
if !strings.HasSuffix(serverName, suffix) {
|
||||
return "", trace.BadParameter("unrecognized name, expected suffix %v, got %q", teleport.APIDomain, serverName)
|
||||
return "", trace.NotFound("no cluster name is encoded")
|
||||
}
|
||||
clusterName := strings.TrimSuffix(serverName, suffix)
|
||||
|
||||
|
@ -810,6 +810,7 @@ func (c *Client) NewWatcher(ctx context.Context, watch services.Watch) (services
|
|||
Name: kind.Name,
|
||||
Kind: kind.Kind,
|
||||
LoadSecrets: kind.LoadSecrets,
|
||||
Filter: kind.Filter,
|
||||
})
|
||||
}
|
||||
stream, err := clt.WatchEvents(cancelCtx, &protoWatch)
|
||||
|
@ -2530,6 +2531,67 @@ func (c *Client) DeleteTrustedCluster(name string) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (c *Client) GetAccessRequests(filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
|
||||
clt, err := c.grpc()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
rsp, err := clt.GetAccessRequests(context.TODO(), &filter)
|
||||
if err != nil {
|
||||
return nil, trail.FromGRPC(err)
|
||||
}
|
||||
reqs := make([]services.AccessRequest, 0, len(rsp.AccessRequests))
|
||||
for _, req := range rsp.AccessRequests {
|
||||
reqs = append(reqs, req)
|
||||
}
|
||||
return reqs, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateAccessRequest(req services.AccessRequest) error {
|
||||
r, ok := req.(*services.AccessRequestV3)
|
||||
if !ok {
|
||||
return trace.BadParameter("unexpected access request type %T", req)
|
||||
}
|
||||
clt, err := c.grpc()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
_, err = clt.CreateAccessRequest(context.TODO(), r)
|
||||
if err != nil {
|
||||
return trail.FromGRPC(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) DeleteAccessRequest(reqID string) error {
|
||||
clt, err := c.grpc()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
_, err = clt.DeleteAccessRequest(context.TODO(), &proto.RequestID{
|
||||
ID: reqID,
|
||||
})
|
||||
if err != nil {
|
||||
return trail.FromGRPC(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SetAccessRequestState(reqID string, state services.RequestState) error {
|
||||
clt, err := c.grpc()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
_, err = clt.SetAccessRequestState(context.TODO(), &proto.RequestStateSetter{
|
||||
ID: reqID,
|
||||
State: state,
|
||||
})
|
||||
if err != nil {
|
||||
return trail.FromGRPC(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebService implements features used by Web UI clients
|
||||
type WebService interface {
|
||||
// GetWebSessionInfo checks if a web sesion is valid, returns session id in case if
|
||||
|
@ -2750,4 +2812,12 @@ type ClientI interface {
|
|||
// ProcessKubeCSR processes CSR request against Kubernetes CA, returns
|
||||
// signed certificate if sucessful.
|
||||
ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error)
|
||||
// GetAccessRequests lists all existing access requests.
|
||||
GetAccessRequests(services.AccessRequestFilter) ([]services.AccessRequest, error)
|
||||
// CreateAccessRequest creates a new access request.
|
||||
CreateAccessRequest(req services.AccessRequest) error
|
||||
// DeleteAccessRequest deletes an access request.
|
||||
DeleteAccessRequest(reqID string) error
|
||||
// SetAccessRequestState updates the state of an existing access request.
|
||||
SetAccessRequestState(reqID string, state services.RequestState) error
|
||||
}
|
||||
|
|
|
@ -85,6 +85,7 @@ func (g *GRPCServer) WatchEvents(watch *proto.Watch, stream proto.AuthService_Wa
|
|||
Name: kind.Name,
|
||||
Kind: kind.Kind,
|
||||
LoadSecrets: kind.LoadSecrets,
|
||||
Filter: kind.Filter,
|
||||
})
|
||||
}
|
||||
watcher, err := auth.NewWatcher(stream.Context(), servicesWatch)
|
||||
|
@ -185,6 +186,66 @@ func (g *GRPCServer) GetUsers(req *proto.GetUsersRequest, stream proto.AuthServi
|
|||
return nil
|
||||
}
|
||||
|
||||
func (g *GRPCServer) GetAccessRequests(ctx context.Context, f *services.AccessRequestFilter) (*proto.AccessRequests, error) {
|
||||
auth, err := g.authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
var filter services.AccessRequestFilter
|
||||
if f != nil {
|
||||
filter = *f
|
||||
}
|
||||
reqs, err := auth.AuthWithRoles.GetAccessRequests(filter)
|
||||
if err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
collector := make([]*services.AccessRequestV3, 0, len(reqs))
|
||||
for _, req := range reqs {
|
||||
r, ok := req.(*services.AccessRequestV3)
|
||||
if !ok {
|
||||
err = trace.BadParameter("unexpected access request type %T", req)
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
collector = append(collector, r)
|
||||
}
|
||||
return &proto.AccessRequests{
|
||||
AccessRequests: collector,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *GRPCServer) CreateAccessRequest(ctx context.Context, req *services.AccessRequestV3) (*empty.Empty, error) {
|
||||
auth, err := g.authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
if err := auth.AuthWithRoles.CreateAccessRequest(req); err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
return &empty.Empty{}, nil
|
||||
}
|
||||
|
||||
func (g *GRPCServer) DeleteAccessRequest(ctx context.Context, id *proto.RequestID) (*empty.Empty, error) {
|
||||
auth, err := g.authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
if err := auth.AuthWithRoles.DeleteAccessRequest(id.ID); err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
return &empty.Empty{}, nil
|
||||
}
|
||||
|
||||
func (g *GRPCServer) SetAccessRequestState(ctx context.Context, req *proto.RequestStateSetter) (*empty.Empty, error) {
|
||||
auth, err := g.authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
if err := auth.SetAccessRequestState(req.ID, req.State); err != nil {
|
||||
return nil, trail.ToGRPC(err)
|
||||
}
|
||||
return &empty.Empty{}, nil
|
||||
}
|
||||
|
||||
type grpcContext struct {
|
||||
*AuthContext
|
||||
*AuthWithRoles
|
||||
|
@ -305,6 +366,10 @@ func eventToGRPC(in services.Event) (*proto.Event, error) {
|
|||
out.Resource = &proto.Event_TunnelConnection{
|
||||
TunnelConnection: r,
|
||||
}
|
||||
case *services.AccessRequestV3:
|
||||
out.Resource = &proto.Event_AccessRequest{
|
||||
AccessRequest: r,
|
||||
}
|
||||
default:
|
||||
return nil, trace.BadParameter("resource type %T is not supported", in.Resource)
|
||||
}
|
||||
|
@ -371,6 +436,9 @@ func eventFromGRPC(in proto.Event) (*services.Event, error) {
|
|||
} else if r := in.GetTunnelConnection(); r != nil {
|
||||
out.Resource = r
|
||||
return &out, nil
|
||||
} else if r := in.GetAccessRequest(); r != nil {
|
||||
out.Resource = r
|
||||
return &out, nil
|
||||
} else {
|
||||
return nil, trace.BadParameter("received unsupported resource %T", in.Resource)
|
||||
}
|
||||
|
|
|
@ -641,6 +641,37 @@ type clt interface {
|
|||
UpsertUser(services.User) error
|
||||
}
|
||||
|
||||
// CreateUserRoleAndRequestable creates two roles for a user, one base role with allowed login
|
||||
// matching username, and another role with a login matching rolename that can be requested.
|
||||
func CreateUserRoleAndRequestable(clt clt, username string, rolename string) (services.User, error) {
|
||||
user, err := services.NewUser(username)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
baseRole := services.RoleForUser(user)
|
||||
baseRole.SetLogins(services.Allow, []string{username})
|
||||
baseRole.SetAccessRequestConditions(services.Allow, services.AccessRequestConditions{
|
||||
Roles: []string{rolename},
|
||||
})
|
||||
err = clt.UpsertRole(baseRole)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
user.AddRole(baseRole.GetName())
|
||||
err = clt.UpsertUser(user)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
requestableRole := services.RoleForUser(user)
|
||||
requestableRole.SetName(rolename)
|
||||
requestableRole.SetLogins(services.Allow, []string{rolename})
|
||||
err = clt.UpsertRole(requestableRole)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateUserAndRole creates user and role and assignes role to a user, used in tests
|
||||
func CreateUserAndRole(clt clt, username string, allowedLogins []string) (services.User, services.Role, error) {
|
||||
user, err := services.NewUser(username)
|
||||
|
|
|
@ -105,6 +105,9 @@ type InitConfig struct {
|
|||
// Access is service controlling access to resources
|
||||
Access services.Access
|
||||
|
||||
// DynamicAccess is a service that manages dynamic RBAC.
|
||||
DynamicAccess services.DynamicAccess
|
||||
|
||||
// Events is an event service
|
||||
Events services.Events
|
||||
|
||||
|
|
|
@ -291,6 +291,13 @@ func (k *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) {
|
|||
if c.RouteToCluster != "" {
|
||||
cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = c.RouteToCluster
|
||||
}
|
||||
if !c.ActiveRequests.IsEmpty() {
|
||||
requests, err := c.ActiveRequests.Marshal()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests)
|
||||
}
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(c.PrivateCASigningKey)
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -53,6 +53,8 @@ message Event {
|
|||
services.ReverseTunnelV2 ReverseTunnel = 12 [(gogoproto.jsontag) = "reverse_tunnel,omitempty"];
|
||||
// TunnelConnection is a resource for tunnel connnections
|
||||
services.TunnelConnectionV2 TunnelConnection = 13 [(gogoproto.jsontag) = "tunnel_connection,omitempty"];
|
||||
// AccessRequest is a resource for access requests
|
||||
services.AccessRequestV3 AccessRequest = 14 [(gogoproto.jsontag) = "access_request,omitempty"];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,6 +74,9 @@ message WatchKind {
|
|||
// if specified only the events with a specific resource
|
||||
// name will be sent
|
||||
string Name = 3 [(gogoproto.jsontag) = "name"];
|
||||
// Filter is an optional mapping of custom filter parameters.
|
||||
// Valid values vary by resource kind.
|
||||
map <string, string> Filter = 4 [(gogoproto.jsontag) = "filter,omitempty"];
|
||||
}
|
||||
|
||||
// Set of certificates corresponding to a single public key.
|
||||
|
@ -100,6 +105,9 @@ message UserCertsRequest {
|
|||
// so that requests originating with this certificate will be redirected
|
||||
// to this cluster
|
||||
string RouteToCluster = 5 [(gogoproto.jsontag) = "route_to_cluster,omitempty"];
|
||||
// AccessRequests is an optional list of request IDs indicating requests whose
|
||||
// escalated privileges should be added to the certificate.
|
||||
repeated string AccessRequests = 6 [(gogoproto.jsontag) = "access_requests,omitempty"];
|
||||
}
|
||||
|
||||
// GetUserRequest specifies paramters for the GetUser method.
|
||||
|
@ -116,6 +124,23 @@ message GetUsersRequest {
|
|||
bool WithSecrets = 1 [(gogoproto.jsontag) = "with_secrets"];
|
||||
}
|
||||
|
||||
// AccessRequests is a collection of AccessRequest values.
|
||||
message AccessRequests {
|
||||
repeated services.AccessRequestV3 AccessRequests = 1 [(gogoproto.jsontag) = "access_requests"];
|
||||
}
|
||||
|
||||
// RequestStateSetter encodes the paramters necessary to update the
|
||||
// state of a privilege escalation request.
|
||||
message RequestStateSetter {
|
||||
string ID = 1 [(gogoproto.jsontag) = "id"];
|
||||
services.RequestState State = 2 [(gogoproto.jsontag) = "state"];
|
||||
}
|
||||
|
||||
// RequestID is the unique identifier of an access request.
|
||||
message RequestID {
|
||||
string ID = 1 [(gogoproto.jsontag) = "id"];
|
||||
}
|
||||
|
||||
// AuthService is authentication/authorization service implementation
|
||||
service AuthService {
|
||||
// SendKeepAlives allows node to send a stream of keep alive requests
|
||||
|
@ -130,4 +155,12 @@ service AuthService {
|
|||
rpc GetUser(GetUserRequest) returns (services.UserV2);
|
||||
// GetUsers gets all current user resources.
|
||||
rpc GetUsers(GetUsersRequest) returns (stream services.UserV2);
|
||||
// GetAccessRequests gets all pending access requests.
|
||||
rpc GetAccessRequests(services.AccessRequestFilter) returns (AccessRequests);
|
||||
// CreateAccessRequest creates a new access request.
|
||||
rpc CreateAccessRequest(services.AccessRequestV3) returns (google.protobuf.Empty);
|
||||
// DeleteAccessRequest deletes an access request.
|
||||
rpc DeleteAccessRequest(RequestID) returns (google.protobuf.Empty);
|
||||
// SetAccessRequestState sets the state of an access request.
|
||||
rpc SetAccessRequestState(RequestStateSetter) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
|
|
@ -1357,6 +1357,99 @@ func (s *TLSSuite) TestGetCertAuthority(c *check.C) {
|
|||
fixtures.ExpectAccessDenied(c, err)
|
||||
}
|
||||
|
||||
func (s *TLSSuite) TestAccessRequest(c *check.C) {
|
||||
priv, pub, err := s.server.Auth().GenerateKeyPair("")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// make sure we can parse the private and public key
|
||||
privateKey, err := ssh.ParseRawPrivateKey(priv)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, _, _, _, err = ssh.ParseAuthorizedKey(pub)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
user := "user1"
|
||||
role := "some-role"
|
||||
_, err = CreateUserRoleAndRequestable(s.server.Auth(), user, role)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testUser := TestUser(user)
|
||||
testUser.TTL = time.Hour
|
||||
userClient, err := s.server.NewClient(testUser)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
req, err := services.NewAccessRequest(user, role)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(userClient.CreateAccessRequest(req), check.IsNil)
|
||||
|
||||
// sanity check; ensure that roles for which no `allow` directive
|
||||
// exists cannot be requested.
|
||||
badReq, err := services.NewAccessRequest(user, "some-fake-role")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userClient.CreateAccessRequest(badReq), check.NotNil)
|
||||
|
||||
// generateCerts executes a GenerateUserCerts request, optionally applying
|
||||
// one or more access-requests to the certificate.
|
||||
generateCerts := func(reqIDs ...string) (*proto.Certs, error) {
|
||||
return userClient.GenerateUserCerts(context.TODO(), proto.UserCertsRequest{
|
||||
PublicKey: pub,
|
||||
Username: user,
|
||||
Expires: time.Now().Add(time.Hour).UTC(),
|
||||
Format: teleport.CertificateFormatStandard,
|
||||
AccessRequests: reqIDs,
|
||||
})
|
||||
}
|
||||
|
||||
// certContainsRole checks if a PEM encoded TLS cert contains the
|
||||
// specified role.
|
||||
certContainsRole := func(cert []byte, role string) bool {
|
||||
tlsCert, err := tlsca.ParseCertificatePEM(cert)
|
||||
c.Assert(err, check.IsNil)
|
||||
identity, err := tlsca.FromSubject(tlsCert.Subject, tlsCert.NotAfter)
|
||||
c.Assert(err, check.IsNil)
|
||||
return utils.SliceContainsStr(identity.Groups, role)
|
||||
}
|
||||
|
||||
// sanity check; ensure that role is not held if no request is applied.
|
||||
userCerts, err := generateCerts()
|
||||
c.Assert(err, check.IsNil)
|
||||
if certContainsRole(userCerts.TLS, role) {
|
||||
c.Errorf("unexpected role %s", role)
|
||||
}
|
||||
|
||||
// attempt to apply request in PENDING state (should fail)
|
||||
_, err = generateCerts(req.GetName())
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
// verify that user does not have the ability to approve their own request (not a special case, this
|
||||
// user just wasn't created with the necessary roles for request management).
|
||||
c.Assert(userClient.SetAccessRequestState(req.GetName(), services.RequestState_APPROVED), check.NotNil)
|
||||
|
||||
// attempt to apply request in APPROVED state (should succeed)
|
||||
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_APPROVED), check.IsNil)
|
||||
userCerts, err = generateCerts(req.GetName())
|
||||
c.Assert(err, check.IsNil)
|
||||
// ensure that the requested role was actually applied to the cert
|
||||
if !certContainsRole(userCerts.TLS, role) {
|
||||
c.Errorf("missing requested role %s", role)
|
||||
}
|
||||
|
||||
// attempt to apply request in DENIED state (should fail)
|
||||
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_DENIED), check.IsNil)
|
||||
_, err = generateCerts(req.GetName())
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
// ensure that once in the DENIED state, a request cannot be set back to PENDING state.
|
||||
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_PENDING), check.NotNil)
|
||||
|
||||
// ensure that once in the DENIED state, a request cannot be set back to APPROVED state.
|
||||
c.Assert(s.server.Auth().SetAccessRequestState(req.GetName(), services.RequestState_APPROVED), check.NotNil)
|
||||
}
|
||||
|
||||
// TestGenerateCerts tests edge cases around authorization of
|
||||
// certificate generation for servers and users
|
||||
func (s *TLSSuite) TestGenerateCerts(c *check.C) {
|
||||
|
|
|
@ -297,6 +297,10 @@ type ProfileStatus struct {
|
|||
|
||||
// Traits hold claim data used to populate a role at runtime.
|
||||
Traits wrappers.Traits
|
||||
|
||||
// ActiveRequests tracks the privilege escalation requests applied
|
||||
// during certificate construction.
|
||||
ActiveRequests services.RequestIDs
|
||||
}
|
||||
|
||||
// IsExpired returns true if profile is not expired yet
|
||||
|
@ -397,13 +401,22 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
|
|||
}
|
||||
}
|
||||
|
||||
var activeRequests services.RequestIDs
|
||||
rawRequests, ok := cert.Extensions[teleport.CertExtensionTeleportActiveRequests]
|
||||
if ok {
|
||||
if err := activeRequests.Unmarshal([]byte(rawRequests)); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract extensions from certificate. This lists the abilities of the
|
||||
// certificate (like can the user request a PTY, port forwarding, etc.)
|
||||
var extensions []string
|
||||
for ext, _ := range cert.Extensions {
|
||||
if ext == teleport.CertExtensionTeleportRoles ||
|
||||
ext == teleport.CertExtensionTeleportTraits ||
|
||||
ext == teleport.CertExtensionTeleportRouteToCluster {
|
||||
ext == teleport.CertExtensionTeleportRouteToCluster ||
|
||||
ext == teleport.CertExtensionTeleportActiveRequests {
|
||||
continue
|
||||
}
|
||||
extensions = append(extensions, ext)
|
||||
|
@ -426,13 +439,14 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error)
|
|||
Scheme: "https",
|
||||
Host: profile.WebProxyAddr,
|
||||
},
|
||||
Username: profile.Username,
|
||||
Logins: cert.ValidPrincipals,
|
||||
ValidUntil: validUntil,
|
||||
Extensions: extensions,
|
||||
Roles: roles,
|
||||
Cluster: clusterName,
|
||||
Traits: traits,
|
||||
Username: profile.Username,
|
||||
Logins: cert.ValidPrincipals,
|
||||
ValidUntil: validUntil,
|
||||
Extensions: extensions,
|
||||
Roles: roles,
|
||||
Cluster: clusterName,
|
||||
Traits: traits,
|
||||
ActiveRequests: activeRequests,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -875,6 +889,38 @@ func (tc *TeleportClient) GenerateCertsForCluster(ctx context.Context, routeToCl
|
|||
return proxyClient.GenerateCertsForCluster(ctx, routeToCluster)
|
||||
}
|
||||
|
||||
func (tc *TeleportClient) ReissueUserCerts(ctx context.Context, params ReissueParams) error {
|
||||
proxyClient, err := tc.ConnectToProxy(ctx)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return proxyClient.ReissueUserCerts(ctx, params)
|
||||
}
|
||||
|
||||
func (tc *TeleportClient) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error {
|
||||
proxyClient, err := tc.ConnectToProxy(ctx)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return proxyClient.CreateAccessRequest(ctx, req)
|
||||
}
|
||||
|
||||
func (tc *TeleportClient) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
|
||||
proxyClient, err := tc.ConnectToProxy(ctx)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return proxyClient.GetAccessRequests(ctx, filter)
|
||||
}
|
||||
|
||||
func (tc *TeleportClient) NewWatcher(ctx context.Context, watch services.Watch) (services.Watcher, error) {
|
||||
proxyClient, err := tc.ConnectToProxy(ctx)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return proxyClient.NewWatcher(ctx, watch)
|
||||
}
|
||||
|
||||
// SSH connects to a node and, if 'command' is specified, executes the command on it,
|
||||
// otherwise runs interactive shell
|
||||
//
|
||||
|
|
|
@ -166,6 +166,104 @@ func (proxy *ProxyClient) GenerateCertsForCluster(ctx context.Context, routeToCl
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// ReissueParams encodes optional paramters for
|
||||
// user certificate reissue.
|
||||
type ReissueParams struct {
|
||||
RouteToCluster string
|
||||
AccessRequests []string
|
||||
}
|
||||
|
||||
// ReissueUserCerts generates certificates for the user
|
||||
// that have a metadata instructing server to route the requests to the cluster
|
||||
func (proxy *ProxyClient) ReissueUserCerts(ctx context.Context, params ReissueParams) error {
|
||||
localAgent := proxy.teleportClient.LocalAgent()
|
||||
key, err := localAgent.GetKey()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
cert, err := key.SSHCert()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
tlsCert, err := key.TLSCertificate()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
clusterName, err := tlsca.ClusterName(tlsCert.Issuer)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
clt, err := proxy.ConnectToCluster(ctx, clusterName, true)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
if params.RouteToCluster != "" {
|
||||
// Before requesting a certificate, check if the requested cluster is valid.
|
||||
_, err = clt.GetCertAuthority(services.CertAuthID{
|
||||
Type: services.HostCA,
|
||||
DomainName: params.RouteToCluster,
|
||||
}, false)
|
||||
if err != nil {
|
||||
return trace.NotFound("cluster %v not found", params.RouteToCluster)
|
||||
}
|
||||
}
|
||||
req := proto.UserCertsRequest{
|
||||
Username: cert.KeyId,
|
||||
PublicKey: key.Pub,
|
||||
Expires: time.Unix(int64(cert.ValidBefore), 0),
|
||||
RouteToCluster: params.RouteToCluster,
|
||||
AccessRequests: params.AccessRequests,
|
||||
}
|
||||
if _, ok := cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles]; !ok {
|
||||
req.Format = teleport.CertificateFormatOldSSH
|
||||
}
|
||||
|
||||
certs, err := clt.GenerateUserCerts(ctx, req)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
key.Cert = certs.SSH
|
||||
key.TLSCert = certs.TLS
|
||||
|
||||
// save the cert to the local storage (~/.tsh usually):
|
||||
_, err = localAgent.AddKey(key)
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// CreateAccessRequest attempts to create a new request for escalated privilege.
|
||||
func (proxy *ProxyClient) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error {
|
||||
site, err := proxy.ConnectToCurrentCluster(ctx, false)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return site.CreateAccessRequest(req)
|
||||
}
|
||||
|
||||
func (proxy *ProxyClient) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
|
||||
site, err := proxy.ConnectToCurrentCluster(ctx, false)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
reqs, err := site.GetAccessRequests(filter)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return reqs, nil
|
||||
}
|
||||
|
||||
func (proxy *ProxyClient) NewWatcher(ctx context.Context, watch services.Watch) (services.Watcher, error) {
|
||||
site, err := proxy.ConnectToCurrentCluster(ctx, false)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
watcher, err := site.NewWatcher(ctx, watch)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
// FindServersByLabels returns list of the nodes which have labels exactly matching
|
||||
// the given label set.
|
||||
//
|
||||
|
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -53,6 +55,10 @@ const (
|
|||
// two different files (in the same directory)
|
||||
IdentityFormatOpenSSH IdentityFileFormat = "openssh"
|
||||
|
||||
// IdentityFormatTLS is a standard TLS format used by common TLS clients (e.g. GRPC) where
|
||||
// certificate and key are stored in separate files.
|
||||
IdentityFormatTLS IdentityFileFormat = "tls"
|
||||
|
||||
// DefaultIdentityFormat is what Teleport uses by default
|
||||
DefaultIdentityFormat = IdentityFormatFile
|
||||
)
|
||||
|
@ -130,9 +136,119 @@ func MakeIdentityFile(filePath string, key *Key, format IdentityFileFormat, cert
|
|||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
case IdentityFormatTLS:
|
||||
keyPath := filePath + ".key"
|
||||
certPath := filePath + ".crt"
|
||||
casPath := filePath + ".cas"
|
||||
|
||||
err = ioutil.WriteFile(certPath, key.TLSCert, fileMode)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(keyPath, key.Priv, fileMode)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
var caCerts []byte
|
||||
for _, ca := range certAuthorities {
|
||||
for _, keyPair := range ca.GetTLSKeyPairs() {
|
||||
caCerts = append(caCerts, keyPair.Cert...)
|
||||
}
|
||||
}
|
||||
err = ioutil.WriteFile(casPath, caCerts, fileMode)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
default:
|
||||
return trace.BadParameter("unsupported identity format: %q, use either %q or %q",
|
||||
format, IdentityFormatFile, IdentityFormatOpenSSH)
|
||||
return trace.BadParameter("unsupported identity format: %q, use one of %q, %q, or %q",
|
||||
format, IdentityFormatFile, IdentityFormatOpenSSH, IdentityFormatTLS)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IdentityFile represents the basic components of an identity file.
|
||||
type IdentityFile struct {
|
||||
PrivateKey []byte
|
||||
Certs struct {
|
||||
SSH []byte
|
||||
TLS []byte
|
||||
}
|
||||
CACerts struct {
|
||||
SSH [][]byte
|
||||
TLS [][]byte
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeIdentityFile attempts to break up the contents of an identity file
|
||||
// into its respective components.
|
||||
func DecodeIdentityFile(r io.Reader) (*IdentityFile, error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
var ident IdentityFile
|
||||
// Subslice of scanner's buffer pointing to current line
|
||||
// with leading and trailing whitespace trimmed.
|
||||
var line []byte
|
||||
// Attempt to scan to the next line.
|
||||
scanln := func() bool {
|
||||
if !scanner.Scan() {
|
||||
line = nil
|
||||
return false
|
||||
}
|
||||
line = bytes.TrimSpace(scanner.Bytes())
|
||||
return true
|
||||
}
|
||||
// Check if the current line starts with prefix `p`.
|
||||
peekln := func(p string) bool {
|
||||
return bytes.HasPrefix(line, []byte(p))
|
||||
}
|
||||
// Get an "owned" copy of the current line.
|
||||
cloneln := func() []byte {
|
||||
ln := make([]byte, len(line))
|
||||
copy(ln, line)
|
||||
return ln
|
||||
}
|
||||
// Scan through all lines of identity file. Lines with a known prefix
|
||||
// are copied out of the scanner's buffer. All others are ignored.
|
||||
for scanln() {
|
||||
switch {
|
||||
case peekln("ssh"):
|
||||
ident.Certs.SSH = cloneln()
|
||||
case peekln("@cert-authority"):
|
||||
ident.CACerts.SSH = append(ident.CACerts.SSH, cloneln())
|
||||
case peekln("-----BEGIN"):
|
||||
// Current line marks the beginning of a PEM block. Consume all
|
||||
// lines until a corresponding END is found.
|
||||
var pemBlock []byte
|
||||
for {
|
||||
pemBlock = append(pemBlock, line...)
|
||||
pemBlock = append(pemBlock, '\n')
|
||||
if peekln("-----END") {
|
||||
break
|
||||
}
|
||||
if !scanln() {
|
||||
// If scanner has terminated in the middle of a PEM block, either
|
||||
// the reader encountered an error, or the PEM block is a fragment.
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return nil, trace.BadParameter("invalid PEM block (fragment)")
|
||||
}
|
||||
}
|
||||
// Decide where to place the pem block based on
|
||||
// which pem blocks have already been found.
|
||||
switch {
|
||||
case ident.PrivateKey == nil:
|
||||
ident.PrivateKey = pemBlock
|
||||
case ident.Certs.TLS == nil:
|
||||
ident.Certs.TLS = pemBlock
|
||||
default:
|
||||
ident.CACerts.TLS = append(ident.CACerts.TLS, pemBlock)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return &ident, nil
|
||||
}
|
||||
|
|
|
@ -368,6 +368,11 @@ const (
|
|||
// certificate rotations, by default to set to maximum allowed user
|
||||
// cert duration
|
||||
RotationGracePeriod = MaxCertDuration
|
||||
// PendingAccessDuration defines the expiry of a pending access request.
|
||||
PendingAccessDuration = time.Hour
|
||||
// MaxAccessDuration defines the maximum time for which an access request
|
||||
// can be active.
|
||||
MaxAccessDuration = MaxCertDuration
|
||||
)
|
||||
|
||||
// list of roles teleport service can run as:
|
||||
|
|
|
@ -139,6 +139,17 @@ const (
|
|||
// UserConnector is the connector used to create the user.
|
||||
UserConnector = "connector"
|
||||
|
||||
// AccessRequestCreateEvent is emitted when a new access request is created.
|
||||
AccessRequestCreateEvent = "access_request.create"
|
||||
// AccessRequestUpdateEvent is emitted when a request's state is updated.
|
||||
AccessRequestUpdateEvent = "access_request.update"
|
||||
// AccessRequestUpdateBy indicates the user that updated the request state.
|
||||
AccessRequestUpdateBy = "updated_by"
|
||||
// AccessRequestState is the state of a request.
|
||||
AccessRequestState = "state"
|
||||
// AccessRequestID is the ID of an access request.
|
||||
AccessRequestID = "id"
|
||||
|
||||
// ExecEvent is an exec command executed by script or user on
|
||||
// the server side
|
||||
ExecEvent = "exec"
|
||||
|
|
|
@ -150,6 +150,15 @@ var (
|
|||
Name: AuthAttemptEvent,
|
||||
Code: AuthAttemptFailureCode,
|
||||
}
|
||||
// AccessRequestCreated is emitted when an access request is created.
|
||||
AccessRequestCreated = Event{
|
||||
Name: AccessRequestCreateEvent,
|
||||
Code: AccessRequestCreateCode,
|
||||
}
|
||||
AccessRequestUpdated = Event{
|
||||
Name: AccessRequestUpdateEvent,
|
||||
Code: AccessRequestUpdateCode,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -203,4 +212,8 @@ var (
|
|||
ClientDisconnectCode = "T3006I"
|
||||
// AuthAttemptFailureCode is the auth attempt failure event code.
|
||||
AuthAttemptFailureCode = "T3007W"
|
||||
// AccessRequestCreateCode is the the access request creation code.
|
||||
AccessRequestCreateCode = "T5000I"
|
||||
// AccessRequestUpdateCode is the access request state update code.
|
||||
AccessRequestUpdateCode = "T5001I"
|
||||
)
|
||||
|
|
508
lib/services/access_request.go
Normal file
508
lib/services/access_request.go
Normal file
|
@ -0,0 +1,508 @@
|
|||
/*
|
||||
Copyright 2019 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
// RequestIDs is a collection of IDs for privelege escalation requests.
|
||||
type RequestIDs struct {
|
||||
AccessRequests []string `json:"access_requests,omitempty"`
|
||||
}
|
||||
|
||||
func (r *RequestIDs) Marshal() ([]byte, error) {
|
||||
data, err := utils.FastMarshal(r)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (r *RequestIDs) Unmarshal(data []byte) error {
|
||||
if err := utils.FastUnmarshal(data, r); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return trace.Wrap(r.Check())
|
||||
}
|
||||
|
||||
func (r *RequestIDs) Check() error {
|
||||
for _, id := range r.AccessRequests {
|
||||
if uuid.Parse(id) == nil {
|
||||
return trace.BadParameter("invalid request id %q", id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RequestIDs) IsEmpty() bool {
|
||||
return len(r.AccessRequests) < 1
|
||||
}
|
||||
|
||||
// stateVariants allows iteration of the expected variants
|
||||
// of RequestState.
|
||||
var stateVariants = [4]RequestState{
|
||||
RequestState_NONE,
|
||||
RequestState_PENDING,
|
||||
RequestState_APPROVED,
|
||||
RequestState_DENIED,
|
||||
}
|
||||
|
||||
// Parse attempts to interpret a value as a string representation
|
||||
// of a RequestState.
|
||||
func (s *RequestState) Parse(val string) error {
|
||||
for _, state := range stateVariants {
|
||||
if state.String() == val {
|
||||
*s = state
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return trace.BadParameter("unknown request state: %q", val)
|
||||
}
|
||||
|
||||
// key values for map encoding of request filter
|
||||
const (
|
||||
keyID = "id"
|
||||
keyUser = "user"
|
||||
keyState = "state"
|
||||
)
|
||||
|
||||
func (f *AccessRequestFilter) IntoMap() map[string]string {
|
||||
m := make(map[string]string)
|
||||
if f.ID != "" {
|
||||
m[keyID] = f.ID
|
||||
}
|
||||
if f.User != "" {
|
||||
m[keyUser] = f.User
|
||||
}
|
||||
if !f.State.IsNone() {
|
||||
m[keyState] = f.State.String()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (f *AccessRequestFilter) FromMap(m map[string]string) error {
|
||||
for key, val := range m {
|
||||
switch key {
|
||||
case keyID:
|
||||
f.ID = val
|
||||
case keyUser:
|
||||
f.User = val
|
||||
case keyState:
|
||||
if err := f.State.Parse(val); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
default:
|
||||
return trace.BadParameter("unknown filter key %s", key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Match checks if a given access request matches this filter.
|
||||
func (f *AccessRequestFilter) Match(req AccessRequest) bool {
|
||||
if f.ID != "" && req.GetName() != f.ID {
|
||||
return false
|
||||
}
|
||||
if f.User != "" && req.GetUser() != f.User {
|
||||
return false
|
||||
}
|
||||
if !f.State.IsNone() && req.GetState() != f.State {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *AccessRequestFilter) Equals(o AccessRequestFilter) bool {
|
||||
return f.ID == o.ID && f.User == o.User && f.State == o.State
|
||||
}
|
||||
|
||||
// DynamicAccess is a service which manages dynamic RBAC.
|
||||
type DynamicAccess interface {
|
||||
// CreateAccessRequest stores a new access request.
|
||||
CreateAccessRequest(AccessRequest) error
|
||||
// SetAccessRequestState updates the state of an existing access request.
|
||||
SetAccessRequestState(reqID string, state RequestState) error
|
||||
// GetAccessRequest gets an access request by name (uuid).
|
||||
GetAccessRequest(string) (AccessRequest, error)
|
||||
// GetAccessRequests gets all currently active access requests.
|
||||
GetAccessRequests(AccessRequestFilter) ([]AccessRequest, error)
|
||||
// DeleteAccessRequest deletes an access request.
|
||||
DeleteAccessRequest(string) error
|
||||
}
|
||||
|
||||
// AccessRequest is a request for temporarily granted roles
|
||||
type AccessRequest interface {
|
||||
Resource
|
||||
// GetUser gets the name of the requesting user
|
||||
GetUser() string
|
||||
// GetRoles gets the roles being requested by the user
|
||||
GetRoles() []string
|
||||
// GetState gets the current state of the request
|
||||
GetState() RequestState
|
||||
// SetState sets the approval state of the request
|
||||
SetState(RequestState) error
|
||||
// GetCreationTime gets the time at which the request was
|
||||
// originally registered with the auth server.
|
||||
GetCreationTime() time.Time
|
||||
// SetCreationTime sets the creation time of the request.
|
||||
SetCreationTime(time.Time)
|
||||
// GetAccessExpiry gets the upper limit for which this request
|
||||
// may be considered active.
|
||||
GetAccessExpiry() time.Time
|
||||
// SetAccessExpiry sets the upper limit for which this request
|
||||
// may be considered active.
|
||||
SetAccessExpiry(time.Time)
|
||||
// CheckAndSetDefaults validates the access request and
|
||||
// supplies default values where appropriate.
|
||||
CheckAndSetDefaults() error
|
||||
// Equals checks equality between access request values.
|
||||
Equals(AccessRequest) bool
|
||||
}
|
||||
|
||||
func (s RequestState) IsNone() bool {
|
||||
return s == RequestState_NONE
|
||||
}
|
||||
|
||||
func (s RequestState) IsPending() bool {
|
||||
return s == RequestState_PENDING
|
||||
}
|
||||
|
||||
func (s RequestState) IsApproved() bool {
|
||||
return s == RequestState_APPROVED
|
||||
}
|
||||
|
||||
func (s RequestState) IsDenied() bool {
|
||||
return s == RequestState_DENIED
|
||||
}
|
||||
|
||||
// NewAccessRequest assembled an AccessReqeust resource.
|
||||
func NewAccessRequest(user string, roles ...string) (AccessRequest, error) {
|
||||
req := AccessRequestV3{
|
||||
Kind: KindAccessRequest,
|
||||
Version: V3,
|
||||
Metadata: Metadata{
|
||||
Name: uuid.New(),
|
||||
},
|
||||
Spec: AccessRequestSpecV3{
|
||||
User: user,
|
||||
Roles: roles,
|
||||
State: RequestState_PENDING,
|
||||
},
|
||||
}
|
||||
if err := req.CheckAndSetDefaults(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
type UserAndRoleGetter interface {
|
||||
UserGetter
|
||||
RoleGetter
|
||||
}
|
||||
|
||||
func ValidateAccessRequest(getter UserAndRoleGetter, req AccessRequest) error {
|
||||
user, err := getter.GetUser(req.GetUser(), false)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
type rstate struct {
|
||||
allowed bool
|
||||
denied bool
|
||||
}
|
||||
roleStates := make(map[string]rstate, len(req.GetRoles()))
|
||||
for _, r := range req.GetRoles() {
|
||||
roleStates[r] = rstate{false, false}
|
||||
}
|
||||
for _, roleName := range user.GetRoles() {
|
||||
role, err := getter.GetRole(roleName)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
Allow:
|
||||
for _, r := range role.GetAccessRequestConditions(Allow).Roles {
|
||||
s, ok := roleStates[r]
|
||||
if !ok {
|
||||
continue Allow
|
||||
}
|
||||
s.allowed = true
|
||||
roleStates[r] = s
|
||||
}
|
||||
Deny:
|
||||
for _, r := range role.GetAccessRequestConditions(Deny).Roles {
|
||||
s, ok := roleStates[r]
|
||||
if !ok {
|
||||
continue Deny
|
||||
}
|
||||
s.denied = true
|
||||
roleStates[r] = s
|
||||
}
|
||||
}
|
||||
for roleName, roleState := range roleStates {
|
||||
if roleState.denied || !roleState.allowed {
|
||||
return trace.BadParameter("user %q cannot request role %q", req.GetUser(), roleName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetUser() string {
|
||||
return r.Spec.User
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetRoles() []string {
|
||||
return r.Spec.Roles
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetState() RequestState {
|
||||
return r.Spec.State
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetState(state RequestState) error {
|
||||
if r.Spec.State.IsDenied() {
|
||||
if state.IsDenied() {
|
||||
return nil
|
||||
}
|
||||
return trace.BadParameter("cannot set request-state %q (already denied)", state.String())
|
||||
}
|
||||
r.Spec.State = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetCreationTime() time.Time {
|
||||
return r.Spec.Created
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetCreationTime(t time.Time) {
|
||||
r.Spec.Created = t
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetAccessExpiry() time.Time {
|
||||
return r.Spec.Expires
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetAccessExpiry(expiry time.Time) {
|
||||
r.Spec.Expires = expiry
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) CheckAndSetDefaults() error {
|
||||
if err := r.Metadata.CheckAndSetDefaults(); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if r.GetState().IsNone() {
|
||||
r.SetState(RequestState_PENDING)
|
||||
}
|
||||
if err := r.Check(); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) Check() error {
|
||||
if r.Kind == "" {
|
||||
return trace.BadParameter("access request kind not set")
|
||||
}
|
||||
if r.Version == "" {
|
||||
return trace.BadParameter("access request version not set")
|
||||
}
|
||||
if r.GetName() == "" {
|
||||
return trace.BadParameter("access request id not set")
|
||||
}
|
||||
if uuid.Parse(r.GetName()) == nil {
|
||||
return trace.BadParameter("invalid access request id %q", r.GetName())
|
||||
}
|
||||
if r.GetUser() == "" {
|
||||
return trace.BadParameter("access request user name not set")
|
||||
}
|
||||
if len(r.GetRoles()) < 1 {
|
||||
return trace.BadParameter("access request does not specify any roles")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) Equals(other AccessRequest) bool {
|
||||
o, ok := other.(*AccessRequestV3)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if r.GetName() != o.GetName() {
|
||||
return false
|
||||
}
|
||||
return r.Spec.Equals(&o.Spec)
|
||||
}
|
||||
|
||||
func (s *AccessRequestSpecV3) Equals(other *AccessRequestSpecV3) bool {
|
||||
if s.User != other.User {
|
||||
return false
|
||||
}
|
||||
if len(s.Roles) != len(other.Roles) {
|
||||
return false
|
||||
}
|
||||
for i, role := range s.Roles {
|
||||
if role != other.Roles[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if s.Created != other.Created {
|
||||
return false
|
||||
}
|
||||
if s.Expires != other.Expires {
|
||||
return false
|
||||
}
|
||||
return s.State == other.State
|
||||
}
|
||||
|
||||
type AccessRequestMarshaler interface {
|
||||
MarshalAccessRequest(req AccessRequest, opts ...MarshalOption) ([]byte, error)
|
||||
UnmarshalAccessRequest(bytes []byte, opts ...MarshalOption) (AccessRequest, error)
|
||||
}
|
||||
|
||||
type accessRequestMarshaler struct{}
|
||||
|
||||
func (r *accessRequestMarshaler) MarshalAccessRequest(req AccessRequest, opts ...MarshalOption) ([]byte, error) {
|
||||
cfg, err := collectOptions(opts)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
switch r := req.(type) {
|
||||
case *AccessRequestV3:
|
||||
if !cfg.PreserveResourceID {
|
||||
// avoid modifying the original object
|
||||
// to prevent unexpected data races
|
||||
cp := *r
|
||||
cp.SetResourceID(0)
|
||||
r = &cp
|
||||
}
|
||||
return utils.FastMarshal(r)
|
||||
default:
|
||||
return nil, trace.BadParameter("unrecognized access request type: %T", req)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accessRequestMarshaler) UnmarshalAccessRequest(data []byte, opts ...MarshalOption) (AccessRequest, error) {
|
||||
cfg, err := collectOptions(opts)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
var req AccessRequestV3
|
||||
if cfg.SkipValidation {
|
||||
if err := utils.FastUnmarshal(data, &req); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
} else {
|
||||
if err := utils.UnmarshalWithSchema(GetAccessRequestSchema(), &req, data); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
if err := req.CheckAndSetDefaults(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if cfg.ID != 0 {
|
||||
req.SetResourceID(cfg.ID)
|
||||
}
|
||||
if !cfg.Expires.IsZero() {
|
||||
req.SetExpiry(cfg.Expires)
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
var accessRequestMarshalerInstance AccessRequestMarshaler = &accessRequestMarshaler{}
|
||||
|
||||
func GetAccessRequestMarshaler() AccessRequestMarshaler {
|
||||
marshalerMutex.Lock()
|
||||
defer marshalerMutex.Unlock()
|
||||
return accessRequestMarshalerInstance
|
||||
}
|
||||
|
||||
const AccessRequestSpecSchema = `{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"user": { "type": "string" },
|
||||
"roles": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" }
|
||||
},
|
||||
"state": { "type": "integer" },
|
||||
"created": { "type": "string" },
|
||||
"expires": { "type": "string" }
|
||||
}
|
||||
}`
|
||||
|
||||
func GetAccessRequestSchema() string {
|
||||
return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, AccessRequestSpecSchema, DefaultDefinitions)
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetKind() string {
|
||||
return r.Kind
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetSubKind() string {
|
||||
return r.SubKind
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetSubKind(subKind string) {
|
||||
r.SubKind = subKind
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetVersion() string {
|
||||
return r.Version
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetName() string {
|
||||
return r.Metadata.Name
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetName(name string) {
|
||||
r.Metadata.Name = name
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) Expiry() time.Time {
|
||||
return r.Metadata.Expiry()
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetExpiry(expiry time.Time) {
|
||||
r.Metadata.SetExpiry(expiry)
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetTTL(clock clockwork.Clock, ttl time.Duration) {
|
||||
r.Metadata.SetTTL(clock, ttl)
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetMetadata() Metadata {
|
||||
return r.Metadata
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) GetResourceID() int64 {
|
||||
return r.Metadata.GetID()
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) SetResourceID(id int64) {
|
||||
r.Metadata.SetID(id)
|
||||
}
|
||||
|
||||
func (r *AccessRequestV3) String() string {
|
||||
return fmt.Sprintf("AccessRequest(user=%v,roles=%+v)", r.Spec.User, r.Spec.Roles)
|
||||
}
|
129
lib/services/access_request_test.go
Normal file
129
lib/services/access_request_test.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
Copyright 2019 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
. "gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
type AccessRequestSuite struct {
|
||||
}
|
||||
|
||||
var _ = Suite(&AccessRequestSuite{})
|
||||
var _ = fmt.Printf
|
||||
|
||||
func (s *AccessRequestSuite) SetUpSuite(c *C) {
|
||||
utils.InitLoggerForTests()
|
||||
}
|
||||
|
||||
// TestRequestMarshaling verifies that marshaling/unmarshaling access requests
|
||||
// works as expected (failures likely indicate a problem with json schema).
|
||||
func (s *AccessRequestSuite) TestRequestMarshaling(c *C) {
|
||||
req1, err := NewAccessRequest("some-user", "role-1", "role-2")
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
marshaled, err := GetAccessRequestMarshaler().MarshalAccessRequest(req1)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
req2, err := GetAccessRequestMarshaler().UnmarshalAccessRequest(marshaled)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
if !req1.Equals(req2) {
|
||||
c.Errorf("unexpected inequality %+v <---> %+v", req1, req2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestFilterMatching verifies expected matching behavior for AccessRequestFilter.
|
||||
func (s *AccessRequestSuite) TestRequestFilterMatching(c *C) {
|
||||
reqA, err := NewAccessRequest("alice", "role-a")
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
reqB, err := NewAccessRequest("bob", "role-b")
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
testCases := []struct {
|
||||
user string
|
||||
id string
|
||||
matchA bool
|
||||
matchB bool
|
||||
}{
|
||||
{"", "", true, true},
|
||||
{"alice", "", true, false},
|
||||
{"", reqA.GetName(), true, false},
|
||||
{"bob", reqA.GetName(), false, false},
|
||||
{"carol", "", false, false},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
m := AccessRequestFilter{
|
||||
User: tc.user,
|
||||
ID: tc.id,
|
||||
}
|
||||
if m.Match(reqA) != tc.matchA {
|
||||
c.Errorf("bad filter behavior (a) %+v", tc)
|
||||
}
|
||||
if m.Match(reqB) != tc.matchB {
|
||||
c.Errorf("bad filter behavior (b) %+v", tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestFilterConversion verifies that filters convert to and from
|
||||
// maps correctly.
|
||||
func (s *AccessRequestSuite) TestRequestFilterConversion(c *C) {
|
||||
testCases := []struct {
|
||||
f AccessRequestFilter
|
||||
m map[string]string
|
||||
}{
|
||||
{
|
||||
AccessRequestFilter{User: "alice", ID: "foo", State: RequestState_PENDING},
|
||||
map[string]string{"user": "alice", "id": "foo", "state": "PENDING"},
|
||||
},
|
||||
{
|
||||
AccessRequestFilter{User: "bob"},
|
||||
map[string]string{"user": "bob"},
|
||||
},
|
||||
{
|
||||
AccessRequestFilter{},
|
||||
map[string]string{},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
|
||||
if m := tc.f.IntoMap(); !utils.StringMapsEqual(m, tc.m) {
|
||||
c.Errorf("bad map encoding: expected %+v, got %+v", tc.m, m)
|
||||
}
|
||||
var f AccessRequestFilter
|
||||
if err := f.FromMap(tc.m); err != nil {
|
||||
c.Errorf("failed to parse %+v: %s", tc.m, err)
|
||||
}
|
||||
if !f.Equals(tc.f) {
|
||||
c.Errorf("bad map decoding: expected %+v, got %+v", tc.f, f)
|
||||
}
|
||||
}
|
||||
badMaps := []map[string]string{
|
||||
{"food": "carrots"},
|
||||
{"state": "homesick"},
|
||||
}
|
||||
for _, m := range badMaps {
|
||||
var f AccessRequestFilter
|
||||
c.Assert(f.FromMap(m), NotNil)
|
||||
}
|
||||
}
|
|
@ -111,6 +111,9 @@ type UserCertParams struct {
|
|||
RouteToCluster string
|
||||
// Traits hold claim data used to populate a role at runtime.
|
||||
Traits wrappers.Traits
|
||||
// ActiveRequests tracks privilege escalation requests applied during
|
||||
// certificate construction.
|
||||
ActiveRequests RequestIDs
|
||||
}
|
||||
|
||||
// CertRoles defines certificate roles
|
||||
|
|
|
@ -48,6 +48,9 @@ type WatchKind struct {
|
|||
Name string
|
||||
// LoadSecrets specifies whether to load secrets
|
||||
LoadSecrets bool
|
||||
// Filter supplies custom event filter parameters that differ by
|
||||
// resource (e.g. "state":"pending" for access requests).
|
||||
Filter map[string]string
|
||||
}
|
||||
|
||||
// Event represents an event that happened in the backend
|
||||
|
|
164
lib/services/local/dynamic_access.go
Normal file
164
lib/services/local/dynamic_access.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
/*
|
||||
Copyright 2019 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package local
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/gravitational/teleport/lib/backend"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// DynamicAccessService manages dynamic RBAC
|
||||
type DynamicAccessService struct {
|
||||
backend.Backend
|
||||
}
|
||||
|
||||
// NewDynamicAccessService returns new dynamic access service instance
|
||||
func NewDynamicAccessService(backend backend.Backend) *AccessService {
|
||||
return &AccessService{Backend: backend}
|
||||
}
|
||||
|
||||
func (s *AccessService) CreateAccessRequest(req services.AccessRequest) error {
|
||||
if err := req.CheckAndSetDefaults(); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
item, err := itemFromAccessRequest(req)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if _, err := s.Create(context.TODO(), item); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccessService) SetAccessRequestState(name string, state services.RequestState) error {
|
||||
item, err := s.Get(context.TODO(), accessRequestKey(name))
|
||||
if err != nil {
|
||||
if trace.IsNotFound(err) {
|
||||
return trace.NotFound("cannot set state of access request %q (not found)", name)
|
||||
}
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
req, err := itemToAccessRequest(*item)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if err := req.SetState(state); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
// approved requests should have a resource expiry which matches
|
||||
// the underlying access expiry.
|
||||
if state.IsApproved() {
|
||||
req.SetExpiry(req.GetAccessExpiry())
|
||||
}
|
||||
newItem, err := itemFromAccessRequest(req)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if _, err := s.CompareAndSwap(context.TODO(), *item, newItem); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccessService) GetAccessRequest(name string) (services.AccessRequest, error) {
|
||||
item, err := s.Get(context.TODO(), accessRequestKey(name))
|
||||
if err != nil {
|
||||
if trace.IsNotFound(err) {
|
||||
return nil, trace.NotFound("access request %q not found", name)
|
||||
}
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
req, err := itemToAccessRequest(*item)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *AccessService) GetAccessRequests(filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
|
||||
result, err := s.GetRange(context.TODO(), backend.Key(accessRequestsPrefix), backend.RangeEnd(backend.Key(accessRequestsPrefix)), backend.NoLimit)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
var requests []services.AccessRequest
|
||||
for _, item := range result.Items {
|
||||
if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) {
|
||||
continue
|
||||
}
|
||||
req, err := itemToAccessRequest(item)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if !filter.Match(req) {
|
||||
// TODO(fspmarshall): optimize filtering to
|
||||
// avoid full query/iteration in some cases.
|
||||
continue
|
||||
}
|
||||
requests = append(requests, req)
|
||||
}
|
||||
return requests, nil
|
||||
}
|
||||
|
||||
func (s *AccessService) DeleteAccessRequest(name string) error {
|
||||
err := s.Delete(context.TODO(), accessRequestKey(name))
|
||||
if err != nil {
|
||||
if trace.IsNotFound(err) {
|
||||
return trace.NotFound("cannot delete access request %q (not found)", name)
|
||||
}
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func itemFromAccessRequest(req services.AccessRequest) (backend.Item, error) {
|
||||
value, err := services.GetAccessRequestMarshaler().MarshalAccessRequest(req)
|
||||
if err != nil {
|
||||
return backend.Item{}, trace.Wrap(err)
|
||||
}
|
||||
return backend.Item{
|
||||
Key: accessRequestKey(req.GetName()),
|
||||
Value: value,
|
||||
Expires: req.Expiry(),
|
||||
ID: req.GetResourceID(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func itemToAccessRequest(item backend.Item) (services.AccessRequest, error) {
|
||||
req, err := services.GetAccessRequestMarshaler().UnmarshalAccessRequest(
|
||||
item.Value,
|
||||
services.WithResourceID(item.ID),
|
||||
services.WithExpires(item.Expires),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func accessRequestKey(name string) []byte {
|
||||
return backend.Key(accessRequestsPrefix, name, paramsPrefix)
|
||||
}
|
||||
|
||||
const (
|
||||
accessRequestsPrefix = "access_requests"
|
||||
)
|
|
@ -82,6 +82,12 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch services.Watch) (s
|
|||
parser = newTunnelConnectionParser()
|
||||
case services.KindReverseTunnel:
|
||||
parser = newReverseTunnelParser()
|
||||
case services.KindAccessRequest:
|
||||
p, err := newAccessRequestParser(kind.Filter)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
parser = p
|
||||
default:
|
||||
return nil, trace.BadParameter("watcher on object kind %v is not supported", kind)
|
||||
}
|
||||
|
@ -134,6 +140,10 @@ func (w *watcher) parseEvent(e backend.Event) (*services.Event, error) {
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
// if resource is nil, then it was well-formed but is being filtered out.
|
||||
if resource == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &services.Event{Type: e.Type, Resource: resource}, nil
|
||||
}
|
||||
}
|
||||
|
@ -157,6 +167,10 @@ func (w *watcher) forwardEvents() {
|
|||
}
|
||||
continue
|
||||
}
|
||||
// event is being filtered out
|
||||
if converted == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case w.eventsC <- *converted:
|
||||
case <-w.backendWatcher.Done():
|
||||
|
@ -485,6 +499,56 @@ func (p *roleParser) parse(event backend.Event) (services.Resource, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func newAccessRequestParser(m map[string]string) (*accessRequestParser, error) {
|
||||
var filter services.AccessRequestFilter
|
||||
if err := filter.FromMap(m); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return &accessRequestParser{
|
||||
filter: filter,
|
||||
matchPrefix: backend.Key(accessRequestsPrefix),
|
||||
matchSuffix: backend.Key(paramsPrefix),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type accessRequestParser struct {
|
||||
filter services.AccessRequestFilter
|
||||
matchPrefix []byte
|
||||
matchSuffix []byte
|
||||
}
|
||||
|
||||
func (p *accessRequestParser) prefix() []byte {
|
||||
return p.matchPrefix
|
||||
}
|
||||
|
||||
func (p *accessRequestParser) match(key []byte) bool {
|
||||
if !bytes.HasPrefix(key, p.matchPrefix) {
|
||||
return false
|
||||
}
|
||||
if !bytes.HasSuffix(key, p.matchSuffix) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *accessRequestParser) parse(event backend.Event) (services.Resource, error) {
|
||||
switch event.Type {
|
||||
case backend.OpDelete:
|
||||
return resourceHeader(event, services.KindAccessRequest, services.V3, 1)
|
||||
case backend.OpPut:
|
||||
req, err := itemToAccessRequest(event.Item)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if !p.filter.Match(req) {
|
||||
return nil, nil
|
||||
}
|
||||
return req, nil
|
||||
default:
|
||||
return nil, trace.BadParameter("event %v is not supported", event.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func newUserParser() *userParser {
|
||||
return &userParser{
|
||||
matchPrefix: backend.Key(webPrefix, usersPrefix),
|
||||
|
|
|
@ -62,6 +62,9 @@ const (
|
|||
// KindRole is a role resource
|
||||
KindRole = "role"
|
||||
|
||||
// KindAccessRequest is an AccessReqeust resource
|
||||
KindAccessRequest = "access_request"
|
||||
|
||||
// KindOIDC is OIDC connector resource
|
||||
KindOIDC = "oidc"
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ func RoleNameForCertAuthority(name string) string {
|
|||
}
|
||||
|
||||
// NewAdminRole is the default admin role for all local users if another role
|
||||
// is not explicitly assigned (Enterprise only).
|
||||
// is not explicitly assigned (this role applies to all users in OSS version).
|
||||
func NewAdminRole() Role {
|
||||
role := &RoleV3{
|
||||
Kind: KindRole,
|
||||
|
@ -265,6 +265,11 @@ type Role interface {
|
|||
GetKubeGroups(RoleConditionType) []string
|
||||
// SetKubeGroups sets kubernetes groups for allow or deny condition.
|
||||
SetKubeGroups(RoleConditionType, []string)
|
||||
|
||||
// GetAccessRequestConditions gets allow/deny conditions for access requests.
|
||||
GetAccessRequestConditions(RoleConditionType) AccessRequestConditions
|
||||
// SetAccessRequestConditions sets allow/deny conditions for access requests.
|
||||
SetAccessRequestConditions(RoleConditionType, AccessRequestConditions)
|
||||
}
|
||||
|
||||
// ApplyTraits applies the passed in traits to any variables within the role
|
||||
|
@ -516,6 +521,27 @@ func (r *RoleV3) SetKubeGroups(rct RoleConditionType, groups []string) {
|
|||
}
|
||||
}
|
||||
|
||||
// GetAccessRequestConditions gets conditions for access requests.
|
||||
func (r *RoleV3) GetAccessRequestConditions(rct RoleConditionType) AccessRequestConditions {
|
||||
cond := r.Spec.Deny.Request
|
||||
if rct == Allow {
|
||||
cond = r.Spec.Allow.Request
|
||||
}
|
||||
if cond == nil {
|
||||
return AccessRequestConditions{}
|
||||
}
|
||||
return *cond
|
||||
}
|
||||
|
||||
// SetAccessRequestConditions sets allow/deny conditions for access requests.
|
||||
func (r *RoleV3) SetAccessRequestConditions(rct RoleConditionType, cond AccessRequestConditions) {
|
||||
if rct == Allow {
|
||||
r.Spec.Allow.Request = &cond
|
||||
} else {
|
||||
r.Spec.Deny.Request = &cond
|
||||
}
|
||||
}
|
||||
|
||||
// GetNamespaces gets a list of namespaces this role is allowed or denied access to.
|
||||
func (r *RoleV3) GetNamespaces(rct RoleConditionType) []string {
|
||||
if rct == Allow {
|
||||
|
@ -2198,6 +2224,16 @@ const RoleSpecV3SchemaDefinitions = `
|
|||
"type": "array",
|
||||
"items": { "type": "string" }
|
||||
},
|
||||
"request": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"roles": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"rules": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -371,6 +371,55 @@ message Namespace {
|
|||
message NamespaceSpec {
|
||||
}
|
||||
|
||||
// AccessRequest represents an access request resource specification
|
||||
message AccessRequestV3 {
|
||||
option (gogoproto.goproto_stringer) = false;
|
||||
option (gogoproto.stringer) = false;
|
||||
|
||||
// Kind is a resource kind
|
||||
string Kind = 1 [(gogoproto.jsontag) = "kind"];
|
||||
// SubKind is an optional resource sub kind, used in some resources
|
||||
string SubKind = 2 [(gogoproto.jsontag) = "sub_kind,omitempty"];
|
||||
// Version is version
|
||||
string Version = 3 [(gogoproto.jsontag) = "version"];
|
||||
// Metadata is AccessRequest metadata
|
||||
Metadata Metadata = 4 [(gogoproto.nullable) = false, (gogoproto.jsontag) = "metadata"];
|
||||
// Spec is an AccessReqeust specification
|
||||
AccessRequestSpecV3 Spec = 5 [(gogoproto.nullable) = false, (gogoproto.jsontag) = "spec"];
|
||||
}
|
||||
|
||||
// RequestState represents the state of a request for escalated privilege.
|
||||
enum RequestState {
|
||||
NONE = 0;
|
||||
PENDING = 1;
|
||||
APPROVED = 2;
|
||||
DENIED = 3;
|
||||
}
|
||||
|
||||
// AccessRequestSpec is the specification for AccessRequest
|
||||
message AccessRequestSpecV3 {
|
||||
// User is the name of the user to whom the roles will be applied.
|
||||
string User = 1 [(gogoproto.jsontag) = "user"];
|
||||
// Roles is the name of the roles being requested.
|
||||
repeated string Roles = 2 [(gogoproto.jsontag) = "roles"];
|
||||
// State is the current state of this access request.
|
||||
RequestState State = 3 [(gogoproto.jsontag) = "state,omitempty"];
|
||||
// Created encodes the time at which the request was registered with the auth server.
|
||||
google.protobuf.Timestamp Created = 4 [(gogoproto.stdtime) = true, (gogoproto.nullable) = false, (gogoproto.jsontag) = "created,omitempty"];
|
||||
// Expires constrains the maximum lifetime of any login session for which this request is active.
|
||||
google.protobuf.Timestamp Expires = 5 [(gogoproto.stdtime) = true, (gogoproto.nullable) = false, (gogoproto.jsontag) = "expires,omitempty"];
|
||||
}
|
||||
|
||||
// AccessRequestFilter encodes filter params for access requests.
|
||||
message AccessRequestFilter {
|
||||
// ID specifies a request ID if set.
|
||||
string ID = 1 [(gogoproto.jsontag) = "id"];
|
||||
// User specifies a username if set.
|
||||
string User = 2 [(gogoproto.jsontag) = "user"];
|
||||
// RequestState filters for requests in a specific state.
|
||||
RequestState State = 3 [(gogoproto.jsontag) = "state"];
|
||||
}
|
||||
|
||||
// RoleV3 represents role resource specification
|
||||
message RoleV3 {
|
||||
option (gogoproto.goproto_stringer) = false;
|
||||
|
@ -444,6 +493,14 @@ message RoleConditions {
|
|||
|
||||
// KubeGroups is a list of kubernetes groups
|
||||
repeated string KubeGroups = 5 [(gogoproto.jsontag) = "kubernetes_groups,omitempty"];
|
||||
|
||||
AccessRequestConditions Request = 6 [(gogoproto.jsontag) = "request,omitempty"];
|
||||
}
|
||||
|
||||
// AccessRequestConditions is a matcher for allow/deny restrictions on access-requests.
|
||||
message AccessRequestConditions {
|
||||
// Roles is the name of roles which will match the request rule.
|
||||
repeated string Roles = 1 [(gogoproto.jsontag) = "roles,omitempty"];
|
||||
}
|
||||
|
||||
// Rule represents allow or deny rule that is executed to check
|
||||
|
|
171
tool/tctl/common/access_request_command.go
Normal file
171
tool/tctl/common/access_request_command.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
Copyright 2019 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/kingpin"
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/asciitable"
|
||||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// AccessRequestCommand implements `tctl users` set of commands
|
||||
// It implements CLICommand interface
|
||||
type AccessRequestCommand struct {
|
||||
config *service.Config
|
||||
reqIDs string
|
||||
|
||||
user string
|
||||
roles string
|
||||
// format is the output format, e.g. text or json
|
||||
format string
|
||||
|
||||
requestList *kingpin.CmdClause
|
||||
requestApprove *kingpin.CmdClause
|
||||
requestDeny *kingpin.CmdClause
|
||||
requestCreate *kingpin.CmdClause
|
||||
requestDelete *kingpin.CmdClause
|
||||
}
|
||||
|
||||
// Initialize allows AccessRequestCommand to plug itself into the CLI parser
|
||||
func (c *AccessRequestCommand) Initialize(app *kingpin.Application, config *service.Config) {
|
||||
c.config = config
|
||||
requests := app.Command("requests", "Manage access requests").Alias("request")
|
||||
|
||||
c.requestList = requests.Command("ls", "Show active access requests")
|
||||
c.requestList.Flag("format", "Output format, 'text' or 'json'").Hidden().Default(teleport.Text).StringVar(&c.format)
|
||||
|
||||
c.requestApprove = requests.Command("approve", "Approve pending access request")
|
||||
c.requestApprove.Arg("request-id", "ID of target request(s)").Required().StringVar(&c.reqIDs)
|
||||
|
||||
c.requestDeny = requests.Command("deny", "Deny pending access request")
|
||||
c.requestDeny.Arg("request-id", "ID of target request(s)").Required().StringVar(&c.reqIDs)
|
||||
|
||||
c.requestCreate = requests.Command("create", "Create pending access request")
|
||||
c.requestCreate.Arg("username", "Name of target user").Required().StringVar(&c.user)
|
||||
c.requestCreate.Flag("roles", "Roles to be requested").Required().StringVar(&c.roles)
|
||||
|
||||
c.requestDelete = requests.Command("rm", "Delete an access request")
|
||||
c.requestDelete.Arg("request-id", "ID of target request(s)").Required().StringVar(&c.reqIDs)
|
||||
}
|
||||
|
||||
// TryRun takes the CLI command as an argument (like "access-request list") and executes it.
|
||||
func (c *AccessRequestCommand) TryRun(cmd string, client auth.ClientI) (match bool, err error) {
|
||||
switch cmd {
|
||||
case c.requestList.FullCommand():
|
||||
err = c.List(client)
|
||||
case c.requestApprove.FullCommand():
|
||||
err = c.Approve(client)
|
||||
case c.requestDeny.FullCommand():
|
||||
err = c.Deny(client)
|
||||
case c.requestCreate.FullCommand():
|
||||
err = c.Create(client)
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
return true, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (c *AccessRequestCommand) List(client auth.ClientI) error {
|
||||
reqs, err := client.GetAccessRequests(services.AccessRequestFilter{})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if err := c.PrintAccessRequests(client, reqs, c.format); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AccessRequestCommand) Approve(client auth.ClientI) error {
|
||||
for _, reqID := range strings.Split(c.reqIDs, ",") {
|
||||
if err := client.SetAccessRequestState(reqID, services.RequestState_APPROVED); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AccessRequestCommand) Deny(client auth.ClientI) error {
|
||||
for _, reqID := range strings.Split(c.reqIDs, ",") {
|
||||
if err := client.SetAccessRequestState(reqID, services.RequestState_DENIED); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AccessRequestCommand) Create(client auth.ClientI) error {
|
||||
roles := strings.Split(c.roles, ",")
|
||||
req, err := services.NewAccessRequest(c.user, roles...)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if err := client.CreateAccessRequest(req); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
fmt.Printf("%s\n", req.GetName())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AccessRequestCommand) Delete(client auth.ClientI) error {
|
||||
for _, reqID := range strings.Split(c.reqIDs, ",") {
|
||||
if err := client.DeleteAccessRequest(reqID); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrintAccessRequests prints access requests
|
||||
func (c *AccessRequestCommand) PrintAccessRequests(client auth.ClientI, reqs []services.AccessRequest, format string) error {
|
||||
if format == teleport.Text {
|
||||
table := asciitable.MakeTable([]string{"Token", "Requestor", "Metadata", "Created At (UTC)", "Status"})
|
||||
now := time.Now()
|
||||
for _, req := range reqs {
|
||||
if now.After(req.GetAccessExpiry()) {
|
||||
continue
|
||||
}
|
||||
params := fmt.Sprintf("roles=%s", strings.Join(req.GetRoles(), ","))
|
||||
table.AddRow([]string{
|
||||
req.GetName(),
|
||||
req.GetUser(),
|
||||
params,
|
||||
req.GetCreationTime().Format(time.RFC822),
|
||||
req.GetState().String(),
|
||||
})
|
||||
}
|
||||
_, err := table.AsBuffer().WriteTo(os.Stdout)
|
||||
return trace.Wrap(err)
|
||||
} else {
|
||||
out, err := json.MarshalIndent(reqs, "", " ")
|
||||
if err != nil {
|
||||
return trace.Wrap(err, "failed to marshal requests")
|
||||
}
|
||||
fmt.Printf("%s\n", out)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -70,7 +70,7 @@ func (a *AuthCommand) Initialize(app *kingpin.Application, config *service.Confi
|
|||
a.authSign.Flag("user", "Teleport user name").StringVar(&a.genUser)
|
||||
a.authSign.Flag("host", "Teleport host name").StringVar(&a.genHost)
|
||||
a.authSign.Flag("out", "identity output").Short('o').StringVar(&a.output)
|
||||
a.authSign.Flag("format", fmt.Sprintf("identity format: %q (default) or %q", client.IdentityFormatFile, client.IdentityFormatOpenSSH)).Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&a.outputFormat))
|
||||
a.authSign.Flag("format", fmt.Sprintf("identity format: %q (default), %q, or %q", client.IdentityFormatFile, client.IdentityFormatOpenSSH, client.IdentityFormatTLS)).Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&a.outputFormat))
|
||||
a.authSign.Flag("ttl", "TTL (time to live) for the generated certificate").Default(fmt.Sprintf("%v", defaults.CertDuration)).DurationVar(&a.genTTL)
|
||||
a.authSign.Flag("compat", "OpenSSH compatibility flag").StringVar(&a.compatibility)
|
||||
|
||||
|
@ -348,7 +348,7 @@ func (a *AuthCommand) generateUserKeys(clusterApi auth.ClientI) error {
|
|||
key.TLSCert = certs.TLS
|
||||
|
||||
var certAuthorities []services.CertAuthority
|
||||
if a.outputFormat == client.IdentityFormatFile {
|
||||
if a.outputFormat != client.IdentityFormatOpenSSH {
|
||||
certAuthorities, err = clusterApi.GetCertAuthorities(services.HostCA, false)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
|
|
@ -186,10 +186,13 @@ func (c *TokenCommand) List(client auth.ClientI) error {
|
|||
if c.format == teleport.Text {
|
||||
tokensView := func() string {
|
||||
table := asciitable.MakeTable([]string{"Token", "Type", "Expiry Time (UTC)"})
|
||||
now := time.Now()
|
||||
for _, t := range tokens {
|
||||
expiry := "never"
|
||||
if t.Expiry().Unix() > 0 {
|
||||
expiry = t.Expiry().Format(time.RFC822)
|
||||
exptime := t.Expiry().Format(time.RFC822)
|
||||
expdur := t.Expiry().Sub(now).Round(time.Second)
|
||||
expiry = fmt.Sprintf("%s (%s)", exptime, expdur.String())
|
||||
}
|
||||
table.AddRow([]string{t.GetName(), t.GetRoles().String(), expiry})
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ func main() {
|
|||
&common.ResourceCommand{},
|
||||
&common.StatusCommand{},
|
||||
&common.TopCommand{},
|
||||
&common.AccessRequestCommand{},
|
||||
}
|
||||
common.Run(commands)
|
||||
}
|
||||
|
|
234
tool/tsh/tsh.go
234
tool/tsh/tsh.go
|
@ -18,8 +18,11 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -34,6 +37,8 @@ import (
|
|||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/asciitable"
|
||||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/backend"
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
kubeclient "github.com/gravitational/teleport/lib/kube/client"
|
||||
|
@ -61,6 +66,8 @@ type CLIConf struct {
|
|||
UserHost string
|
||||
// Commands to execute on a remote host
|
||||
RemoteCommand []string
|
||||
// DesiredRoles indicates one or more roles which should be requested.
|
||||
DesiredRoles string
|
||||
// Username is the Teleport user's username (to login into proxies)
|
||||
Username string
|
||||
// Proxy keeps the hostname:port of the SSH proxy to use
|
||||
|
@ -262,6 +269,7 @@ func Run(args []string, underTest bool) {
|
|||
login.Flag("format", fmt.Sprintf("Identity format [%s] or %s (for OpenSSH compatibility)",
|
||||
client.DefaultIdentityFormat,
|
||||
client.IdentityFormatOpenSSH)).Default(string(client.DefaultIdentityFormat)).StringVar((*string)(&cf.IdentityFormat))
|
||||
login.Flag("request-roles", "Request one or more extra roles").StringVar(&cf.DesiredRoles)
|
||||
login.Arg("cluster", clusterHelp).StringVar(&cf.SiteName)
|
||||
login.Alias(loginUsageFooter)
|
||||
|
||||
|
@ -411,18 +419,23 @@ func onLogin(cf *CLIConf) {
|
|||
if profile != nil && !profile.IsExpired(clockwork.NewRealClock()) {
|
||||
switch {
|
||||
// in case if nothing is specified, print current status
|
||||
case cf.Proxy == "" && cf.SiteName == "":
|
||||
case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "":
|
||||
printProfiles(cf.Debug, profile, profiles)
|
||||
return
|
||||
// in case if parameters match, print current status
|
||||
case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster:
|
||||
case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "":
|
||||
printProfiles(cf.Debug, profile, profiles)
|
||||
return
|
||||
// proxy is unspecified or the same as the currently provided proxy,
|
||||
// but cluster is specified, treat this as selecting a new cluster
|
||||
// for the same proxy
|
||||
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "":
|
||||
if err := tc.GenerateCertsForCluster(cf.Context, cf.SiteName); err != nil {
|
||||
// trigger reissue, preserving any active requests.
|
||||
err = tc.ReissueUserCerts(cf.Context, client.ReissueParams{
|
||||
AccessRequests: profile.ActiveRequests.AccessRequests,
|
||||
RouteToCluster: cf.SiteName,
|
||||
})
|
||||
if err != nil {
|
||||
utils.FatalError(err)
|
||||
}
|
||||
tc.SaveProfile("", "")
|
||||
|
@ -431,6 +444,12 @@ func onLogin(cf *CLIConf) {
|
|||
}
|
||||
onStatus(cf)
|
||||
return
|
||||
// proxy is unspecified or the same as the currently provided proxy,
|
||||
// but desired roles are specified, treat this as a privilege escalation
|
||||
// request for the same login session.
|
||||
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.DesiredRoles != "":
|
||||
executeAccessRequest(cf)
|
||||
return
|
||||
// otherwise just passthrough to standard login
|
||||
default:
|
||||
}
|
||||
|
@ -477,7 +496,12 @@ func onLogin(cf *CLIConf) {
|
|||
// advertised settings are picked up.
|
||||
webProxyHost, _ := tc.WebProxyHostPort()
|
||||
cf.Proxy = webProxyHost
|
||||
onStatus(cf)
|
||||
if cf.DesiredRoles != "" {
|
||||
fmt.Println("") // visually separate onRequestExecute output
|
||||
executeAccessRequest(cf)
|
||||
} else {
|
||||
onStatus(cf)
|
||||
}
|
||||
}
|
||||
|
||||
// setupNoninteractiveClient sets up existing client to use
|
||||
|
@ -665,6 +689,33 @@ func onListNodes(cf *CLIConf) {
|
|||
}
|
||||
}
|
||||
|
||||
func executeAccessRequest(cf *CLIConf) {
|
||||
if cf.DesiredRoles == "" {
|
||||
utils.FatalError(trace.BadParameter("one or more roles must be specified"))
|
||||
}
|
||||
roles := strings.Split(cf.DesiredRoles, ",")
|
||||
tc, err := makeClient(cf, true)
|
||||
if err != nil {
|
||||
utils.FatalError(err)
|
||||
}
|
||||
if cf.Username == "" {
|
||||
cf.Username = tc.Username
|
||||
}
|
||||
req, err := services.NewAccessRequest(cf.Username, roles...)
|
||||
if err != nil {
|
||||
utils.FatalError(err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Seeking request approval... (id: %s)\n", req.GetName())
|
||||
if err := getRequestApproval(cf, tc, req); err != nil {
|
||||
utils.FatalError(err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Approval received, getting updated certificates...\n\n")
|
||||
if err := reissueWithRequests(cf, tc, req.GetName()); err != nil {
|
||||
utils.FatalError(err)
|
||||
}
|
||||
onStatus(cf)
|
||||
}
|
||||
|
||||
// chunkLabels breaks labels into sized chunks. Used to improve readability
|
||||
// of "tsh ls".
|
||||
func chunkLabels(labels map[string]string, chunkSize int) [][]string {
|
||||
|
@ -1028,6 +1079,104 @@ func refuseArgs(command string, args []string) {
|
|||
}
|
||||
}
|
||||
|
||||
// loadIdentity loads the private key + certificate from a file
|
||||
// Returns:
|
||||
// - client key: user's private key+cert
|
||||
// - host auth callback: function to validate the host (may be null)
|
||||
// - error, if somthing happens when reading the identityf file
|
||||
//
|
||||
// If the "host auth callback" is not returned, user will be prompted to
|
||||
// trust the proxy server.
|
||||
func loadIdentity(idFn string) (*client.Key, ssh.HostKeyCallback, error) {
|
||||
log.Infof("Reading identity file: %v", idFn)
|
||||
|
||||
f, err := os.Open(idFn)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
defer f.Close()
|
||||
ident, err := client.DecodeIdentityFile(f)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err, "failed to parse identity file")
|
||||
}
|
||||
// did not find the certificate in the file? look in a separate file with
|
||||
// -cert.pub prefix
|
||||
if len(ident.Certs.SSH) == 0 {
|
||||
certFn := idFn + "-cert.pub"
|
||||
log.Infof("Certificate not found in %s. Looking in %s.", idFn, certFn)
|
||||
ident.Certs.SSH, err = ioutil.ReadFile(certFn)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
// validate both by parsing them:
|
||||
privKey, err := ssh.ParseRawPrivateKey(ident.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, nil, trace.BadParameter("invalid identity: %s. %v", idFn, err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(privKey)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
// validate TLS Cert (if present):
|
||||
if len(ident.Certs.TLS) > 0 {
|
||||
_, err := tls.X509KeyPair(ident.Certs.TLS, ident.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
// Validate TLS CA certs (if present).
|
||||
var trustedCA []auth.TrustedCerts
|
||||
if len(ident.CACerts.TLS) > 0 {
|
||||
var trustedCerts auth.TrustedCerts
|
||||
pool := x509.NewCertPool()
|
||||
for i, certPEM := range ident.CACerts.TLS {
|
||||
if !pool.AppendCertsFromPEM(certPEM) {
|
||||
return nil, nil, trace.BadParameter("identity file contains invalid TLS CA cert (#%v)", i+1)
|
||||
}
|
||||
trustedCerts.TLSCertificates = append(trustedCerts.TLSCertificates, certPEM)
|
||||
}
|
||||
trustedCA = []auth.TrustedCerts{trustedCerts}
|
||||
}
|
||||
var hostAuthFunc ssh.HostKeyCallback = nil
|
||||
// validate CA (cluster) cert
|
||||
if len(ident.CACerts.SSH) > 0 {
|
||||
var trustedKeys []ssh.PublicKey
|
||||
for _, caCert := range ident.CACerts.SSH {
|
||||
_, _, publicKey, _, _, err := ssh.ParseKnownHosts(caCert)
|
||||
if err != nil {
|
||||
return nil, nil, trace.BadParameter("CA cert parsing error: %v. cert line :%v",
|
||||
err.Error(), string(caCert))
|
||||
}
|
||||
trustedKeys = append(trustedKeys, publicKey)
|
||||
}
|
||||
|
||||
// found CA cert in the indentity file? construct the host key checking function
|
||||
// and return it:
|
||||
hostAuthFunc = func(host string, a net.Addr, hostKey ssh.PublicKey) error {
|
||||
clusterCert, ok := hostKey.(*ssh.Certificate)
|
||||
if ok {
|
||||
hostKey = clusterCert.SignatureKey
|
||||
}
|
||||
for _, trustedKey := range trustedKeys {
|
||||
if sshutils.KeysEqual(trustedKey, hostKey) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
err = trace.AccessDenied("host %v is untrusted", host)
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return &client.Key{
|
||||
Priv: ident.PrivateKey,
|
||||
Pub: signer.PublicKey().Marshal(),
|
||||
Cert: ident.Certs.SSH,
|
||||
TLSCert: ident.Certs.TLS,
|
||||
TrustedCA: trustedCA,
|
||||
}, hostAuthFunc, nil
|
||||
}
|
||||
|
||||
// authFromIdentity returns a standard ssh.Authmethod for a given identity file
|
||||
func authFromIdentity(k *client.Key) (ssh.AuthMethod, error) {
|
||||
signer, err := sshutils.NewSigner(k.Priv, k.Cert)
|
||||
|
@ -1147,3 +1296,80 @@ func host(in string) string {
|
|||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// getRequestApproval registers an access request with the auth server and waits for it to be approved.
|
||||
func getRequestApproval(cf *CLIConf, tc *client.TeleportClient, req services.AccessRequest) error {
|
||||
// set up request watcher before submitting the request to the admin server
|
||||
// in order to avoid potential race.
|
||||
filter := services.AccessRequestFilter{
|
||||
User: tc.Username,
|
||||
}
|
||||
watcher, err := tc.NewWatcher(cf.Context, services.Watch{
|
||||
Name: "await-request-approval",
|
||||
Kinds: []services.WatchKind{
|
||||
services.WatchKind{
|
||||
Kind: services.KindAccessRequest,
|
||||
Filter: filter.IntoMap(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
defer watcher.Close()
|
||||
if err := tc.CreateAccessRequest(cf.Context, req); err != nil {
|
||||
utils.FatalError(err)
|
||||
}
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case event := <-watcher.Events():
|
||||
if event.Type != backend.OpPut {
|
||||
continue Loop
|
||||
}
|
||||
r, ok := event.Resource.(*services.AccessRequestV3)
|
||||
if !ok {
|
||||
return trace.Errorf("unexpected resource type %T", event.Resource)
|
||||
}
|
||||
if r.GetName() != req.GetName() || r.GetState().IsPending() {
|
||||
continue Loop
|
||||
}
|
||||
if !r.GetState().IsApproved() {
|
||||
return trace.Errorf("request %s has been set to %s", r.GetName(), r.GetState().String())
|
||||
}
|
||||
return nil
|
||||
case <-watcher.Done():
|
||||
utils.FatalError(watcher.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reissueWithRequests handles a certificate reissue, applying new requests by ID,
|
||||
// and saving the updated profile.
|
||||
func reissueWithRequests(cf *CLIConf, tc *client.TeleportClient, reqIDs ...string) error {
|
||||
profile, _, err := client.Status("", cf.Proxy)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
params := client.ReissueParams{
|
||||
AccessRequests: reqIDs,
|
||||
RouteToCluster: cf.SiteName,
|
||||
}
|
||||
// if the certificate already had active requests, add them to our inputs parameters.
|
||||
if len(profile.ActiveRequests.AccessRequests) > 0 {
|
||||
params.AccessRequests = append(params.AccessRequests, profile.ActiveRequests.AccessRequests...)
|
||||
}
|
||||
if params.RouteToCluster == "" {
|
||||
params.RouteToCluster = profile.Cluster
|
||||
}
|
||||
if err := tc.ReissueUserCerts(cf.Context, params); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if err := tc.SaveProfile("", ""); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if err := kubeclient.UpdateKubeconfig(tc); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue