Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 217 additions & 24 deletions app/src/components/ServerSettings/ModelManagement.tsx

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions app/src/lib/api/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,22 @@ class ApiClient {
});
}

async cancelDownload(modelName: string): Promise<{ message: string }> {
return this.request<{ message: string }>('/models/download/cancel', {
method: 'POST',
body: JSON.stringify({ model_name: modelName } as ModelDownloadRequest),
});
}

// Task Management
async getActiveTasks(): Promise<ActiveTasksResponse> {
return this.request<ActiveTasksResponse>('/tasks/active');
}

async clearAllTasks(): Promise<{ message: string }> {
return this.request<{ message: string }>('/tasks/clear', { method: 'POST' });
}

// Audio Channels
async listChannels(): Promise<
Array<{
Expand Down
1 change: 1 addition & 0 deletions app/src/lib/api/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ export interface ActiveDownloadTask {
model_name: string;
status: string;
started_at: string;
error?: string;
}

export interface ActiveGenerationTask {
Expand Down
9 changes: 4 additions & 5 deletions app/src/lib/hooks/useModelDownloadToast.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ interface UseModelDownloadToastOptions {
displayName: string;
enabled?: boolean;
onComplete?: () => void;
onError?: () => void;
onError?: (error: string) => void;
}

/**
Expand Down Expand Up @@ -101,7 +101,7 @@ export function useModelDownloadToast({
break;
case 'error':
statusIcon = <XCircle className="h-4 w-4 text-destructive" />;
statusText = `Error: ${progress.error || 'Unknown error'}`;
statusText = 'Download failed. See Problems panel for details.';
break;
case 'downloading':
statusIcon = <Loader2 className="h-4 w-4 animate-spin" />;
Expand Down Expand Up @@ -131,8 +131,7 @@ export function useModelDownloadToast({
)}
</div>
),
duration: progress.status === 'complete' ? 5000 : Infinity,
variant: progress.status === 'error' ? 'destructive' : 'default',
duration: progress.status === 'complete' || progress.status === 'error' ? 5000 : Infinity,
});

// Close connection and dismiss toast on completion or error
Expand Down Expand Up @@ -169,7 +168,7 @@ export function useModelDownloadToast({
onComplete();
} else if (isError && onError) {
console.log('[useModelDownloadToast] Download error, calling onError callback');
onError();
onError(progress.error || 'Unknown error');
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions backend/backends/mlx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,17 @@ def _generate_sync():
return audio, sample_rate


WHISPER_HF_REPOS = {
"base": "openai/whisper-base",
"small": "openai/whisper-small",
"medium": "openai/whisper-medium",
"large": "openai/whisper-large-v3",
}


class MLXSTTBackend:
"""MLX-based STT backend using mlx-audio Whisper."""

def __init__(self, model_size: str = "base"):
self.model = None
self.model_size = model_size
Expand All @@ -402,8 +410,8 @@ def _is_model_cached(self, model_size: str) -> bool:
"""
try:
from huggingface_hub import constants as hf_constants
model_name = f"openai/whisper-{model_size}"
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--"))
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))

if not repo_cache.exists():
return False
Expand Down Expand Up @@ -474,7 +482,7 @@ def _load_model_sync(self, model_size: str):
from mlx_audio.stt import load

# MLX Whisper uses the standard OpenAI models
model_name = f"openai/whisper-{model_size}"
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")

print(f"Loading MLX Whisper model {model_size}...")

Expand Down
26 changes: 17 additions & 9 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,17 @@ def _generate_sync():
return audio, sample_rate


WHISPER_HF_REPOS = {
"base": "openai/whisper-base",
"small": "openai/whisper-small",
"medium": "openai/whisper-medium",
"large": "openai/whisper-large-v3",
}


class PyTorchSTTBackend:
"""PyTorch-based STT backend using Whisper."""

def __init__(self, model_size: str = "base"):
self.model = None
self.processor = None
Expand Down Expand Up @@ -416,18 +424,18 @@ def _is_model_cached(self, model_size: str) -> bool:
"""
try:
from huggingface_hub import constants as hf_constants
model_name = f"openai/whisper-{model_size}"
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--"))
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))

if not repo_cache.exists():
return False

# Check for .incomplete files - if any exist, download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
print(f"[_is_model_cached] Found .incomplete files for whisper-{model_size}, treating as not cached")
return False

# Check that actual model weight files exist in snapshots
snapshots_dir = repo_cache / "snapshots"
if snapshots_dir.exists():
Expand All @@ -438,12 +446,12 @@ def _is_model_cached(self, model_size: str) -> bool:
if not has_weights:
print(f"[_is_model_cached] No model weights found for whisper-{model_size}, treating as not cached")
return False

return True
except Exception as e:
print(f"[_is_model_cached] Error checking cache for whisper-{model_size}: {e}")
return False

async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the Whisper model.
Expand Down Expand Up @@ -494,7 +502,7 @@ def _load_model_sync(self, model_size: str):
# Import transformers
from transformers import WhisperProcessor, WhisperForConditionalGeneration

model_name = f"openai/whisper-{model_size}"
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
print(f"[DEBUG] Model name: {model_name}")

print(f"Loading Whisper model {model_size} on {self.device}...")
Expand Down
59 changes: 54 additions & 5 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,11 @@ async def transcribe_audio(

# Check if Whisper model is downloaded (uses default size "base")
model_size = whisper_model.model_size
model_name = f"openai/whisper-{model_size}"
# Map model sizes to HF repo IDs (whisper-large needs -v3 suffix)
whisper_hf_repos = {
"large": "openai/whisper-large-v3",
}
model_name = whisper_hf_repos.get(model_size, f"openai/whisper-{model_size}")

# Check if model is cached
from huggingface_hub import constants as hf_constants
Expand Down Expand Up @@ -1310,14 +1314,14 @@ def check_whisper_loaded(model_size: str):
whisper_base_id = "openai/whisper-base"
whisper_small_id = "openai/whisper-small"
whisper_medium_id = "openai/whisper-medium"
whisper_large_id = "openai/whisper-large"
whisper_large_id = "openai/whisper-large-v3"
else:
tts_1_7b_id = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
tts_0_6b_id = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
whisper_base_id = "openai/whisper-base"
whisper_small_id = "openai/whisper-small"
whisper_medium_id = "openai/whisper-medium"
whisper_large_id = "openai/whisper-large"
whisper_large_id = "openai/whisper-large-v3"

model_configs = [
{
Expand Down Expand Up @@ -1586,6 +1590,42 @@ async def download_in_background():
return {"message": f"Model {request.model_name} download started"}


@app.post("/models/download/cancel")
async def cancel_model_download(request: models.ModelDownloadRequest):
"""Cancel or dismiss an errored/stale download task."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()

removed = task_manager.cancel_download(request.model_name)

# Also clear progress state so the model doesn't show as downloading
progress_removed = False
with progress_manager._lock:
if request.model_name in progress_manager._progress:
del progress_manager._progress[request.model_name]
progress_removed = True

if removed or progress_removed:
return {"message": f"Download task for {request.model_name} cancelled"}
return {"message": f"No active task found for {request.model_name}"}


@app.post("/tasks/clear")
async def clear_all_tasks():
"""Clear all download tasks and progress state. Does not delete downloaded files."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()

task_manager.clear_all()

with progress_manager._lock:
progress_manager._progress.clear()
progress_manager._last_notify_time.clear()
progress_manager._last_notify_progress.clear()

return {"message": "All task state cleared"}


@app.delete("/models/{model_name}")
async def delete_model(model_name: str):
"""Delete a downloaded model from the HuggingFace cache."""
Expand Down Expand Up @@ -1621,12 +1661,12 @@ async def delete_model(model_name: str):
"model_type": "whisper",
},
"whisper-large": {
"hf_repo_id": "openai/whisper-large",
"hf_repo_id": "openai/whisper-large-v3",
"model_size": "large",
"model_type": "whisper",
},
}

if model_name not in model_configs:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")

Expand Down Expand Up @@ -1710,10 +1750,18 @@ async def get_active_tasks():
progress = progress_map.get(model_name)

if task:
# Prefer task error, fall back to progress manager error
error = task.error
if not error:
with progress_manager._lock:
pm_data = progress_manager._progress.get(model_name)
if pm_data:
error = pm_data.get("error")
active_downloads.append(models.ActiveDownloadTask(
model_name=model_name,
status=task.status,
started_at=task.started_at,
error=error,
))
elif progress:
# Progress exists but no task - create from progress data
Expand All @@ -1730,6 +1778,7 @@ async def get_active_tasks():
model_name=model_name,
status=progress.get("status", "downloading"),
started_at=started_at,
error=progress.get("error"),
))

# Get active generations
Expand Down
1 change: 1 addition & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class ActiveDownloadTask(BaseModel):
model_name: str
status: str
started_at: datetime
error: Optional[str] = None


class ActiveGenerationTask(BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions backend/utils/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def get_active_generations(self) -> List[GenerationTask]:
"""Get all active generations."""
return list(self._active_generations.values())

def cancel_download(self, model_name: str) -> bool:
"""Cancel/dismiss a download task (removes it from active list)."""
return self._active_downloads.pop(model_name, None) is not None

def clear_all(self) -> None:
"""Clear all download and generation tasks."""
self._active_downloads.clear()
self._active_generations.clear()

def is_download_active(self, model_name: str) -> bool:
"""Check if a download is active."""
return model_name in self._active_downloads
Expand Down