diff --git a/frontend/src/components/settings/EmbeddingFormModal.tsx b/frontend/src/components/settings/EmbeddingFormModal.tsx index 4fc60c0..74e07df 100644 --- a/frontend/src/components/settings/EmbeddingFormModal.tsx +++ b/frontend/src/components/settings/EmbeddingFormModal.tsx @@ -1,6 +1,7 @@ import { useState, useEffect } from "react"; import { Box, Button, TextInput, FormControl, Select, Dialog } from "@primer/react"; import type { EmbeddingModelConfig, LLMProvider } from "../../types"; +import { isLLMProvider, LLM_PROVIDERS } from "../../types"; interface Props { isOpen: boolean; @@ -9,13 +10,6 @@ interface Props { initialData?: EmbeddingModelConfig; } -const PROVIDERS: { value: LLMProvider; label: string }[] = [ - { value: "openai", label: "OpenAI" }, - { value: "anthropic", label: "Anthropic" }, - { value: "gemini", label: "Google Gemini" }, - { value: "ollama", label: "Ollama" }, -]; - const PROVIDER_DEFAULTS: Record< LLMProvider, { endpoint: string; model: string; dimensions?: number } @@ -58,14 +52,18 @@ export default function EmbeddingFormModal({ isOpen, onClose, onSave, initialDat setApiKey(initialData.api_key || ""); setModelName(initialData.model_name); setDimensions(initialData.dimensions?.toString() || ""); - } else { - // set defaults for new model - const defaults = PROVIDER_DEFAULTS[provider]; + } else if (isOpen) { + // set defaults for new model only when opening + const defaultProvider: LLMProvider = "openai"; + const defaults = PROVIDER_DEFAULTS[defaultProvider]; + setName(""); + setProvider(defaultProvider); setEndpoint(defaults.endpoint); setModelName(defaults.model); setDimensions(defaults.dimensions?.toString() || ""); + setApiKey(""); } - }, [initialData, provider]); + }, [isOpen, initialData]); const handleProviderChange = (newProvider: LLMProvider) => { setProvider(newProvider); @@ -92,8 +90,8 @@ export default function EmbeddingFormModal({ isOpen, onClose, onSave, initialDat if (provider !== "ollama" && !apiKey.trim()) { newErrors.apiKey = "api key is required for this provider"; } - if (dimensions && isNaN(parseInt(dimensions))) { - newErrors.dimensions = "dimensions must be a number"; + if (dimensions && (isNaN(Number(dimensions)) || Number(dimensions) < 1)) { + newErrors.dimensions = "dimensions must be a number greater than 0"; } setErrors(newErrors); @@ -110,6 +108,7 @@ export default function EmbeddingFormModal({ isOpen, onClose, onSave, initialDat api_key: apiKey.trim() || null, model_name: modelName.trim(), dimensions: dimensions ? parseInt(dimensions) : null, + is_default: initialData?.is_default ?? false, }; setSaving(true); @@ -163,10 +162,13 @@ export default function EmbeddingFormModal({ isOpen, onClose, onSave, initialDat Provider handleProviderChange(e.target.value as LLMProvider)} + onChange={(e) => { + const val = e.target.value; + if (isLLMProvider(val)) handleProviderChange(val); + }} block > - {PROVIDERS.map((p) => ( + {LLM_PROVIDERS.map((p) => ( {p.label} diff --git a/frontend/src/pages/Settings.tsx b/frontend/src/pages/Settings.tsx index cc5076f..aa12ec9 100644 --- a/frontend/src/pages/Settings.tsx +++ b/frontend/src/pages/Settings.tsx @@ -29,6 +29,9 @@ export default function Settings() { const isMountedRef = useRef(true); useEffect(() => { + // Reset to true on each mount (important for React StrictMode double-mount) + isMountedRef.current = true; + loadLlmModels(); loadEmbeddingModels(); loadLangfuseStatus(); @@ -64,14 +67,10 @@ export default function Settings() { const loadLangfuseStatus = async () => { try { - const res = await fetch("/api/langfuse/status"); - if (!res.ok) { - throw new Error(`http ${res.status}`); - } - const data = await res.json(); + const data = await llmConfigApi.getLangfuseStatus(); if (isMountedRef.current) { setLangfuseEnabled(data.enabled); - setLangfuseHost(data.host); + setLangfuseHost(data.host ?? null); } } catch (error) { const message = error instanceof Error ? error.message : "Unknown error"; @@ -214,7 +213,7 @@ export default function Settings() { loadEmbeddingModels(); } catch (error) { const message = error instanceof Error ? error.message : "Unknown error"; - toast.error(`Failed to save LLM model: ${message}`); + toast.error(`Failed to save embedding model: ${message}`); throw error; } }; diff --git a/frontend/src/services/llmConfigApi.ts b/frontend/src/services/llmConfigApi.ts index 8b1a9dd..fb10cec 100644 --- a/frontend/src/services/llmConfigApi.ts +++ b/frontend/src/services/llmConfigApi.ts @@ -139,6 +139,17 @@ class LLMConfigApi { if (!response.ok) throw new Error(`http ${response.status}`); return response.json(); } + + async getLangfuseStatus(): Promise<{ + enabled: boolean; + host?: string; + public_key?: string; + error?: string; + }> { + const response = await fetch(`${API_BASE}/langfuse/status`); + if (!response.ok) throw new Error(`http ${response.status}`); + return response.json(); + } } export const llmConfigApi = new LLMConfigApi(); diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 5e2a1e1..2ee19fb 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -89,6 +89,16 @@ export interface BlockSchema { export type LLMProvider = "openai" | "anthropic" | "gemini" | "ollama"; +export const LLM_PROVIDERS: { value: LLMProvider; label: string }[] = [ + { value: "openai", label: "OpenAI" }, + { value: "anthropic", label: "Anthropic" }, + { value: "gemini", label: "Google Gemini" }, + { value: "ollama", label: "Ollama" }, +]; + +export const isLLMProvider = (v: string): v is LLMProvider => + LLM_PROVIDERS.some((p) => p.value === v); + export interface LLMModelConfig { name: string; provider: LLMProvider; diff --git a/lib/entities/llm_config.py b/lib/entities/llm_config.py index 52e2519..13b25a7 100644 --- a/lib/entities/llm_config.py +++ b/lib/entities/llm_config.py @@ -40,6 +40,12 @@ def validate_str_fields(cls, v: str | None) -> str: """convert None to empty string for database compatibility""" return v if v is not None else "" + @field_validator("dimensions", mode="before") + @classmethod + def validate_dimensions(cls, v: int | None) -> int: + """coerce None to 0""" + return v if v is not None else 0 + class ConnectionTestResult(BaseModel): success: bool diff --git a/lib/storage.py b/lib/storage.py index ac233b9..85af99b 100644 --- a/lib/storage.py +++ b/lib/storage.py @@ -200,6 +200,16 @@ async def _migrate_schema(self, db: Connection) -> None: if "is_default" not in llm_column_names: await db.execute("ALTER TABLE llm_models ADD COLUMN is_default BOOLEAN DEFAULT 0") + # ensure at least one llm model is default if models exist + await db.execute( + """ + UPDATE llm_models + SET is_default = 1 + WHERE name = (SELECT name FROM llm_models ORDER BY name LIMIT 1) + AND (SELECT COUNT(*) FROM llm_models WHERE is_default = 1) = 0 + """ + ) + # migrate embedding_models table cursor = await db.execute("PRAGMA table_info(embedding_models)") embedding_columns = await cursor.fetchall() @@ -208,6 +218,16 @@ async def _migrate_schema(self, db: Connection) -> None: if "is_default" not in embedding_column_names: await db.execute("ALTER TABLE embedding_models ADD COLUMN is_default BOOLEAN DEFAULT 0") + # ensure at least one embedding model is default if models exist + await db.execute( + """ + UPDATE embedding_models + SET is_default = 1 + WHERE name = (SELECT name FROM embedding_models ORDER BY name LIMIT 1) + AND (SELECT COUNT(*) FROM embedding_models WHERE is_default = 1) = 0 + """ + ) + async def _migrate_env_to_db(self, db: Connection) -> None: """migrate .env config to database if no models configured""" # check if any llm models exist @@ -245,14 +265,26 @@ async def _migrate_env_to_db(self, db: Connection) -> None: async def _execute_with_connection(self, func: Callable[[Connection], Any]) -> Any: if self._conn: - result = await func(self._conn) - await self._conn.commit() - return result + try: + result = await func(self._conn) + await self._conn.commit() + except Exception: + logger.exception("transaction failed during _execute_with_connection") + await self._conn.rollback() + raise + else: + return result async with aiosqlite.connect(self.db_path) as db: - result = await func(db) - await db.commit() - return result + try: + result = await func(db) + await db.commit() + except Exception: + logger.exception("transaction failed during _execute_with_connection") + await db.rollback() + raise + else: + return result async def save_record( self, record: RecordCreate, pipeline_id: int | None = None, job_id: int | None = None @@ -670,43 +702,38 @@ async def save_llm_model(self, config: LLMModelConfig) -> None: """create or update llm model config (upsert)""" async def _save(db: Connection) -> None: - await db.execute("BEGIN") - try: - # check if this is the first model inside transaction - cursor = await db.execute("SELECT COUNT(*) FROM llm_models") - row = await cursor.fetchone() - count = row[0] if row else 0 - - final_is_default = config.is_default or count == 0 + if config.is_default: + await db.execute("UPDATE llm_models SET is_default = 0") - if final_is_default: - await db.execute("UPDATE llm_models SET is_default = 0") + await db.execute( + """ + INSERT INTO llm_models + (name, provider, endpoint, api_key, model_name, is_default) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(name) DO UPDATE SET + provider = excluded.provider, + endpoint = excluded.endpoint, + api_key = excluded.api_key, + model_name = excluded.model_name, + is_default = excluded.is_default + """, + ( + config.name, + config.provider.value, + config.endpoint, + config.api_key, + config.model_name, + config.is_default, + ), + ) + # self-healing: ensure at least one default model exists + cursor = await db.execute("SELECT COUNT(*) FROM llm_models WHERE is_default = 1") + row = await cursor.fetchone() + if not row or row[0] == 0: await db.execute( - """ - INSERT INTO llm_models - (name, provider, endpoint, api_key, model_name, is_default) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(name) DO UPDATE SET - provider = excluded.provider, - endpoint = excluded.endpoint, - api_key = excluded.api_key, - model_name = excluded.model_name, - is_default = excluded.is_default - """, - ( - config.name, - config.provider.value, - config.endpoint, - config.api_key, - config.model_name, - final_is_default, - ), + "UPDATE llm_models SET is_default = 1 WHERE name = ?", (config.name,) ) - await db.execute("COMMIT") - except Exception: - await db.execute("ROLLBACK") - raise await self._execute_with_connection(_save) @@ -734,6 +761,7 @@ async def _delete(db: Connection) -> bool: await db.execute("COMMIT") return deleted except Exception: + logger.exception(f"transaction failed during delete_llm_model for name={name}") await db.execute("ROLLBACK") raise @@ -757,6 +785,7 @@ async def _set_default(db: Connection) -> bool: await db.execute("COMMIT") return True except Exception: + logger.exception(f"transaction failed during set_default_llm_model for name={name}") await db.execute("ROLLBACK") raise @@ -809,45 +838,40 @@ async def save_embedding_model(self, config: EmbeddingModelConfig) -> None: """create or update embedding model config (upsert)""" async def _save(db: Connection) -> None: - await db.execute("BEGIN") - try: - # check if this is the first model inside transaction - cursor = await db.execute("SELECT COUNT(*) FROM embedding_models") - row = await cursor.fetchone() - count = row[0] if row else 0 - - final_is_default = config.is_default or count == 0 + if config.is_default: + await db.execute("UPDATE embedding_models SET is_default = 0") - if final_is_default: - await db.execute("UPDATE embedding_models SET is_default = 0") + await db.execute( + """ + INSERT INTO embedding_models + (name, provider, endpoint, api_key, model_name, dimensions, is_default) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(name) DO UPDATE SET + provider = excluded.provider, + endpoint = excluded.endpoint, + api_key = excluded.api_key, + model_name = excluded.model_name, + dimensions = excluded.dimensions, + is_default = excluded.is_default + """, + ( + config.name, + config.provider.value, + config.endpoint, + config.api_key, + config.model_name, + config.dimensions, + config.is_default, + ), + ) + # self-healing: ensure at least one default model exists + cursor = await db.execute("SELECT COUNT(*) FROM embedding_models WHERE is_default = 1") + row = await cursor.fetchone() + if not row or row[0] == 0: await db.execute( - """ - INSERT INTO embedding_models - (name, provider, endpoint, api_key, model_name, dimensions, is_default) - VALUES (?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(name) DO UPDATE SET - provider = excluded.provider, - endpoint = excluded.endpoint, - api_key = excluded.api_key, - model_name = excluded.model_name, - dimensions = excluded.dimensions, - is_default = excluded.is_default - """, - ( - config.name, - config.provider.value, - config.endpoint, - config.api_key, - config.model_name, - config.dimensions, - final_is_default, - ), + "UPDATE embedding_models SET is_default = 1 WHERE name = ?", (config.name,) ) - await db.execute("COMMIT") - except Exception: - await db.execute("ROLLBACK") - raise await self._execute_with_connection(_save) @@ -874,6 +898,9 @@ async def _delete(db: Connection) -> bool: await db.execute("COMMIT") return deleted except Exception: + logger.exception( + f"transaction failed during delete_embedding_model for name={name}" + ) await db.execute("ROLLBACK") raise @@ -899,6 +926,9 @@ async def _set_default(db: Connection) -> bool: await db.execute("COMMIT") return True except Exception: + logger.exception( + f"transaction failed during set_default_embedding_model for name={name}" + ) await db.execute("ROLLBACK") raise diff --git a/tests/integration/test_auto_default_logic.py b/tests/integration/test_auto_default_logic.py index 9e732c7..4193231 100644 --- a/tests/integration/test_auto_default_logic.py +++ b/tests/integration/test_auto_default_logic.py @@ -99,3 +99,186 @@ async def test_embedding_auto_default_logic(storage: Storage): saved_model2 = await storage.get_embedding_model("emb2") assert saved_model2.is_default is True, "Remaining single embedding model should become default" + + +@pytest.mark.asyncio +async def test_model_update_preserves_state(storage: Storage): + # Clear tables + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM llm_models")) + + # 1. Create a default model + model = LLMModelConfig( + name="test-model", + provider=LLMProvider.OPENAI, + model_name="gpt-4", + is_default=True, + ) + await storage.save_llm_model(model) + + # 2. Update the model (changing provider and model_name) + updated_model = LLMModelConfig( + name="test-model", + provider=LLMProvider.ANTHROPIC, + model_name="claude-3", + is_default=True, # Frontend will now send this + endpoint="https://api.anthropic.com", + ) + await storage.save_llm_model(updated_model) + + # 3. Verify all fields updated and is_default is still True + saved = await storage.get_llm_model("test-model") + assert saved is not None + assert saved.provider == LLMProvider.ANTHROPIC + assert saved.model_name == "claude-3" + assert saved.endpoint == "https://api.anthropic.com" + assert saved.is_default is True + + +@pytest.mark.asyncio +async def test_model_update_non_default_stays_non_default(storage: Storage): + # Clear tables + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM llm_models")) + + # 1. Create two models, first becomes default + model1 = LLMModelConfig( + name="m1", provider=LLMProvider.OPENAI, model_name="gpt-4", is_default=True + ) + model2 = LLMModelConfig( + name="m2", provider=LLMProvider.ANTHROPIC, model_name="claude-3", is_default=False + ) + await storage.save_llm_model(model1) + await storage.save_llm_model(model2) + + # 2. Update non-default model + updated = LLMModelConfig( + name="m2", provider=LLMProvider.OLLAMA, model_name="llama3", is_default=False + ) + await storage.save_llm_model(updated) + + saved = await storage.get_llm_model("m2") + assert saved is not None + assert saved.is_default is False + # verify m1 is still default + m1 = await storage.get_llm_model("m1") + assert m1 is not None + assert m1.is_default is True + + +@pytest.mark.asyncio +async def test_model_update_forces_default_if_only_one(storage: Storage): + # Clear tables + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM llm_models")) + + # 1. Create a model with is_default=False (but it will be forced to True as it's the only one) + model = LLMModelConfig( + name="only-one", provider=LLMProvider.OPENAI, model_name="gpt-4", is_default=False + ) + await storage.save_llm_model(model) + + saved = await storage.get_llm_model("only-one") + assert saved is not None + assert saved.is_default is True + + # 2. Update it specifically with is_default=False + updated = LLMModelConfig( + name="only-one", provider=LLMProvider.OPENAI, model_name="gpt-4", is_default=False + ) + await storage.save_llm_model(updated) + + # 3. Verify it is STILL default (self-healing) + saved = await storage.get_llm_model("only-one") + assert saved is not None + assert saved.is_default is True + + +@pytest.mark.asyncio +async def test_embedding_update_preserves_state(storage: Storage): + # Clear tables + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM embedding_models")) + + # 1. Create a default model + model = EmbeddingModelConfig( + name="test-embed", + provider=LLMProvider.OPENAI, + model_name="text-embedding-3-small", + is_default=True, + dimensions=1536, + ) + await storage.save_embedding_model(model) + + # 2. Update the model + updated_model = EmbeddingModelConfig( + name="test-embed", + provider=LLMProvider.OLLAMA, + model_name="mxbai-embed-large", + is_default=True, + dimensions=1024, + ) + await storage.save_embedding_model(updated_model) + + # 3. Verify + saved = await storage.get_embedding_model("test-embed") + assert saved is not None + assert saved.provider == LLMProvider.OLLAMA + assert saved.model_name == "mxbai-embed-large" + assert saved.dimensions == 1024 + assert saved.is_default is True + + +@pytest.mark.asyncio +async def test_embedding_update_non_default_stays_non_default(storage: Storage): + # Clear tables + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM embedding_models")) + + # 1. Create two models + m1 = EmbeddingModelConfig( + name="e1", provider=LLMProvider.OPENAI, model_name="text-3", is_default=True + ) + m2 = EmbeddingModelConfig( + name="e2", provider=LLMProvider.OPENAI, model_name="text-3", is_default=False + ) + await storage.save_embedding_model(m1) + await storage.save_embedding_model(m2) + + # 2. Update non-default + updated = EmbeddingModelConfig( + name="e2", provider=LLMProvider.GEMINI, model_name="embed-001", is_default=False + ) + await storage.save_embedding_model(updated) + + saved = await storage.get_embedding_model("e2") + assert saved is not None + assert saved.is_default is False + assert saved.provider == LLMProvider.GEMINI + + # verify e1 is still default + e1 = await storage.get_embedding_model("e1") + assert e1 is not None + assert e1.is_default is True + + +@pytest.mark.asyncio +async def test_embedding_update_forces_default_if_only_one(storage: Storage): + # Clear tables + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM embedding_models")) + + # 1. Create a model with is_default=False (but it will be forced to True as it's the only one) + model = EmbeddingModelConfig( + name="only-embed", provider=LLMProvider.OPENAI, model_name="text-3", is_default=False + ) + await storage.save_embedding_model(model) + + saved = await storage.get_embedding_model("only-embed") + assert saved is not None + assert saved.is_default is True + + # 2. Update it specifically with is_default=False + updated = EmbeddingModelConfig( + name="only-embed", provider=LLMProvider.OPENAI, model_name="text-3", is_default=False + ) + await storage.save_embedding_model(updated) + + # 3. Verify it is STILL default (self-healing) + saved = await storage.get_embedding_model("only-embed") + assert saved is not None + assert saved.is_default is True diff --git a/tests/test_api.py b/tests/test_api.py index 2694842..5f98a58 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -563,3 +563,15 @@ def test_set_default_embedding_model_nonexistent_returns_404(self, client): """Test PUT /api/embedding-models/{name}/default - not found""" response = client.put("/api/embedding-models/nonexistent/default") assert response.status_code == 404 + + +class TestAPILangfuse: + """Test Langfuse-related API endpoints""" + + def test_get_langfuse_status(self, client): + """Test GET /api/langfuse/status""" + response = client.get("/api/langfuse/status") + assert response.status_code == 200 + data = response.json() + assert "enabled" in data + assert isinstance(data["enabled"], bool)