From ba8e68fd3be22cbcf51395fb8345712d4962e82c Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:45:20 -0500 Subject: [PATCH] feat: add VAE download infrastructure with metadata registry feat: add VAE download infrastructure with metadata registry Adds complete infrastructure for downloading and managing VAE models independently of pipeline downloads. This change has no effect on the current code. When the ability to select different VAE types is added, the system will check if the selected VAE is downloaded and prompt for download if missing. Backend: - Add /api/v1/vae/status endpoint to check VAE download status - Add /api/v1/vae/download endpoint to trigger VAE downloads - Add VAEMetadata dataclass and VAE_METADATA registry as single source of truth for VAE filenames and download sources - Add vae_file_exists() and get_vae_file_path() using metadata registry - Add download_vae() and download_downloadable_vaes() functions Frontend: - Add checkVaeStatus() and downloadVae() API functions - Extend DownloadDialog to show VAE-specific download prompts - Add VAE status checking before stream start in StreamPage - Add separate VAE download flow with polling for completion Prepares codebase for upcoming PRs adding additional VAE types. Signed-off-by: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> --- frontend/src/components/DownloadDialog.tsx | 20 +++- frontend/src/lib/api.ts | 44 ++++++++ frontend/src/pages/StreamPage.tsx | 104 +++++++++++++++++- .../core/pipelines/wan2_1/vae/__init__.py | 37 ++++++- src/scope/server/app.py | 49 +++++++++ src/scope/server/download_models.py | 44 ++++++++ src/scope/server/models_config.py | 36 ++++++ 7 files changed, 328 insertions(+), 6 deletions(-) diff --git a/frontend/src/components/DownloadDialog.tsx b/frontend/src/components/DownloadDialog.tsx index fd1ab768c..0dfec0033 100644 --- a/frontend/src/components/DownloadDialog.tsx +++ b/frontend/src/components/DownloadDialog.tsx @@ -16,6 +16,10 @@ interface DownloadDialogProps { pipelineId: PipelineId; onClose: () => void; onDownload: () => void; + vaeNeedsDownload?: { + vaeType: string; + modelName: string; + } | null; } export function DownloadDialog({ @@ -23,6 +27,7 @@ export function DownloadDialog({ pipelineId, onClose, onDownload, + vaeNeedsDownload, }: DownloadDialogProps) { const pipelineInfo = PIPELINES[pipelineId]; if (!pipelineInfo) return null; @@ -31,13 +36,22 @@ export function DownloadDialog({ !isOpen && onClose()}> - Download Models + + {vaeNeedsDownload ? "Download VAE Model" : "Download Models"} + - This pipeline requires model weights to be downloaded. + {vaeNeedsDownload ? ( + <> + The selected VAE model ({vaeNeedsDownload.vaeType}) is missing + and needs to be downloaded. + + ) : ( + <>This pipeline requires model weights to be downloaded. + )} - {pipelineInfo.estimatedVram && ( + {!vaeNeedsDownload && pipelineInfo.estimatedVram && (

Estimated GPU VRAM Requirement: diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 05c431fcf..a6f716b8c 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -307,3 +307,47 @@ export const getPipelineSchemas = const result = await response.json(); return result; }; + +export const checkVaeStatus = async ( + vaeType: string, + modelName: string = "Wan2.1-T2V-1.3B" +): Promise<{ downloaded: boolean }> => { + const response = await fetch( + `/api/v1/vae/status?vae_type=${encodeURIComponent(vaeType)}&model_name=${encodeURIComponent(modelName)}`, + { + method: "GET", + headers: { "Content-Type": "application/json" }, + } + ); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `VAE status check failed: ${response.status} ${response.statusText}: ${errorText}` + ); + } + + const result = await response.json(); + return result; +}; + +export const downloadVae = async ( + vaeType: string, + modelName: string = "Wan2.1-T2V-1.3B" +): Promise<{ message: string }> => { + const response = await fetch("/api/v1/vae/download", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ vae_type: vaeType, model_name: modelName }), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `VAE download failed: ${response.status} ${response.statusText}: ${errorText}` + ); + } + + const result = await response.json(); + return result; +}; diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index f020349a7..55b75a12d 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -24,7 +24,12 @@ import type { LoraMergeStrategy, } from "../types"; import type { PromptItem, PromptTransition } from "../lib/api"; -import { checkModelStatus, downloadPipelineModels } from "../lib/api"; +import { + checkModelStatus, + checkVaeStatus, + downloadPipelineModels, + downloadVae, +} from "../lib/api"; import { sendLoRAScaleUpdates } from "../utils/loraHelpers"; // Delay before resetting video reinitialization flag (ms) @@ -96,6 +101,10 @@ export function StreamPage() { const [pipelineNeedsModels, setPipelineNeedsModels] = useState( null ); + const [vaeNeedsDownload, setVaeNeedsDownload] = useState<{ + vaeType: string; + modelName: string; + } | null>(null); // Ref to access timeline functions const timelineRef = useRef<{ @@ -274,6 +283,15 @@ export function StreamPage() { }; const handleDownloadModels = async () => { + // Check if we need to download VAE first + if (vaeNeedsDownload) { + await handleDownloadVae(); + } else if (pipelineNeedsModels) { + await handleDownloadPipelineModels(); + } + }; + + const handleDownloadPipelineModels = async () => { if (!pipelineNeedsModels) return; setIsDownloading(true); @@ -354,9 +372,74 @@ export function StreamPage() { } }; + const handleDownloadVae = async () => { + if (!vaeNeedsDownload) return; + + setIsDownloading(true); + setShowDownloadDialog(false); + + try { + await downloadVae(vaeNeedsDownload.vaeType, vaeNeedsDownload.modelName); + + // Start polling to check when download is complete + const checkDownloadComplete = async () => { + try { + const status = await checkVaeStatus( + vaeNeedsDownload.vaeType, + vaeNeedsDownload.modelName + ); + if (status.downloaded) { + setIsDownloading(false); + setVaeNeedsDownload(null); + + // After VAE download, check if pipeline models are also needed + const pipelineIdToUse = pipelineNeedsModels || settings.pipelineId; + const pipelineInfo = PIPELINES[pipelineIdToUse]; + if (pipelineInfo?.requiresModels) { + try { + const pipelineStatus = await checkModelStatus(pipelineIdToUse); + if (!pipelineStatus.downloaded) { + // Still need pipeline models, show dialog for that + setPipelineNeedsModels(pipelineIdToUse); + setShowDownloadDialog(true); + return; + } + } catch (error) { + console.error("Error checking model status:", error); + } + } + + // All downloads complete, start the stream + setTimeout(async () => { + const started = await handleStartStream(); + if (started && timelinePlayPauseRef.current) { + setTimeout(() => { + timelinePlayPauseRef.current?.(); + }, 2000); + } + }, 100); + } else { + // Check again in 2 seconds + setTimeout(checkDownloadComplete, 2000); + } + } catch (error) { + console.error("Error checking VAE download status:", error); + setIsDownloading(false); + } + }; + + // Start checking for completion + setTimeout(checkDownloadComplete, 5000); + } catch (error) { + console.error("Error downloading VAE:", error); + setIsDownloading(false); + } + }; + const handleDialogClose = () => { setShowDownloadDialog(false); setPipelineNeedsModels(null); + setVaeNeedsDownload(null); // When user cancels, no stream or timeline has started yet, so nothing to clean up // Just close the dialog and return early without any state changes @@ -569,6 +652,24 @@ export function StreamPage() { } } + // Check if VAE is needed but not downloaded + // Default to "wan" VAE type (backend will handle VAE selection) + // NOTE: support for other vae types will be added later. const vaeType = settings.vaeType ?? "wan"; + const vaeType = "wan"; + try { + const vaeStatus = await checkVaeStatus(vaeType); + if (!vaeStatus.downloaded) { + // Show download dialog for VAE (use pipeline ID for dialog, but track VAE separately) + setVaeNeedsDownload({ vaeType, modelName: "Wan2.1-T2V-1.3B" }); + setPipelineNeedsModels(pipelineIdToUse); + setShowDownloadDialog(true); + return false; // Stream did not start + } + } catch (error) { + console.error("Error checking VAE status:", error); + // Continue anyway if check fails + } + // Always load pipeline with current parameters - backend will handle the rest console.log(`Loading ${pipelineIdToUse} pipeline...`); @@ -959,6 +1060,7 @@ export function StreamPage() { pipelineId={pipelineNeedsModels as PipelineId} onClose={handleDialogClose} onDownload={handleDownloadModels} + vaeNeedsDownload={vaeNeedsDownload} /> )} diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py index c643470ae..ff5e75130 100644 --- a/src/scope/core/pipelines/wan2_1/vae/__init__.py +++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py @@ -16,8 +16,39 @@ vae = create_vae(model_dir="wan_models", vae_path="/path/to/custom_vae.pth") """ +from dataclasses import dataclass + from .wan import WanVAEWrapper + +@dataclass(frozen=True) +class VAEMetadata: + """Metadata for a VAE type (filenames, download sources).""" + + filename: str + download_repo: str | None = None # None = bundled with main model repo + download_file: str | None = None # None = no separate download needed + + +# Single source of truth for VAE metadata +VAE_METADATA: dict[str, VAEMetadata] = { + "wan": VAEMetadata( + filename="Wan2.1_VAE.pth", + download_repo="Wan-AI/Wan2.1-T2V-1.3B", + download_file="Wan2.1_VAE.pth", + ), + "lightvae": VAEMetadata( + filename="lightvaew2_1.pth", + download_repo="lightx2v/Autoencoders", + download_file="lightvaew2_1.pth", + ), + "tae": VAEMetadata( + filename="taew2_1.pth", + download_repo="lightx2v/Autoencoders", + download_file="taew2_1.pth", + ), +} + # Registry mapping type names to VAE classes # UI dropdowns will use these keys VAE_REGISTRY: dict[str, type] = { @@ -69,8 +100,10 @@ def list_vae_types() -> list[str]: __all__ = [ "WanVAEWrapper", - "create_vae", - "list_vae_types", + "VAEMetadata", + "VAE_METADATA", "VAE_REGISTRY", "DEFAULT_VAE_TYPE", + "create_vae", + "list_vae_types", ] diff --git a/src/scope/server/app.py b/src/scope/server/app.py index ceb2ba305..8cacd2e19 100644 --- a/src/scope/server/app.py +++ b/src/scope/server/app.py @@ -367,6 +367,15 @@ class DownloadModelsRequest(BaseModel): pipeline_id: str +class VaeStatusResponse(BaseModel): + downloaded: bool + + +class DownloadVaeRequest(BaseModel): + vae_type: str + model_name: str = "Wan2.1-T2V-1.3B" + + class LoRAFileInfo(BaseModel): """Metadata for an available LoRA file on disk.""" @@ -452,6 +461,46 @@ def download_in_background(): raise HTTPException(status_code=500, detail=str(e)) from e +@app.get("/api/v1/vae/status", response_model=VaeStatusResponse) +async def get_vae_status(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B"): + """Check if a VAE file is downloaded.""" + try: + from .models_config import vae_file_exists + + downloaded = vae_file_exists(vae_type, model_name) + return VaeStatusResponse(downloaded=downloaded) + except Exception as e: + logger.error(f"Error checking VAE status: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.post("/api/v1/vae/download") +async def download_vae(request: DownloadVaeRequest): + """Download a specific VAE file.""" + try: + if not request.vae_type: + raise HTTPException(status_code=400, detail="vae_type is required") + + # Download in a background thread to avoid blocking + import threading + + from .download_models import download_vae as download_vae_func + + def download_in_background(): + download_vae_func(request.vae_type, request.model_name) + + thread = threading.Thread(target=download_in_background) + thread.daemon = True + thread.start() + + return { + "message": f"VAE download started for {request.vae_type} (model: {request.model_name})" + } + except Exception as e: + logger.error(f"Error starting VAE download: {e}") + raise HTTPException(status_code=500, detail=str(e)) from e + + @app.get("/api/v1/hardware/info", response_model=HardwareInfoResponse) async def get_hardware_info(): """Get hardware information including available VRAM.""" diff --git a/src/scope/server/download_models.py b/src/scope/server/download_models.py index d2d6482ee..cd70a6a51 100644 --- a/src/scope/server/download_models.py +++ b/src/scope/server/download_models.py @@ -65,6 +65,41 @@ def download_hf_single_file(repo_id: str, filename: str, local_dir: Path) -> Non print(f"[OK] Downloaded file '{filename}' from '{repo_id}' to: {out_path}") +def download_vae(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B") -> None: + """ + Download VAE weights for a specific VAE type. + + Args: + vae_type: VAE type (e.g., "wan", "lightvae", "tae") + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + """ + from scope.core.pipelines.wan2_1.vae import VAE_METADATA + + metadata = VAE_METADATA.get(vae_type) + if metadata is None: + available = list(VAE_METADATA.keys()) + raise ValueError(f"Unknown VAE type: {vae_type}. Available: {available}") + + if metadata.download_repo is None: + # VAE is bundled with main model repo (e.g., "wan") + print(f"[INFO] {vae_type} VAE is bundled with {model_name} repo download") + return + + models_root = ensure_models_dir() + download_hf_single_file( + metadata.download_repo, metadata.download_file, models_root / model_name + ) + + +def download_downloadable_vaes(model_name: str = "Wan2.1-T2V-1.3B") -> None: + """Download all VAEs that require separate downloads.""" + from scope.core.pipelines.wan2_1.vae import VAE_METADATA + + for vae_type, metadata in VAE_METADATA.items(): + if metadata.download_repo is not None: + download_vae(vae_type, model_name) + + def download_required_models(): """Download required models if they are not already present.""" if models_are_downloaded(): @@ -111,6 +146,9 @@ def download_streamdiffusionv2_pipeline() -> None: allow_patterns=["wan_causal_dmd_v2v/model.pt"], ) + # 4) Download additional VAE variants + download_downloadable_vaes() + def download_longlive_pipeline() -> None: """Download models for the LongLive pipeline.""" @@ -139,6 +177,9 @@ def download_longlive_pipeline() -> None: # 3) HF repo download for LongLive-1.3B download_hf_repo_excluding(longlive_repo, longlive_dst, ignore_patterns=[]) + # 4) Download additional VAE variants + download_downloadable_vaes() + def download_krea_realtime_video_pipeline() -> None: """ @@ -183,6 +224,9 @@ def download_krea_realtime_video_pipeline() -> None: wan_video_comfy_repo, wan_video_comfy_file, wan_video_comfy_dst ) + # 5) Download additional VAE variants + download_downloadable_vaes() + def download_models(pipeline_id: str | None = None) -> None: """ diff --git a/src/scope/server/models_config.py b/src/scope/server/models_config.py index 6a967e4b1..dcea6552b 100644 --- a/src/scope/server/models_config.py +++ b/src/scope/server/models_config.py @@ -73,6 +73,42 @@ def get_model_file_path(relative_path: str) -> Path: return models_dir / relative_path +def get_vae_file_path(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B") -> Path: + """ + Get the path to a VAE file based on VAE type. + + Args: + vae_type: VAE type (e.g., "wan", "lightvae", "tae") + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + + Returns: + Path: Absolute path to the VAE file + """ + from scope.core.pipelines.wan2_1.vae import VAE_METADATA + + metadata = VAE_METADATA.get(vae_type) + if metadata is None: + available = list(VAE_METADATA.keys()) + raise ValueError(f"Unknown VAE type: {vae_type}. Available: {available}") + + return get_models_dir() / model_name / metadata.filename + + +def vae_file_exists(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B") -> bool: + """ + Check if a VAE file exists. + + Args: + vae_type: VAE type ("wan", "lightvae", or "tae") + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + + Returns: + bool: True if the VAE file exists, False otherwise + """ + vae_path = get_vae_file_path(vae_type, model_name) + return vae_path.exists() + + def get_required_model_files(pipeline_id: str | None = None) -> list[Path]: """ Get the list of required model files that should exist for a given pipeline.