From a47e24404c81c3ca90f76cf9efcf79ce7f1c570e Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 20 Jan 2026 16:49:36 +0100 Subject: [PATCH] [auth] Change to jwt.RegisteredClaims, support multiple audiences --- pkg/auth/auth.go | 22 ++-- pkg/auth/auth_test.go | 249 +++++++++++++++++++++++++++++++++++++- pkg/auth/claims/claims.go | 8 +- 3 files changed, 263 insertions(+), 16 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index a071f1d25..f91823230 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -124,7 +124,7 @@ type Authorizer struct { logger *zap.Logger keys []interface{} keyGuard sync.RWMutex - acceptedAudiences map[string]bool + acceptedAudiences []string } // Configuration bundles up creation-time parameters for an Authorizer instance. @@ -143,13 +143,8 @@ func NewRSAAuthorizer(ctx context.Context, configuration Configuration) (*Author return nil, stacktrace.Propagate(err, "Unable to resolve keys") } - auds := make(map[string]bool) - for _, s := range configuration.AcceptedAudiences { - auds[s] = true - } - authorizer := &Authorizer{ - acceptedAudiences: auds, + acceptedAudiences: configuration.AcceptedAudiences, logger: logger, keys: keys, } @@ -205,7 +200,18 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio return api.AuthorizationResult{Error: stacktrace.Propagate(err, "Error retrieving claims from context")} } - if !a.acceptedAudiences[keyClaims.Audience] { + validAudience := false + + for _, audience := range a.acceptedAudiences { + + if keyClaims.VerifyAudience(audience, true) { + validAudience = true + break + } + + } + + if !validAudience { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)} } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 02141713c..243cc6063 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -27,6 +27,7 @@ func rsaTokenReq(key *rsa.PrivateKey, exp, nbf int64) *http.Request { "nbf": nbf, "sub": "real_owner", "iss": "baz", + "aud": "test-aud", }) // Sign and get the complete encoded token as a string using the secret @@ -42,6 +43,46 @@ func rsaTokenReqWithMissingIssuer(key *rsa.PrivateKey, exp, nbf int64) *http.Req "exp": exp, "nbf": nbf, "sub": "real_owner", + "aud": "test-aud", + }) + + // Sign and get the complete encoded token as a string using the secret + // Ignore the error, it will fail the test anyways if it is not nil. + tokenString, _ := token.SignedString(key) + req := &http.Request{Header: make(http.Header)} + req.Header.Set("Authorization", "Bearer "+tokenString) + return req +} + +func rsaTokenReqWithMissingAudience(key *rsa.PrivateKey, exp, nbf int64) *http.Request { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "foo": "bar", + "exp": exp, + "nbf": nbf, + "sub": "real_owner", + "iss": "baz", + }) + + // Sign and get the complete encoded token as a string using the secret + // Ignore the error, it will fail the test anyways if it is not nil. + tokenString, _ := token.SignedString(key) + req := &http.Request{Header: make(http.Header)} + req.Header.Set("Authorization", "Bearer "+tokenString) + return req +} + +func rsaTokenReqWithMultipleAudience(key *rsa.PrivateKey, exp, nbf int64) *http.Request { + return rsaTokenReqWithAudiences(key, exp, nbf, []string{"test-aud", "test-aud2"}) +} + +func rsaTokenReqWithAudiences(key *rsa.PrivateKey, exp, nbf int64, audiences interface{}) *http.Request { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "foo": "bar", + "exp": exp, + "nbf": nbf, + "sub": "real_owner", + "iss": "baz", + "aud": audiences, }) // Sign and get the complete encoded token as a string using the secret @@ -99,6 +140,8 @@ func TestRSAAuthInterceptor(t *testing.T) { {noTokenReq, dsserr.Unauthenticated}, {rsaTokenReq(badKey, 100, 20), dsserr.Unauthenticated}, {rsaTokenReq(key, 100, 20), stacktrace.NoCode}, + {rsaTokenReqWithMultipleAudience(key, 100, 20), stacktrace.NoCode}, + {rsaTokenReqWithMissingAudience(key, 100, 20), dsserr.Unauthenticated}, {rsaTokenReqWithMissingIssuer(key, 100, 20), dsserr.Unauthenticated}, {rsaTokenReq(key, 30, 20), dsserr.Unauthenticated}, {rsaTokenReq(key, 100, 50), dsserr.Unauthenticated}, @@ -109,7 +152,7 @@ func TestRSAAuthInterceptor(t *testing.T) { Keys: []interface{}{&key.PublicKey}, }, KeyRefreshTimeout: 1 * time.Millisecond, - AcceptedAudiences: []string{""}, + AcceptedAudiences: []string{"test-aud"}, }) require.NoError(t, err) @@ -133,6 +176,204 @@ func TestRSAAuthInterceptor(t *testing.T) { } } +func TestRSAAuthAudiences(t *testing.T) { + + var tests = []struct { + Accepted []string + Provided interface{} + ShouldBeAuthorized bool + }{ + { + []string{"aud1", "aud2"}, + []string{"aud1", "aud2", "aud3"}, + true, + }, + { + []string{"aud1", "aud2"}, + []string{"aud2", "aud3"}, + true, + }, + { + []string{"aud1", "aud2"}, + []string{"aud1", "aud3"}, + true, + }, + { + []string{"aud1", "aud2"}, + []string{"aud1", "aud2"}, + true, + }, + { + []string{"aud1", "aud2"}, + []string{"aud1"}, + true, + }, + { + []string{"aud1", "aud2"}, + []string{"aud2"}, + true, + }, + { + []string{"aud1", "aud2"}, + []string{"aud3"}, + false, + }, + { + []string{"aud1", "aud2"}, + "aud1", + true, + }, + { + []string{"aud1", "aud2"}, + "aud2", + true, + }, + { + []string{"aud1", "aud2"}, + "aud3", + false, + }, + { + []string{"aud1", "aud2"}, + []string{}, + false, + }, + { + []string{"aud1", "aud2"}, + "", + false, + }, + { + []string{"aud1", "aud2"}, + nil, + false, + }, + { + []string{"aud1"}, + []string{"aud1", "aud2", "aud3"}, + true, + }, + { + []string{"aud1"}, + []string{"aud2", "aud3"}, + false, + }, + { + []string{"aud1"}, + []string{"aud2", "aud1"}, + true, + }, + { + []string{"aud1"}, + []string{"aud1"}, + true, + }, + { + []string{"aud1"}, + []string{"aud2"}, + false, + }, + { + []string{"aud1"}, + "aud1", + true, + }, + { + []string{"aud1"}, + "aud2", + false, + }, + { + []string{"aud1"}, + []string{}, + false, + }, + { + []string{"aud1"}, + "", + false, + }, + { + []string{"aud1"}, + nil, + false, + }, + { + []string{}, + []string{"aud1"}, + false, + }, + { + []string{}, + "aud1", + false, + }, + { + []string{}, + []string{}, + false, + }, + { + []string{}, + "", + false, + }, + { + []string{}, + nil, + false, + }, + } + + jwt.TimeFunc = func() time.Time { + return time.Unix(42, 0) + } + + defer func() { + jwt.TimeFunc = time.Now + }() + + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + + a, err := NewRSAAuthorizer(t.Context(), Configuration{ + KeyResolver: &fromMemoryKeyResolver{ + Keys: []interface{}{&key.PublicKey}, + }, + KeyRefreshTimeout: 1 * time.Millisecond, + AcceptedAudiences: test.Accepted, + }) + require.NoError(t, err) + + req := rsaTokenReqWithAudiences(key, 100, 20, test.Provided) + code := dsserr.Unauthenticated + + if test.ShouldBeAuthorized { + code = stacktrace.NoCode + } + + ctx := t.Context() + claimsValue, err := a.extractClaims(req) + if err != nil { + ctx = claims.NewContextFromError(ctx, err) + } else { + ctx = claims.NewContext(ctx, claimsValue) + } + + res := a.Authorize(nil, req.WithContext(ctx), []api.AuthorizationOption{}) + if code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != code { + t.Logf("%v", res.Error) + t.Errorf("expected: %v, got: %v, with message %s", code, stacktrace.GetCode(res.Error), res.Error.Error()) + } + }) + } +} + func TestMissingScopes(t *testing.T) { authOptions := []api.AuthorizationOption{ {"TestAuth1": {"required1"}}, @@ -217,18 +458,18 @@ func TestClaimsValidation(t *testing.T) { require.Error(t, claims.Valid()) claims.Subject = "real_owner" - claims.ExpiresAt = 45 + claims.ExpiresAt = jwt.NewNumericDate(time.Unix(45, 0)) claims.Issuer = "real_issuer" require.NoError(t, claims.Valid()) // Test error out on expired token Now.Unix() = 42 - claims.ExpiresAt = 41 + claims.ExpiresAt = jwt.NewNumericDate(time.Unix(41, 0)) require.Error(t, claims.Valid()) // Test error out on missing Issuer URI claims.Issuer = "" - claims.ExpiresAt = 45 + claims.ExpiresAt = jwt.NewNumericDate(time.Unix(45, 0)) require.Error(t, claims.Valid()) } diff --git a/pkg/auth/claims/claims.go b/pkg/auth/claims/claims.go index 92f2b11f7..df055b335 100644 --- a/pkg/auth/claims/claims.go +++ b/pkg/auth/claims/claims.go @@ -91,7 +91,7 @@ func (s *ScopeSet) ToStringSlice() []string { } type Claims struct { - jwt.StandardClaims + jwt.RegisteredClaims Scopes ScopeSet `json:"scope"` } @@ -101,9 +101,9 @@ func (c *Claims) Valid() error { } now := Now() - c.VerifyExpiresAt(now.Unix(), true) + c.VerifyExpiresAt(now, true) - if c.ExpiresAt > now.Add(time.Hour).Unix() { + if c.ExpiresAt.After(now.Add(time.Hour)) { return errTokenExpireTooFar } @@ -111,5 +111,5 @@ func (c *Claims) Valid() error { return errMissingIssuer } - return c.StandardClaims.Valid() + return c.RegisteredClaims.Valid() }