diff --git a/app.py b/app.py index ee8e978..625b97d 100644 --- a/app.py +++ b/app.py @@ -29,7 +29,7 @@ from lib.errors import BlockExecutionError, BlockNotFoundError, ValidationError from lib.job_processor import process_job_in_thread from lib.job_queue import JobQueue -from lib.llm_config import LLMConfigManager, LLMConfigNotFoundError +from lib.llm_config import LLMConfigError, LLMConfigManager, LLMConfigNotFoundError from lib.storage import Storage from lib.templates import template_registry from lib.workflow import Pipeline as WorkflowPipeline @@ -698,6 +698,19 @@ async def delete_llm_model(name: str) -> dict[str, str]: raise HTTPException(status_code=404, detail=e.message) +@api_router.put("/llm-models/{name}/default") +async def set_default_llm_model(name: str) -> dict[str, str]: + """set default llm model""" + try: + await llm_config_manager.set_default_llm_model(name) + return {"message": "llm model set as default successfully"} + except LLMConfigNotFoundError as e: + raise HTTPException(status_code=404, detail=e.message) + except LLMConfigError as e: + logger.exception(f"failed to set default llm model {name}") + raise HTTPException(status_code=400, detail=e.message) from e + + @api_router.post("/llm-models/test") async def test_llm_connection(config: LLMModelConfig) -> ConnectionTestResult: """test llm connection""" @@ -753,6 +766,19 @@ async def delete_embedding_model(name: str) -> dict[str, str]: raise HTTPException(status_code=404, detail=e.message) +@api_router.put("/embedding-models/{name}/default") +async def set_default_embedding_model(name: str) -> dict[str, str]: + """set default embedding model""" + try: + await llm_config_manager.set_default_embedding_model(name) + return {"message": "embedding model set as default successfully"} + except LLMConfigNotFoundError as e: + raise HTTPException(status_code=404, detail=e.message) + except LLMConfigError as e: + logger.exception(f"failed to set default embedding model {name}") + raise HTTPException(status_code=400, detail=e.message) from e + + @api_router.post("/embedding-models/test") async def test_embedding_connection( config: EmbeddingModelConfig, diff --git a/frontend/src/components/settings/ModelCard.tsx b/frontend/src/components/settings/ModelCard.tsx new file mode 100644 index 0000000..f2c92aa --- /dev/null +++ b/frontend/src/components/settings/ModelCard.tsx @@ -0,0 +1,159 @@ +import type { ReactNode } from "react"; +import { Box, Text, Button, IconButton, Spinner, Tooltip } from "@primer/react"; +import { + TrashIcon, + PencilIcon, + CheckCircleIcon, + CheckCircleFillIcon, + StarIcon, +} from "@primer/octicons-react"; +import type { LLMModelConfig, EmbeddingModelConfig } from "../../types"; + +interface ModelCardStatus { + isDefault: boolean; + isTesting: boolean; + isSettingDefault: boolean; +} + +interface ModelCardActions { + onSetDefault: () => void; + onTest: () => void; + onEdit: () => void; + onDelete: () => void; +} + +interface ModelCardProps { + model: T; + status: ModelCardStatus; + actions: ModelCardActions; + extraDetails?: ReactNode; +} + +export function ModelCard({ + model, + status, + actions, + extraDetails, +}: ModelCardProps) { + const { isDefault, isTesting, isSettingDefault } = status; + const { onSetDefault, onTest, onEdit, onDelete } = actions; + + return ( + + + + {/* name and badges row */} + + {model.name} + + {model.provider} + + {/* isDefault renders the default badge with CheckCircleFillIcon to visually distinguish the selected model */} + {isDefault && ( + + + Default + + )} + + + {/* model details - model.model_name and model.endpoint; extraDetails may be appended for additional info like embedding dimensions */} + + model: {model.model_name} + {extraDetails} + + {model.endpoint} + + + {/* action buttons */} + + {!isDefault && ( + + + + )} + + + + + + + ); +} diff --git a/frontend/src/pages/Settings.tsx b/frontend/src/pages/Settings.tsx index 3c6dea0..cc5076f 100644 --- a/frontend/src/pages/Settings.tsx +++ b/frontend/src/pages/Settings.tsx @@ -1,19 +1,13 @@ -import { useEffect, useState } from "react"; -import { Box, Heading, Text, Button, IconButton, Spinner, Tooltip } from "@primer/react"; -import { - PlusIcon, - TrashIcon, - PencilIcon, - CheckCircleIcon, - CircleIcon, - CheckCircleFillIcon, -} from "@primer/octicons-react"; +import { useEffect, useState, useRef } from "react"; +import { Box, Heading, Text, Button, Spinner, Tooltip } from "@primer/react"; +import { PlusIcon, CircleIcon, CheckCircleFillIcon } from "@primer/octicons-react"; import { toast } from "sonner"; import type { LLMModelConfig, EmbeddingModelConfig } from "../types"; import { llmConfigApi } from "../services/llmConfigApi"; import LLMFormModal from "../components/settings/LLMFormModal"; import EmbeddingFormModal from "../components/settings/EmbeddingFormModal"; import { ConfirmModal } from "../components/ui/confirm-modal"; +import { ModelCard } from "../components/settings/ModelCard"; export default function Settings() { const [llmModels, setLlmModels] = useState([]); @@ -26,20 +20,30 @@ export default function Settings() { const [testingEmbedding, setTestingEmbedding] = useState(null); const [deletingLlm, setDeletingLlm] = useState(null); const [deletingEmbedding, setDeletingEmbedding] = useState(null); + const [settingDefaultLlm, setSettingDefaultLlm] = useState(null); + const [settingDefaultEmbedding, setSettingDefaultEmbedding] = useState(null); const [langfuseEnabled, setLangfuseEnabled] = useState(false); const [langfuseHost, setLangfuseHost] = useState(null); const [loadingLangfuse, setLoadingLangfuse] = useState(true); + const isMountedRef = useRef(true); + useEffect(() => { loadLlmModels(); loadEmbeddingModels(); loadLangfuseStatus(); + + return () => { + isMountedRef.current = false; + }; }, []); const loadLlmModels = async () => { try { const models = await llmConfigApi.listLLMModels(); - setLlmModels(models); + if (isMountedRef.current) { + setLlmModels(models); + } } catch (error) { const message = error instanceof Error ? error.message : "Unknown error"; toast.error(`Failed to load LLM models: ${message}`); @@ -49,7 +53,9 @@ export default function Settings() { const loadEmbeddingModels = async () => { try { const models = await llmConfigApi.listEmbeddingModels(); - setEmbeddingModels(models); + if (isMountedRef.current) { + setEmbeddingModels(models); + } } catch (error) { const message = error instanceof Error ? error.message : "Unknown error"; toast.error(`Failed to load embedding models: ${message}`); @@ -63,13 +69,17 @@ export default function Settings() { throw new Error(`http ${res.status}`); } const data = await res.json(); - setLangfuseEnabled(data.enabled); - setLangfuseHost(data.host); + if (isMountedRef.current) { + setLangfuseEnabled(data.enabled); + setLangfuseHost(data.host); + } } catch (error) { const message = error instanceof Error ? error.message : "Unknown error"; console.error("Failed to load Langfuse status:", message); } finally { - setLoadingLangfuse(false); + if (isMountedRef.current) { + setLoadingLangfuse(false); + } } }; @@ -105,10 +115,13 @@ export default function Settings() { toast.error(`Connection test failed: ${result.message}`); } } catch (error) { + console.error(error); const message = error instanceof Error ? error.message : "Unknown error"; toast.error(`Connection test failed: ${message}`); } finally { - setTestingLlm(null); + if (isMountedRef.current) { + setTestingLlm(null); + } } }; @@ -122,10 +135,49 @@ export default function Settings() { toast.error(`Connection test failed: ${result.message}`); } } catch (error) { + console.error(error); const message = error instanceof Error ? error.message : "Unknown error"; toast.error(`Connection test failed: ${message}`); } finally { - setTestingEmbedding(null); + if (isMountedRef.current) { + setTestingEmbedding(null); + } + } + }; + + const handleSetDefaultLlm = async (name: string) => { + if (settingDefaultLlm === name) return; + setSettingDefaultLlm(name); + try { + await llmConfigApi.setDefaultLLMModel(name); + toast.success("Default LLM model updated"); + loadLlmModels(); + } catch (error) { + console.error(error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to set default LLM model: ${message}`); + } finally { + if (isMountedRef.current) { + setSettingDefaultLlm(null); + } + } + }; + + const handleSetDefaultEmbedding = async (name: string) => { + if (settingDefaultEmbedding === name) return; + setSettingDefaultEmbedding(name); + try { + await llmConfigApi.setDefaultEmbeddingModel(name); + toast.success("Default embedding model updated"); + loadEmbeddingModels(); + } catch (error) { + console.error(error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to set default embedding model: ${message}`); + } finally { + if (isMountedRef.current) { + setSettingDefaultEmbedding(null); + } } }; @@ -209,106 +261,24 @@ export default function Settings() { ) : ( {llmModels.map((model) => ( - - - - - - {model.name} - - - {model.provider} - - {model.name === "default" && ( - - default - - )} - - - - model: {model.model_name} - - - {model.endpoint} - - - - - - - { - setEditingLlm(model); - setLlmModalOpen(true); - }} - /> - setDeletingLlm(model.name)} - /> - - - + actions={{ + onSetDefault: () => handleSetDefaultLlm(model.name), + onTest: () => handleTestLlm(model), + onEdit: () => { + setEditingLlm(model); + setLlmModalOpen(true); + }, + onDelete: () => setDeletingLlm(model.name), + }} + /> ))} )} @@ -352,90 +322,25 @@ export default function Settings() { ) : ( {embeddingModels.map((model) => ( - - - - - - {model.name} - - - {model.provider} - - - - model: {model.model_name} - {model.dimensions && ` (${model.dimensions}d)`} - - - {model.endpoint} - - - - - - { - setEditingEmbedding(model); - setEmbeddingModalOpen(true); - }} - /> - setDeletingEmbedding(model.name)} - /> - - - + actions={{ + onSetDefault: () => handleSetDefaultEmbedding(model.name), + onTest: () => handleTestEmbedding(model), + onEdit: () => { + setEditingEmbedding(model); + setEmbeddingModalOpen(true); + }, + onDelete: () => setDeletingEmbedding(model.name), + }} + extraDetails={model.dimensions ? ` (${model.dimensions}d)` : undefined} + /> ))} )} diff --git a/frontend/src/services/llmConfigApi.ts b/frontend/src/services/llmConfigApi.ts index 5b572d7..8b1a9dd 100644 --- a/frontend/src/services/llmConfigApi.ts +++ b/frontend/src/services/llmConfigApi.ts @@ -50,6 +50,16 @@ class LLMConfigApi { } } + async setDefaultLLMModel(name: string): Promise { + const response = await fetch(`${API_BASE}/llm-models/${encodeURIComponent(name)}/default`, { + method: "PUT", + }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || `http ${response.status}`); + } + } + async testLLMConnection(config: LLMModelConfig): Promise { const response = await fetch(`${API_BASE}/llm-models/test`, { method: "POST", @@ -107,6 +117,19 @@ class LLMConfigApi { } } + async setDefaultEmbeddingModel(name: string): Promise { + const response = await fetch( + `${API_BASE}/embedding-models/${encodeURIComponent(name)}/default`, + { + method: "PUT", + } + ); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || `http ${response.status}`); + } + } + async testEmbeddingConnection(config: EmbeddingModelConfig): Promise { const response = await fetch(`${API_BASE}/embedding-models/test`, { method: "POST", diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index c736226..5e2a1e1 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -95,6 +95,7 @@ export interface LLMModelConfig { endpoint: string; api_key: string | null; model_name: string; + is_default?: boolean; } export interface EmbeddingModelConfig { @@ -104,6 +105,7 @@ export interface EmbeddingModelConfig { api_key: string | null; model_name: string; dimensions: number | null; + is_default?: boolean; } export interface ConnectionTestResult { diff --git a/lib/entities/llm_config.py b/lib/entities/llm_config.py index 46dbb32..52e2519 100644 --- a/lib/entities/llm_config.py +++ b/lib/entities/llm_config.py @@ -16,6 +16,7 @@ class LLMModelConfig(BaseModel): endpoint: str = "" api_key: str = "" model_name: str = Field(..., min_length=1) + is_default: bool = False @field_validator("endpoint", "api_key", mode="before") @classmethod @@ -30,6 +31,7 @@ class EmbeddingModelConfig(BaseModel): endpoint: str = "" api_key: str = "" model_name: str = Field(..., min_length=1) + is_default: bool = False dimensions: int = 0 @field_validator("endpoint", "api_key", mode="before") diff --git a/lib/llm_config.py b/lib/llm_config.py index 420d2ad..4030678 100644 --- a/lib/llm_config.py +++ b/lib/llm_config.py @@ -41,9 +41,10 @@ async def get_llm_model(self, name: str | None = None) -> LLMModelConfig: uses fallback chain to ensure blocks always have a model available: 1. requested name - 2. model named "default" - 3. first model in db - 4. .env fallback (LLM_ENDPOINT, LLM_API_KEY, LLM_MODEL) + 2. model marked as default (is_default=True) + 3. model named "default" (legacy) + 4. first model in db + 5. .env fallback (LLM_ENDPOINT, LLM_API_KEY, LLM_MODEL) """ if name: config = await self.storage.get_llm_model(name) @@ -53,14 +54,18 @@ async def get_llm_model(self, name: str | None = None) -> LLMModelConfig: f"llm model '{name}' not found", detail={"requested_name": name} ) - # try default model - config = await self.storage.get_llm_model("default") - if config: - return config - - # try first model + # try explicit default model or model named "default" all_models = await self.storage.list_llm_models() if all_models: + # check for is_default=True + for model in all_models: + if model.is_default: + return model + # fallback to name="default" + for model in all_models: + if model.name == "default": + return model + # fallback to first model return all_models[0] # fallback to .env @@ -93,6 +98,12 @@ async def delete_llm_model(self, name: str) -> None: if not success: raise LLMConfigNotFoundError(f"llm model '{name}' not found", detail={"name": name}) + async def set_default_llm_model(self, name: str) -> None: + """set default llm model""" + success = await self.storage.set_default_llm_model(name) + if not success: + raise LLMConfigNotFoundError(f"llm model '{name}' not found", detail={"name": name}) + async def test_llm_connection(self, config: LLMModelConfig) -> ConnectionTestResult: """test llm connection with simple prompt @@ -122,8 +133,9 @@ async def get_embedding_model(self, name: str | None = None) -> EmbeddingModelCo fallback chain: 1. requested name - 2. model named "default" - 3. first model in db + 2. model marked as default (is_default=True) + 3. model named "default" (legacy) + 4. first model in db """ if name: config = await self.storage.get_embedding_model(name) @@ -133,14 +145,18 @@ async def get_embedding_model(self, name: str | None = None) -> EmbeddingModelCo f"embedding model '{name}' not found", detail={"requested_name": name} ) - # try default model - config = await self.storage.get_embedding_model("default") - if config: - return config - - # try first model + # try explicit default model or model named "default" all_models = await self.storage.list_embedding_models() if all_models: + # check for is_default=True + for model in all_models: + if model.is_default: + return model + # fallback to name="default" + for model in all_models: + if model.name == "default": + return model + # fallback to first model return all_models[0] raise LLMConfigNotFoundError( @@ -163,6 +179,14 @@ async def delete_embedding_model(self, name: str) -> None: f"embedding model '{name}' not found", detail={"name": name} ) + async def set_default_embedding_model(self, name: str) -> None: + """set default embedding model""" + success = await self.storage.set_default_embedding_model(name) + if not success: + raise LLMConfigNotFoundError( + f"embedding model '{name}' not found", detail={"name": name} + ) + async def test_embedding_connection(self, config: EmbeddingModelConfig) -> ConnectionTestResult: """test embedding connection with simple text diff --git a/lib/storage.py b/lib/storage.py index b4b56b1..ac233b9 100644 --- a/lib/storage.py +++ b/lib/storage.py @@ -192,6 +192,22 @@ async def _migrate_schema(self, db: Connection) -> None: if "metadata" not in job_column_names: await db.execute("ALTER TABLE jobs ADD COLUMN metadata TEXT") + # migrate llm_models table + cursor = await db.execute("PRAGMA table_info(llm_models)") + llm_columns = await cursor.fetchall() + llm_column_names = [col[1] for col in llm_columns] + + if "is_default" not in llm_column_names: + await db.execute("ALTER TABLE llm_models ADD COLUMN is_default BOOLEAN DEFAULT 0") + + # migrate embedding_models table + cursor = await db.execute("PRAGMA table_info(embedding_models)") + embedding_columns = await cursor.fetchall() + embedding_column_names = [col[1] for col in embedding_columns] + + if "is_default" not in embedding_column_names: + await db.execute("ALTER TABLE embedding_models ADD COLUMN is_default BOOLEAN DEFAULT 0") + async def _migrate_env_to_db(self, db: Connection) -> None: """migrate .env config to database if no models configured""" # check if any llm models exist @@ -214,8 +230,8 @@ async def _migrate_env_to_db(self, db: Connection) -> None: # create default model from .env await db.execute( """ - INSERT INTO llm_models (name, provider, endpoint, api_key, model_name) - VALUES (?, ?, ?, ?, ?) + INSERT INTO llm_models (name, provider, endpoint, api_key, model_name, is_default) + VALUES (?, ?, ?, ?, ?, ?) """, ( "default", @@ -223,6 +239,7 @@ async def _migrate_env_to_db(self, db: Connection) -> None: settings.LLM_ENDPOINT, settings.LLM_API_KEY if settings.LLM_API_KEY else None, settings.LLM_MODEL, + True, # make env model default if it's the only one ), ) @@ -622,6 +639,7 @@ async def _list(db: Connection) -> list[LLMModelConfig]: endpoint=row["endpoint"], api_key=row["api_key"], model_name=row["model_name"], + is_default=bool(row["is_default"]), ) for row in rows ] @@ -643,6 +661,7 @@ async def _get(db: Connection) -> LLMModelConfig | None: endpoint=row["endpoint"], api_key=row["api_key"], model_name=row["model_name"], + is_default=bool(row["is_default"]), ) return await self._execute_with_connection(_get) @@ -651,24 +670,43 @@ async def save_llm_model(self, config: LLMModelConfig) -> None: """create or update llm model config (upsert)""" async def _save(db: Connection) -> None: - await db.execute( - """ - INSERT INTO llm_models (name, provider, endpoint, api_key, model_name) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(name) DO UPDATE SET - provider = excluded.provider, - endpoint = excluded.endpoint, - api_key = excluded.api_key, - model_name = excluded.model_name - """, - ( - config.name, - config.provider.value, - config.endpoint, - config.api_key, - config.model_name, - ), - ) + await db.execute("BEGIN") + try: + # check if this is the first model inside transaction + cursor = await db.execute("SELECT COUNT(*) FROM llm_models") + row = await cursor.fetchone() + count = row[0] if row else 0 + + final_is_default = config.is_default or count == 0 + + if final_is_default: + await db.execute("UPDATE llm_models SET is_default = 0") + + await db.execute( + """ + INSERT INTO llm_models + (name, provider, endpoint, api_key, model_name, is_default) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(name) DO UPDATE SET + provider = excluded.provider, + endpoint = excluded.endpoint, + api_key = excluded.api_key, + model_name = excluded.model_name, + is_default = excluded.is_default + """, + ( + config.name, + config.provider.value, + config.endpoint, + config.api_key, + config.model_name, + final_is_default, + ), + ) + await db.execute("COMMIT") + except Exception: + await db.execute("ROLLBACK") + raise await self._execute_with_connection(_save) @@ -676,11 +714,54 @@ async def delete_llm_model(self, name: str) -> bool: """delete llm model config""" async def _delete(db: Connection) -> bool: - cursor = await db.execute("DELETE FROM llm_models WHERE name = ?", (name,)) - return cursor.rowcount > 0 + await db.execute("BEGIN") + try: + cursor = await db.execute("DELETE FROM llm_models WHERE name = ?", (name,)) + deleted = cursor.rowcount > 0 + + if deleted: + # if we deleted the default model (or the last default), pick a new one + # this query updates a model to default ONLY IF no default currently exists + await db.execute( + """ + UPDATE llm_models + SET is_default = 1 + WHERE name = (SELECT name FROM llm_models ORDER BY name LIMIT 1) + AND (SELECT COUNT(*) FROM llm_models WHERE is_default = 1) = 0 + """ + ) + + await db.execute("COMMIT") + return deleted + except Exception: + await db.execute("ROLLBACK") + raise return await self._execute_with_connection(_delete) + async def set_default_llm_model(self, name: str) -> bool: + """set default llm model""" + + async def _set_default(db: Connection) -> bool: + # check if model exists + cursor = await db.execute("SELECT 1 FROM llm_models WHERE name = ?", (name,)) + if not await cursor.fetchone(): + return False + + await db.execute("BEGIN") + try: + # reset all to false + await db.execute("UPDATE llm_models SET is_default = 0") + # set selected to true + await db.execute("UPDATE llm_models SET is_default = 1 WHERE name = ?", (name,)) + await db.execute("COMMIT") + return True + except Exception: + await db.execute("ROLLBACK") + raise + + return await self._execute_with_connection(_set_default) + async def list_embedding_models(self) -> list[EmbeddingModelConfig]: """list all configured embedding models""" @@ -695,6 +776,7 @@ async def _list(db: Connection) -> list[EmbeddingModelConfig]: endpoint=row["endpoint"], api_key=row["api_key"], model_name=row["model_name"], + is_default=bool(row["is_default"]), dimensions=row["dimensions"] or 0, ) for row in rows @@ -717,6 +799,7 @@ async def _get(db: Connection) -> EmbeddingModelConfig | None: endpoint=row["endpoint"], api_key=row["api_key"], model_name=row["model_name"], + is_default=bool(row["is_default"]), dimensions=row["dimensions"] or 0, ) @@ -726,27 +809,45 @@ async def save_embedding_model(self, config: EmbeddingModelConfig) -> None: """create or update embedding model config (upsert)""" async def _save(db: Connection) -> None: - await db.execute( - """ - INSERT INTO embedding_models - (name, provider, endpoint, api_key, model_name, dimensions) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(name) DO UPDATE SET - provider = excluded.provider, - endpoint = excluded.endpoint, - api_key = excluded.api_key, - model_name = excluded.model_name, - dimensions = excluded.dimensions - """, - ( - config.name, - config.provider.value, - config.endpoint, - config.api_key, - config.model_name, - config.dimensions, - ), - ) + await db.execute("BEGIN") + try: + # check if this is the first model inside transaction + cursor = await db.execute("SELECT COUNT(*) FROM embedding_models") + row = await cursor.fetchone() + count = row[0] if row else 0 + + final_is_default = config.is_default or count == 0 + + if final_is_default: + await db.execute("UPDATE embedding_models SET is_default = 0") + + await db.execute( + """ + INSERT INTO embedding_models + (name, provider, endpoint, api_key, model_name, dimensions, is_default) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(name) DO UPDATE SET + provider = excluded.provider, + endpoint = excluded.endpoint, + api_key = excluded.api_key, + model_name = excluded.model_name, + dimensions = excluded.dimensions, + is_default = excluded.is_default + """, + ( + config.name, + config.provider.value, + config.endpoint, + config.api_key, + config.model_name, + config.dimensions, + final_is_default, + ), + ) + await db.execute("COMMIT") + except Exception: + await db.execute("ROLLBACK") + raise await self._execute_with_connection(_save) @@ -754,11 +855,55 @@ async def delete_embedding_model(self, name: str) -> bool: """delete embedding model config""" async def _delete(db: Connection) -> bool: - cursor = await db.execute("DELETE FROM embedding_models WHERE name = ?", (name,)) - return cursor.rowcount > 0 + await db.execute("BEGIN") + try: + cursor = await db.execute("DELETE FROM embedding_models WHERE name = ?", (name,)) + deleted = cursor.rowcount > 0 + + if deleted: + # if we deleted the default model (or the last default), pick a new one + await db.execute( + """ + UPDATE embedding_models + SET is_default = 1 + WHERE name = (SELECT name FROM embedding_models ORDER BY name LIMIT 1) + AND (SELECT COUNT(*) FROM embedding_models WHERE is_default = 1) = 0 + """ + ) + + await db.execute("COMMIT") + return deleted + except Exception: + await db.execute("ROLLBACK") + raise return await self._execute_with_connection(_delete) + async def set_default_embedding_model(self, name: str) -> bool: + """set default embedding model""" + + async def _set_default(db: Connection) -> bool: + # check if model exists + cursor = await db.execute("SELECT 1 FROM embedding_models WHERE name = ?", (name,)) + if not await cursor.fetchone(): + return False + + await db.execute("BEGIN") + try: + # reset all to false + await db.execute("UPDATE embedding_models SET is_default = 0") + # set selected to true + await db.execute( + "UPDATE embedding_models SET is_default = 1 WHERE name = ?", (name,) + ) + await db.execute("COMMIT") + return True + except Exception: + await db.execute("ROLLBACK") + raise + + return await self._execute_with_connection(_set_default) + def _row_to_record(self, row: aiosqlite.Row) -> Record: return Record( id=row["id"], diff --git a/llm/state-backend.md b/llm/state-backend.md index d6660f3..817877c 100644 --- a/llm/state-backend.md +++ b/llm/state-backend.md @@ -85,6 +85,7 @@ config.py # env Settings - `POST /api/llm-models` - create config - `PUT /api/llm-models/{name}` - update config - `DELETE /api/llm-models/{name}` - delete config +- `PUT /api/llm-models/{name}/default` - set default model - `POST /api/llm-models/test` - test connection ### embedding config @@ -93,6 +94,7 @@ config.py # env Settings - `POST /api/embedding-models` - create config - `PUT /api/embedding-models/{name}` - update config - `DELETE /api/embedding-models/{name}` - delete config +- `PUT /api/embedding-models/{name}/default` - set default model - `POST /api/embedding-models/test` - test connection ## database schema diff --git a/llm/state-frontend.md b/llm/state-frontend.md index e4248f5..be4f406 100644 --- a/llm/state-frontend.md +++ b/llm/state-frontend.md @@ -30,6 +30,7 @@ frontend/src/ StartEndNode.tsx # circular start/end utils.ts # format conversion settings/ + ModelCard.tsx # reusable model card (status, actions objects) LLMFormModal.tsx # llm config form EmbeddingFormModal.tsx # embedding config form ui/ # shadcn components @@ -69,11 +70,14 @@ frontend/src/ - view stability: tracks by ID, single mode preserves current record ### Settings.tsx -- LLM/embedding model management -- provider/model selection (OpenAI, Anthropic, Ollama, etc) +- LLM/embedding model management via ModelCard components +- provider/model selection (OpenAI, Anthropic, Ollama, etc.) - API key configuration -- connection testing -- default model selection +- connection testing with loading states +- explicit "Set Default" button per model (shows spinner while setting) +- default model badge with CheckCircleFillIcon for visual distinction +- mounted guards in async handlers to prevent state updates on unmount +- console.error logging before toast.error for debugging ## components @@ -157,7 +161,7 @@ shadcn radix-ui dialog, replaces browser confirm() **endpoints:** - GET /api/blocks, /api/templates, /api/pipelines, /api/jobs/active, /api/jobs/{id}, /api/records - POST /api/pipelines, /api/pipelines/from_template/{id}, /api/generate, /api/seeds/validate -- PUT /api/records/{id}, /api/llm-models/{name}, /api/embedding-models/{name} +- PUT /api/records/{id}, /api/llm-models/{name}, /api/embedding-models/{name}, /api/llm-models/{name}/default, /api/embedding-models/{name}/default - DELETE /api/pipelines/{id}, /api/jobs/{id}, /api/records - GET /api/export/download, /api/llm-models, /api/embedding-models diff --git a/llm/state-project.md b/llm/state-project.md index 9c571a3..5a8f51f 100644 --- a/llm/state-project.md +++ b/llm/state-project.md @@ -388,7 +388,7 @@ production-ready full-stack data generation platform - structured errors with context - sqlite with migrations - type-safe BlockExecutionContext -- LLM/embedding config management (multi-provider) +- LLM/embedding config management (multi-provider) + default model selection - 4 pages: Pipelines, Generator, Review, Settings - primer + dark mode - accumulated state visualization diff --git a/tests/integration/test_auto_default_logic.py b/tests/integration/test_auto_default_logic.py new file mode 100644 index 0000000..9e732c7 --- /dev/null +++ b/tests/integration/test_auto_default_logic.py @@ -0,0 +1,101 @@ +import pytest + +from lib.entities import EmbeddingModelConfig, LLMModelConfig, LLMProvider +from lib.storage import Storage + + +@pytest.mark.asyncio +async def test_llm_auto_default_logic(storage: Storage): + # Clear tables to remove auto-migrated models + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM llm_models")) + + # 1. Test auto-default on first creation + model1 = LLMModelConfig( + name="model1", + provider=LLMProvider.OPENAI, + model_name="gpt-4", + is_default=False, # Explicitly False + ) + await storage.save_llm_model(model1) + + saved_model1 = await storage.get_llm_model("model1") + assert saved_model1 is not None + assert saved_model1.is_default is True, ( + "First model should be auto-set to default even if is_default=False" + ) + + # 2. Test adds second model (should NOT be default) + model2 = LLMModelConfig( + name="model2", provider=LLMProvider.ANTHROPIC, model_name="claude-3", is_default=False + ) + await storage.save_llm_model(model2) + + saved_model2 = await storage.get_llm_model("model2") + assert saved_model2.is_default is False + + # Verify model1 is still default + saved_model1 = await storage.get_llm_model("model1") + assert saved_model1.is_default is True + + # 3. Test auto-default on delete to one + # Delete model1 (default), model2 should become default + await storage.delete_llm_model("model1") + + saved_model2 = await storage.get_llm_model("model2") + assert saved_model2.is_default is True, "Remaining single model should become default" + + # 4. Test default reassignment when multiple models exist + # Setup: Create model3, ensure model2 is default. + model3 = LLMModelConfig( + name="model3", provider=LLMProvider.OLLAMA, model_name="llama2", is_default=False + ) + await storage.save_llm_model(model3) + + # model2 is currently default. model3 is not. + m2 = await storage.get_llm_model("model2") + m3 = await storage.get_llm_model("model3") + assert m2.is_default is True + assert m3.is_default is False + + # Delete the current default (model2) + # We expect model3 to become default (since it's the only other one, or alphabetical) + await storage.delete_llm_model("model2") + + saved_model3 = await storage.get_llm_model("model3") + assert saved_model3.is_default is True, ( + "Deleting default model should reassign default to available model" + ) + + +@pytest.mark.asyncio +async def test_embedding_auto_default_logic(storage: Storage): + # Clear tables to remove auto-migrated models + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM embedding_models")) + + # 1. Test auto-default on first creation + model1 = EmbeddingModelConfig( + name="emb1", + provider=LLMProvider.OPENAI, + model_name="text-embedding-3-small", + is_default=False, + ) + await storage.save_embedding_model(model1) + + saved_model1 = await storage.get_embedding_model("emb1") + assert saved_model1 is not None + assert saved_model1.is_default is True, "First embedding model should be auto-set to default" + + # 2. Add second model + model2 = EmbeddingModelConfig( + name="emb2", provider=LLMProvider.GEMINI, model_name="embedding-001", is_default=False + ) + await storage.save_embedding_model(model2) + + saved_model2 = await storage.get_embedding_model("emb2") + assert saved_model2.is_default is False + + # 3. Test delete to one + await storage.delete_embedding_model("emb1") + + saved_model2 = await storage.get_embedding_model("emb2") + assert saved_model2.is_default is True, "Remaining single embedding model should become default" diff --git a/tests/integration/test_default_model_selection_integration.py b/tests/integration/test_default_model_selection_integration.py new file mode 100644 index 0000000..9541365 --- /dev/null +++ b/tests/integration/test_default_model_selection_integration.py @@ -0,0 +1,147 @@ +import pytest +import pytest_asyncio + +from lib.entities import EmbeddingModelConfig, LLMModelConfig, LLMProvider +from lib.llm_config import LLMConfigManager, LLMConfigNotFoundError +from lib.storage import Storage + + +@pytest_asyncio.fixture +async def storage(): + """create in-memory storage for testing""" + storage = Storage(":memory:") + await storage.init_db() + + # Clear any models created by auto-migration from env + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM llm_models")) + await storage._execute_with_connection(lambda db: db.execute("DELETE FROM embedding_models")) + + yield storage + await storage.close() + + +@pytest_asyncio.fixture +async def llm_config_manager(storage): + """create llm config manager with test storage""" + return LLMConfigManager(storage) + + +@pytest.mark.asyncio +async def test_llm_default_selection_flow(llm_config_manager): + """ + Test the flow of setting and retrieving default LLM models. + + Verifies: + 1. Fallback to first model when no default is set. + 2. Explicit default selection. + 3. Ensuring only one model is default at a time. + 4. Fallback to 'default' named model (legacy support). + """ + + # 1. Create a few models + model1 = LLMModelConfig( + name="gpt-4", provider=LLMProvider.OPENAI, model_name="gpt-4", is_default=False + ) + model2 = LLMModelConfig( + name="claude-3", + provider=LLMProvider.ANTHROPIC, + model_name="claude-3-opus", + is_default=False, + ) + model3 = LLMModelConfig( + name="gemini-pro", provider=LLMProvider.GEMINI, model_name="gemini-pro", is_default=False + ) + + await llm_config_manager.save_llm_model(model1) + await llm_config_manager.save_llm_model(model2) + await llm_config_manager.save_llm_model(model3) + + # Validation 1: No explicit default, should return first one (ordering might depend on DB, usually insertion order) + # We just ensure it returns *one* of them. + default_model = await llm_config_manager.get_llm_model(None) + assert default_model.name in ["gpt-4", "claude-3", "gemini-pro"] + + # Validation 2: Set model2 as default + await llm_config_manager.set_default_llm_model("claude-3") + + # Check if retrieval returns model2 + default_model = await llm_config_manager.get_llm_model(None) + assert default_model.name == "claude-3" + assert default_model.is_default is True + + # Verify others are NOT default + m1 = await llm_config_manager.get_llm_model("gpt-4") + m3 = await llm_config_manager.get_llm_model("gemini-pro") + assert m1.is_default is False + assert m3.is_default is False + + # Validation 3: Switch default to model3 + await llm_config_manager.set_default_llm_model("gemini-pro") + + default_model = await llm_config_manager.get_llm_model(None) + assert default_model.name == "gemini-pro" + assert default_model.is_default is True + + # Verify model2 is no longer default + m2 = await llm_config_manager.get_llm_model("claude-3") + assert m2.is_default is False + + +@pytest.mark.asyncio +async def test_embedding_default_selection_flow(llm_config_manager): + """ + Test the flow of setting and retrieving default Embedding models. + + Verifies: + 1. Fallback to first model when no default is set. + 2. Explicit default selection. + 3. Ensuring only one model is default at a time. + 4. Switching default model updates correctly. + """ + embed1 = EmbeddingModelConfig( + name="openai-embed", + provider=LLMProvider.OPENAI, + model_name="text-embedding-3-small", + is_default=False, + ) + embed2 = EmbeddingModelConfig( + name="local-embed", + provider=LLMProvider.OLLAMA, + model_name="nomic-embed-text", + is_default=False, + ) + + await llm_config_manager.save_embedding_model(embed1) + await llm_config_manager.save_embedding_model(embed2) + + # 1. No default set, returns one of them + default_model = await llm_config_manager.get_embedding_model(None) + assert default_model.name in ["openai-embed", "local-embed"] + + # 2. Set default + await llm_config_manager.set_default_embedding_model("local-embed") + + default_model = await llm_config_manager.get_embedding_model(None) + assert default_model.name == "local-embed" + assert default_model.is_default is True + + # Check other is not default + e1 = await llm_config_manager.get_embedding_model("openai-embed") + assert e1.is_default is False + + # 3. Switch default + await llm_config_manager.set_default_embedding_model("openai-embed") + + default_model = await llm_config_manager.get_embedding_model(None) + assert default_model.name == "openai-embed" + assert default_model.is_default is True + + e2 = await llm_config_manager.get_embedding_model("local-embed") + assert e2.is_default is False + + +@pytest.mark.asyncio +async def test_set_nonexistent_default_raises_error(llm_config_manager): + """Test setting a non-existent model as default raises LLMConfigNotFoundError""" + with pytest.raises(LLMConfigNotFoundError): + await llm_config_manager.set_default_llm_model("non_existent_model") diff --git a/tests/test_api.py b/tests/test_api.py index 6b793a2..2694842 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -523,3 +523,43 @@ def test_execute_nonexistent_pipeline(self, client): """Test executing non-existent pipeline""" response = client.post("/api/pipelines/999999/execute", json={"text": "test"}) assert response.status_code == 404 + + +class TestAPIDefaultModelSelection: + """Test default model selection API endpoints""" + + def test_set_default_llm_model_success_returns_message(self, client): + """Test PUT /api/llm-models/{name}/default - success""" + model_config = { + "name": "test-llm", + "provider": "openai", + "model_name": "gpt-4", + "api_key": "test-key", + } + client.post("/api/llm-models", json=model_config) + response = client.put("/api/llm-models/test-llm/default") + assert response.status_code == 200 + assert response.json()["message"] == "llm model set as default successfully" + + def test_set_default_llm_model_nonexistent_returns_404(self, client): + """Test PUT /api/llm-models/{name}/default - not found""" + response = client.put("/api/llm-models/nonexistent/default") + assert response.status_code == 404 + + def test_set_default_embedding_model_success_returns_message(self, client): + """Test PUT /api/embedding-models/{name}/default - success""" + model_config = { + "name": "test-embed", + "provider": "openai", + "model_name": "text-embedding-3-small", + "api_key": "test-key", + } + client.post("/api/embedding-models", json=model_config) + response = client.put("/api/embedding-models/test-embed/default") + assert response.status_code == 200 + assert response.json()["message"] == "embedding model set as default successfully" + + def test_set_default_embedding_model_nonexistent_returns_404(self, client): + """Test PUT /api/embedding-models/{name}/default - not found""" + response = client.put("/api/embedding-models/nonexistent/default") + assert response.status_code == 404