Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type Authorizer struct {
logger *zap.Logger
keys []interface{}
keyGuard sync.RWMutex
acceptedAudiences map[string]bool
acceptedAudiences []string
}

// Configuration bundles up creation-time parameters for an Authorizer instance.
Expand All @@ -143,13 +143,8 @@ func NewRSAAuthorizer(ctx context.Context, configuration Configuration) (*Author
return nil, stacktrace.Propagate(err, "Unable to resolve keys")
}

auds := make(map[string]bool)
for _, s := range configuration.AcceptedAudiences {
auds[s] = true
}

authorizer := &Authorizer{
acceptedAudiences: auds,
acceptedAudiences: configuration.AcceptedAudiences,
logger: logger,
keys: keys,
}
Expand Down Expand Up @@ -205,7 +200,18 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio
return api.AuthorizationResult{Error: stacktrace.Propagate(err, "Error retrieving claims from context")}
}

if !a.acceptedAudiences[keyClaims.Audience] {
validAudience := false

for _, audience := range a.acceptedAudiences {

if keyClaims.VerifyAudience(audience, true) {
validAudience = true
break
}

}

if !validAudience {
return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Invalid access token audience: %v", keyClaims.Audience)}
}

Expand Down
249 changes: 245 additions & 4 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func rsaTokenReq(key *rsa.PrivateKey, exp, nbf int64) *http.Request {
"nbf": nbf,
"sub": "real_owner",
"iss": "baz",
"aud": "test-aud",
})

// Sign and get the complete encoded token as a string using the secret
Expand All @@ -42,6 +43,46 @@ func rsaTokenReqWithMissingIssuer(key *rsa.PrivateKey, exp, nbf int64) *http.Req
"exp": exp,
"nbf": nbf,
"sub": "real_owner",
"aud": "test-aud",
})

// Sign and get the complete encoded token as a string using the secret
// Ignore the error, it will fail the test anyways if it is not nil.
tokenString, _ := token.SignedString(key)
req := &http.Request{Header: make(http.Header)}
req.Header.Set("Authorization", "Bearer "+tokenString)
return req
}

func rsaTokenReqWithMissingAudience(key *rsa.PrivateKey, exp, nbf int64) *http.Request {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"foo": "bar",
"exp": exp,
"nbf": nbf,
"sub": "real_owner",
"iss": "baz",
})

// Sign and get the complete encoded token as a string using the secret
// Ignore the error, it will fail the test anyways if it is not nil.
tokenString, _ := token.SignedString(key)
req := &http.Request{Header: make(http.Header)}
req.Header.Set("Authorization", "Bearer "+tokenString)
return req
}

func rsaTokenReqWithMultipleAudience(key *rsa.PrivateKey, exp, nbf int64) *http.Request {
return rsaTokenReqWithAudiences(key, exp, nbf, []string{"test-aud", "test-aud2"})
}

func rsaTokenReqWithAudiences(key *rsa.PrivateKey, exp, nbf int64, audiences interface{}) *http.Request {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"foo": "bar",
"exp": exp,
"nbf": nbf,
"sub": "real_owner",
"iss": "baz",
"aud": audiences,
})

// Sign and get the complete encoded token as a string using the secret
Expand Down Expand Up @@ -99,6 +140,8 @@ func TestRSAAuthInterceptor(t *testing.T) {
{noTokenReq, dsserr.Unauthenticated},
{rsaTokenReq(badKey, 100, 20), dsserr.Unauthenticated},
{rsaTokenReq(key, 100, 20), stacktrace.NoCode},
{rsaTokenReqWithMultipleAudience(key, 100, 20), stacktrace.NoCode},
{rsaTokenReqWithMissingAudience(key, 100, 20), dsserr.Unauthenticated},
{rsaTokenReqWithMissingIssuer(key, 100, 20), dsserr.Unauthenticated},
{rsaTokenReq(key, 30, 20), dsserr.Unauthenticated},
{rsaTokenReq(key, 100, 50), dsserr.Unauthenticated},
Expand All @@ -109,7 +152,7 @@ func TestRSAAuthInterceptor(t *testing.T) {
Keys: []interface{}{&key.PublicKey},
},
KeyRefreshTimeout: 1 * time.Millisecond,
AcceptedAudiences: []string{""},
AcceptedAudiences: []string{"test-aud"},
})

require.NoError(t, err)
Expand All @@ -133,6 +176,204 @@ func TestRSAAuthInterceptor(t *testing.T) {
}
}

func TestRSAAuthAudiences(t *testing.T) {

var tests = []struct {
Accepted []string
Provided interface{}
ShouldBeAuthorized bool
}{
{
[]string{"aud1", "aud2"},
[]string{"aud1", "aud2", "aud3"},
true,
},
{
[]string{"aud1", "aud2"},
[]string{"aud2", "aud3"},
true,
},
{
[]string{"aud1", "aud2"},
[]string{"aud1", "aud3"},
true,
},
{
[]string{"aud1", "aud2"},
[]string{"aud1", "aud2"},
true,
},
{
[]string{"aud1", "aud2"},
[]string{"aud1"},
true,
},
{
[]string{"aud1", "aud2"},
[]string{"aud2"},
true,
},
{
[]string{"aud1", "aud2"},
[]string{"aud3"},
false,
},
{
[]string{"aud1", "aud2"},
"aud1",
true,
},
{
[]string{"aud1", "aud2"},
"aud2",
true,
},
{
[]string{"aud1", "aud2"},
"aud3",
false,
},
{
[]string{"aud1", "aud2"},
[]string{},
false,
},
{
[]string{"aud1", "aud2"},
"",
false,
},
{
[]string{"aud1", "aud2"},
nil,
false,
},
{
[]string{"aud1"},
[]string{"aud1", "aud2", "aud3"},
true,
},
{
[]string{"aud1"},
[]string{"aud2", "aud3"},
false,
},
{
[]string{"aud1"},
[]string{"aud2", "aud1"},
true,
},
{
[]string{"aud1"},
[]string{"aud1"},
true,
},
{
[]string{"aud1"},
[]string{"aud2"},
false,
},
{
[]string{"aud1"},
"aud1",
true,
},
{
[]string{"aud1"},
"aud2",
false,
},
{
[]string{"aud1"},
[]string{},
false,
},
{
[]string{"aud1"},
"",
false,
},
{
[]string{"aud1"},
nil,
false,
},
{
[]string{},
[]string{"aud1"},
false,
},
{
[]string{},
"aud1",
false,
},
{
[]string{},
[]string{},
false,
},
{
[]string{},
"",
false,
},
{
[]string{},
nil,
false,
},
}

jwt.TimeFunc = func() time.Time {
return time.Unix(42, 0)
}

defer func() {
jwt.TimeFunc = time.Now
}()

key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}

for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {

a, err := NewRSAAuthorizer(t.Context(), Configuration{
KeyResolver: &fromMemoryKeyResolver{
Keys: []interface{}{&key.PublicKey},
},
KeyRefreshTimeout: 1 * time.Millisecond,
AcceptedAudiences: test.Accepted,
})
require.NoError(t, err)

req := rsaTokenReqWithAudiences(key, 100, 20, test.Provided)
code := dsserr.Unauthenticated

if test.ShouldBeAuthorized {
code = stacktrace.NoCode
}

ctx := t.Context()
claimsValue, err := a.extractClaims(req)
if err != nil {
ctx = claims.NewContextFromError(ctx, err)
} else {
ctx = claims.NewContext(ctx, claimsValue)
}

res := a.Authorize(nil, req.WithContext(ctx), []api.AuthorizationOption{})
if code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != code {
t.Logf("%v", res.Error)
t.Errorf("expected: %v, got: %v, with message %s", code, stacktrace.GetCode(res.Error), res.Error.Error())
}
})
}
}

func TestMissingScopes(t *testing.T) {
authOptions := []api.AuthorizationOption{
{"TestAuth1": {"required1"}},
Expand Down Expand Up @@ -217,18 +458,18 @@ func TestClaimsValidation(t *testing.T) {
require.Error(t, claims.Valid())

claims.Subject = "real_owner"
claims.ExpiresAt = 45
claims.ExpiresAt = jwt.NewNumericDate(time.Unix(45, 0))
claims.Issuer = "real_issuer"

require.NoError(t, claims.Valid())

// Test error out on expired token Now.Unix() = 42
claims.ExpiresAt = 41
claims.ExpiresAt = jwt.NewNumericDate(time.Unix(41, 0))
require.Error(t, claims.Valid())

// Test error out on missing Issuer URI
claims.Issuer = ""
claims.ExpiresAt = 45
claims.ExpiresAt = jwt.NewNumericDate(time.Unix(45, 0))
require.Error(t, claims.Valid())
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/auth/claims/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (s *ScopeSet) ToStringSlice() []string {
}

type Claims struct {
jwt.StandardClaims
jwt.RegisteredClaims
Scopes ScopeSet `json:"scope"`
}

Expand All @@ -101,15 +101,15 @@ func (c *Claims) Valid() error {
}
now := Now()

c.VerifyExpiresAt(now.Unix(), true)
c.VerifyExpiresAt(now, true)

if c.ExpiresAt > now.Add(time.Hour).Unix() {
if c.ExpiresAt.After(now.Add(time.Hour)) {
return errTokenExpireTooFar
}

if c.Issuer == "" {
return errMissingIssuer
}

return c.StandardClaims.Valid()
return c.RegisteredClaims.Valid()
}
Loading