Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/oauth/callback/callback.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package callback

import (
"encoding/json"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -198,6 +199,8 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if needsUserInfo {
sensitiveProps["email"] = userInfo.Email
sensitiveProps["name"] = userInfo.Name
infoJSON, _ := json.Marshal(userInfo)
sensitiveProps["info"] = string(infoJSON)
}

// Initialize props map
Expand Down
58 changes: 38 additions & 20 deletions pkg/oauth/validate/validatetoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ import (
)

type TokenValidator struct {
tokenManager *tokens.TokenManager
encryptionKey []byte
mcpUIManager *mcpui.Manager // Optional MCP UI manager for JWT handling
db TokenStore // Database for refresh operations
provider providers.Provider // OAuth provider for generating auth URLs
clientID string // OAuth client ID
clientSecret string // OAuth client secret
scopesSupported []string // Supported OAuth scopes
routePrefix string
tokenManager *tokens.TokenManager
encryptionKey []byte
mcpUIManager *mcpui.Manager // Optional MCP UI manager for JWT handling
db TokenStore // Database for refresh operations
provider providers.Provider // OAuth provider for generating auth URLs
clientID string // OAuth client ID
clientSecret string // OAuth client secret
scopesSupported []string // Supported OAuth scopes
routePrefix string
requiredAuthPaths []string
}

// TokenStore interface for database operations needed by validator
Expand All @@ -39,17 +40,18 @@ type TokenStore interface {
StoreAuthRequest(key string, data map[string]any) error
}

func NewTokenValidator(tokenManager *tokens.TokenManager, mcpUIManager *mcpui.Manager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string, routePrefix string) *TokenValidator {
func NewTokenValidator(tokenManager *tokens.TokenManager, mcpUIManager *mcpui.Manager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string, routePrefix string, requiredAuthPaths []string) *TokenValidator {
return &TokenValidator{
mcpUIManager: mcpUIManager,
tokenManager: tokenManager,
encryptionKey: encryptionKey,
db: db,
provider: provider,
clientID: clientID,
clientSecret: clientSecret,
scopesSupported: scopesSupported,
routePrefix: routePrefix,
mcpUIManager: mcpUIManager,
tokenManager: tokenManager,
encryptionKey: encryptionKey,
db: db,
provider: provider,
clientID: clientID,
clientSecret: clientSecret,
scopesSupported: scopesSupported,
routePrefix: routePrefix,
requiredAuthPaths: requiredAuthPaths,
}
}

Expand Down Expand Up @@ -181,6 +183,21 @@ func (p *TokenValidator) setCookiesForRefresh(w http.ResponseWriter, r *http.Req
func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" && len(p.requiredAuthPaths) > 0 {
matches := false
for _, path := range p.requiredAuthPaths {
if strings.HasPrefix(r.URL.Path, path) {
matches = true
break
}
}
if !matches {
// Not a protected path, skip validation
next.ServeHTTP(w, r)
return
}
}

if authHeader == "" {
// Try cookie-based authentication with refresh capability
var bearerTokenFromCookie string
Expand Down Expand Up @@ -351,7 +368,8 @@ func (p *TokenValidator) handleOauthFlow(w http.ResponseWriter, r *http.Request)
}

func GetTokenInfo(r *http.Request) *tokens.TokenInfo {
return r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo)
v, _ := r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo)
return v
}

type tokenInfoKey struct{}
Expand Down
44 changes: 23 additions & 21 deletions pkg/providers/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,24 +175,25 @@ func (p *GenericProvider) GetUserInfo(ctx context.Context, accessToken string) (
return nil, fmt.Errorf("failed to decode user info response: %w", err)
}

var userInfo *UserInfo
if p.metadata.UserinfoEndpoint == "https://api.github.com/user" {
userInfo = &UserInfo{
ID: getString(userInfoResp, "login"),
Email: getString(userInfoResp, "email"),
Name: getString(userInfoResp, "name"),
}
} else {
userInfo = &UserInfo{
ID: getString(userInfoResp, "sub"),
Email: getString(userInfoResp, "email"),
Name: getString(userInfoResp, "name"),
}
userInfo := &UserInfo{
ID: getString(userInfoResp, "id"),
Sub: getString(userInfoResp, "sub"),
Login: getString(userInfoResp, "login"),
Email: getString(userInfoResp, "email"),
EmailVerified: getBool(userInfoResp, "email_verified"),
Name: getString(userInfoResp, "name"),
Picture: getString(userInfoResp, "picture"),
GivenName: getString(userInfoResp, "given_name"),
FamilyName: getString(userInfoResp, "family_name"),
Locale: getString(userInfoResp, "locale"),
}

// If sub is not available, try other common ID fields
if userInfo.ID == "" {
userInfo.ID = getString(userInfoResp, "id")
userInfo.ID = userInfo.Sub
}

if userInfo.ID == "" && p.metadata.UserinfoEndpoint == "https://api.github.com/user" {
userInfo.ID = userInfo.Login
}

return userInfo, nil
Expand Down Expand Up @@ -231,12 +232,13 @@ func (p *GenericProvider) GetName() string {
return "generic"
}

func getBool(m map[string]any, key string) bool {
b, _ := m[key].(bool)
return b
}

// Helper functions
func getString(m map[string]any, key string) string {
if val, ok := m[key]; ok {
if str, ok := val.(string); ok {
return str
}
}
return ""
str, _ := m[key].(string)
return str
}
13 changes: 10 additions & 3 deletions pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ import (

// UserInfo represents user information from OAuth provider
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
ID string `json:"id"`
Sub string `json:"sub"`
Login string `json:"login"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
Locale string `json:"locale"`
}

// TokenInfo represents token information from OAuth provider
Expand Down
44 changes: 21 additions & 23 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import (
"encoding/base64"
"fmt"
"log"
"maps"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"

"github.com/gorilla/handlers"
Expand Down Expand Up @@ -43,7 +43,6 @@ type OAuthProxy struct {
provider string
encryptionKey []byte
resourceName string
lock sync.Mutex
config *types.Config

ctx context.Context
Expand All @@ -53,6 +52,7 @@ type OAuthProxy struct {
const (
ModeProxy = "proxy"
ModeForwardAuth = "forward_auth"
Middleware = "middleware"
)

func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) {
Expand Down Expand Up @@ -206,7 +206,7 @@ func (p *OAuthProxy) Start(ctx context.Context) error {
return nil
}

func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux, next http.Handler) {
provider, err := p.providers.GetProvider(p.provider)
if err != nil {
log.Fatalf("Failed to get provider: %v", err)
Expand All @@ -216,7 +216,7 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
tokenHandler := token.NewHandler(p.db)
callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.config.RoutePrefix, p.mcpUIManager)
revokeHandler := revoke.NewHandler(p.db)
tokenValidator := validate.NewTokenValidator(p.tokenManager, p.mcpUIManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported, p.config.RoutePrefix)
tokenValidator := validate.NewTokenValidator(p.tokenManager, p.mcpUIManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported, p.config.RoutePrefix, p.config.RequiredAuthPaths)
successHandler := success.NewHandler()

// Get route prefix from config
Expand All @@ -239,13 +239,15 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET "+prefix+"/auth/mcp-ui/success", p.withCORS(p.withRateLimit(successHandler)))

// Protect everything else
mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.mcpProxyHandler(w, r, next)
})))))
}

// GetHandler returns an http.Handler for the OAuth proxy
func (p *OAuthProxy) GetHandler() http.Handler {
mux := http.NewServeMux()
p.SetupRoutes(mux)
p.SetupRoutes(mux, nil)

// Wrap with logging middleware
loggedHandler := handlers.LoggingHandler(os.Stdout, mux)
Expand Down Expand Up @@ -335,20 +337,17 @@ func (p *OAuthProxy) protectedResourceMetadataHandler(w http.ResponseWriter, r *
handlerutils.JSON(w, http.StatusOK, metadata)
}

func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request, next http.Handler) {
tokenInfo := validate.GetTokenInfo(r)
path := r.PathValue("path")

// Check if the access token is expired and refresh if needed
if tokenInfo.Props != nil {
if tokenInfo != nil && tokenInfo.Props != nil {
if _, ok := tokenInfo.Props["access_token"].(string); ok {
// Check if token is expired (with a 5-minute buffer)
expiresAt, ok := tokenInfo.Props["expires_at"].(float64)
if ok && expiresAt > 0 {
if time.Now().Add(5 * time.Minute).After(time.Unix(int64(expiresAt), 0)) {
// when refreshing token, we need to lock the database to avoid race conditions
// otherwise we could get save the old access token into the database when another refresh process is running
p.lock.Lock()
log.Printf("Access token is expired or will expire soon, attempting to refresh")

// Get the refresh token
Expand All @@ -359,7 +358,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
"error": "invalid_token",
"error_description": "Access token expired and no refresh token available",
})
p.lock.Unlock()
return
}

Expand All @@ -371,7 +369,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
"error": "server_error",
"error_description": "Failed to refresh token",
})
p.lock.Unlock()
return
}

Expand All @@ -384,7 +381,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
"error": "server_error",
"error_description": "OAuth credentials not configured",
})
p.lock.Unlock()
return
}

Expand All @@ -396,7 +392,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
"error": "invalid_token",
"error_description": "Failed to refresh access token",
})
p.lock.Unlock()
return
}

Expand All @@ -407,21 +402,20 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
"error": "server_error",
"error_description": "Failed to update grant with new token",
})
p.lock.Unlock()
return
}

// Update the token info with the new access token for the current request
tokenInfo.Props["access_token"] = newTokenInfo.AccessToken
p.lock.Unlock()

log.Printf("Successfully refreshed access token")
}
}
}
}

switch p.config.Mode {
case Middleware:
next.ServeHTTP(w, r)
case ModeForwardAuth:
setHeaders(w.Header(), tokenInfo.Props)
case ModeProxy:
Expand Down Expand Up @@ -508,13 +502,17 @@ func (p *OAuthProxy) updateGrant(grantID, userID string, oldTokenInfo *tokens.To
return fmt.Errorf("failed to get grant: %w", err)
}

// Prepare sensitive props data
sensitiveProps := map[string]any{
"access_token": newTokenInfo.AccessToken,
"refresh_token": newTokenInfo.RefreshToken,
"expires_at": newTokenInfo.Expiry.Unix(),
sensitiveProps := map[string]any{}
if oldTokenInfo.Props != nil {
// keep all the old props, that include a lot of the user info
maps.Copy(sensitiveProps, oldTokenInfo.Props)
}

// Prepare sensitive props data
sensitiveProps["access_token"] = newTokenInfo.AccessToken
sensitiveProps["refresh_token"] = newTokenInfo.RefreshToken
sensitiveProps["expires_at"] = newTokenInfo.Expiry.Unix()

// Add existing user info if available
if grant.Props != nil {
if email, ok := grant.Props["email"].(string); ok {
Expand Down
8 changes: 7 additions & 1 deletion pkg/ratelimit/ratelimiter.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package ratelimit

import "time"
import (
"sync"
"time"
)

// RateLimiter simple in-memory rate limiter
type RateLimiter struct {
requests map[string][]time.Time
lock sync.Mutex
window time.Duration
max int
}
Expand All @@ -18,6 +22,8 @@ func NewRateLimiter(window time.Duration, max int) *RateLimiter {
}

func (rl *RateLimiter) Allow(key string) bool {
rl.lock.Lock()
defer rl.lock.Unlock()
now := time.Now()
windowStart := now.Add(-rl.window)

Expand Down
1 change: 1 addition & 0 deletions pkg/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Config struct {
MCPServerURL string
Mode string
RoutePrefix string
RequiredAuthPaths []string
}

// TokenData represents stored token data for OAuth 2.1 compliance
Expand Down
Loading