Skip to content

Commit fdcf9ab

Browse files
committed
add majority of tests
1 parent 60b659a commit fdcf9ab

File tree

10 files changed

+837
-33
lines changed

10 files changed

+837
-33
lines changed

internal/api/oauthserver/handlers.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon
437437

438438
// Issue the refresh token and access token
439439
var terr error
440-
tokenResponse, terr = tokenService.IssueRefreshToken(r, tx, user, authMethod, grantParams)
440+
tokenResponse, terr = tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams)
441441
if terr != nil {
442442
return terr
443443
}
@@ -488,7 +488,7 @@ func (s *Server) handleRefreshTokenGrant(ctx context.Context, w http.ResponseWri
488488
}
489489

490490
db := s.db.WithContext(ctx)
491-
tokenResponse, err := tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{
491+
tokenResponse, err := tokenService.RefreshTokenGrant(ctx, db, r, w.Header(), tokens.RefreshTokenGrantParams{
492492
RefreshToken: params.RefreshToken,
493493
ClientID: clientID,
494494
})

internal/api/token.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
190190
}); terr != nil {
191191
return terr
192192
}
193-
token, terr = a.tokenService.IssueRefreshToken(r, tx, user, models.PasswordGrant, grantParams)
193+
token, terr = a.tokenService.IssueRefreshToken(r, w.Header(), tx, user, models.PasswordGrant, grantParams)
194194
if terr != nil {
195195
return terr
196196
}
@@ -260,7 +260,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
260260
}); terr != nil {
261261
return terr
262262
}
263-
token, terr = a.tokenService.IssueRefreshToken(r, tx, user, authMethod, grantParams)
263+
token, terr = a.tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams)
264264
if terr != nil {
265265
// error type is already handled in issueRefreshToken
266266
return terr
@@ -295,7 +295,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
295295
}
296296

297297
func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*tokens.AccessTokenResponse, error) {
298-
return a.tokenService.IssueRefreshToken(r, conn, user, authenticationMethod, grantParams)
298+
return a.tokenService.IssueRefreshToken(r, make(http.Header), conn, user, authenticationMethod, grantParams)
299299
}
300300

301301
func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*tokens.AccessTokenResponse, error) {

internal/api/token_refresh.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package api
33
import (
44
"context"
55
"net/http"
6+
"regexp"
67

8+
"github.com/supabase/auth/internal/api/apierrors"
9+
"github.com/supabase/auth/internal/crypto"
710
"github.com/supabase/auth/internal/tokens"
811
)
912

@@ -12,15 +15,42 @@ type RefreshTokenGrantParams struct {
1215
RefreshToken string `json:"refresh_token"`
1316
}
1417

18+
var legacyRefreshTokenPattern = regexp.MustCompile("^[a-z0-9]{12}$")
19+
20+
func (p *RefreshTokenGrantParams) Validate() error {
21+
if len(p.RefreshToken) < 12 {
22+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid")
23+
}
24+
25+
if len(p.RefreshToken) == 12 {
26+
if !legacyRefreshTokenPattern.MatchString(p.RefreshToken) {
27+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid")
28+
}
29+
30+
return nil
31+
}
32+
33+
_, err := crypto.ParseRefreshToken(p.RefreshToken)
34+
if err != nil {
35+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid").WithInternalError(err)
36+
}
37+
38+
return nil
39+
}
40+
1541
// RefreshTokenGrant implements the refresh_token grant type flow
1642
func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
1743
params := &RefreshTokenGrantParams{}
1844
if err := retrieveRequestParams(r, params); err != nil {
1945
return err
2046
}
2147

48+
if err := params.Validate(); err != nil {
49+
return err
50+
}
51+
2252
db := a.db.WithContext(ctx)
23-
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{
53+
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, w.Header(), tokens.RefreshTokenGrantParams{
2454
RefreshToken: params.RefreshToken,
2555
})
2656
if err != nil {

internal/crypto/crypto_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,5 @@ func TestEncryptedStringDecryptNegative(t *testing.T) {
105105

106106
func TestSecureToken(t *testing.T) {
107107
assert.Equal(t, len(SecureAlphanumeric(22)), 22)
108+
assert.Equal(t, len(SecureAlphanumeric(7)), 8)
108109
}

internal/crypto/refresh_tokens.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (r *RefreshToken) Encode(hmacSha256Key []byte) string {
6161

6262
result = append(result, 0)
6363
result = append(result, r.SessionID.Bytes()...)
64-
result = binary.AppendUvarint(result, uint64(r.Counter))
64+
result = binary.AppendUvarint(result, safeUint64(r.Counter))
6565

6666
// Note on truncating the HMAC-SHA-256 output:
6767
// This does not impact security as the brute-force space is 2^128 and
@@ -90,6 +90,22 @@ var (
9090
ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid")
9191
)
9292

93+
func safeInt64(v uint64) int64 {
94+
if v > math.MaxInt64 {
95+
return math.MaxInt64
96+
}
97+
98+
return int64(v)
99+
}
100+
101+
func safeUint64(v int64) uint64 {
102+
if v < 0 {
103+
return 0
104+
}
105+
106+
return uint64(v)
107+
}
108+
93109
func ParseRefreshToken(token string) (*RefreshToken, error) {
94110
bytes, err := base64.RawURLEncoding.DecodeString(token)
95111
if err != nil {
@@ -111,10 +127,7 @@ func ParseRefreshToken(token string) (*RefreshToken, error) {
111127
return nil, ErrRefreshTokenChecksumInvalid
112128
}
113129

114-
sessionID, err := uuid.FromBytes(parseFrom[0:16])
115-
if err != nil {
116-
return nil, err
117-
}
130+
sessionID := uuid.FromBytesOrNil(parseFrom[0:16])
118131

119132
parseFrom = parseFrom[16:]
120133

@@ -123,10 +136,6 @@ func ParseRefreshToken(token string) (*RefreshToken, error) {
123136
return nil, ErrRefreshTokenCounterInvalid
124137
}
125138

126-
if counter > math.MaxInt64 {
127-
return nil, ErrRefreshTokenCounterInvalid
128-
}
129-
130139
parseFrom = parseFrom[counterBytes:]
131140

132141
if len(parseFrom) != 16 {
@@ -140,7 +149,7 @@ func ParseRefreshToken(token string) (*RefreshToken, error) {
140149

141150
Version: 0,
142151
SessionID: sessionID,
143-
Counter: int64(counter),
152+
Counter: safeInt64(counter),
144153
Signature: signature,
145154
}, nil
146155
}

internal/crypto/refresh_tokens_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@ import (
44
"crypto/sha256"
55
"encoding/base64"
66
"fmt"
7+
"math"
78
"strings"
89
"testing"
910

1011
"github.com/gofrs/uuid"
1112
"github.com/stretchr/testify/require"
1213
)
1314

15+
func TestSafeIntegers(t *testing.T) {
16+
require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxUint64))
17+
require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64))
18+
19+
require.Equal(t, int64(0), safeUint64(-1))
20+
require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64))
21+
}
22+
1423
func TestRefreshTokenParse(t *testing.T) {
1524
negativeExamples := []struct {
1625
value []byte
@@ -84,3 +93,9 @@ func TestRefreshTokenParse(t *testing.T) {
8493
require.Equal(t, original.Raw, parsed.Raw)
8594
require.Equal(t, original.Signature, parsed.Signature)
8695
}
96+
97+
func TestRefreshTokenTableName(t *testing.T) {
98+
require.Panics(t, func() {
99+
RefreshToken{}.TableName()
100+
})
101+
}

internal/models/sessions.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,23 @@ func (s *Session) UpdateOnlyRefreshToken(tx *storage.Connection) error {
167167
return tx.UpdateOnly(s, "refresh_token_counter")
168168
}
169169

170+
func (s *Session) ReEncryptRefreshTokenHmacKey(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration) error {
171+
key, _, err := s.GetRefreshTokenHmacKey(dbEncryption)
172+
if err != nil {
173+
return err
174+
}
175+
176+
es, err := crypto.NewEncryptedString(s.ID.String(), []byte(base64.RawURLEncoding.EncodeToString(key)), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey)
177+
if err != nil {
178+
return err
179+
}
180+
181+
encryptedValue := es.String()
182+
s.RefreshTokenHmacKey = &encryptedValue
183+
184+
return tx.UpdateOnly(s, "refresh_token_hmac_key")
185+
}
186+
170187
type SessionValidityReason = int
171188

172189
const (

internal/models/user.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,13 @@ func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) {
636636
// lock will only be acquired if there's no other lock. In case there is a
637637
// lock, a IsNotFound(err) error will be returned.
638638
//
639-
// Second value returned is either *models.RefreshToken or *models.TODO.
639+
// Second value returned is either *models.RefreshToken or *crypto.RefreshToken.
640640
func FindUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration, token string, forUpdate bool) (*User, any, *Session, error) {
641+
if len(token) < 12 {
642+
// not a valid refresh token so don't bother looking it up in the database
643+
return nil, nil, nil, SessionNotFoundError{}
644+
}
645+
641646
if len(token) == 12 {
642647
return findUserWithLegacyRefreshToken(tx, token, forUpdate)
643648
}
@@ -648,11 +653,11 @@ func FindUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.Database
648653
func findUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration, token string, forUpdate bool) (*User, *crypto.RefreshToken, *Session, error) {
649654
refreshToken, err := crypto.ParseRefreshToken(token)
650655
if err != nil {
651-
return nil, nil, nil, err
656+
// refresh token is not valid
657+
return nil, nil, nil, SessionNotFoundError{}
652658
}
653659

654-
// first find the session to check the token's signature
655-
session, err := FindSessionByID(tx, refreshToken.SessionID, false)
660+
session, err := FindSessionByID(tx, refreshToken.SessionID, forUpdate)
656661
if err != nil {
657662
return nil, nil, nil, err
658663
}
@@ -663,28 +668,28 @@ func findUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.Database
663668
return nil, nil, nil, SessionNotFoundError{}
664669
}
665670

666-
key, _, err := session.GetRefreshTokenHmacKey(dbEncryption)
671+
key, shouldReEncrypt, err := session.GetRefreshTokenHmacKey(dbEncryption)
667672
if err != nil {
668673
return nil, nil, nil, err
669674
}
670675

671676
if !refreshToken.CheckSignature(key) {
672-
// TODO: return SessionNotFound, log informational
673-
return nil, nil, nil, fmt.Errorf("refresh token for session %s with counter %v has invalid signature", session.ID.String(), refreshToken.Counter)
674-
}
675-
676-
user, err := FindUserByID(tx, session.UserID)
677-
if err != nil {
678-
return nil, nil, nil, err
677+
// refresh token signature is not valid for this session
678+
return nil, nil, nil, SessionNotFoundError{}
679679
}
680680

681-
if forUpdate {
682-
session, err = FindSessionByID(tx, refreshToken.SessionID, forUpdate)
681+
if shouldReEncrypt && forUpdate {
682+
err := session.ReEncryptRefreshTokenHmacKey(tx, dbEncryption)
683683
if err != nil {
684684
return nil, nil, nil, err
685685
}
686686
}
687687

688+
user, err := FindUserByID(tx, session.UserID)
689+
if err != nil {
690+
return nil, nil, nil, err
691+
}
692+
688693
return user, refreshToken, session, nil
689694
}
690695

0 commit comments

Comments
 (0)