From 9c97815e1f057673290aba372e2eef34f278c9b3 Mon Sep 17 00:00:00 2001 From: AJ Date: Mon, 2 Mar 2026 03:26:42 +0530 Subject: [PATCH] feat: add custom HuggingFace voice model support - New custom_models.py module for CRUD management of user-defined HF TTS models - New /custom-models API endpoints (list, add, get, delete) - Updated MLX and PyTorch backends to resolve custom model paths (custom:slug format) - Added Custom Models section to ModelManagement UI with add/remove dialogs - Updated GenerationForm and FloatingGenerateBox with grouped model selectors - Added CustomModelCreate/Response types and API client methods - Added instruct field to GenerationRequest type - Graceful actool fallback in build.rs for non-Xcode environments - Added custom_models hidden import for PyInstaller bundling Author: AJ - Kamyab (Ankit Jain) --- .../Generation/FloatingGenerateBox.tsx | 58 +++- .../components/Generation/GenerationForm.tsx | 53 ++- .../ServerSettings/ModelManagement.tsx | 302 +++++++++++++++++- app/src/lib/api/client.ts | 39 ++- app/src/lib/api/types.ts | 33 +- app/src/lib/hooks/useGenerationForm.ts | 46 ++- backend/backends/mlx_backend.py | 14 +- backend/backends/pytorch_backend.py | 13 +- backend/build_binary.py | 1 + backend/custom_models.py | 165 ++++++++++ backend/main.py | 235 +++++++++++++- backend/models.py | 22 +- backend/voicebox-server.spec | 17 +- bun.lock | 9 +- data/custom_models.json | 10 + tauri/src-tauri/Cargo.lock | 2 +- tauri/src-tauri/build.rs | 9 +- 17 files changed, 971 insertions(+), 57 deletions(-) create mode 100644 backend/custom_models.py create mode 100644 data/custom_models.json diff --git a/app/src/components/Generation/FloatingGenerateBox.tsx b/app/src/components/Generation/FloatingGenerateBox.tsx index a8d556a6..29e418f3 100644 --- a/app/src/components/Generation/FloatingGenerateBox.tsx +++ b/app/src/components/Generation/FloatingGenerateBox.tsx @@ -2,18 +2,22 @@ import { useMatchRoute } from '@tanstack/react-router'; import { AnimatePresence, motion } from 'framer-motion'; import { Loader2, SlidersHorizontal, Sparkles } from 'lucide-react'; import { useEffect, useRef, useState } from 'react'; +import { useQuery } from '@tanstack/react-query'; import { Button } from '@/components/ui/button'; import { Form, FormControl, FormField, FormItem, FormMessage } from '@/components/ui/form'; import { Select, SelectContent, + SelectGroup, SelectItem, + SelectLabel, SelectTrigger, SelectValue, } from '@/components/ui/select'; import { Textarea } from '@/components/ui/textarea'; import { useToast } from '@/components/ui/use-toast'; import { LANGUAGE_OPTIONS } from '@/lib/constants/languages'; +import { apiClient } from '@/lib/api/client'; import { useGenerationForm } from '@/lib/hooks/useGenerationForm'; import { useProfile, useProfiles } from '@/lib/hooks/useProfiles'; import { useAddStoryItem, useStory } from '@/lib/hooks/useStories'; @@ -46,6 +50,20 @@ export function FloatingGenerateBox({ const addStoryItem = useAddStoryItem(); const { toast } = useToast(); + // Fetch model status to dynamically populate the model selector dropdown. + // Models are split into "Built-in" (qwen-tts-*) and "Custom" (is_custom flag) + // groups, keeping the same structure as GenerationForm.tsx. + // @modified AJ - Kamyab (Ankit Jain) — Added custom model grouping in selector + const { data: modelStatus } = useQuery({ + queryKey: ['modelStatus'], + queryFn: () => apiClient.getModelStatus(), + refetchInterval: 10000, + }); + + // Separate built-in TTS models from user-added custom models + const builtInModels = modelStatus?.models.filter((m) => m.model_name.startsWith('qwen-tts')) || []; + const customModels = modelStatus?.models.filter((m) => m.is_custom) || []; + // Calculate if track editor is visible (on stories route with items) const hasTrackEditor = isStoriesRoute && currentStory && currentStory.items.length > 0; @@ -173,7 +191,7 @@ export function FloatingGenerateBox({ 'fixed right-auto', isStoriesRoute ? // Position aligned with story list: after sidebar + padding, width 360px - 'left-[calc(5rem+2rem)] w-[360px]' + 'left-[calc(5rem+2rem)] w-[360px]' : 'left-[calc(5rem+2rem)] w-[calc((100%-5rem-4rem)/2-1rem)]', )} style={{ @@ -414,12 +432,38 @@ export function FloatingGenerateBox({ - - Qwen3-TTS 1.7B - - - Qwen3-TTS 0.6B - + + Built-in + {builtInModels.length > 0 ? ( + builtInModels.map((model) => { + const sizeValue = model.model_name.replace('qwen-tts-', ''); + return ( + + {model.display_name} + + ); + }) + ) : ( + <> + + Qwen3-TTS 1.7B + + + Qwen3-TTS 0.6B + + + )} + + {customModels.length > 0 && ( + + Custom + {customModels.map((model) => ( + + {model.display_name} + + ))} + + )} diff --git a/app/src/components/Generation/GenerationForm.tsx b/app/src/components/Generation/GenerationForm.tsx index 31b100f8..4f88f54a 100644 --- a/app/src/components/Generation/GenerationForm.tsx +++ b/app/src/components/Generation/GenerationForm.tsx @@ -1,4 +1,5 @@ import { Loader2, Mic } from 'lucide-react'; +import { useQuery } from '@tanstack/react-query'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { @@ -14,12 +15,15 @@ import { Input } from '@/components/ui/input'; import { Select, SelectContent, + SelectGroup, SelectItem, + SelectLabel, SelectTrigger, SelectValue, } from '@/components/ui/select'; import { Textarea } from '@/components/ui/textarea'; import { LANGUAGE_OPTIONS } from '@/lib/constants/languages'; +import { apiClient } from '@/lib/api/client'; import { useGenerationForm } from '@/lib/hooks/useGenerationForm'; import { useProfile } from '@/lib/hooks/useProfiles'; import { useUIStore } from '@/stores/uiStore'; @@ -30,6 +34,20 @@ export function GenerationForm() { const { form, handleSubmit, isPending } = useGenerationForm(); + // Fetch model status to dynamically populate the model selector dropdown. + // Models are split into "Built-in" (qwen-tts-*) and "Custom" (is_custom flag) + // groups so users can easily distinguish between them. + // @modified AJ - Kamyab (Ankit Jain) — Added custom model grouping in selector + const { data: modelStatus } = useQuery({ + queryKey: ['modelStatus'], + queryFn: () => apiClient.getModelStatus(), + refetchInterval: 10000, + }); + + // Separate built-in TTS models from user-added custom models + const builtInModels = modelStatus?.models.filter((m) => m.model_name.startsWith('qwen-tts')) || []; + const customModels = modelStatus?.models.filter((m) => m.is_custom) || []; + async function onSubmit(data: Parameters[0]) { await handleSubmit(data, selectedProfileId); } @@ -129,7 +147,7 @@ export function GenerationForm() { name="modelSize" render={({ field }) => ( - Model Size + Model - Larger models produce better quality + Select voice generation model )} diff --git a/app/src/components/ServerSettings/ModelManagement.tsx b/app/src/components/ServerSettings/ModelManagement.tsx index 4a5fd439..3f1f44f3 100644 --- a/app/src/components/ServerSettings/ModelManagement.tsx +++ b/app/src/components/ServerSettings/ModelManagement.tsx @@ -1,5 +1,5 @@ import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; -import { Download, Loader2, Trash2 } from 'lucide-react'; +import { Download, Loader2, Plus, Trash2, X } from 'lucide-react'; import { useCallback, useState } from 'react'; import { AlertDialog, @@ -14,10 +14,33 @@ import { import { Badge } from '@/components/ui/badge'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; import { useToast } from '@/components/ui/use-toast'; import { apiClient } from '@/lib/api/client'; import { useModelDownloadToast } from '@/lib/hooks/useModelDownloadToast'; +/** + * Model Management panel — displayed in the Settings page. + * + * Renders three sections: + * 1. Built-in Voice Generation models (Qwen TTS 1.7B / 0.6B) + * 2. Transcription models (Whisper variants) + * 3. Custom Models — user-added HuggingFace TTS models + * + * Custom models use a "custom:" naming convention throughout the + * frontend and backend so they can be distinguished from built-in models. + * + * @modified AJ - Kamyab (Ankit Jain) — Added Custom Models section, add/remove mutations, and CustomModelItem component + */ export function ModelManagement() { const { toast } = useToast(); const queryClient = useQueryClient(); @@ -65,13 +88,18 @@ export function ModelManagement() { sizeMb?: number; } | null>(null); + // Add Custom Model dialog state + const [addDialogOpen, setAddDialogOpen] = useState(false); + const [newModelRepoId, setNewModelRepoId] = useState(''); + const [newModelDisplayName, setNewModelDisplayName] = useState(''); + const handleDownload = async (modelName: string) => { console.log('[Download] Button clicked for:', modelName, 'at', new Date().toISOString()); - + // Find display name const model = modelStatus?.models.find((m) => m.model_name === modelName); const displayName = model?.display_name || modelName; - + try { // IMPORTANT: Call the API FIRST before setting state // Setting state enables the SSE EventSource in useModelDownloadToast, @@ -79,11 +107,11 @@ export function ModelManagement() { console.log('[Download] Calling download API for:', modelName); const result = await apiClient.triggerModelDownload(modelName); console.log('[Download] Download API responded:', result); - + // NOW set state to enable SSE tracking (after download has started on backend) setDownloadingModel(modelName); setDownloadingDisplayName(displayName); - + // Download initiated successfully - state will be cleared when SSE reports completion // or by the polling interval detecting the model is downloaded queryClient.invalidateQueries({ queryKey: ['modelStatus'] }); @@ -114,14 +142,11 @@ export function ModelManagement() { }); setDeleteDialogOpen(false); setModelToDelete(null); - // Invalidate AND explicitly refetch to ensure UI updates - // Using refetchType: 'all' ensures we refetch even if the query is stale console.log('[Delete] Invalidating modelStatus query'); - await queryClient.invalidateQueries({ + await queryClient.invalidateQueries({ queryKey: ['modelStatus'], refetchType: 'all', }); - // Also explicitly refetch to guarantee fresh data console.log('[Delete] Explicitly refetching modelStatus query'); await queryClient.refetchQueries({ queryKey: ['modelStatus'] }); console.log('[Delete] Query refetched'); @@ -136,12 +161,70 @@ export function ModelManagement() { }, }); + // ── Add custom model mutation ─────────────────────────────────────── + // Registers a new HuggingFace model in data/custom_models.json. + // This does NOT trigger a download — the user must click "Download". + const addCustomModelMutation = useMutation({ + mutationFn: async (data: { hf_repo_id: string; display_name: string }) => { + return apiClient.addCustomModel(data); + }, + onSuccess: async () => { + toast({ + title: 'Custom model added', + description: `${newModelDisplayName} has been added successfully.`, + }); + setAddDialogOpen(false); + setNewModelRepoId(''); + setNewModelDisplayName(''); + await queryClient.invalidateQueries({ queryKey: ['modelStatus'] }); + }, + onError: (error: Error) => { + toast({ + title: 'Failed to add model', + description: error.message, + variant: 'destructive', + }); + }, + }); + + // ── Remove custom model mutation ──────────────────────────────────── + // Removes the model entry from data/custom_models.json. + // Does NOT delete cached model files from the HuggingFace cache. + // To delete cache, the user should click the trash icon (onDeleteCache). + const removeCustomModelMutation = useMutation({ + mutationFn: async (modelId: string) => { + return apiClient.removeCustomModel(modelId); + }, + onSuccess: async () => { + toast({ + title: 'Custom model removed', + description: 'The custom model has been removed from your list.', + }); + await queryClient.invalidateQueries({ queryKey: ['modelStatus'] }); + }, + onError: (error: Error) => { + toast({ + title: 'Failed to remove model', + description: error.message, + variant: 'destructive', + }); + }, + }); + const formatSize = (sizeMb?: number): string => { if (!sizeMb) return 'Unknown'; if (sizeMb < 1024) return `${sizeMb.toFixed(1)} MB`; return `${(sizeMb / 1024).toFixed(2)} GB`; }; + const handleAddCustomModel = () => { + if (!newModelRepoId.trim() || !newModelDisplayName.trim()) return; + addCustomModelMutation.mutate({ + hf_repo_id: newModelRepoId.trim(), + display_name: newModelDisplayName.trim(), + }); + }; + return ( @@ -213,6 +296,54 @@ export function ModelManagement() { + {/* Custom Models */} +
+
+

+ Custom Models +

+ +
+
+ {modelStatus.models + .filter((m) => m.is_custom) + .map((model) => ( + handleDownload(model.model_name)} + onDeleteCache={() => { + setModelToDelete({ + name: model.model_name, + displayName: model.display_name, + sizeMb: model.size_mb, + }); + setDeleteDialogOpen(true); + }} + onRemove={() => { + // Extract custom ID from "custom:slug" format + const customId = model.model_name.replace('custom:', ''); + removeCustomModelMutation.mutate(customId); + }} + isDownloading={downloadingModel === model.model_name} + formatSize={formatSize} + /> + ))} + {modelStatus.models.filter((m) => m.is_custom).length === 0 && ( +
+ No custom models added yet. Click "Add Model" to add a HuggingFace model. +
+ )} +
+
+ ) : null} @@ -256,6 +387,64 @@ export function ModelManagement() { + + {/* Add Custom Model Dialog */} + + + + Add Custom Model + + Add a HuggingFace model to use for voice generation. The model must be compatible + with the TTS backend. + + +
+
+ + setNewModelRepoId(e.target.value)} + /> +

+ The full repo ID from HuggingFace (owner/model-name) +

+
+
+ + setNewModelDisplayName(e.target.value)} + /> +
+
+ + + + +
+
); } @@ -268,6 +457,7 @@ interface ModelItemProps { downloading?: boolean; // From server - true if download in progress size_mb?: number; loaded: boolean; + is_custom?: boolean; }; onDownload: () => void; onDelete: () => void; @@ -275,10 +465,14 @@ interface ModelItemProps { formatSize: (sizeMb?: number) => string; } +/** + * A single row in the built-in model list (Qwen TTS / Whisper). + * Shows download status, size, and delete/download actions. + */ function ModelItem({ model, onDownload, onDelete, isDownloading, formatSize }: ModelItemProps) { // Use server's downloading state OR local state (for immediate feedback before server updates) const showDownloading = model.downloading || isDownloading; - + return (
@@ -333,3 +527,91 @@ function ModelItem({ model, onDownload, onDelete, isDownloading, formatSize }: M
); } + +interface CustomModelItemProps { + model: { + model_name: string; + display_name: string; + downloaded: boolean; + downloading?: boolean; + size_mb?: number; + loaded: boolean; + }; + onDownload: () => void; + onDeleteCache: () => void; + onRemove: () => void; + isDownloading: boolean; + formatSize: (sizeMb?: number) => string; +} + +/** + * A single row in the custom model list. + * In addition to download/delete-cache, custom models have a "remove" button + * (X icon) that un-registers the model from the config without deleting cached files. + */ +function CustomModelItem({ model, onDownload, onDeleteCache, onRemove, isDownloading, formatSize }: CustomModelItemProps) { + const showDownloading = model.downloading || isDownloading; + + return ( +
+
+
+ {model.display_name} + Custom + {model.loaded && ( + + Loaded + + )} + {model.downloaded && !model.loaded && !showDownloading && ( + + Downloaded + + )} +
+ {model.downloaded && model.size_mb && !showDownloading && ( +
+ Size: {formatSize(model.size_mb)} +
+ )} +
+
+ {model.downloaded && !showDownloading ? ( +
+
+ Ready +
+ +
+ ) : showDownloading ? ( + + ) : ( + + )} + +
+
+ ); +} diff --git a/app/src/lib/api/client.ts b/app/src/lib/api/client.ts index c5b079b2..6edb8fbc 100644 --- a/app/src/lib/api/client.ts +++ b/app/src/lib/api/client.ts @@ -14,6 +14,9 @@ import type { ModelStatusListResponse, ModelDownloadRequest, ActiveTasksResponse, + CustomModelCreate, + CustomModelResponse, + CustomModelListResponse, StoryCreate, StoryResponse, StoryDetailResponse, @@ -319,8 +322,42 @@ class ApiClient { return result; } + /** + * Delete a model's cached files from disk. + * Uses encodeURIComponent because custom model names contain colons ("custom:slug"). + */ async deleteModel(modelName: string): Promise<{ message: string }> { - return this.request<{ message: string }>(`/models/${modelName}`, { + return this.request<{ message: string }>(`/models/${encodeURIComponent(modelName)}`, { + method: 'DELETE', + }); + } + + // ── Custom Models ───────────────────────────────────────────────────── + // CRUD operations for user-defined HuggingFace TTS models. + // Custom models are persisted in data/custom_models.json on the backend. + // + // @author AJ - Kamyab (Ankit Jain) + + /** List all registered custom models. */ + async listCustomModels(): Promise { + return this.request('/custom-models'); + } + + /** Register a new custom HuggingFace model (does NOT trigger download). */ + async addCustomModel(data: CustomModelCreate): Promise { + return this.request('/custom-models', { + method: 'POST', + body: JSON.stringify(data), + }); + } + + /** + * Remove a custom model from the config. + * This only removes the registration — cached HuggingFace files are NOT deleted. + * Use deleteModel("custom:slug") to also clear the HF cache. + */ + async removeCustomModel(modelId: string): Promise<{ message: string }> { + return this.request<{ message: string }>(`/custom-models/${modelId}`, { method: 'DELETE', }); } diff --git a/app/src/lib/api/types.ts b/app/src/lib/api/types.ts index 131c1be5..38fbf14e 100644 --- a/app/src/lib/api/types.ts +++ b/app/src/lib/api/types.ts @@ -33,7 +33,10 @@ export interface GenerationRequest { text: string; language: LanguageCode; seed?: number; - model_size?: '1.7B' | '0.6B'; + /** Model identifier — built-in size ("1.7B", "0.6B") or custom model ID ("custom:slug") */ + model_size?: string; + /** Natural language instruction for speech delivery control (e.g. "speak slowly") */ + instruct?: string; } export interface GenerationResponse { @@ -99,6 +102,8 @@ export interface ModelStatus { downloading: boolean; // True if download is in progress size_mb?: number; loaded: boolean; + /** True for user-added custom HuggingFace models (model_name uses "custom:slug" format) */ + is_custom?: boolean; } export interface ModelStatusListResponse { @@ -109,6 +114,32 @@ export interface ModelDownloadRequest { model_name: string; } +/** + * Request payload for registering a custom HuggingFace TTS model. + * After adding, the model appears in model management and generation dropdowns. + * + * @author AJ - Kamyab (Ankit Jain) + */ +export interface CustomModelCreate { + /** Full HuggingFace repository ID, e.g. "AryanNsc/IND-QWENTTS-V1" */ + hf_repo_id: string; + /** User-friendly name shown in the UI */ + display_name: string; +} + +/** Custom model as returned by the backend after creation or listing. */ +export interface CustomModelResponse { + /** Auto-generated slug ID derived from the repo path (e.g. "aryansc-ind-qwentts-v1") */ + id: string; + hf_repo_id: string; + display_name: string; + added_at: string; +} + +export interface CustomModelListResponse { + models: CustomModelResponse[]; +} + export interface ActiveDownloadTask { model_name: string; status: string; diff --git a/app/src/lib/hooks/useGenerationForm.ts b/app/src/lib/hooks/useGenerationForm.ts index c6fdba50..92e56aab 100644 --- a/app/src/lib/hooks/useGenerationForm.ts +++ b/app/src/lib/hooks/useGenerationForm.ts @@ -10,11 +10,20 @@ import { useModelDownloadToast } from '@/lib/hooks/useModelDownloadToast'; import { useGenerationStore } from '@/stores/generationStore'; import { usePlayerStore } from '@/stores/playerStore'; +/** + * Zod schema for the generation form. + * + * `modelSize` is a free-form string rather than a strict enum + * because it can be either a built-in size ("1.7B", "0.6B") or + * a custom model identifier ("custom:"). + * + * @modified AJ - Kamyab (Ankit Jain) — Changed modelSize from enum to string for custom model support + */ const generationSchema = z.object({ text: z.string().min(1, 'Text is required').max(5000), language: z.enum(LANGUAGE_CODES as [LanguageCode, ...LanguageCode[]]), seed: z.number().int().optional(), - modelSize: z.enum(['1.7B', '0.6B']).optional(), + modelSize: z.string().optional(), instruct: z.string().max(500).optional(), }); @@ -67,18 +76,41 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { try { setIsGenerating(true); - const modelName = `qwen-tts-${data.modelSize}`; - const displayName = data.modelSize === '1.7B' ? 'Qwen TTS 1.7B' : 'Qwen TTS 0.6B'; + const modelSize = data.modelSize || '1.7B'; + + // Derive model tracking name and display name. + // Built-in models use "qwen-tts-" format for tracking. + // Custom models use the full "custom:" identifier as-is. + let modelName: string; + let displayName: string; + + if (modelSize.startsWith('custom:')) { + // Custom model: use the full "custom:slug" as the tracking key + modelName = modelSize; + displayName = modelSize.replace('custom:', ''); + } else { + // Built-in model: construct the standard tracking name + modelName = `qwen-tts-${modelSize}`; + displayName = modelSize === '1.7B' ? 'Qwen TTS 1.7B' : 'Qwen TTS 0.6B'; + } + // Pre-flight check: query model status to get the accurate display name + // and to detect if the model needs downloading first. + // If the model isn't downloaded yet, enable the SSE download progress toast. try { const modelStatus = await apiClient.getModelStatus(); const model = modelStatus.models.find((m) => m.model_name === modelName); - if (model && !model.downloaded) { - setDownloadingModelName(modelName); - setDownloadingDisplayName(displayName); + if (model) { + displayName = model.display_name; + if (!model.downloaded) { + // Not yet downloaded — enable progress tracking UI + setDownloadingModelName(modelName); + setDownloadingDisplayName(displayName); + } } } catch (error) { + // Non-fatal: generation will still attempt and may trigger download on the backend console.error('Failed to check model status:', error); } @@ -87,7 +119,7 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { text: data.text, language: data.language, seed: data.seed, - model_size: data.modelSize, + model_size: modelSize, instruct: data.instruct || undefined, }); diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index c4ecc090..d222090c 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -32,11 +32,22 @@ def _get_model_path(self, model_size: str) -> str: Get the MLX model path. Args: - model_size: Model size (1.7B or 0.6B) + model_size: Model size (1.7B or 0.6B) or custom model ID (custom:slug) Returns: HuggingFace Hub model ID for MLX """ + # Handle custom model IDs + # @modified AJ - Kamyab (Ankit Jain) — Added custom model path resolution + if model_size.startswith("custom:"): + custom_id = model_size[len("custom:"):] + from ..custom_models import get_hf_repo_id_for_custom_model + hf_repo_id = get_hf_repo_id_for_custom_model(custom_id) + if not hf_repo_id: + raise ValueError(f"Custom model '{custom_id}' not found") + print(f"Will download custom model from HuggingFace Hub: {hf_repo_id}") + return hf_repo_id + # MLX model mapping mlx_model_map = { "1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16", @@ -51,6 +62,7 @@ def _get_model_path(self, model_size: str) -> str: print(f"Will download MLX model from HuggingFace Hub: {hf_model_id}") return hf_model_id + def _is_model_cached(self, model_size: str) -> bool: """ diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index d0cba11a..0585e6c9 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -57,11 +57,21 @@ def _get_model_path(self, model_size: str) -> str: Get the HuggingFace Hub model ID. Args: - model_size: Model size (1.7B or 0.6B) + model_size: Model size (1.7B or 0.6B) or custom model ID (custom:slug) Returns: HuggingFace Hub model ID """ + # Handle custom model IDs + # @modified AJ - Kamyab (Ankit Jain) — Added custom model path resolution + if model_size.startswith("custom:"): + custom_id = model_size[len("custom:"):] + from ..custom_models import get_hf_repo_id_for_custom_model + hf_repo_id = get_hf_repo_id_for_custom_model(custom_id) + if not hf_repo_id: + raise ValueError(f"Custom model '{custom_id}' not found") + return hf_repo_id + hf_model_map = { "1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base", "0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base", @@ -71,6 +81,7 @@ def _get_model_path(self, model_size: str) -> str: raise ValueError(f"Unknown model size: {model_size}") return hf_model_map[model_size] + def _is_model_cached(self, model_size: str) -> bool: """ diff --git a/backend/build_binary.py b/backend/build_binary.py index 73f21d23..3c1d4082 100644 --- a/backend/build_binary.py +++ b/backend/build_binary.py @@ -48,6 +48,7 @@ def build_server(): '--hidden-import', 'backend.utils.cache', '--hidden-import', 'backend.utils.progress', '--hidden-import', 'backend.utils.hf_progress', + '--hidden-import', 'backend.custom_models', # @modified AJ - Kamyab (Ankit Jain) '--hidden-import', 'backend.utils.validation', '--hidden-import', 'torch', '--hidden-import', 'transformers', diff --git a/backend/custom_models.py b/backend/custom_models.py new file mode 100644 index 00000000..4e2ea243 --- /dev/null +++ b/backend/custom_models.py @@ -0,0 +1,165 @@ +""" +Custom voice model management module. + +Handles adding, removing, and listing user-defined HuggingFace TTS models. +Models are persisted in a JSON config file in the data directory. + +@author AJ - Kamyab (Ankit Jain) +""" + +import json +import re +from datetime import datetime +from pathlib import Path +from typing import List, Optional + +from . import config + + +def _get_config_path() -> Path: + """Get path to the custom models JSON config file.""" + return config.get_data_dir() / "custom_models.json" + + +def _load_config() -> dict: + """Load custom models config from disk.""" + path = _get_config_path() + if not path.exists(): + return {"models": []} + try: + with open(path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + return {"models": []} + + +def _save_config(data: dict) -> None: + """Save custom models config to disk.""" + path = _get_config_path() + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(data, f, indent=2, default=str) + + +def _generate_id(hf_repo_id: str) -> str: + """Generate a slug ID from a HuggingFace repo ID. + + Example: 'AryanNsc/IND-QWENTTS-V1' -> 'aryansc-ind-qwentts-v1' + """ + slug = hf_repo_id.lower().replace("/", "-") + slug = re.sub(r"[^a-z0-9-]", "-", slug) + slug = re.sub(r"-+", "-", slug).strip("-") + return slug + + +def list_custom_models() -> List[dict]: + """List all custom models. + + Returns: + List of custom model dicts + """ + data = _load_config() + return data.get("models", []) + + +def get_custom_model(model_id: str) -> Optional[dict]: + """Get a single custom model by ID. + + Args: + model_id: Custom model ID (slug) + + Returns: + Model dict or None if not found + """ + models = list_custom_models() + for model in models: + if model["id"] == model_id: + return model + return None + + +def add_custom_model(hf_repo_id: str, display_name: str) -> dict: + """Add a new custom model. + + Args: + hf_repo_id: HuggingFace repo ID (e.g. 'AryanNsc/IND-QWENTTS-V1') + display_name: User-friendly display name + + Returns: + Created model dict + + Raises: + ValueError: If model already exists or inputs are invalid + """ + hf_repo_id = hf_repo_id.strip() + display_name = display_name.strip() + + if not hf_repo_id: + raise ValueError("HuggingFace repo ID is required") + if not display_name: + raise ValueError("Display name is required") + if "/" not in hf_repo_id: + raise ValueError("HuggingFace repo ID must be in format 'owner/model-name'") + + model_id = _generate_id(hf_repo_id) + + data = _load_config() + models = data.get("models", []) + + # Check for duplicates + for existing in models: + if existing["id"] == model_id: + raise ValueError(f"Model '{hf_repo_id}' already exists") + if existing["hf_repo_id"] == hf_repo_id: + raise ValueError(f"Model with repo ID '{hf_repo_id}' already exists") + + model = { + "id": model_id, + "display_name": display_name, + "hf_repo_id": hf_repo_id, + "added_at": datetime.utcnow().isoformat() + "Z", + } + + models.append(model) + data["models"] = models + _save_config(data) + + return model + + +def remove_custom_model(model_id: str) -> bool: + """Remove a custom model by ID. + + Args: + model_id: Custom model ID (slug) + + Returns: + True if removed, False if not found + """ + data = _load_config() + models = data.get("models", []) + + original_count = len(models) + models = [m for m in models if m["id"] != model_id] + + if len(models) == original_count: + return False + + data["models"] = models + _save_config(data) + return True + + +def get_hf_repo_id_for_custom_model(model_id: str) -> Optional[str]: + """Get the HuggingFace repo ID for a custom model. + + Args: + model_id: Custom model ID (slug, without 'custom:' prefix) + + Returns: + HuggingFace repo ID or None if not found + """ + model = get_custom_model(model_id) + if model: + return model["hf_repo_id"] + return None diff --git a/backend/main.py b/backend/main.py index e218d237..aac5e8b6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -41,7 +41,7 @@ def _safe_content_disposition(disposition_type: str, filename: str) -> str: ) -from . import database, models, profiles, history, tts, transcribe, config, export_import, channels, stories, __version__ +from . import database, models, profiles, history, tts, transcribe, config, export_import, channels, stories, custom_models, __version__ from .database import get_db, Generation as DBGeneration, VoiceProfile as DBVoiceProfile from .utils.progress import get_progress_manager from .utils.tasks import get_task_manager @@ -615,7 +615,10 @@ async def generate_speech( if not tts_model._is_model_cached(model_size): # Model is not fully cached — kick off a background download and tell # the client to retry once it's ready. - model_name = f"qwen-tts-{model_size}" + if model_size.startswith("custom:"): + model_name = model_size # Use the full custom:slug as the tracking name + else: + model_name = f"qwen-tts-{model_size}" async def download_model_background(): try: @@ -1509,6 +1512,116 @@ def check_whisper_loaded(model_size: str): loaded=loaded, )) + + # ==== Add custom models to the status list ==== + custom_model_list = custom_models.list_custom_models() + for cm in custom_model_list: + model_name = f"custom:{cm['id']}" + hf_repo_id = cm["hf_repo_id"] + + try: + downloaded = False + size_mb = None + + # Check if custom model is cached (same logic as built-in models) + if cache_info: + for repo in cache_info.repos: + if repo.repo_id == hf_repo_id: + has_model_weights = False + for rev in repo.revisions: + for f in rev.files: + fname = f.file_name.lower() + if fname.endswith(('.safetensors', '.bin', '.pt', '.pth', '.npz')): + has_model_weights = True + break + if has_model_weights: + break + + has_incomplete = False + try: + cache_dir_path = hf_constants.HF_HUB_CACHE + blobs_dir = Path(cache_dir_path) / ("models--" + hf_repo_id.replace("/", "--")) / "blobs" + if blobs_dir.exists(): + has_incomplete = any(blobs_dir.glob("*.incomplete")) + except Exception: + pass + + if has_model_weights and not has_incomplete: + downloaded = True + try: + total_size = sum(revision.size_on_disk for revision in repo.revisions) + size_mb = total_size / (1024 * 1024) + except Exception: + pass + break + + # Fallback cache check + if not downloaded: + try: + cache_dir_path = hf_constants.HF_HUB_CACHE + repo_cache = Path(cache_dir_path) / ("models--" + hf_repo_id.replace("/", "--")) + if repo_cache.exists(): + blobs_dir = repo_cache / "blobs" + has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")) + if not has_incomplete: + snapshots_dir = repo_cache / "snapshots" + has_model_files = False + if snapshots_dir.exists(): + has_model_files = ( + any(snapshots_dir.rglob("*.bin")) or + any(snapshots_dir.rglob("*.safetensors")) or + any(snapshots_dir.rglob("*.pt")) or + any(snapshots_dir.rglob("*.pth")) or + any(snapshots_dir.rglob("*.npz")) + ) + if has_model_files: + downloaded = True + try: + total_size = sum( + f.stat().st_size for f in repo_cache.rglob("*") + if f.is_file() and not f.name.endswith('.incomplete') + ) + size_mb = total_size / (1024 * 1024) + except Exception: + pass + except Exception: + pass + + # Check if loaded + loaded = False + try: + tts_model = tts.get_tts_model() + loaded = tts_model.is_loaded() and getattr(tts_model, '_current_model_size', None) == model_name + except Exception: + pass + + # Check if downloading + is_downloading = model_name in active_download_names or hf_repo_id in active_download_repos + if is_downloading: + downloaded = False + size_mb = None + + statuses.append(models.ModelStatus( + model_name=model_name, + display_name=cm["display_name"], + downloaded=downloaded, + downloading=is_downloading, + size_mb=size_mb, + loaded=loaded, + is_custom=True, + )) + except Exception: + is_downloading = model_name in active_download_names + statuses.append(models.ModelStatus( + model_name=model_name, + display_name=cm["display_name"], + downloaded=False, + downloading=is_downloading, + size_mb=None, + loaded=False, + is_custom=True, + )) + return models.ModelStatusListResponse(models=statuses) @@ -1520,6 +1633,7 @@ async def trigger_model_download(request: models.ModelDownloadRequest): task_manager = get_task_manager() progress_manager = get_progress_manager() + # Built-in model configs model_configs = { "qwen-tts-1.7B": { "model_size": "1.7B", @@ -1547,10 +1661,22 @@ async def trigger_model_download(request: models.ModelDownloadRequest): }, } - if request.model_name not in model_configs: + # Handle custom models (custom:slug format) + if request.model_name.startswith("custom:"): + custom_id = request.model_name[len("custom:"):] + cm = custom_models.get_custom_model(custom_id) + if not cm: + raise HTTPException(status_code=400, detail=f"Custom model '{custom_id}' not found") + + model_size = request.model_name # Pass full "custom:slug" to load_model + config = { + "model_size": model_size, + "load_func": lambda: tts.get_tts_model().load_model(model_size), + } + elif request.model_name not in model_configs: raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}") - - config = model_configs[request.model_name] + else: + config = model_configs[request.model_name] async def download_in_background(): """Download model in background without blocking the HTTP request.""" @@ -1593,7 +1719,36 @@ async def delete_model(model_name: str): import os from huggingface_hub import constants as hf_constants - # Map model names to HuggingFace repo IDs + # Handle custom models (custom:slug format) + if model_name.startswith("custom:"): + custom_id = model_name[len("custom:"):] + cm = custom_models.get_custom_model(custom_id) + if not cm: + raise HTTPException(status_code=400, detail=f"Custom model '{custom_id}' not found") + hf_repo_id = cm["hf_repo_id"] + + try: + # Unload if this custom model is loaded + tts_model = tts.get_tts_model() + if tts_model.is_loaded() and getattr(tts_model, '_current_model_size', None) == model_name: + tts.unload_tts_model() + + # Delete from HF cache + cache_dir = hf_constants.HF_HUB_CACHE + repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--")) + + if not repo_cache_dir.exists(): + raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache") + + shutil.rmtree(repo_cache_dir) + return {"message": f"Model {model_name} deleted successfully"} + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") + + # Map built-in model names to HuggingFace repo IDs model_configs = { "qwen-tts-1.7B": { "hf_repo_id": "Qwen/Qwen3-TTS-12Hz-1.7B-Base", @@ -1669,6 +1824,74 @@ async def delete_model(model_name: str): raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") +# ============================================ +# ============================================ +# CUSTOM MODEL MANAGEMENT +# ============================================ +# These endpoints manage user-defined HuggingFace TTS models. +# Models are stored in data/custom_models.json and identified +# by a slug ID (e.g. "aryansc-ind-qwentts-v1") derived from +# the HuggingFace repo path. +# +# Adding a custom model only registers it in the config. +# It must be separately downloaded via /models/download +# with model_name="custom:" before it can be used. +# +# @author AJ - Kamyab (Ankit Jain) +# ============================================ + +@app.get("/custom-models", response_model=models.CustomModelListResponse) +async def list_custom_models_endpoint(): + """List all registered custom models and their metadata.""" + items = custom_models.list_custom_models() + return models.CustomModelListResponse( + models=[models.CustomModelResponse(**m) for m in items] + ) + + +@app.post("/custom-models", response_model=models.CustomModelResponse) +async def add_custom_model_endpoint(data: models.CustomModelCreate): + """ + Register a new custom HuggingFace TTS model. + + Validates the repo ID format (must contain '/') and checks for duplicates. + The model is NOT downloaded — use POST /models/download with + model_name="custom:" to fetch model weights from HuggingFace. + """ + try: + model = custom_models.add_custom_model( + hf_repo_id=data.hf_repo_id, + display_name=data.display_name, + ) + return models.CustomModelResponse(**model) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.get("/custom-models/{model_id}", response_model=models.CustomModelResponse) +async def get_custom_model_endpoint(model_id: str): + """Get a single custom model's metadata by its slug ID.""" + model = custom_models.get_custom_model(model_id) + if not model: + raise HTTPException(status_code=404, detail=f"Custom model '{model_id}' not found") + return models.CustomModelResponse(**model) + + +@app.delete("/custom-models/{model_id}") +async def delete_custom_model_endpoint(model_id: str): + """ + Remove a custom model from the config. + + This only removes the registration — cached HuggingFace model files + are NOT deleted. Use DELETE /models/custom: to also clear the + HF cache. + """ + success = custom_models.remove_custom_model(model_id) + if not success: + raise HTTPException(status_code=404, detail=f"Custom model '{model_id}' not found") + return {"message": f"Custom model '{model_id}' removed successfully"} + + @app.post("/cache/clear") async def clear_cache(): """Clear all voice prompt caches (memory and disk).""" diff --git a/backend/models.py b/backend/models.py index 59e45405..6560039c 100644 --- a/backend/models.py +++ b/backend/models.py @@ -55,7 +55,7 @@ class GenerationRequest(BaseModel): text: str = Field(..., min_length=1, max_length=5000) language: str = Field(default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it)$") seed: Optional[int] = Field(None, ge=0) - model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B)$") + model_size: Optional[str] = Field(default="1.7B") instruct: Optional[str] = Field(None, max_length=500) @@ -137,6 +137,7 @@ class ModelStatus(BaseModel): downloading: bool = False # True if download is in progress size_mb: Optional[float] = None loaded: bool = False + is_custom: bool = False # True for user-added custom models — @modified AJ - Kamyab (Ankit Jain) class ModelStatusListResponse(BaseModel): @@ -164,6 +165,25 @@ class ActiveGenerationTask(BaseModel): started_at: datetime +class CustomModelCreate(BaseModel): + """Request model for adding a custom model.""" + hf_repo_id: str = Field(..., min_length=3, max_length=200) + display_name: str = Field(..., min_length=1, max_length=100) + + +class CustomModelResponse(BaseModel): + """Response model for a custom model.""" + id: str + hf_repo_id: str + display_name: str + added_at: str + + +class CustomModelListResponse(BaseModel): + """Response model for custom model list.""" + models: List[CustomModelResponse] + + class ActiveTasksResponse(BaseModel): """Response model for active tasks.""" downloads: List[ActiveDownloadTask] diff --git a/backend/voicebox-server.spec b/backend/voicebox-server.spec index feccfae0..3719edc0 100644 --- a/backend/voicebox-server.spec +++ b/backend/voicebox-server.spec @@ -1,29 +1,28 @@ # -*- mode: python ; coding: utf-8 -*- from PyInstaller.utils.hooks import collect_data_files from PyInstaller.utils.hooks import collect_submodules +from PyInstaller.utils.hooks import collect_all from PyInstaller.utils.hooks import copy_metadata datas = [] -hiddenimports = ['backend', 'backend.main', 'backend.config', 'backend.database', 'backend.models', 'backend.profiles', 'backend.history', 'backend.tts', 'backend.transcribe', 'backend.platform_detect', 'backend.backends', 'backend.backends.pytorch_backend', 'backend.utils.audio', 'backend.utils.cache', 'backend.utils.progress', 'backend.utils.hf_progress', 'backend.utils.validation', 'torch', 'transformers', 'fastapi', 'uvicorn', 'sqlalchemy', 'librosa', 'soundfile', 'qwen_tts', 'qwen_tts.inference', 'qwen_tts.inference.qwen3_tts_model', 'qwen_tts.inference.qwen3_tts_tokenizer', 'qwen_tts.core', 'qwen_tts.cli', 'pkg_resources.extern', 'backend.backends.mlx_backend', 'mlx', 'mlx.core', 'mlx.nn', 'mlx_audio', 'mlx_audio.tts', 'mlx_audio.stt'] +binaries = [] +hiddenimports = ['backend', 'backend.main', 'backend.config', 'backend.database', 'backend.models', 'backend.profiles', 'backend.history', 'backend.tts', 'backend.transcribe', 'backend.platform_detect', 'backend.backends', 'backend.backends.pytorch_backend', 'backend.utils.audio', 'backend.utils.cache', 'backend.utils.progress', 'backend.utils.hf_progress', 'backend.custom_models', 'backend.utils.validation', 'torch', 'transformers', 'fastapi', 'uvicorn', 'sqlalchemy', 'librosa', 'soundfile', 'qwen_tts', 'qwen_tts.inference', 'qwen_tts.inference.qwen3_tts_model', 'qwen_tts.inference.qwen3_tts_tokenizer', 'qwen_tts.core', 'qwen_tts.cli', 'pkg_resources.extern', 'backend.backends.mlx_backend', 'mlx', 'mlx.core', 'mlx.nn', 'mlx_audio', 'mlx_audio.tts', 'mlx_audio.stt'] datas += collect_data_files('qwen_tts') -# Use collect_all (not collect_data_files) so native .dylib and .metallib -# files are bundled as binaries, not data. Without this, MLX raises OSError -# when loading Metal shaders inside the PyInstaller bundle. -from PyInstaller.utils.hooks import collect_all as _collect_all -_mlx_datas, _mlx_bins, _mlx_hidden = _collect_all('mlx') -_mlxa_datas, _mlxa_bins, _mlxa_hidden = _collect_all('mlx_audio') -datas += _mlx_datas + _mlxa_datas datas += copy_metadata('qwen-tts') hiddenimports += collect_submodules('qwen_tts') hiddenimports += collect_submodules('jaraco') hiddenimports += collect_submodules('mlx') hiddenimports += collect_submodules('mlx_audio') +tmp_ret = collect_all('mlx') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] +tmp_ret = collect_all('mlx_audio') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] a = Analysis( ['server.py'], pathex=[], - binaries=_mlx_bins + _mlxa_bins, + binaries=binaries, datas=datas, hiddenimports=hiddenimports, hookspath=[], diff --git a/bun.lock b/bun.lock index 9e08a825..d271b5c6 100644 --- a/bun.lock +++ b/bun.lock @@ -13,7 +13,7 @@ }, "app": { "name": "@voicebox/app", - "version": "0.1.11", + "version": "0.1.13", "dependencies": { "@dnd-kit/core": "^6.3.1", "@dnd-kit/sortable": "^10.0.0", @@ -68,7 +68,7 @@ }, "landing": { "name": "@voicebox/landing", - "version": "0.1.11", + "version": "0.1.13", "dependencies": { "@radix-ui/react-separator": "^1.1.8", "@radix-ui/react-slot": "^1.2.4", @@ -93,7 +93,7 @@ }, "tauri": { "name": "@voicebox/tauri", - "version": "0.1.11", + "version": "0.1.13", "dependencies": { "@tauri-apps/api": "^2.0.0", "@tauri-apps/plugin-dialog": "^2.0.0", @@ -116,7 +116,7 @@ }, "web": { "name": "@voicebox/web", - "version": "0.1.11", + "version": "0.1.13", "dependencies": { "@tanstack/react-query": "^5.0.0", "react": "^18.3.0", @@ -125,6 +125,7 @@ "zustand": "^4.5.0", }, "devDependencies": { + "@tailwindcss/vite": "^4.0.0", "@types/react": "^18.3.0", "@types/react-dom": "^18.3.0", "@typescript-eslint/eslint-plugin": "^7.0.0", diff --git a/data/custom_models.json b/data/custom_models.json new file mode 100644 index 00000000..28884d93 --- /dev/null +++ b/data/custom_models.json @@ -0,0 +1,10 @@ +{ + "models": [ + { + "id": "aryannsc-ind-qwentts-v1", + "display_name": "IND Qwen tts1", + "hf_repo_id": "AryanNsc/IND-QWENTTS-V1", + "added_at": "2026-03-01T18:38:22.160558Z" + } + ] +} \ No newline at end of file diff --git a/tauri/src-tauri/Cargo.lock b/tauri/src-tauri/Cargo.lock index 35b15188..450b9a35 100644 --- a/tauri/src-tauri/Cargo.lock +++ b/tauri/src-tauri/Cargo.lock @@ -5041,7 +5041,7 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "voicebox" -version = "0.1.12" +version = "0.1.13" dependencies = [ "base64 0.22.1", "core-foundation-sys", diff --git a/tauri/src-tauri/build.rs b/tauri/src-tauri/build.rs index ea612597..3e5090e9 100644 --- a/tauri/src-tauri/build.rs +++ b/tauri/src-tauri/build.rs @@ -63,17 +63,18 @@ fn main() { match output { Ok(output) => { + // @modified AJ - Kamyab (Ankit Jain) — Graceful fallback when full Xcode is not installed if !output.status.success() { eprintln!("actool stderr: {}", String::from_utf8_lossy(&output.stderr)); eprintln!("actool stdout: {}", String::from_utf8_lossy(&output.stdout)); - panic!("actool failed to compile icon"); + println!("cargo:warning=actool failed to compile icon (full Xcode may be required). Continuing without custom icon."); + } else { + println!("Successfully compiled icon to {}", gen_dir); } - println!("Successfully compiled icon to {}", gen_dir); } Err(e) => { eprintln!("Failed to execute xcrun actool: {}", e); - eprintln!("Make sure you have Xcode Command Line Tools installed"); - panic!("Icon compilation failed"); + println!("cargo:warning=Could not run actool (full Xcode may be required). Continuing without custom icon."); } } } else {