Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions createFeatureFlag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ func handler(ctx context.Context, req events.APIGatewayProxyRequest) (events.API
return corsResponse, err
}

jwtResponse, _, err := jwt.JWTMiddleware()(req)
// Use enhanced middleware with user verification (Week 2 migration)
jwtResponse, userContext, err := jwt.JWTMiddlewareWithUserVerification()(req)
if err != nil || jwtResponse.StatusCode != http.StatusOK {
return jwtResponse, err
}

if userContext == nil {
return utils.ClientError(http.StatusUnauthorized, "User context not available")
}

corsHeaders := middleware.GetCORSHeadersV1(req.Headers)

err = json.Unmarshal([]byte(req.Body), &createFeatureFlagRequest)
Expand All @@ -85,11 +90,16 @@ func handler(ctx context.Context, req events.APIGatewayProxyRequest) (events.API

if err := validate.Struct(&createFeatureFlagRequest); err != nil {
return events.APIGatewayProxyResponse{
Body: "Check the request body passed name, description and userId are required.",
Body: "Check the request body passed name and description are required.",
StatusCode: http.StatusBadRequest,
Headers: corsHeaders,
}, nil
}

// Use userId from authenticated user context (Week 2 migration)
// Override any userId in request body with authenticated user
createFeatureFlagRequest.UserId = userContext.UserId

featureFlag, err := createFeatureFlag(ctx, db, createFeatureFlagRequest)
if err != nil {
log.Printf("Error while creating feature flag: \n %v ", err)
Expand Down
32 changes: 32 additions & 0 deletions layer/database/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,38 @@ func ProcessGetFeatureFlagByHashKey(attributeName string, attributeValue string)
return featureFlagResponse, nil
}

func GetUserById(ctx context.Context, userId string) (*models.User, error) {
db := CreateDynamoDB()

input := &dynamodb.GetItemInput{
TableName: aws.String(utils.USER_TABLE_NAME),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{
Value: userId,
},
},
}

result, err := db.GetItem(ctx, input)
if err != nil {
utils.DdbError(err)
return nil, err
}

if len(result.Item) == 0 {
return nil, nil
}

var user models.User
err = UnmarshalMap(result.Item, &user)
if err != nil {
log.Println(err, " is the error while converting to user object")
return nil, err
}

return &user, nil
}

func AddUserFeatureFlagMapping(featureFlagUserMappings []models.FeatureFlagUserMapping) ([]models.FeatureFlagUserMapping, error) {
ctx := context.TODO()
db := CreateDynamoDB()
Expand Down
126 changes: 126 additions & 0 deletions layer/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"
"time"

"feature-flag-backend/layer/database"
"feature-flag-backend/layer/utils"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -260,6 +261,13 @@ func handleMiddlewareResponse(statusCode int, message string) (events.APIGateway
}, "", nil
}

// UserContextResponse holds the user context and response
type UserContextResponse struct {
Response events.APIGatewayProxyResponse
UserContext *utils.UserContext
Error error
}

func JWTMiddleware() func(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, string, error) {
return func(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, string, error) {
jwtUtils, err := GetInstance()
Expand Down Expand Up @@ -319,6 +327,124 @@ func JWTMiddleware() func(req events.APIGatewayProxyRequest) (events.APIGatewayP
return handleMiddlewareResponse(http.StatusUnauthorized, "Unauthorized")
}

// Return userId (backward compatible)
// Role extraction is available in JWTMiddlewareWithUserVerification
return handleMiddlewareResponse(http.StatusOK, userId)
}
}

// handleMiddlewareResponseWithContext is a helper for the enhanced middleware
func handleMiddlewareResponseWithContext(statusCode int, message string) (events.APIGatewayProxyResponse, *utils.UserContext, error) {
return events.APIGatewayProxyResponse{
StatusCode: statusCode,
Body: message,
}, nil, nil
}

// JWTMiddlewareWithUserVerification is an enhanced middleware that verifies user exists in database
// and returns UserContext. This is for Week 2 migration to internal authentication.
func JWTMiddlewareWithUserVerification() func(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, *utils.UserContext, error) {
return func(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, *utils.UserContext, error) {
jwtUtils, err := GetInstance()
if err != nil {
log.Printf("Failed to get JWTUtils instance: %v", err)
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusInternalServerError, "Internal server error")
return resp, nil, err
}

cookie := ""
for key, val := range req.Headers {
if strings.ToLower(key) == "cookie" {
cookie = val
break
}
}
if cookie == "" {
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusUnauthorized, "Unauthenticated")
return resp, nil, nil
}

envConfig, _ := LoadEnvConfig()
cookieName := envConfig.SessionCookieName
if cookieName == "" {
switch envConfig.Environment {
case utils.PROD:
cookieName = utils.SESSION_COOKIE_NAME_PROD
case utils.DEV:
cookieName = utils.SESSION_COOKIE_NAME_DEV
default:
cookieName = utils.SESSION_COOKIE_NAME_LOCAL
}
}

var jwtToken string
cookies := strings.Split(cookie, ";")
for _, c := range cookies {
c = strings.TrimSpace(c)
if strings.HasPrefix(c, cookieName+"=") {
jwtToken = strings.TrimPrefix(c, cookieName+"=")
break
}
}

if jwtToken == "" {
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusUnauthorized, "Unauthenticated")
return resp, nil, nil
}

claims, err := jwtUtils.ValidateToken(jwtToken)
if err != nil {
log.Printf("Token validation failed: %v", err)
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusUnauthorized, "Invalid token")
return resp, nil, nil
}

userId, err := jwtUtils.ExtractClaim(claims, "userId")
if err != nil {
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusUnauthorized, "Unauthorized")
return resp, nil, nil
}

// Extract role from token
role, _ := jwtUtils.ExtractClaim(claims, "role")
if role == "" {
role = utils.ROLE_VIEWER
}

// Verify user exists in database and is active
ctx := context.Background()
user, err := database.GetUserById(ctx, userId)
if err != nil {
log.Printf("Error fetching user from database: %v", err)
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusInternalServerError, "Internal server error")
return resp, nil, err
}

if user == nil {
log.Printf("User not found in database: %s", userId)
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusUnauthorized, "User not found")
return resp, nil, nil
}

if !user.IsActive {
log.Printf("User account is inactive: %s", userId)
resp, _, _ := handleMiddlewareResponseWithContext(http.StatusForbidden, "User account is inactive")
return resp, nil, nil
}

// Use role from database (source of truth) if it differs from token
// This allows role updates without requiring re-login
if user.Role != "" {
role = user.Role
}

userContext := &utils.UserContext{
UserId: userId,
Role: role,
Email: user.Email,
}

resp, _, _ := handleMiddlewareResponseWithContext(http.StatusOK, "")
return resp, userContext, nil
}
}
2 changes: 1 addition & 1 deletion layer/utils/RequestResponse.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type UpdateFeatureFlagRequest struct {
type CreateFeatureFlagRequest struct {
FlagName string `json:"name" validate:"required"`
Description string `json:"description" validate:"required"`
UserId string `json:"userId" validate:"required"`
UserId string `json:"userId"` // Optional - will be set from authenticated user context
}

type FeatureFlagResponse struct {
Expand Down
9 changes: 9 additions & 0 deletions layer/utils/UserContext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package utils

// UserContext holds authenticated user information
type UserContext struct {
UserId string
Role string
Email string
}