Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oauth

import (
"fmt"
"strconv"

"github.com/tuannvm/oauth-mcp-proxy/provider"
)
Expand Down Expand Up @@ -31,6 +32,11 @@ type Config struct {
// Implement the Logger interface (Debug, Info, Warn, Error methods) to
// integrate with your application's logging system (e.g., zap, logrus).
Logger Logger

// Validation skip configuration
SkipIssuerCheck bool
SkipAudienceCheck bool
SkipExpiryCheck bool
}

// Validate validates the configuration
Expand Down Expand Up @@ -119,11 +125,14 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) {
// Convert root Config to provider.Config
providerCfg := &provider.Config{
Provider: cfg.Provider,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
JWTSecret: cfg.JWTSecret,
Logger: logger,
Provider: cfg.Provider,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
JWTSecret: cfg.JWTSecret,
Logger: logger,
SkipIssuerCheck: cfg.SkipIssuerCheck,
SkipAudienceCheck: cfg.SkipAudienceCheck,
SkipExpiryCheck: cfg.SkipExpiryCheck,
}

var validator provider.TokenValidator
Expand Down Expand Up @@ -223,6 +232,24 @@ func (b *ConfigBuilder) WithLogger(logger Logger) *ConfigBuilder {
return b
}

// WithSkipIssuerCheck sets issuer check toggle
func (b *ConfigBuilder) WithSkipIssuerCheck(skipIssuerCheck bool) *ConfigBuilder {
b.config.SkipIssuerCheck = skipIssuerCheck
return b
}

// WithSkipAudienceCheck sets audience check toggle
func (b *ConfigBuilder) WithSkipAudienceCheck(skipAudienceCheck bool) *ConfigBuilder {
b.config.SkipAudienceCheck = skipAudienceCheck
return b
}

// WithSkipExpiryCheck sets expiry check toggle
func (b *ConfigBuilder) WithSkipExpiryCheck(skipExpiryCheck bool) *ConfigBuilder {
b.config.SkipExpiryCheck = skipExpiryCheck
return b
}

// WithServerURL sets the full server URL directly
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
b.config.ServerURL = url
Expand Down Expand Up @@ -289,7 +316,23 @@ func FromEnv() (*Config, error) {
WithAudience(getEnv("OIDC_AUDIENCE", "")).
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
WithClientSecret(getEnv("OIDC_CLIENT_SECRET", "")).
WithSkipAudienceCheck(parseBoolEnv("OIDC_SKIP_AUDIENCE_CHECK", false)).
WithSkipIssuerCheck(parseBoolEnv("OIDC_SKIP_ISSUER_CHECK", false)).
WithSkipExpiryCheck(parseBoolEnv("OIDC_SKIP_EXPIRY_CHECK", false)).
WithServerURL(serverURL).
WithJWTSecret([]byte(jwtSecret)).
Build()
}

// parseBoolEnv parses a boolean environment variable
func parseBoolEnv(key string, defaultVal bool) bool {
val := getEnv(key, "")
if val == "" {
return defaultVal
}
parsed, err := strconv.ParseBool(val)
if err != nil {
return defaultVal
}
return parsed
}
20 changes: 11 additions & 9 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ type Logger interface {

// Config holds OAuth configuration (subset needed by provider)
type Config struct {
Provider string
Issuer string
Audience string
JWTSecret []byte
Logger Logger
Provider string
Issuer string
Audience string
JWTSecret []byte
Logger Logger
SkipIssuerCheck bool
SkipAudienceCheck bool
SkipExpiryCheck bool
}

// TokenValidator interface for OAuth token validation
Expand Down Expand Up @@ -90,7 +93,6 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (
}
return []byte(v.secret), nil
})

if err != nil {
return nil, fmt.Errorf("failed to parse and validate token: %w", err)
}
Expand Down Expand Up @@ -204,9 +206,9 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
verifier := provider.Verifier(&oidc.Config{
ClientID: cfg.Audience, // Note: go-oidc uses ClientID field for audience validation - see https://github.com/coreos/go-oidc/blob/v3/oidc/verify.go#L85
SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256},
SkipClientIDCheck: false, // Always validate if ClientID is provided
SkipExpiryCheck: false, // Verify expiration
SkipIssuerCheck: false, // Verify issuer
SkipClientIDCheck: cfg.SkipAudienceCheck,
SkipExpiryCheck: cfg.SkipExpiryCheck,
SkipIssuerCheck: cfg.SkipIssuerCheck,
})

v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience)
Expand Down