From e5904ae5a9ee665db9b67817507683e8e06c67d1 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Fri, 10 Oct 2025 17:00:13 +0200 Subject: [PATCH 1/7] feat: introduce v2 refresh token algorithm --- internal/api/token_test.go | 7 +- internal/conf/configuration.go | 2 + internal/crypto/crypto.go | 1 + internal/crypto/refresh_tokens.go | 146 +++++++++++ internal/crypto/refresh_tokens_test.go | 86 ++++++ internal/models/refresh_token.go | 65 +++-- internal/models/refresh_token_test.go | 10 +- internal/models/sessions.go | 41 +++ internal/models/user.go | 56 +++- internal/models/user_test.go | 10 +- internal/tokens/service.go | 245 ++++++++++++++---- ...0_add_session_refresh_token_columns.up.sql | 6 + 12 files changed, 593 insertions(+), 82 deletions(-) create mode 100644 internal/crypto/refresh_tokens.go create mode 100644 internal/crypto/refresh_tokens_test.go create mode 100644 migrations/20251007112900_add_session_refresh_token_columns.up.sql diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 166bfc900..919248c8a 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -435,9 +435,11 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() { // ensure that the 4 refresh tokens are setup correctly for i, refreshToken := range refreshTokens { - _, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) + _, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false) require.NoError(ts.T(), err) + token := anyToken.(*models.RefreshToken) + if i == len(refreshTokens)-1 { require.False(ts.T(), token.Revoked) } else { @@ -470,9 +472,10 @@ func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() { // ensure that the refresh tokens are marked as revoked in the database for _, refreshToken := range refreshTokens { - _, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) + _, anyToken, _, err := models.FindUserWithRefreshToken(ts.API.db, ts.Config.Security.DBEncryption, refreshToken, false) require.NoError(ts.T(), err) + token := anyToken.(*models.RefreshToken) require.True(ts.T(), token.Revoked) } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 84ec2ee16..3f4bbf833 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -723,8 +723,10 @@ func (c *DatabaseEncryptionConfiguration) Validate() error { type SecurityConfiguration struct { Captcha CaptchaConfiguration `json:"captcha"` + RefreshTokenAlgorithmVersion int `json:"refresh_token_algorithm_version" split_words:"true"` RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"` + RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"` UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index bb8688bb6..70b148add 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -29,6 +29,7 @@ func GenerateOtp(digits int) string { return otp } + func GenerateTokenHash(emailOrPhone, otp string) string { return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) } diff --git a/internal/crypto/refresh_tokens.go b/internal/crypto/refresh_tokens.go new file mode 100644 index 000000000..2744c03cf --- /dev/null +++ b/internal/crypto/refresh_tokens.go @@ -0,0 +1,146 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "encoding/binary" + "errors" + "math" + + "github.com/gofrs/uuid" +) + +func GenerateRefreshTokenHmacKey() []byte { + key := make([]byte, 32) + must(rand.Read(key)) + + return key +} + +const refreshTokenChecksumLength = 4 +const refreshTokenSignatureLength = 16 +const minRefreshTokenLength = 1 + 16 + 1 + refreshTokenSignatureLength + refreshTokenChecksumLength +const maxRefreshTokenLength = minRefreshTokenLength + 8 + +// RefreshToken is an object that encodes a cryptographically authenticated +// (signed) message containing a version, session ID and monotonically +// increasing non-negative counter. +// +// The signature is a truncated (first 128 bits) of HMAC-SHA-256, which saves +// on encoded length without sacrificing security. The checksum of 4 bytes at +// the end is to lessen the load on the server with invalid strings (those that +// are not likely to be a proper refresh token). +type RefreshToken struct { + Raw []byte + + Version byte + SessionID uuid.UUID + Counter int64 + Signature []byte +} + +func (RefreshToken) TableName() string { + panic("crypto.RefreshToken is not meant to be saved in the database") +} + +func (r *RefreshToken) CheckSignature(hmacSha256Key []byte) bool { + bytes := r.Raw[:len(r.Raw)-refreshTokenSignatureLength-refreshTokenChecksumLength] + + h := hmac.New(sha256.New, hmacSha256Key) + h.Write(bytes) + signature := h.Sum(nil)[:refreshTokenSignatureLength] + + return hmac.Equal(signature, r.Signature) +} + +func (r *RefreshToken) Encode(hmacSha256Key []byte) string { + result := make([]byte, 0, maxRefreshTokenLength) + + result = append(result, 0) + result = append(result, r.SessionID.Bytes()...) + result = binary.AppendUvarint(result, uint64(r.Counter)) + + // Note on truncating the HMAC-SHA-256 output: + // This does not impact security as the brute-force space is 2^128 and + // the collision space is 2^64, both unattainable in practice. + + h := hmac.New(sha256.New, hmacSha256Key) + h.Write(result) + signature := h.Sum(nil)[:refreshTokenSignatureLength] + + result = append(result, signature...) + + checksum := sha256.Sum256(result) + result = append(result, checksum[:refreshTokenChecksumLength]...) + + r.Version = 0 + r.Raw = result + r.Signature = signature + + return base64.RawURLEncoding.EncodeToString(result) +} + +var ( + ErrRefreshTokenLength = errors.New("crypto: refresh token length is not valid") + ErrRefreshTokenUnknownVersion = errors.New("crypto: refresh token version is not 0") + ErrRefreshTokenChecksumInvalid = errors.New("crypto: refresh token checksum is not valid") + ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid") +) + +func ParseRefreshToken(token string) (*RefreshToken, error) { + bytes, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return nil, err + } + + if len(bytes) < minRefreshTokenLength { + return nil, ErrRefreshTokenLength + } + + if bytes[0] != 0 { + return nil, ErrRefreshTokenUnknownVersion + } + + parseFrom := bytes[1 : len(bytes)-refreshTokenChecksumLength] + + checksum256 := sha256.Sum256(bytes[:len(bytes)-refreshTokenChecksumLength]) + if subtle.ConstantTimeCompare(checksum256[:refreshTokenChecksumLength], bytes[len(bytes)-refreshTokenChecksumLength:]) != 1 { + return nil, ErrRefreshTokenChecksumInvalid + } + + sessionID, err := uuid.FromBytes(parseFrom[0:16]) + if err != nil { + return nil, err + } + + parseFrom = parseFrom[16:] + + counter, counterBytes := binary.Uvarint(parseFrom) + if counterBytes <= 0 { + return nil, ErrRefreshTokenCounterInvalid + } + + if counter > math.MaxInt64 { + return nil, ErrRefreshTokenCounterInvalid + } + + parseFrom = parseFrom[counterBytes:] + + if len(parseFrom) != 16 { + return nil, ErrRefreshTokenLength + } + + signature := parseFrom + + return &RefreshToken{ + Raw: bytes, + + Version: 0, + SessionID: sessionID, + Counter: int64(counter), + Signature: signature, + }, nil +} diff --git a/internal/crypto/refresh_tokens_test.go b/internal/crypto/refresh_tokens_test.go new file mode 100644 index 000000000..f5aabad77 --- /dev/null +++ b/internal/crypto/refresh_tokens_test.go @@ -0,0 +1,86 @@ +package crypto + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "strings" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" +) + +func TestRefreshTokenParse(t *testing.T) { + negativeExamples := []struct { + value []byte + error error + }{ + { + value: make([]byte, minRefreshTokenLength-1), + error: ErrRefreshTokenLength, + }, + { + value: make([]byte, minRefreshTokenLength), + error: ErrRefreshTokenChecksumInvalid, + }, + { + value: func() []byte { + b := make([]byte, minRefreshTokenLength) + b[0] = 1 + return b + }(), + error: ErrRefreshTokenUnknownVersion, + }, + { + value: func() []byte { + b := make([]byte, minRefreshTokenLength) + for i := 1 + 16; i < len(b); i += 1 { + b[i] = 0xFF + } + + checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength]) + copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength]) + return b + }(), + error: ErrRefreshTokenCounterInvalid, + }, + { + value: func() []byte { + b := make([]byte, minRefreshTokenLength) + b[1+16] = 0xFF + b[1+16+1] = 0 + + checksum := sha256.Sum256(b[:len(b)-refreshTokenChecksumLength]) + copy(b[len(b)-refreshTokenChecksumLength:], checksum[:refreshTokenChecksumLength]) + return b + }(), + error: ErrRefreshTokenLength, + }, + } + + for i, example := range negativeExamples { + t.Run(fmt.Sprintf("negative example %d", i), func(t *testing.T) { + rt, err := ParseRefreshToken(base64.RawURLEncoding.EncodeToString(example.value)) + require.Nil(t, rt) + require.Error(t, err) + require.Equal(t, err, example.error) + }) + } + + rt, err := ParseRefreshToken(strings.Repeat("!", (4*minRefreshTokenLength)/3)) + require.Nil(t, rt) + require.Error(t, err) + + original := &RefreshToken{ + SessionID: uuid.Must(uuid.NewV4()), + Counter: 9223372036854775807, + } + + parsed, err := ParseRefreshToken(original.Encode(make([]byte, 32))) + require.Nil(t, err) + require.Equal(t, original.SessionID.String(), parsed.SessionID.String()) + require.Equal(t, original.Counter, parsed.Counter) + require.Equal(t, original.Raw, parsed.Raw) + require.Equal(t, original.Signature, parsed.Signature) +} diff --git a/internal/models/refresh_token.go b/internal/models/refresh_token.go index d0a070b23..604e230b5 100644 --- a/internal/models/refresh_token.go +++ b/internal/models/refresh_token.go @@ -2,6 +2,7 @@ package models import ( "database/sql" + "encoding/base64" "net/http" "time" @@ -118,6 +119,50 @@ func FindTokenBySessionID(tx *storage.Connection, sessionId *uuid.UUID) (*Refres return refreshToken, nil } +func (s *Session) ApplyGrantParams(params *GrantParams) { + s.FactorID = params.FactorID + + if params.SessionNotAfter != nil { + s.NotAfter = params.SessionNotAfter + } + + if params.UserAgent != "" { + s.UserAgent = ¶ms.UserAgent + } + + if params.IP != "" { + s.IP = ¶ms.IP + } + + if params.SessionTag != nil && *params.SessionTag != "" { + s.Tag = params.SessionTag + } + + if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil { + s.OAuthClientID = params.OAuthClientID + } +} + +func (s *Session) SetupRefreshTokenData(dbEncryption conf.DatabaseEncryptionConfiguration) error { + hmacKey := base64.RawURLEncoding.EncodeToString(crypto.GenerateRefreshTokenHmacKey()) + + if dbEncryption.Encrypt { + es, err := crypto.NewEncryptedString(s.ID.String(), []byte(hmacKey), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey) + if err != nil { + return err + } + + hmacKey = es.String() + } + + counter := int64(0) + + s.RefreshTokenHmacKey = &hmacKey + s.RefreshTokenCounter = &counter + + return nil +} + func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) { token := &RefreshToken{ UserID: user.ID, @@ -135,25 +180,7 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok return nil, errors.Wrap(err, "error instantiating new session object") } - if params.SessionNotAfter != nil { - session.NotAfter = params.SessionNotAfter - } - - if params.UserAgent != "" { - session.UserAgent = ¶ms.UserAgent - } - - if params.IP != "" { - session.IP = ¶ms.IP - } - - if params.SessionTag != nil && *params.SessionTag != "" { - session.Tag = params.SessionTag - } - - if params.OAuthClientID != nil && *params.OAuthClientID != uuid.Nil { - session.OAuthClientID = params.OAuthClientID - } + session.ApplyGrantParams(params) if err := tx.Create(session); err != nil { return nil, errors.Wrap(err, "error creating new session") diff --git a/internal/models/refresh_token_test.go b/internal/models/refresh_token_test.go index 7f687f764..b527081b7 100644 --- a/internal/models/refresh_token_test.go +++ b/internal/models/refresh_token_test.go @@ -54,9 +54,11 @@ func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() { s, err := GrantRefreshTokenSwap(ts.config.AuditLog, &http.Request{}, ts.db, u, r) require.NoError(ts.T(), err) - _, nr, _, err := FindUserWithRefreshToken(ts.db, r.Token, false) + _, anyNR, _, err := FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, false) require.NoError(ts.T(), err) + nr := anyNR.(*RefreshToken) + require.Equal(ts.T(), r.ID, nr.ID) require.True(ts.T(), nr.Revoked, "expected old token to be revoked") @@ -69,9 +71,11 @@ func (ts *RefreshTokenTestSuite) TestLogout() { r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) require.NoError(ts.T(), err) + var anyR any + require.NoError(ts.T(), Logout(ts.db, u.ID)) - u, r, _, err = FindUserWithRefreshToken(ts.db, r.Token, false) - require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, r) + u, anyR, _, err = FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, false) + require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, anyR) require.True(ts.T(), IsNotFoundError(err), "expected NotFoundError") } diff --git a/internal/models/sessions.go b/internal/models/sessions.go index 9d06e4bbe..24bf83a37 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -2,6 +2,7 @@ package models import ( "database/sql" + "encoding/base64" "fmt" "slices" "sort" @@ -11,6 +12,8 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/storage" ) @@ -91,6 +94,9 @@ type Session struct { Tag *string `json:"tag" db:"tag"` OAuthClientID *uuid.UUID `json:"oauth_client_id" db:"oauth_client_id"` + + RefreshTokenHmacKey *string `json:"-" db:"refresh_token_hmac_key"` + RefreshTokenCounter *int64 `json:"-" db:"refresh_token_counter"` } func (Session) TableName() string { @@ -98,6 +104,37 @@ func (Session) TableName() string { return tableName } +func (s *Session) GetRefreshTokenHmacKey(dbEncryption conf.DatabaseEncryptionConfiguration) ([]byte, bool, error) { + if s.RefreshTokenHmacKey == nil { + return nil, false, nil + } + + if es := crypto.ParseEncryptedString(*s.RefreshTokenHmacKey); es != nil { + bytes, err := es.Decrypt(s.ID.String(), dbEncryption.DecryptionKeys) + if err != nil { + return nil, false, err + } + + hmacKey, err := base64.RawURLEncoding.DecodeString(string(bytes)) + if err != nil { + return nil, false, err + } + + return hmacKey, dbEncryption.Encrypt && es.ShouldReEncrypt(dbEncryption.EncryptionKeyID), nil + } + + if s.RefreshTokenHmacKey == nil { + return nil, false, nil + } + + hmacKey, err := base64.RawURLEncoding.DecodeString(*s.RefreshTokenHmacKey) + if err != nil { + return nil, false, err + } + + return hmacKey, dbEncryption.Encrypt, nil +} + func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time { refreshedAt := s.RefreshedAt @@ -126,6 +163,10 @@ func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error { return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip") } +func (s *Session) UpdateOnlyRefreshToken(tx *storage.Connection) error { + return tx.UpdateOnly(s, "refresh_token_counter") +} + type SessionValidityReason = int const ( diff --git a/internal/models/user.go b/internal/models/user.go index 068b0c970..f611e3ba9 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -13,6 +13,7 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/storage" "golang.org/x/crypto/bcrypt" @@ -634,7 +635,60 @@ func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { // the form SELECT ... FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE // lock will only be acquired if there's no other lock. In case there is a // lock, a IsNotFound(err) error will be returned. -func FindUserWithRefreshToken(tx *storage.Connection, token string, forUpdate bool) (*User, *RefreshToken, *Session, error) { +// +// Second value returned is either *models.RefreshToken or *models.TODO. +func FindUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration, token string, forUpdate bool) (*User, any, *Session, error) { + if len(token) == 12 { + return findUserWithLegacyRefreshToken(tx, token, forUpdate) + } + + return findUserWithRefreshToken(tx, dbEncryption, token, forUpdate) +} + +func findUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration, token string, forUpdate bool) (*User, *crypto.RefreshToken, *Session, error) { + refreshToken, err := crypto.ParseRefreshToken(token) + if err != nil { + return nil, nil, nil, err + } + + // first find the session to check the token's signature + session, err := FindSessionByID(tx, refreshToken.SessionID, false) + if err != nil { + return nil, nil, nil, err + } + + if session.RefreshTokenHmacKey == nil || session.RefreshTokenCounter == nil { + // if the session is not set up to support these encoded refresh tokens it's as if it doesn't exist + // meaning someone is hand-crafting tokens for uuids that exist + return nil, nil, nil, SessionNotFoundError{} + } + + key, _, err := session.GetRefreshTokenHmacKey(dbEncryption) + if err != nil { + return nil, nil, nil, err + } + + if !refreshToken.CheckSignature(key) { + // TODO: return SessionNotFound, log informational + return nil, nil, nil, fmt.Errorf("refresh token for session %s with counter %v has invalid signature", session.ID.String(), refreshToken.Counter) + } + + user, err := FindUserByID(tx, session.UserID) + if err != nil { + return nil, nil, nil, err + } + + if forUpdate { + session, err = FindSessionByID(tx, refreshToken.SessionID, forUpdate) + if err != nil { + return nil, nil, nil, err + } + } + + return user, refreshToken, session, nil +} + +func findUserWithLegacyRefreshToken(tx *storage.Connection, token string, forUpdate bool) (*User, any, *Session, error) { refreshToken := &RefreshToken{} if forUpdate { diff --git a/internal/models/user_test.go b/internal/models/user_test.go index 26bfb0792..28e9f9cc6 100644 --- a/internal/models/user_test.go +++ b/internal/models/user_test.go @@ -22,7 +22,8 @@ func init() { type UserTestSuite struct { suite.Suite - db *storage.Connection + db *storage.Connection + config *conf.GlobalConfiguration } func (ts *UserTestSuite) SetupTest() { @@ -37,7 +38,8 @@ func TestUser(t *testing.T) { require.NoError(t, err) ts := &UserTestSuite{ - db: conn, + db: conn, + config: globalConfig, } defer ts.db.Close() @@ -152,8 +154,10 @@ func (ts *UserTestSuite) TestFindUserWithRefreshToken() { r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) require.NoError(ts.T(), err) - n, nr, s, err := FindUserWithRefreshToken(ts.db, r.Token, true /* forUpdate */) + n, anyNR, s, err := FindUserWithRefreshToken(ts.db, ts.config.Security.DBEncryption, r.Token, true /* forUpdate */) require.NoError(ts.T(), err) + + nr := anyNR.(*RefreshToken) require.Equal(ts.T(), r.ID, nr.ID) require.Equal(ts.T(), u.ID, n.ID) require.NotNil(ts.T(), s) diff --git a/internal/tokens/service.go b/internal/tokens/service.go index 673392764..326f2cd81 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -15,6 +15,7 @@ import ( "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/hooks/v0hooks" "github.com/supabase/auth/internal/metering" "github.com/supabase/auth/internal/models" @@ -130,7 +131,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, for retry && time.Since(retryStart).Seconds() < retryLoopDuration { retry = false - user, token, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false) + user, anyToken, session, err := models.FindUserWithRefreshToken(db, config.Security.DBEncryption, params.RefreshToken, false) if err != nil { if models.IsNotFoundError(err) { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenNotFound, "Invalid Refresh Token: Refresh Token Not Found") @@ -143,9 +144,11 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, } if session == nil { - // a refresh token won't have a session if it's created prior to the sessions table introduced - if err := db.Destroy(token); err != nil { - return nil, apierrors.NewInternalServerError("Error deleting refresh token with missing session").WithInternalError(err) + if token, ok := anyToken.(*models.RefreshToken); ok { + // a refresh token won't have a session if it's created prior to the sessions table introduced + if err := db.Destroy(token); err != nil { + return nil, apierrors.NewInternalServerError("Error deleting refresh token with missing session").WithInternalError(err) + } } return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeSessionNotFound, "Invalid Refresh Token: No Valid Session Found") } @@ -159,7 +162,12 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, AllowLowAAL: config.Sessions.AllowLowAAL, } - result := session.CheckValidity(sessionValidityConfig, retryStart, &token.UpdatedAt, user.HighestPossibleAAL()) + var refreshTokenTime *time.Time + if token, ok := anyToken.(*models.RefreshToken); ok { + refreshTokenTime = &token.UpdatedAt + } + + result := session.CheckValidity(sessionValidityConfig, retryStart, refreshTokenTime, user.HighestPossibleAAL()) switch result { case models.SessionValid: @@ -187,7 +195,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, var newTokenResponse *AccessTokenResponse err = db.Transaction(func(tx *storage.Connection) error { - user, token, session, terr := models.FindUserWithRefreshToken(tx, params.RefreshToken, true /* forUpdate */) + user, anyToken, session, terr := models.FindUserWithRefreshToken(tx, config.Security.DBEncryption, params.RefreshToken, true /* forUpdate */) if terr != nil { if models.IsNotFoundError(terr) { // because forUpdate was set, and the @@ -266,7 +274,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, // since token is not the refresh token // of s, we can't use it's UpdatedAt // time to compare! - if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) { + if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(refreshTokenTime)) { // session is not the most // recently active one return apierrors.NewBadRequestError(apierrors.ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Revoked by Newer Login)") @@ -279,60 +287,157 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, // refresh token row and session are locked at this // point, cannot be concurrently refreshed - var issuedToken *models.RefreshToken + var issuedToken string - if token.Revoked { - activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx) - if terr != nil && !models.IsNotFoundError(terr) { - return apierrors.NewInternalServerError(terr.Error()) - } + if token, ok := anyToken.(*models.RefreshToken); ok { + if token.Revoked { + activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx) + if terr != nil && !models.IsNotFoundError(terr) { + return apierrors.NewInternalServerError(terr.Error()) + } - if activeRefreshToken != nil && activeRefreshToken.Parent.String() == token.Token { - // Token was revoked, but it's the - // parent of the currently active one. - // This indicates that the client was - // not able to store the result when it - // refreshed token. This case is - // allowed, provided we return back the - // active refresh token instead of - // creating a new one. - issuedToken = activeRefreshToken - } else { - // For a revoked refresh token to be reused, it - // has to fall within the reuse interval. - reuseUntil := token.UpdatedAt.Add( - time.Second * time.Duration(config.Security.RefreshTokenReuseInterval)) - - if s.now().After(reuseUntil) { - // not OK to reuse this token - if config.Security.RefreshTokenRotationEnabled { - // Revoke all tokens in token family - if err := models.RevokeTokenFamily(tx, token); err != nil { - return apierrors.NewInternalServerError(err.Error()) + if activeRefreshToken != nil && activeRefreshToken.Parent.String() == token.Token { + // Token was revoked, but it's the + // parent of the currently active one. + // This indicates that the client was + // not able to store the result when it + // refreshed token. This case is + // allowed, provided we return back the + // active refresh token instead of + // creating a new one. + issuedToken = activeRefreshToken.Token + } else { + // For a revoked refresh token to be reused, it + // has to fall within the reuse interval. + reuseUntil := token.UpdatedAt.Add( + time.Second * time.Duration(config.Security.RefreshTokenReuseInterval)) + + if s.now().After(reuseUntil) { + // not OK to reuse this token + if config.Security.RefreshTokenRotationEnabled { + // Revoke all tokens in token family + if err := models.RevokeTokenFamily(tx, token); err != nil { + return apierrors.NewInternalServerError(err.Error()) + } } + + return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)) } + } + } + + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.TokenRefreshedAction, "", nil); terr != nil { + return terr + } - return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)) + if issuedToken == "" { + newToken, terr := models.GrantRefreshTokenSwap(config.AuditLog, r, tx, user, token) + if terr != nil { + return terr } + + issuedToken = newToken.Token + } + } else if token, ok := anyToken.(*crypto.RefreshToken); ok { + signingKey, _, kerr := session.GetRefreshTokenHmacKey(config.Security.DBEncryption) + if kerr != nil { + return apierrors.NewInternalServerError("failed to load session from database").WithInternalError(terr) } - } - if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.TokenRefreshedAction, "", nil); terr != nil { - return terr - } + counterDifference := *session.RefreshTokenCounter - token.Counter + + if counterDifference < 0 { + // refresh token was not issued by this server + } else if counterDifference == 0 || config.Security.RefreshTokenAllowReuse { + // normal refresh token use + counter := *session.RefreshTokenCounter + 1 + session.RefreshTokenCounter = &counter + + issuedToken = (&crypto.RefreshToken{ + Version: 0, + SessionID: session.ID, + Counter: *session.RefreshTokenCounter, + }).Encode(signingKey) + } else if counterDifference > 0 { + // refresh token is being reused + + // This is caused when the client has + // failed to receive or save the + // response from the last refresh token + // requests. This occurs more + // frequently than you can imagine, so + // it's an allowed reuse. + likelyNotSavedByClient := counterDifference == 1 + + // Concurrent refreshes occur when the + // client sends off multiple refresh + // token requests at once or close by. + // Often this happens when your browser + // remembers multiple tabs of the app, + // which were paused by it or by the OS + // (such as you quitting the browser) + // and then opening it back up. If the + // app uses SSR it is likely that the N + // open tabs will immediately send a + // request to the app's hosting server, + // which will attempt to concurrently + // refresh the session at once using + // refresh token. + likelyConcurrentRefreshes := retryStart.Sub(session.LastRefreshedAt(nil)).Abs() < time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + + reuseAllowed := likelyNotSavedByClient || likelyConcurrentRefreshes || config.Security.RefreshTokenAllowReuse + + if reuseAllowed { + // When reuse is allowed, we do + // not increment the counter. + // This allows all of the + // concurrent clients to + // synchronize their state + // within the refresh token + // reuse interval to the + // currently active refresh + // token. + + issuedToken = (&crypto.RefreshToken{ + Version: 0, + SessionID: session.ID, + Counter: *session.RefreshTokenCounter, + }).Encode(signingKey) + } else if config.Security.RefreshTokenRotationEnabled { + // Reuse is not allowed, in + // which case the whole session + // must go preventing any + // client with any refresh and + // access token for this + // session from being used. + + if terr := models.LogoutSession(tx, session.ID); terr != nil { + return apierrors.NewInternalServerError("destroying session after detected refresh token reuse failed").WithInternalError(terr) + } - if issuedToken == nil { - newToken, terr := models.GrantRefreshTokenSwap(config.AuditLog, r, tx, user, token) - if terr != nil { - return terr + return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Refresh token behind current counter by %v, session %v is terminated due to refresh token reuse", counterDifference, session.ID.String())) + } else { + // Reuse is not allowed, but no + // refresh token rotation + // enabled. So only fail this + // request. + + return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Refresh token behind current counter by %v, session %v is not terminated", counterDifference, session.ID.String())) + } } - issuedToken = newToken + if terr := session.UpdateOnlyRefreshToken(tx); terr != nil { + return apierrors.NewInternalServerError("failed saving session").WithInternalError(terr) + } + + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.TokenRefreshedAction, "", nil); terr != nil { + return terr + } } tokenString, expiresAt, terr = s.GenerateAccessToken(r, tx, GenerateAccessTokenParams{ User: user, - SessionID: issuedToken.SessionId, + SessionID: &session.ID, AuthenticationMethod: models.TokenRefresh, ClientID: sessionClientID, }) @@ -370,7 +475,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, TokenType: "bearer", ExpiresIn: config.JWT.Exp, ExpiresAt: expiresAt, - RefreshToken: issuedToken.Token, + RefreshToken: issuedToken, User: user, } @@ -473,28 +578,60 @@ func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, u var tokenString string var expiresAt int64 - var refreshToken *models.RefreshToken + var refreshToken string + var sessionID uuid.UUID var oAuthClientID *uuid.UUID err := conn.Transaction(func(tx *storage.Connection) error { var terr error - refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams) - if terr != nil { - return apierrors.NewInternalServerError("Database error granting user").WithInternalError(terr) + if config.Security.RefreshTokenAlgorithmVersion == 2 { + session, terr := models.NewSession(user.ID, grantParams.FactorID) + if terr != nil { + return apierrors.NewInternalServerError("Failed to create new session").WithInternalError(terr) + } + + session.ApplyGrantParams(&grantParams) + if terr := session.SetupRefreshTokenData(config.Security.DBEncryption); terr != nil { + return apierrors.NewInternalServerError("Failed to setup refresh token data for session").WithInternalError(terr) + } + + if terr := tx.Create(session); terr != nil { + return apierrors.NewInternalServerError("Database error creating new session").WithInternalError(terr) + } + + signingKey, _, terr := session.GetRefreshTokenHmacKey(config.Security.DBEncryption) + if terr != nil { + return apierrors.NewInternalServerError("Failed to get session's refresh token key").WithInternalError(terr) + } + + sessionID = session.ID + refreshToken = (&crypto.RefreshToken{ + SessionID: session.ID, + Counter: *session.RefreshTokenCounter, + }).Encode(signingKey) + } else { + rt, terr := models.GrantAuthenticatedUser(tx, user, grantParams) + if terr != nil { + return apierrors.NewInternalServerError("Database error granting user").WithInternalError(terr) + } + + sessionID = *rt.SessionId + refreshToken = rt.Token } + if grantParams.OAuthClientID != nil && *grantParams.OAuthClientID != uuid.Nil { oAuthClientID = grantParams.OAuthClientID } - terr = models.AddClaimToSession(tx, *refreshToken.SessionId, authenticationMethod) + terr = models.AddClaimToSession(tx, sessionID, authenticationMethod) if terr != nil { return terr } tokenString, expiresAt, terr = s.GenerateAccessToken(r, tx, GenerateAccessTokenParams{ User: user, - SessionID: refreshToken.SessionId, + SessionID: &sessionID, AuthenticationMethod: authenticationMethod, ClientID: oAuthClientID, }) @@ -516,7 +653,7 @@ func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, u TokenType: "bearer", ExpiresIn: config.JWT.Exp, ExpiresAt: expiresAt, - RefreshToken: refreshToken.Token, + RefreshToken: refreshToken, User: user, }, nil } diff --git a/migrations/20251007112900_add_session_refresh_token_columns.up.sql b/migrations/20251007112900_add_session_refresh_token_columns.up.sql new file mode 100644 index 000000000..abcf44228 --- /dev/null +++ b/migrations/20251007112900_add_session_refresh_token_columns.up.sql @@ -0,0 +1,6 @@ +ALTER TABLE {{ index .Options "Namespace" }}.sessions + ADD COLUMN IF NOT EXISTS refresh_token_hmac_key text, + ADD COLUMN IF NOT EXISTS refresh_token_counter bigint; + +COMMENT ON COLUMN {{ index .Options "Namespace" }}.sessions.refresh_token_hmac_key IS 'Holds a HMAC-SHA256 key used to sign refresh tokens for this session.'; +COMMENT ON COLUMN {{ index .Options "Namespace" }}.sessions.refresh_token_counter IS 'Holds the ID (counter) of the last issued refresh token.'; From 55a4c12f314aa19fb53721cac0def3f33ed266f7 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Wed, 22 Oct 2025 17:03:31 +0200 Subject: [PATCH 2/7] add majority of tests --- internal/api/oauthserver/handlers.go | 4 +- internal/api/token.go | 6 +- internal/api/token_refresh.go | 32 +- internal/crypto/crypto_test.go | 1 + internal/crypto/refresh_tokens.go | 29 +- internal/crypto/refresh_tokens_test.go | 15 + internal/models/sessions.go | 17 + internal/models/user.go | 33 +- internal/tokens/service.go | 48 +- internal/tokens/service_test.go | 685 +++++++++++++++++++++++++ 10 files changed, 837 insertions(+), 33 deletions(-) create mode 100644 internal/tokens/service_test.go diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 476e01329..1c59a4034 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -437,7 +437,7 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon // Issue the refresh token and access token var terr error - tokenResponse, terr = tokenService.IssueRefreshToken(r, tx, user, authMethod, grantParams) + tokenResponse, terr = tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams) if terr != nil { return terr } @@ -488,7 +488,7 @@ func (s *Server) handleRefreshTokenGrant(ctx context.Context, w http.ResponseWri } db := s.db.WithContext(ctx) - tokenResponse, err := tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{ + tokenResponse, err := tokenService.RefreshTokenGrant(ctx, db, r, w.Header(), tokens.RefreshTokenGrantParams{ RefreshToken: params.RefreshToken, ClientID: clientID, }) diff --git a/internal/api/token.go b/internal/api/token.go index a2b9bdcd1..b85a01877 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -190,7 +190,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri }); terr != nil { return terr } - token, terr = a.tokenService.IssueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) + token, terr = a.tokenService.IssueRefreshToken(r, w.Header(), tx, user, models.PasswordGrant, grantParams) if terr != nil { return terr } @@ -260,7 +260,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) }); terr != nil { return terr } - token, terr = a.tokenService.IssueRefreshToken(r, tx, user, authMethod, grantParams) + token, terr = a.tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams) if terr != nil { // error type is already handled in issueRefreshToken return terr @@ -295,7 +295,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user } func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*tokens.AccessTokenResponse, error) { - return a.tokenService.IssueRefreshToken(r, conn, user, authenticationMethod, grantParams) + return a.tokenService.IssueRefreshToken(r, make(http.Header), conn, user, authenticationMethod, grantParams) } func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*tokens.AccessTokenResponse, error) { diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 2ba3dbfd2..3178b9ec9 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -3,7 +3,10 @@ package api import ( "context" "net/http" + "regexp" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/tokens" ) @@ -12,6 +15,29 @@ type RefreshTokenGrantParams struct { RefreshToken string `json:"refresh_token"` } +var legacyRefreshTokenPattern = regexp.MustCompile("^[a-z0-9]{12}$") + +func (p *RefreshTokenGrantParams) Validate() error { + if len(p.RefreshToken) < 12 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid") + } + + if len(p.RefreshToken) == 12 { + if !legacyRefreshTokenPattern.MatchString(p.RefreshToken) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid") + } + + return nil + } + + _, err := crypto.ParseRefreshToken(p.RefreshToken) + if err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Refresh token is not valid").WithInternalError(err) + } + + return nil +} + // RefreshTokenGrant implements the refresh_token grant type flow func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { params := &RefreshTokenGrantParams{} @@ -19,8 +45,12 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return err } + if err := params.Validate(); err != nil { + return err + } + db := a.db.WithContext(ctx) - tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{ + tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, w.Header(), tokens.RefreshTokenGrantParams{ RefreshToken: params.RefreshToken, }) if err != nil { diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go index 4541c91fc..de05dbfb7 100644 --- a/internal/crypto/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -105,4 +105,5 @@ func TestEncryptedStringDecryptNegative(t *testing.T) { func TestSecureToken(t *testing.T) { assert.Equal(t, len(SecureAlphanumeric(22)), 22) + assert.Equal(t, len(SecureAlphanumeric(7)), 8) } diff --git a/internal/crypto/refresh_tokens.go b/internal/crypto/refresh_tokens.go index 2744c03cf..966cb5f0b 100644 --- a/internal/crypto/refresh_tokens.go +++ b/internal/crypto/refresh_tokens.go @@ -61,7 +61,7 @@ func (r *RefreshToken) Encode(hmacSha256Key []byte) string { result = append(result, 0) result = append(result, r.SessionID.Bytes()...) - result = binary.AppendUvarint(result, uint64(r.Counter)) + result = binary.AppendUvarint(result, safeUint64(r.Counter)) // Note on truncating the HMAC-SHA-256 output: // This does not impact security as the brute-force space is 2^128 and @@ -90,6 +90,22 @@ var ( ErrRefreshTokenCounterInvalid = errors.New("crypto: refresh token's counter is not valid") ) +func safeInt64(v uint64) int64 { + if v > math.MaxInt64 { + return math.MaxInt64 + } + + return int64(v) +} + +func safeUint64(v int64) uint64 { + if v < 0 { + return 0 + } + + return uint64(v) +} + func ParseRefreshToken(token string) (*RefreshToken, error) { bytes, err := base64.RawURLEncoding.DecodeString(token) if err != nil { @@ -111,10 +127,7 @@ func ParseRefreshToken(token string) (*RefreshToken, error) { return nil, ErrRefreshTokenChecksumInvalid } - sessionID, err := uuid.FromBytes(parseFrom[0:16]) - if err != nil { - return nil, err - } + sessionID := uuid.FromBytesOrNil(parseFrom[0:16]) parseFrom = parseFrom[16:] @@ -123,10 +136,6 @@ func ParseRefreshToken(token string) (*RefreshToken, error) { return nil, ErrRefreshTokenCounterInvalid } - if counter > math.MaxInt64 { - return nil, ErrRefreshTokenCounterInvalid - } - parseFrom = parseFrom[counterBytes:] if len(parseFrom) != 16 { @@ -140,7 +149,7 @@ func ParseRefreshToken(token string) (*RefreshToken, error) { Version: 0, SessionID: sessionID, - Counter: int64(counter), + Counter: safeInt64(counter), Signature: signature, }, nil } diff --git a/internal/crypto/refresh_tokens_test.go b/internal/crypto/refresh_tokens_test.go index f5aabad77..759eb42ec 100644 --- a/internal/crypto/refresh_tokens_test.go +++ b/internal/crypto/refresh_tokens_test.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/base64" "fmt" + "math" "strings" "testing" @@ -11,6 +12,14 @@ import ( "github.com/stretchr/testify/require" ) +func TestSafeIntegers(t *testing.T) { + require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxUint64)) + require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64)) + + require.Equal(t, int64(0), safeUint64(-1)) + require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64)) +} + func TestRefreshTokenParse(t *testing.T) { negativeExamples := []struct { value []byte @@ -84,3 +93,9 @@ func TestRefreshTokenParse(t *testing.T) { require.Equal(t, original.Raw, parsed.Raw) require.Equal(t, original.Signature, parsed.Signature) } + +func TestRefreshTokenTableName(t *testing.T) { + require.Panics(t, func() { + RefreshToken{}.TableName() + }) +} diff --git a/internal/models/sessions.go b/internal/models/sessions.go index 24bf83a37..b5474047d 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -167,6 +167,23 @@ func (s *Session) UpdateOnlyRefreshToken(tx *storage.Connection) error { return tx.UpdateOnly(s, "refresh_token_counter") } +func (s *Session) ReEncryptRefreshTokenHmacKey(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration) error { + key, _, err := s.GetRefreshTokenHmacKey(dbEncryption) + if err != nil { + return err + } + + es, err := crypto.NewEncryptedString(s.ID.String(), []byte(base64.RawURLEncoding.EncodeToString(key)), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey) + if err != nil { + return err + } + + encryptedValue := es.String() + s.RefreshTokenHmacKey = &encryptedValue + + return tx.UpdateOnly(s, "refresh_token_hmac_key") +} + type SessionValidityReason = int const ( diff --git a/internal/models/user.go b/internal/models/user.go index f611e3ba9..3c706b80e 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -636,8 +636,13 @@ func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { // lock will only be acquired if there's no other lock. In case there is a // lock, a IsNotFound(err) error will be returned. // -// Second value returned is either *models.RefreshToken or *models.TODO. +// Second value returned is either *models.RefreshToken or *crypto.RefreshToken. func FindUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration, token string, forUpdate bool) (*User, any, *Session, error) { + if len(token) < 12 { + // not a valid refresh token so don't bother looking it up in the database + return nil, nil, nil, SessionNotFoundError{} + } + if len(token) == 12 { return findUserWithLegacyRefreshToken(tx, token, forUpdate) } @@ -648,11 +653,11 @@ func FindUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.Database func findUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.DatabaseEncryptionConfiguration, token string, forUpdate bool) (*User, *crypto.RefreshToken, *Session, error) { refreshToken, err := crypto.ParseRefreshToken(token) if err != nil { - return nil, nil, nil, err + // refresh token is not valid + return nil, nil, nil, SessionNotFoundError{} } - // first find the session to check the token's signature - session, err := FindSessionByID(tx, refreshToken.SessionID, false) + session, err := FindSessionByID(tx, refreshToken.SessionID, forUpdate) if err != nil { return nil, nil, nil, err } @@ -663,28 +668,28 @@ func findUserWithRefreshToken(tx *storage.Connection, dbEncryption conf.Database return nil, nil, nil, SessionNotFoundError{} } - key, _, err := session.GetRefreshTokenHmacKey(dbEncryption) + key, shouldReEncrypt, err := session.GetRefreshTokenHmacKey(dbEncryption) if err != nil { return nil, nil, nil, err } if !refreshToken.CheckSignature(key) { - // TODO: return SessionNotFound, log informational - return nil, nil, nil, fmt.Errorf("refresh token for session %s with counter %v has invalid signature", session.ID.String(), refreshToken.Counter) - } - - user, err := FindUserByID(tx, session.UserID) - if err != nil { - return nil, nil, nil, err + // refresh token signature is not valid for this session + return nil, nil, nil, SessionNotFoundError{} } - if forUpdate { - session, err = FindSessionByID(tx, refreshToken.SessionID, forUpdate) + if shouldReEncrypt && forUpdate { + err := session.ReEncryptRefreshTokenHmacKey(tx, dbEncryption) if err != nil { return nil, nil, nil, err } } + user, err := FindUserByID(tx, session.UserID) + if err != nil { + return nil, nil, nil, err + } + return user, refreshToken, session, nil } diff --git a/internal/tokens/service.go b/internal/tokens/service.go index 326f2cd81..101226578 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "time" "github.com/gofrs/uuid" @@ -112,7 +113,7 @@ func (s *Service) SetTimeFunc(timeFunc func() time.Time) { } // RefreshTokenGrant implements the refresh_token grant type flow -func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, r *http.Request, params RefreshTokenGrantParams) (*AccessTokenResponse, error) { +func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, r *http.Request, responseHeaders http.Header, params RefreshTokenGrantParams) (*AccessTokenResponse, error) { db = db.WithContext(ctx) config := s.config @@ -139,6 +140,8 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, return nil, apierrors.NewInternalServerError(err.Error()) } + responseHeaders.Set("sb-auth-user-id", user.ID.String()) + if user.IsBanned() { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeUserBanned, "Invalid Refresh Token: User Banned") } @@ -153,6 +156,8 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeSessionNotFound, "Invalid Refresh Token: No Valid Session Found") } + responseHeaders.Set("sb-auth-session-id", session.ID.String()) + // OAuth client validation will be done inside the transaction var sessionClientID *uuid.UUID @@ -338,6 +343,8 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, issuedToken = newToken.Token } + + responseHeaders.Set("sb-auth-refresh-token-prefix", issuedToken[0:5]) } else if token, ok := anyToken.(*crypto.RefreshToken); ok { signingKey, _, kerr := session.GetRefreshTokenHmacKey(config.Security.DBEncryption) if kerr != nil { @@ -348,6 +355,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, if counterDifference < 0 { // refresh token was not issued by this server + apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid Refresh Token: Not Issued By This Server").WithInternalMessage("Refresh token for session %s has a counter that's ahead %d of the database state", session.ID.String(), counterDifference) } else if counterDifference == 0 || config.Security.RefreshTokenAllowReuse { // normal refresh token use counter := *session.RefreshTokenCounter + 1 @@ -358,6 +366,8 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, SessionID: session.ID, Counter: *session.RefreshTokenCounter, }).Encode(signingKey) + + responseHeaders.Set("sb-auth-refresh-token-reuse", "false") } else if counterDifference > 0 { // refresh token is being reused @@ -385,6 +395,24 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, // refresh token. likelyConcurrentRefreshes := retryStart.Sub(session.LastRefreshedAt(nil)).Abs() < time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + var causes []string + if likelyConcurrentRefreshes { + causes = append(causes, "concurrent-refresh") + } + + if likelyNotSavedByClient { + causes = append(causes, "fail-to-save") + } + + if config.Security.RefreshTokenAllowReuse { + causes = append(causes, "always-allow") + } + + headerValue := strings.Join(causes, ",") + if headerValue != "" { + responseHeaders.Set("sb-auth-refresh-token-reuse-cause", headerValue) + } + reuseAllowed := likelyNotSavedByClient || likelyConcurrentRefreshes || config.Security.RefreshTokenAllowReuse if reuseAllowed { @@ -403,6 +431,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, SessionID: session.ID, Counter: *session.RefreshTokenCounter, }).Encode(signingKey) + } else if config.Security.RefreshTokenRotationEnabled { // Reuse is not allowed, in // which case the whole session @@ -415,6 +444,8 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, return apierrors.NewInternalServerError("destroying session after detected refresh token reuse failed").WithInternalError(terr) } + responseHeaders.Set("sb-auth-refresh-token-rotated", "true") + return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Refresh token behind current counter by %v, session %v is terminated due to refresh token reuse", counterDifference, session.ID.String())) } else { // Reuse is not allowed, but no @@ -422,7 +453,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, // enabled. So only fail this // request. - return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Refresh token behind current counter by %v, session %v is not terminated", counterDifference, session.ID.String())) + return storage.NewCommitWithError(apierrors.NewBadRequestError(apierrors.ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Refresh token behind current counter by %v, session %v is not terminated but error is returned", counterDifference, session.ID.String())) } } @@ -430,6 +461,8 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, return apierrors.NewInternalServerError("failed saving session").WithInternalError(terr) } + responseHeaders.Set("sb-auth-refresh-token-counter", strconv.FormatInt(*session.RefreshTokenCounter, 10)) + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.TokenRefreshedAction, "", nil); terr != nil { return terr } @@ -570,7 +603,7 @@ func (s *Service) GenerateAccessToken(r *http.Request, tx *storage.Connection, p } // IssueRefreshToken creates a new refresh token and access token -func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { +func (s *Service) IssueRefreshToken(r *http.Request, responseHeaders http.Header, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { config := s.config now := s.now() @@ -582,6 +615,8 @@ func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, u var sessionID uuid.UUID var oAuthClientID *uuid.UUID + responseHeaders.Set("sb-auth-user-id", user.ID.String()) + err := conn.Transaction(func(tx *storage.Connection) error { var terr error @@ -610,6 +645,9 @@ func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, u SessionID: session.ID, Counter: *session.RefreshTokenCounter, }).Encode(signingKey) + + responseHeaders.Set("sb-auth-session-id", sessionID.String()) + responseHeaders.Set("sb-auth-refresh-token-counter", strconv.FormatInt(*session.RefreshTokenCounter, 10)) } else { rt, terr := models.GrantAuthenticatedUser(tx, user, grantParams) if terr != nil { @@ -618,6 +656,9 @@ func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, u sessionID = *rt.SessionId refreshToken = rt.Token + + responseHeaders.Set("sb-auth-session-id", sessionID.String()) + responseHeaders.Set("sb-auth-refresh-token-prefix", refreshToken[0:5]) } if grantParams.OAuthClientID != nil && *grantParams.OAuthClientID != uuid.Nil { @@ -642,6 +683,7 @@ func (s *Service) IssueRefreshToken(r *http.Request, conn *storage.Connection, u } return apierrors.NewInternalServerError("error generating jwt token").WithInternalError(terr) } + return nil }) if err != nil { diff --git a/internal/tokens/service_test.go b/internal/tokens/service_test.go new file mode 100644 index 000000000..1d0ed2216 --- /dev/null +++ b/internal/tokens/service_test.go @@ -0,0 +1,685 @@ +package tokens + +import ( + "context" + "crypto/rand" + "encoding/base64" + "net/http" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type panicHookManager struct { +} + +func (m *panicHookManager) InvokeHook(tx *storage.Connection, r *http.Request, input any, output any) error { + panic("must not be called") +} + +type RefreshTokenV2Suite struct { + suite.Suite + + Conn *storage.Connection + + User *models.User +} + +func TestRefreshTokenV2(t *testing.T) { + ts := &RefreshTokenV2Suite{} + + conn, err := test.SetupDBConnection(ts.config()) + require.NoError(t, err) + + ts.Conn = conn + defer conn.Close() + + suite.Run(t, ts) +} + +func (ts *RefreshTokenV2Suite) SetupTest() { + models.TruncateAll(ts.Conn) + u, err := models.NewUser("", "test@example.com", "password", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.Conn.Create(u)) + + ts.User = u +} + +func (ts *RefreshTokenV2Suite) config() *conf.GlobalConfiguration { + config, err := conf.LoadGlobal("../../hack/test.env") + if err != nil { + panic("failed to load config") + } + + config.Security.RefreshTokenAlgorithmVersion = 2 + + return config +} + +func (ts *RefreshTokenV2Suite) TestNormalUse() { + config := ts.config() + require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion) + + config.Security.RefreshTokenRotationEnabled = false + config.Security.RefreshTokenReuseInterval = 1 + config.Security.RefreshTokenAllowReuse = false + + clock := time.Now() + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + req, err := http.NewRequest("POST", "https://example.com/", nil) + require.NoError(ts.T(), err) + + req = req.WithContext(context.Background()) + responseHeaders := make(http.Header) + + at, err := srv.IssueRefreshToken( + req, + responseHeaders, + ts.Conn, + ts.User, + models.PasswordGrant, + models.GrantParams{}, + ) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), at) + + prt, err := crypto.ParseRefreshToken(at.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), prt) + require.Equal(ts.T(), int64(0), prt.Counter) + + session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), session.UserID.String(), ts.User.ID.String()) + require.NotNil(ts.T(), session.RefreshTokenCounter) + require.NotNil(ts.T(), session.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(0), *session.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), "0", responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse := at.RefreshToken + + // 128 is used here to force multi-byte encoding of the refresh token counter + for i := 1; i < 128; i += 1 { + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokenToUse, + }) + require.NoError(ts.T(), err) + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + require.Equal(ts.T(), int64(i), pnrt.Counter) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(i), *refreshedSession.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), strconv.FormatInt(int64(i), 10), responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse = nrt.RefreshToken + } +} + +func (ts *RefreshTokenV2Suite) TestMaliciousReuse() { + config := ts.config() + require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion) + + config.Security.RefreshTokenRotationEnabled = false + config.Security.RefreshTokenReuseInterval = 1 + config.Security.RefreshTokenAllowReuse = false + + clock := time.Now() + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + req, err := http.NewRequest("POST", "https://example.com/", nil) + require.NoError(ts.T(), err) + responseHeaders := make(http.Header) + + req = req.WithContext(context.Background()) + + at, err := srv.IssueRefreshToken( + req, + responseHeaders, + ts.Conn, + ts.User, + models.PasswordGrant, + models.GrantParams{}, + ) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), at) + + prt, err := crypto.ParseRefreshToken(at.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), prt) + require.Equal(ts.T(), int64(0), prt.Counter) + + session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), session.UserID.String(), ts.User.ID.String()) + require.NotNil(ts.T(), session.RefreshTokenCounter) + require.NotNil(ts.T(), session.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(0), *session.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), "0", responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse := at.RefreshToken + + refreshTokens := []string{at.RefreshToken} + + // run through a few regular refresh tokens + for i := 1; i < 4; i += 1 { + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokenToUse, + }) + require.NoError(ts.T(), err) + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + require.Equal(ts.T(), int64(i), pnrt.Counter) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(i), *refreshedSession.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), strconv.FormatInt(int64(i), 10), responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse = nrt.RefreshToken + refreshTokens = append(refreshTokens, nrt.RefreshToken) + } + + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + + // all but the last two must fail refreshing + for _, refreshToken := range refreshTokens[:len(refreshTokens)-2] { + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshToken, + }) + require.Error(ts.T(), err) + require.Nil(ts.T(), nrt) + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + + refreshedSession, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(len(refreshTokens)-1), *refreshedSession.RefreshTokenCounter) + } + + // make sure that the last two allow refreshing + for _, refreshToken := range refreshTokens[len(refreshTokens)-2:] { + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshToken, + }) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), nrt) + } + + session, err = models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), int64(len(refreshTokens)), *session.RefreshTokenCounter) + + // now update service to use rotation, meaning that after the first reuse + config.Security.RefreshTokenRotationEnabled = true + + srv = NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + + responseHeaders = make(http.Header) + + // reuse the first refresh token, causing the session to be completely deleted + rrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokens[0], + }) + require.Error(ts.T(), err) + require.Nil(ts.T(), rrt) + + deletedSession, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.Error(ts.T(), err) + require.True(ts.T(), models.IsNotFoundError(err)) + require.Nil(ts.T(), deletedSession) +} + +func (ts *RefreshTokenV2Suite) TestConcurrentReuse() { + config := ts.config() + require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion) + + config.Security.RefreshTokenRotationEnabled = true + config.Security.RefreshTokenReuseInterval = 1 + config.Security.RefreshTokenAllowReuse = false + + clock := time.Now() + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + req, err := http.NewRequest("POST", "https://example.com/", nil) + require.NoError(ts.T(), err) + responseHeaders := make(http.Header) + + req = req.WithContext(context.Background()) + + at, err := srv.IssueRefreshToken( + req, + responseHeaders, + ts.Conn, + ts.User, + models.PasswordGrant, + models.GrantParams{}, + ) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), at) + + prt, err := crypto.ParseRefreshToken(at.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), prt) + require.Equal(ts.T(), int64(0), prt.Counter) + + session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), session.UserID.String(), ts.User.ID.String()) + require.NotNil(ts.T(), session.RefreshTokenCounter) + require.NotNil(ts.T(), session.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(0), *session.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), "0", responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse := at.RefreshToken + refreshTokens := []string{at.RefreshToken} + + // refresh the token serially a few times, to mimic a more real-world scenario + for i := 1; i < 4; i += 1 { + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokenToUse, + }) + require.NoError(ts.T(), err) + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + require.Equal(ts.T(), int64(i), pnrt.Counter) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(i), *refreshedSession.RefreshTokenCounter) + + refreshTokenToUse = nrt.RefreshToken + refreshTokens = append(refreshTokens, nrt.RefreshToken) + } + + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + + var wg sync.WaitGroup + + endTimeChan := make(chan time.Time) + defer close(endTimeChan) + + concurrency := 50 // going too large might use too many database connections + wg.Add(concurrency + 2) + + endTimes := make([]time.Time, 0, concurrency+1) + + go func() { + defer wg.Done() + endTimes = append(endTimes, time.Now()) + + for i := 0; i < concurrency; i += 1 { + endTimes = append(endTimes, <-endTimeChan) + } + }() + + causesChan := make(chan string) + defer close(causesChan) + + causes := make([]string, 0, concurrency) + + go func() { + defer wg.Done() + + for i := 0; i < concurrency; i += 1 { + causes = append(causes, <-causesChan) + } + }() + + for i := 0; i < concurrency; i += 1 { + go func() { + defer wg.Done() + + responseHeaders := make(http.Header) + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokenToUse, + }) + endTimeChan <- time.Now() + + require.NoError(ts.T(), err) + causesChan <- responseHeaders.Get("sb-auth-refresh-token-reuse-cause") + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + require.Equal(ts.T(), int64(len(refreshTokens)), pnrt.Counter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), strconv.FormatInt(int64(len(refreshTokens)), 10), responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(len(refreshTokens)), *refreshedSession.RefreshTokenCounter) + }() + } + + wg.Wait() + + // check that the session exists and was not refreshed concurrency times, but only once + session, err = models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), session.UserID.String(), ts.User.ID.String()) + require.NotNil(ts.T(), session.RefreshTokenCounter) + require.NotNil(ts.T(), session.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(len(refreshTokens)), *session.RefreshTokenCounter) + + // ensure that the end times are naturally sorted, indicating an exclusive lock was used + for i := 1; i < len(endTimes); i += 1 { + require.True(ts.T(), endTimes[i-1].Before(endTimes[i])) + } + + // first refresh is OK + require.Equal(ts.T(), "", causes[0]) + + // second refresh is either concurrent-refresh or fail-to-save + for _, cause := range causes[1:] { + require.Equal(ts.T(), "concurrent-refresh,fail-to-save", cause) + } +} + +func (ts *RefreshTokenV2Suite) TestFailToSaveReuse() { + config := ts.config() + require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion) + + config.Security.RefreshTokenRotationEnabled = false + config.Security.RefreshTokenReuseInterval = 1 + + clock := time.Now() + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + req, err := http.NewRequest("POST", "https://example.com/", nil) + require.NoError(ts.T(), err) + responseHeaders := make(http.Header) + + req = req.WithContext(context.Background()) + + at, err := srv.IssueRefreshToken( + req, + responseHeaders, + ts.Conn, + ts.User, + models.PasswordGrant, + models.GrantParams{}, + ) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), at) + + prt, err := crypto.ParseRefreshToken(at.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), prt) + require.Equal(ts.T(), int64(0), prt.Counter) + + session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), session.UserID.String(), ts.User.ID.String()) + require.NotNil(ts.T(), session.RefreshTokenCounter) + require.NotNil(ts.T(), session.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(0), *session.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), "0", responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokens := []string{at.RefreshToken} + + // a few regular refresh token calls to prime a real world scenario + for i := 1; i < 4; i += 1 { + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokens[len(refreshTokens)-1], + }) + require.NoError(ts.T(), err) + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + require.Equal(ts.T(), int64(i), pnrt.Counter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), strconv.FormatInt(int64(i), 10), responseHeaders.Get("sb-auth-refresh-token-counter")) + require.Equal(ts.T(), "", responseHeaders.Get("sb-auth-refresh-token-reuse-cause")) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(i), *refreshedSession.RefreshTokenCounter) + + refreshTokens = append(refreshTokens, nrt.RefreshToken) + } + + for i := 0; i < 10; i += 1 { + // ensure refreshes occur outside of the allowed reuse interval + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + responseHeaders := make(http.Header) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokens[len(refreshTokens)-2], + }) + require.NoError(ts.T(), err) + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + + // key assertion, ensuring the refresh token returned from the "failed to save" scenario is always the currently active refresh token + require.Equal(ts.T(), int64(len(refreshTokens)-1), pnrt.Counter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), strconv.FormatInt(int64(len(refreshTokens)-1), 10), responseHeaders.Get("sb-auth-refresh-token-counter")) + require.Equal(ts.T(), "fail-to-save", responseHeaders.Get("sb-auth-refresh-token-reuse-cause")) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(len(refreshTokens)-1), *refreshedSession.RefreshTokenCounter) + } +} + +func (ts *RefreshTokenV2Suite) TestDBEncryption() { + config := ts.config() + require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion) + + config.Security.RefreshTokenRotationEnabled = false + config.Security.RefreshTokenReuseInterval = 1 + config.Security.RefreshTokenAllowReuse = false + + encryptionKeyA := make([]byte, 32) + encryptionKeyB := make([]byte, 32) + + rand.Read(encryptionKeyA) + rand.Read(encryptionKeyB) + config.Security.DBEncryption.Encrypt = true + config.Security.DBEncryption.DecryptionKeys = map[string]string{ + "A": base64.RawURLEncoding.EncodeToString(encryptionKeyA), + "B": base64.RawURLEncoding.EncodeToString(encryptionKeyB), + } + config.Security.DBEncryption.EncryptionKeyID = "A" + config.Security.DBEncryption.EncryptionKey = config.Security.DBEncryption.DecryptionKeys[config.Security.DBEncryption.EncryptionKeyID] + + clock := time.Now() + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + req, err := http.NewRequest("POST", "https://example.com/", nil) + require.NoError(ts.T(), err) + + req = req.WithContext(context.Background()) + responseHeaders := make(http.Header) + + at, err := srv.IssueRefreshToken( + req, + responseHeaders, + ts.Conn, + ts.User, + models.PasswordGrant, + models.GrantParams{}, + ) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), at) + + prt, err := crypto.ParseRefreshToken(at.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), prt) + require.Equal(ts.T(), int64(0), prt.Counter) + + session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + require.Equal(ts.T(), session.UserID.String(), ts.User.ID.String()) + require.NotNil(ts.T(), session.RefreshTokenCounter) + require.NotNil(ts.T(), session.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(0), *session.RefreshTokenCounter) + + // key assertion + require.True(ts.T(), strings.Contains(*session.RefreshTokenHmacKey, "\"key_id\":\"A\"")) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), "0", responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse := at.RefreshToken + + var encryptedStrings []string + + for i := 1; i < 3; i += 1 { + clock = clock.Add(time.Duration(config.Security.RefreshTokenReuseInterval)*time.Second + time.Duration(100)*time.Millisecond) + responseHeaders := make(http.Header) + + // switch the encryption key to trigger re-encryption + if i%2 == 0 { + config.Security.DBEncryption.EncryptionKeyID = "A" + } else { + config.Security.DBEncryption.EncryptionKeyID = "B" + } + config.Security.DBEncryption.EncryptionKey = config.Security.DBEncryption.DecryptionKeys[config.Security.DBEncryption.EncryptionKeyID] + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: refreshTokenToUse, + }) + require.NoError(ts.T(), err) + + pnrt, err := crypto.ParseRefreshToken(nrt.RefreshToken) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), pnrt) + require.Equal(ts.T(), pnrt.SessionID.String(), prt.SessionID.String()) + require.Equal(ts.T(), int64(i), pnrt.Counter) + + refreshedSession, err := models.FindSessionByID(ts.Conn, pnrt.SessionID, false) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), refreshedSession.RefreshTokenCounter) + require.NotNil(ts.T(), refreshedSession.RefreshTokenHmacKey) + require.Equal(ts.T(), int64(i), *refreshedSession.RefreshTokenCounter) + + require.Equal(ts.T(), session.UserID.String(), responseHeaders.Get("sb-auth-user-id")) + require.Equal(ts.T(), session.ID.String(), responseHeaders.Get("sb-auth-session-id")) + require.Equal(ts.T(), strconv.FormatInt(int64(i), 10), responseHeaders.Get("sb-auth-refresh-token-counter")) + + refreshTokenToUse = nrt.RefreshToken + encryptedStrings = append(encryptedStrings, *refreshedSession.RefreshTokenHmacKey) + } + + require.Equal(ts.T(), 2, len(encryptedStrings)) + require.NotEqual(ts.T(), encryptedStrings[0], encryptedStrings[1]) + require.True(ts.T(), strings.Contains(encryptedStrings[0], "\"key_id\":\"B\"")) + require.True(ts.T(), strings.Contains(encryptedStrings[1], "\"key_id\":\"A\"")) +} From cf13bb4de7178b325544fb9e28f6bdd2857a50db Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Thu, 23 Oct 2025 10:59:02 +0200 Subject: [PATCH 3/7] fix TestSafeIntegers --- internal/crypto/refresh_tokens_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/crypto/refresh_tokens_test.go b/internal/crypto/refresh_tokens_test.go index 759eb42ec..2c1ccc85c 100644 --- a/internal/crypto/refresh_tokens_test.go +++ b/internal/crypto/refresh_tokens_test.go @@ -17,7 +17,7 @@ func TestSafeIntegers(t *testing.T) { require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64)) require.Equal(t, int64(0), safeUint64(-1)) - require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64)) + require.Equal(t, int64(math.MaxInt64), safeUint64(math.MaxInt64)) } func TestRefreshTokenParse(t *testing.T) { From e4be94d45cbee9d215a721fec1d599eb6b3a3036 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Thu, 23 Oct 2025 11:18:39 +0200 Subject: [PATCH 4/7] adjust concurrency down --- internal/tokens/service_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/tokens/service_test.go b/internal/tokens/service_test.go index 1d0ed2216..1290f46d7 100644 --- a/internal/tokens/service_test.go +++ b/internal/tokens/service_test.go @@ -376,7 +376,8 @@ func (ts *RefreshTokenV2Suite) TestConcurrentReuse() { endTimeChan := make(chan time.Time) defer close(endTimeChan) - concurrency := 50 // going too large might use too many database connections + // in CI this can cause quite a bit of issues due to the limited number of connections to the database + concurrency := 20 wg.Add(concurrency + 2) endTimes := make([]time.Time, 0, concurrency+1) From e95fcaa39ccd94c95a6366fd967d733498167b9c Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Thu, 23 Oct 2025 12:03:46 +0200 Subject: [PATCH 5/7] actually fix TestSafeIntegers --- internal/crypto/refresh_tokens_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/crypto/refresh_tokens_test.go b/internal/crypto/refresh_tokens_test.go index 2c1ccc85c..f511a30a7 100644 --- a/internal/crypto/refresh_tokens_test.go +++ b/internal/crypto/refresh_tokens_test.go @@ -16,8 +16,8 @@ func TestSafeIntegers(t *testing.T) { require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxUint64)) require.Equal(t, int64(math.MaxInt64), safeInt64(math.MaxInt64)) - require.Equal(t, int64(0), safeUint64(-1)) - require.Equal(t, int64(math.MaxInt64), safeUint64(math.MaxInt64)) + require.Equal(t, uint64(0), safeUint64(-1)) + require.Equal(t, uint64(math.MaxInt64), safeUint64(math.MaxInt64)) } func TestRefreshTokenParse(t *testing.T) { From 0ad7fa5dd18ff12478540f822eb96f56221df75f Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Thu, 23 Oct 2025 15:13:52 +0200 Subject: [PATCH 6/7] add tests for invalid tokens --- internal/tokens/service.go | 2 +- internal/tokens/service_test.go | 81 +++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/internal/tokens/service.go b/internal/tokens/service.go index 101226578..87118b28e 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -355,7 +355,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, if counterDifference < 0 { // refresh token was not issued by this server - apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid Refresh Token: Not Issued By This Server").WithInternalMessage("Refresh token for session %s has a counter that's ahead %d of the database state", session.ID.String(), counterDifference) + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid Refresh Token: Not Issued By This Server").WithInternalMessage("Refresh token for session %s has a counter that's ahead %d of the database state", session.ID.String(), -counterDifference) } else if counterDifference == 0 || config.Security.RefreshTokenAllowReuse { // normal refresh token use counter := *session.RefreshTokenCounter + 1 diff --git a/internal/tokens/service_test.go b/internal/tokens/service_test.go index 1290f46d7..c1735420c 100644 --- a/internal/tokens/service_test.go +++ b/internal/tokens/service_test.go @@ -684,3 +684,84 @@ func (ts *RefreshTokenV2Suite) TestDBEncryption() { require.True(ts.T(), strings.Contains(encryptedStrings[0], "\"key_id\":\"B\"")) require.True(ts.T(), strings.Contains(encryptedStrings[1], "\"key_id\":\"A\"")) } + +func (ts *RefreshTokenV2Suite) TestInvalidRefreshTokens() { + config := ts.config() + require.Equal(ts.T(), 2, config.Security.RefreshTokenAlgorithmVersion) + + config.Security.RefreshTokenRotationEnabled = false + config.Security.RefreshTokenReuseInterval = 1 + config.Security.RefreshTokenAllowReuse = false + + clock := time.Now() + + srv := NewService(config, &panicHookManager{}) + srv.SetTimeFunc(func() time.Time { + return clock + }) + + req, err := http.NewRequest("POST", "https://example.com/", nil) + require.NoError(ts.T(), err) + + req = req.WithContext(context.Background()) + responseHeaders := make(http.Header) + + at, err := srv.IssueRefreshToken( + req, + responseHeaders, + ts.Conn, + ts.User, + models.PasswordGrant, + models.GrantParams{}, + ) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), at) + + prt, err := crypto.ParseRefreshToken(at.RefreshToken) + require.NoError(ts.T(), err) + + session, err := models.FindSessionByID(ts.Conn, prt.SessionID, false) + require.NoError(ts.T(), err) + + key, _, err := session.GetRefreshTokenHmacKey(config.Security.DBEncryption) + require.NoError(ts.T(), err) + + // tamper with counter + prt.Counter += 1 + tamperedRefreshToken := prt.Encode(key) + + responseHeaders = make(http.Header) + nrt, err := srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: tamperedRefreshToken, + }) + require.Error(ts.T(), err) + require.Nil(ts.T(), nrt) + + require.Equal(ts.T(), prt.SessionID.String(), responseHeaders.Get("sb-auth-session-id")) + + // tamper with signature + prt.Counter = 0 + tamperedRefreshToken = prt.Encode(make([]byte, 32)) + + responseHeaders = make(http.Header) + nrt, err = srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: tamperedRefreshToken, + }) + require.Error(ts.T(), err) + require.Nil(ts.T(), nrt) + + require.Equal(ts.T(), "", responseHeaders.Get("sb-auth-session-id")) + + // remove the session + err = models.LogoutSession(ts.Conn, prt.SessionID) + require.NoError(ts.T(), err) + + responseHeaders = make(http.Header) + nrt, err = srv.RefreshTokenGrant(context.Background(), ts.Conn, req, responseHeaders, RefreshTokenGrantParams{ + RefreshToken: at.RefreshToken, + }) + require.Error(ts.T(), err) + require.Nil(ts.T(), nrt) + + require.Equal(ts.T(), "", responseHeaders.Get("sb-auth-session-id")) +} From c230d8a22dfcb2e9503183cfabf0de04d94c839a Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Thu, 23 Oct 2025 15:50:27 +0200 Subject: [PATCH 7/7] bump up coverage, change slightly always allow implementation --- internal/api/token_test.go | 24 ++++++++++++++++++++++++ internal/models/sessions.go | 4 ---- internal/models/sessions_test.go | 20 ++++++++++++++++++++ internal/tokens/service.go | 2 +- 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 919248c8a..e3390c94e 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" ) @@ -890,3 +891,26 @@ $$;` }) } } + +func TestRefreshTokenGrantParamsValidate(t *testing.T) { + examples := []string{ + "", + "01234567890", + "AAAAAAAAAAAA", + "------------", + "0000000000000", + } + + p := &RefreshTokenGrantParams{} + + for _, example := range examples { + p.RefreshToken = example + require.Error(t, p.Validate()) + } + + p.RefreshToken = "0123456abcde" + require.NoError(t, p.Validate()) + + p.RefreshToken = (&crypto.RefreshToken{}).Encode(make([]byte, 32)) + require.NoError(t, p.Validate()) +} diff --git a/internal/models/sessions.go b/internal/models/sessions.go index b5474047d..c1d0661b5 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -123,10 +123,6 @@ func (s *Session) GetRefreshTokenHmacKey(dbEncryption conf.DatabaseEncryptionCon return hmacKey, dbEncryption.Encrypt && es.ShouldReEncrypt(dbEncryption.EncryptionKeyID), nil } - if s.RefreshTokenHmacKey == nil { - return nil, false, nil - } - hmacKey, err := base64.RawURLEncoding.DecodeString(*s.RefreshTokenHmacKey) if err != nil { return nil, false, err diff --git a/internal/models/sessions_test.go b/internal/models/sessions_test.go index c39890a1e..3631dd1c7 100644 --- a/internal/models/sessions_test.go +++ b/internal/models/sessions_test.go @@ -1,10 +1,12 @@ package models import ( + "encoding/base64" "strings" "testing" "time" + "github.com/gofrs/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/conf" @@ -160,3 +162,21 @@ func TestCheckValidity(t *testing.T) { }) } } + +func TestSessionGetRefreshTokenHmacKey(t *testing.T) { + s, err := NewSession(uuid.Must(uuid.NewV4()), nil) + require.NoError(t, err) + + hmacKey, shouldReEncrypt, err := s.GetRefreshTokenHmacKey(conf.DatabaseEncryptionConfiguration{}) + require.NoError(t, err) + require.Nil(t, hmacKey) + require.False(t, shouldReEncrypt) + + key := base64.RawURLEncoding.EncodeToString(make([]byte, 32)) + s.RefreshTokenHmacKey = &key + + hmacKey, shouldReEncrypt, err = s.GetRefreshTokenHmacKey(conf.DatabaseEncryptionConfiguration{}) + require.NoError(t, err) + require.Equal(t, make([]byte, 32), hmacKey) + require.False(t, shouldReEncrypt) +} diff --git a/internal/tokens/service.go b/internal/tokens/service.go index 87118b28e..8650351d9 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -356,7 +356,7 @@ func (s *Service) RefreshTokenGrant(ctx context.Context, db *storage.Connection, if counterDifference < 0 { // refresh token was not issued by this server return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid Refresh Token: Not Issued By This Server").WithInternalMessage("Refresh token for session %s has a counter that's ahead %d of the database state", session.ID.String(), -counterDifference) - } else if counterDifference == 0 || config.Security.RefreshTokenAllowReuse { + } else if counterDifference == 0 { // normal refresh token use counter := *session.RefreshTokenCounter + 1 session.RefreshTokenCounter = &counter