change service account embedded policy size limit (#19840)

Bonus: trim-off all the unnecessary spaces to allow
for real 2048 characters in policies for STS handlers
and re-use the code in all STS handlers.
This commit is contained in:
Harshavardhana 2024-05-30 11:10:41 -07:00 committed by GitHub
parent 4af31e654b
commit 8f93e81afb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 67 additions and 82 deletions

View file

@ -2371,7 +2371,7 @@ func (store *IAMStoreSys) UpdateServiceAccount(ctx context.Context, accessKey st
return updatedAt, err
}
if len(policyBuf) > 2048 {
if len(policyBuf) > maxSVCSessionPolicySize {
return updatedAt, errSessionPolicyTooLarge
}

View file

@ -78,6 +78,10 @@ const (
inheritedPolicyType = "inherited-policy"
)
const (
maxSVCSessionPolicySize = 4096
)
// IAMSys - config system.
type IAMSys struct {
// Need to keep them here to keep alignment - ref: https://golang.org/pkg/sync/atomic/#pkg-note-BUG
@ -977,7 +981,7 @@ func (sys *IAMSys) NewServiceAccount(ctx context.Context, parentUser string, gro
if err != nil {
return auth.Credentials{}, time.Time{}, err
}
if len(policyBuf) > 2048 {
if len(policyBuf) > maxSVCSessionPolicySize {
return auth.Credentials{}, time.Time{}, errSessionPolicyTooLarge
}
}

View file

@ -22,9 +22,11 @@ import (
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@ -82,8 +84,50 @@ const (
// Role Claim key
roleArnClaim = "roleArn"
// maximum supported STS session policy size
maxSTSSessionPolicySize = 2048
)
type stsClaims map[string]interface{}
func (c stsClaims) populateSessionPolicy(form url.Values) error {
if len(form) == 0 {
return nil
}
sessionPolicyStr := form.Get(stsPolicy)
if len(sessionPolicyStr) == 0 {
return nil
}
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
return err
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
return errors.New("Version cannot be empty expecting '2012-10-17'")
}
policyBuf, err := json.Marshal(sessionPolicy)
if err != nil {
return err
}
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html
// The plain text that you use for both inline and managed session
// policies shouldn't exceed maxSTSSessionPolicySize characters.
if len(policyBuf) > maxSTSSessionPolicySize {
return errSessionPolicyTooLarge
}
c[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString(policyBuf)
return nil
}
// stsAPIHandlers implements and provides http handlers for AWS STS API.
type stsAPIHandlers struct{}
@ -212,7 +256,7 @@ func getTokenSigningKey() (string, error) {
func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
ctx := newContext(r, w, "AssumeRole")
claims := make(map[string]interface{})
claims := stsClaims{}
defer logger.AuditLog(ctx, w, r, claims)
// Check auth here (otherwise r.Form will have unexpected values from
@ -249,29 +293,11 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
return
}
sessionPolicyStr := r.Form.Get(stsPolicy)
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html
// The plain text that you use for both inline and managed session
// policies shouldn't exceed 2048 characters.
if len(sessionPolicyStr) > 2048 {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, errSessionPolicyTooLarge)
if err := claims.populateSessionPolicy(r.Form); err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
if len(sessionPolicyStr) > 0 {
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Version cannot be empty expecting '2012-10-17'"))
return
}
}
duration, err := openid.GetDefaultExpiration(r.Form.Get(stsDurationSeconds))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
@ -288,10 +314,6 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
return
}
if len(sessionPolicyStr) > 0 {
claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr))
}
secret, err := getTokenSigningKey()
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err)
@ -342,7 +364,7 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Request) {
ctx := newContext(r, w, "AssumeRoleSSOCommon")
claims := make(map[string]interface{})
claims := stsClaims{}
defer logger.AuditLog(ctx, w, r, claims)
// Parse the incoming form data.
@ -449,31 +471,11 @@ func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Requ
claims[iamPolicyClaimNameOpenID()] = policyName
}
sessionPolicyStr := r.Form.Get(stsPolicy)
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
// The plain text that you use for both inline and managed session
// policies shouldn't exceed 2048 characters.
if len(sessionPolicyStr) > 2048 {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Session policy should not exceed 2048 characters"))
if err := claims.populateSessionPolicy(r.Form); err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
if len(sessionPolicyStr) > 0 {
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Invalid session policy version"))
return
}
claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr))
}
secret, err := getTokenSigningKey()
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err)
@ -612,7 +614,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithClientGrants(w http.ResponseWriter, r *
func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *http.Request) {
ctx := newContext(r, w, "AssumeRoleWithLDAPIdentity")
claims := make(map[string]interface{})
claims := stsClaims{}
defer logger.AuditLog(ctx, w, r, claims, stsLDAPPassword)
// Parse the incoming form data.
@ -643,29 +645,11 @@ func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *
return
}
sessionPolicyStr := r.Form.Get(stsPolicy)
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html
// The plain text that you use for both inline and managed session
// policies shouldn't exceed 2048 characters.
if len(sessionPolicyStr) > 2048 {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Session policy should not exceed 2048 characters"))
if err := claims.populateSessionPolicy(r.Form); err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
if len(sessionPolicyStr) > 0 {
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Version needs to be specified in session policy"))
return
}
}
if !globalIAMSys.Initialized() {
writeSTSErrorResponse(ctx, w, ErrSTSIAMNotInitialized, errIAMNotInitialized)
return
@ -708,10 +692,6 @@ func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *
claims[ldapAttribPrefix+attrib] = value
}
if len(sessionPolicyStr) > 0 {
claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr))
}
secret, err := getTokenSigningKey()
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err)

View file

@ -133,7 +133,7 @@ const (
)
// Validate - validates the id_token.
func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error {
func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims map[string]interface{}) error {
jp := new(jwtgo.Parser)
jp.ValidMethods = []string{
"RS256", "RS384", "RS512",
@ -156,14 +156,15 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
return fmt.Errorf("Role %s does not exist", arn)
}
jwtToken, err := jp.ParseWithClaims(token, &claims, keyFuncCallback)
mclaims := jwtgo.MapClaims(claims)
jwtToken, err := jp.ParseWithClaims(token, &mclaims, keyFuncCallback)
if err != nil {
// Re-populate the public key in-case the JWKS
// pubkeys are refreshed
if err = r.PopulatePublicKey(arn); err != nil {
return err
}
jwtToken, err = jwtgo.ParseWithClaims(token, &claims, keyFuncCallback)
jwtToken, err = jwtgo.ParseWithClaims(token, &mclaims, keyFuncCallback)
if err != nil {
return err
}
@ -173,11 +174,11 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
return ErrTokenExpired
}
if err = updateClaimsExpiry(dsecs, claims); err != nil {
if err = updateClaimsExpiry(dsecs, mclaims); err != nil {
return err
}
if err = r.updateUserinfoClaims(ctx, arn, accessToken, claims); err != nil {
if err = r.updateUserinfoClaims(ctx, arn, accessToken, mclaims); err != nil {
return err
}
@ -190,7 +191,7 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
// array of case sensitive strings. In the common special case
// when there is one audience, the aud value MAY be a single
// case sensitive
audValues, ok := policy.GetValuesFromClaims(claims, audClaim)
audValues, ok := policy.GetValuesFromClaims(mclaims, audClaim)
if !ok {
return errors.New("STS JWT Token has `aud` claim invalid, `aud` must match configured OpenID Client ID")
}
@ -204,7 +205,7 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
// be included even when the authorized party is the same
// as the sole audience. The azp value is a case sensitive
// string containing a StringOrURI value
azpValues, ok := policy.GetValuesFromClaims(claims, azpClaim)
azpValues, ok := policy.GetValuesFromClaims(mclaims, azpClaim)
if !ok {
return errors.New("STS JWT Token has `azp` claim invalid, `azp` must match configured OpenID Client ID")
}