diff --git a/frontend/src/components/DownloadDialog.tsx b/frontend/src/components/DownloadDialog.tsx index fd1ab768c..0dfec0033 100644 --- a/frontend/src/components/DownloadDialog.tsx +++ b/frontend/src/components/DownloadDialog.tsx @@ -16,6 +16,10 @@ interface DownloadDialogProps { pipelineId: PipelineId; onClose: () => void; onDownload: () => void; + vaeNeedsDownload?: { + vaeType: string; + modelName: string; + } | null; } export function DownloadDialog({ @@ -23,6 +27,7 @@ export function DownloadDialog({ pipelineId, onClose, onDownload, + vaeNeedsDownload, }: DownloadDialogProps) { const pipelineInfo = PIPELINES[pipelineId]; if (!pipelineInfo) return null; @@ -31,13 +36,22 @@ export function DownloadDialog({ !isOpen && onClose()}> - Download Models + + {vaeNeedsDownload ? "Download VAE Model" : "Download Models"} + - This pipeline requires model weights to be downloaded. + {vaeNeedsDownload ? ( + <> + The selected VAE model ({vaeNeedsDownload.vaeType}) is missing + and needs to be downloaded. + + ) : ( + <>This pipeline requires model weights to be downloaded. + )} - {pipelineInfo.estimatedVram && ( + {!vaeNeedsDownload && pipelineInfo.estimatedVram && (

Estimated GPU VRAM Requirement: diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 05c431fcf..a6f716b8c 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -307,3 +307,47 @@ export const getPipelineSchemas = const result = await response.json(); return result; }; + +export const checkVaeStatus = async ( + vaeType: string, + modelName: string = "Wan2.1-T2V-1.3B" +): Promise<{ downloaded: boolean }> => { + const response = await fetch( + `/api/v1/vae/status?vae_type=${encodeURIComponent(vaeType)}&model_name=${encodeURIComponent(modelName)}`, + { + method: "GET", + headers: { "Content-Type": "application/json" }, + } + ); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `VAE status check failed: ${response.status} ${response.statusText}: ${errorText}` + ); + } + + const result = await response.json(); + return result; +}; + +export const downloadVae = async ( + vaeType: string, + modelName: string = "Wan2.1-T2V-1.3B" +): Promise<{ message: string }> => { + const response = await fetch("/api/v1/vae/download", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ vae_type: vaeType, model_name: modelName }), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `VAE download failed: ${response.status} ${response.statusText}: ${errorText}` + ); + } + + const result = await response.json(); + return result; +}; diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index f020349a7..55b75a12d 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -24,7 +24,12 @@ import type { LoraMergeStrategy, } from "../types"; import type { PromptItem, PromptTransition } from "../lib/api"; -import { checkModelStatus, downloadPipelineModels } from "../lib/api"; +import { + checkModelStatus, + checkVaeStatus, + downloadPipelineModels, + downloadVae, +} from "../lib/api"; import { sendLoRAScaleUpdates } from "../utils/loraHelpers"; // Delay before resetting video reinitialization flag (ms) @@ -96,6 +101,10 @@ export function StreamPage() { const [pipelineNeedsModels, setPipelineNeedsModels] = useState( null ); + const [vaeNeedsDownload, setVaeNeedsDownload] = useState<{ + vaeType: string; + modelName: string; + } | null>(null); // Ref to access timeline functions const timelineRef = useRef<{ @@ -274,6 +283,15 @@ export function StreamPage() { }; const handleDownloadModels = async () => { + // Check if we need to download VAE first + if (vaeNeedsDownload) { + await handleDownloadVae(); + } else if (pipelineNeedsModels) { + await handleDownloadPipelineModels(); + } + }; + + const handleDownloadPipelineModels = async () => { if (!pipelineNeedsModels) return; setIsDownloading(true); @@ -354,9 +372,74 @@ export function StreamPage() { } }; + const handleDownloadVae = async () => { + if (!vaeNeedsDownload) return; + + setIsDownloading(true); + setShowDownloadDialog(false); + + try { + await downloadVae(vaeNeedsDownload.vaeType, vaeNeedsDownload.modelName); + + // Start polling to check when download is complete + const checkDownloadComplete = async () => { + try { + const status = await checkVaeStatus( + vaeNeedsDownload.vaeType, + vaeNeedsDownload.modelName + ); + if (status.downloaded) { + setIsDownloading(false); + setVaeNeedsDownload(null); + + // After VAE download, check if pipeline models are also needed + const pipelineIdToUse = pipelineNeedsModels || settings.pipelineId; + const pipelineInfo = PIPELINES[pipelineIdToUse]; + if (pipelineInfo?.requiresModels) { + try { + const pipelineStatus = await checkModelStatus(pipelineIdToUse); + if (!pipelineStatus.downloaded) { + // Still need pipeline models, show dialog for that + setPipelineNeedsModels(pipelineIdToUse); + setShowDownloadDialog(true); + return; + } + } catch (error) { + console.error("Error checking model status:", error); + } + } + + // All downloads complete, start the stream + setTimeout(async () => { + const started = await handleStartStream(); + if (started && timelinePlayPauseRef.current) { + setTimeout(() => { + timelinePlayPauseRef.current?.(); + }, 2000); + } + }, 100); + } else { + // Check again in 2 seconds + setTimeout(checkDownloadComplete, 2000); + } + } catch (error) { + console.error("Error checking VAE download status:", error); + setIsDownloading(false); + } + }; + + // Start checking for completion + setTimeout(checkDownloadComplete, 5000); + } catch (error) { + console.error("Error downloading VAE:", error); + setIsDownloading(false); + } + }; + const handleDialogClose = () => { setShowDownloadDialog(false); setPipelineNeedsModels(null); + setVaeNeedsDownload(null); // When user cancels, no stream or timeline has started yet, so nothing to clean up // Just close the dialog and return early without any state changes @@ -569,6 +652,24 @@ export function StreamPage() { } } + // Check if VAE is needed but not downloaded + // Default to "wan" VAE type (backend will handle VAE selection) + // NOTE: support for other vae types will be added later. const vaeType = settings.vaeType ?? "wan"; + const vaeType = "wan"; + try { + const vaeStatus = await checkVaeStatus(vaeType); + if (!vaeStatus.downloaded) { + // Show download dialog for VAE (use pipeline ID for dialog, but track VAE separately) + setVaeNeedsDownload({ vaeType, modelName: "Wan2.1-T2V-1.3B" }); + setPipelineNeedsModels(pipelineIdToUse); + setShowDownloadDialog(true); + return false; // Stream did not start + } + } catch (error) { + console.error("Error checking VAE status:", error); + // Continue anyway if check fails + } + // Always load pipeline with current parameters - backend will handle the rest console.log(`Loading ${pipelineIdToUse} pipeline...`); @@ -959,6 +1060,7 @@ export function StreamPage() { pipelineId={pipelineNeedsModels as PipelineId} onClose={handleDialogClose} onDownload={handleDownloadModels} + vaeNeedsDownload={vaeNeedsDownload} /> )} diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py index c643470ae..ff5e75130 100644 --- a/src/scope/core/pipelines/wan2_1/vae/__init__.py +++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py @@ -16,8 +16,39 @@ vae = create_vae(model_dir="wan_models", vae_path="/path/to/custom_vae.pth") """ +from dataclasses import dataclass + from .wan import WanVAEWrapper + +@dataclass(frozen=True) +class VAEMetadata: + """Metadata for a VAE type (filenames, download sources).""" + + filename: str + download_repo: str | None = None # None = bundled with main model repo + download_file: str | None = None # None = no separate download needed + + +# Single source of truth for VAE metadata +VAE_METADATA: dict[str, VAEMetadata] = { + "wan": VAEMetadata( + filename="Wan2.1_VAE.pth", + download_repo="Wan-AI/Wan2.1-T2V-1.3B", + download_file="Wan2.1_VAE.pth", + ), + "lightvae": VAEMetadata( + filename="lightvaew2_1.pth", + download_repo="lightx2v/Autoencoders", + download_file="lightvaew2_1.pth", + ), + "tae": VAEMetadata( + filename="taew2_1.pth", + download_repo="lightx2v/Autoencoders", + download_file="taew2_1.pth", + ), +} + # Registry mapping type names to VAE classes # UI dropdowns will use these keys VAE_REGISTRY: dict[str, type] = { @@ -69,8 +100,10 @@ def list_vae_types() -> list[str]: __all__ = [ "WanVAEWrapper", - "create_vae", - "list_vae_types", + "VAEMetadata", + "VAE_METADATA", "VAE_REGISTRY", "DEFAULT_VAE_TYPE", + "create_vae", + "list_vae_types", ] diff --git a/src/scope/server/app.py b/src/scope/server/app.py index ceb2ba305..8cacd2e19 100644 --- a/src/scope/server/app.py +++ b/src/scope/server/app.py @@ -367,6 +367,15 @@ class DownloadModelsRequest(BaseModel): pipeline_id: str +class VaeStatusResponse(BaseModel): + downloaded: bool + + +class DownloadVaeRequest(BaseModel): + vae_type: str + model_name: str = "Wan2.1-T2V-1.3B" + + class LoRAFileInfo(BaseModel): """Metadata for an available LoRA file on disk.""" @@ -452,6 +461,46 @@ def download_in_background(): raise HTTPException(status_code=500, detail=str(e)) from e +@app.get("/api/v1/vae/status", response_model=VaeStatusResponse) +async def get_vae_status(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B"): + """Check if a VAE file is downloaded.""" + try: + from .models_config import vae_file_exists + + downloaded = vae_file_exists(vae_type, model_name) + return VaeStatusResponse(downloaded=downloaded) + except Exception as e: + logger.error(f"Error checking VAE status: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.post("/api/v1/vae/download") +async def download_vae(request: DownloadVaeRequest): + """Download a specific VAE file.""" + try: + if not request.vae_type: + raise HTTPException(status_code=400, detail="vae_type is required") + + # Download in a background thread to avoid blocking + import threading + + from .download_models import download_vae as download_vae_func + + def download_in_background(): + download_vae_func(request.vae_type, request.model_name) + + thread = threading.Thread(target=download_in_background) + thread.daemon = True + thread.start() + + return { + "message": f"VAE download started for {request.vae_type} (model: {request.model_name})" + } + except Exception as e: + logger.error(f"Error starting VAE download: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e + + @app.get("/api/v1/hardware/info", response_model=HardwareInfoResponse) async def get_hardware_info(): """Get hardware information including available VRAM.""" diff --git a/src/scope/server/download_models.py b/src/scope/server/download_models.py index d2d6482ee..cd70a6a51 100644 --- a/src/scope/server/download_models.py +++ b/src/scope/server/download_models.py @@ -65,6 +65,41 @@ def download_hf_single_file(repo_id: str, filename: str, local_dir: Path) -> Non print(f"[OK] Downloaded file '{filename}' from '{repo_id}' to: {out_path}") +def download_vae(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B") -> None: + """ + Download VAE weights for a specific VAE type. + + Args: + vae_type: VAE type (e.g., "wan", "lightvae", "tae") + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + """ + from scope.core.pipelines.wan2_1.vae import VAE_METADATA + + metadata = VAE_METADATA.get(vae_type) + if metadata is None: + available = list(VAE_METADATA.keys()) + raise ValueError(f"Unknown VAE type: {vae_type}. Available: {available}") + + if metadata.download_repo is None: + # VAE is bundled with main model repo (e.g., "wan") + print(f"[INFO] {vae_type} VAE is bundled with {model_name} repo download") + return + + models_root = ensure_models_dir() + download_hf_single_file( + metadata.download_repo, metadata.download_file, models_root / model_name + ) + + +def download_downloadable_vaes(model_name: str = "Wan2.1-T2V-1.3B") -> None: + """Download all VAEs that require separate downloads.""" + from scope.core.pipelines.wan2_1.vae import VAE_METADATA + + for vae_type, metadata in VAE_METADATA.items(): + if metadata.download_repo is not None: + download_vae(vae_type, model_name) + + def download_required_models(): """Download required models if they are not already present.""" if models_are_downloaded(): @@ -111,6 +146,9 @@ def download_streamdiffusionv2_pipeline() -> None: allow_patterns=["wan_causal_dmd_v2v/model.pt"], ) + # 4) Download additional VAE variants + download_downloadable_vaes() + def download_longlive_pipeline() -> None: """Download models for the LongLive pipeline.""" @@ -139,6 +177,9 @@ def download_longlive_pipeline() -> None: # 3) HF repo download for LongLive-1.3B download_hf_repo_excluding(longlive_repo, longlive_dst, ignore_patterns=[]) + # 4) Download additional VAE variants + download_downloadable_vaes() + def download_krea_realtime_video_pipeline() -> None: """ @@ -183,6 +224,9 @@ def download_krea_realtime_video_pipeline() -> None: wan_video_comfy_repo, wan_video_comfy_file, wan_video_comfy_dst ) + # 5) Download additional VAE variants + download_downloadable_vaes() + def download_models(pipeline_id: str | None = None) -> None: """ diff --git a/src/scope/server/models_config.py b/src/scope/server/models_config.py index 6a967e4b1..dcea6552b 100644 --- a/src/scope/server/models_config.py +++ b/src/scope/server/models_config.py @@ -73,6 +73,42 @@ def get_model_file_path(relative_path: str) -> Path: return models_dir / relative_path +def get_vae_file_path(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B") -> Path: + """ + Get the path to a VAE file based on VAE type. + + Args: + vae_type: VAE type (e.g., "wan", "lightvae", "tae") + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + + Returns: + Path: Absolute path to the VAE file + """ + from scope.core.pipelines.wan2_1.vae import VAE_METADATA + + metadata = VAE_METADATA.get(vae_type) + if metadata is None: + available = list(VAE_METADATA.keys()) + raise ValueError(f"Unknown VAE type: {vae_type}. Available: {available}") + + return get_models_dir() / model_name / metadata.filename + + +def vae_file_exists(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B") -> bool: + """ + Check if a VAE file exists. + + Args: + vae_type: VAE type ("wan", "lightvae", or "tae") + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + + Returns: + bool: True if the VAE file exists, False otherwise + """ + vae_path = get_vae_file_path(vae_type, model_name) + return vae_path.exists() + + def get_required_model_files(pipeline_id: str | None = None) -> list[Path]: """ Get the list of required model files that should exist for a given pipeline.