Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions frontend/src/components/DownloadDialog.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ interface DownloadDialogProps {
pipelineId: PipelineId;
onClose: () => void;
onDownload: () => void;
vaeNeedsDownload?: {
vaeType: string;
modelName: string;
} | null;
}

export function DownloadDialog({
open,
pipelineId,
onClose,
onDownload,
vaeNeedsDownload,
}: DownloadDialogProps) {
const pipelineInfo = PIPELINES[pipelineId];
if (!pipelineInfo) return null;
Expand All @@ -31,13 +36,22 @@ export function DownloadDialog({
<Dialog open={open} onOpenChange={isOpen => !isOpen && onClose()}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle>Download Models</DialogTitle>
<DialogTitle>
{vaeNeedsDownload ? "Download VAE Model" : "Download Models"}
</DialogTitle>
<DialogDescription className="mt-3">
This pipeline requires model weights to be downloaded.
{vaeNeedsDownload ? (
<>
The selected VAE model ({vaeNeedsDownload.vaeType}) is missing
and needs to be downloaded.
</>
) : (
<>This pipeline requires model weights to be downloaded.</>
)}
</DialogDescription>
</DialogHeader>

{pipelineInfo.estimatedVram && (
{!vaeNeedsDownload && pipelineInfo.estimatedVram && (
<p className="text-sm text-muted-foreground mb-3">
<span className="font-semibold">
Estimated GPU VRAM Requirement:
Expand Down
44 changes: 44 additions & 0 deletions frontend/src/lib/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,47 @@ export const getPipelineSchemas =
const result = await response.json();
return result;
};

export const checkVaeStatus = async (
vaeType: string,
modelName: string = "Wan2.1-T2V-1.3B"
): Promise<{ downloaded: boolean }> => {
const response = await fetch(
`/api/v1/vae/status?vae_type=${encodeURIComponent(vaeType)}&model_name=${encodeURIComponent(modelName)}`,
{
method: "GET",
headers: { "Content-Type": "application/json" },
}
);

if (!response.ok) {
const errorText = await response.text();
throw new Error(
`VAE status check failed: ${response.status} ${response.statusText}: ${errorText}`
);
}

const result = await response.json();
return result;
};

export const downloadVae = async (
vaeType: string,
modelName: string = "Wan2.1-T2V-1.3B"
): Promise<{ message: string }> => {
const response = await fetch("/api/v1/vae/download", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ vae_type: vaeType, model_name: modelName }),
});

if (!response.ok) {
const errorText = await response.text();
throw new Error(
`VAE download failed: ${response.status} ${response.statusText}: ${errorText}`
);
}

const result = await response.json();
return result;
};
104 changes: 103 additions & 1 deletion frontend/src/pages/StreamPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ import type {
LoraMergeStrategy,
} from "../types";
import type { PromptItem, PromptTransition } from "../lib/api";
import { checkModelStatus, downloadPipelineModels } from "../lib/api";
import {
checkModelStatus,
checkVaeStatus,
downloadPipelineModels,
downloadVae,
} from "../lib/api";
import { sendLoRAScaleUpdates } from "../utils/loraHelpers";

// Delay before resetting video reinitialization flag (ms)
Expand Down Expand Up @@ -96,6 +101,10 @@ export function StreamPage() {
const [pipelineNeedsModels, setPipelineNeedsModels] = useState<string | null>(
null
);
const [vaeNeedsDownload, setVaeNeedsDownload] = useState<{
vaeType: string;
modelName: string;
} | null>(null);

// Ref to access timeline functions
const timelineRef = useRef<{
Expand Down Expand Up @@ -274,6 +283,15 @@ export function StreamPage() {
};

const handleDownloadModels = async () => {
// Check if we need to download VAE first
if (vaeNeedsDownload) {
await handleDownloadVae();
} else if (pipelineNeedsModels) {
await handleDownloadPipelineModels();
}
};

const handleDownloadPipelineModels = async () => {
if (!pipelineNeedsModels) return;

setIsDownloading(true);
Expand Down Expand Up @@ -354,9 +372,74 @@ export function StreamPage() {
}
};

const handleDownloadVae = async () => {
if (!vaeNeedsDownload) return;

setIsDownloading(true);
setShowDownloadDialog(false);

try {
await downloadVae(vaeNeedsDownload.vaeType, vaeNeedsDownload.modelName);

// Start polling to check when download is complete
const checkDownloadComplete = async () => {
try {
const status = await checkVaeStatus(
vaeNeedsDownload.vaeType,
vaeNeedsDownload.modelName
);
if (status.downloaded) {
setIsDownloading(false);
setVaeNeedsDownload(null);

// After VAE download, check if pipeline models are also needed
const pipelineIdToUse = pipelineNeedsModels || settings.pipelineId;
const pipelineInfo = PIPELINES[pipelineIdToUse];
if (pipelineInfo?.requiresModels) {
try {
const pipelineStatus = await checkModelStatus(pipelineIdToUse);
if (!pipelineStatus.downloaded) {
// Still need pipeline models, show dialog for that
setPipelineNeedsModels(pipelineIdToUse);
setShowDownloadDialog(true);
return;
}
} catch (error) {
console.error("Error checking model status:", error);
}
}

// All downloads complete, start the stream
setTimeout(async () => {
const started = await handleStartStream();
if (started && timelinePlayPauseRef.current) {
setTimeout(() => {
timelinePlayPauseRef.current?.();
}, 2000);
}
}, 100);
} else {
// Check again in 2 seconds
setTimeout(checkDownloadComplete, 2000);
}
} catch (error) {
console.error("Error checking VAE download status:", error);
setIsDownloading(false);
}
};

// Start checking for completion
setTimeout(checkDownloadComplete, 5000);
} catch (error) {
console.error("Error downloading VAE:", error);
setIsDownloading(false);
}
};

const handleDialogClose = () => {
setShowDownloadDialog(false);
setPipelineNeedsModels(null);
setVaeNeedsDownload(null);

// When user cancels, no stream or timeline has started yet, so nothing to clean up
// Just close the dialog and return early without any state changes
Expand Down Expand Up @@ -569,6 +652,24 @@ export function StreamPage() {
}
}

// Check if VAE is needed but not downloaded
// Default to "wan" VAE type (backend will handle VAE selection)
// NOTE: support for other vae types will be added later. const vaeType = settings.vaeType ?? "wan";
const vaeType = "wan";
try {
const vaeStatus = await checkVaeStatus(vaeType);
if (!vaeStatus.downloaded) {
// Show download dialog for VAE (use pipeline ID for dialog, but track VAE separately)
setVaeNeedsDownload({ vaeType, modelName: "Wan2.1-T2V-1.3B" });
setPipelineNeedsModels(pipelineIdToUse);
setShowDownloadDialog(true);
return false; // Stream did not start
}
} catch (error) {
console.error("Error checking VAE status:", error);
// Continue anyway if check fails
}

// Always load pipeline with current parameters - backend will handle the rest
console.log(`Loading ${pipelineIdToUse} pipeline...`);

Expand Down Expand Up @@ -959,6 +1060,7 @@ export function StreamPage() {
pipelineId={pipelineNeedsModels as PipelineId}
onClose={handleDialogClose}
onDownload={handleDownloadModels}
vaeNeedsDownload={vaeNeedsDownload}
/>
)}
</div>
Expand Down
37 changes: 35 additions & 2 deletions src/scope/core/pipelines/wan2_1/vae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,39 @@
vae = create_vae(model_dir="wan_models", vae_path="/path/to/custom_vae.pth")
"""

from dataclasses import dataclass

from .wan import WanVAEWrapper


@dataclass(frozen=True)
class VAEMetadata:
"""Metadata for a VAE type (filenames, download sources)."""

filename: str
download_repo: str | None = None # None = bundled with main model repo
download_file: str | None = None # None = no separate download needed


# Single source of truth for VAE metadata
VAE_METADATA: dict[str, VAEMetadata] = {
"wan": VAEMetadata(
filename="Wan2.1_VAE.pth",
download_repo="Wan-AI/Wan2.1-T2V-1.3B",
download_file="Wan2.1_VAE.pth",
),
"lightvae": VAEMetadata(
filename="lightvaew2_1.pth",
download_repo="lightx2v/Autoencoders",
download_file="lightvaew2_1.pth",
),
"tae": VAEMetadata(
filename="taew2_1.pth",
download_repo="lightx2v/Autoencoders",
download_file="taew2_1.pth",
),
}

# Registry mapping type names to VAE classes
# UI dropdowns will use these keys
VAE_REGISTRY: dict[str, type] = {
Expand Down Expand Up @@ -69,8 +100,10 @@ def list_vae_types() -> list[str]:

__all__ = [
"WanVAEWrapper",
"create_vae",
"list_vae_types",
"VAEMetadata",
"VAE_METADATA",
"VAE_REGISTRY",
"DEFAULT_VAE_TYPE",
"create_vae",
"list_vae_types",
]
49 changes: 49 additions & 0 deletions src/scope/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,15 @@ class DownloadModelsRequest(BaseModel):
pipeline_id: str


class VaeStatusResponse(BaseModel):
downloaded: bool


class DownloadVaeRequest(BaseModel):
vae_type: str
model_name: str = "Wan2.1-T2V-1.3B"


class LoRAFileInfo(BaseModel):
"""Metadata for an available LoRA file on disk."""

Expand Down Expand Up @@ -452,6 +461,46 @@ def download_in_background():
raise HTTPException(status_code=500, detail=str(e)) from e


@app.get("/api/v1/vae/status", response_model=VaeStatusResponse)
async def get_vae_status(vae_type: str, model_name: str = "Wan2.1-T2V-1.3B"):
"""Check if a VAE file is downloaded."""
try:
from .models_config import vae_file_exists

downloaded = vae_file_exists(vae_type, model_name)
return VaeStatusResponse(downloaded=downloaded)
except Exception as e:
logger.error(f"Error checking VAE status: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e


@app.post("/api/v1/vae/download")
async def download_vae(request: DownloadVaeRequest):
"""Download a specific VAE file."""
try:
if not request.vae_type:
raise HTTPException(status_code=400, detail="vae_type is required")

# Download in a background thread to avoid blocking
import threading

from .download_models import download_vae as download_vae_func

def download_in_background():
download_vae_func(request.vae_type, request.model_name)

thread = threading.Thread(target=download_in_background)
thread.daemon = True
thread.start()

return {
"message": f"VAE download started for {request.vae_type} (model: {request.model_name})"
}
except Exception as e:
logger.error(f"Error starting VAE download: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e


@app.get("/api/v1/hardware/info", response_model=HardwareInfoResponse)
async def get_hardware_info():
"""Get hardware information including available VRAM."""
Expand Down
Loading
Loading