Skip to content

Commit e48eea9

Browse files
committed
Add auth req store
1 parent 71c7c1d commit e48eea9

File tree

22 files changed

+1406
-1025
lines changed

22 files changed

+1406
-1025
lines changed

backend/dbscripts/runtimedb/postgres.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ CREATE TABLE AUTHORIZATION_CODE (
1010
EXPIRY_TIME TIMESTAMP NOT NULL
1111
);
1212

13+
-- Table to store OAuth2 authorization request context
14+
CREATE TABLE AUTHORIZATION_REQUEST (
15+
AUTH_ID VARCHAR(36) PRIMARY KEY,
16+
REQUEST_DATA JSONB NOT NULL,
17+
EXPIRY_TIME TIMESTAMP NOT NULL,
18+
CREATED_AT TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
19+
);
20+
21+
-- Index for efficient expiry cleanup
22+
CREATE INDEX IDX_AUTHORIZATION_REQUEST_EXPIRY ON AUTHORIZATION_REQUEST(EXPIRY_TIME);
23+
1324
-- Table to store flow context metadata and state
1425
CREATE TABLE FLOW_CONTEXT (
1526
FLOW_ID VARCHAR(36) PRIMARY KEY,

backend/dbscripts/runtimedb/sqlite.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ CREATE TABLE AUTHORIZATION_CODE (
1010
EXPIRY_TIME DATETIME NOT NULL
1111
);
1212

13+
-- Table to store OAuth2 authorization request context
14+
CREATE TABLE AUTHORIZATION_REQUEST (
15+
AUTH_ID VARCHAR(36) PRIMARY KEY,
16+
REQUEST_DATA TEXT NOT NULL,
17+
EXPIRY_TIME DATETIME NOT NULL,
18+
CREATED_AT TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
19+
);
20+
21+
-- Index for efficient expiry cleanup
22+
CREATE INDEX IDX_AUTHORIZATION_REQUEST_EXPIRY ON AUTHORIZATION_REQUEST(EXPIRY_TIME);
23+
1324
-- Table to store flow context metadata and state
1425
CREATE TABLE FLOW_CONTEXT (
1526
FLOW_ID VARCHAR(36) PRIMARY KEY,

backend/internal/oauth/oauth2/authz/store.go renamed to backend/internal/oauth/oauth2/authz/auth_code_store.go

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import (
2222
"encoding/json"
2323
"errors"
2424
"fmt"
25-
"strings"
26-
"time"
2725

2826
"github.com/asgardeo/thunder/internal/system/database/provider"
2927
)
@@ -266,32 +264,3 @@ func appendAuthzDataJSON(row map[string]interface{}, authzCode *AuthorizationCod
266264

267265
return authzCode, nil
268266
}
269-
270-
// parseTimeField parses a time field from the database result.
271-
func parseTimeField(field interface{}, fieldName string) (time.Time, error) {
272-
const customTimeFormat = "2006-01-02 15:04:05.999999999"
273-
274-
switch v := field.(type) {
275-
case string:
276-
trimmedTime := trimTimeString(v)
277-
parsedTime, err := time.Parse(customTimeFormat, trimmedTime)
278-
if err != nil {
279-
return time.Time{}, fmt.Errorf("error parsing %s: %w", fieldName, err)
280-
}
281-
return parsedTime, nil
282-
case time.Time:
283-
return v, nil
284-
default:
285-
return time.Time{}, fmt.Errorf("unexpected type for %s", fieldName)
286-
}
287-
}
288-
289-
// trimTimeString trims extra information from a time string to match the expected format.
290-
func trimTimeString(timeStr string) string {
291-
// Split the string into parts by spaces and retain only the first two parts.
292-
parts := strings.SplitN(timeStr, " ", 3)
293-
if len(parts) >= 2 {
294-
return parts[0] + " " + parts[1]
295-
}
296-
return timeStr
297-
}

backend/internal/oauth/oauth2/authz/store_test.go renamed to backend/internal/oauth/oauth2/authz/auth_code_store_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,11 @@ func (suite *AuthorizationCodeStoreTestSuite) TestUpdateAuthorizationCodeState_E
287287
suite.mockdbProvider.AssertExpectations(suite.T())
288288
}
289289

290+
const testTimeString = "2023-12-01 10:30:45.123456789"
291+
290292
func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_StringInput() {
291-
testTime := "2023-12-01 10:30:45.123456789 extra content"
292-
expectedTime, _ := time.Parse("2006-01-02 15:04:05.999999999", "2023-12-01 10:30:45.123456789")
293+
testTime := testTimeString + " extra content"
294+
expectedTime, _ := time.Parse("2006-01-02 15:04:05.999999999", testTimeString)
293295

294296
result, err := parseTimeField(testTime, "test_field")
295297
assert.NoError(suite.T(), err)
@@ -305,8 +307,8 @@ func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_TimeInput() {
305307
}
306308

307309
func (suite *AuthorizationCodeStoreTestSuite) TestTrimTimeString() {
308-
input := "2023-12-01 10:30:45.123456789 extra content here"
309-
expected := "2023-12-01 10:30:45.123456789"
310+
input := testTimeString + " extra content here"
311+
expected := testTimeString
310312

311313
result := trimTimeString(input)
312314
assert.Equal(suite.T(), expected, result)
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
/*
2+
* Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
3+
*
4+
* WSO2 LLC. licenses this file to you under the Apache License,
5+
* Version 2.0 (the "License"); you may not use this file except
6+
* in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing,
12+
* software distributed under the License is distributed on an
13+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
* KIND, either express or implied. See the License for the
15+
* specific language governing permissions and limitations
16+
* under the License.
17+
*/
18+
19+
package authz
20+
21+
import (
22+
"encoding/json"
23+
"fmt"
24+
"time"
25+
26+
"github.com/asgardeo/thunder/internal/oauth/oauth2/model"
27+
"github.com/asgardeo/thunder/internal/system/database/provider"
28+
"github.com/asgardeo/thunder/internal/system/log"
29+
"github.com/asgardeo/thunder/internal/system/utils"
30+
)
31+
32+
// AuthRequestContext holds OAuth authorization request information.
33+
type AuthRequestContext struct {
34+
OAuthParameters model.OAuthParameters
35+
}
36+
37+
// authorizationRequestStoreInterface defines the interface for authorization request storage.
38+
type authorizationRequestStoreInterface interface {
39+
AddRequest(value AuthRequestContext) string
40+
GetRequest(key string) (bool, AuthRequestContext)
41+
ClearRequest(key string)
42+
}
43+
44+
// authorizationRequestStore provides the authorization request store functionality using database.
45+
type authorizationRequestStore struct {
46+
dbProvider provider.DBProviderInterface
47+
validityPeriod time.Duration
48+
}
49+
50+
// newAuthorizationRequestStore creates a new instance of authorizationRequestStore with injected dependencies.
51+
func newAuthorizationRequestStore() authorizationRequestStoreInterface {
52+
return &authorizationRequestStore{
53+
dbProvider: provider.GetDBProvider(),
54+
validityPeriod: 10 * time.Minute, // Set a default validity period.
55+
}
56+
}
57+
58+
// AddRequest adds an authorization request context entry to the store.
59+
func (authzRS *authorizationRequestStore) AddRequest(value AuthRequestContext) string {
60+
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore"))
61+
62+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
63+
if err != nil {
64+
logger.Error("Failed to get database client", log.Error(err))
65+
return ""
66+
}
67+
68+
key := utils.GenerateUUID()
69+
// Calculate expiry based on current time
70+
requestInitiatedTime := time.Now()
71+
expiryTime := requestInitiatedTime.Add(authzRS.validityPeriod)
72+
73+
// Serialize AuthRequestContext to JSON
74+
jsonDataBytes, err := authzRS.getJSONDataBytes(value)
75+
if err != nil {
76+
logger.Error("Failed to marshal request context to JSON", log.Error(err))
77+
return ""
78+
}
79+
80+
_, err = dbClient.Execute(queryInsertAuthRequest, key, jsonDataBytes, expiryTime)
81+
if err != nil {
82+
logger.Error("Failed to insert authorization request", log.Error(err))
83+
return ""
84+
}
85+
86+
return key
87+
}
88+
89+
// GetRequest retrieves an authorization request context entry from the store.
90+
func (authzRS *authorizationRequestStore) GetRequest(key string) (bool, AuthRequestContext) {
91+
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore"))
92+
93+
if key == "" {
94+
return false, AuthRequestContext{}
95+
}
96+
97+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
98+
if err != nil {
99+
logger.Error("Failed to get database client", log.Error(err))
100+
return false, AuthRequestContext{}
101+
}
102+
103+
// Check expiry by comparing with current time
104+
now := time.Now()
105+
results, err := dbClient.Query(queryGetAuthRequest, key, now)
106+
if err != nil {
107+
logger.Error("Failed to query authorization request", log.Error(err))
108+
return false, AuthRequestContext{}
109+
}
110+
111+
if len(results) == 0 {
112+
return false, AuthRequestContext{}
113+
}
114+
115+
row := results[0]
116+
authRequestContext, err := authzRS.buildAuthRequestContextFromResultRow(row)
117+
if err != nil {
118+
logger.Error("Failed to build authorization request context from result", log.Error(err))
119+
return false, AuthRequestContext{}
120+
}
121+
122+
return true, authRequestContext
123+
}
124+
125+
// ClearRequest removes a specific authorization request context entry from the store.
126+
func (authzRS *authorizationRequestStore) ClearRequest(key string) {
127+
if key == "" {
128+
return
129+
}
130+
131+
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore"))
132+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
133+
if err != nil {
134+
logger.Error("Failed to get database client", log.Error(err))
135+
return
136+
}
137+
138+
_, err = dbClient.Execute(queryDeleteAuthRequest, key)
139+
if err != nil {
140+
logger.Error("Failed to delete authorization request", log.Error(err))
141+
}
142+
}
143+
144+
// getJSONDataBytes prepares the JSON data bytes for the authorization request context.
145+
func (authzRS *authorizationRequestStore) getJSONDataBytes(authRequestContext AuthRequestContext) ([]byte, error) {
146+
jsonData := map[string]interface{}{
147+
"state": authRequestContext.OAuthParameters.State,
148+
"client_id": authRequestContext.OAuthParameters.ClientID,
149+
"redirect_uri": authRequestContext.OAuthParameters.RedirectURI,
150+
"response_type": authRequestContext.OAuthParameters.ResponseType,
151+
"standard_scopes": authRequestContext.OAuthParameters.StandardScopes,
152+
"permission_scopes": authRequestContext.OAuthParameters.PermissionScopes,
153+
"code_challenge": authRequestContext.OAuthParameters.CodeChallenge,
154+
"code_challenge_method": authRequestContext.OAuthParameters.CodeChallengeMethod,
155+
"resource": authRequestContext.OAuthParameters.Resource,
156+
}
157+
158+
jsonDataBytes, err := json.Marshal(jsonData)
159+
if err != nil {
160+
return nil, fmt.Errorf("error marshaling request context to JSON: %w", err)
161+
}
162+
return jsonDataBytes, nil
163+
}
164+
165+
// buildAuthRequestContextFromResultRow builds an AuthRequestContext from a database result row.
166+
func (authzRS *authorizationRequestStore) buildAuthRequestContextFromResultRow(
167+
row map[string]interface{},
168+
) (AuthRequestContext, error) {
169+
// Parse request_data JSON
170+
var dataJSON string
171+
if val, ok := row["request_data"].(string); ok && val != "" {
172+
dataJSON = val
173+
} else if val, ok := row["request_data"].([]byte); ok && len(val) > 0 {
174+
dataJSON = string(val)
175+
} else {
176+
return AuthRequestContext{}, fmt.Errorf("request_data is missing or of unexpected type")
177+
}
178+
179+
var requestDataMap map[string]interface{}
180+
if err := json.Unmarshal([]byte(dataJSON), &requestDataMap); err != nil {
181+
return AuthRequestContext{}, fmt.Errorf("failed to unmarshal request_data JSON: %w", err)
182+
}
183+
184+
// Build OAuthParameters from JSON
185+
oauthParams := model.OAuthParameters{}
186+
// Initialize slices to empty (not nil) to match original behavior
187+
oauthParams.StandardScopes = []string{}
188+
oauthParams.PermissionScopes = []string{}
189+
190+
if state, ok := requestDataMap["state"].(string); ok {
191+
oauthParams.State = state
192+
}
193+
if clientID, ok := requestDataMap["client_id"].(string); ok {
194+
oauthParams.ClientID = clientID
195+
}
196+
if redirectURI, ok := requestDataMap["redirect_uri"].(string); ok {
197+
oauthParams.RedirectURI = redirectURI
198+
}
199+
if responseType, ok := requestDataMap["response_type"].(string); ok {
200+
oauthParams.ResponseType = responseType
201+
}
202+
// Handle standard_scopes - can be []interface{} or []string or nil
203+
if standardScopes, ok := requestDataMap["standard_scopes"].([]interface{}); ok {
204+
oauthParams.StandardScopes = convertToStringArray(standardScopes)
205+
} else if standardScopes, ok := requestDataMap["standard_scopes"].([]string); ok {
206+
oauthParams.StandardScopes = standardScopes
207+
} else if requestDataMap["standard_scopes"] == nil {
208+
// Handle nil case - set to empty slice
209+
oauthParams.StandardScopes = []string{}
210+
}
211+
// Handle permission_scopes - can be []interface{} or []string or nil
212+
if permissionScopes, ok := requestDataMap["permission_scopes"].([]interface{}); ok {
213+
oauthParams.PermissionScopes = convertToStringArray(permissionScopes)
214+
} else if permissionScopes, ok := requestDataMap["permission_scopes"].([]string); ok {
215+
oauthParams.PermissionScopes = permissionScopes
216+
} else if requestDataMap["permission_scopes"] == nil {
217+
// Handle nil case - set to empty slice
218+
oauthParams.PermissionScopes = []string{}
219+
}
220+
if codeChallenge, ok := requestDataMap["code_challenge"].(string); ok {
221+
oauthParams.CodeChallenge = codeChallenge
222+
}
223+
if codeChallengeMethod, ok := requestDataMap["code_challenge_method"].(string); ok {
224+
oauthParams.CodeChallengeMethod = codeChallengeMethod
225+
}
226+
if resource, ok := requestDataMap["resource"].(string); ok {
227+
oauthParams.Resource = resource
228+
}
229+
230+
return AuthRequestContext{
231+
OAuthParameters: oauthParams,
232+
}, nil
233+
}
234+
235+
// convertToStringArray converts []interface{} to []string.
236+
func convertToStringArray(arr []interface{}) []string {
237+
result := make([]string, 0, len(arr))
238+
for _, v := range arr {
239+
if str, ok := v.(string); ok {
240+
result = append(result, str)
241+
}
242+
}
243+
return result
244+
}

0 commit comments

Comments
 (0)