diff --git a/src-tauri/capabilities/main.json b/src-tauri/capabilities/main.json index 2195cacf..15fd230d 100644 --- a/src-tauri/capabilities/main.json +++ b/src-tauri/capabilities/main.json @@ -53,6 +53,9 @@ "allow": [ { "url": "https://generativelanguage.googleapis.com/**" + }, + { + "url": "http://localhost:11434/**" } ] }, diff --git a/src/features/ai/components/history/sidebar.tsx b/src/features/ai/components/history/sidebar.tsx index e0aa6a2b..1c476fac 100644 --- a/src/features/ai/components/history/sidebar.tsx +++ b/src/features/ai/components/history/sidebar.tsx @@ -104,7 +104,7 @@ export default function ChatHistorySidebar({ onClose(); }} isSelected={index === selectedIndex} - className="px-3 py-1.5" + className="group px-3 py-1.5" >
{chat.title}
diff --git a/src/features/ai/components/input/chat-input-bar.tsx b/src/features/ai/components/input/chat-input-bar.tsx index 00aaa2c0..74349215 100644 --- a/src/features/ai/components/input/chat-input-bar.tsx +++ b/src/features/ai/components/input/chat-input-bar.tsx @@ -4,6 +4,7 @@ import { useAIChatStore } from "@/features/ai/store/store"; import type { AIChatInputBarProps } from "@/features/ai/types/ai-chat"; import { getModelById } from "@/features/ai/types/providers"; import { useEditorSettingsStore } from "@/features/editor/stores/settings-store"; +import { useToast } from "@/features/layout/contexts/toast-context"; import { useSettingsStore } from "@/features/settings/store"; import { useUIState } from "@/stores/ui-state-store"; import Button from "@/ui/button"; @@ -528,9 +529,17 @@ const AIChatInputBar = memo(function AIChatInputBar({ { + const { dynamicModels } = useAIChatStore.getState(); + const providerModels = dynamicModels[settings.aiProviderId]; + const dynamicModel = providerModels?.find((m) => m.id === settings.aiModelId); + if (dynamicModel) return dynamicModel.name; + + return ( + getModelById(settings.aiProviderId, settings.aiModelId)?.name || + settings.aiModelId + ); + })()} onSelect={(providerId, modelId) => { updateSetting("aiProviderId", providerId); updateSetting("aiModelId", modelId); diff --git a/src/features/ai/components/selectors/model-selector-dropdown.tsx b/src/features/ai/components/selectors/model-selector-dropdown.tsx index 85b4acf7..dcf369ed 100644 --- a/src/features/ai/components/selectors/model-selector-dropdown.tsx +++ b/src/features/ai/components/selectors/model-selector-dropdown.tsx @@ -1,7 +1,10 @@ import { Check, ChevronDown, Key, Search } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { useAIChatStore } from "@/features/ai/store/store"; import { getAvailableProviders } from "@/features/ai/types/providers"; import { cn } from "@/utils/cn"; +import { getProvider } from "@/utils/providers"; +import type { ProviderModel } from "@/utils/providers/provider-interface"; interface ModelSelectorDropdownProps { currentProviderId: string; @@ -23,12 +26,41 @@ export function ModelSelectorDropdown({ const [isOpen, setIsOpen] = useState(false); const [search, setSearch] = useState(""); const [selectedIndex, setSelectedIndex] = useState(0); + const { dynamicModels, setDynamicModels } = useAIChatStore(); const triggerRef = useRef(null); const dropdownRef = useRef(null); const inputRef = useRef(null); const providers = getAvailableProviders(); + // Fetch dynamic models on mount + useEffect(() => { + const fetchModels = async () => { + for (const provider of providers) { + // Skip if we already have models for this provider + if (dynamicModels[provider.id]?.length > 0) continue; + + // Only fetch dynamic models if provider supports it AND does not require an API key + // This enforces static lists for cloud providers like OpenAI + if (provider.requiresApiKey) continue; + + const providerInstance = getProvider(provider.id); + if (providerInstance?.getModels) { + try { + const models = await providerInstance.getModels(); + if (models.length > 0) { + setDynamicModels(provider.id, models); + } + } catch (error) { + console.error(`Failed to fetch models for ${provider.id}:`, error); + } + } + } + }; + + fetchModels(); + }, [providers, dynamicModels, setDynamicModels]); + const filteredItems = useMemo(() => { const items: Array<{ type: "provider" | "model"; @@ -45,7 +77,11 @@ export function ModelSelectorDropdown({ for (const provider of providers) { const providerMatches = provider.name.toLowerCase().includes(searchLower); const providerHasKey = !provider.requiresApiKey || hasApiKey(provider.id); - const matchingModels = provider.models.filter( + + // Use dynamic models if available, otherwise use static models + const models = dynamicModels[provider.id] || provider.models; + + const matchingModels = models.filter( (model) => providerMatches || model.name.toLowerCase().includes(searchLower) || @@ -63,7 +99,7 @@ export function ModelSelectorDropdown({ // Only show models if provider has API key or doesn't require one if (providerHasKey) { - const modelsToShow = search ? matchingModels : provider.models; + const modelsToShow = search ? matchingModels : models; for (const model of modelsToShow) { items.push({ type: "model", @@ -78,7 +114,7 @@ export function ModelSelectorDropdown({ } return items; - }, [providers, search, hasApiKey]); + }, [providers, search, hasApiKey, dynamicModels]); const selectableItems = useMemo( () => filteredItems.filter((item) => item.type === "model"), @@ -183,7 +219,12 @@ export function ModelSelectorDropdown({
{filteredItems.length === 0 ? ( -
No models found
+
+ {providers.find((p) => p.id === currentProviderId)?.id === "ollama" && + !dynamicModels["ollama"]?.length + ? "No models detected. Please install a model." + : "No models found"} +
) : ( filteredItems.map((item) => { if (item.type === "provider") { diff --git a/src/features/ai/store/store.ts b/src/features/ai/store/store.ts index b8911df0..422007b4 100644 --- a/src/features/ai/store/store.ts +++ b/src/features/ai/store/store.ts @@ -18,6 +18,7 @@ import { loadChatFromDb, saveChatToDb, } from "@/utils/chat-history-db"; +import type { ProviderModel } from "@/utils/providers/provider-interface"; import type { AIChatActions, AIChatState } from "./types"; export const useAIChatStore = create()( @@ -45,6 +46,7 @@ export const useAIChatStore = create()( providerApiKeys: new Map(), apiKeyModalState: { isOpen: false, providerId: null }, + dynamicModels: {}, mentionState: { active: false, @@ -313,8 +315,10 @@ export const useAIChatStore = create()( checkApiKey: async (providerId) => { try { - // Claude Code doesn't require an API key in the frontend - if (providerId === "claude-code") { + const provider = AI_PROVIDERS.find((p) => p.id === providerId); + + // If provider doesn't require an API key, set hasApiKey to true + if (provider && !provider.requiresApiKey) { set((state) => { state.hasApiKey = true; }); @@ -338,8 +342,8 @@ export const useAIChatStore = create()( for (const provider of AI_PROVIDERS) { try { - // Claude Code doesn't require an API key in the frontend - if (provider.id === "claude-code") { + // If provider doesn't require an API key, mark it as having one + if (!provider.requiresApiKey) { newApiKeyMap.set(provider.id, true); continue; } @@ -366,7 +370,7 @@ export const useAIChatStore = create()( const newApiKeyMap = new Map(); for (const provider of AI_PROVIDERS) { try { - if (provider.id === "claude-code") { + if (!provider.requiresApiKey) { newApiKeyMap.set(provider.id, true); continue; } @@ -381,7 +385,8 @@ export const useAIChatStore = create()( }); // Update hasApiKey for current provider - if (providerId === "claude-code") { + const currentProvider = AI_PROVIDERS.find((p) => p.id === providerId); + if (currentProvider && !currentProvider.requiresApiKey) { set((state) => { state.hasApiKey = true; }); @@ -409,7 +414,7 @@ export const useAIChatStore = create()( const newApiKeyMap = new Map(); for (const provider of AI_PROVIDERS) { try { - if (provider.id === "claude-code") { + if (!provider.requiresApiKey) { newApiKeyMap.set(provider.id, true); continue; } @@ -424,7 +429,8 @@ export const useAIChatStore = create()( }); // Update hasApiKey for current provider - if (providerId === "claude-code") { + const currentProvider = AI_PROVIDERS.find((p) => p.id === providerId); + if (currentProvider && !currentProvider.requiresApiKey) { set((state) => { state.hasApiKey = true; }); @@ -443,6 +449,11 @@ export const useAIChatStore = create()( return get().providerApiKeys.get(providerId) || false; }, + setDynamicModels: (providerId, models) => + set((state) => { + state.dynamicModels[providerId] = models; + }), + // Mention actions showMention: (position, search, startIndex) => set((state) => { diff --git a/src/features/ai/store/types.ts b/src/features/ai/store/types.ts index 47101687..fab3ac0f 100644 --- a/src/features/ai/store/types.ts +++ b/src/features/ai/store/types.ts @@ -1,5 +1,6 @@ import type { Chat, Message } from "@/features/ai/types/ai-chat"; import type { FileEntry } from "@/features/file-system/types/app"; +import type { ProviderModel } from "@/utils/providers/provider-interface"; export type OutputStyle = "default" | "explanatory" | "learning" | "custom"; export type ChatMode = "chat" | "plan"; @@ -34,6 +35,9 @@ export interface AIChatState { providerApiKeys: Map; apiKeyModalState: { isOpen: boolean; providerId: string | null }; + // Dynamic models state + dynamicModels: Record; + // Mention state mentionState: { active: boolean; @@ -93,6 +97,9 @@ export interface AIChatActions { removeApiKey: (providerId: string) => Promise; hasProviderApiKey: (providerId: string) => boolean; + // Dynamic models actions + setDynamicModels: (providerId: string, models: ProviderModel[]) => void; + // Mention actions showMention: ( position: { top: number; left: number }, diff --git a/src/features/ai/types/providers.ts b/src/features/ai/types/providers.ts index 39bdcfc3..b9e5d4f8 100644 --- a/src/features/ai/types/providers.ts +++ b/src/features/ai/types/providers.ts @@ -408,6 +408,13 @@ export const AI_PROVIDERS: ModelProvider[] = [ }, ], }, + { + id: "ollama", + name: "Ollama (Local)", + apiUrl: "http://localhost:11434/v1/chat/completions", + requiresApiKey: false, + models: [], + }, ]; // Track Claude Code availability diff --git a/src/features/settings/components/tabs/ai-settings.tsx b/src/features/settings/components/tabs/ai-settings.tsx index a43ed981..14d20295 100644 --- a/src/features/settings/components/tabs/ai-settings.tsx +++ b/src/features/settings/components/tabs/ai-settings.tsx @@ -1,5 +1,15 @@ import { invoke } from "@tauri-apps/api/core"; -import { AlertCircle, Check, CheckCircle, Eye, EyeOff, Key, Trash2, X } from "lucide-react"; +import { + AlertCircle, + Check, + CheckCircle, + Eye, + EyeOff, + Key, + RefreshCw, + Trash2, + X, +} from "lucide-react"; import { useEffect, useState } from "react"; import { useAIChatStore } from "@/features/ai/store/store"; import type { ClaudeStatus } from "@/features/ai/types/claude"; @@ -9,6 +19,7 @@ import Button from "@/ui/button"; import Dropdown from "@/ui/dropdown"; import Section, { SettingRow } from "@/ui/section"; import { cn } from "@/utils/cn"; +import { getProvider } from "@/utils/providers"; export const AISettings = () => { const { settings, updateSetting } = useSettingsStore(); @@ -37,6 +48,11 @@ export const AISettings = () => { message?: string; }>({ providerId: null, status: null }); + // Dynamic models state + const { dynamicModels, setDynamicModels } = useAIChatStore(); + const [isLoadingModels, setIsLoadingModels] = useState(false); + const [modelFetchError, setModelFetchError] = useState(null); + // API Key functions from AI chat store const saveApiKey = useAIChatStore((state) => state.saveApiKey); const removeApiKey = useAIChatStore((state) => state.removeApiKey); @@ -48,31 +64,65 @@ export const AISettings = () => { checkAllProviderApiKeys(); }, [checkAllProviderApiKeys]); - const currentProvider = getAvailableProviders().find((p) => p.id === settings.aiProviderId); + const providers = getAvailableProviders(); + const currentProvider = providers.find((p) => p.id === settings.aiProviderId); + + // Fetch dynamic models if provider supports it + const fetchDynamicModels = async () => { + const providerInstance = getProvider(settings.aiProviderId); + const providerConfig = providers.find((p) => p.id === settings.aiProviderId); + + // Always clear error when fetching/switching + setModelFetchError(null); + + // Only fetch dynamic models if provider supports it AND does not require an API key (unless explicitly allowed) + // This enforces static lists for cloud providers like OpenAI as requested + if (providerInstance?.getModels && !providerConfig?.requiresApiKey) { + setIsLoadingModels(true); + try { + const models = await providerInstance.getModels(); + if (models.length > 0) { + setDynamicModels(settings.aiProviderId, models); + // If current model is not in the list, select the first one + if (!models.find((m) => m.id === settings.aiModelId)) { + updateSetting("aiModelId", models[0].id); + } + } else { + setDynamicModels(settings.aiProviderId, []); + const errorMessage = + settings.aiProviderId === "ollama" + ? "No models detected. Please install a model in Ollama." + : "No models found."; + setModelFetchError(errorMessage); + } + } catch (error) { + console.error("Failed to fetch models:", error); + setModelFetchError("Failed to fetch models"); + } finally { + setIsLoadingModels(false); + } + } + }; + + useEffect(() => { + fetchDynamicModels(); + }, [settings.aiProviderId, updateSetting, setDynamicModels]); const providerOptions = getAvailableProviders().map((provider) => ({ value: provider.id, label: provider.name, })); - const modelOptions = - currentProvider?.models.map((model) => ({ - value: model.id, - label: model.name, - })) || []; - const handleProviderChange = (providerId: string) => { const provider = getAvailableProviders().find((p) => p.id === providerId); - if (provider && provider.models.length > 0) { + if (provider) { updateSetting("aiProviderId", providerId); - updateSetting("aiModelId", provider.models[0].id); + // Reset model ID, it will be updated by fetchDynamicModels or default logic + if (provider.models.length > 0) { + updateSetting("aiModelId", provider.models[0].id); + } } }; - - const handleModelChange = (modelId: string) => { - updateSetting("aiModelId", modelId); - }; - const startEditing = (providerId: string) => { setEditingProvider(providerId); setApiKeyInput(""); @@ -257,6 +307,9 @@ export const AISettings = () => { (p) => p.requiresAuth && !p.requiresApiKey, ); + const providerInstance = getProvider(settings.aiProviderId); + const supportsDynamicModels = !!providerInstance?.getModels; + return (
@@ -271,13 +324,40 @@ export const AISettings = () => { - +
+
+ +
+ {supportsDynamicModels && ( + + )} +
+ {modelFetchError && ( +
+ + {modelFetchError} +
+ )}
diff --git a/src/utils/ai-chat.ts b/src/utils/ai-chat.ts index 4bfaabf8..593c7d68 100644 --- a/src/utils/ai-chat.ts +++ b/src/utils/ai-chat.ts @@ -1,4 +1,5 @@ import { fetch as tauriFetch } from "@tauri-apps/plugin-http"; +import { useAIChatStore } from "@/features/ai/store/store"; import type { ChatMode, OutputStyle } from "@/features/ai/store/types"; import type { AIMessage } from "@/features/ai/types/messages"; import { getModelById, getProviderById } from "@/features/ai/types/providers"; @@ -35,7 +36,20 @@ export const getChatCompletionStream = async ( ): Promise => { try { const provider = getProviderById(providerId); - const model = getModelById(providerId, modelId); + + // Check for model in static list or dynamic store + let model = getModelById(providerId, modelId); + if (!model) { + const { dynamicModels } = useAIChatStore.getState(); + const providerModels = dynamicModels[providerId]; + const dynamicModel = providerModels?.find((m) => m.id === modelId); + if (dynamicModel) { + model = { + ...dynamicModel, + maxTokens: dynamicModel.maxTokens || 4096, // Default max tokens if missing + }; + } + } if (!provider || !model) { throw new Error(`Provider or model not found: ${providerId}/${modelId}`); @@ -102,8 +116,8 @@ export const getChatCompletionStream = async ( console.log(`Making ${provider.name} streaming chat request with model ${model.name}...`); - // Use Tauri's fetch for Gemini to bypass CORS restrictions - const fetchFn = providerId === "gemini" ? tauriFetch : fetch; + // Use Tauri's fetch for Gemini and Ollama to bypass CORS restrictions + const fetchFn = providerId === "gemini" || providerId === "ollama" ? tauriFetch : fetch; const response = await fetchFn(url, { method: "POST", headers, @@ -119,10 +133,9 @@ export const getChatCompletionStream = async ( return; } - // Use stream processing utility await processStreamingResponse(response, onChunk, onComplete, onError); - } catch (error) { + } catch (error: any) { console.error(`${providerId} streaming chat completion error:`, error); - onError(`Failed to connect to ${providerId} API`); + onError(`Failed to connect to ${providerId} API: ${error.message || error}`); } }; diff --git a/src/utils/providers/index.ts b/src/utils/providers/index.ts index 28cc9a8e..dac148b0 100644 --- a/src/utils/providers/index.ts +++ b/src/utils/providers/index.ts @@ -1,5 +1,6 @@ import { GeminiProvider } from "./gemini-provider"; import { GrokProvider } from "./grok-provider"; +import { OllamaProvider } from "./ollama-provider"; import { OpenAIProvider } from "./openai-provider"; import { OpenRouterProvider } from "./openrouter-provider"; import type { AIProvider, ProviderConfig } from "./provider-interface"; @@ -43,6 +44,15 @@ function initializeProviders(): void { maxTokens: 131072, }; providers.set("grok", new GrokProvider(grokConfig)); + + const ollamaConfig: ProviderConfig = { + id: "ollama", + name: "Ollama", + apiUrl: "http://localhost:11434/v1/chat/completions", + requiresApiKey: false, + maxTokens: 4096, + }; + providers.set("ollama", new OllamaProvider(ollamaConfig)); } export function getProvider(providerId: string): AIProvider | undefined { diff --git a/src/utils/providers/ollama-provider.ts b/src/utils/providers/ollama-provider.ts new file mode 100644 index 00000000..a8e95c2d --- /dev/null +++ b/src/utils/providers/ollama-provider.ts @@ -0,0 +1,48 @@ +import { AIProvider, type ProviderHeaders, type StreamRequest } from "./provider-interface"; + +export class OllamaProvider extends AIProvider { + buildHeaders(_apiKey?: string): ProviderHeaders { + // Ollama typically doesn't require headers, but we can add Content-Type + return { + "Content-Type": "application/json", + }; + } + + buildPayload(request: StreamRequest): any { + return { + model: request.modelId, + messages: request.messages, + stream: true, + temperature: request.temperature, + max_tokens: request.maxTokens, + }; + } + + async validateApiKey(_apiKey: string): Promise { + // Ollama doesn't require an API key by default + return true; + } + + // Override buildUrl to point to the local Ollama instance + buildUrl(_request: StreamRequest): string { + return "http://localhost:11434/v1/chat/completions"; + } + + async getModels(): Promise { + try { + const response = await fetch("http://localhost:11434/api/tags"); + if (!response.ok) { + throw new Error("Failed to fetch models"); + } + const data = await response.json(); + return data.models.map((model: any) => ({ + id: model.name, + name: model.name, + maxTokens: 4096, // Default for now as Ollama doesn't always provide this + })); + } catch (error) { + console.error("Failed to fetch Ollama models:", error); + return []; + } + } +} diff --git a/src/utils/providers/provider-interface.ts b/src/utils/providers/provider-interface.ts index 17b67c49..79069743 100644 --- a/src/utils/providers/provider-interface.ts +++ b/src/utils/providers/provider-interface.ts @@ -20,6 +20,12 @@ export interface StreamRequest { apiKey?: string; } +export interface ProviderModel { + id: string; + name: string; + maxTokens?: number; +} + export abstract class AIProvider { constructor(protected config: ProviderConfig) {} @@ -30,6 +36,11 @@ export abstract class AIProvider { // Optional: Allows providers to customize the URL (e.g., add API key as query param) buildUrl?(request: StreamRequest): string; + // Optional: Allows providers to fetch available models dynamically + async getModels?(): Promise { + return []; + } + get id(): string { return this.config.id; }