Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ func API(application *application.Application) (*echo.Echo, error) {
if application.AuthDB() != nil {
e.Use(auth.RequireRouteFeature(application.AuthDB()))
e.Use(auth.RequireModelAccess(application.AuthDB()))
e.Use(auth.RequireQuota(application.AuthDB()))
}

// CORS middleware
Expand Down
6 changes: 5 additions & 1 deletion core/http/auth/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ func InitDB(databaseURL string) (*gorm.DB, error) {
return nil, fmt.Errorf("failed to open auth database: %w", err)
}

if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}); err != nil {
if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}, &QuotaRule{}); err != nil {
return nil, fmt.Errorf("failed to migrate auth tables: %w", err)
}

// Backfill: users created before the provider column existed have an empty
// provider — treat them as local accounts so the UI can identify them.
db.Exec("UPDATE users SET provider = ? WHERE provider = '' OR provider IS NULL", ProviderLocal)

// Create composite index on users(provider, subject) for fast OAuth lookups
if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil {
// Ignore error on postgres if index already exists
Expand Down
50 changes: 50 additions & 0 deletions core/http/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/subtle"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -370,6 +371,55 @@ func extractModelFromRequest(c echo.Context) string {
return ""
}

// RequireQuota returns a global middleware that enforces per-user quota rules.
// If no auth DB is provided, it's a no-op. Admin users always bypass quotas.
// Only inference routes (those listed in RouteFeatureRegistry) count toward quota.
func RequireQuota(db *gorm.DB) echo.MiddlewareFunc {
if db == nil {
return NoopMiddleware()
}
// Pre-build lookup set from RouteFeatureRegistry — only these routes
// should count toward quota. Mirrors RequireRouteFeature's approach.
inferenceRoutes := map[string]bool{}
for _, rf := range RouteFeatureRegistry {
inferenceRoutes[rf.Method+":"+rf.Pattern] = true
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Only enforce quotas on inference routes
path := c.Path()
method := c.Request().Method
if !inferenceRoutes[method+":"+path] && !inferenceRoutes["*:"+path] {
return next(c)
}

user := GetUser(c)
if user == nil {
return next(c)
}
if user.Role == RoleAdmin {
return next(c)
}

model := extractModelFromRequest(c)

exceeded, retryAfter, msg := QuotaExceeded(db, user.ID, model)
if exceeded {
c.Response().Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
return c.JSON(http.StatusTooManyRequests, schema.ErrorResponse{
Error: &schema.APIError{
Message: msg,
Code: http.StatusTooManyRequests,
Type: "quota_exceeded",
},
})
}

return next(c)
}
}
}

// tryAuthenticate attempts to authenticate the request using the database.
func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User {
hmacSecret := appConfig.Auth.APIKeyHMACSecret
Expand Down
Loading
Loading