diff --git a/frontend/src/data/parameterMetadata.ts b/frontend/src/data/parameterMetadata.ts index 431c584b2..816aa85ee 100644 --- a/frontend/src/data/parameterMetadata.ts +++ b/frontend/src/data/parameterMetadata.ts @@ -80,6 +80,6 @@ export const PARAMETER_METADATA: Record = { vaeType: { label: "VAE:", tooltip: - "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality.", + "VAE type to use for encoding/decoding. 'wan' is the full VAE with best quality. 'lightvae' is 75% pruned for faster performance but lower quality. 'tae' is a tiny autoencoder for fast preview quality.", }, }; diff --git a/src/scope/core/pipelines/wan2_1/vae/__init__.py b/src/scope/core/pipelines/wan2_1/vae/__init__.py index ec2cf4797..72f37c063 100644 --- a/src/scope/core/pipelines/wan2_1/vae/__init__.py +++ b/src/scope/core/pipelines/wan2_1/vae/__init__.py @@ -21,6 +21,7 @@ from functools import partial +from .tae import TAEWrapper from .wan import WanVAEWrapper # Registry mapping type names to VAE factory functions @@ -28,6 +29,7 @@ VAE_REGISTRY: dict[str, type] = { "wan": WanVAEWrapper, "lightvae": partial(WanVAEWrapper, use_lightvae=True), + "tae": TAEWrapper, } DEFAULT_VAE_TYPE = "wan" @@ -38,7 +40,7 @@ def create_vae( model_name: str = "Wan2.1-T2V-1.3B", vae_type: str | None = None, vae_path: str | None = None, -) -> WanVAEWrapper: +) -> WanVAEWrapper | TAEWrapper: """Create VAE instance by type. Args: @@ -74,6 +76,7 @@ def list_vae_types() -> list[str]: __all__ = [ "WanVAEWrapper", + "TAEWrapper", "create_vae", "list_vae_types", "VAE_REGISTRY", diff --git a/src/scope/core/pipelines/wan2_1/vae/tae.py b/src/scope/core/pipelines/wan2_1/vae/tae.py new file mode 100644 index 000000000..26114ee2a --- /dev/null +++ b/src/scope/core/pipelines/wan2_1/vae/tae.py @@ -0,0 +1,625 @@ +"""Tiny AutoEncoder (TAE) wrapper for Wan2.1 models. + +TAE is a lightweight alternative VAE architecture from the LightX2V project. +Unlike WanVAE, TAE is a completely different architecture - a much smaller/faster +model designed for quick encoding/decoding previews. + +Key differences from WanVAE: +- Uses MemBlock for temporal memory (different from CausalConv3d caching) +- Has TPool/TGrow blocks for temporal downsampling/upsampling +- Much simpler architecture with 64 channels throughout encoder +- Approximately 4x temporal upscaling in decoder (TGrow blocks expand frames) + +Streaming mode: +- TAE supports streaming decode via parallel processing with persistent MemBlock memory +- Each batch is processed in parallel (fast) while memory state is maintained across batches +- This provides both speed AND temporal continuity for smooth streaming +- First decode call has fewer output frames due to TGrow expansion and frame trimming (3 frames) +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import load_file + +# Note: TAE does NOT use WAN_VAE_LATENT_MEAN/STD - it has its own latent space + +# Default checkpoint filename for Wan 2.1 TAE +DEFAULT_TAE_FILENAME = "taew2_1.pth" + + +def _conv(n_in: int, n_out: int, **kwargs) -> nn.Conv2d: + """Create a 3x3 Conv2d with padding.""" + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class _Clamp(nn.Module): + """Clamp activation using scaled tanh.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.tanh(x / 3) * 3 + + +class _MemBlock(nn.Module): + """Memory block that combines current input with past state.""" + + def __init__(self, n_in: int, n_out: int, act_func: nn.Module): + super().__init__() + self.conv = nn.Sequential( + _conv(n_in * 2, n_out), + act_func, + _conv(n_out, n_out), + act_func, + _conv(n_out, n_out), + ) + self.skip = ( + nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + ) + self.act = act_func + + def forward(self, x: torch.Tensor, past: torch.Tensor) -> torch.Tensor: + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + + +class _TPool(nn.Module): + """Temporal pooling block that combines multiple frames.""" + + def __init__(self, n_f: int, stride: int): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + + +class _TGrow(nn.Module): + """Temporal growth block that expands to multiple frames.""" + + def __init__(self, n_f: int, stride: int): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + + +def _apply_model_parallel_streaming( + model: nn.Sequential, + x: torch.Tensor, + N: int, + initial_mem: list[torch.Tensor | None] | None = None, +) -> tuple[torch.Tensor, list[torch.Tensor | None]]: + """Apply model in parallel mode with streaming memory support. + + This processes all frames in parallel (fast) while maintaining temporal + continuity across batches by using initial memory from the previous batch. + + Args: + model: nn.Sequential of blocks to apply + x: input data reshaped to (N*T, C, H, W) + N: batch size (for reshaping) + initial_mem: Initial memory values for each MemBlock (from previous batch). + If None, uses zeros for first batch. + + Returns: + Tuple of (NTCHW output tensor, list of final memory values for next batch) + """ + # Count MemBlocks for memory initialization + num_memblocks = sum(1 for b in model if isinstance(b, _MemBlock)) + + # Initialize memory list if not provided + if initial_mem is None: + initial_mem = [None] * num_memblocks + + # Track which MemBlock we're at + mem_idx = 0 + final_mem = [] + + for b in model: + if isinstance(b, _MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + + # Create memory: pad with initial_mem at t=0, then shift frames + if initial_mem[mem_idx] is not None: + # Use previous batch's last frame as initial memory + init_frame = initial_mem[mem_idx].reshape(N, 1, C, H, W) + mem = torch.cat([init_frame, _x[:, :-1]], dim=1).reshape(x.shape) + else: + # First batch - use zeros + mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape( + x.shape + ) + + # Save last frame for next batch (input before processing) + final_mem.append(_x[:, -1:].reshape(N, C, H, W).clone()) + mem_idx += 1 + + x = b(x, mem) + else: + x = b(x) + + NT, C, H, W = x.shape + T = NT // N + return x.view(N, T, C, H, W), final_mem + + +def _apply_model_with_memblocks( + model: nn.Sequential, + x: torch.Tensor, + parallel: bool = True, + show_progress_bar: bool = False, +) -> torch.Tensor: + """Apply a sequential model with memblocks to the given input (batch mode). + + Args: + model: nn.Sequential of blocks to apply + x: input data, of dimensions NTCHW + parallel: unused, kept for API compatibility (always uses parallel) + show_progress_bar: unused, kept for API compatibility + + Returns: + NTCHW tensor of output data. + """ + assert x.ndim == 5, f"_apply_model_with_memblocks: TAE expects NTCHW, got {x.ndim}D" + N, T, C, H, W = x.shape + x = x.reshape(N * T, C, H, W) + result, _ = _apply_model_parallel_streaming(model, x, N, initial_mem=None) + return result + + +class _TAEModel(nn.Module): + """Tiny AutoEncoder model for Wan 2.1. + + This is a lightweight VAE designed for quick previews. It uses a different + architecture than the standard WanVAE, with MemBlocks for temporal processing. + + Supports two decode modes: + - Batch mode (decode_video): Process all frames at once + - Streaming mode (stream_decode): Process frames incrementally with persistent memory + """ + + def __init__( + self, + checkpoint_path: str | None = None, + decoder_time_upscale: tuple[bool, bool] = (True, True), + decoder_space_upscale: tuple[bool, bool, bool] = (True, True, True), + patch_size: int = 1, + latent_channels: int = 16, + ): + """Initialize TAE model. + + Args: + checkpoint_path: Path to weight file (.pth or .safetensors) + decoder_time_upscale: Whether temporal upsampling is enabled for each block + decoder_space_upscale: Whether spatial upsampling is enabled for each block + patch_size: Input/output pixelshuffle patch-size (1 for Wan 2.1) + latent_channels: Number of latent channels (16 for Wan 2.1) + """ + super().__init__() + self.patch_size = patch_size + self.latent_channels = latent_channels + self.image_channels = 3 + + # Wan 2.1 uses ReLU activation + act_func = nn.ReLU(inplace=True) + + # Encoder: 64 channels throughout, simple architecture + self.encoder = nn.Sequential( + _conv(self.image_channels * self.patch_size**2, 64), + act_func, + _TPool(64, 2), + _conv(64, 64, stride=2, bias=False), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _TPool(64, 2), + _conv(64, 64, stride=2, bias=False), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _TPool(64, 1), + _conv(64, 64, stride=2, bias=False), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _MemBlock(64, 64, act_func), + _conv(64, self.latent_channels), + ) + + # Decoder with configurable upscaling + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 + self._decoder_time_upscale = decoder_time_upscale + + self.decoder = nn.Sequential( + _Clamp(), + _conv(self.latent_channels, n_f[0]), + act_func, + _MemBlock(n_f[0], n_f[0], act_func), + _MemBlock(n_f[0], n_f[0], act_func), + _MemBlock(n_f[0], n_f[0], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), + _TGrow(n_f[0], 1), + _conv(n_f[0], n_f[1], bias=False), + _MemBlock(n_f[1], n_f[1], act_func), + _MemBlock(n_f[1], n_f[1], act_func), + _MemBlock(n_f[1], n_f[1], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), + _TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), + _conv(n_f[1], n_f[2], bias=False), + _MemBlock(n_f[2], n_f[2], act_func), + _MemBlock(n_f[2], n_f[2], act_func), + _MemBlock(n_f[2], n_f[2], act_func), + nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), + _TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), + _conv(n_f[2], n_f[3], bias=False), + act_func, + _conv(n_f[3], self.image_channels * self.patch_size**2), + ) + + # Streaming state for parallel streaming encode/decode + self._encoder_mem: list[torch.Tensor | None] | None = None + self._decoder_mem: list[torch.Tensor | None] | None = None + self._frames_output: int = 0 # Track output frames for trim handling + + if checkpoint_path is not None: + ext = os.path.splitext(checkpoint_path)[1].lower() + if ext == ".pth": + state_dict = torch.load( + checkpoint_path, map_location="cpu", weights_only=True + ) + elif ext == ".safetensors": + state_dict = load_file(checkpoint_path, device="cpu") + else: + raise ValueError( + f"_TAEModel.__init__: Unsupported checkpoint format: {ext}. " + "Supported: .pth, .safetensors" + ) + self.load_state_dict(self._patch_tgrow_layers(state_dict)) + + def _patch_tgrow_layers(self, sd: dict) -> dict: + """Patch TGrow layers to use a smaller kernel if needed. + + Args: + sd: state dict to patch + + Returns: + Patched state dict + """ + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, _TGrow): + key = f"decoder.{i}.conv.weight" + if sd[key].shape[0] > new_sd[key].shape[0]: + # Take the last-timestep output channels + sd[key] = sd[key][-new_sd[key].shape[0] :] + return sd + + def clear_decode_state(self): + """Clear decoder streaming state for a new sequence.""" + self._decoder_mem = None + self._frames_output = 0 + + def clear_encode_state(self): + """Clear encoder streaming state for a new sequence.""" + self._encoder_mem = None + + def stream_encode( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Encode frames in streaming mode with persistent memory. + + This uses parallel processing within each batch for speed, while maintaining + MemBlock memory across batches for smooth temporal continuity at chunk + boundaries. + + Unlike encode_video, this maintains state across calls. + Call clear_encode_state() before a new sequence. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1] + + Returns: + NTCHW latent tensor with approximately Gaussian values + """ + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) + if x.shape[1] % 4 != 0: + # Pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + + N, T, C, H, W = x.shape + x_flat = x.reshape(N * T, C, H, W) + + result, self._encoder_mem = _apply_model_parallel_streaming( + self.encoder, + x_flat, + N, + initial_mem=self._encoder_mem, + ) + + return result + + def encode_video( + self, + x: torch.Tensor, + parallel: bool = True, + show_progress_bar: bool = False, + ) -> torch.Tensor: + """Encode a sequence of frames. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1] + parallel: if True, all frames processed at once (faster, more memory) + if False, frames processed sequentially (slower, O(1) memory) + show_progress_bar: if True, display tqdm progress bar + + Returns: + NTCHW latent tensor with approximately Gaussian values + """ + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) + if x.shape[1] % 4 != 0: + # Pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + return _apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar) + + def decode_video( + self, + x: torch.Tensor, + parallel: bool = True, + show_progress_bar: bool = False, + ) -> torch.Tensor: + """Decode a sequence of frames (batch mode). + + Args: + x: input NTCHW latent tensor with approximately Gaussian values + parallel: if True, all frames processed at once (faster, more memory) + if False, frames processed sequentially (slower, O(1) memory) + show_progress_bar: if True, display tqdm progress bar + + Returns: + NTCHW RGB tensor with values clamped to [0, 1] + """ + x = _apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar) + x = x.clamp_(0, 1) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim :] + + def stream_decode( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Decode frames in streaming mode with persistent memory. + + This uses parallel processing within each batch for speed, while maintaining + MemBlock memory across batches for smooth temporal continuity. + + On the first batch, frames are processed sequentially (first frame separately, + then remaining frames) to match WanVAE warmup behavior for better temporal + consistency. + + Unlike decode_video, this maintains state across calls. + Call clear_decode_state() before a new sequence. + + Args: + x: input NTCHW latent tensor (typically 1-4 frames at a time) + + Returns: + NTCHW RGB tensor with values in [0, 1]. + First call returns fewer frames due to temporal trim. + """ + N, T, C, H, W = x.shape + + # First batch warmup: process first frame separately, then remaining frames + # This matches WanVAE's warmup behavior for better temporal consistency + if self._frames_output == 0: + # Clear decoder memory state for first batch + self._decoder_mem = None + + # Process first frame separately + first_frame = x[:, :1, :, :, :] # [N, 1, C, H, W] + first_flat = first_frame.reshape(N * 1, C, H, W) + first_result, first_mem = _apply_model_parallel_streaming( + self.decoder, + first_flat, + N, + initial_mem=None, # Use zeros for first frame + ) + + # Process remaining frames if any + if T > 1: + remaining_frames = x[:, 1:, :, :, :] # [N, T-1, C, H, W] + remaining_flat = remaining_frames.reshape(N * (T - 1), C, H, W) + remaining_result, self._decoder_mem = _apply_model_parallel_streaming( + self.decoder, + remaining_flat, + N, + initial_mem=first_mem, # Use memory from first frame + ) + # Concatenate first frame and remaining frames + result = torch.cat([first_result, remaining_result], dim=1) + else: + # Only one frame + result = first_result + self._decoder_mem = first_mem + else: + # Subsequent batches: use parallel processing with persistent memory + x_flat = x.reshape(N * T, C, H, W) + result, self._decoder_mem = _apply_model_parallel_streaming( + self.decoder, + x_flat, + N, + initial_mem=self._decoder_mem, + ) + + result = result.clamp_(0, 1) + + if self.patch_size > 1: + result = F.pixel_shuffle(result, self.patch_size) + + # Handle temporal trim - only trim on first output + if self._frames_output == 0 and result.shape[1] > self.frames_to_trim: + result = result[:, self.frames_to_trim :] + + self._frames_output += result.shape[1] + + return result + + +class TAEWrapper(nn.Module): + """TAE wrapper with interface matching WanVAEWrapper. + + This provides a consistent interface for the Tiny AutoEncoder that matches + the WanVAEWrapper's encode_to_latent/decode_to_pixel/clear_cache API. + + Note: TAE is a lightweight preview encoder with its own latent space. It does + NOT use WanVAE's normalization constants - TAE produces approximately Gaussian + latents directly. Quality may be lower than WanVAE but encoding/decoding is faster. + + Streaming mode (use_cache=True): + TAE maintains persistent MemBlock memory for smooth frame-by-frame streaming. + This is faster than batch mode for real-time applications since it processes + smaller chunks while maintaining temporal continuity. + + Batch mode (use_cache=False): + Processes all frames at once without persistent state. Good for one-shot + encoding/decoding of complete videos. + + Args: + model_dir: Base directory containing model files + model_name: Model subdirectory name (e.g., "Wan2.1-T2V-1.3B") + vae_path: Explicit path to TAE checkpoint (overrides model_dir/model_name) + """ + + def __init__( + self, + model_dir: str = "wan_models", + model_name: str = "Wan2.1-T2V-1.3B", + vae_path: str | None = None, + ): + super().__init__() + + # Determine checkpoint path + if vae_path is None: + vae_path = os.path.join(model_dir, model_name, DEFAULT_TAE_FILENAME) + + self.z_dim = 16 + + # Create TAE model + self.model = ( + _TAEModel( + checkpoint_path=vae_path, + patch_size=1, + latent_channels=self.z_dim, + ) + .eval() + .requires_grad_(False) + ) + + # Track state for streaming + self._first_batch = True + + def encode_to_latent( + self, + pixel: torch.Tensor, + use_cache: bool = True, + feat_cache: list | None = None, + ) -> torch.Tensor: + """Encode video pixels to latents. + + Args: + pixel: Input video tensor [batch, channels, frames, height, width] + use_cache: If True, use streaming encode with persistent memory. + If False, use batch encode (clears state). + feat_cache: Unused (kept for interface compatibility with WanVAEWrapper) + + Returns: + Latent tensor [batch, frames, channels, height, width] + + Note: + TAE produces approximately Gaussian latents directly without additional + normalization. The latent space is similar to but not identical to WanVAE. + + In streaming mode (use_cache=True), TAE maintains MemBlock state across + calls for smooth temporal continuity at chunk boundaries. + """ + # [batch, channels, frames, h, w] -> [batch, frames, channels, h, w] for TAE + pixel_ntchw = pixel.permute(0, 2, 1, 3, 4) + + # Scale from [-1, 1] to [0, 1] range expected by TAE + pixel_ntchw = (pixel_ntchw + 1) / 2 + + if use_cache: + # Streaming mode - use parallel processing with persistent memory + if self._first_batch: + self.model.clear_encode_state() + + latent = self.model.stream_encode(pixel_ntchw) + else: + # Batch mode - no persistent state + latent = self.model.encode_video( + pixel_ntchw, parallel=True, show_progress_bar=False + ) + + # Return in [batch, frames, channels, h, w] format + return latent + + def decode_to_pixel( + self, latent: torch.Tensor, use_cache: bool = True + ) -> torch.Tensor: + """Decode latents to video pixels. + + Args: + latent: Latent tensor [batch, frames, channels, height, width] + use_cache: If True, use streaming decode with persistent memory. + If False, use batch decode (clears state). + + Returns: + Video tensor [batch, frames, channels, height, width] in range [-1, 1] + + Note: + In streaming mode (use_cache=True), TAE maintains MemBlock state across + calls for smooth temporal continuity. Uses parallel processing within + each batch for speed. The first call may have fewer output frames due + to TGrow temporal expansion and frame trimming. + """ + if use_cache: + # Streaming mode - use parallel processing with persistent memory + if self._first_batch: + self.model.clear_decode_state() + self._first_batch = False + + output = self.model.stream_decode(latent) + else: + # Batch mode - no persistent state + output = self.model.decode_video( + latent, parallel=True, show_progress_bar=False + ) + + # Scale from [0, 1] to [-1, 1] range + output = output * 2 - 1 + output = output.clamp_(-1, 1) + + # Return in [batch, frames, channels, h, w] format + return output + + def clear_cache(self): + """Clear state for next sequence.""" + self._first_batch = True + self.model.clear_encode_state() + self.model.clear_decode_state() diff --git a/src/scope/server/schema.py b/src/scope/server/schema.py index 31f1adccd..df2ad9467 100644 --- a/src/scope/server/schema.py +++ b/src/scope/server/schema.py @@ -14,7 +14,7 @@ from scope.core.pipelines.wan2_1.vae import DEFAULT_VAE_TYPE # VAE type literal based on available VAE types -VaeType = Literal["wan", "lightvae"] +VaeType = Literal["wan", "lightvae", "tae"] class HealthResponse(BaseModel):