mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 17:53:28 +00:00
335 lines
10 KiB
Go
335 lines
10 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/flate"
|
|
"encoding/base64"
|
|
"io/ioutil"
|
|
"time"
|
|
|
|
"github.com/gravitational/teleport"
|
|
"github.com/gravitational/teleport/lib/defaults"
|
|
"github.com/gravitational/teleport/lib/services"
|
|
"github.com/gravitational/teleport/lib/utils"
|
|
"github.com/gravitational/trace"
|
|
|
|
log "github.com/Sirupsen/logrus"
|
|
"github.com/beevik/etree"
|
|
saml2 "github.com/russellhaering/gosaml2"
|
|
)
|
|
|
|
func (s *AuthServer) UpsertSAMLConnector(connector services.SAMLConnector) error {
|
|
return s.Identity.UpsertSAMLConnector(connector)
|
|
}
|
|
|
|
func (s *AuthServer) DeleteSAMLConnector(connectorName string) error {
|
|
return s.Identity.DeleteSAMLConnector(connectorName)
|
|
}
|
|
|
|
func (s *AuthServer) CreateSAMLAuthRequest(req services.SAMLAuthRequest) (*services.SAMLAuthRequest, error) {
|
|
connector, err := s.Identity.GetSAMLConnector(req.ConnectorID, true)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
provider, err := s.getSAMLProvider(connector)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
|
|
doc, err := provider.BuildAuthRequestDocument()
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
|
|
attr := doc.Root().SelectAttr("ID")
|
|
if attr == nil || attr.Value == "" {
|
|
return nil, trace.BadParameter("missing auth request ID")
|
|
}
|
|
|
|
req.ID = attr.Value
|
|
req.RedirectURL, err = provider.BuildAuthURLFromDocument("", doc)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
|
|
err = s.Identity.CreateSAMLAuthRequest(req, defaults.SAMLAuthRequestTTL)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
return &req, nil
|
|
}
|
|
|
|
func (s *AuthServer) getSAMLProvider(conn services.SAMLConnector) (*saml2.SAMLServiceProvider, error) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
providerPack, ok := s.samlProviders[conn.GetName()]
|
|
if ok && providerPack.connector.Equals(conn) {
|
|
return providerPack.provider, nil
|
|
}
|
|
delete(s.samlProviders, conn.GetName())
|
|
|
|
serviceProvider, err := conn.GetServiceProvider(s.clock)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
|
|
s.samlProviders[conn.GetName()] = &samlProvider{connector: conn, provider: serviceProvider}
|
|
|
|
return serviceProvider, nil
|
|
}
|
|
|
|
// buildSAMLRoles takes a connector and claims and returns a slice of roles. If the claims
|
|
// match a concrete roles in the connector, those roles are returned directly. If the
|
|
// claims match a template role in the connector, then that role is first created from
|
|
// the template, then returned.
|
|
func (a *AuthServer) buildSAMLRoles(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) ([]string, error) {
|
|
roles := connector.MapAttributes(assertionInfo)
|
|
if len(roles) == 0 {
|
|
role, err := connector.RoleFromTemplate(assertionInfo)
|
|
if err != nil {
|
|
log.Warningf("[SAML] Unable to map claims to roles or role templates for %q", connector.GetName())
|
|
return nil, trace.AccessDenied("unable to map claims to roles or role templates for %q", connector.GetName())
|
|
}
|
|
|
|
// figure out ttl for role. expires = now + ttl => ttl = expires - now
|
|
ttl := expiresAt.Sub(a.clock.Now())
|
|
|
|
// upsert templated role
|
|
err = a.Access.UpsertRole(role, ttl)
|
|
if err != nil {
|
|
log.Warningf("[SAML] Unable to upsert templated role for connector: %q", connector.GetName())
|
|
return nil, trace.AccessDenied("unable to upsert templated role: %q", connector.GetName())
|
|
}
|
|
|
|
roles = []string{role.GetName()}
|
|
}
|
|
|
|
return roles, nil
|
|
}
|
|
|
|
func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) error {
|
|
roles, err := a.buildSAMLRoles(connector, assertionInfo, expiresAt)
|
|
if err != nil {
|
|
return trace.Wrap(err)
|
|
}
|
|
|
|
log.Debugf("[SAML] %v/%v is a dynamic identity, generating user with roles: %v", connector.GetName(), assertionInfo.NameID, roles)
|
|
user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{
|
|
Kind: services.KindUser,
|
|
Version: services.V2,
|
|
Metadata: services.Metadata{
|
|
Name: assertionInfo.NameID,
|
|
Namespace: defaults.Namespace,
|
|
},
|
|
Spec: services.UserSpecV2{
|
|
Roles: roles,
|
|
Expires: expiresAt,
|
|
SAMLIdentities: []services.ExternalIdentity{{ConnectorID: connector.GetName(), Username: assertionInfo.NameID}},
|
|
CreatedBy: services.CreatedBy{
|
|
User: services.UserRef{Name: "system"},
|
|
Time: time.Now().UTC(),
|
|
Connector: &services.ConnectorRef{
|
|
Type: teleport.ConnectorSAML,
|
|
ID: connector.GetName(),
|
|
Identity: assertionInfo.NameID,
|
|
},
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return trace.Wrap(err)
|
|
}
|
|
|
|
// check if a user exists already
|
|
existingUser, err := a.GetUser(assertionInfo.NameID)
|
|
if err != nil {
|
|
if !trace.IsNotFound(err) {
|
|
return trace.Wrap(err)
|
|
}
|
|
}
|
|
|
|
// check if exisiting user is a non-saml user, if so, return an error
|
|
if existingUser != nil {
|
|
connectorRef := existingUser.GetCreatedBy().Connector
|
|
if connectorRef == nil || connectorRef.Type != teleport.ConnectorSAML || connectorRef.ID != connector.GetName() {
|
|
return trace.AlreadyExists("user %q already exists and is not SAML user", existingUser.GetName())
|
|
}
|
|
}
|
|
|
|
// no non-saml user exists, create or update the exisiting saml user
|
|
err = a.UpsertUser(user)
|
|
if err != nil {
|
|
return trace.Wrap(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseSAMLInResponseTo(response string) (string, error) {
|
|
raw, _ := base64.StdEncoding.DecodeString(response)
|
|
|
|
doc := etree.NewDocument()
|
|
err := doc.ReadFromBytes(raw)
|
|
if err != nil {
|
|
// Attempt to inflate the response in case it happens to be compressed (as with one case at saml.oktadev.com)
|
|
buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(raw)))
|
|
if err != nil {
|
|
return "", trace.Wrap(err)
|
|
}
|
|
|
|
doc = etree.NewDocument()
|
|
err = doc.ReadFromBytes(buf)
|
|
if err != nil {
|
|
return "", trace.Wrap(err)
|
|
}
|
|
}
|
|
|
|
if doc.Root() == nil {
|
|
return "", trace.BadParameter("unable to parse response")
|
|
}
|
|
|
|
el := doc.Root()
|
|
responseTo := el.SelectAttr("InResponseTo")
|
|
if responseTo == nil {
|
|
return "", trace.BadParameter("identity provider initiated flows are not supported")
|
|
}
|
|
if responseTo.Value == "" {
|
|
return "", trace.BadParameter("InResponseTo can not be empty")
|
|
}
|
|
return responseTo.Value, nil
|
|
}
|
|
|
|
// SAMLAuthResponse is returned when auth server validated callback parameters
|
|
// returned from SAML identity provider
|
|
type SAMLAuthResponse struct {
|
|
// Username is authenticated teleport username
|
|
Username string `json:"username"`
|
|
// Identity contains validated SAML identity
|
|
Identity services.ExternalIdentity `json:"identity"`
|
|
// Web session will be generated by auth server if requested in SAMLAuthRequest
|
|
Session services.WebSession `json:"session,omitempty"`
|
|
// Cert will be generated by certificate authority
|
|
Cert []byte `json:"cert,omitempty"`
|
|
// Req is original SAML auth request
|
|
Req services.SAMLAuthRequest `json:"req"`
|
|
// HostSigners is a list of signing host public keys
|
|
// trusted by proxy, used in console login
|
|
HostSigners []services.CertAuthority `json:"host_signers"`
|
|
}
|
|
|
|
// ValidateSAMLResponse consumes attribute statements from SAML identity provider
|
|
func (a *AuthServer) ValidateSAMLResponse(samlResponse string) (*SAMLAuthResponse, error) {
|
|
requestID, err := parseSAMLInResponseTo(samlResponse)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
request, err := a.Identity.GetSAMLAuthRequest(requestID)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
connector, err := a.Identity.GetSAMLConnector(request.ConnectorID, true)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
provider, err := a.getSAMLProvider(connector)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
assertionInfo, err := provider.RetrieveAssertionInfo(samlResponse)
|
|
if err != nil {
|
|
log.Warningf("SAML error: %v", err)
|
|
return nil, trace.AccessDenied("bad SAML response")
|
|
}
|
|
|
|
if assertionInfo.WarningInfo.InvalidTime {
|
|
log.Warningf("SAML error, invalid time")
|
|
return nil, trace.AccessDenied("bad SAML response")
|
|
}
|
|
|
|
if assertionInfo.WarningInfo.NotInAudience {
|
|
log.Warningf("SAML error, not in audience")
|
|
return nil, trace.AccessDenied("bad SAML response")
|
|
}
|
|
|
|
log.Debugf("[SAML] Obtained Assertions for %q", assertionInfo.NameID)
|
|
for key, val := range assertionInfo.Values {
|
|
var vals []string
|
|
for _, vv := range val.Values {
|
|
vals = append(vals, vv.Value)
|
|
}
|
|
log.Debugf("[SAML] Assertion: %q: %q", key, vals)
|
|
}
|
|
log.Debugf("[SAML] Assertion Warnings: %+v", assertionInfo.WarningInfo)
|
|
|
|
log.Debugf("[SAML] Applying %v claims to roles mappings", len(connector.GetAttributesToRoles()))
|
|
if len(connector.GetAttributesToRoles()) == 0 {
|
|
return nil, trace.BadParameter("SAML does not support binding to local users")
|
|
}
|
|
// TODO(klizhentas) use SessionNotOnOrAfter to calculate expiration time
|
|
expiresAt := a.clock.Now().Add(defaults.CertDuration)
|
|
if err := a.createSAMLUser(connector, *assertionInfo, expiresAt); err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
|
|
identity := services.ExternalIdentity{
|
|
ConnectorID: request.ConnectorID,
|
|
Username: assertionInfo.NameID,
|
|
}
|
|
user, err := a.Identity.GetUserBySAMLIdentity(identity)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
response := &SAMLAuthResponse{
|
|
Req: *request,
|
|
Identity: identity,
|
|
Username: user.GetName(),
|
|
}
|
|
|
|
var roles services.RoleSet
|
|
roles, err = services.FetchRoles(user.GetRoles(), a.Access)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, expiresAt))
|
|
bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL)
|
|
|
|
if request.CreateWebSession {
|
|
sess, err := a.NewWebSession(user.GetName())
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
// session will expire based on identity TTL and allowed session TTL
|
|
sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL))
|
|
// bearer token will expire based on the expected session renewal
|
|
sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL))
|
|
if err := a.UpsertWebSession(user.GetName(), sess); err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
response.Session = sess
|
|
}
|
|
|
|
if len(request.PublicKey) != 0 {
|
|
certTTL := utils.MinTTL(utils.ToTTL(a.clock, expiresAt), request.CertTTL)
|
|
allowedLogins, err := roles.CheckLogins(certTTL)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
cert, err := a.GenerateUserCert(request.PublicKey, user, allowedLogins, certTTL, roles.CanForwardAgents())
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
response.Cert = cert
|
|
|
|
authorities, err := a.GetCertAuthorities(services.HostCA, false)
|
|
if err != nil {
|
|
return nil, trace.Wrap(err)
|
|
}
|
|
for _, authority := range authorities {
|
|
response.HostSigners = append(response.HostSigners, authority)
|
|
}
|
|
}
|
|
|
|
return response, nil
|
|
}
|