diff --git a/Dockerfile b/Dockerfile index 17c783ec3ae7..4318398193b5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -256,7 +256,7 @@ RUN apt-get update && \ FROM build-requirements AS builder-base -ARG GO_TAGS="" +ARG GO_TAGS="auth" ARG GRPC_BACKENDS ARG MAKEFLAGS ARG LD_FLAGS="-s -w" diff --git a/core/application/application.go b/core/application/application.go index f1adc71449ed..c636be38f137 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" + "gorm.io/gorm" ) type Application struct { @@ -22,6 +23,7 @@ type Application struct { galleryService *services.GalleryService agentJobService *services.AgentJobService agentPoolService atomic.Pointer[services.AgentPoolService] + authDB *gorm.DB watchdogMutex sync.Mutex watchdogStop chan bool p2pMutex sync.Mutex @@ -74,6 +76,11 @@ func (a *Application) AgentPoolService() *services.AgentPoolService { return a.agentPoolService.Load() } +// AuthDB returns the auth database connection, or nil if auth is not enabled. +func (a *Application) AuthDB() *gorm.DB { + return a.authDB +} + // StartupConfig returns the original startup configuration (from env vars, before file loading) func (a *Application) StartupConfig() *config.ApplicationConfig { return a.startupConfig @@ -118,9 +125,23 @@ func (a *Application) StartAgentPool() { xlog.Error("Failed to create agent pool service", "error", err) return } + if a.authDB != nil { + aps.SetAuthDB(a.authDB) + } if err := aps.Start(a.applicationConfig.Context); err != nil { xlog.Error("Failed to start agent pool", "error", err) return } + + // Wire per-user scoped services so collections, skills, and jobs are isolated per user + usm := services.NewUserServicesManager( + aps.UserStorage(), + a.applicationConfig, + a.modelLoader, + a.backendLoader, + a.templatesEvaluator, + ) + aps.SetUserServicesManager(usm) + a.agentPoolService.Store(aps) } diff --git a/core/application/config_file_watcher.go b/core/application/config_file_watcher.go index d9b1671ac8c7..094eebbbe69d 100644 --- a/core/application/config_file_watcher.go +++ b/core/application/config_file_watcher.go @@ -207,7 +207,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand envF16 := appConfig.F16 == startupAppConfig.F16 envDebug := appConfig.Debug == startupAppConfig.Debug envCORS := appConfig.CORS == startupAppConfig.CORS - envCSRF := appConfig.CSRF == startupAppConfig.CSRF + envCSRF := appConfig.DisableCSRF == startupAppConfig.DisableCSRF envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID @@ -313,7 +313,7 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand appConfig.CORS = *settings.CORS } if settings.CSRF != nil && !envCSRF { - appConfig.CSRF = *settings.CSRF + appConfig.DisableCSRF = *settings.CSRF } if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins { appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins diff --git a/core/application/startup.go b/core/application/startup.go index 0d69763b486e..17d3b234657a 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -1,6 +1,8 @@ package application import ( + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "os" @@ -10,6 +12,7 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" @@ -81,6 +84,45 @@ func New(opts ...config.AppOption) (*Application, error) { } } + // Initialize auth database if auth is enabled + if options.Auth.Enabled { + // Auto-generate HMAC secret if not provided + if options.Auth.APIKeyHMACSecret == "" { + secretFile := filepath.Join(options.DataPath, ".hmac_secret") + secret, err := loadOrGenerateHMACSecret(secretFile) + if err != nil { + return nil, fmt.Errorf("failed to initialize HMAC secret: %w", err) + } + options.Auth.APIKeyHMACSecret = secret + } + + authDB, err := auth.InitDB(options.Auth.DatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to initialize auth database: %w", err) + } + application.authDB = authDB + xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL) + + // Start session and expired API key cleanup goroutine + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + for { + select { + case <-options.Context.Done(): + return + case <-ticker.C: + if err := auth.CleanExpiredSessions(authDB); err != nil { + xlog.Error("failed to clean expired sessions", "error", err) + } + if err := auth.CleanExpiredAPIKeys(authDB); err != nil { + xlog.Error("failed to clean expired API keys", "error", err) + } + } + } + }() + } + if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { xlog.Error("error installing models", "error", err) } @@ -434,6 +476,31 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon } } +// loadOrGenerateHMACSecret loads an HMAC secret from the given file path, +// or generates a random 32-byte secret and persists it if the file doesn't exist. +func loadOrGenerateHMACSecret(path string) (string, error) { + data, err := os.ReadFile(path) + if err == nil { + secret := string(data) + if len(secret) >= 32 { + return secret, nil + } + } + + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate HMAC secret: %w", err) + } + secret := hex.EncodeToString(b) + + if err := os.WriteFile(path, []byte(secret), 0600); err != nil { + return "", fmt.Errorf("failed to persist HMAC secret: %w", err) + } + + xlog.Info("Generated new HMAC secret for API key hashing", "path", path) + return secret, nil +} + // migrateDataFiles moves persistent data files from the old config directory // to the new data directory. Only moves files that exist in src but not in dst. func migrateDataFiles(srcDir, dstDir string) { diff --git a/core/cli/run.go b/core/cli/run.go index 163797ac08aa..c614e123d05b 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -58,7 +58,7 @@ type RunCMD struct { Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` - CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` + DisableCSRF bool `env:"LOCALAI_DISABLE_CSRF" help:"Disable CSRF middleware (enabled by default)" group:"api"` UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"` @@ -121,6 +121,21 @@ type RunCMD struct { AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` + // Authentication + AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` + AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"` + GitHubClientID string `env:"GITHUB_CLIENT_ID" help:"GitHub OAuth App Client ID (auto-enables auth when set)" group:"auth"` + GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET" help:"GitHub OAuth App Client Secret" group:"auth"` + OIDCIssuer string `env:"LOCALAI_OIDC_ISSUER" help:"OIDC issuer URL for auto-discovery" group:"auth"` + OIDCClientID string `env:"LOCALAI_OIDC_CLIENT_ID" help:"OIDC Client ID (auto-enables auth)" group:"auth"` + OIDCClientSecret string `env:"LOCALAI_OIDC_CLIENT_SECRET" help:"OIDC Client Secret" group:"auth"` + AuthBaseURL string `env:"LOCALAI_BASE_URL" help:"Base URL for OAuth callbacks (e.g. http://localhost:8080)" group:"auth"` + AuthAdminEmail string `env:"LOCALAI_ADMIN_EMAIL" help:"Email address to auto-promote to admin role" group:"auth"` + AuthRegistrationMode string `env:"LOCALAI_REGISTRATION_MODE" default:"open" help:"Registration mode: 'open' (default), 'approval', or 'invite' (invite code required)" group:"auth"` + DisableLocalAuth bool `env:"LOCALAI_DISABLE_LOCAL_AUTH" default:"false" help:"Disable local email/password registration and login (use with OAuth/OIDC-only setups)" group:"auth"` + AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"` + DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"` + Version bool } @@ -165,7 +180,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithBackendGalleries(r.BackendGalleries), config.WithCors(r.CORS), config.WithCorsAllowOrigins(r.CORSAllowOrigins), - config.WithCsrf(r.CSRF), + config.WithDisableCSRF(r.DisableCSRF), config.WithThreads(r.Threads), config.WithUploadLimitMB(r.UploadLimit), config.WithApiKeys(r.APIKeys), @@ -311,6 +326,46 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.WithAgentHubURL(r.AgentHubURL)) } + // Authentication + authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != "" + if authEnabled { + opts = append(opts, config.WithAuthEnabled(true)) + + dbURL := r.AuthDatabaseURL + if dbURL == "" { + dbURL = filepath.Join(r.DataPath, "database.db") + } + opts = append(opts, config.WithAuthDatabaseURL(dbURL)) + + if r.GitHubClientID != "" { + opts = append(opts, config.WithAuthGitHubClientID(r.GitHubClientID)) + opts = append(opts, config.WithAuthGitHubClientSecret(r.GitHubClientSecret)) + } + if r.OIDCClientID != "" { + opts = append(opts, config.WithAuthOIDCIssuer(r.OIDCIssuer)) + opts = append(opts, config.WithAuthOIDCClientID(r.OIDCClientID)) + opts = append(opts, config.WithAuthOIDCClientSecret(r.OIDCClientSecret)) + } + if r.AuthBaseURL != "" { + opts = append(opts, config.WithAuthBaseURL(r.AuthBaseURL)) + } + if r.AuthAdminEmail != "" { + opts = append(opts, config.WithAuthAdminEmail(r.AuthAdminEmail)) + } + if r.AuthRegistrationMode != "" { + opts = append(opts, config.WithAuthRegistrationMode(r.AuthRegistrationMode)) + } + if r.DisableLocalAuth { + opts = append(opts, config.WithAuthDisableLocalAuth(true)) + } + if r.AuthAPIKeyHMACSecret != "" { + opts = append(opts, config.WithAuthAPIKeyHMACSecret(r.AuthAPIKeyHMACSecret)) + } + if r.DefaultAPIKeyExpiry != "" { + opts = append(opts, config.WithAuthDefaultAPIKeyExpiry(r.DefaultAPIKeyExpiry)) + } + } + if idleWatchDog || busyWatchDog { opts = append(opts, config.EnableWatchDog) if idleWatchDog { diff --git a/core/config/application_config.go b/core/config/application_config.go index 74c3511a6594..9c1be82d980a 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -30,7 +30,7 @@ type ApplicationConfig struct { DynamicConfigsDir string DynamicConfigsDirPollInterval time.Duration CORS bool - CSRF bool + DisableCSRF bool PreloadJSONModels string PreloadModelsFromPath string CORSAllowOrigins string @@ -96,6 +96,26 @@ type ApplicationConfig struct { // Agent Pool (LocalAGI integration) AgentPool AgentPoolConfig + + // Authentication & Authorization + Auth AuthConfig +} + +// AuthConfig holds configuration for user authentication and authorization. +type AuthConfig struct { + Enabled bool + DatabaseURL string // "postgres://..." or file path for SQLite + GitHubClientID string + GitHubClientSecret string + OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com) + OIDCClientID string + OIDCClientSecret string + BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080") + AdminEmail string // auto-promote to admin on login + RegistrationMode string // "open", "approval" (default when empty), "invite" + DisableLocalAuth bool // disable local email/password registration and login + APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty + DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry } // AgentPoolConfig holds configuration for the LocalAGI agent pool integration. @@ -150,6 +170,8 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig { "/favicon.svg", "/readyz", "/healthz", + "/api/auth/", + "/assets/", }, } for _, oo := range o { @@ -194,9 +216,9 @@ func WithP2PNetworkID(s string) AppOption { } } -func WithCsrf(b bool) AppOption { +func WithDisableCSRF(b bool) AppOption { return func(o *ApplicationConfig) { - o.CSRF = b + o.DisableCSRF = b } } @@ -711,6 +733,86 @@ func WithAgentHubURL(url string) AppOption { } } +// Auth options + +func WithAuthEnabled(enabled bool) AppOption { + return func(o *ApplicationConfig) { + o.Auth.Enabled = enabled + } +} + +func WithAuthDatabaseURL(url string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.DatabaseURL = url + } +} + +func WithAuthGitHubClientID(clientID string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.GitHubClientID = clientID + } +} + +func WithAuthGitHubClientSecret(clientSecret string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.GitHubClientSecret = clientSecret + } +} + +func WithAuthBaseURL(baseURL string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.BaseURL = baseURL + } +} + +func WithAuthAdminEmail(email string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.AdminEmail = email + } +} + +func WithAuthRegistrationMode(mode string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.RegistrationMode = mode + } +} + +func WithAuthDisableLocalAuth(disable bool) AppOption { + return func(o *ApplicationConfig) { + o.Auth.DisableLocalAuth = disable + } +} + +func WithAuthOIDCIssuer(issuer string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.OIDCIssuer = issuer + } +} + +func WithAuthOIDCClientID(clientID string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.OIDCClientID = clientID + } +} + +func WithAuthOIDCClientSecret(clientSecret string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.OIDCClientSecret = clientSecret + } +} + +func WithAuthAPIKeyHMACSecret(secret string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.APIKeyHMACSecret = secret + } +} + +func WithAuthDefaultAPIKeyExpiry(expiry string) AppOption { + return func(o *ApplicationConfig) { + o.Auth.DefaultAPIKeyExpiry = expiry + } +} + // ToConfigLoaderOptions returns a slice of ConfigLoader Option. // Some options defined at the application level are going to be passed as defaults for // all the configuration for the models. @@ -750,7 +852,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { enableTracing := o.EnableTracing enableBackendLogging := o.EnableBackendLogging cors := o.CORS - csrf := o.CSRF + csrf := o.DisableCSRF corsAllowOrigins := o.CORSAllowOrigins p2pToken := o.P2PToken p2pNetworkID := o.P2PNetworkID @@ -958,7 +1060,7 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req o.CORS = *settings.CORS } if settings.CSRF != nil { - o.CSRF = *settings.CSRF + o.DisableCSRF = *settings.CSRF } if settings.CORSAllowOrigins != nil { o.CORSAllowOrigins = *settings.CORSAllowOrigins diff --git a/core/config/application_config_test.go b/core/config/application_config_test.go index c6d4fbecd6bc..c6c6c15b98a1 100644 --- a/core/config/application_config_test.go +++ b/core/config/application_config_test.go @@ -26,7 +26,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { F16: true, Debug: true, CORS: true, - CSRF: true, + DisableCSRF: true, CORSAllowOrigins: "https://example.com", P2PToken: "test-token", P2PNetworkID: "test-network", @@ -377,7 +377,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.CORS).To(BeTrue()) - Expect(appConfig.CSRF).To(BeTrue()) + Expect(appConfig.DisableCSRF).To(BeTrue()) Expect(appConfig.CORSAllowOrigins).To(Equal("https://example.com,https://other.com")) }) @@ -463,7 +463,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { F16: true, Debug: false, CORS: true, - CSRF: false, + DisableCSRF: false, CORSAllowOrigins: "https://test.com", P2PToken: "round-trip-token", P2PNetworkID: "round-trip-network", @@ -495,7 +495,7 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { Expect(target.F16).To(Equal(original.F16)) Expect(target.Debug).To(Equal(original.Debug)) Expect(target.CORS).To(Equal(original.CORS)) - Expect(target.CSRF).To(Equal(original.CSRF)) + Expect(target.DisableCSRF).To(Equal(original.DisableCSRF)) Expect(target.CORSAllowOrigins).To(Equal(original.CORSAllowOrigins)) Expect(target.P2PToken).To(Equal(original.P2PToken)) Expect(target.P2PNetworkID).To(Equal(original.P2PNetworkID)) diff --git a/core/http/app.go b/core/http/app.go index 138515fb7edc..e2da479b5442 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -14,6 +14,7 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/endpoints/localai" httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" @@ -170,11 +171,9 @@ func API(application *application.Application) (*echo.Echo, error) { // Health Checks should always be exempt from auth, so register these first routes.HealthRoutes(e) - // Get key auth middleware - keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig()) - if err != nil { - return nil, fmt.Errorf("failed to create key auth config: %w", err) - } + // Build auth middleware: use the new auth.Middleware when auth is enabled or + // as a unified replacement for the legacy key-auth middleware. + authMiddleware := auth.Middleware(application.AuthDB(), application.ApplicationConfig()) // Favicon handler e.GET("/favicon.svg", func(c echo.Context) error { @@ -209,8 +208,20 @@ func API(application *application.Application) (*echo.Echo, error) { e.Static("/generated-videos", videoPath) } - // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration - e.Use(keyAuthMiddleware) + // Initialize usage recording when auth DB is available + if application.AuthDB() != nil { + httpMiddleware.InitUsageRecorder(application.AuthDB()) + } + + // Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is + // the role of the exempt-path logic inside the middleware. + e.Use(authMiddleware) + + // Feature and model access control (after auth middleware, before routes) + if application.AuthDB() != nil { + e.Use(auth.RequireRouteFeature(application.AuthDB())) + e.Use(auth.RequireModelAccess(application.AuthDB())) + } // CORS middleware if application.ApplicationConfig().CORS { @@ -223,14 +234,63 @@ func API(application *application.Application) (*echo.Echo, error) { e.Use(middleware.CORS()) } - // CSRF middleware - if application.ApplicationConfig().CSRF { - xlog.Debug("Enabling CSRF middleware. Tokens are now required for state-modifying requests") - e.Use(middleware.CSRF()) + // CSRF middleware (enabled by default, disable with LOCALAI_DISABLE_CSRF=true) + // + // Protection relies on Echo's Sec-Fetch-Site header check (supported by all + // modern browsers). The legacy cookie+token approach is removed because + // Echo's Sec-Fetch-Site short-circuit never sets the cookie, so the frontend + // could never read a token to send back. + if !application.ApplicationConfig().DisableCSRF { + xlog.Debug("Enabling CSRF middleware (Sec-Fetch-Site mode)") + e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ + Skipper: func(c echo.Context) bool { + // Skip CSRF for API clients using auth headers (may be cross-origin) + if c.Request().Header.Get("Authorization") != "" { + return true + } + if c.Request().Header.Get("x-api-key") != "" || c.Request().Header.Get("xi-api-key") != "" { + return true + } + // Skip when Sec-Fetch-Site header is absent (older browsers, reverse + // proxies that strip the header). The SameSite=Lax cookie attribute + // provides baseline CSRF protection for these clients. + if c.Request().Header.Get("Sec-Fetch-Site") == "" { + return true + } + return false + }, + // Allow same-site requests (subdomains / different ports) in addition + // to same-origin which Echo already permits by default. + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + secFetchSite := c.Request().Header.Get("Sec-Fetch-Site") + if secFetchSite == "same-site" { + return true, nil + } + // cross-site: block + return false, nil + }, + })) } + // Admin middleware: enforces admin role when auth is enabled, no-op otherwise + var adminMiddleware echo.MiddlewareFunc + if application.AuthDB() != nil { + adminMiddleware = auth.RequireAdmin() + } else { + adminMiddleware = auth.NoopMiddleware() + } + + // Feature middlewares: per-feature access control + agentsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureAgents) + skillsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureSkills) + collectionsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureCollections) + mcpJobsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCPJobs) + requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + // Register auth routes (login, callback, API keys, user management) + routes.RegisterAuthRoutes(e, application) + routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Create opcache for tracking UI operations (used by both UI and LocalAI routes) @@ -239,14 +299,15 @@ func API(application *application.Application) (*echo.Echo, error) { opcache = services.NewOpCache(application.GalleryService()) } - routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application) - routes.RegisterAgentPoolRoutes(e, application) + mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP) + routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw) + routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { - routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application) - routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService()) + routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware) + routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware) // Serve React SPA from / with SPA fallback via 404 handler reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist") diff --git a/core/http/app_test.go b/core/http/app_test.go index a2742aa8ee04..903aae17b19a 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -428,8 +428,10 @@ var _ = Describe("API test", func() { "X-Forwarded-Prefix": {"/myprefix/"}, }) Expect(err).To(BeNil(), "error") - Expect(sc).To(Equal(401), "status code") + Expect(sc).To(Equal(200), "status code") + // Non-API paths pass through to the React SPA (which handles login client-side) Expect(string(body)).To(ContainSubstring(``), "body") + Expect(string(body)).To(ContainSubstring(`
`), "should serve React SPA") }) It("Should support reverse-proxy when authenticated", func() { diff --git a/core/http/auth/apikeys.go b/core/http/auth/apikeys.go new file mode 100644 index 000000000000..295030f2cad6 --- /dev/null +++ b/core/http/auth/apikeys.go @@ -0,0 +1,121 @@ +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +const ( + apiKeyPrefix = "lai-" + apiKeyRandBytes = 32 // 32 bytes = 64 hex chars + keyPrefixLen = 8 // display prefix length (from the random part) +) + +// GenerateAPIKey generates a new API key. Returns the plaintext key, +// its HMAC-SHA256 hash, and a display prefix. +func GenerateAPIKey(hmacSecret string) (plaintext, hash, prefix string, err error) { + b := make([]byte, apiKeyRandBytes) + if _, err := rand.Read(b); err != nil { + return "", "", "", fmt.Errorf("failed to generate API key: %w", err) + } + + randHex := hex.EncodeToString(b) + plaintext = apiKeyPrefix + randHex + hash = HashAPIKey(plaintext, hmacSecret) + prefix = plaintext[:len(apiKeyPrefix)+keyPrefixLen] + + return plaintext, hash, prefix, nil +} + +// HashAPIKey returns the HMAC-SHA256 hex digest of the given plaintext key. +// If hmacSecret is empty, falls back to plain SHA-256 for backward compatibility. +func HashAPIKey(plaintext, hmacSecret string) string { + if hmacSecret == "" { + h := sha256.Sum256([]byte(plaintext)) + return hex.EncodeToString(h[:]) + } + mac := hmac.New(sha256.New, []byte(hmacSecret)) + mac.Write([]byte(plaintext)) + return hex.EncodeToString(mac.Sum(nil)) +} + +// CreateAPIKey generates and stores a new API key for the given user. +// Returns the plaintext key (shown once) and the database record. +func CreateAPIKey(db *gorm.DB, userID, name, role, hmacSecret string, expiresAt *time.Time) (string, *UserAPIKey, error) { + plaintext, hash, prefix, err := GenerateAPIKey(hmacSecret) + if err != nil { + return "", nil, err + } + + record := &UserAPIKey{ + ID: uuid.New().String(), + UserID: userID, + Name: name, + KeyHash: hash, + KeyPrefix: prefix, + Role: role, + ExpiresAt: expiresAt, + } + + if err := db.Create(record).Error; err != nil { + return "", nil, fmt.Errorf("failed to store API key: %w", err) + } + + return plaintext, record, nil +} + +// ValidateAPIKey looks up an API key by hashing the plaintext and searching +// the database. Returns the key record if found, or an error. +// Updates LastUsed on successful validation. +func ValidateAPIKey(db *gorm.DB, plaintext, hmacSecret string) (*UserAPIKey, error) { + hash := HashAPIKey(plaintext, hmacSecret) + + var key UserAPIKey + if err := db.Preload("User").Where("key_hash = ?", hash).First(&key).Error; err != nil { + return nil, fmt.Errorf("invalid API key") + } + + if key.ExpiresAt != nil && time.Now().After(*key.ExpiresAt) { + return nil, fmt.Errorf("API key expired") + } + + if key.User.Status != StatusActive { + return nil, fmt.Errorf("user account is not active") + } + + // Update LastUsed + now := time.Now() + db.Model(&key).Update("last_used", now) + + return &key, nil +} + +// ListAPIKeys returns all API keys for the given user (without plaintext). +func ListAPIKeys(db *gorm.DB, userID string) ([]UserAPIKey, error) { + var keys []UserAPIKey + if err := db.Where("user_id = ?", userID).Order("created_at DESC").Find(&keys).Error; err != nil { + return nil, err + } + return keys, nil +} + +// RevokeAPIKey deletes an API key. Only the owner can revoke their own key. +func RevokeAPIKey(db *gorm.DB, keyID, userID string) error { + result := db.Where("id = ? AND user_id = ?", keyID, userID).Delete(&UserAPIKey{}) + if result.RowsAffected == 0 { + return fmt.Errorf("API key not found or not owned by user") + } + return result.Error +} + +// CleanExpiredAPIKeys removes all API keys that have passed their expiry time. +func CleanExpiredAPIKeys(db *gorm.DB) error { + return db.Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Delete(&UserAPIKey{}).Error +} diff --git a/core/http/auth/apikeys_test.go b/core/http/auth/apikeys_test.go new file mode 100644 index 000000000000..4441af07af7d --- /dev/null +++ b/core/http/auth/apikeys_test.go @@ -0,0 +1,212 @@ +//go:build auth + +package auth_test + +import ( + "strings" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("API Keys", func() { + var ( + db *gorm.DB + user *auth.User + ) + + // Use empty HMAC secret for tests (falls back to plain SHA-256) + hmacSecret := "" + + BeforeEach(func() { + db = testDB() + user = createTestUser(db, "apikey@example.com", auth.RoleUser, auth.ProviderGitHub) + }) + + Describe("GenerateAPIKey", func() { + It("returns key with 'lai-' prefix", func() { + plaintext, _, _, err := auth.GenerateAPIKey(hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(plaintext).To(HavePrefix("lai-")) + }) + + It("returns consistent hash for same plaintext", func() { + plaintext, hash, _, err := auth.GenerateAPIKey(hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(auth.HashAPIKey(plaintext, hmacSecret)).To(Equal(hash)) + }) + + It("returns prefix for display", func() { + _, _, prefix, err := auth.GenerateAPIKey(hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(prefix).To(HavePrefix("lai-")) + Expect(len(prefix)).To(Equal(12)) // "lai-" + 8 chars + }) + + It("generates unique keys", func() { + key1, _, _, _ := auth.GenerateAPIKey(hmacSecret) + key2, _, _, _ := auth.GenerateAPIKey(hmacSecret) + Expect(key1).ToNot(Equal(key2)) + }) + }) + + Describe("CreateAPIKey", func() { + It("stores hashed key in DB", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(plaintext).To(HavePrefix("lai-")) + Expect(record.KeyHash).To(Equal(auth.HashAPIKey(plaintext, hmacSecret))) + }) + + It("does not store plaintext in DB", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + var keys []auth.UserAPIKey + db.Find(&keys) + for _, k := range keys { + Expect(k.KeyHash).ToNot(Equal(plaintext)) + Expect(strings.Contains(k.KeyHash, "lai-")).To(BeFalse()) + } + }) + + It("inherits role from parameter", func() { + _, record, err := auth.CreateAPIKey(db, user.ID, "admin key", auth.RoleAdmin, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(record.Role).To(Equal(auth.RoleAdmin)) + }) + }) + + Describe("ValidateAPIKey", func() { + It("returns UserAPIKey for valid key", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "valid key", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(found).ToNot(BeNil()) + Expect(found.UserID).To(Equal(user.ID)) + }) + + It("returns error for invalid key", func() { + _, err := auth.ValidateAPIKey(db, "lai-invalidkey12345678901234567890", hmacSecret) + Expect(err).To(HaveOccurred()) + }) + + It("updates LastUsed timestamp", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "used key", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(record.LastUsed).To(BeNil()) + + _, err = auth.ValidateAPIKey(db, plaintext, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + var updated auth.UserAPIKey + db.First(&updated, "id = ?", record.ID) + Expect(updated.LastUsed).ToNot(BeNil()) + }) + + It("loads associated user", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "with user", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(found.User.ID).To(Equal(user.ID)) + Expect(found.User.Email).To(Equal("apikey@example.com")) + }) + }) + + Describe("ListAPIKeys", func() { + It("returns all keys for the user", func() { + auth.CreateAPIKey(db, user.ID, "key1", auth.RoleUser, hmacSecret, nil) + auth.CreateAPIKey(db, user.ID, "key2", auth.RoleUser, hmacSecret, nil) + + keys, err := auth.ListAPIKeys(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(keys).To(HaveLen(2)) + }) + + It("does not return other users' keys", func() { + other := createTestUser(db, "other@example.com", auth.RoleUser, auth.ProviderGitHub) + auth.CreateAPIKey(db, user.ID, "my key", auth.RoleUser, hmacSecret, nil) + auth.CreateAPIKey(db, other.ID, "other key", auth.RoleUser, hmacSecret, nil) + + keys, err := auth.ListAPIKeys(db, user.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(keys).To(HaveLen(1)) + Expect(keys[0].Name).To(Equal("my key")) + }) + }) + + Context("with HMAC secret", func() { + hmacSecretVal := "test-hmac-secret-456" + + It("generates different hash than empty secret", func() { + plaintext, _, _, err := auth.GenerateAPIKey("") + Expect(err).ToNot(HaveOccurred()) + + hashEmpty := auth.HashAPIKey(plaintext, "") + hashHMAC := auth.HashAPIKey(plaintext, hmacSecretVal) + Expect(hashEmpty).ToNot(Equal(hashHMAC)) + }) + + It("round-trips CreateAPIKey and ValidateAPIKey with HMAC secret", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key", auth.RoleUser, hmacSecretVal, nil) + Expect(err).ToNot(HaveOccurred()) + + found, err := auth.ValidateAPIKey(db, plaintext, hmacSecretVal) + Expect(err).ToNot(HaveOccurred()) + Expect(found).ToNot(BeNil()) + Expect(found.UserID).To(Equal(user.ID)) + }) + + It("does not validate with wrong HMAC secret", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key2", auth.RoleUser, hmacSecretVal, nil) + Expect(err).ToNot(HaveOccurred()) + + _, err = auth.ValidateAPIKey(db, plaintext, "wrong-secret") + Expect(err).To(HaveOccurred()) + }) + + It("does not validate key created with empty secret using non-empty secret", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "empty-secret key", auth.RoleUser, "", nil) + Expect(err).ToNot(HaveOccurred()) + + _, err = auth.ValidateAPIKey(db, plaintext, hmacSecretVal) + Expect(err).To(HaveOccurred()) + }) + + It("does not validate key created with non-empty secret using empty secret", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "nonempty-secret key", auth.RoleUser, hmacSecretVal, nil) + Expect(err).ToNot(HaveOccurred()) + + _, err = auth.ValidateAPIKey(db, plaintext, "") + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("RevokeAPIKey", func() { + It("deletes the key record", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + err = auth.RevokeAPIKey(db, record.ID, user.ID) + Expect(err).ToNot(HaveOccurred()) + + _, err = auth.ValidateAPIKey(db, plaintext, hmacSecret) + Expect(err).To(HaveOccurred()) + }) + + It("only allows owner to revoke their own key", func() { + _, record, err := auth.CreateAPIKey(db, user.ID, "mine", auth.RoleUser, hmacSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + other := createTestUser(db, "attacker@example.com", auth.RoleUser, auth.ProviderGitHub) + err = auth.RevokeAPIKey(db, record.ID, other.ID) + Expect(err).To(HaveOccurred()) + }) + }) +}) diff --git a/core/http/auth/auth_suite_test.go b/core/http/auth/auth_suite_test.go new file mode 100644 index 000000000000..c32c18ed19df --- /dev/null +++ b/core/http/auth/auth_suite_test.go @@ -0,0 +1,15 @@ +//go:build auth + +package auth_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAuth(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Auth Suite") +} diff --git a/core/http/auth/db.go b/core/http/auth/db.go new file mode 100644 index 000000000000..f3b2f0d3866f --- /dev/null +++ b/core/http/auth/db.go @@ -0,0 +1,49 @@ +package auth + +import ( + "fmt" + "strings" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// InitDB initializes the auth database. If databaseURL starts with "postgres://" +// or "postgresql://", it connects to PostgreSQL; otherwise it treats the value +// as a SQLite file path (use ":memory:" for in-memory). +// SQLite support requires building with the "auth" build tag (CGO). +func InitDB(databaseURL string) (*gorm.DB, error) { + var dialector gorm.Dialector + + if strings.HasPrefix(databaseURL, "postgres://") || strings.HasPrefix(databaseURL, "postgresql://") { + dialector = postgres.Open(databaseURL) + } else { + d, err := openSQLiteDialector(databaseURL) + if err != nil { + return nil, err + } + dialector = d + } + + db, err := gorm.Open(dialector, &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + return nil, fmt.Errorf("failed to open auth database: %w", err) + } + + if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}); err != nil { + return nil, fmt.Errorf("failed to migrate auth tables: %w", err) + } + + // Create composite index on users(provider, subject) for fast OAuth lookups + if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil { + // Ignore error on postgres if index already exists + if !strings.Contains(err.Error(), "already exists") { + return nil, fmt.Errorf("failed to create composite index: %w", err) + } + } + + return db, nil +} diff --git a/core/http/auth/db_nosqlite.go b/core/http/auth/db_nosqlite.go new file mode 100644 index 000000000000..73233bf4543c --- /dev/null +++ b/core/http/auth/db_nosqlite.go @@ -0,0 +1,13 @@ +//go:build !auth + +package auth + +import ( + "fmt" + + "gorm.io/gorm" +) + +func openSQLiteDialector(path string) (gorm.Dialector, error) { + return nil, fmt.Errorf("SQLite auth database requires building with -tags auth (CGO); use DATABASE_URL with PostgreSQL instead") +} diff --git a/core/http/auth/db_sqlite.go b/core/http/auth/db_sqlite.go new file mode 100644 index 000000000000..5c13ecf05cc4 --- /dev/null +++ b/core/http/auth/db_sqlite.go @@ -0,0 +1,12 @@ +//go:build auth + +package auth + +import ( + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func openSQLiteDialector(path string) (gorm.Dialector, error) { + return sqlite.Open(path), nil +} diff --git a/core/http/auth/db_test.go b/core/http/auth/db_test.go new file mode 100644 index 000000000000..6d603d85da68 --- /dev/null +++ b/core/http/auth/db_test.go @@ -0,0 +1,53 @@ +//go:build auth + +package auth_test + +import ( + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("InitDB", func() { + Context("SQLite", func() { + It("creates all tables with in-memory SQLite", func() { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + Expect(db).ToNot(BeNil()) + + // Verify tables exist + Expect(db.Migrator().HasTable(&auth.User{})).To(BeTrue()) + Expect(db.Migrator().HasTable(&auth.Session{})).To(BeTrue()) + Expect(db.Migrator().HasTable(&auth.UserAPIKey{})).To(BeTrue()) + }) + + It("is idempotent - running twice does not error", func() { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + + // Re-migrate on same DB should succeed + err = db.AutoMigrate(&auth.User{}, &auth.Session{}, &auth.UserAPIKey{}) + Expect(err).ToNot(HaveOccurred()) + }) + + It("creates composite index on users(provider, subject)", func() { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + + // Insert a user to verify the index doesn't prevent normal operations + user := &auth.User{ + ID: "test-1", + Provider: auth.ProviderGitHub, + Subject: "12345", + Role: "admin", + Status: auth.StatusActive, + } + Expect(db.Create(user).Error).ToNot(HaveOccurred()) + + // Query using the indexed columns should work + var found auth.User + Expect(db.Where("provider = ? AND subject = ?", auth.ProviderGitHub, "12345").First(&found).Error).ToNot(HaveOccurred()) + Expect(found.ID).To(Equal("test-1")) + }) + }) +}) diff --git a/core/http/auth/features.go b/core/http/auth/features.go new file mode 100644 index 000000000000..85b4e60ec982 --- /dev/null +++ b/core/http/auth/features.go @@ -0,0 +1,125 @@ +package auth + +// RouteFeature maps a route pattern + HTTP method to a required feature. +type RouteFeature struct { + Method string // "POST", "GET", "*" (any) + Pattern string // Echo route pattern, e.g. "/v1/chat/completions" + Feature string // Feature constant, e.g. FeatureChat +} + +// RouteFeatureRegistry is the single source of truth for endpoint -> feature mappings. +// To gate a new endpoint, add an entry here -- no other file changes needed. +var RouteFeatureRegistry = []RouteFeature{ + // Chat / Completions + {"POST", "/v1/chat/completions", FeatureChat}, + {"POST", "/chat/completions", FeatureChat}, + {"POST", "/v1/completions", FeatureChat}, + {"POST", "/completions", FeatureChat}, + {"POST", "/v1/engines/:model/completions", FeatureChat}, + {"POST", "/v1/edits", FeatureChat}, + {"POST", "/edits", FeatureChat}, + + // Anthropic + {"POST", "/v1/messages", FeatureChat}, + {"POST", "/messages", FeatureChat}, + + // Open Responses + {"POST", "/v1/responses", FeatureChat}, + {"POST", "/responses", FeatureChat}, + {"GET", "/v1/responses", FeatureChat}, + {"GET", "/responses", FeatureChat}, + + // Embeddings + {"POST", "/v1/embeddings", FeatureEmbeddings}, + {"POST", "/embeddings", FeatureEmbeddings}, + {"POST", "/v1/engines/:model/embeddings", FeatureEmbeddings}, + + // Images + {"POST", "/v1/images/generations", FeatureImages}, + {"POST", "/images/generations", FeatureImages}, + {"POST", "/v1/images/inpainting", FeatureImages}, + {"POST", "/images/inpainting", FeatureImages}, + + // Audio transcription + {"POST", "/v1/audio/transcriptions", FeatureAudioTranscription}, + {"POST", "/audio/transcriptions", FeatureAudioTranscription}, + + // Audio speech / TTS + {"POST", "/v1/audio/speech", FeatureAudioSpeech}, + {"POST", "/audio/speech", FeatureAudioSpeech}, + {"POST", "/tts", FeatureAudioSpeech}, + {"POST", "/v1/text-to-speech/:voice-id", FeatureAudioSpeech}, + + // VAD + {"POST", "/vad", FeatureVAD}, + {"POST", "/v1/vad", FeatureVAD}, + + // Detection + {"POST", "/v1/detection", FeatureDetection}, + + // Video + {"POST", "/video", FeatureVideo}, + + // Sound generation + {"POST", "/v1/sound-generation", FeatureSound}, + + // Realtime + {"GET", "/v1/realtime", FeatureRealtime}, + {"POST", "/v1/realtime/sessions", FeatureRealtime}, + {"POST", "/v1/realtime/transcription_session", FeatureRealtime}, + {"POST", "/v1/realtime/calls", FeatureRealtime}, + + // MCP + {"POST", "/v1/mcp/chat/completions", FeatureMCP}, + {"POST", "/mcp/v1/chat/completions", FeatureMCP}, + {"POST", "/mcp/chat/completions", FeatureMCP}, + + // Tokenize + {"POST", "/v1/tokenize", FeatureTokenize}, + + // Rerank + {"POST", "/v1/rerank", FeatureRerank}, + + // Stores + {"POST", "/stores/set", FeatureStores}, + {"POST", "/stores/delete", FeatureStores}, + {"POST", "/stores/get", FeatureStores}, + {"POST", "/stores/find", FeatureStores}, +} + +// FeatureMeta describes a feature for the admin API/UI. +type FeatureMeta struct { + Key string `json:"key"` + Label string `json:"label"` + DefaultValue bool `json:"default"` +} + +// AgentFeatureMetas returns metadata for agent features. +func AgentFeatureMetas() []FeatureMeta { + return []FeatureMeta{ + {FeatureAgents, "Agents", false}, + {FeatureSkills, "Skills", false}, + {FeatureCollections, "Collections", false}, + {FeatureMCPJobs, "MCP CI Jobs", false}, + } +} + +// APIFeatureMetas returns metadata for API endpoint features. +func APIFeatureMetas() []FeatureMeta { + return []FeatureMeta{ + {FeatureChat, "Chat Completions", true}, + {FeatureImages, "Image Generation", true}, + {FeatureAudioSpeech, "Audio Speech / TTS", true}, + {FeatureAudioTranscription, "Audio Transcription", true}, + {FeatureVAD, "Voice Activity Detection", true}, + {FeatureDetection, "Detection", true}, + {FeatureVideo, "Video Generation", true}, + {FeatureEmbeddings, "Embeddings", true}, + {FeatureSound, "Sound Generation", true}, + {FeatureRealtime, "Realtime", true}, + {FeatureRerank, "Rerank", true}, + {FeatureTokenize, "Tokenize", true}, + {FeatureMCP, "MCP", true}, + {FeatureStores, "Stores", true}, + } +} diff --git a/core/http/auth/helpers_test.go b/core/http/auth/helpers_test.go new file mode 100644 index 000000000000..1fcf9e449479 --- /dev/null +++ b/core/http/auth/helpers_test.go @@ -0,0 +1,155 @@ +//go:build auth + +package auth_test + +import ( + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +// testDB creates an in-memory SQLite GORM instance with auto-migration. +func testDB() *gorm.DB { + db, err := auth.InitDB(":memory:") + Expect(err).ToNot(HaveOccurred()) + return db +} + +// createTestUser inserts a user directly into the DB for test setup. +func createTestUser(db *gorm.DB, email, role, provider string) *auth.User { + user := &auth.User{ + ID: generateTestID(), + Email: email, + Name: "Test User", + Provider: provider, + Subject: generateTestID(), + Role: role, + Status: auth.StatusActive, + } + err := db.Create(user).Error + Expect(err).ToNot(HaveOccurred()) + return user +} + +// createTestSession creates a session for a user, returns plaintext session token. +func createTestSession(db *gorm.DB, userID string) string { + sessionID, err := auth.CreateSession(db, userID, "") + Expect(err).ToNot(HaveOccurred()) + return sessionID +} + +var testIDCounter int + +func generateTestID() string { + testIDCounter++ + return "test-id-" + string(rune('a'+testIDCounter)) +} + +// ok is a simple handler that returns 200 OK. +func ok(c echo.Context) error { + return c.String(http.StatusOK, "ok") +} + +// newAuthTestApp creates a minimal Echo app with the new auth middleware. +func newAuthTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo { + e := echo.New() + e.Use(auth.Middleware(db, appConfig)) + + // API routes (require auth) + e.GET("/v1/models", ok) + e.POST("/v1/chat/completions", ok) + e.GET("/api/settings", ok) + e.POST("/api/settings", ok) + + // Auth routes (exempt) + e.GET("/api/auth/status", ok) + e.GET("/api/auth/github/login", ok) + + // Static routes + e.GET("/app", ok) + e.GET("/app/*", ok) + + return e +} + +// newAdminTestApp creates an Echo app with admin-protected routes. +func newAdminTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo { + e := echo.New() + e.Use(auth.Middleware(db, appConfig)) + + // Regular routes + e.GET("/v1/models", ok) + e.POST("/v1/chat/completions", ok) + + // Admin-only routes + adminMw := auth.RequireAdmin() + e.POST("/api/settings", ok, adminMw) + e.POST("/models/apply", ok, adminMw) + e.POST("/backends/apply", ok, adminMw) + e.GET("/api/agents", ok, adminMw) + + // Trace/log endpoints (admin only) + e.GET("/api/traces", ok, adminMw) + e.POST("/api/traces/clear", ok, adminMw) + e.GET("/api/backend-logs", ok, adminMw) + e.GET("/api/backend-logs/:modelId", ok, adminMw) + + // Gallery/management reads (admin only) + e.GET("/api/operations", ok, adminMw) + e.GET("/api/models", ok, adminMw) + e.GET("/api/backends", ok, adminMw) + e.GET("/api/resources", ok, adminMw) + e.GET("/api/p2p/workers", ok, adminMw) + + // Agent task/job routes (admin only) + e.POST("/api/agent/tasks", ok, adminMw) + e.GET("/api/agent/tasks", ok, adminMw) + e.GET("/api/agent/jobs", ok, adminMw) + + // System info (admin only) + e.GET("/system", ok, adminMw) + e.GET("/backend/monitor", ok, adminMw) + + return e +} + +// doRequest performs an HTTP request against the given Echo app and returns the recorder. +func doRequest(e *echo.Echo, method, path string, opts ...func(*http.Request)) *httptest.ResponseRecorder { + req := httptest.NewRequest(method, path, nil) + req.Header.Set("Content-Type", "application/json") + for _, opt := range opts { + opt(req) + } + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec +} + +func withBearerToken(token string) func(*http.Request) { + return func(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+token) + } +} + +func withXApiKey(key string) func(*http.Request) { + return func(req *http.Request) { + req.Header.Set("x-api-key", key) + } +} + +func withSessionCookie(sessionID string) func(*http.Request) { + return func(req *http.Request) { + req.AddCookie(&http.Cookie{Name: "session", Value: sessionID}) + } +} + +func withTokenCookie(token string) func(*http.Request) { + return func(req *http.Request) { + req.AddCookie(&http.Cookie{Name: "token", Value: token}) + } +} diff --git a/core/http/auth/middleware.go b/core/http/auth/middleware.go new file mode 100644 index 000000000000..9525b9f22a3a --- /dev/null +++ b/core/http/auth/middleware.go @@ -0,0 +1,522 @@ +package auth + +import ( + "bytes" + "crypto/subtle" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "gorm.io/gorm" +) + +const ( + contextKeyUser = "auth_user" + contextKeyRole = "auth_role" +) + +// Middleware returns an Echo middleware that handles authentication. +// +// Resolution order: +// 1. If auth not enabled AND no legacy API keys → pass through +// 2. Skip auth for exempt paths (PathWithoutAuth + /api/auth/) +// 3. If auth enabled (db != nil): +// a. Try "session" cookie → DB lookup +// b. Try Authorization: Bearer → session ID, then user API key +// c. Try x-api-key / xi-api-key → user API key +// d. Try "token" cookie → legacy API key check +// e. Check all extracted keys against legacy ApiKeys → synthetic admin +// 4. If auth not enabled → delegate to legacy API key validation +// 5. If no auth found for /api/ or /v1/ paths → 401 +// 6. Otherwise pass through (static assets, UI pages, etc.) +func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + authEnabled := db != nil + hasLegacyKeys := len(appConfig.ApiKeys) > 0 + + // 1. No auth at all + if !authEnabled && !hasLegacyKeys { + return next(c) + } + + path := c.Request().URL.Path + exempt := isExemptPath(path, appConfig) + authenticated := false + + // 2. Try to authenticate (populates user in context if possible) + if authEnabled { + user := tryAuthenticate(c, db, appConfig) + if user != nil { + c.Set(contextKeyUser, user) + c.Set(contextKeyRole, user.Role) + authenticated = true + + // Session rotation for cookie-based sessions + if session, ok := c.Get("_auth_session").(*Session); ok { + MaybeRotateSession(c, db, session, appConfig.Auth.APIKeyHMACSecret) + } + } + } + + // 3. Legacy API key validation (works whether auth is enabled or not) + if !authenticated && hasLegacyKeys { + key := extractKey(c) + if key != "" && isValidLegacyKey(key, appConfig) { + syntheticUser := &User{ + ID: "legacy-api-key", + Name: "API Key User", + Role: RoleAdmin, + } + c.Set(contextKeyUser, syntheticUser) + c.Set(contextKeyRole, RoleAdmin) + authenticated = true + } + } + + // 4. If authenticated or exempt path, proceed + if authenticated || exempt { + return next(c) + } + + // 5. Require auth for API paths + if isAPIPath(path) { + // Check GET exemptions for legacy keys + if hasLegacyKeys && appConfig.DisableApiKeyRequirementForHttpGet && c.Request().Method == http.MethodGet { + for _, rx := range appConfig.HttpGetExemptedEndpoints { + if rx.MatchString(c.Path()) { + return next(c) + } + } + } + return authError(c, appConfig) + } + + // 6. Non-API paths (UI, static assets) pass through. + // The React UI handles login redirects client-side. + return next(c) + } + } +} + +// RequireAdmin returns middleware that checks the user has admin role. +func RequireAdmin() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + user := GetUser(c) + if user == nil { + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "Authentication required", + Code: http.StatusUnauthorized, + Type: "authentication_error", + }, + }) + } + if user.Role != RoleAdmin { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "Admin access required", + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + return next(c) + } + } +} + +// NoopMiddleware returns a middleware that does nothing (pass-through). +// Used when auth is disabled to satisfy route registration that expects +// an admin middleware parameter. +func NoopMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return next + } +} + +// RequireFeature returns middleware that checks the user has access to the given feature. +// If no auth DB is provided, it passes through (backward compat). +// Admins always pass. Regular users must have the feature enabled in their permissions. +func RequireFeature(db *gorm.DB, feature string) echo.MiddlewareFunc { + if db == nil { + return NoopMiddleware() + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + user := GetUser(c) + if user == nil { + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "Authentication required", + Code: http.StatusUnauthorized, + Type: "authentication_error", + }, + }) + } + if user.Role == RoleAdmin { + return next(c) + } + perm, err := GetCachedUserPermissions(c, db, user.ID) + if err != nil { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "feature not enabled for your account", + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + val, exists := perm.Permissions[feature] + if !exists { + if !isDefaultOnFeature(feature) { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "feature not enabled for your account", + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + } else if !val { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "feature not enabled for your account", + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + return next(c) + } + } +} + +// GetUser returns the authenticated user from the echo context, or nil. +func GetUser(c echo.Context) *User { + u, ok := c.Get(contextKeyUser).(*User) + if !ok { + return nil + } + return u +} + +// GetUserRole returns the role of the authenticated user, or empty string. +func GetUserRole(c echo.Context) string { + role, _ := c.Get(contextKeyRole).(string) + return role +} + +// RequireRouteFeature returns a global middleware that checks the user has access +// to the feature required by the matched route. It uses the RouteFeatureRegistry +// to look up the required feature for each route pattern + HTTP method. +// If no entry matches, the request passes through (no restriction). +func RequireRouteFeature(db *gorm.DB) echo.MiddlewareFunc { + if db == nil { + return NoopMiddleware() + } + // Pre-build lookup map: "METHOD:pattern" -> feature + lookup := map[string]string{} + for _, rf := range RouteFeatureRegistry { + lookup[rf.Method+":"+rf.Pattern] = rf.Feature + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Path() // Echo route pattern (e.g. "/v1/engines/:model/completions") + method := c.Request().Method + feature := lookup[method+":"+path] + if feature == "" { + feature = lookup["*:"+path] + } + if feature == "" { + return next(c) // no restriction for this route + } + user := GetUser(c) + if user == nil { + return next(c) // auth middleware handles unauthenticated + } + if user.Role == RoleAdmin { + return next(c) + } + perm, err := GetCachedUserPermissions(c, db, user.ID) + if err != nil { + return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "failed to check permissions", + Code: http.StatusInternalServerError, + Type: "server_error", + }, + }) + } + val, exists := perm.Permissions[feature] + if !exists { + if !isDefaultOnFeature(feature) { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "feature not enabled for your account: " + feature, + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + } else if !val { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "feature not enabled for your account: " + feature, + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + return next(c) + } + } +} + +// RequireModelAccess returns a global middleware that checks the user is allowed +// to use the resolved model. It extracts the model name directly from the request +// (path param, query param, JSON body, or form value) rather than relying on a +// context key set by downstream route-specific middleware. +func RequireModelAccess(db *gorm.DB) echo.MiddlewareFunc { + if db == nil { + return NoopMiddleware() + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + user := GetUser(c) + if user == nil { + return next(c) + } + if user.Role == RoleAdmin { + return next(c) + } + + // Check if this user even has a model allowlist enabled before + // doing the expensive body read. Most users won't have restrictions. + // Uses request-scoped cache to avoid duplicate DB hit when + // RequireRouteFeature already fetched permissions. + perm, err := GetCachedUserPermissions(c, db, user.ID) + if err != nil { + return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "failed to check permissions", + Code: http.StatusInternalServerError, + Type: "server_error", + }, + }) + } + allowlist := perm.AllowedModels + if !allowlist.Enabled { + return next(c) + } + + modelName := extractModelFromRequest(c) + if modelName == "" { + return next(c) + } + + for _, m := range allowlist.Models { + if m == modelName { + return next(c) + } + } + + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "access denied to model: " + modelName, + Code: http.StatusForbidden, + Type: "authorization_error", + }, + }) + } + } +} + +// extractModelFromRequest extracts the model name from various request sources. +// It checks URL path params, query params, JSON body, and form values. +// For JSON bodies, it peeks at the body and resets it so downstream handlers +// can still read it. +func extractModelFromRequest(c echo.Context) string { + // 1. URL path param (e.g. /v1/engines/:model/completions) + if model := c.Param("model"); model != "" { + return model + } + // 2. Query param + if model := c.QueryParam("model"); model != "" { + return model + } + // 3. Peek at JSON body + if strings.HasPrefix(c.Request().Header.Get("Content-Type"), "application/json") { + body, err := io.ReadAll(c.Request().Body) + c.Request().Body = io.NopCloser(bytes.NewReader(body)) // always reset + if err == nil && len(body) > 0 { + var m struct { + Model string `json:"model"` + } + if json.Unmarshal(body, &m) == nil && m.Model != "" { + return m.Model + } + } + } + // 4. Form value (multipart/form-data) + if model := c.FormValue("model"); model != "" { + return model + } + return "" +} + +// tryAuthenticate attempts to authenticate the request using the database. +func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User { + hmacSecret := appConfig.Auth.APIKeyHMACSecret + + // a. Session cookie + if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" { + if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil { + // Store session for rotation check in middleware + c.Set("_auth_session", session) + return user + } + } + + // b. Authorization: Bearer token + authHeader := c.Request().Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + token := strings.TrimPrefix(authHeader, "Bearer ") + + // Try as session ID first + if user, _ := ValidateSession(db, token, hmacSecret); user != nil { + return user + } + + // Try as user API key + if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil { + return &key.User + } + } + + // c. x-api-key / xi-api-key headers + for _, header := range []string{"x-api-key", "xi-api-key"} { + if key := c.Request().Header.Get(header); key != "" { + if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil { + return &apiKey.User + } + } + } + + // d. token cookie (legacy) + if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { + // Try as user API key + if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil { + return &key.User + } + } + + return nil +} + +// extractKey extracts an API key from the request (all sources). +func extractKey(c echo.Context) string { + // Authorization header + auth := c.Request().Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + if auth != "" { + return auth + } + + // x-api-key + if key := c.Request().Header.Get("x-api-key"); key != "" { + return key + } + + // xi-api-key + if key := c.Request().Header.Get("xi-api-key"); key != "" { + return key + } + + // token cookie + if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { + return cookie.Value + } + + return "" +} + +// isValidLegacyKey checks if the key matches any configured API key +// using constant-time comparison to prevent timing attacks. +func isValidLegacyKey(key string, appConfig *config.ApplicationConfig) bool { + for _, validKey := range appConfig.ApiKeys { + if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { + return true + } + } + return false +} + +// isExemptPath returns true if the path should skip authentication. +func isExemptPath(path string, appConfig *config.ApplicationConfig) bool { + // Auth endpoints are always public + if strings.HasPrefix(path, "/api/auth/") { + return true + } + + // Check configured exempt paths + for _, p := range appConfig.PathWithoutAuth { + if strings.HasPrefix(path, p) { + return true + } + } + + return false +} + +// isAPIPath returns true for paths that always require authentication. +func isAPIPath(path string) bool { + return strings.HasPrefix(path, "/api/") || + strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/models/") || + strings.HasPrefix(path, "/backends/") || + strings.HasPrefix(path, "/backend/") || + strings.HasPrefix(path, "/tts") || + strings.HasPrefix(path, "/vad") || + strings.HasPrefix(path, "/video") || + strings.HasPrefix(path, "/stores/") || + strings.HasPrefix(path, "/system") || + strings.HasPrefix(path, "/ws/") || + strings.HasPrefix(path, "/generated-") || + path == "/metrics" +} + +// authError returns an appropriate error response. +func authError(c echo.Context, appConfig *config.ApplicationConfig) error { + c.Response().Header().Set("WWW-Authenticate", "Bearer") + + if appConfig.OpaqueErrors { + return c.NoContent(http.StatusUnauthorized) + } + + contentType := c.Request().Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "An authentication key is required", + Code: http.StatusUnauthorized, + Type: "invalid_request_error", + }, + }) + } + + return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: "An authentication key is required", + Code: http.StatusUnauthorized, + Type: "invalid_request_error", + }, + }) +} diff --git a/core/http/auth/middleware_test.go b/core/http/auth/middleware_test.go new file mode 100644 index 000000000000..e7b4daa60070 --- /dev/null +++ b/core/http/auth/middleware_test.go @@ -0,0 +1,306 @@ +//go:build auth + +package auth_test + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Auth Middleware", func() { + + Context("auth disabled, no API keys", func() { + var app *echo.Echo + + BeforeEach(func() { + appConfig := config.NewApplicationConfig() + app = newAuthTestApp(nil, appConfig) + }) + + It("passes through all requests", func() { + rec := doRequest(app, http.MethodGet, "/v1/models") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("passes through POST requests", func() { + rec := doRequest(app, http.MethodPost, "/v1/chat/completions") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + }) + + Context("auth disabled, API keys configured", func() { + var app *echo.Echo + const validKey = "sk-test-key-123" + + BeforeEach(func() { + appConfig := config.NewApplicationConfig() + appConfig.ApiKeys = []string{validKey} + app = newAuthTestApp(nil, appConfig) + }) + + It("returns 401 for request without key", func() { + rec := doRequest(app, http.MethodGet, "/v1/models") + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("passes with valid Bearer token", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(validKey)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("passes with valid x-api-key header", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withXApiKey(validKey)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("passes with valid token cookie", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withTokenCookie(validKey)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("returns 401 for invalid key", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("wrong-key")) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("auth enabled with database", func() { + var ( + db *gorm.DB + app *echo.Echo + appConfig *config.ApplicationConfig + user *auth.User + ) + + BeforeEach(func() { + db = testDB() + appConfig = config.NewApplicationConfig() + app = newAuthTestApp(db, appConfig) + user = createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) + }) + + It("allows requests with valid session cookie", func() { + sessionID := createTestSession(db, user.ID) + rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows requests with valid session as Bearer token", func() { + sessionID := createTestSession(db, user.ID) + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows requests with valid user API key as Bearer token", func() { + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test", auth.RoleUser, "", nil) + Expect(err).ToNot(HaveOccurred()) + + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows requests with legacy API_KEY as admin bypass", func() { + appConfig.ApiKeys = []string{"legacy-key-123"} + app = newAuthTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("legacy-key-123")) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("returns 401 for expired session", func() { + sessionID := createTestSession(db, user.ID) + // Manually expire (session ID in DB is the hash) + hash := auth.HashAPIKey(sessionID, "") + db.Model(&auth.Session{}).Where("id = ?", hash). + Update("expires_at", "2020-01-01") + + rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("returns 401 for invalid session ID", func() { + rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie("invalid-session-id")) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("returns 401 for revoked API key", func() { + plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, "", nil) + Expect(err).ToNot(HaveOccurred()) + + err = auth.RevokeAPIKey(db, record.ID, user.ID) + Expect(err).ToNot(HaveOccurred()) + + rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext)) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("skips auth for /api/auth/* paths", func() { + rec := doRequest(app, http.MethodGet, "/api/auth/status") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("skips auth for PathWithoutAuth paths", func() { + rec := doRequest(app, http.MethodGet, "/healthz") + // healthz is not registered in our test app, so it'll be 404/405 but NOT 401 + Expect(rec.Code).ToNot(Equal(http.StatusUnauthorized)) + }) + + It("returns 401 for unauthenticated API requests", func() { + rec := doRequest(app, http.MethodGet, "/v1/models") + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("allows unauthenticated access to non-API paths when no legacy keys", func() { + rec := doRequest(app, http.MethodGet, "/app") + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + }) + + Describe("RequireAdmin", func() { + var ( + db *gorm.DB + appConfig *config.ApplicationConfig + ) + + BeforeEach(func() { + db = testDB() + appConfig = config.NewApplicationConfig() + }) + + It("passes for admin user", func() { + admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub) + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("returns 403 for user role", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("returns 401 when no user in context", func() { + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings") + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("allows admin to access model management", func() { + admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub) + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("blocks user from model management", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("allows user to access regular inference endpoints", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/v1/chat/completions", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows legacy API key (admin bypass) on admin routes", func() { + appConfig.ApiKeys = []string{"admin-key"} + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodPost, "/api/settings", withBearerToken("admin-key")) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("allows admin to access trace endpoints", func() { + admin := createTestUser(db, "admin2@example.com", auth.RoleAdmin, auth.ProviderGitHub) + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("blocks non-admin from trace endpoints", func() { + user := createTestUser(db, "user2@example.com", auth.RoleUser, auth.ProviderGitHub) + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + + rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("allows admin to access agent job endpoints", func() { + admin := createTestUser(db, "admin3@example.com", auth.RoleAdmin, auth.ProviderGitHub) + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK)) + }) + + It("blocks non-admin from agent job endpoints", func() { + user := createTestUser(db, "user3@example.com", auth.RoleUser, auth.ProviderGitHub) + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + + rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("blocks non-admin from system/management endpoints", func() { + user := createTestUser(db, "user4@example.com", auth.RoleUser, auth.ProviderGitHub) + sessionID := createTestSession(db, user.ID) + app := newAdminTestApp(db, appConfig) + + for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} { + rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusForbidden), "expected 403 for path: "+path) + } + }) + + It("allows admin to access system/management endpoints", func() { + admin := createTestUser(db, "admin4@example.com", auth.RoleAdmin, auth.ProviderGitHub) + sessionID := createTestSession(db, admin.ID) + app := newAdminTestApp(db, appConfig) + + for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} { + rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID)) + Expect(rec.Code).To(Equal(http.StatusOK), "expected 200 for path: "+path) + } + }) + }) +}) diff --git a/core/http/auth/models.go b/core/http/auth/models.go new file mode 100644 index 000000000000..598c0342cc6f --- /dev/null +++ b/core/http/auth/models.go @@ -0,0 +1,148 @@ +package auth + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "time" +) + +// Auth provider constants. +const ( + ProviderLocal = "local" + ProviderGitHub = "github" + ProviderOIDC = "oidc" +) + +// User represents an authenticated user. +type User struct { + ID string `gorm:"primaryKey;size:36"` + Email string `gorm:"size:255;index"` + Name string `gorm:"size:255"` + AvatarURL string `gorm:"size:512"` + Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC + Subject string `gorm:"size:255"` // provider-specific user ID + PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users + Role string `gorm:"size:20;default:user"` + Status string `gorm:"size:20;default:active"` // "active", "pending" + CreatedAt time.Time + UpdatedAt time.Time +} + +// Session represents a user login session. +type Session struct { + ID string `gorm:"primaryKey;size:64"` // HMAC-SHA256 hash of session token + UserID string `gorm:"size:36;index"` + ExpiresAt time.Time + RotatedAt time.Time + CreatedAt time.Time + User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} + +// UserAPIKey represents a user-generated API key for programmatic access. +type UserAPIKey struct { + ID string `gorm:"primaryKey;size:36"` + UserID string `gorm:"size:36;index"` + Name string `gorm:"size:255"` // user-provided label + KeyHash string `gorm:"size:64;uniqueIndex"` + KeyPrefix string `gorm:"size:12"` // first 8 chars of key for display + Role string `gorm:"size:20"` + CreatedAt time.Time + ExpiresAt *time.Time `gorm:"index"` + LastUsed *time.Time + User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} + +// PermissionMap is a flexible map of feature -> enabled, stored as JSON text. +// Known features: "agents", "skills", "collections", "mcp_jobs". +// New features can be added without schema changes. +type PermissionMap map[string]bool + +// Value implements driver.Valuer for GORM JSON serialization. +func (p PermissionMap) Value() (driver.Value, error) { + if p == nil { + return "{}", nil + } + b, err := json.Marshal(p) + if err != nil { + return nil, fmt.Errorf("failed to marshal PermissionMap: %w", err) + } + return string(b), nil +} + +// Scan implements sql.Scanner for GORM JSON deserialization. +func (p *PermissionMap) Scan(value any) error { + if value == nil { + *p = PermissionMap{} + return nil + } + var bytes []byte + switch v := value.(type) { + case string: + bytes = []byte(v) + case []byte: + bytes = v + default: + return fmt.Errorf("cannot scan %T into PermissionMap", value) + } + return json.Unmarshal(bytes, p) +} + +// InviteCode represents an admin-generated invitation for user registration. +type InviteCode struct { + ID string `gorm:"primaryKey;size:36"` + Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code + CodePrefix string `gorm:"size:12"` // first 8 chars for admin display + CreatedBy string `gorm:"size:36;not null"` + UsedBy *string `gorm:"size:36"` + UsedAt *time.Time + ExpiresAt time.Time `gorm:"not null;index"` + CreatedAt time.Time + Creator User `gorm:"foreignKey:CreatedBy"` + Consumer *User `gorm:"foreignKey:UsedBy"` +} + +// ModelAllowlist controls which models a user can access. +// When Enabled is false (default), all models are allowed. +type ModelAllowlist struct { + Enabled bool `json:"enabled"` + Models []string `json:"models,omitempty"` +} + +// Value implements driver.Valuer for GORM JSON serialization. +func (m ModelAllowlist) Value() (driver.Value, error) { + b, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("failed to marshal ModelAllowlist: %w", err) + } + return string(b), nil +} + +// Scan implements sql.Scanner for GORM JSON deserialization. +func (m *ModelAllowlist) Scan(value any) error { + if value == nil { + *m = ModelAllowlist{} + return nil + } + var bytes []byte + switch v := value.(type) { + case string: + bytes = []byte(v) + case []byte: + bytes = v + default: + return fmt.Errorf("cannot scan %T into ModelAllowlist", value) + } + return json.Unmarshal(bytes, m) +} + +// UserPermission stores per-user feature permissions. +type UserPermission struct { + ID string `gorm:"primaryKey;size:36"` + UserID string `gorm:"size:36;uniqueIndex"` + Permissions PermissionMap `gorm:"type:text"` + AllowedModels ModelAllowlist `gorm:"type:text"` + CreatedAt time.Time + UpdatedAt time.Time + User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} diff --git a/core/http/auth/oauth.go b/core/http/auth/oauth.go new file mode 100644 index 000000000000..e4b37c4e9c87 --- /dev/null +++ b/core/http/auth/oauth.go @@ -0,0 +1,439 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/xlog" + "golang.org/x/oauth2" + githubOAuth "golang.org/x/oauth2/github" + "gorm.io/gorm" +) + +// providerEntry holds the OAuth2/OIDC config for a single provider. +type providerEntry struct { + oauth2Config oauth2.Config + oidcVerifier *oidc.IDTokenVerifier // nil for GitHub (API-based user info) + name string + userInfoURL string // only used for GitHub +} + +// oauthUserInfo is a provider-agnostic representation of an authenticated user. +type oauthUserInfo struct { + Subject string + Email string + Name string + AvatarURL string +} + +// OAuthManager manages multiple OAuth/OIDC providers. +type OAuthManager struct { + providers map[string]*providerEntry +} + +// OAuthParams groups the parameters needed to create an OAuthManager. +type OAuthParams struct { + GitHubClientID string + GitHubClientSecret string + OIDCIssuer string + OIDCClientID string + OIDCClientSecret string +} + +// NewOAuthManager creates an OAuthManager from the given params. +func NewOAuthManager(baseURL string, params OAuthParams) (*OAuthManager, error) { + m := &OAuthManager{providers: make(map[string]*providerEntry)} + + if params.GitHubClientID != "" { + m.providers[ProviderGitHub] = &providerEntry{ + name: ProviderGitHub, + oauth2Config: oauth2.Config{ + ClientID: params.GitHubClientID, + ClientSecret: params.GitHubClientSecret, + Endpoint: githubOAuth.Endpoint, + RedirectURL: baseURL + "/api/auth/github/callback", + Scopes: []string{"user:email", "read:user"}, + }, + userInfoURL: "https://api.github.com/user", + } + } + + if params.OIDCClientID != "" && params.OIDCIssuer != "" { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + provider, err := oidc.NewProvider(ctx, params.OIDCIssuer) + if err != nil { + return nil, fmt.Errorf("OIDC discovery failed for %s: %w", params.OIDCIssuer, err) + } + + verifier := provider.Verifier(&oidc.Config{ClientID: params.OIDCClientID}) + + m.providers[ProviderOIDC] = &providerEntry{ + name: ProviderOIDC, + oauth2Config: oauth2.Config{ + ClientID: params.OIDCClientID, + ClientSecret: params.OIDCClientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: baseURL + "/api/auth/oidc/callback", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + }, + oidcVerifier: verifier, + } + } + + return m, nil +} + +// Providers returns the list of configured provider names. +func (m *OAuthManager) Providers() []string { + names := make([]string, 0, len(m.providers)) + for name := range m.providers { + names = append(names, name) + } + return names +} + +// LoginHandler redirects the user to the OAuth provider's login page. +func (m *OAuthManager) LoginHandler(providerName string) echo.HandlerFunc { + return func(c echo.Context) error { + provider, ok := m.providers[providerName] + if !ok { + return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"}) + } + + state, err := generateState() + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate state"}) + } + + secure := isSecure(c) + c.SetCookie(&http.Cookie{ + Name: "oauth_state", + Value: state, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + MaxAge: 600, // 10 minutes + }) + + // Store invite code in cookie if provided + if inviteCode := c.QueryParam("invite_code"); inviteCode != "" { + c.SetCookie(&http.Cookie{ + Name: "invite_code", + Value: inviteCode, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + MaxAge: 600, + }) + } + + url := provider.oauth2Config.AuthCodeURL(state) + return c.Redirect(http.StatusTemporaryRedirect, url) + } +} + +// CallbackHandler handles the OAuth callback, creates/updates the user, and +// creates a session. +func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEmail, registrationMode, hmacSecret string) echo.HandlerFunc { + return func(c echo.Context) error { + provider, ok := m.providers[providerName] + if !ok { + return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"}) + } + + // Validate state + stateCookie, err := c.Cookie("oauth_state") + if err != nil || stateCookie.Value == "" || subtle.ConstantTimeCompare([]byte(stateCookie.Value), []byte(c.QueryParam("state"))) != 1 { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid OAuth state"}) + } + + // Clear state cookie + c.SetCookie(&http.Cookie{ + Name: "oauth_state", + Value: "", + Path: "/", + HttpOnly: true, + Secure: isSecure(c), + MaxAge: -1, + }) + + // Exchange code for token + code := c.QueryParam("code") + if code == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing authorization code"}) + } + + ctx, cancel := context.WithTimeout(c.Request().Context(), 30*time.Second) + defer cancel() + + token, err := provider.oauth2Config.Exchange(ctx, code) + if err != nil { + xlog.Error("OAuth code exchange failed", "provider", providerName, "error", err) + return c.JSON(http.StatusBadRequest, map[string]string{"error": "OAuth authentication failed"}) + } + + // Fetch user info — branch based on provider type + var userInfo *oauthUserInfo + if provider.oidcVerifier != nil { + userInfo, err = extractOIDCUserInfo(ctx, provider.oidcVerifier, token) + } else { + userInfo, err = fetchGitHubUserInfoAsOAuth(ctx, token.AccessToken) + } + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to fetch user info"}) + } + + // Retrieve invite code from cookie if present + var inviteCode string + if ic, err := c.Cookie("invite_code"); err == nil && ic.Value != "" { + inviteCode = ic.Value + // Clear the invite code cookie + c.SetCookie(&http.Cookie{ + Name: "invite_code", + Value: "", + Path: "/", + HttpOnly: true, + Secure: isSecure(c), + MaxAge: -1, + }) + } + + // Upsert user (with invite code support) + user, err := upsertOAuthUser(db, providerName, userInfo, adminEmail, registrationMode) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create user"}) + } + + // For new users that are pending, check if they have a valid invite + if user.Status != StatusActive && inviteCode != "" { + if invite, err := ValidateInvite(db, inviteCode, hmacSecret); err == nil { + user.Status = StatusActive + db.Model(user).Update("status", StatusActive) + ConsumeInvite(db, invite, user.ID) + } + } + + if user.Status != StatusActive { + if registrationMode == "invite" { + return c.JSON(http.StatusForbidden, map[string]string{"error": "a valid invite code is required to register"}) + } + return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"}) + } + + // Maybe promote on login + MaybePromote(db, user, adminEmail) + + // Create session + sessionID, err := CreateSession(db, user.ID, hmacSecret) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create session"}) + } + + SetSessionCookie(c, sessionID) + return c.Redirect(http.StatusTemporaryRedirect, "/app") + } +} + +// extractOIDCUserInfo extracts user info from the OIDC ID token. +func extractOIDCUserInfo(ctx context.Context, verifier *oidc.IDTokenVerifier, token *oauth2.Token) (*oauthUserInfo, error) { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok || rawIDToken == "" { + return nil, fmt.Errorf("no id_token in token response") + } + + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("failed to verify ID token: %w", err) + } + + var claims struct { + Sub string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + Picture string `json:"picture"` + } + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse ID token claims: %w", err) + } + + return &oauthUserInfo{ + Subject: claims.Sub, + Email: claims.Email, + Name: claims.Name, + AvatarURL: claims.Picture, + }, nil +} + +type githubUserInfo struct { + ID int `json:"id"` + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` + AvatarURL string `json:"avatar_url"` +} + +type githubEmail struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +// fetchGitHubUserInfoAsOAuth fetches GitHub user info and returns it as oauthUserInfo. +func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauthUserInfo, error) { + info, err := fetchGitHubUserInfo(ctx, accessToken) + if err != nil { + return nil, err + } + return &oauthUserInfo{ + Subject: fmt.Sprintf("%d", info.ID), + Email: info.Email, + Name: info.Name, + AvatarURL: info.AvatarURL, + }, nil +} + +func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) { + client := &http.Client{Timeout: 10 * time.Second} + + req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var info githubUserInfo + if err := json.Unmarshal(body, &info); err != nil { + return nil, err + } + + // If no public email, fetch from /user/emails + if info.Email == "" { + info.Email, _ = fetchGitHubPrimaryEmail(ctx, accessToken) + } + + return &info, nil +} + +func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) { + client := &http.Client{Timeout: 10 * time.Second} + + req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var emails []githubEmail + if err := json.Unmarshal(body, &emails); err != nil { + return "", err + } + + for _, e := range emails { + if e.Primary && e.Verified { + return e.Email, nil + } + } + + // Fall back to first verified email + for _, e := range emails { + if e.Verified { + return e.Email, nil + } + } + + return "", fmt.Errorf("no verified email found") +} + +func upsertOAuthUser(db *gorm.DB, provider string, info *oauthUserInfo, adminEmail, registrationMode string) (*User, error) { + // Normalize email from provider (#10) + if info.Email != "" { + info.Email = strings.ToLower(strings.TrimSpace(info.Email)) + } + + var user User + err := db.Where("provider = ? AND subject = ?", provider, info.Subject).First(&user).Error + if err == nil { + // Existing user — update profile fields + user.Name = info.Name + user.AvatarURL = info.AvatarURL + if info.Email != "" { + user.Email = info.Email + } + db.Save(&user) + return &user, nil + } + + // New user — empty registration mode defaults to "approval" + effectiveMode := registrationMode + if effectiveMode == "" { + effectiveMode = "approval" + } + status := StatusActive + if effectiveMode == "approval" || effectiveMode == "invite" { + status = StatusPending + } + + role := AssignRole(db, info.Email, adminEmail) + // First user is always active regardless of registration mode + if role == RoleAdmin { + status = StatusActive + } + + user = User{ + ID: uuid.New().String(), + Email: info.Email, + Name: info.Name, + AvatarURL: info.AvatarURL, + Provider: provider, + Subject: info.Subject, + Role: role, + Status: status, + } + + if err := db.Create(&user).Error; err != nil { + return nil, err + } + + return &user, nil +} + +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/core/http/auth/password.go b/core/http/auth/password.go new file mode 100644 index 000000000000..4c88fedb7267 --- /dev/null +++ b/core/http/auth/password.go @@ -0,0 +1,14 @@ +package auth + +import "golang.org/x/crypto/bcrypt" + +// HashPassword returns a bcrypt hash of the given password. +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +// CheckPassword compares a bcrypt hash with a plaintext password. +func CheckPassword(hash, password string) bool { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil +} diff --git a/core/http/auth/permissions.go b/core/http/auth/permissions.go new file mode 100644 index 000000000000..b2408ad4ee40 --- /dev/null +++ b/core/http/auth/permissions.go @@ -0,0 +1,211 @@ +package auth + +import ( + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +const contextKeyPermissions = "auth_permissions" + +// GetCachedUserPermissions returns the user's permission record, using a +// request-scoped cache stored in the echo context. This avoids duplicate +// DB lookups when multiple middlewares (RequireRouteFeature, RequireModelAccess) +// both need permissions in the same request. +func GetCachedUserPermissions(c echo.Context, db *gorm.DB, userID string) (*UserPermission, error) { + if perm, ok := c.Get(contextKeyPermissions).(*UserPermission); ok && perm != nil { + return perm, nil + } + perm, err := GetUserPermissions(db, userID) + if err != nil { + return nil, err + } + c.Set(contextKeyPermissions, perm) + return perm, nil +} + +// Feature name constants — all code must use these, never bare strings. +const ( + // Agent features (default OFF for new users) + FeatureAgents = "agents" + FeatureSkills = "skills" + FeatureCollections = "collections" + FeatureMCPJobs = "mcp_jobs" + + // API features (default ON for new users) + FeatureChat = "chat" + FeatureImages = "images" + FeatureAudioSpeech = "audio_speech" + FeatureAudioTranscription = "audio_transcription" + FeatureVAD = "vad" + FeatureDetection = "detection" + FeatureVideo = "video" + FeatureEmbeddings = "embeddings" + FeatureSound = "sound" + FeatureRealtime = "realtime" + FeatureRerank = "rerank" + FeatureTokenize = "tokenize" + FeatureMCP = "mcp" + FeatureStores = "stores" +) + +// AgentFeatures lists agent-related features (default OFF). +var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs} + +// APIFeatures lists API endpoint features (default ON). +var APIFeatures = []string{ + FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription, + FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound, + FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores, +} + +// AllFeatures lists all known features (used by UI and validation). +var AllFeatures = append(append([]string{}, AgentFeatures...), APIFeatures...) + +// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map. +var defaultOnFeatures = func() map[string]bool { + m := map[string]bool{} + for _, f := range APIFeatures { + m[f] = true + } + return m +}() + +// isDefaultOnFeature returns true if the feature defaults to ON when not explicitly set. +func isDefaultOnFeature(feature string) bool { + return defaultOnFeatures[feature] +} + +// GetUserPermissions returns the permission record for a user, creating a default +// (empty map = all disabled) if none exists. +func GetUserPermissions(db *gorm.DB, userID string) (*UserPermission, error) { + var perm UserPermission + err := db.Where("user_id = ?", userID).First(&perm).Error + if err == gorm.ErrRecordNotFound { + perm = UserPermission{ + ID: uuid.New().String(), + UserID: userID, + Permissions: PermissionMap{}, + } + if err := db.Create(&perm).Error; err != nil { + return nil, err + } + return &perm, nil + } + if err != nil { + return nil, err + } + return &perm, nil +} + +// UpdateUserPermissions upserts the permission map for a user. +func UpdateUserPermissions(db *gorm.DB, userID string, perms PermissionMap) error { + var perm UserPermission + err := db.Where("user_id = ?", userID).First(&perm).Error + if err == gorm.ErrRecordNotFound { + perm = UserPermission{ + ID: uuid.New().String(), + UserID: userID, + Permissions: perms, + } + return db.Create(&perm).Error + } + if err != nil { + return err + } + perm.Permissions = perms + return db.Save(&perm).Error +} + +// HasFeatureAccess returns true if the user is an admin or has the given feature enabled. +// When a feature key is absent from the user's permission map, it checks whether the +// feature defaults to ON (API features) or OFF (agent features) for backward compatibility. +func HasFeatureAccess(db *gorm.DB, user *User, feature string) bool { + if user == nil { + return false + } + if user.Role == RoleAdmin { + return true + } + perm, err := GetUserPermissions(db, user.ID) + if err != nil { + return false + } + val, exists := perm.Permissions[feature] + if !exists { + return isDefaultOnFeature(feature) + } + return val +} + +// GetPermissionMapForUser returns the effective permission map for a user. +// Admins get all features as true (virtual). +// For regular users, absent keys are filled with their defaults so the +// UI/API always returns a complete picture. +func GetPermissionMapForUser(db *gorm.DB, user *User) PermissionMap { + if user == nil { + return PermissionMap{} + } + if user.Role == RoleAdmin { + m := PermissionMap{} + for _, f := range AllFeatures { + m[f] = true + } + return m + } + perm, err := GetUserPermissions(db, user.ID) + if err != nil { + return PermissionMap{} + } + // Fill in defaults for absent keys + effective := PermissionMap{} + for _, f := range AllFeatures { + val, exists := perm.Permissions[f] + if exists { + effective[f] = val + } else { + effective[f] = isDefaultOnFeature(f) + } + } + return effective +} + +// GetModelAllowlist returns the model allowlist for a user. +func GetModelAllowlist(db *gorm.DB, userID string) ModelAllowlist { + perm, err := GetUserPermissions(db, userID) + if err != nil { + return ModelAllowlist{} + } + return perm.AllowedModels +} + +// UpdateModelAllowlist updates the model allowlist for a user. +func UpdateModelAllowlist(db *gorm.DB, userID string, allowlist ModelAllowlist) error { + perm, err := GetUserPermissions(db, userID) + if err != nil { + return err + } + perm.AllowedModels = allowlist + return db.Save(perm).Error +} + +// IsModelAllowed returns true if the user is allowed to use the given model. +// Admins always have access. If the allowlist is not enabled, all models are allowed. +func IsModelAllowed(db *gorm.DB, user *User, modelName string) bool { + if user == nil { + return false + } + if user.Role == RoleAdmin { + return true + } + allowlist := GetModelAllowlist(db, user.ID) + if !allowlist.Enabled { + return true + } + for _, m := range allowlist.Models { + if m == modelName { + return true + } + } + return false +} diff --git a/core/http/auth/roles.go b/core/http/auth/roles.go new file mode 100644 index 000000000000..1f01889590c9 --- /dev/null +++ b/core/http/auth/roles.go @@ -0,0 +1,103 @@ +package auth + +import ( + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +const ( + RoleAdmin = "admin" + RoleUser = "user" + + StatusActive = "active" + StatusPending = "pending" + StatusDisabled = "disabled" +) + +// AssignRole determines the role for a new user. +// First user in the database becomes admin. If adminEmail is set and matches, +// the user becomes admin. Otherwise, the user gets the "user" role. +// Must be called within a transaction that also creates the user to prevent +// race conditions on the first-user admin assignment. +func AssignRole(tx *gorm.DB, email, adminEmail string) string { + var count int64 + tx.Model(&User{}).Count(&count) + if count == 0 { + return RoleAdmin + } + + if adminEmail != "" && strings.EqualFold(email, adminEmail) { + return RoleAdmin + } + + return RoleUser +} + +// MaybePromote promotes a user to admin on login if their email matches +// adminEmail. It does not demote existing admins. Returns true if the user +// was promoted. +func MaybePromote(db *gorm.DB, user *User, adminEmail string) bool { + if user.Role == RoleAdmin { + return false + } + + if adminEmail != "" && strings.EqualFold(user.Email, adminEmail) { + user.Role = RoleAdmin + db.Model(user).Update("role", RoleAdmin) + return true + } + + return false +} + +// ValidateInvite checks that an invite code exists, is unused, and has not expired. +// The code is hashed with HMAC-SHA256 before lookup. +func ValidateInvite(db *gorm.DB, code, hmacSecret string) (*InviteCode, error) { + hash := HashAPIKey(code, hmacSecret) + var invite InviteCode + if err := db.Where("code = ?", hash).First(&invite).Error; err != nil { + return nil, fmt.Errorf("invite code not found") + } + if invite.UsedBy != nil { + return nil, fmt.Errorf("invite code already used") + } + if time.Now().After(invite.ExpiresAt) { + return nil, fmt.Errorf("invite code expired") + } + return &invite, nil +} + +// ConsumeInvite marks an invite code as used by the given user. +func ConsumeInvite(db *gorm.DB, invite *InviteCode, userID string) { + now := time.Now() + invite.UsedBy = &userID + invite.UsedAt = &now + db.Save(invite) +} + +// NeedsInviteOrApproval returns true if registration gating applies for the given mode. +// Admins (first user or matching adminEmail) are never gated. +// Must be called within a transaction that also creates the user. +func NeedsInviteOrApproval(tx *gorm.DB, email, adminEmail, registrationMode string) bool { + // Empty registration mode defaults to "approval" + if registrationMode == "" { + registrationMode = "approval" + } + if registrationMode != "approval" && registrationMode != "invite" { + return false + } + // Admin email is never gated + if adminEmail != "" && strings.EqualFold(email, adminEmail) { + return false + } + // First user is never gated + var count int64 + tx.Model(&User{}).Count(&count) + if count == 0 { + return false + } + return true +} diff --git a/core/http/auth/roles_test.go b/core/http/auth/roles_test.go new file mode 100644 index 000000000000..a12e237d48bc --- /dev/null +++ b/core/http/auth/roles_test.go @@ -0,0 +1,84 @@ +//go:build auth + +package auth_test + +import ( + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Roles", func() { + var db *gorm.DB + + BeforeEach(func() { + db = testDB() + }) + + Describe("AssignRole", func() { + It("returns admin for the first user (empty DB)", func() { + role := auth.AssignRole(db, "first@example.com", "") + Expect(role).To(Equal(auth.RoleAdmin)) + }) + + It("returns user for the second user", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) + + role := auth.AssignRole(db, "second@example.com", "") + Expect(role).To(Equal(auth.RoleUser)) + }) + + It("returns admin when email matches adminEmail", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) + + role := auth.AssignRole(db, "admin@example.com", "admin@example.com") + Expect(role).To(Equal(auth.RoleAdmin)) + }) + + It("is case-insensitive for admin email match", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) + + role := auth.AssignRole(db, "Admin@Example.COM", "admin@example.com") + Expect(role).To(Equal(auth.RoleAdmin)) + }) + + It("returns user when email does not match adminEmail", func() { + createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) + + role := auth.AssignRole(db, "other@example.com", "admin@example.com") + Expect(role).To(Equal(auth.RoleUser)) + }) + }) + + Describe("MaybePromote", func() { + It("promotes user to admin when email matches", func() { + user := createTestUser(db, "promoted@example.com", auth.RoleUser, auth.ProviderGitHub) + + promoted := auth.MaybePromote(db, user, "promoted@example.com") + Expect(promoted).To(BeTrue()) + Expect(user.Role).To(Equal(auth.RoleAdmin)) + + // Verify in DB + var dbUser auth.User + db.First(&dbUser, "id = ?", user.ID) + Expect(dbUser.Role).To(Equal(auth.RoleAdmin)) + }) + + It("does not promote when email does not match", func() { + user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) + + promoted := auth.MaybePromote(db, user, "admin@example.com") + Expect(promoted).To(BeFalse()) + Expect(user.Role).To(Equal(auth.RoleUser)) + }) + + It("does not demote an existing admin", func() { + user := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub) + + promoted := auth.MaybePromote(db, user, "other@example.com") + Expect(promoted).To(BeFalse()) + Expect(user.Role).To(Equal(auth.RoleAdmin)) + }) + }) +}) diff --git a/core/http/auth/session.go b/core/http/auth/session.go new file mode 100644 index 000000000000..7c8bf68b68fd --- /dev/null +++ b/core/http/auth/session.go @@ -0,0 +1,182 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "time" + + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +const ( + sessionDuration = 30 * 24 * time.Hour // 30 days + sessionIDBytes = 32 // 32 bytes = 64 hex chars + sessionCookie = "session" + sessionRotationInterval = 1 * time.Hour +) + +// CreateSession creates a new session for the given user, returning the +// plaintext token (64-char hex string). The stored session ID is the +// HMAC-SHA256 hash of the token. +func CreateSession(db *gorm.DB, userID, hmacSecret string) (string, error) { + b := make([]byte, sessionIDBytes) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate session ID: %w", err) + } + + plaintext := hex.EncodeToString(b) + hash := HashAPIKey(plaintext, hmacSecret) + + now := time.Now() + session := Session{ + ID: hash, + UserID: userID, + ExpiresAt: now.Add(sessionDuration), + RotatedAt: now, + } + + if err := db.Create(&session).Error; err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + + return plaintext, nil +} + +// ValidateSession hashes the plaintext token and looks up the session. +// Returns the associated user and session, or (nil, nil) if not found/expired. +func ValidateSession(db *gorm.DB, token, hmacSecret string) (*User, *Session) { + hash := HashAPIKey(token, hmacSecret) + + var session Session + if err := db.Preload("User").Where("id = ? AND expires_at > ?", hash, time.Now()).First(&session).Error; err != nil { + return nil, nil + } + if session.User.Status != StatusActive { + return nil, nil + } + return &session.User, &session +} + +// DeleteSession removes a session by hashing the plaintext token. +func DeleteSession(db *gorm.DB, token, hmacSecret string) error { + hash := HashAPIKey(token, hmacSecret) + return db.Where("id = ?", hash).Delete(&Session{}).Error +} + +// CleanExpiredSessions removes all sessions that have passed their expiry time. +func CleanExpiredSessions(db *gorm.DB) error { + return db.Where("expires_at < ?", time.Now()).Delete(&Session{}).Error +} + +// DeleteUserSessions removes all sessions for the given user. +func DeleteUserSessions(db *gorm.DB, userID string) error { + return db.Where("user_id = ?", userID).Delete(&Session{}).Error +} + +// RotateSession creates a new session for the same user, deletes the old one, +// and returns the new plaintext token. +func RotateSession(db *gorm.DB, oldSession *Session, hmacSecret string) (string, error) { + b := make([]byte, sessionIDBytes) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate session ID: %w", err) + } + + plaintext := hex.EncodeToString(b) + hash := HashAPIKey(plaintext, hmacSecret) + + now := time.Now() + newSession := Session{ + ID: hash, + UserID: oldSession.UserID, + ExpiresAt: oldSession.ExpiresAt, + RotatedAt: now, + } + + err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(&newSession).Error; err != nil { + return err + } + return tx.Where("id = ?", oldSession.ID).Delete(&Session{}).Error + }) + if err != nil { + return "", fmt.Errorf("failed to rotate session: %w", err) + } + + return plaintext, nil +} + +// MaybeRotateSession checks if the session should be rotated and does so if needed. +// Called from the auth middleware after successful cookie-based authentication. +func MaybeRotateSession(c echo.Context, db *gorm.DB, session *Session, hmacSecret string) { + if session == nil { + return + } + + rotatedAt := session.RotatedAt + if rotatedAt.IsZero() { + rotatedAt = session.CreatedAt + } + + if time.Since(rotatedAt) < sessionRotationInterval { + return + } + + newToken, err := RotateSession(db, session, hmacSecret) + if err != nil { + // Rotation failure is non-fatal; the old session remains valid + return + } + + SetSessionCookie(c, newToken) +} + +// isSecure returns true when the request arrived over HTTPS, either directly +// or via a reverse proxy that sets X-Forwarded-Proto. +func isSecure(c echo.Context) bool { + return c.Scheme() == "https" +} + +// SetSessionCookie sets the session cookie on the response. +func SetSessionCookie(c echo.Context, sessionID string) { + cookie := &http.Cookie{ + Name: sessionCookie, + Value: sessionID, + Path: "/", + HttpOnly: true, + Secure: isSecure(c), + SameSite: http.SameSiteLaxMode, + MaxAge: int(sessionDuration.Seconds()), + } + c.SetCookie(cookie) +} + +// SetTokenCookie sets an httpOnly "token" cookie for legacy API key auth. +func SetTokenCookie(c echo.Context, token string) { + cookie := &http.Cookie{ + Name: "token", + Value: token, + Path: "/", + HttpOnly: true, + Secure: isSecure(c), + SameSite: http.SameSiteLaxMode, + MaxAge: int(sessionDuration.Seconds()), + } + c.SetCookie(cookie) +} + +// ClearSessionCookie clears the session cookie. +func ClearSessionCookie(c echo.Context) { + cookie := &http.Cookie{ + Name: sessionCookie, + Value: "", + Path: "/", + HttpOnly: true, + Secure: isSecure(c), + SameSite: http.SameSiteLaxMode, + MaxAge: -1, + } + c.SetCookie(cookie) +} diff --git a/core/http/auth/session_test.go b/core/http/auth/session_test.go new file mode 100644 index 000000000000..e24e02c0f344 --- /dev/null +++ b/core/http/auth/session_test.go @@ -0,0 +1,272 @@ +//go:build auth + +package auth_test + +import ( + "time" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Sessions", func() { + var ( + db *gorm.DB + user *auth.User + ) + + // Use empty HMAC secret for basic tests + hmacSecret := "" + + BeforeEach(func() { + db = testDB() + user = createTestUser(db, "session@example.com", auth.RoleUser, auth.ProviderGitHub) + }) + + Describe("CreateSession", func() { + It("creates a session and returns 64-char hex plaintext token", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(token).To(HaveLen(64)) + }) + + It("stores the hash (not plaintext) in the DB", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + hash := auth.HashAPIKey(token, hmacSecret) + var session auth.Session + err = db.First(&session, "id = ?", hash).Error + Expect(err).ToNot(HaveOccurred()) + Expect(session.UserID).To(Equal(user.ID)) + // The plaintext token should NOT be stored as the ID + Expect(session.ID).ToNot(Equal(token)) + Expect(session.ID).To(Equal(hash)) + }) + + It("sets expiry to approximately 30 days from now", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + hash := auth.HashAPIKey(token, hmacSecret) + var session auth.Session + db.First(&session, "id = ?", hash) + + expectedExpiry := time.Now().Add(30 * 24 * time.Hour) + Expect(session.ExpiresAt).To(BeTemporally("~", expectedExpiry, time.Minute)) + }) + + It("sets RotatedAt on creation", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + hash := auth.HashAPIKey(token, hmacSecret) + var session auth.Session + db.First(&session, "id = ?", hash) + + Expect(session.RotatedAt).To(BeTemporally("~", time.Now(), time.Minute)) + }) + + It("associates session with correct user", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + hash := auth.HashAPIKey(token, hmacSecret) + var session auth.Session + db.First(&session, "id = ?", hash) + Expect(session.UserID).To(Equal(user.ID)) + }) + }) + + Describe("ValidateSession", func() { + It("returns user for valid session", func() { + token := createTestSession(db, user.ID) + + found, session := auth.ValidateSession(db, token, hmacSecret) + Expect(found).ToNot(BeNil()) + Expect(found.ID).To(Equal(user.ID)) + Expect(session).ToNot(BeNil()) + }) + + It("returns nil for non-existent session", func() { + found, session := auth.ValidateSession(db, "nonexistent-session-id", hmacSecret) + Expect(found).To(BeNil()) + Expect(session).To(BeNil()) + }) + + It("returns nil for expired session", func() { + token := createTestSession(db, user.ID) + hash := auth.HashAPIKey(token, hmacSecret) + + // Manually expire the session + db.Model(&auth.Session{}).Where("id = ?", hash). + Update("expires_at", time.Now().Add(-1*time.Hour)) + + found, _ := auth.ValidateSession(db, token, hmacSecret) + Expect(found).To(BeNil()) + }) + }) + + Describe("DeleteSession", func() { + It("removes the session from DB", func() { + token := createTestSession(db, user.ID) + + err := auth.DeleteSession(db, token, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + found, _ := auth.ValidateSession(db, token, hmacSecret) + Expect(found).To(BeNil()) + }) + + It("does not error on non-existent session", func() { + err := auth.DeleteSession(db, "nonexistent", hmacSecret) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Describe("CleanExpiredSessions", func() { + It("removes expired sessions", func() { + token := createTestSession(db, user.ID) + hash := auth.HashAPIKey(token, hmacSecret) + + // Manually expire the session + db.Model(&auth.Session{}).Where("id = ?", hash). + Update("expires_at", time.Now().Add(-1*time.Hour)) + + err := auth.CleanExpiredSessions(db) + Expect(err).ToNot(HaveOccurred()) + + var count int64 + db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count) + Expect(count).To(Equal(int64(0))) + }) + + It("keeps active sessions", func() { + token := createTestSession(db, user.ID) + hash := auth.HashAPIKey(token, hmacSecret) + + err := auth.CleanExpiredSessions(db) + Expect(err).ToNot(HaveOccurred()) + + var count int64 + db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count) + Expect(count).To(Equal(int64(1))) + }) + }) + + Describe("RotateSession", func() { + It("creates a new session and deletes the old one", func() { + token := createTestSession(db, user.ID) + hash := auth.HashAPIKey(token, hmacSecret) + + // Get the old session + var oldSession auth.Session + db.First(&oldSession, "id = ?", hash) + + newToken, err := auth.RotateSession(db, &oldSession, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + Expect(newToken).To(HaveLen(64)) + Expect(newToken).ToNot(Equal(token)) + + // Old session should be gone + var count int64 + db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count) + Expect(count).To(Equal(int64(0))) + + // New session should exist and validate + found, _ := auth.ValidateSession(db, newToken, hmacSecret) + Expect(found).ToNot(BeNil()) + Expect(found.ID).To(Equal(user.ID)) + }) + + It("preserves user ID and expiry", func() { + token := createTestSession(db, user.ID) + hash := auth.HashAPIKey(token, hmacSecret) + + var oldSession auth.Session + db.First(&oldSession, "id = ?", hash) + + newToken, err := auth.RotateSession(db, &oldSession, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + newHash := auth.HashAPIKey(newToken, hmacSecret) + var newSession auth.Session + db.First(&newSession, "id = ?", newHash) + + Expect(newSession.UserID).To(Equal(oldSession.UserID)) + Expect(newSession.ExpiresAt).To(BeTemporally("~", oldSession.ExpiresAt, time.Second)) + }) + }) + + Context("with HMAC secret", func() { + hmacSecret := "test-hmac-secret-123" + + It("creates and validates sessions with HMAC secret", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + found, session := auth.ValidateSession(db, token, hmacSecret) + Expect(found).ToNot(BeNil()) + Expect(found.ID).To(Equal(user.ID)) + Expect(session).ToNot(BeNil()) + }) + + It("does not validate with wrong HMAC secret", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + found, _ := auth.ValidateSession(db, token, "wrong-secret") + Expect(found).To(BeNil()) + }) + + It("does not validate with empty HMAC secret", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + found, _ := auth.ValidateSession(db, token, "") + Expect(found).To(BeNil()) + }) + + It("session created with empty secret does not validate with non-empty secret", func() { + token, err := auth.CreateSession(db, user.ID, "") + Expect(err).ToNot(HaveOccurred()) + + found, _ := auth.ValidateSession(db, token, hmacSecret) + Expect(found).To(BeNil()) + }) + + It("deletes session with correct HMAC secret", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + err = auth.DeleteSession(db, token, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + found, _ := auth.ValidateSession(db, token, hmacSecret) + Expect(found).To(BeNil()) + }) + + It("rotates session with HMAC secret", func() { + token, err := auth.CreateSession(db, user.ID, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + hash := auth.HashAPIKey(token, hmacSecret) + var oldSession auth.Session + db.First(&oldSession, "id = ?", hash) + + newToken, err := auth.RotateSession(db, &oldSession, hmacSecret) + Expect(err).ToNot(HaveOccurred()) + + // Old token should not validate + found, _ := auth.ValidateSession(db, token, hmacSecret) + Expect(found).To(BeNil()) + + // New token should validate + found, _ = auth.ValidateSession(db, newToken, hmacSecret) + Expect(found).ToNot(BeNil()) + Expect(found.ID).To(Equal(user.ID)) + }) + }) +}) diff --git a/core/http/auth/usage.go b/core/http/auth/usage.go new file mode 100644 index 000000000000..08841a442dbc --- /dev/null +++ b/core/http/auth/usage.go @@ -0,0 +1,151 @@ +package auth + +import ( + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +// UsageRecord represents a single API request's token usage. +type UsageRecord struct { + ID uint `gorm:"primaryKey;autoIncrement"` + UserID string `gorm:"size:36;index:idx_usage_user_time"` + UserName string `gorm:"size:255"` + Model string `gorm:"size:255;index"` + Endpoint string `gorm:"size:255"` + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 + Duration int64 // milliseconds + CreatedAt time.Time `gorm:"index:idx_usage_user_time"` +} + +// RecordUsage inserts a usage record. +func RecordUsage(db *gorm.DB, record *UsageRecord) error { + return db.Create(record).Error +} + +// UsageBucket is an aggregated time bucket for the dashboard. +type UsageBucket struct { + Bucket string `json:"bucket"` + Model string `json:"model"` + UserID string `json:"user_id,omitempty"` + UserName string `json:"user_name,omitempty"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` +} + +// UsageTotals is a summary of all usage. +type UsageTotals struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + RequestCount int64 `json:"request_count"` +} + +// periodToWindow returns the time window and SQL date format for a period. +func periodToWindow(period string, isSQLite bool) (time.Time, string) { + now := time.Now() + var since time.Time + var dateFmt string + + switch period { + case "day": + since = now.Add(-24 * time.Hour) + if isSQLite { + dateFmt = "strftime('%Y-%m-%d %H:00', created_at)" + } else { + dateFmt = "to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00')" + } + case "week": + since = now.Add(-7 * 24 * time.Hour) + if isSQLite { + dateFmt = "strftime('%Y-%m-%d', created_at)" + } else { + dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')" + } + case "all": + since = time.Time{} // zero time = no filter + if isSQLite { + dateFmt = "strftime('%Y-%m', created_at)" + } else { + dateFmt = "to_char(date_trunc('month', created_at), 'YYYY-MM')" + } + default: // "month" + since = now.Add(-30 * 24 * time.Hour) + if isSQLite { + dateFmt = "strftime('%Y-%m-%d', created_at)" + } else { + dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')" + } + } + + return since, dateFmt +} + +func isSQLiteDB(db *gorm.DB) bool { + return strings.Contains(db.Dialector.Name(), "sqlite") +} + +// GetUserUsage returns aggregated usage for a single user. +func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", model, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Where("user_id = ?", userID). + Group("bucket, model"). + Order("bucket ASC") + + if !since.IsZero() { + query = query.Where("created_at >= ?", since) + } + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, err + } + return buckets, nil +} + +// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter. +func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", model, user_id, user_name, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Group("bucket, model, user_id, user_name"). + Order("bucket ASC") + + if !since.IsZero() { + query = query.Where("created_at >= ?", since) + } + + if userID != "" { + query = query.Where("user_id = ?", userID) + } + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, err + } + return buckets, nil +} diff --git a/core/http/auth/usage_test.go b/core/http/auth/usage_test.go new file mode 100644 index 000000000000..0c3fa5df5846 --- /dev/null +++ b/core/http/auth/usage_test.go @@ -0,0 +1,161 @@ +//go:build auth + +package auth_test + +import ( + "time" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Usage", func() { + Describe("RecordUsage", func() { + It("inserts a usage record", func() { + db := testDB() + record := &auth.UsageRecord{ + UserID: "user-1", + UserName: "Test User", + Model: "gpt-4", + Endpoint: "/v1/chat/completions", + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + Duration: 1200, + CreatedAt: time.Now(), + } + err := auth.RecordUsage(db, record) + Expect(err).ToNot(HaveOccurred()) + Expect(record.ID).ToNot(BeZero()) + }) + }) + + Describe("GetUserUsage", func() { + It("returns aggregated usage for a specific user", func() { + db := testDB() + + // Insert records for two users + for i := 0; i < 3; i++ { + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-a", + UserName: "Alice", + Model: "gpt-4", + Endpoint: "/v1/chat/completions", + PromptTokens: 100, + TotalTokens: 150, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + } + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-b", + UserName: "Bob", + Model: "gpt-4", + PromptTokens: 200, + TotalTokens: 300, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + buckets, err := auth.GetUserUsage(db, "user-a", "month") + Expect(err).ToNot(HaveOccurred()) + Expect(buckets).ToNot(BeEmpty()) + + // All returned buckets should be for user-a's model + totalPrompt := int64(0) + for _, b := range buckets { + totalPrompt += b.PromptTokens + } + Expect(totalPrompt).To(Equal(int64(300))) + }) + + It("filters by period", func() { + db := testDB() + + // Record in the past (beyond day window) + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-c", + UserName: "Carol", + Model: "gpt-4", + PromptTokens: 100, + TotalTokens: 100, + CreatedAt: time.Now().Add(-48 * time.Hour), + }) + Expect(err).ToNot(HaveOccurred()) + + // Record now + err = auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-c", + UserName: "Carol", + Model: "gpt-4", + PromptTokens: 200, + TotalTokens: 200, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + // Day period should only include recent record + buckets, err := auth.GetUserUsage(db, "user-c", "day") + Expect(err).ToNot(HaveOccurred()) + totalPrompt := int64(0) + for _, b := range buckets { + totalPrompt += b.PromptTokens + } + Expect(totalPrompt).To(Equal(int64(200))) + + // Month period should include both + buckets, err = auth.GetUserUsage(db, "user-c", "month") + Expect(err).ToNot(HaveOccurred()) + totalPrompt = 0 + for _, b := range buckets { + totalPrompt += b.PromptTokens + } + Expect(totalPrompt).To(Equal(int64(300))) + }) + }) + + Describe("GetAllUsage", func() { + It("returns usage for all users", func() { + db := testDB() + + for _, uid := range []string{"user-x", "user-y"} { + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: uid, + UserName: uid, + Model: "gpt-4", + PromptTokens: 100, + TotalTokens: 150, + CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + } + + buckets, err := auth.GetAllUsage(db, "month", "") + Expect(err).ToNot(HaveOccurred()) + Expect(len(buckets)).To(BeNumerically(">=", 2)) + }) + + It("filters by user ID when specified", func() { + db := testDB() + + err := auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-p", UserName: "Pat", Model: "gpt-4", + PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + err = auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "user-q", UserName: "Quinn", Model: "gpt-4", + PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(), + }) + Expect(err).ToNot(HaveOccurred()) + + buckets, err := auth.GetAllUsage(db, "month", "user-p") + Expect(err).ToNot(HaveOccurred()) + for _, b := range buckets { + Expect(b.UserID).To(Equal("user-p")) + } + }) + }) +}) diff --git a/core/http/endpoints/localai/agent_collections.go b/core/http/endpoints/localai/agent_collections.go index 49b6ea386dc8..022035ef4096 100644 --- a/core/http/endpoints/localai/agent_collections.go +++ b/core/http/endpoints/localai/agent_collections.go @@ -12,27 +12,54 @@ import ( func ListCollectionsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - collections, err := svc.ListCollections() + userID := getUserID(c) + cols, err := svc.ListCollectionsForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } - return c.JSON(http.StatusOK, map[string]any{ - "collections": collections, - "count": len(collections), - }) + + resp := map[string]any{ + "collections": cols, + "count": len(cols), + } + + // Admin cross-user aggregation + if wantsAllUsers(c) { + usm := svc.UserServicesManager() + if usm != nil { + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userCols, err := svc.ListCollectionsForUser(uid) + if err != nil || len(userCols) == 0 { + continue + } + userGroups[uid] = map[string]any{"collections": userCols} + } + if len(userGroups) > 0 { + resp["user_groups"] = userGroups + } + } + } + + return c.JSON(http.StatusOK, resp) } } func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Name string `json:"name"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.CreateCollection(payload.Name); err != nil { + if err := svc.CreateCollectionForUser(userID, payload.Name); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok", "name": payload.Name}) @@ -42,20 +69,18 @@ func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc { func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) } - if svc.CollectionEntryExists(name, file.Filename) { - return c.JSON(http.StatusConflict, map[string]string{"error": "entry already exists"}) - } src, err := file.Open() if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } defer src.Close() - if err := svc.UploadToCollection(name, file.Filename, src); err != nil { + if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -68,7 +93,8 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - entries, err := svc.ListCollectionEntries(c.Param("name")) + userID := effectiveUserID(c) + entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name")) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -85,12 +111,13 @@ func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFun func GetCollectionEntryContentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) entryParam := c.Param("*") entry, err := url.PathUnescape(entryParam) if err != nil { entry = entryParam } - content, chunkCount, err := svc.GetCollectionEntryContent(c.Param("name"), entry) + content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -107,6 +134,7 @@ func GetCollectionEntryContentEndpoint(app *application.Application) echo.Handle func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) var payload struct { Query string `json:"query"` MaxResults int `json:"max_results"` @@ -114,7 +142,7 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - results, err := svc.SearchCollection(c.Param("name"), payload.Query, payload.MaxResults) + results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -131,7 +159,8 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.ResetCollection(c.Param("name")); err != nil { + userID := effectiveUserID(c) + if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -144,13 +173,14 @@ func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc { func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) var payload struct { Entry string `json:"entry"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - remaining, err := svc.DeleteCollectionEntry(c.Param("name"), payload.Entry) + remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -167,6 +197,7 @@ func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFun func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) var payload struct { URL string `json:"url"` UpdateInterval int `json:"update_interval"` @@ -177,7 +208,7 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc if payload.UpdateInterval < 1 { payload.UpdateInterval = 60 } - if err := svc.AddCollectionSource(c.Param("name"), payload.URL, payload.UpdateInterval); err != nil { + if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -190,13 +221,14 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.RemoveCollectionSource(c.Param("name"), payload.URL); err != nil { + if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -207,12 +239,13 @@ func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFu func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) entryParam := c.Param("*") entry, err := url.PathUnescape(entryParam) if err != nil { entry = entryParam } - fpath, err := svc.GetCollectionEntryFilePath(c.Param("name"), entry) + fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -226,7 +259,8 @@ func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.Handle func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - sources, err := svc.ListCollectionSources(c.Param("name")) + userID := effectiveUserID(c) + sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name")) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) diff --git a/core/http/endpoints/localai/agent_jobs.go b/core/http/endpoints/localai/agent_jobs.go index c46a0208a10f..8ed20d7df446 100644 --- a/core/http/endpoints/localai/agent_jobs.go +++ b/core/http/endpoints/localai/agent_jobs.go @@ -8,19 +8,27 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" ) -// CreateTaskEndpoint creates a new agent task -// @Summary Create a new agent task -// @Description Create a new reusable agent task with prompt template and configuration -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param task body schema.Task true "Task definition" -// @Success 201 {object} map[string]string "Task created" -// @Failure 400 {object} map[string]string "Invalid request" -// @Failure 500 {object} map[string]string "Internal server error" -// @Router /api/agent/tasks [post] +// getJobService returns the job service for the current user. +// Falls back to the global service when no user is authenticated. +func getJobService(app *application.Application, c echo.Context) *services.AgentJobService { + userID := getUserID(c) + if userID == "" { + return app.AgentJobService() + } + svc := app.AgentPoolService() + if svc == nil { + return app.AgentJobService() + } + jobSvc, err := svc.JobServiceForUser(userID) + if err != nil { + return app.AgentJobService() + } + return jobSvc +} + func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var task schema.Task @@ -28,7 +36,7 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } - id, err := app.AgentJobService().CreateTask(task) + id, err := getJobService(app, c).CreateTask(task) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -37,18 +45,6 @@ func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// UpdateTaskEndpoint updates an existing task -// @Summary Update an agent task -// @Description Update an existing agent task -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param id path string true "Task ID" -// @Param task body schema.Task true "Updated task definition" -// @Success 200 {object} map[string]string "Task updated" -// @Failure 400 {object} map[string]string "Invalid request" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{id} [put] func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") @@ -57,7 +53,7 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } - if err := app.AgentJobService().UpdateTask(id, task); err != nil { + if err := getJobService(app, c).UpdateTask(id, task); err != nil { if err.Error() == "task not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -68,19 +64,10 @@ func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// DeleteTaskEndpoint deletes a task -// @Summary Delete an agent task -// @Description Delete an agent task by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Task ID" -// @Success 200 {object} map[string]string "Task deleted" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{id} [delete] func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - if err := app.AgentJobService().DeleteTask(id); err != nil { + if err := getJobService(app, c).DeleteTask(id); err != nil { if err.Error() == "task not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -91,33 +78,52 @@ func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// ListTasksEndpoint lists all tasks -// @Summary List all agent tasks -// @Description Get a list of all agent tasks -// @Tags agent-jobs -// @Produce json -// @Success 200 {array} schema.Task "List of tasks" -// @Router /api/agent/tasks [get] func ListTasksEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { - tasks := app.AgentJobService().ListTasks() + jobSvc := getJobService(app, c) + tasks := jobSvc.ListTasks() + + // Admin cross-user aggregation + if wantsAllUsers(c) { + svc := app.AgentPoolService() + if svc != nil { + usm := svc.UserServicesManager() + if usm != nil { + userID := getUserID(c) + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userJobSvc, err := svc.JobServiceForUser(uid) + if err != nil { + continue + } + userTasks := userJobSvc.ListTasks() + if len(userTasks) == 0 { + continue + } + userGroups[uid] = map[string]any{"tasks": userTasks} + } + if len(userGroups) > 0 { + return c.JSON(http.StatusOK, map[string]any{ + "tasks": tasks, + "user_groups": userGroups, + }) + } + } + } + } + return c.JSON(http.StatusOK, tasks) } } -// GetTaskEndpoint gets a task by ID -// @Summary Get an agent task -// @Description Get an agent task by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Task ID" -// @Success 200 {object} schema.Task "Task details" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{id} [get] func GetTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - task, err := app.AgentJobService().GetTask(id) + task, err := getJobService(app, c).GetTask(id) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -126,16 +132,6 @@ func GetTaskEndpoint(app *application.Application) echo.HandlerFunc { } } -// ExecuteJobEndpoint executes a job -// @Summary Execute an agent job -// @Description Create and execute a new agent job -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param request body schema.JobExecutionRequest true "Job execution request" -// @Success 201 {object} schema.JobExecutionResponse "Job created" -// @Failure 400 {object} map[string]string "Invalid request" -// @Router /api/agent/jobs/execute [post] func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var req schema.JobExecutionRequest @@ -147,7 +143,6 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { req.Parameters = make(map[string]string) } - // Build multimedia struct from request var multimedia *schema.MultimediaAttachment if len(req.Images) > 0 || len(req.Videos) > 0 || len(req.Audios) > 0 || len(req.Files) > 0 { multimedia = &schema.MultimediaAttachment{ @@ -158,7 +153,7 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { } } - jobID, err := app.AgentJobService().ExecuteJob(req.TaskID, req.Parameters, "api", multimedia) + jobID, err := getJobService(app, c).ExecuteJob(req.TaskID, req.Parameters, "api", multimedia) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -172,19 +167,10 @@ func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// GetJobEndpoint gets a job by ID -// @Summary Get an agent job -// @Description Get an agent job by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Job ID" -// @Success 200 {object} schema.Job "Job details" -// @Failure 404 {object} map[string]string "Job not found" -// @Router /api/agent/jobs/{id} [get] func GetJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - job, err := app.AgentJobService().GetJob(id) + job, err := getJobService(app, c).GetJob(id) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -193,16 +179,6 @@ func GetJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// ListJobsEndpoint lists jobs with optional filtering -// @Summary List agent jobs -// @Description Get a list of agent jobs, optionally filtered by task_id and status -// @Tags agent-jobs -// @Produce json -// @Param task_id query string false "Filter by task ID" -// @Param status query string false "Filter by status (pending, running, completed, failed, cancelled)" -// @Param limit query int false "Limit number of results" -// @Success 200 {array} schema.Job "List of jobs" -// @Router /api/agent/jobs [get] func ListJobsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var taskID *string @@ -224,25 +200,50 @@ func ListJobsEndpoint(app *application.Application) echo.HandlerFunc { } } - jobs := app.AgentJobService().ListJobs(taskID, status, limit) + jobSvc := getJobService(app, c) + jobs := jobSvc.ListJobs(taskID, status, limit) + + // Admin cross-user aggregation + if wantsAllUsers(c) { + svc := app.AgentPoolService() + if svc != nil { + usm := svc.UserServicesManager() + if usm != nil { + userID := getUserID(c) + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userJobSvc, err := svc.JobServiceForUser(uid) + if err != nil { + continue + } + userJobs := userJobSvc.ListJobs(taskID, status, limit) + if len(userJobs) == 0 { + continue + } + userGroups[uid] = map[string]any{"jobs": userJobs} + } + if len(userGroups) > 0 { + return c.JSON(http.StatusOK, map[string]any{ + "jobs": jobs, + "user_groups": userGroups, + }) + } + } + } + } + return c.JSON(http.StatusOK, jobs) } } -// CancelJobEndpoint cancels a running job -// @Summary Cancel an agent job -// @Description Cancel a running or pending agent job -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Job ID" -// @Success 200 {object} map[string]string "Job cancelled" -// @Failure 400 {object} map[string]string "Job cannot be cancelled" -// @Failure 404 {object} map[string]string "Job not found" -// @Router /api/agent/jobs/{id}/cancel [post] func CancelJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - if err := app.AgentJobService().CancelJob(id); err != nil { + if err := getJobService(app, c).CancelJob(id); err != nil { if err.Error() == "job not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -253,19 +254,10 @@ func CancelJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// DeleteJobEndpoint deletes a job -// @Summary Delete an agent job -// @Description Delete an agent job by ID -// @Tags agent-jobs -// @Produce json -// @Param id path string true "Job ID" -// @Success 200 {object} map[string]string "Job deleted" -// @Failure 404 {object} map[string]string "Job not found" -// @Router /api/agent/jobs/{id} [delete] func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") - if err := app.AgentJobService().DeleteJob(id); err != nil { + if err := getJobService(app, c).DeleteJob(id); err != nil { if err.Error() == "job not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -276,52 +268,33 @@ func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc { } } -// ExecuteTaskByNameEndpoint executes a task by name -// @Summary Execute a task by name -// @Description Execute an agent task by its name (convenience endpoint). Parameters can be provided in the request body as a JSON object with string values. -// @Tags agent-jobs -// @Accept json -// @Produce json -// @Param name path string true "Task name" -// @Param request body map[string]string false "Template parameters (JSON object with string values)" -// @Success 201 {object} schema.JobExecutionResponse "Job created" -// @Failure 400 {object} map[string]string "Invalid request" -// @Failure 404 {object} map[string]string "Task not found" -// @Router /api/agent/tasks/{name}/execute [post] func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { name := c.Param("name") var params map[string]string - // Try to bind parameters from request body - // If body is empty or invalid, use empty params if c.Request().ContentLength > 0 { if err := c.Bind(¶ms); err != nil { - // If binding fails, try to read as raw JSON body := make(map[string]interface{}) if err := c.Bind(&body); err == nil { - // Convert interface{} values to strings params = make(map[string]string) for k, v := range body { if str, ok := v.(string); ok { params[k] = str } else { - // Convert non-string values to string params[k] = fmt.Sprintf("%v", v) } } } else { - // If all binding fails, use empty params params = make(map[string]string) } } } else { - // No body provided, use empty params params = make(map[string]string) } - // Find task by name - tasks := app.AgentJobService().ListTasks() + jobSvc := getJobService(app, c) + tasks := jobSvc.ListTasks() var task *schema.Task for _, t := range tasks { if t.Name == name { @@ -334,7 +307,7 @@ func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusNotFound, map[string]string{"error": "Task not found: " + name}) } - jobID, err := app.AgentJobService().ExecuteJob(task.ID, params, "api", nil) + jobID, err := jobSvc.ExecuteJob(task.ID, params, "api", nil) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } diff --git a/core/http/endpoints/localai/agent_skills.go b/core/http/endpoints/localai/agent_skills.go index 0a9d998c4ac0..2256db2bbaa5 100644 --- a/core/http/endpoints/localai/agent_skills.go +++ b/core/http/endpoints/localai/agent_skills.go @@ -44,10 +44,38 @@ func skillsToResponses(skills []skilldomain.Skill) []skillResponse { func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - skills, err := svc.ListSkills() + userID := getUserID(c) + skills, err := svc.ListSkillsForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } + + // Admin cross-user aggregation + if wantsAllUsers(c) { + usm := svc.UserServicesManager() + if usm != nil { + userIDs, _ := usm.ListAllUserIDs() + userGroups := map[string]any{} + for _, uid := range userIDs { + if uid == userID { + continue + } + userSkills, err := svc.ListSkillsForUser(uid) + if err != nil || len(userSkills) == 0 { + continue + } + userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)} + } + resp := map[string]any{ + "skills": skillsToResponses(skills), + } + if len(userGroups) > 0 { + resp["user_groups"] = userGroups + } + return c.JSON(http.StatusOK, resp) + } + } + return c.JSON(http.StatusOK, skillsToResponses(skills)) } } @@ -55,7 +83,8 @@ func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - cfg := svc.GetSkillsConfig() + userID := getUserID(c) + cfg := svc.GetSkillsConfigForUser(userID) return c.JSON(http.StatusOK, cfg) } } @@ -63,8 +92,9 @@ func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) query := c.QueryParam("q") - skills, err := svc.SearchSkills(query) + skills, err := svc.SearchSkillsForUser(userID, query) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -75,6 +105,7 @@ func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Name string `json:"name"` Description string `json:"description"` @@ -87,7 +118,7 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.CreateSkill(payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) + skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "already exists") { return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()}) @@ -101,7 +132,8 @@ func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - skill, err := svc.GetSkill(c.Param("name")) + userID := effectiveUserID(c) + skill, err := svc.GetSkillForUser(userID, c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -112,6 +144,7 @@ func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) var payload struct { Description string `json:"description"` Content string `json:"content"` @@ -123,7 +156,7 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.UpdateSkill(c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) + skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -137,7 +170,8 @@ func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.DeleteSkill(c.Param("name")); err != nil { + userID := effectiveUserID(c) + if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -147,9 +181,9 @@ func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - // The wildcard param captures the path after /export/ + userID := effectiveUserID(c) name := c.Param("*") - data, err := svc.ExportSkill(name) + data, err := svc.ExportSkillForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -162,6 +196,7 @@ func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) @@ -175,7 +210,7 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - skill, err := svc.ImportSkill(data) + skill, err := svc.ImportSkillForUser(userID, data) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -188,7 +223,8 @@ func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - resources, skill, err := svc.ListSkillResources(c.Param("name")) + userID := effectiveUserID(c) + resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -225,7 +261,8 @@ func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - content, info, err := svc.GetSkillResource(c.Param("name"), c.Param("*")) + userID := effectiveUserID(c) + content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -245,6 +282,7 @@ func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"}) @@ -262,7 +300,7 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } - if err := svc.CreateSkillResource(c.Param("name"), path, data); err != nil { + if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"path": path}) @@ -272,13 +310,14 @@ func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { Content string `json:"content"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.UpdateSkillResource(c.Param("name"), c.Param("*"), payload.Content); err != nil { + if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -288,7 +327,8 @@ func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.DeleteSkillResource(c.Param("name"), c.Param("*")); err != nil { + userID := getUserID(c) + if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -300,7 +340,8 @@ func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - repos, err := svc.ListGitRepos() + userID := getUserID(c) + repos, err := svc.ListGitReposForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -311,13 +352,14 @@ func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - repo, err := svc.AddGitRepo(payload.URL) + repo, err := svc.AddGitRepoForUser(userID, payload.URL) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } @@ -328,6 +370,7 @@ func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var payload struct { URL string `json:"url"` Enabled *bool `json:"enabled"` @@ -335,7 +378,7 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - repo, err := svc.UpdateGitRepo(c.Param("id"), payload.URL, payload.Enabled) + repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -349,7 +392,8 @@ func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.DeleteGitRepo(c.Param("id")); err != nil { + userID := getUserID(c) + if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -362,7 +406,8 @@ func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.SyncGitRepo(c.Param("id")); err != nil { + userID := getUserID(c) + if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"}) @@ -372,7 +417,8 @@ func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - repo, err := svc.ToggleGitRepo(c.Param("id")) + userID := getUserID(c) + repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id")) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } diff --git a/core/http/endpoints/localai/agents.go b/core/http/endpoints/localai/agents.go index 5226f7edfc78..d2bc25c48e30 100644 --- a/core/http/endpoints/localai/agents.go +++ b/core/http/endpoints/localai/agents.go @@ -12,6 +12,7 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAGI/core/state" @@ -19,10 +20,42 @@ import ( agiServices "github.com/mudler/LocalAGI/services" ) +// getUserID extracts the scoped user ID from the request context. +// Returns empty string when auth is not active (backward compat). +func getUserID(c echo.Context) string { + user := auth.GetUser(c) + if user == nil { + return "" + } + return user.ID +} + +// isAdminUser returns true if the authenticated user has admin role. +func isAdminUser(c echo.Context) bool { + user := auth.GetUser(c) + return user != nil && user.Role == auth.RoleAdmin +} + +// wantsAllUsers returns true if the request has ?all_users=true and the user is admin. +func wantsAllUsers(c echo.Context) bool { + return c.QueryParam("all_users") == "true" && isAdminUser(c) +} + +// effectiveUserID returns the user ID to scope operations to. +// SECURITY: Only admins may supply ?user_id= to operate on another user's +// resources. Non-admin callers always get their own ID regardless of query params. +func effectiveUserID(c echo.Context) string { + if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) { + return targetUID + } + return getUserID(c) +} + func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - statuses := svc.ListAgents() + userID := getUserID(c) + statuses := svc.ListAgentsForUser(userID) agents := make([]string, 0, len(statuses)) for name := range statuses { agents = append(agents, name) @@ -38,6 +71,22 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { if hubURL := svc.AgentHubURL(); hubURL != "" { resp["agent_hub_url"] = hubURL } + + // Admin cross-user aggregation + if wantsAllUsers(c) { + grouped := svc.ListAllAgentsGrouped() + userGroups := map[string]any{} + for uid, agentList := range grouped { + if uid == userID || uid == "" { + continue + } + userGroups[uid] = map[string]any{"agents": agentList} + } + if len(userGroups) > 0 { + resp["user_groups"] = userGroups + } + } + return c.JSON(http.StatusOK, resp) } } @@ -45,11 +94,12 @@ func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) var cfg state.AgentConfig if err := c.Bind(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.CreateAgent(&cfg); err != nil { + if err := svc.CreateAgentForUser(userID, &cfg); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) @@ -59,8 +109,9 @@ func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - ag := svc.GetAgent(name) + ag := svc.GetAgentForUser(userID, name) if ag == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } @@ -73,12 +124,13 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc { func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") var cfg state.AgentConfig if err := c.Bind(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.UpdateAgent(name, &cfg); err != nil { + if err := svc.UpdateAgentForUser(userID, name, &cfg); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -91,8 +143,9 @@ func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc { func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - if err := svc.DeleteAgent(name); err != nil { + if err := svc.DeleteAgentForUser(userID, name); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -102,8 +155,9 @@ func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - cfg := svc.GetAgentConfig(name) + cfg := svc.GetAgentConfigForUser(userID, name) if cfg == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } @@ -114,7 +168,8 @@ func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc { func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.PauseAgent(c.Param("name")); err != nil { + userID := effectiveUserID(c) + if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -124,7 +179,8 @@ func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc { func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() - if err := svc.ResumeAgent(c.Param("name")); err != nil { + userID := effectiveUserID(c) + if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) @@ -134,8 +190,9 @@ func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - history := svc.GetAgentStatus(name) + history := svc.GetAgentStatusForUser(userID, name) if history == nil { history = &state.Status{ActionResults: []coreTypes.ActionState{}} } @@ -162,8 +219,9 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc { func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - history, err := svc.GetAgentObservables(name) + history, err := svc.GetAgentObservablesForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -177,8 +235,9 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - if err := svc.ClearAgentObservables(name); err != nil { + if err := svc.ClearAgentObservablesForUser(userID, name); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{"Name": name, "cleared": true}) @@ -188,6 +247,7 @@ func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFun func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") var payload struct { Message string `json:"message"` @@ -199,7 +259,7 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { if message == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Message cannot be empty"}) } - messageID, err := svc.Chat(name, message) + messageID, err := svc.ChatForUser(userID, name, message) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) @@ -216,8 +276,9 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - manager := svc.GetSSEManager(name) + manager := svc.GetSSEManagerForUser(userID, name) if manager == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } @@ -243,8 +304,9 @@ func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc { func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := effectiveUserID(c) name := c.Param("name") - data, err := svc.ExportAgent(name) + data, err := svc.ExportAgentForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } @@ -256,6 +318,7 @@ func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc { func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() + userID := getUserID(c) // Try multipart form file first file, err := c.FormFile("file") @@ -269,7 +332,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read file"}) } - if err := svc.ImportAgent(data); err != nil { + if err := svc.ImportAgentForUser(userID, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) @@ -284,7 +347,7 @@ func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } - if err := svc.ImportAgent(data); err != nil { + if err := svc.ImportAgentForUser(userID, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) @@ -358,10 +421,16 @@ func AgentFileEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusNotFound, map[string]string{"error": "file not found"}) } - // Only serve files from the outputs subdirectory - outputsDir, _ := filepath.EvalSymlinks(filepath.Clean(svc.OutputsDir())) + // Determine the allowed outputs directory — scoped to the user when auth is active + allowedDir := svc.OutputsDir() + user := auth.GetUser(c) + if user != nil { + allowedDir = filepath.Join(allowedDir, user.ID) + } + + allowedDirResolved, _ := filepath.EvalSymlinks(filepath.Clean(allowedDir)) - if utils.InTrustedRoot(resolved, outputsDir) != nil { + if utils.InTrustedRoot(resolved, allowedDirResolved) != nil { return c.JSON(http.StatusForbidden, map[string]string{"error": "access denied"}) } diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 47501dd934f8..1f722bacf90e 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -3,16 +3,22 @@ package openai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" model "github.com/mudler/LocalAI/pkg/model" + "gorm.io/gorm" ) // ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models // @Summary List and describe the various models available in the API. // @Success 200 {object} schema.ModelsDataResponse "Response" // @Router /v1/models [get] -func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { +func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, db ...*gorm.DB) echo.HandlerFunc { + var authDB *gorm.DB + if len(db) > 0 { + authDB = db[0] + } return func(c echo.Context) error { // If blank, no filter is applied. filter := c.QueryParam("filter") @@ -36,6 +42,26 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap return err } + // Filter models by user's allowlist if auth is enabled + if authDB != nil { + if user := auth.GetUser(c); user != nil && user.Role != auth.RoleAdmin { + perm, err := auth.GetCachedUserPermissions(c, authDB, user.ID) + if err == nil && perm.AllowedModels.Enabled { + allowed := map[string]bool{} + for _, m := range perm.AllowedModels.Models { + allowed[m] = true + } + filtered := make([]string, 0, len(modelNames)) + for _, m := range modelNames { + if allowed[m] { + filtered = append(filtered, m) + } + } + modelNames = filtered + } + } + } + // Map from a slice of names to a slice of OpenAIModel response objects dataModels := []schema.OpenAIModel{} for _, m := range modelNames { diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index 800b824c8789..22049083d266 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -2,15 +2,16 @@ package middleware import ( "bytes" - "github.com/emirpasic/gods/v2/queues/circularbuffer" "io" "net/http" "sort" "sync" "time" + "github.com/emirpasic/gods/v2/queues/circularbuffer" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/xlog" ) @@ -33,6 +34,8 @@ type APIExchange struct { Request APIExchangeRequest `json:"request"` Response APIExchangeResponse `json:"response"` Error string `json:"error,omitempty"` + UserID string `json:"user_id,omitempty"` + UserName string `json:"user_name,omitempty"` } var traceBuffer *circularbuffer.Queue[APIExchange] @@ -147,6 +150,11 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc { exchange.Error = handlerErr.Error() } + if user := auth.GetUser(c); user != nil { + exchange.UserID = user.ID + exchange.UserName = user.Name + } + select { case logChan <- exchange: default: diff --git a/core/http/middleware/usage.go b/core/http/middleware/usage.go new file mode 100644 index 000000000000..b82c1ee3f506 --- /dev/null +++ b/core/http/middleware/usage.go @@ -0,0 +1,185 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/xlog" + "gorm.io/gorm" +) + +const ( + usageFlushInterval = 5 * time.Second + usageMaxPending = 5000 +) + +// usageBatcher accumulates usage records and flushes them to the DB periodically. +type usageBatcher struct { + mu sync.Mutex + pending []*auth.UsageRecord + db *gorm.DB +} + +func (b *usageBatcher) add(r *auth.UsageRecord) { + b.mu.Lock() + b.pending = append(b.pending, r) + b.mu.Unlock() +} + +func (b *usageBatcher) flush() { + b.mu.Lock() + batch := b.pending + b.pending = nil + b.mu.Unlock() + + if len(batch) == 0 { + return + } + + if err := b.db.Create(&batch).Error; err != nil { + xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err) + // Re-queue failed records with a cap to avoid unbounded growth + b.mu.Lock() + if len(b.pending) < usageMaxPending { + b.pending = append(batch, b.pending...) + } + b.mu.Unlock() + } +} + +var batcher *usageBatcher + +// InitUsageRecorder starts a background goroutine that periodically flushes +// accumulated usage records to the database. +func InitUsageRecorder(db *gorm.DB) { + if db == nil { + return + } + batcher = &usageBatcher{db: db} + go func() { + ticker := time.NewTicker(usageFlushInterval) + defer ticker.Stop() + for range ticker.C { + batcher.flush() + } + }() +} + +// usageResponseBody is the minimal structure we need from the response JSON. +type usageResponseBody struct { + Model string `json:"model"` + Usage *struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + } `json:"usage"` +} + +// UsageMiddleware extracts token usage from OpenAI-compatible response JSON +// and records it per-user. +func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if db == nil || batcher == nil { + return next(c) + } + + startTime := time.Now() + + // Wrap response writer to capture body + resBody := new(bytes.Buffer) + origWriter := c.Response().Writer + mw := &bodyWriter{ + ResponseWriter: origWriter, + body: resBody, + } + c.Response().Writer = mw + + handlerErr := next(c) + + // Restore original writer + c.Response().Writer = origWriter + + // Only record on successful responses + if c.Response().Status < 200 || c.Response().Status >= 300 { + return handlerErr + } + + // Get authenticated user + user := auth.GetUser(c) + if user == nil { + return handlerErr + } + + // Try to parse usage from response + responseBytes := resBody.Bytes() + if len(responseBytes) == 0 { + return handlerErr + } + + // Check content type + ct := c.Response().Header().Get("Content-Type") + isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json")) + isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream")) + + if !isJSON && !isSSE { + return handlerErr + } + + var resp usageResponseBody + if isSSE { + last, ok := lastSSEData(responseBytes) + if !ok { + return handlerErr + } + if err := json.Unmarshal(last, &resp); err != nil { + return handlerErr + } + } else { + if err := json.Unmarshal(responseBytes, &resp); err != nil { + return handlerErr + } + } + + if resp.Usage == nil { + return handlerErr + } + + record := &auth.UsageRecord{ + UserID: user.ID, + UserName: user.Name, + Model: resp.Model, + Endpoint: c.Request().URL.Path, + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + Duration: time.Since(startTime).Milliseconds(), + CreatedAt: startTime, + } + + batcher.add(record) + + return handlerErr + } + } +} + +// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]". +func lastSSEData(b []byte) ([]byte, bool) { + prefix := []byte("data: ") + var last []byte + for _, line := range bytes.Split(b, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + if bytes.HasPrefix(line, prefix) { + payload := line[len(prefix):] + if !bytes.Equal(payload, []byte("[DONE]")) { + last = payload + } + } + } + return last, last != nil +} diff --git a/core/http/react-ui/e2e/navigation.spec.js b/core/http/react-ui/e2e/navigation.spec.js index 99b5bb91954d..d61b290af397 100644 --- a/core/http/react-ui/e2e/navigation.spec.js +++ b/core/http/react-ui/e2e/navigation.spec.js @@ -9,7 +9,7 @@ test.describe('Navigation', () => { test('/app shows home page with LocalAI title', async ({ page }) => { await page.goto('/app') await expect(page.locator('.sidebar')).toBeVisible() - await expect(page.getByRole('heading', { name: 'How can I help you today?' })).toBeVisible() + await expect(page.locator('.home-page')).toBeVisible() }) test('sidebar traces link navigates to /app/traces', async ({ page }) => { diff --git a/core/http/react-ui/src/App.css b/core/http/react-ui/src/App.css index d0f44789baed..16132da70a2c 100644 --- a/core/http/react-ui/src/App.css +++ b/core/http/react-ui/src/App.css @@ -142,6 +142,7 @@ box-shadow: var(--shadow-sidebar); transition: width var(--duration-normal) var(--ease-default), transform var(--duration-normal) var(--ease-default); + will-change: transform; } .sidebar-overlay { @@ -244,6 +245,7 @@ flex: 1; overflow: hidden; text-overflow: ellipsis; + transition: opacity 150ms ease; } .nav-external { @@ -260,6 +262,92 @@ align-items: center; justify-content: space-between; gap: var(--spacing-xs); + flex-wrap: wrap; +} + +.sidebar-user { + display: flex; + align-items: center; + gap: var(--spacing-xs); + width: 100%; + padding: var(--spacing-xs) 0; + font-size: 0.75rem; + color: var(--color-text-secondary); + overflow: hidden; +} + +.sidebar-user-avatar { + width: 20px; + height: 20px; + border-radius: var(--radius-full); + flex-shrink: 0; +} + +.sidebar-user-avatar-icon { + font-size: 1.25rem; + color: var(--color-text-muted); + flex-shrink: 0; +} + +.sidebar-user-link { + display: flex; + align-items: center; + gap: var(--spacing-xs); + flex: 1; + min-width: 0; + background: none; + border: none; + padding: 2px var(--spacing-xs); + margin: -2px calc(-1 * var(--spacing-xs)); + border-radius: var(--radius-sm); + color: inherit; + font: inherit; + cursor: pointer; + transition: background var(--duration-fast), color var(--duration-fast); +} + +.sidebar-user-link:hover { + background: var(--color-bg-hover); + color: var(--color-text-primary); +} + +.sidebar-user-name { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + text-align: left; +} + +.sidebar-logout-btn { + background: none; + border: none; + color: var(--color-text-muted); + cursor: pointer; + padding: 2px 4px; + border-radius: var(--radius-sm); + font-size: 0.75rem; + flex-shrink: 0; + transition: color var(--duration-fast); +} + +.sidebar-logout-btn:hover { + color: var(--color-error); +} + +.sidebar.collapsed .sidebar-user { + justify-content: center; +} + +.sidebar.collapsed .sidebar-user-link { + flex: 0; + margin: 0; + padding: 2px; +} + +.sidebar.collapsed .sidebar-user-name, +.sidebar.collapsed .sidebar-logout-btn { + display: none; } .sidebar-collapse-btn { @@ -423,7 +511,7 @@ position: fixed; top: var(--spacing-lg); right: var(--spacing-lg); - z-index: 100; + z-index: 1100; display: flex; flex-direction: column; gap: var(--spacing-sm); @@ -446,24 +534,30 @@ transform: translateX(20px); } +.toast-exit { + opacity: 0; + transform: translateX(20px); + transition: opacity 150ms ease, transform 150ms ease; +} + .toast-success { - background: rgba(20, 184, 166, 0.15); - border: 1px solid rgba(20, 184, 166, 0.3); + background: var(--color-success-light); + border: 1px solid var(--color-success-border); color: var(--color-success); } .toast-error { - background: rgba(239, 68, 68, 0.15); - border: 1px solid rgba(239, 68, 68, 0.3); + background: var(--color-error-light); + border: 1px solid var(--color-error-border); color: var(--color-error); } .toast-warning { - background: rgba(245, 158, 11, 0.15); - border: 1px solid rgba(245, 158, 11, 0.3); + background: var(--color-warning-light); + border: 1px solid var(--color-warning-border); color: var(--color-warning); } .toast-info { - background: rgba(56, 189, 248, 0.15); - border: 1px solid rgba(56, 189, 248, 0.3); + background: var(--color-info-light); + border: 1px solid var(--color-info-border); color: var(--color-info); } @@ -494,6 +588,14 @@ .spinner-md .spinner-ring { width: 24px; height: 24px; } .spinner-lg .spinner-ring { width: 40px; height: 40px; } +.spinner-logo { + animation: pulse 1.2s ease-in-out infinite; + object-fit: contain; +} +.spinner-sm .spinner-logo { width: 16px; height: 16px; } +.spinner-md .spinner-logo { width: 24px; height: 24px; } +.spinner-lg .spinner-logo { width: 40px; height: 40px; } + /* Model selector */ .model-selector { background: var(--color-bg-tertiary); @@ -623,6 +725,7 @@ max-width: 1200px; margin: 0 auto; width: 100%; + animation: fadeIn var(--duration-normal) var(--ease-default); } .page-header { @@ -646,11 +749,13 @@ border: 1px solid var(--color-border-subtle); border-radius: var(--radius-lg); padding: var(--spacing-md); - transition: border-color var(--duration-fast), box-shadow var(--duration-fast); + transition: border-color var(--duration-fast), box-shadow var(--duration-fast), transform var(--duration-fast); } .card:hover { border-color: var(--color-border-default); + box-shadow: var(--shadow-sm); + transform: translateY(-1px); } .card-grid { @@ -671,7 +776,7 @@ font-weight: 500; cursor: pointer; border: none; - transition: all var(--duration-fast) var(--ease-default); + transition: background var(--duration-fast) var(--ease-default), color var(--duration-fast) var(--ease-default), border-color var(--duration-fast) var(--ease-default), box-shadow var(--duration-fast) var(--ease-default); text-decoration: none; } @@ -707,6 +812,10 @@ font-size: 0.8125rem; } +.btn:active:not(:disabled) { + transform: translateY(1px); +} + .btn:disabled { opacity: 0.5; cursor: not-allowed; @@ -727,6 +836,7 @@ } .input:focus { border-color: var(--color-border-strong); + box-shadow: 0 0 0 2px var(--color-primary-light); } .textarea { @@ -745,6 +855,7 @@ } .textarea:focus { border-color: var(--color-border-strong); + box-shadow: 0 0 0 2px var(--color-primary-light); } /* Code editor (syntax-highlighted textarea overlay) */ @@ -877,6 +988,7 @@ padding: var(--spacing-sm) var(--spacing-md); border-bottom: 1px solid var(--color-border-divider); color: var(--color-text-primary); + transition: background var(--duration-fast) var(--ease-default); } .table tr:last-child td { @@ -933,6 +1045,103 @@ background: white; } +/* Model checkbox list */ +.model-list { + display: flex; + flex-direction: column; + gap: 2px; + max-height: 200px; + overflow: auto; + padding: var(--spacing-xs); + background: var(--color-bg-secondary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-md); +} + +.model-list::-webkit-scrollbar { + width: 6px; +} + +.model-list::-webkit-scrollbar-track { + background: transparent; +} + +.model-list::-webkit-scrollbar-thumb { + background: var(--color-border-default); + border-radius: var(--radius-full); +} + +.model-item { + display: flex; + align-items: center; + gap: var(--spacing-sm); + padding: 6px var(--spacing-sm); + cursor: pointer; + border-radius: var(--radius-sm); + transition: background var(--duration-fast) var(--ease-default); + user-select: none; +} + +.model-item:hover { + background: var(--color-primary-light); +} + +.model-item.model-item-checked { + background: var(--color-primary-light); +} + +.model-item input[type="checkbox"] { + display: none; +} + +.model-item-check { + width: 18px; + height: 18px; + border-radius: var(--radius-sm); + border: 2px solid var(--color-border-default); + display: flex; + align-items: center; + justify-content: center; + flex-shrink: 0; + transition: all var(--duration-fast) var(--ease-default); + background: transparent; +} + +.model-item:hover .model-item-check { + border-color: var(--color-primary); +} + +.model-item-checked .model-item-check { + background: var(--color-primary); + border-color: var(--color-primary); + box-shadow: 0 0 0 1px var(--color-primary-light); +} + +.model-item-checked .model-item-check i { + color: white; + font-size: 10px; + animation: checkPop var(--duration-fast) var(--ease-default); +} + +@keyframes checkPop { + 0% { transform: scale(0); } + 60% { transform: scale(1.2); } + 100% { transform: scale(1); } +} + +.model-item-name { + font-family: 'JetBrains Mono', monospace; + font-size: 0.8rem; + color: var(--color-text-primary); + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.model-item-checked .model-item-name { + color: var(--color-primary); +} + /* Collapsible */ .collapsible-header { display: flex; @@ -1039,10 +1248,131 @@ border-color: var(--color-primary-border); } +/* Login page */ +.login-page { + min-height: 100vh; + min-height: 100dvh; + background: var(--color-bg-primary); + display: flex; + align-items: center; + justify-content: center; + padding: var(--spacing-xl); +} + +.login-card { + width: 100%; + max-width: 400px; + padding: var(--spacing-xl); +} + +.login-header { + text-align: center; + margin-bottom: var(--spacing-xl); +} + +.login-logo { + width: 56px; + height: 56px; + margin-bottom: var(--spacing-md); +} + +.login-title { + font-size: 1.5rem; + font-weight: 700; + margin-bottom: var(--spacing-xs); + color: var(--color-text-primary); +} + +.login-subtitle { + color: var(--color-text-secondary); + font-size: 0.875rem; +} + +.login-alert { + padding: var(--spacing-sm) var(--spacing-md); + border-radius: var(--radius-md); + font-size: 0.8125rem; + margin-bottom: var(--spacing-md); +} + +.login-alert-error { + background: var(--color-error-light); + color: var(--color-error); + border: 1px solid var(--color-error-border); +} + +.login-alert-success { + background: var(--color-success-light); + color: var(--color-success); + border: 1px solid var(--color-success-border); +} + +.login-divider { + display: flex; + align-items: center; + gap: var(--spacing-md); + margin: var(--spacing-lg) 0; + color: var(--color-text-muted); + font-size: 0.8125rem; +} + +.login-divider::before, +.login-divider::after { + content: ''; + flex: 1; + height: 1px; + background: var(--color-border-subtle); +} + +.login-footer { + text-align: center; + margin-top: var(--spacing-md); + font-size: 0.8125rem; + color: var(--color-text-secondary); +} + +.login-link { + background: none; + border: none; + color: var(--color-primary); + cursor: pointer; + padding: 0; + font: inherit; +} + +.login-link:hover { + color: var(--color-primary-hover); +} + +.login-token-toggle { + margin-top: var(--spacing-lg); + text-align: center; +} + +.login-token-toggle > button { + background: none; + border: none; + color: var(--color-text-muted); + cursor: pointer; + font-size: 0.75rem; + padding: 0; + font: inherit; + font-size: 0.75rem; +} + +.login-token-toggle > button:hover { + color: var(--color-text-secondary); +} + +.login-token-form { + margin-top: var(--spacing-sm); +} + /* Empty state */ .empty-state { text-align: center; padding: var(--spacing-3xl, 4rem) var(--spacing-xl); + animation: fadeIn var(--duration-normal) var(--ease-default); } .empty-state-icon { @@ -1083,14 +1413,45 @@ 50% { opacity: 0.5; } } -/* Chat-specific styles */ -.chat-layout { +@keyframes messageSlideIn { + from { opacity: 0; transform: translateY(8px); } + to { opacity: 1; transform: translateY(0); } +} + +@keyframes dropdownIn { + from { opacity: 0; transform: translateY(-4px); } + to { opacity: 1; transform: translateY(0); } +} + +@keyframes completionGlow { + 0% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0.2); } + 50% { box-shadow: 0 0 0 4px rgba(59, 130, 246, 0.1); } + 100% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0); } +} + +/* Page route transitions */ +.page-transition { + animation: fadeIn 200ms ease; display: flex; + flex-direction: column; flex: 1; min-height: 0; min-width: 0; - overflow: hidden; - position: relative; +} + +/* Completion glow on streaming finish */ +.chat-message-new .chat-message-content { + animation: completionGlow 600ms ease-out; +} + +/* Chat-specific styles */ +.chat-layout { + display: flex; + flex: 1; + min-height: 0; + min-width: 0; + overflow: hidden; + position: relative; } .chat-sidebar { @@ -1141,12 +1502,13 @@ cursor: pointer; font-size: 0.8125rem; color: var(--color-text-secondary); - transition: all var(--duration-fast); + transition: background var(--duration-fast), color var(--duration-fast), transform var(--duration-fast); margin-bottom: 2px; } .chat-list-item:hover { background: var(--color-primary-light); + transform: translateX(2px); } .chat-list-item.active { @@ -1243,7 +1605,7 @@ gap: var(--spacing-sm); max-width: 80%; min-width: 0; - animation: fadeIn 200ms ease; + animation: messageSlideIn 250ms ease-out; } .chat-message-user { @@ -1413,7 +1775,7 @@ } .chat-input-wrapper:focus-within { border-color: var(--color-primary-border); - box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.1); + box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.12), 0 0 12px rgba(59, 130, 246, 0.06); } .chat-attach-btn { @@ -1471,6 +1833,9 @@ opacity: 0.3; cursor: not-allowed; } +.chat-send-btn:active:not(:disabled) { + transform: scale(0.92); +} .chat-stop-btn { padding: var(--spacing-xs); @@ -1511,6 +1876,48 @@ vertical-align: text-bottom; } +/* Inline streaming speed indicator */ +.chat-streaming-speed { + font-size: 0.6875rem; + color: var(--color-text-muted); + padding-top: var(--spacing-xs); + font-family: 'JetBrains Mono', monospace; + display: flex; + align-items: center; + gap: 4px; +} + +/* Thinking dots animation */ +.chat-thinking-indicator { + display: flex; + align-items: center; + min-height: 24px; +} +.chat-thinking-dots { + display: inline-flex; + gap: 4px; + align-items: center; +} +.chat-thinking-dots span { + width: 6px; + height: 6px; + border-radius: 50%; + background: var(--color-text-muted); + animation: thinkingBounce 1.4s infinite ease-in-out both; +} +.chat-thinking-dots span:nth-child(1) { animation-delay: -0.32s; } +.chat-thinking-dots span:nth-child(2) { animation-delay: -0.16s; } +.chat-thinking-dots span:nth-child(3) { animation-delay: 0s; } +@keyframes thinkingBounce { + 0%, 80%, 100% { transform: scale(0.6); opacity: 0.4; } + 40% { transform: scale(1); opacity: 1; } +} + +/* Message completion flash */ +.chat-message-bubble { + transition: border-color 300ms ease; +} + /* Chat empty state */ .chat-empty-state { display: flex; @@ -1546,6 +1953,30 @@ margin: 0 0 var(--spacing-lg); max-width: 400px; } +.chat-empty-suggestions { + display: flex; + flex-wrap: wrap; + gap: var(--spacing-xs); + justify-content: center; + margin-bottom: var(--spacing-lg); + max-width: 500px; +} +.chat-empty-suggestion { + padding: var(--spacing-xs) var(--spacing-md); + background: var(--color-bg-tertiary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-full); + font-size: 0.8125rem; + font-family: inherit; + color: var(--color-text-secondary); + cursor: pointer; + transition: all var(--duration-fast); +} +.chat-empty-suggestion:hover { + border-color: var(--color-primary-border); + color: var(--color-primary); + background: var(--color-bg-secondary); +} .chat-empty-hints { display: flex; gap: var(--spacing-md); @@ -1735,13 +2166,13 @@ overflow-y: auto; } .chat-activity-thinking { - border-left-color: rgba(99, 102, 241, 0.3); + border-left-color: rgba(59, 130, 246, 0.3); } .chat-activity-tool-call { - border-left-color: rgba(139, 92, 246, 0.3); + border-left-color: rgba(245, 158, 11, 0.3); } .chat-activity-tool-result { - border-left-color: rgba(20, 184, 166, 0.3); + border-left-color: rgba(34, 197, 94, 0.3); } /* Context window progress bar */ @@ -1837,6 +2268,7 @@ border: 1px solid var(--color-border-subtle); border-radius: var(--radius-md); box-shadow: var(--shadow-lg); + animation: dropdownIn 120ms ease-out; } .chat-mcp-dropdown-loading, .chat-mcp-dropdown-empty { @@ -1912,15 +2344,15 @@ background: var(--color-text-tertiary); } .chat-client-mcp-status-connected { - background: #22c55e; + background: var(--color-success); box-shadow: 0 0 4px rgba(34, 197, 94, 0.5); } .chat-client-mcp-status-connecting { - background: #f59e0b; + background: var(--color-warning); animation: pulse 1s infinite; } .chat-client-mcp-status-error { - background: #ef4444; + background: var(--color-error); } .chat-client-mcp-status-disconnected { background: var(--color-text-tertiary); @@ -2001,6 +2433,7 @@ transform: translateX(100%); transition: transform 250ms var(--ease-default); box-shadow: var(--shadow-lg); + will-change: transform; } .chat-settings-drawer.open { transform: translateX(0); @@ -2095,7 +2528,7 @@ /* Max tokens/sec badge */ .chat-max-tps-badge { - background: rgba(99, 102, 241, 0.15); + background: rgba(59, 130, 246, 0.15); color: var(--color-primary); padding: 1px 6px; border-radius: var(--radius-full); @@ -2149,7 +2582,7 @@ align-items: center; gap: 4px; padding: 2px 6px; - background: rgba(99, 102, 241, 0.1); + background: rgba(59, 130, 246, 0.1); border-radius: var(--radius-sm); font-size: 0.7rem; color: var(--color-text-secondary); @@ -2642,6 +3075,11 @@ flex-direction: column; gap: var(--spacing-xs); } + + .chat-empty-suggestions { + flex-direction: column; + align-items: stretch; + } } /* MCP App Frame */ @@ -2677,3 +3115,408 @@ background: var(--color-bg-secondary); border-top: 1px solid var(--color-border-subtle); } + +/* Confirm Dialog */ +.confirm-dialog-backdrop { + position: fixed; + inset: 0; + z-index: 1050; + display: flex; + align-items: center; + justify-content: center; + background: var(--color-modal-backdrop); + backdrop-filter: blur(4px); + animation: fadeIn 150ms ease; +} +.confirm-dialog { + background: var(--color-bg-secondary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-lg); + max-width: 420px; + width: 90%; + padding: var(--spacing-lg); + box-shadow: var(--shadow-lg); + animation: slideUp 150ms ease; + will-change: transform, opacity; +} +@keyframes slideUp { + from { opacity: 0; transform: translateY(8px); } + to { opacity: 1; transform: translateY(0); } +} +.confirm-dialog-header { + display: flex; + align-items: center; + gap: var(--spacing-sm); + margin-bottom: var(--spacing-md); +} +.confirm-dialog-danger-icon { + color: var(--color-error); + font-size: 1.125rem; +} +.confirm-dialog-title { + font-size: 1rem; + font-weight: 600; + color: var(--color-text-primary); +} +.confirm-dialog-body { + font-size: 0.875rem; + color: var(--color-text-secondary); + margin-bottom: var(--spacing-lg); + line-height: 1.5; +} +.confirm-dialog-actions { + display: flex; + justify-content: flex-end; + gap: var(--spacing-sm); +} +.btn-danger { + background: var(--color-error); + color: white; + border: none; +} +.btn-danger:hover { + background: var(--color-error-hover, #dc2626); +} + +/* Home page */ +.home-page { + flex: 1; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + max-width: 48rem; + margin: 0 auto; + padding: var(--spacing-xl); + width: 100%; +} +.home-hero { + text-align: center; + padding: var(--spacing-md) 0; +} +.home-logo { + width: 80px; + height: auto; + margin: 0 auto var(--spacing-sm); + display: block; +} + +/* Home resource bar - prominent */ +.home-resource-bar { + width: 100%; + max-width: 320px; + padding: var(--spacing-sm) var(--spacing-md); + background: var(--color-bg-secondary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-lg); + margin-bottom: var(--spacing-md); +} +.home-resource-bar-header { + display: flex; + align-items: center; + gap: var(--spacing-xs); + font-size: 0.8125rem; + color: var(--color-text-secondary); + margin-bottom: var(--spacing-xs); +} +.home-resource-label { + font-weight: 500; +} +.home-resource-pct { + margin-left: auto; + font-family: 'JetBrains Mono', monospace; + font-weight: 500; +} +.home-resource-track { + width: 100%; + height: 6px; + background: var(--color-bg-tertiary); + border-radius: 3px; + overflow: hidden; +} +.home-resource-fill { + height: 100%; + border-radius: 3px; + transition: width 500ms ease; +} + +/* Home chat card */ +.home-chat-card { + width: 100%; + background: var(--color-bg-secondary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-lg); + padding: var(--spacing-md); + margin-bottom: var(--spacing-md); +} +.home-model-row { + display: flex; + align-items: center; + gap: var(--spacing-sm); + margin-bottom: var(--spacing-sm); +} +.home-file-tags { + display: flex; + flex-wrap: wrap; + gap: var(--spacing-xs); + margin-bottom: var(--spacing-sm); +} +.home-file-tag { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 2px 8px; + background: var(--color-bg-tertiary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-full); + font-size: 0.75rem; + color: var(--color-text-secondary); +} +.home-file-tag button { + background: none; + border: none; + color: var(--color-text-muted); + cursor: pointer; + padding: 0; + font-size: 0.625rem; +} + +/* Home input container */ +.home-input-container { + background: var(--color-bg-tertiary); + border: 1px solid var(--color-border-default); + border-radius: var(--radius-lg); + transition: border-color var(--duration-fast), box-shadow var(--duration-fast); +} +.home-input-container:focus-within { + border-color: var(--color-primary-border); + box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.12), 0 0 12px rgba(59, 130, 246, 0.06); +} +.home-textarea { + width: 100%; + background: transparent; + color: var(--color-text-primary); + border: none; + border-radius: var(--radius-lg) var(--radius-lg) 0 0; + padding: var(--spacing-sm) var(--spacing-md); + font-size: 0.875rem; + font-family: inherit; + outline: none; + resize: none; + min-height: 80px; + line-height: 1.5; +} +.home-textarea::placeholder { + color: var(--color-text-muted); +} +.home-input-footer { + display: flex; + align-items: center; + padding: var(--spacing-xs) var(--spacing-sm); + border-top: 1px solid var(--color-border-subtle); +} +.home-attach-buttons { + display: flex; + gap: 2px; +} +.home-attach-btn { + background: none; + border: none; + color: var(--color-text-muted); + cursor: pointer; + padding: 4px 8px; + font-size: 0.875rem; + border-radius: var(--radius-sm); + transition: color var(--duration-fast); +} +.home-attach-btn:hover { + color: var(--color-primary); +} +.home-input-hint { + flex: 1; + text-align: center; + font-size: 0.6875rem; + color: var(--color-text-muted); +} +.home-send-btn { + display: flex; + align-items: center; + justify-content: center; + width: 32px; + height: 32px; + background: var(--color-primary); + color: var(--color-primary-text); + border: none; + border-radius: 50%; + font-size: 0.8125rem; + cursor: pointer; + transition: background var(--duration-fast), transform 100ms; + flex-shrink: 0; +} +.home-send-btn:hover:not(:disabled) { + background: var(--color-primary-hover); + transform: scale(1.05); +} +.home-send-btn:disabled { + opacity: 0.3; + cursor: not-allowed; +} +.home-send-btn:active:not(:disabled) { + transform: scale(0.92); +} + +/* Home quick links */ +.home-quick-links { + display: flex; + flex-wrap: wrap; + gap: var(--spacing-sm); + justify-content: center; + margin: var(--spacing-md) 0; +} +.home-link-btn { + display: inline-flex; + align-items: center; + gap: var(--spacing-xs); + padding: var(--spacing-xs) var(--spacing-md); + background: var(--color-bg-tertiary); + color: var(--color-text-secondary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-full); + font-size: 0.8125rem; + font-family: inherit; + cursor: pointer; + text-decoration: none; + transition: all var(--duration-fast); +} +.home-link-btn:hover { + border-color: var(--color-primary-border); + color: var(--color-primary); + transform: translateY(-1px); +} + +/* Home loaded models */ +.home-loaded-models { + display: flex; + flex-wrap: wrap; + align-items: center; + gap: var(--spacing-xs); + padding: var(--spacing-sm); + background: var(--color-bg-secondary); + border: 1px solid var(--color-border-subtle); + border-radius: var(--radius-lg); + font-size: 0.8125rem; + color: var(--color-text-secondary); + width: 100%; +} +.home-loaded-dot { + width: 6px; + height: 6px; + border-radius: 50%; + background: var(--color-success); +} +.home-loaded-text { + font-weight: 500; + margin-right: var(--spacing-xs); +} +.home-loaded-list { + display: flex; + flex-wrap: wrap; + gap: var(--spacing-xs); +} +.home-loaded-item { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 2px 8px; + background: var(--color-bg-tertiary); + border-radius: var(--radius-full); + font-size: 0.75rem; +} +.home-loaded-item button { + background: none; + border: none; + color: var(--color-error); + cursor: pointer; + padding: 0; + font-size: 0.625rem; +} +.home-stop-all { + margin-left: auto; + background: none; + border: 1px solid var(--color-error); + color: var(--color-error); + padding: 2px 8px; + border-radius: var(--radius-full); + font-size: 0.75rem; + cursor: pointer; + font-family: inherit; +} + +/* Home wizard (no models) */ +.home-wizard { + max-width: 48rem; + width: 100%; +} +.home-wizard-hero { + text-align: center; + padding: var(--spacing-xl) 0; +} +.home-wizard-hero h1 { + font-size: 1.5rem; + font-weight: 600; + margin-bottom: var(--spacing-sm); +} +.home-wizard-hero p { + color: var(--color-text-secondary); + font-size: 0.9375rem; +} +.home-wizard-steps { + margin-bottom: var(--spacing-xl); +} +.home-wizard-steps h2 { + font-size: 1.125rem; + font-weight: 600; + margin-bottom: var(--spacing-md); +} +.home-wizard-step { + display: flex; + gap: var(--spacing-md); + align-items: flex-start; + padding: var(--spacing-sm) 0; +} +.home-wizard-step-num { + width: 28px; + height: 28px; + border-radius: 50%; + background: var(--color-primary); + color: white; + display: flex; + align-items: center; + justify-content: center; + font-size: 0.8125rem; + font-weight: 600; + flex-shrink: 0; +} +.home-wizard-step strong { + display: block; + margin-bottom: 2px; +} +.home-wizard-step p { + font-size: 0.8125rem; + color: var(--color-text-secondary); + margin: 0; +} +.home-wizard-actions { + display: flex; + gap: var(--spacing-sm); + justify-content: center; +} + +/* Reduced motion accessibility */ +@media (prefers-reduced-motion: reduce) { + *, *::before, *::after { + animation-duration: 0.01ms !important; + animation-iteration-count: 1 !important; + transition-duration: 0.01ms !important; + } +} diff --git a/core/http/react-ui/src/App.jsx b/core/http/react-ui/src/App.jsx index 421441071c96..f06fe788dc11 100644 --- a/core/http/react-ui/src/App.jsx +++ b/core/http/react-ui/src/App.jsx @@ -29,6 +29,11 @@ export default function App() { return () => window.removeEventListener('sidebar-collapse', handler) }, []) + // Scroll to top on route change + useEffect(() => { + window.scrollTo(0, 0) + }, [location.pathname]) + const layoutClasses = [ 'app-layout', isChatRoute ? 'app-layout-chat' : '', @@ -51,7 +56,9 @@ export default function App() { LocalAI
- +
+ +
{!isChatRoute && (