Skip to content

Commit 15d5cf4

Browse files
committed
Add auth req store
1 parent 71c7c1d commit 15d5cf4

File tree

21 files changed

+1456
-983
lines changed

21 files changed

+1456
-983
lines changed

backend/dbscripts/runtimedb/postgres.sql

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ CREATE TABLE AUTHORIZATION_CODE (
1010
EXPIRY_TIME TIMESTAMP NOT NULL
1111
);
1212

13+
-- Table to store OAuth2 authorization request context
14+
CREATE TABLE AUTHORIZATION_REQUEST (
15+
AUTH_REQUEST_ID VARCHAR(36) PRIMARY KEY,
16+
REQUEST_DATA JSONB NOT NULL,
17+
AUTH_TIME TIMESTAMP NOT NULL,
18+
EXPIRY_TIME TIMESTAMP NOT NULL,
19+
CREATED_AT TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
20+
);
21+
22+
-- Index for efficient expiry cleanup
23+
CREATE INDEX IDX_AUTHORIZATION_REQUEST_EXPIRY ON AUTHORIZATION_REQUEST(EXPIRY_TIME);
24+
1325
-- Table to store flow context metadata and state
1426
CREATE TABLE FLOW_CONTEXT (
1527
FLOW_ID VARCHAR(36) PRIMARY KEY,

backend/dbscripts/runtimedb/sqlite.sql

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ CREATE TABLE AUTHORIZATION_CODE (
1010
EXPIRY_TIME DATETIME NOT NULL
1111
);
1212

13+
-- Table to store OAuth2 authorization request context
14+
CREATE TABLE AUTHORIZATION_REQUEST (
15+
AUTH_REQUEST_ID VARCHAR(36) PRIMARY KEY,
16+
REQUEST_DATA TEXT NOT NULL,
17+
AUTH_TIME DATETIME NOT NULL,
18+
EXPIRY_TIME DATETIME NOT NULL,
19+
CREATED_AT TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
20+
);
21+
22+
-- Index for efficient expiry cleanup
23+
CREATE INDEX IDX_AUTHORIZATION_REQUEST_EXPIRY ON AUTHORIZATION_REQUEST(EXPIRY_TIME);
24+
1325
-- Table to store flow context metadata and state
1426
CREATE TABLE FLOW_CONTEXT (
1527
FLOW_ID VARCHAR(36) PRIMARY KEY,

backend/internal/oauth/oauth2/authz/store.go renamed to backend/internal/oauth/oauth2/authz/auth_code_store.go

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import (
2222
"encoding/json"
2323
"errors"
2424
"fmt"
25-
"strings"
26-
"time"
2725

2826
"github.com/asgardeo/thunder/internal/system/database/provider"
2927
)
@@ -266,32 +264,3 @@ func appendAuthzDataJSON(row map[string]interface{}, authzCode *AuthorizationCod
266264

267265
return authzCode, nil
268266
}
269-
270-
// parseTimeField parses a time field from the database result.
271-
func parseTimeField(field interface{}, fieldName string) (time.Time, error) {
272-
const customTimeFormat = "2006-01-02 15:04:05.999999999"
273-
274-
switch v := field.(type) {
275-
case string:
276-
trimmedTime := trimTimeString(v)
277-
parsedTime, err := time.Parse(customTimeFormat, trimmedTime)
278-
if err != nil {
279-
return time.Time{}, fmt.Errorf("error parsing %s: %w", fieldName, err)
280-
}
281-
return parsedTime, nil
282-
case time.Time:
283-
return v, nil
284-
default:
285-
return time.Time{}, fmt.Errorf("unexpected type for %s", fieldName)
286-
}
287-
}
288-
289-
// trimTimeString trims extra information from a time string to match the expected format.
290-
func trimTimeString(timeStr string) string {
291-
// Split the string into parts by spaces and retain only the first two parts.
292-
parts := strings.SplitN(timeStr, " ", 3)
293-
if len(parts) >= 2 {
294-
return parts[0] + " " + parts[1]
295-
}
296-
return timeStr
297-
}

backend/internal/oauth/oauth2/authz/store_test.go renamed to backend/internal/oauth/oauth2/authz/auth_code_store_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,11 @@ func (suite *AuthorizationCodeStoreTestSuite) TestUpdateAuthorizationCodeState_E
287287
suite.mockdbProvider.AssertExpectations(suite.T())
288288
}
289289

290+
const testTimeString = "2023-12-01 10:30:45.123456789"
291+
290292
func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_StringInput() {
291-
testTime := "2023-12-01 10:30:45.123456789 extra content"
292-
expectedTime, _ := time.Parse("2006-01-02 15:04:05.999999999", "2023-12-01 10:30:45.123456789")
293+
testTime := testTimeString + " extra content"
294+
expectedTime, _ := time.Parse("2006-01-02 15:04:05.999999999", testTimeString)
293295

294296
result, err := parseTimeField(testTime, "test_field")
295297
assert.NoError(suite.T(), err)
@@ -305,8 +307,8 @@ func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_TimeInput() {
305307
}
306308

307309
func (suite *AuthorizationCodeStoreTestSuite) TestTrimTimeString() {
308-
input := "2023-12-01 10:30:45.123456789 extra content here"
309-
expected := "2023-12-01 10:30:45.123456789"
310+
input := testTimeString + " extra content here"
311+
expected := testTimeString
310312

311313
result := trimTimeString(input)
312314
assert.Equal(suite.T(), expected, result)
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/*
2+
* Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
3+
*
4+
* WSO2 LLC. licenses this file to you under the Apache License,
5+
* Version 2.0 (the "License"); you may not use this file except
6+
* in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing,
12+
* software distributed under the License is distributed on an
13+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
* KIND, either express or implied. See the License for the
15+
* specific language governing permissions and limitations
16+
* under the License.
17+
*/
18+
19+
package authz
20+
21+
import (
22+
"encoding/json"
23+
"fmt"
24+
"time"
25+
26+
"github.com/asgardeo/thunder/internal/oauth/oauth2/model"
27+
"github.com/asgardeo/thunder/internal/system/database/provider"
28+
"github.com/asgardeo/thunder/internal/system/log"
29+
"github.com/asgardeo/thunder/internal/system/utils"
30+
)
31+
32+
// AuthRequestContext holds OAuth authorization request information including parameters and authentication time.
33+
type AuthRequestContext struct {
34+
OAuthParameters model.OAuthParameters
35+
AuthTime time.Time
36+
}
37+
38+
// authorizationRequestStoreInterface defines the interface for authorization request storage.
39+
type authorizationRequestStoreInterface interface {
40+
AddRequest(value AuthRequestContext) string
41+
GetRequest(key string) (bool, AuthRequestContext)
42+
ClearRequest(key string)
43+
}
44+
45+
// authorizationRequestStore provides the authorization request store functionality using database.
46+
type authorizationRequestStore struct {
47+
dbProvider provider.DBProviderInterface
48+
validityPeriod time.Duration
49+
}
50+
51+
// newAuthorizationRequestStore creates a new instance of authorizationRequestStore with injected dependencies.
52+
func newAuthorizationRequestStore() authorizationRequestStoreInterface {
53+
return &authorizationRequestStore{
54+
dbProvider: provider.GetDBProvider(),
55+
validityPeriod: 10 * time.Minute, // Set a default validity period.
56+
}
57+
}
58+
59+
// AddRequest adds an authorization request context entry to the store.
60+
func (authzRS *authorizationRequestStore) AddRequest(value AuthRequestContext) string {
61+
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore"))
62+
63+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
64+
if err != nil {
65+
logger.Error("Failed to get database client", log.Error(err))
66+
return ""
67+
}
68+
69+
key := utils.GenerateUUID()
70+
expiryTime := value.AuthTime.Add(authzRS.validityPeriod)
71+
72+
// Serialize AuthRequestContext to JSON
73+
jsonDataBytes, err := authzRS.getJSONDataBytes(value)
74+
if err != nil {
75+
logger.Error("Failed to marshal request context to JSON", log.Error(err))
76+
return ""
77+
}
78+
79+
_, err = dbClient.Execute(queryInsertAuthRequest, key, jsonDataBytes, value.AuthTime, expiryTime)
80+
if err != nil {
81+
logger.Error("Failed to insert authorization request", log.Error(err))
82+
return ""
83+
}
84+
85+
return key
86+
}
87+
88+
// GetRequest retrieves an authorization request context entry from the store.
89+
func (authzRS *authorizationRequestStore) GetRequest(key string) (bool, AuthRequestContext) {
90+
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore"))
91+
92+
if key == "" {
93+
return false, AuthRequestContext{}
94+
}
95+
96+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
97+
if err != nil {
98+
logger.Error("Failed to get database client", log.Error(err))
99+
return false, AuthRequestContext{}
100+
}
101+
102+
// Check expiry by comparing with current time
103+
now := time.Now()
104+
results, err := dbClient.Query(queryGetAuthRequest, key, now)
105+
if err != nil {
106+
logger.Error("Failed to query authorization request", log.Error(err))
107+
return false, AuthRequestContext{}
108+
}
109+
110+
if len(results) == 0 {
111+
return false, AuthRequestContext{}
112+
}
113+
114+
row := results[0]
115+
authRequestContext, err := authzRS.buildAuthRequestContextFromResultRow(row)
116+
if err != nil {
117+
logger.Error("Failed to build authorization request context from result", log.Error(err))
118+
return false, AuthRequestContext{}
119+
}
120+
121+
return true, authRequestContext
122+
}
123+
124+
// ClearRequest removes a specific authorization request context entry from the store.
125+
func (authzRS *authorizationRequestStore) ClearRequest(key string) {
126+
if key == "" {
127+
return
128+
}
129+
130+
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore"))
131+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
132+
if err != nil {
133+
logger.Error("Failed to get database client", log.Error(err))
134+
return
135+
}
136+
137+
_, err = dbClient.Execute(queryDeleteAuthRequest, key)
138+
if err != nil {
139+
logger.Error("Failed to delete authorization request", log.Error(err))
140+
}
141+
}
142+
143+
// getJSONDataBytes prepares the JSON data bytes for the authorization request context.
144+
func (authzRS *authorizationRequestStore) getJSONDataBytes(authRequestContext AuthRequestContext) ([]byte, error) {
145+
jsonData := map[string]interface{}{
146+
"state": authRequestContext.OAuthParameters.State,
147+
"client_id": authRequestContext.OAuthParameters.ClientID,
148+
"redirect_uri": authRequestContext.OAuthParameters.RedirectURI,
149+
"response_type": authRequestContext.OAuthParameters.ResponseType,
150+
"standard_scopes": authRequestContext.OAuthParameters.StandardScopes,
151+
"permission_scopes": authRequestContext.OAuthParameters.PermissionScopes,
152+
"code_challenge": authRequestContext.OAuthParameters.CodeChallenge,
153+
"code_challenge_method": authRequestContext.OAuthParameters.CodeChallengeMethod,
154+
"resource": authRequestContext.OAuthParameters.Resource,
155+
}
156+
157+
jsonDataBytes, err := json.Marshal(jsonData)
158+
if err != nil {
159+
return nil, fmt.Errorf("error marshaling request context to JSON: %w", err)
160+
}
161+
return jsonDataBytes, nil
162+
}
163+
164+
// buildAuthRequestContextFromResultRow builds an AuthRequestContext from a database result row.
165+
func (authzRS *authorizationRequestStore) buildAuthRequestContextFromResultRow(
166+
row map[string]interface{},
167+
) (AuthRequestContext, error) {
168+
// Parse request_data JSON
169+
var dataJSON string
170+
if val, ok := row["request_data"].(string); ok && val != "" {
171+
dataJSON = val
172+
} else if val, ok := row["request_data"].([]byte); ok && len(val) > 0 {
173+
dataJSON = string(val)
174+
} else {
175+
return AuthRequestContext{}, fmt.Errorf("request_data is missing or of unexpected type")
176+
}
177+
178+
var requestDataMap map[string]interface{}
179+
if err := json.Unmarshal([]byte(dataJSON), &requestDataMap); err != nil {
180+
return AuthRequestContext{}, fmt.Errorf("failed to unmarshal request_data JSON: %w", err)
181+
}
182+
183+
// Parse auth_time
184+
authTime, err := parseTimeField(row["auth_time"], "auth_time")
185+
if err != nil {
186+
return AuthRequestContext{}, err
187+
}
188+
189+
// Build OAuthParameters from JSON
190+
oauthParams := model.OAuthParameters{}
191+
// Initialize slices to empty (not nil) to match original behavior
192+
oauthParams.StandardScopes = []string{}
193+
oauthParams.PermissionScopes = []string{}
194+
195+
if state, ok := requestDataMap["state"].(string); ok {
196+
oauthParams.State = state
197+
}
198+
if clientID, ok := requestDataMap["client_id"].(string); ok {
199+
oauthParams.ClientID = clientID
200+
}
201+
if redirectURI, ok := requestDataMap["redirect_uri"].(string); ok {
202+
oauthParams.RedirectURI = redirectURI
203+
}
204+
if responseType, ok := requestDataMap["response_type"].(string); ok {
205+
oauthParams.ResponseType = responseType
206+
}
207+
// Handle standard_scopes - can be []interface{} or []string or nil
208+
if standardScopes, ok := requestDataMap["standard_scopes"].([]interface{}); ok {
209+
oauthParams.StandardScopes = convertToStringArray(standardScopes)
210+
} else if standardScopes, ok := requestDataMap["standard_scopes"].([]string); ok {
211+
oauthParams.StandardScopes = standardScopes
212+
} else if requestDataMap["standard_scopes"] == nil {
213+
// Handle nil case - set to empty slice
214+
oauthParams.StandardScopes = []string{}
215+
}
216+
// Handle permission_scopes - can be []interface{} or []string or nil
217+
if permissionScopes, ok := requestDataMap["permission_scopes"].([]interface{}); ok {
218+
oauthParams.PermissionScopes = convertToStringArray(permissionScopes)
219+
} else if permissionScopes, ok := requestDataMap["permission_scopes"].([]string); ok {
220+
oauthParams.PermissionScopes = permissionScopes
221+
} else if requestDataMap["permission_scopes"] == nil {
222+
// Handle nil case - set to empty slice
223+
oauthParams.PermissionScopes = []string{}
224+
}
225+
if codeChallenge, ok := requestDataMap["code_challenge"].(string); ok {
226+
oauthParams.CodeChallenge = codeChallenge
227+
}
228+
if codeChallengeMethod, ok := requestDataMap["code_challenge_method"].(string); ok {
229+
oauthParams.CodeChallengeMethod = codeChallengeMethod
230+
}
231+
if resource, ok := requestDataMap["resource"].(string); ok {
232+
oauthParams.Resource = resource
233+
}
234+
235+
return AuthRequestContext{
236+
OAuthParameters: oauthParams,
237+
AuthTime: authTime,
238+
}, nil
239+
}
240+
241+
// convertToStringArray converts []interface{} to []string.
242+
func convertToStringArray(arr []interface{}) []string {
243+
result := make([]string, 0, len(arr))
244+
for _, v := range arr {
245+
if str, ok := v.(string); ok {
246+
result = append(result, str)
247+
}
248+
}
249+
return result
250+
}

0 commit comments

Comments
 (0)