Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 134 additions & 59 deletions auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package authkit

import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
Expand All @@ -14,6 +16,9 @@ import (
// It extracts the Bearer token from the Authorization header (or "token" query
// parameter for WebSocket upgrades), validates it against the JWKS endpoint,
// and stores the parsed claims in the Gin context.
//
// If ServiceToken is configured and JWT validation fails, it will attempt to
// validate the token as a Zitadel session token.
func AuthN(cfg Config) gin.HandlerFunc {
jwks := NewJWKSCache(cfg.IssuerURL + "/oauth/v2/keys")

Expand All @@ -22,8 +27,8 @@ func AuthN(cfg Config) gin.HandlerFunc {
skipSet[p] = true
}

log.Printf("[authkit] Initialized AuthN middleware (issuer=%s, audience=%s, skip=%d paths)",
cfg.IssuerURL, cfg.Audience, len(cfg.SkipPaths))
log.Printf("[authkit] Initialized AuthN middleware (issuer=%s, audience=%s, skip=%d paths, session_auth=%v)",
cfg.IssuerURL, cfg.Audience, len(cfg.SkipPaths), cfg.ServiceToken != "")

return func(c *gin.Context) {
// Skip configured paths
Expand All @@ -41,78 +46,148 @@ func AuthN(cfg Config) gin.HandlerFunc {
return
}

// Parse and validate the JWT
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// Get the key ID from the token header
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("missing kid in token header")
}

// Fetch the public key from JWKS cache
key, err := jwks.GetKey(kid)
if err != nil {
return nil, err
}
return key, nil
},
jwt.WithIssuer(cfg.IssuerURL),
jwt.WithValidMethods([]string{"RS256"}),
)

if err != nil || !token.Valid {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid or expired token",
})
// Try JWT validation first
claims, jwtErr := validateJWT(tokenStr, jwks, cfg)
if jwtErr == nil {
SetClaims(c, claims)
c.Next()
return
}

// Also validate audience if configured
if cfg.Audience != "" {
if err := validateAudience(token, cfg.Audience); err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "token audience mismatch",
})
// If JWT validation failed and ServiceToken is configured, try session validation
if cfg.ServiceToken != "" {
claims, sessionErr := validateSessionToken(tokenStr, cfg)
if sessionErr == nil {
SetClaims(c, claims)
c.Next()
return
}
// Log session validation error for debugging
log.Printf("[authkit] Session validation failed: %v (JWT error: %v)", sessionErr, jwtErr)
}

// Extract claims into our struct
mapClaims, ok := token.Claims.(jwt.MapClaims)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid or expired token",
})
}
}

// validateJWT validates a JWT token and returns claims
func validateJWT(tokenStr string, jwks *JWKSCache, cfg Config) (*Claims, error) {
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, ok := token.Header["kid"].(string)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid token claims",
})
return
return nil, fmt.Errorf("missing kid in token header")
}

claims := &Claims{
Sub: getStringClaim(mapClaims, "sub"),
Email: getStringClaim(mapClaims, "email"),
OrgID: getStringClaim(mapClaims, "urn:zitadel:iam:org:id"),
OrgDomain: getStringClaim(mapClaims, "urn:zitadel:iam:user:resourceowner:primary_domain"),
key, err := jwks.GetKey(kid)
if err != nil {
return nil, err
}
return key, nil
},
jwt.WithIssuer(cfg.IssuerURL),
jwt.WithValidMethods([]string{"RS256"}),
)

// Extract project roles
if roles, ok := mapClaims["urn:zitadel:iam:org:project:roles"].(map[string]interface{}); ok {
claims.Roles = roles
}
if err != nil || !token.Valid {
return nil, fmt.Errorf("invalid JWT: %w", err)
}

// Fallback: extract org ID from roles claim if not present as a top-level claim.
// Zitadel embeds the org ID as the key inside each role grant, e.g.:
// "urn:zitadel:iam:org:project:roles": { "user": { "<orgID>": "domain" } }
if claims.OrgID == "" && claims.Roles != nil {
claims.OrgID = extractOrgIDFromRoles(claims.Roles)
if cfg.Audience != "" {
if err := validateAudience(token, cfg.Audience); err != nil {
return nil, err
}
}

SetClaims(c, claims)
c.Next()
mapClaims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid claims type")
}

claims := &Claims{
Sub: getStringClaim(mapClaims, "sub"),
Email: getStringClaim(mapClaims, "email"),
OrgID: getStringClaim(mapClaims, "urn:zitadel:iam:org:id"),
OrgDomain: getStringClaim(mapClaims, "urn:zitadel:iam:user:resourceowner:primary_domain"),
}

if roles, ok := mapClaims["urn:zitadel:iam:org:project:roles"].(map[string]interface{}); ok {
claims.Roles = roles
}

if claims.OrgID == "" && claims.Roles != nil {
claims.OrgID = extractOrgIDFromRoles(claims.Roles)
}

return claims, nil
}

// validateSessionToken validates a Zitadel session token and returns claims
// Token format: sessionId:sessionToken
func validateSessionToken(tokenStr string, cfg Config) (*Claims, error) {
// Parse sessionId:sessionToken format
parts := strings.SplitN(tokenStr, ":", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid session token format (expected sessionId:sessionToken)")
}

sessionID := parts[0]
sessionToken := parts[1]

// Validate session with Zitadel's Session API
url := fmt.Sprintf("%s/v2/sessions/%s", cfg.IssuerURL, sessionID)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+cfg.ServiceToken)
req.Header.Set("x-zitadel-session-token", sessionToken)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to validate session: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("session validation failed (status %d): %s", resp.StatusCode, string(body))
}

var result struct {
Session struct {
ID string `json:"id"`
Factors struct {
User struct {
ID string `json:"id"`
LoginName string `json:"loginName"`
DisplayName string `json:"displayName"`
OrganizationID string `json:"organizationId"`
} `json:"user"`
} `json:"factors"`
} `json:"session"`
}

if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode session response: %w", err)
}

userFactor := result.Session.Factors.User

if userFactor.ID == "" {
return nil, fmt.Errorf("session has no user factor")
}

return &Claims{
Sub: userFactor.ID,
Email: userFactor.LoginName,
OrgID: userFactor.OrganizationID,
}, nil
}

// extractToken gets the JWT from the Authorization header or "token" query param.
Expand Down
21 changes: 3 additions & 18 deletions claims.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,14 @@
package authkit

import (
"github.com/Prescott-Data/dromos-authkit/internal/models"
"github.com/gin-gonic/gin"
)

const claimsKey = "dromos_auth_claims"

// Claims represents the validated JWT claims from Zitadel.
type Claims struct {
// Sub is the Zitadel user ID.
Sub string `json:"sub"`

// Email is the user's email address.
Email string `json:"email"`

// OrgID is the Zitadel organization ID the user belongs to.
OrgID string `json:"urn:zitadel:iam:org:id"`

// OrgDomain is the primary domain of the user's resource owner organization.
OrgDomain string `json:"urn:zitadel:iam:user:resourceowner:primary_domain"`

// Roles maps role names to their grant details.
// The keys are role names (e.g. "admin", "editor").
Roles map[string]interface{} `json:"urn:zitadel:iam:org:project:roles"`
}
// Claims is an alias to models.Claims for backward compatibility.
type Claims = models.Claims

// SetClaims stores validated claims in the Gin context.
func SetClaims(c *gin.Context, claims *Claims) {
Expand Down
14 changes: 3 additions & 11 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
package authkit

// Config holds the configuration for the auth middleware.
type Config struct {
// IssuerURL is the Zitadel issuer URL (e.g. "http://172.191.51.250:8080").
IssuerURL string
import "github.com/Prescott-Data/dromos-authkit/internal/models"

// Audience is the expected audience claim (Zitadel project ID).
Audience string

// SkipPaths lists route paths that bypass authentication (e.g. health checks).
// These should match Gin's FullPath() patterns (e.g. "/api/v1/health").
SkipPaths []string
}
// Config is an alias to models.Config for backward compatibility.
type Config = models.Config
32 changes: 32 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package authkit

import "errors"

// Invitation and access code errors.
var (
// ErrInvalidAccessCode is returned when an access code format is invalid.
ErrInvalidAccessCode = errors.New("invalid access code format")

// ErrAccessCodeExpired is returned when an access code has expired.
ErrAccessCodeExpired = errors.New("access code has expired")

// ErrAccessCodeUsed is returned when an access code has already been used.
ErrAccessCodeUsed = errors.New("access code has already been used")

// ErrInvitationNotFound is returned when an invitation cannot be found.
ErrInvitationNotFound = errors.New("invitation not found")

// ErrInvitationExpired is returned when an invitation has expired.
ErrInvitationExpired = errors.New("invitation has expired")
)

// Organization errors.
var (
// ErrUnauthorizedOrgAction is returned when a user attempts an action
// they don't have permission for within an organization.
ErrUnauthorizedOrgAction = errors.New("unauthorized organization action")

// ErrUserAlreadyExists is returned when attempting to create a user
// that already exists in the identity provider.
ErrUserAlreadyExists = errors.New("user already exists")
)
10 changes: 10 additions & 0 deletions internal/models/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package models

// Claims represents the validated JWT claims from Zitadel.
type Claims struct {
Sub string `json:"sub"`
Email string `json:"email"`
OrgID string `json:"urn:zitadel:iam:org:id"`
OrgDomain string `json:"urn:zitadel:iam:user:resourceowner:primary_domain"`
Roles map[string]interface{} `json:"urn:zitadel:iam:org:project:roles"`
}
9 changes: 9 additions & 0 deletions internal/models/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package models

// Config holds the configuration for the auth middleware.
type Config struct {
IssuerURL string
Audience string
SkipPaths []string
ServiceToken string // enables session token validation as fallback
}
21 changes: 21 additions & 0 deletions internal/models/invitation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package models

import "time"

// InvitationClaims represents the claims embedded in an invitation token.
type InvitationClaims struct {
InvitationID string `json:"invitation_id"`
OrgID string `json:"org_id"`
Email string `json:"email"`
Role string `json:"role"`
ExpiresAt time.Time `json:"expires_at"`
CreatedBy string `json:"created_by"`
}

// AccessCode represents a secure one-time access code for invitation acceptance.
type AccessCode struct {
Code string `json:"code"`
InvitationID string `json:"invitation_id"`
ExpiresAt time.Time `json:"expires_at"`
UsedAt *time.Time `json:"used_at,omitempty"`
}
Loading