Skip to content

Commit decd3ad

Browse files
committed
Add auth req store
1 parent 3472f6e commit decd3ad

File tree

25 files changed

+1686
-1074
lines changed

25 files changed

+1686
-1074
lines changed

backend/dbscripts/runtimedb/postgres.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@ CREATE TABLE AUTHORIZATION_CODE (
1515
-- Index for deployment isolation on AUTHORIZATION_CODE
1616
CREATE INDEX idx_authorization_code_deployment_id ON AUTHORIZATION_CODE (DEPLOYMENT_ID);
1717

18+
-- Table to store OAuth2 authorization request context
19+
CREATE TABLE AUTHORIZATION_REQUEST (
20+
AUTH_ID VARCHAR(36) NOT NULL,
21+
DEPLOYMENT_ID VARCHAR(255) NOT NULL,
22+
REQUEST_DATA JSONB NOT NULL,
23+
EXPIRY_TIME TIMESTAMP NOT NULL,
24+
CREATED_AT TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
25+
PRIMARY KEY (AUTH_ID, DEPLOYMENT_ID)
26+
);
27+
28+
-- Index for deployment isolation on AUTHORIZATION_REQUEST
29+
CREATE INDEX idx_authorization_request_deployment_id ON AUTHORIZATION_REQUEST (DEPLOYMENT_ID);
30+
31+
-- Index for expiry time on AUTHORIZATION_REQUEST
32+
CREATE INDEX idx_authorization_request_expiry_time ON AUTHORIZATION_REQUEST (EXPIRY_TIME);
33+
1834
-- Table to store flow context metadata and state
1935
CREATE TABLE FLOW_CONTEXT (
2036
FLOW_ID VARCHAR(36) NOT NULL,

backend/dbscripts/runtimedb/sqlite.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@ CREATE TABLE AUTHORIZATION_CODE (
1515
-- Index for deployment isolation on AUTHORIZATION_CODE
1616
CREATE INDEX idx_authorization_code_deployment_id ON AUTHORIZATION_CODE (DEPLOYMENT_ID);
1717

18+
-- Table to store OAuth2 authorization request context
19+
CREATE TABLE AUTHORIZATION_REQUEST (
20+
AUTH_ID VARCHAR(36) NOT NULL,
21+
DEPLOYMENT_ID VARCHAR(255) NOT NULL,
22+
REQUEST_DATA TEXT NOT NULL,
23+
EXPIRY_TIME DATETIME NOT NULL,
24+
CREATED_AT TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
25+
PRIMARY KEY (AUTH_ID, DEPLOYMENT_ID)
26+
);
27+
28+
-- Index for deployment isolation on AUTHORIZATION_REQUEST
29+
CREATE INDEX idx_authorization_request_deployment_id ON AUTHORIZATION_REQUEST (DEPLOYMENT_ID);
30+
31+
-- Index for expiry time on AUTHORIZATION_REQUEST
32+
CREATE INDEX idx_authorization_request_expiry_time ON AUTHORIZATION_REQUEST (EXPIRY_TIME);
33+
1834
-- Table to store flow context metadata and state
1935
CREATE TABLE FLOW_CONTEXT (
2036
FLOW_ID VARCHAR(36) NOT NULL,

backend/internal/oauth/oauth2/authz/store.go renamed to backend/internal/oauth/oauth2/authz/auth_code_store.go

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import (
2222
"encoding/json"
2323
"errors"
2424
"fmt"
25-
"strings"
26-
"time"
2725

2826
"github.com/asgardeo/thunder/internal/system/config"
2927
"github.com/asgardeo/thunder/internal/system/database/provider"
@@ -269,32 +267,3 @@ func appendAuthzDataJSON(row map[string]interface{}, authzCode *AuthorizationCod
269267

270268
return authzCode, nil
271269
}
272-
273-
// parseTimeField parses a time field from the database result.
274-
func parseTimeField(field interface{}, fieldName string) (time.Time, error) {
275-
const customTimeFormat = "2006-01-02 15:04:05.999999999"
276-
277-
switch v := field.(type) {
278-
case string:
279-
trimmedTime := trimTimeString(v)
280-
parsedTime, err := time.Parse(customTimeFormat, trimmedTime)
281-
if err != nil {
282-
return time.Time{}, fmt.Errorf("error parsing %s: %w", fieldName, err)
283-
}
284-
return parsedTime, nil
285-
case time.Time:
286-
return v, nil
287-
default:
288-
return time.Time{}, fmt.Errorf("unexpected type for %s", fieldName)
289-
}
290-
}
291-
292-
// trimTimeString trims extra information from a time string to match the expected format.
293-
func trimTimeString(timeStr string) string {
294-
// Split the string into parts by spaces and retain only the first two parts.
295-
parts := strings.SplitN(timeStr, " ", 3)
296-
if len(parts) >= 2 {
297-
return parts[0] + " " + parts[1]
298-
}
299-
return timeStr
300-
}

backend/internal/oauth/oauth2/authz/store_test.go renamed to backend/internal/oauth/oauth2/authz/auth_code_store_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,11 @@ func (suite *AuthorizationCodeStoreTestSuite) TestUpdateAuthorizationCodeState_E
292292
suite.mockdbProvider.AssertExpectations(suite.T())
293293
}
294294

295+
const testTimeString = "2023-12-01 10:30:45.123456789"
296+
295297
func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_StringInput() {
296-
testTime := "2023-12-01 10:30:45.123456789 extra content"
297-
expectedTime, _ := time.Parse("2006-01-02 15:04:05.999999999", "2023-12-01 10:30:45.123456789")
298+
testTime := testTimeString + " extra content"
299+
expectedTime, _ := time.Parse("2006-01-02 15:04:05.999999999", testTimeString)
298300

299301
result, err := parseTimeField(testTime, "test_field")
300302
assert.NoError(suite.T(), err)
@@ -310,8 +312,8 @@ func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_TimeInput() {
310312
}
311313

312314
func (suite *AuthorizationCodeStoreTestSuite) TestTrimTimeString() {
313-
input := "2023-12-01 10:30:45.123456789 extra content here"
314-
expected := "2023-12-01 10:30:45.123456789"
315+
input := testTimeString + " extra content here"
316+
expected := testTimeString
315317

316318
result := trimTimeString(input)
317319
assert.Equal(suite.T(), expected, result)
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
/*
2+
* Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
3+
*
4+
* WSO2 LLC. licenses this file to you under the Apache License,
5+
* Version 2.0 (the "License"); you may not use this file except
6+
* in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing,
12+
* software distributed under the License is distributed on an
13+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
* KIND, either express or implied. See the License for the
15+
* specific language governing permissions and limitations
16+
* under the License.
17+
*/
18+
19+
package authz
20+
21+
import (
22+
"encoding/json"
23+
"fmt"
24+
"time"
25+
26+
"github.com/asgardeo/thunder/internal/oauth/oauth2/model"
27+
"github.com/asgardeo/thunder/internal/system/config"
28+
"github.com/asgardeo/thunder/internal/system/database/provider"
29+
"github.com/asgardeo/thunder/internal/system/log"
30+
"github.com/asgardeo/thunder/internal/system/utils"
31+
)
32+
33+
// authRequestContext holds OAuth authorization request information.
34+
type authRequestContext struct {
35+
OAuthParameters model.OAuthParameters
36+
}
37+
38+
// authorizationRequestStoreInterface defines the interface for authorization request storage.
39+
type authorizationRequestStoreInterface interface {
40+
AddRequest(value authRequestContext) string
41+
GetRequest(key string) (bool, authRequestContext)
42+
ClearRequest(key string)
43+
}
44+
45+
// authorizationRequestStore provides the authorization request store functionality using database.
46+
type authorizationRequestStore struct {
47+
dbProvider provider.DBProviderInterface
48+
validityPeriod time.Duration
49+
deploymentID string
50+
logger *log.Logger
51+
}
52+
53+
// newAuthorizationRequestStore creates a new instance of authorizationRequestStore with injected dependencies.
54+
func newAuthorizationRequestStore() authorizationRequestStoreInterface {
55+
return &authorizationRequestStore{
56+
dbProvider: provider.GetDBProvider(),
57+
validityPeriod: 10 * time.Minute,
58+
deploymentID: config.GetThunderRuntime().Config.Server.Identifier,
59+
logger: log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationRequestStore")),
60+
}
61+
}
62+
63+
// AddRequest adds an authorization request context entry to the store.
64+
func (authzRS *authorizationRequestStore) AddRequest(value authRequestContext) string {
65+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
66+
if err != nil {
67+
authzRS.logger.Error("Failed to get database client", log.Error(err))
68+
return ""
69+
}
70+
71+
key := utils.GenerateUUID()
72+
// Calculate expiry based on current time
73+
requestInitiatedTime := time.Now()
74+
expiryTime := requestInitiatedTime.Add(authzRS.validityPeriod)
75+
76+
// Serialize authRequestContext to JSON
77+
jsonDataBytes, err := authzRS.getJSONDataBytes(value)
78+
if err != nil {
79+
authzRS.logger.Error("Failed to marshal request context to JSON", log.Error(err))
80+
return ""
81+
}
82+
83+
_, err = dbClient.Execute(queryInsertAuthRequest, key, jsonDataBytes, expiryTime, authzRS.deploymentID)
84+
if err != nil {
85+
authzRS.logger.Error("Failed to insert authorization request", log.Error(err))
86+
return ""
87+
}
88+
89+
return key
90+
}
91+
92+
// GetRequest retrieves an authorization request context entry from the store.
93+
func (authzRS *authorizationRequestStore) GetRequest(key string) (bool, authRequestContext) {
94+
if key == "" {
95+
return false, authRequestContext{}
96+
}
97+
98+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
99+
if err != nil {
100+
authzRS.logger.Error("Failed to get database client", log.Error(err))
101+
return false, authRequestContext{}
102+
}
103+
104+
// Check expiry by comparing with current time
105+
now := time.Now()
106+
results, err := dbClient.Query(queryGetAuthRequest, key, now, authzRS.deploymentID)
107+
if err != nil {
108+
authzRS.logger.Error("Failed to query authorization request", log.Error(err))
109+
return false, authRequestContext{}
110+
}
111+
112+
if len(results) == 0 {
113+
return false, authRequestContext{}
114+
}
115+
116+
row := results[0]
117+
authRequestCtx, err := authzRS.buildAuthRequestContextFromResultRow(row)
118+
if err != nil {
119+
authzRS.logger.Error("Failed to build authorization request context from result", log.Error(err))
120+
return false, authRequestContext{}
121+
}
122+
123+
return true, authRequestCtx
124+
}
125+
126+
// ClearRequest removes a specific authorization request context entry from the store.
127+
func (authzRS *authorizationRequestStore) ClearRequest(key string) {
128+
if key == "" {
129+
return
130+
}
131+
132+
dbClient, err := authzRS.dbProvider.GetRuntimeDBClient()
133+
if err != nil {
134+
authzRS.logger.Error("Failed to get database client", log.Error(err))
135+
return
136+
}
137+
138+
_, err = dbClient.Execute(queryDeleteAuthRequest, key, authzRS.deploymentID)
139+
if err != nil {
140+
authzRS.logger.Error("Failed to delete authorization request", log.Error(err))
141+
}
142+
}
143+
144+
// getJSONDataBytes prepares the JSON data bytes for the authorization request context.
145+
func (authzRS *authorizationRequestStore) getJSONDataBytes(authRequestCtx authRequestContext) ([]byte, error) {
146+
jsonData := map[string]interface{}{
147+
jsonKeyState: authRequestCtx.OAuthParameters.State,
148+
jsonKeyClientID: authRequestCtx.OAuthParameters.ClientID,
149+
jsonKeyRedirectURI: authRequestCtx.OAuthParameters.RedirectURI,
150+
jsonKeyResponseType: authRequestCtx.OAuthParameters.ResponseType,
151+
jsonKeyStandardScopes: authRequestCtx.OAuthParameters.StandardScopes,
152+
jsonKeyPermissionScopes: authRequestCtx.OAuthParameters.PermissionScopes,
153+
jsonKeyCodeChallenge: authRequestCtx.OAuthParameters.CodeChallenge,
154+
jsonKeyCodeChallengeMethod: authRequestCtx.OAuthParameters.CodeChallengeMethod,
155+
jsonKeyResource: authRequestCtx.OAuthParameters.Resource,
156+
}
157+
158+
jsonDataBytes, err := json.Marshal(jsonData)
159+
if err != nil {
160+
return nil, fmt.Errorf("error marshaling request context to JSON: %w", err)
161+
}
162+
return jsonDataBytes, nil
163+
}
164+
165+
// buildAuthRequestContextFromResultRow builds an authRequestContext from a database result row.
166+
func (authzRS *authorizationRequestStore) buildAuthRequestContextFromResultRow(
167+
row map[string]interface{},
168+
) (authRequestContext, error) {
169+
var dataJSON string
170+
if val, ok := row[dbColumnRequestData].(string); ok && val != "" {
171+
dataJSON = val
172+
} else if val, ok := row[dbColumnRequestData].([]byte); ok && len(val) > 0 {
173+
dataJSON = string(val)
174+
} else {
175+
return authRequestContext{}, fmt.Errorf("%s is missing or of unexpected type", dbColumnRequestData)
176+
}
177+
178+
var requestDataMap map[string]interface{}
179+
if err := json.Unmarshal([]byte(dataJSON), &requestDataMap); err != nil {
180+
return authRequestContext{}, fmt.Errorf("failed to unmarshal %s JSON: %w", dbColumnRequestData, err)
181+
}
182+
183+
// Build OAuthParameters from JSON
184+
oauthParams := model.OAuthParameters{}
185+
oauthParams.StandardScopes = []string{}
186+
oauthParams.PermissionScopes = []string{}
187+
188+
if state, ok := requestDataMap[jsonKeyState].(string); ok {
189+
oauthParams.State = state
190+
}
191+
if clientID, ok := requestDataMap[jsonKeyClientID].(string); ok {
192+
oauthParams.ClientID = clientID
193+
}
194+
if redirectURI, ok := requestDataMap[jsonKeyRedirectURI].(string); ok {
195+
oauthParams.RedirectURI = redirectURI
196+
}
197+
if responseType, ok := requestDataMap[jsonKeyResponseType].(string); ok {
198+
oauthParams.ResponseType = responseType
199+
}
200+
// Handle standard_scopes
201+
if standardScopes, ok := requestDataMap[jsonKeyStandardScopes].([]interface{}); ok {
202+
oauthParams.StandardScopes = convertToStringArray(standardScopes)
203+
} else if standardScopes, ok := requestDataMap[jsonKeyStandardScopes].([]string); ok {
204+
oauthParams.StandardScopes = standardScopes
205+
} else if requestDataMap[jsonKeyStandardScopes] == nil {
206+
oauthParams.StandardScopes = []string{}
207+
}
208+
// Handle permission_scopes
209+
if permissionScopes, ok := requestDataMap[jsonKeyPermissionScopes].([]interface{}); ok {
210+
oauthParams.PermissionScopes = convertToStringArray(permissionScopes)
211+
} else if permissionScopes, ok := requestDataMap[jsonKeyPermissionScopes].([]string); ok {
212+
oauthParams.PermissionScopes = permissionScopes
213+
} else if requestDataMap[jsonKeyPermissionScopes] == nil {
214+
oauthParams.PermissionScopes = []string{}
215+
}
216+
if codeChallenge, ok := requestDataMap[jsonKeyCodeChallenge].(string); ok {
217+
oauthParams.CodeChallenge = codeChallenge
218+
}
219+
if codeChallengeMethod, ok := requestDataMap[jsonKeyCodeChallengeMethod].(string); ok {
220+
oauthParams.CodeChallengeMethod = codeChallengeMethod
221+
}
222+
if resource, ok := requestDataMap[jsonKeyResource].(string); ok {
223+
oauthParams.Resource = resource
224+
}
225+
226+
return authRequestContext{
227+
OAuthParameters: oauthParams,
228+
}, nil
229+
}
230+
231+
// convertToStringArray converts []interface{} to []string.
232+
func convertToStringArray(arr []interface{}) []string {
233+
result := make([]string, 0, len(arr))
234+
for _, v := range arr {
235+
if str, ok := v.(string); ok {
236+
result = append(result, str)
237+
}
238+
}
239+
return result
240+
}

0 commit comments

Comments
 (0)