diff --git a/app/src/components/ServerSettings/ModelManagement.tsx b/app/src/components/ServerSettings/ModelManagement.tsx index 4a5fd439..a5fafdbf 100644 --- a/app/src/components/ServerSettings/ModelManagement.tsx +++ b/app/src/components/ServerSettings/ModelManagement.tsx @@ -1,5 +1,5 @@ import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; -import { Download, Loader2, Trash2 } from 'lucide-react'; +import { ChevronDown, ChevronUp, Download, Loader2, RotateCcw, Trash2, X } from 'lucide-react'; import { useCallback, useState } from 'react'; import { AlertDialog, @@ -16,6 +16,7 @@ import { Button } from '@/components/ui/button'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; import { useToast } from '@/components/ui/use-toast'; import { apiClient } from '@/lib/api/client'; +import type { ActiveDownloadTask } from '@/lib/api/types'; import { useModelDownloadToast } from '@/lib/hooks/useModelDownloadToast'; export function ModelManagement() { @@ -23,6 +24,9 @@ export function ModelManagement() { const queryClient = useQueryClient(); const [downloadingModel, setDownloadingModel] = useState(null); const [downloadingDisplayName, setDownloadingDisplayName] = useState(null); + const [consoleOpen, setConsoleOpen] = useState(false); + const [dismissedErrors, setDismissedErrors] = useState>(new Set()); + const [localErrors, setLocalErrors] = useState>(new Map()); const { data: modelStatus, isLoading } = useQuery({ queryKey: ['modelStatus'], @@ -35,19 +39,57 @@ export function ModelManagement() { refetchInterval: 5000, // Refresh every 5 seconds }); + const { data: activeTasks } = useQuery({ + queryKey: ['activeTasks'], + queryFn: () => apiClient.getActiveTasks(), + refetchInterval: 5000, + }); + + // Build a map of errored downloads for quick lookup, excluding dismissed ones + // Merge server errors with locally captured SSE errors + const erroredDownloads = new Map(); + if (activeTasks?.downloads) { + for (const dl of activeTasks.downloads) { + if (dl.status === 'error' && !dismissedErrors.has(dl.model_name)) { + // Prefer locally captured error (from SSE) over server error + const localErr = localErrors.get(dl.model_name); + erroredDownloads.set(dl.model_name, localErr ? { ...dl, error: localErr } : dl); + } + } + } + // Also add locally captured errors that aren't in server response yet + for (const [modelName, error] of localErrors) { + if (!erroredDownloads.has(modelName) && !dismissedErrors.has(modelName)) { + erroredDownloads.set(modelName, { + model_name: modelName, + status: 'error', + started_at: new Date().toISOString(), + error, + }); + } + } + + const errorCount = erroredDownloads.size; + // Callbacks for download completion const handleDownloadComplete = useCallback(() => { console.log('[ModelManagement] Download complete, clearing state'); setDownloadingModel(null); setDownloadingDisplayName(null); queryClient.invalidateQueries({ queryKey: ['modelStatus'] }); + queryClient.invalidateQueries({ queryKey: ['activeTasks'] }); }, [queryClient]); - const handleDownloadError = useCallback(() => { + const handleDownloadError = useCallback((error: string) => { console.log('[ModelManagement] Download error, clearing state'); + if (downloadingModel) { + setLocalErrors((prev) => new Map(prev).set(downloadingModel, error)); + setConsoleOpen(true); + } setDownloadingModel(null); setDownloadingDisplayName(null); - }, []); + queryClient.invalidateQueries({ queryKey: ['activeTasks'] }); + }, [queryClient, downloadingModel]); // Use progress toast hook for the downloading model useModelDownloadToast({ @@ -67,11 +109,17 @@ export function ModelManagement() { const handleDownload = async (modelName: string) => { console.log('[Download] Button clicked for:', modelName, 'at', new Date().toISOString()); - + // Clear any previous dismissal so fresh errors can appear + setDismissedErrors((prev) => { + const next = new Set(prev); + next.delete(modelName); + return next; + }); + // Find display name const model = modelStatus?.models.find((m) => m.model_name === modelName); const displayName = model?.display_name || modelName; - + try { // IMPORTANT: Call the API FIRST before setting state // Setting state enables the SSE EventSource in useModelDownloadToast, @@ -79,14 +127,15 @@ export function ModelManagement() { console.log('[Download] Calling download API for:', modelName); const result = await apiClient.triggerModelDownload(modelName); console.log('[Download] Download API responded:', result); - + // NOW set state to enable SSE tracking (after download has started on backend) setDownloadingModel(modelName); setDownloadingDisplayName(displayName); - + // Download initiated successfully - state will be cleared when SSE reports completion // or by the polling interval detecting the model is downloaded queryClient.invalidateQueries({ queryKey: ['modelStatus'] }); + queryClient.invalidateQueries({ queryKey: ['activeTasks'] }); } catch (error) { console.error('[Download] Download failed:', error); setDownloadingModel(null); @@ -99,6 +148,53 @@ export function ModelManagement() { } }; + const cancelMutation = useMutation({ + mutationFn: (modelName: string) => apiClient.cancelDownload(modelName), + onSuccess: async () => { + await queryClient.invalidateQueries({ queryKey: ['modelStatus'], refetchType: 'all' }); + await queryClient.invalidateQueries({ queryKey: ['activeTasks'], refetchType: 'all' }); + }, + }); + + const handleCancel = (modelName: string) => { + // Snapshot previous state for rollback + const prevDismissed = dismissedErrors; + const prevLocalErrors = localErrors; + const prevDownloadingModel = downloadingModel; + const prevDownloadingDisplayName = downloadingDisplayName; + + // Optimistically hide the error and suppress downloading state in UI + setDismissedErrors((prev) => new Set(prev).add(modelName)); + setLocalErrors((prev) => { const next = new Map(prev); next.delete(modelName); return next; }); + if (downloadingModel === modelName) { + setDownloadingModel(null); + setDownloadingDisplayName(null); + } + + cancelMutation.mutate(modelName, { + onError: () => { + // Rollback optimistic updates on failure + setDismissedErrors(prevDismissed); + setLocalErrors(prevLocalErrors); + setDownloadingModel(prevDownloadingModel); + setDownloadingDisplayName(prevDownloadingDisplayName); + toast({ title: 'Cancel failed', description: 'Could not cancel the download task.', variant: 'destructive' }); + }, + }); + }; + + const clearAllMutation = useMutation({ + mutationFn: () => apiClient.clearAllTasks(), + onSuccess: async () => { + setDismissedErrors(new Set()); + setLocalErrors(new Map()); + setDownloadingModel(null); + setDownloadingDisplayName(null); + await queryClient.invalidateQueries({ queryKey: ['modelStatus'], refetchType: 'all' }); + await queryClient.invalidateQueries({ queryKey: ['activeTasks'], refetchType: 'all' }); + }, + }); + const deleteMutation = useMutation({ mutationFn: async (modelName: string) => { console.log('[Delete] Deleting model:', modelName); @@ -114,14 +210,11 @@ export function ModelManagement() { }); setDeleteDialogOpen(false); setModelToDelete(null); - // Invalidate AND explicitly refetch to ensure UI updates - // Using refetchType: 'all' ensures we refetch even if the query is stale console.log('[Delete] Invalidating modelStatus query'); - await queryClient.invalidateQueries({ + await queryClient.invalidateQueries({ queryKey: ['modelStatus'], refetchType: 'all', }); - // Also explicitly refetch to guarantee fresh data console.log('[Delete] Explicitly refetching modelStatus query'); await queryClient.refetchQueries({ queryKey: ['modelStatus'] }); console.log('[Delete] Query refetched'); @@ -178,7 +271,11 @@ export function ModelManagement() { }); setDeleteDialogOpen(true); }} + onCancel={() => handleCancel(model.model_name)} isDownloading={downloadingModel === model.model_name} + isCancelling={cancelMutation.isPending && cancelMutation.variables === model.model_name} + isDismissed={dismissedErrors.has(model.model_name)} + erroredDownload={erroredDownloads.get(model.model_name)} formatSize={formatSize} /> ))} @@ -206,13 +303,73 @@ export function ModelManagement() { }); setDeleteDialogOpen(true); }} + onCancel={() => handleCancel(model.model_name)} isDownloading={downloadingModel === model.model_name} + isCancelling={cancelMutation.isPending && cancelMutation.variables === model.model_name} + isDismissed={dismissedErrors.has(model.model_name)} + erroredDownload={erroredDownloads.get(model.model_name)} formatSize={formatSize} /> ))} + {/* Console Panel */} + {errorCount > 0 && ( +
+
+ + +
+ {consoleOpen && ( +
+ {Array.from(erroredDownloads.entries()).map(([modelName, dl]) => ( +
+ [error]{' '} + {modelName} + {dl.error ? ( + <> + {': '} + {dl.error} + + ) : ( + <> + {': '} + No error details available. Try downloading again. + + )} +
+ started at {new Date(dl.started_at).toLocaleString()} +
+
+ ))} +
+ )} +
+ )} ) : null} @@ -271,17 +428,22 @@ interface ModelItemProps { }; onDownload: () => void; onDelete: () => void; + onCancel: () => void; isDownloading: boolean; // Local state - true if user just clicked download + isCancelling: boolean; + isDismissed: boolean; + erroredDownload?: ActiveDownloadTask; formatSize: (sizeMb?: number) => string; } -function ModelItem({ model, onDownload, onDelete, isDownloading, formatSize }: ModelItemProps) { +function ModelItem({ model, onDownload, onDelete, onCancel, isDownloading, isCancelling, isDismissed, erroredDownload, formatSize }: ModelItemProps) { // Use server's downloading state OR local state (for immediate feedback before server updates) - const showDownloading = model.downloading || isDownloading; - + // Suppress downloading if user just dismissed/cancelled this model + const showDownloading = (model.downloading || isDownloading) && !erroredDownload && !isDismissed; + return (
-
+
{model.display_name} {model.loaded && ( @@ -289,21 +451,41 @@ function ModelItem({ model, onDownload, onDelete, isDownloading, formatSize }: M Loaded )} - {/* Only show Downloaded if actually downloaded AND not downloading */} - {model.downloaded && !model.loaded && !showDownloading && ( + {model.downloaded && !model.loaded && !showDownloading && !erroredDownload && ( Downloaded )} + {erroredDownload && ( + + Error + + )}
- {model.downloaded && model.size_mb && !showDownloading && ( + {model.downloaded && model.size_mb && !showDownloading && !erroredDownload && (
Size: {formatSize(model.size_mb)}
)}
-
- {model.downloaded && !showDownloading ? ( +
+ {erroredDownload ? ( +
+ + +
+ ) : model.downloaded && !showDownloading ? (
Ready @@ -319,10 +501,21 @@ function ModelItem({ model, onDownload, onDelete, isDownloading, formatSize }: M
) : showDownloading ? ( - +
+ + +
) : (
), - duration: progress.status === 'complete' ? 5000 : Infinity, - variant: progress.status === 'error' ? 'destructive' : 'default', + duration: progress.status === 'complete' || progress.status === 'error' ? 5000 : Infinity, }); // Close connection and dismiss toast on completion or error @@ -169,7 +168,7 @@ export function useModelDownloadToast({ onComplete(); } else if (isError && onError) { console.log('[useModelDownloadToast] Download error, calling onError callback'); - onError(); + onError(progress.error || 'Unknown error'); } } } diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index c4ecc090..60c015de 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -379,9 +379,17 @@ def _generate_sync(): return audio, sample_rate +WHISPER_HF_REPOS = { + "base": "openai/whisper-base", + "small": "openai/whisper-small", + "medium": "openai/whisper-medium", + "large": "openai/whisper-large-v3", +} + + class MLXSTTBackend: """MLX-based STT backend using mlx-audio Whisper.""" - + def __init__(self, model_size: str = "base"): self.model = None self.model_size = model_size @@ -402,8 +410,8 @@ def _is_model_cached(self, model_size: str) -> bool: """ try: from huggingface_hub import constants as hf_constants - model_name = f"openai/whisper-{model_size}" - repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--")) + hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") + repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--")) if not repo_cache.exists(): return False @@ -474,7 +482,7 @@ def _load_model_sync(self, model_size: str): from mlx_audio.stt import load # MLX Whisper uses the standard OpenAI models - model_name = f"openai/whisper-{model_size}" + model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") print(f"Loading MLX Whisper model {model_size}...") diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index d0cba11a..8059dc09 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -369,9 +369,17 @@ def _generate_sync(): return audio, sample_rate +WHISPER_HF_REPOS = { + "base": "openai/whisper-base", + "small": "openai/whisper-small", + "medium": "openai/whisper-medium", + "large": "openai/whisper-large-v3", +} + + class PyTorchSTTBackend: """PyTorch-based STT backend using Whisper.""" - + def __init__(self, model_size: str = "base"): self.model = None self.processor = None @@ -416,18 +424,18 @@ def _is_model_cached(self, model_size: str) -> bool: """ try: from huggingface_hub import constants as hf_constants - model_name = f"openai/whisper-{model_size}" - repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--")) - + hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") + repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--")) + if not repo_cache.exists(): return False - + # Check for .incomplete files - if any exist, download is still in progress blobs_dir = repo_cache / "blobs" if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")): print(f"[_is_model_cached] Found .incomplete files for whisper-{model_size}, treating as not cached") return False - + # Check that actual model weight files exist in snapshots snapshots_dir = repo_cache / "snapshots" if snapshots_dir.exists(): @@ -438,12 +446,12 @@ def _is_model_cached(self, model_size: str) -> bool: if not has_weights: print(f"[_is_model_cached] No model weights found for whisper-{model_size}, treating as not cached") return False - + return True except Exception as e: print(f"[_is_model_cached] Error checking cache for whisper-{model_size}: {e}") return False - + async def load_model_async(self, model_size: Optional[str] = None): """ Lazy load the Whisper model. @@ -494,7 +502,7 @@ def _load_model_sync(self, model_size: str): # Import transformers from transformers import WhisperProcessor, WhisperForConditionalGeneration - model_name = f"openai/whisper-{model_size}" + model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") print(f"[DEBUG] Model name: {model_name}") print(f"Loading Whisper model {model_size} on {self.device}...") diff --git a/backend/main.py b/backend/main.py index e218d237..c9d7e7f7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -932,7 +932,11 @@ async def transcribe_audio( # Check if Whisper model is downloaded (uses default size "base") model_size = whisper_model.model_size - model_name = f"openai/whisper-{model_size}" + # Map model sizes to HF repo IDs (whisper-large needs -v3 suffix) + whisper_hf_repos = { + "large": "openai/whisper-large-v3", + } + model_name = whisper_hf_repos.get(model_size, f"openai/whisper-{model_size}") # Check if model is cached from huggingface_hub import constants as hf_constants @@ -1310,14 +1314,14 @@ def check_whisper_loaded(model_size: str): whisper_base_id = "openai/whisper-base" whisper_small_id = "openai/whisper-small" whisper_medium_id = "openai/whisper-medium" - whisper_large_id = "openai/whisper-large" + whisper_large_id = "openai/whisper-large-v3" else: tts_1_7b_id = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" tts_0_6b_id = "Qwen/Qwen3-TTS-12Hz-0.6B-Base" whisper_base_id = "openai/whisper-base" whisper_small_id = "openai/whisper-small" whisper_medium_id = "openai/whisper-medium" - whisper_large_id = "openai/whisper-large" + whisper_large_id = "openai/whisper-large-v3" model_configs = [ { @@ -1586,6 +1590,42 @@ async def download_in_background(): return {"message": f"Model {request.model_name} download started"} +@app.post("/models/download/cancel") +async def cancel_model_download(request: models.ModelDownloadRequest): + """Cancel or dismiss an errored/stale download task.""" + task_manager = get_task_manager() + progress_manager = get_progress_manager() + + removed = task_manager.cancel_download(request.model_name) + + # Also clear progress state so the model doesn't show as downloading + progress_removed = False + with progress_manager._lock: + if request.model_name in progress_manager._progress: + del progress_manager._progress[request.model_name] + progress_removed = True + + if removed or progress_removed: + return {"message": f"Download task for {request.model_name} cancelled"} + return {"message": f"No active task found for {request.model_name}"} + + +@app.post("/tasks/clear") +async def clear_all_tasks(): + """Clear all download tasks and progress state. Does not delete downloaded files.""" + task_manager = get_task_manager() + progress_manager = get_progress_manager() + + task_manager.clear_all() + + with progress_manager._lock: + progress_manager._progress.clear() + progress_manager._last_notify_time.clear() + progress_manager._last_notify_progress.clear() + + return {"message": "All task state cleared"} + + @app.delete("/models/{model_name}") async def delete_model(model_name: str): """Delete a downloaded model from the HuggingFace cache.""" @@ -1621,12 +1661,12 @@ async def delete_model(model_name: str): "model_type": "whisper", }, "whisper-large": { - "hf_repo_id": "openai/whisper-large", + "hf_repo_id": "openai/whisper-large-v3", "model_size": "large", "model_type": "whisper", }, } - + if model_name not in model_configs: raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}") @@ -1710,10 +1750,18 @@ async def get_active_tasks(): progress = progress_map.get(model_name) if task: + # Prefer task error, fall back to progress manager error + error = task.error + if not error: + with progress_manager._lock: + pm_data = progress_manager._progress.get(model_name) + if pm_data: + error = pm_data.get("error") active_downloads.append(models.ActiveDownloadTask( model_name=model_name, status=task.status, started_at=task.started_at, + error=error, )) elif progress: # Progress exists but no task - create from progress data @@ -1730,6 +1778,7 @@ async def get_active_tasks(): model_name=model_name, status=progress.get("status", "downloading"), started_at=started_at, + error=progress.get("error"), )) # Get active generations diff --git a/backend/models.py b/backend/models.py index 59e45405..3f55b591 100644 --- a/backend/models.py +++ b/backend/models.py @@ -154,6 +154,7 @@ class ActiveDownloadTask(BaseModel): model_name: str status: str started_at: datetime + error: Optional[str] = None class ActiveGenerationTask(BaseModel): diff --git a/backend/utils/tasks.py b/backend/utils/tasks.py index 05b8e019..8baf71c3 100644 --- a/backend/utils/tasks.py +++ b/backend/utils/tasks.py @@ -72,6 +72,15 @@ def get_active_generations(self) -> List[GenerationTask]: """Get all active generations.""" return list(self._active_generations.values()) + def cancel_download(self, model_name: str) -> bool: + """Cancel/dismiss a download task (removes it from active list).""" + return self._active_downloads.pop(model_name, None) is not None + + def clear_all(self) -> None: + """Clear all download and generation tasks.""" + self._active_downloads.clear() + self._active_generations.clear() + def is_download_active(self, model_name: str) -> bool: """Check if a download is active.""" return model_name in self._active_downloads