diff --git a/docs/PORTING-GUIDE.md b/docs/PORTING-GUIDE.md index 5c44f94..e17ab3e 100644 --- a/docs/PORTING-GUIDE.md +++ b/docs/PORTING-GUIDE.md @@ -847,7 +847,7 @@ python scripts/video/video_quality.py output.mp4 --chunk-size 32 python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4 ``` -### Model-specific diagnostics (`scripts/helios/`) +### Model-specific diagnostics (`mlx_video/models/helios/scripts/`) | Script | Purpose | |--------|---------| diff --git a/examples/poodles_helios.gif b/examples/poodles_helios.gif new file mode 100644 index 0000000..0ded8b7 Binary files /dev/null and b/examples/poodles_helios.gif differ diff --git a/mlx_video/convert_helios.py b/mlx_video/convert_helios.py new file mode 100644 index 0000000..57b30ed --- /dev/null +++ b/mlx_video/convert_helios.py @@ -0,0 +1,717 @@ +"""Weight conversion for Helios models (HuggingFace diffusers → MLX).""" + +import gc +import json +import logging +import math +from pathlib import Path +from typing import Dict + +import mlx.core as mx +import mlx.utils + +logger = logging.getLogger(__name__) + + +def sanitize_helios_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Convert HuggingFace Helios diffusers weight keys to MLX structure. + + HF diffusers keys → MLX keys: + patch_embedding (Conv3d) → patch_embedding (Linear, reshaped) + patch_short/mid/long (Conv3d) → patch_short/mid/long (Linear, reshaped) + condition_embedder.time_embedder.linear_1/2 → time_embedding_0/1 + condition_embedder.time_proj → time_projection + condition_embedder.text_embedder.linear_1/2 → text_embedding_0/1 + blocks.{i}.attn1.to_q/k/v → blocks.{i}.self_attn.q/k/v + blocks.{i}.attn1.to_out.0 → blocks.{i}.self_attn.o + blocks.{i}.attn1.norm_q/k → blocks.{i}.self_attn.norm_q/k + blocks.{i}.attn2.to_q/k/v → blocks.{i}.cross_attn.q/k/v + blocks.{i}.attn2.to_out.0 → blocks.{i}.cross_attn.o + blocks.{i}.attn2.norm_q/k → blocks.{i}.cross_attn.norm_q/k + blocks.{i}.ffn.net.0.proj → blocks.{i}.ffn.fc1 + blocks.{i}.ffn.net.2 → blocks.{i}.ffn.fc2 + blocks.{i}.norm1/2/3 → blocks.{i}.norm1/2/3 + blocks.{i}.scale_shift_table → blocks.{i}.scale_shift_table + norm_out.norm → output_norm + norm_out.scale_shift_table → output_norm_table + proj_out → proj_out + """ + sanitized = {} + consumed = set() + + for key, value in weights.items(): + new_key = key + + # Conv3d patch embeddings: [O, I, D, H, W] → Linear [O, I*D*H*W] + if key in ("patch_embedding.weight", "patch_short.weight", + "patch_mid.weight", "patch_long.weight"): + if value.ndim == 5: + value = value.reshape(value.shape[0], -1) + sanitized[key] = value + consumed.add(key) + continue + if key in ("patch_embedding.bias", "patch_short.bias", + "patch_mid.bias", "patch_long.bias"): + sanitized[key] = value + consumed.add(key) + continue + + # condition_embedder.time_embedder → time_embedding + if key.startswith("condition_embedder.time_embedder.linear_1."): + suffix = key.split("condition_embedder.time_embedder.linear_1.")[-1] + new_key = f"time_embedding_0.{suffix}" + sanitized[new_key] = value + consumed.add(key) + continue + if key.startswith("condition_embedder.time_embedder.linear_2."): + suffix = key.split("condition_embedder.time_embedder.linear_2.")[-1] + new_key = f"time_embedding_1.{suffix}" + sanitized[new_key] = value + consumed.add(key) + continue + + # condition_embedder.time_proj → time_projection + if key.startswith("condition_embedder.time_proj."): + suffix = key.split("condition_embedder.time_proj.")[-1] + new_key = f"time_projection.{suffix}" + sanitized[new_key] = value + consumed.add(key) + continue + + # condition_embedder.text_embedder → text_embedding + if key.startswith("condition_embedder.text_embedder.linear_1."): + suffix = key.split("condition_embedder.text_embedder.linear_1.")[-1] + new_key = f"text_embedding_0.{suffix}" + sanitized[new_key] = value + consumed.add(key) + continue + if key.startswith("condition_embedder.text_embedder.linear_2."): + suffix = key.split("condition_embedder.text_embedder.linear_2.")[-1] + new_key = f"text_embedding_1.{suffix}" + sanitized[new_key] = value + consumed.add(key) + continue + + # norm_out.norm → output_norm + if key.startswith("norm_out.norm."): + suffix = key.split("norm_out.norm.")[-1] + new_key = f"output_norm.{suffix}" + sanitized[new_key] = value + consumed.add(key) + continue + + # norm_out.scale_shift_table → output_norm_table + if key == "norm_out.scale_shift_table": + sanitized["output_norm_table"] = value + consumed.add(key) + continue + + # Attention: attn1 → self_attn, attn2 → cross_attn + new_key = new_key.replace(".attn1.", ".self_attn.") + new_key = new_key.replace(".attn2.", ".cross_attn.") + + # to_q/k/v → q/k/v, to_out.0 → o + new_key = new_key.replace(".to_q.", ".q.") + new_key = new_key.replace(".to_k.", ".k.") + new_key = new_key.replace(".to_v.", ".v.") + new_key = new_key.replace(".to_out.0.", ".o.") + + # FFN: net.0.proj → fc1, net.2 → fc2 + new_key = new_key.replace(".ffn.net.0.proj.", ".ffn.fc1.") + new_key = new_key.replace(".ffn.net.2.", ".ffn.fc2.") + + # Skip timesteps_proj (no trainable weights) and dropout layers + if "timesteps_proj" in key or "to_out.1." in key: + consumed.add(key) + continue + + # Skip rope buffers (computed in model) + if key.startswith("rope."): + consumed.add(key) + continue + + # Skip unused keys + if "norm_added_q" in key or "norm_added_k" in key: + consumed.add(key) + continue + if "add_k_proj" in key or "add_v_proj" in key: + consumed.add(key) + continue + + sanitized[new_key] = value + consumed.add(key) + + unconsumed = set(weights.keys()) - consumed + if unconsumed: + logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed)) + + return sanitized + + +def _quantize_predicate(path: str, module) -> bool: + """Return True for layers that should be quantized. + + Targets heavyweight Linear layers in attention and FFN blocks. + Skips embeddings, norms, head, and modulation (small, precision-sensitive). + """ + if not hasattr(module, "to_quantized"): + return False + quantize_patterns = ( + ".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o", + ".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o", + ".ffn.fc1", ".ffn.fc2", + ) + return any(path.endswith(p) for p in quantize_patterns) + + +def sanitize_helios_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Convert HuggingFace UMT5 encoder weight keys to MLX T5Encoder structure. + + HF UMT5 keys → MLX keys: + shared.weight / encoder.embed_tokens.weight → token_embedding.weight + encoder.final_layer_norm.weight → norm.weight + encoder.block.{i}.layer.0.SelfAttention.q.weight → blocks.{i}.attn.q.weight + encoder.block.{i}.layer.0.SelfAttention.k.weight → blocks.{i}.attn.k.weight + encoder.block.{i}.layer.0.SelfAttention.v.weight → blocks.{i}.attn.v.weight + encoder.block.{i}.layer.0.SelfAttention.o.weight → blocks.{i}.attn.o.weight + encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight + → blocks.{i}.pos_embedding.embedding.weight + encoder.block.{i}.layer.0.layer_norm.weight → blocks.{i}.norm1.weight + encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight → blocks.{i}.ffn.gate_proj.weight + encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight → blocks.{i}.ffn.fc1.weight + encoder.block.{i}.layer.1.DenseReluDense.wo.weight → blocks.{i}.ffn.fc2.weight + encoder.block.{i}.layer.1.layer_norm.weight → blocks.{i}.norm2.weight + """ + sanitized = {} + + for key, value in weights.items(): + # Token embedding + if key in ("shared.weight", "encoder.embed_tokens.weight"): + sanitized["token_embedding.weight"] = value + continue + + # Final layer norm + if key == "encoder.final_layer_norm.weight": + sanitized["norm.weight"] = value + continue + + # Skip decoder keys if present + if key.startswith("decoder.") or key.startswith("lm_head."): + continue + + # Block-level mappings: encoder.block.{i}.layer.{j}.* + if key.startswith("encoder.block."): + new_key = key + # encoder.block.{i} → blocks.{i} + new_key = new_key.replace("encoder.block.", "blocks.") + + # Self-attention (layer.0.SelfAttention) + new_key = new_key.replace(".layer.0.SelfAttention.", ".attn.") + new_key = new_key.replace( + ".attn.relative_attention_bias.weight", + ".pos_embedding.embedding.weight", + ) + + # Layer norms + new_key = new_key.replace(".layer.0.layer_norm.", ".norm1.") + new_key = new_key.replace(".layer.1.layer_norm.", ".norm2.") + + # FFN (layer.1.DenseReluDense) + new_key = new_key.replace(".layer.1.DenseReluDense.wi_0.", ".ffn.gate_proj.") + new_key = new_key.replace(".layer.1.DenseReluDense.wi_1.", ".ffn.fc1.") + new_key = new_key.replace(".layer.1.DenseReluDense.wo.", ".ffn.fc2.") + + sanitized[new_key] = value + continue + + logger.warning("Unhandled T5 key: %s", key) + + return sanitized + + +def sanitize_helios_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Convert HF diffusers AutoencoderKLWan keys to MLX WanVAE keys. + + Handles key renaming and Conv3d/Conv2d weight transpositions. + + Key mapping: + post_quant_conv → conv2 + quant_conv → conv1 + decoder.conv_in → decoder.conv1 + decoder.conv_out → decoder.head.2 + decoder.norm_out → decoder.head.0 + decoder.mid_block.resnets.{i} → decoder.middle.{i*2} + decoder.mid_block.attentions.0 → decoder.middle.1 + decoder.up_blocks.{b}.resnets.{r} → decoder.upsamples.{flat} + decoder.up_blocks.{b}.upsamplers.0 → decoder.upsamples.{flat} + Within resnets: norm1→residual.0, conv1→residual.2, + norm2→residual.3, conv2→residual.6, + conv_shortcut→shortcut + encoder.* keys are skipped (decoder-only loading) + """ + # The WanVAE decoder.upsamples is a flat Sequential. + # For dim_mult=[1,2,4,4], num_res_blocks=2 (→3 resnets/block): + # up_block 0: upsamples 0-2 (resnets) + 3 (upsampler with time_conv) + # up_block 1: upsamples 4-6 (resnets) + 7 (upsampler with time_conv) + # up_block 2: upsamples 8-10 (resnets) + 11 (upsampler, no time_conv) + # up_block 3: upsamples 12-14 (resnets only, no upsampler) + block_offsets = {0: 0, 1: 4, 2: 8, 3: 12} + upsampler_offsets = {0: 3, 1: 7, 2: 11} + + # Resnet sub-key mapping + resnet_map = { + "norm1": "residual.0", + "conv1": "residual.2", + "norm2": "residual.3", + "conv2": "residual.6", + "conv_shortcut": "shortcut", + } + + sanitized = {} + latents_mean = None + latents_std = None + + for key, value in weights.items(): + # Transpose Conv3d [O, I, D, H, W] → MLX [O, D, H, W, I] + if "weight" in key and value.ndim == 5: + value = mx.transpose(value, (0, 2, 3, 4, 1)) + # Transpose Conv2d [O, I, H, W] → MLX [O, H, W, I] + if "weight" in key and value.ndim == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + + # Skip encoder keys + if key.startswith("encoder."): + continue + + # Top-level convolutions + if key.startswith("post_quant_conv."): + new_key = key.replace("post_quant_conv.", "conv2.") + sanitized[new_key] = value + continue + if key.startswith("quant_conv."): + new_key = key.replace("quant_conv.", "conv1.") + sanitized[new_key] = value + continue + + # Decoder conv_in → conv1 + if key.startswith("decoder.conv_in."): + new_key = key.replace("decoder.conv_in.", "decoder.conv1.") + sanitized[new_key] = value + continue + + # Decoder conv_out → head.2 + if key.startswith("decoder.conv_out."): + new_key = key.replace("decoder.conv_out.", "decoder.head.2.") + sanitized[new_key] = value + continue + + # Decoder norm_out → head.0 + if key.startswith("decoder.norm_out."): + new_key = key.replace("decoder.norm_out.", "decoder.head.0.") + sanitized[new_key] = value + continue + + # Mid block resnets + if key.startswith("decoder.mid_block.resnets."): + rest = key[len("decoder.mid_block.resnets."):] + # rest = "{i}.{subkey}" + parts = rest.split(".", 1) + resnet_idx = int(parts[0]) + subkey = parts[1] # e.g. "norm1.gamma" or "conv1.weight" + mid_idx = resnet_idx * 2 # resnets 0→middle.0, 1→middle.2 + + sub_prefix = subkey.split(".")[0] + if sub_prefix in resnet_map: + mapped = resnet_map[sub_prefix] + sub_rest = subkey[len(sub_prefix):] + new_key = f"decoder.middle.{mid_idx}.{mapped}{sub_rest}" + else: + new_key = f"decoder.middle.{mid_idx}.{subkey}" + sanitized[new_key] = value + continue + + # Mid block attention + if key.startswith("decoder.mid_block.attentions.0."): + rest = key[len("decoder.mid_block.attentions.0."):] + new_key = f"decoder.middle.1.{rest}" + sanitized[new_key] = value + continue + + # Up blocks + if key.startswith("decoder.up_blocks."): + rest = key[len("decoder.up_blocks."):] + # Parse block index + parts = rest.split(".", 1) + block_idx = int(parts[0]) + sub = parts[1] + + if sub.startswith("resnets."): + resnet_rest = sub[len("resnets."):] + rparts = resnet_rest.split(".", 1) + resnet_idx = int(rparts[0]) + subkey = rparts[1] + flat_idx = block_offsets[block_idx] + resnet_idx + + sub_prefix = subkey.split(".")[0] + if sub_prefix in resnet_map: + mapped = resnet_map[sub_prefix] + sub_rest = subkey[len(sub_prefix):] + new_key = f"decoder.upsamples.{flat_idx}.{mapped}{sub_rest}" + else: + new_key = f"decoder.upsamples.{flat_idx}.{subkey}" + sanitized[new_key] = value + continue + + if sub.startswith("upsamplers.0."): + upsampler_rest = sub[len("upsamplers.0."):] + flat_idx = upsampler_offsets[block_idx] + new_key = f"decoder.upsamples.{flat_idx}.{upsampler_rest}" + sanitized[new_key] = value + continue + + logger.debug("Skipped VAE key: %s", key) + + # Add latent statistics as buffers + # WanVAE expects 'mean', 'std', 'inv_std' as registered buffers + # Read from config if available, otherwise use zeros + return sanitized + + +def convert_helios_checkpoint( + checkpoint_dir: str, + output_dir: str, + dtype: str = "bfloat16", + quantize: bool = False, + bits: int = 4, + group_size: int = 64, +): + """Convert a HuggingFace Helios checkpoint to MLX format. + + Expected structure (HuggingFace diffusers format): + checkpoint_dir/ + transformer/ + diffusion_pytorch_model*.safetensors + config.json + text_encoder/ + model*.safetensors + vae/ + diffusion_pytorch_model*.safetensors + tokenizer/ + ... + + Args: + checkpoint_dir: Path to HF Helios checkpoint + output_dir: Path to output MLX model directory + dtype: Target dtype for transformer/text-encoder weights + quantize: Whether to quantize the transformer + bits: Quantization bits (4 or 8) + group_size: Quantization group size + """ + from mlx_video.convert_wan import ( + load_safetensors_weights, + ) + + checkpoint_dir = Path(checkpoint_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + dtype_map = { + "float16": mx.float16, + "float32": mx.float32, + "bfloat16": mx.bfloat16, + } + target_dtype = dtype_map.get(dtype, mx.bfloat16) + + # 1. Convert transformer + transformer_dir = checkpoint_dir / "transformer" + if transformer_dir.exists(): + print("Converting Helios transformer...") + weights = load_safetensors_weights(str(transformer_dir)) + weights = sanitize_helios_transformer_weights(weights) + weights = {k: v.astype(target_dtype) for k, v in weights.items()} + out_path = output_dir / "model.safetensors" + mx.save_safetensors(str(out_path), weights) + print(f" Saved {len(weights)} weight tensors to {out_path}") + del weights + gc.collect() + else: + # Try loading safetensors directly from checkpoint_dir + print("Converting Helios transformer (flat directory)...") + weights = load_safetensors_weights(str(checkpoint_dir)) + if weights: + weights = sanitize_helios_transformer_weights(weights) + weights = {k: v.astype(target_dtype) for k, v in weights.items()} + out_path = output_dir / "model.safetensors" + mx.save_safetensors(str(out_path), weights) + print(f" Saved {len(weights)} weight tensors to {out_path}") + del weights + gc.collect() + + # 2. Save config + from mlx_video.models.helios.config import HeliosModelConfig + + config = HeliosModelConfig.helios_distilled() + config_path = output_dir / "config.json" + + # Try to read source config for any overrides + src_config_path = transformer_dir / "config.json" if transformer_dir.exists() else None + if src_config_path and src_config_path.exists(): + with open(src_config_path) as f: + src_cfg = json.load(f) + # Could update config from src_cfg if needed + print(f" Source config found: {src_config_path}") + + config_dict = { + "model_type": "helios_distilled", + "dim": config.dim, + "ffn_dim": config.ffn_dim, + "num_heads": config.num_heads, + "num_layers": config.num_layers, + "patch_size": list(config.patch_size), + "in_dim": config.in_dim, + "out_dim": config.out_dim, + "text_dim": config.text_dim, + "freq_dim": config.freq_dim, + "text_len": config.text_len, + "rope_dim": list(config.rope_dim), + "rope_theta": config.rope_theta, + "history_sizes": config.history_sizes, + "num_latent_frames_per_chunk": config.num_latent_frames_per_chunk, + "vae_z_dim": config.vae_z_dim, + "vae_stride": list(config.vae_stride), + "shift": config.shift, + "sample_fps": config.sample_fps, + } + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=2) + print(f" Saved config to {config_path}") + + # 3. Convert text encoder (T5) + t5_dir = checkpoint_dir / "text_encoder" + if t5_dir.exists(): + print("Converting T5 text encoder...") + weights = load_safetensors_weights(str(t5_dir)) + weights = sanitize_helios_t5_weights(weights) + weights = {k: v.astype(target_dtype) for k, v in weights.items()} + out_path = output_dir / "t5_encoder.safetensors" + mx.save_safetensors(str(out_path), weights) + print(f" Saved {len(weights)} weight tensors to {out_path}") + del weights + gc.collect() + + # 4. Convert VAE + vae_dir = checkpoint_dir / "vae" + if vae_dir.exists(): + print("Converting VAE...") + weights = load_safetensors_weights(str(vae_dir)) + weights = sanitize_helios_vae_weights(weights) + # VAE in float32 for quality + weights = {k: v.astype(mx.float32) for k, v in weights.items()} + out_path = output_dir / "vae.safetensors" + mx.save_safetensors(str(out_path), weights) + print(f" Saved {len(weights)} weight tensors to {out_path} (float32)") + del weights + gc.collect() + + # 5. Quantize if requested + if quantize: + print(f"\nQuantizing transformer ({bits}-bit, group_size={group_size})...") + _quantize_saved_model(output_dir, config, bits, group_size) + + print(f"\nConversion complete! Output: {output_dir}") + + +def _quantize_saved_model( + output_dir: Path, + config, + bits: int, + group_size: int, + source_dir: Path | None = None, +): + """Load saved bf16 model, quantize, and re-save. + + Args: + output_dir: Directory to save quantized weights + config: HeliosModelConfig + bits: Quantization bits + group_size: Quantization group size + source_dir: If provided, load weights from here instead of output_dir + """ + import mlx.nn as nn + from mlx_video.models.helios.transformer import HeliosModel + + load_dir = source_dir if source_dir is not None else output_dir + model_path = load_dir / "model.safetensors" + if not model_path.exists(): + print(" No model.safetensors found, skipping quantization") + return + + model = HeliosModel(config) + weights = mx.load(str(model_path)) + model.load_weights(list(weights.items()), strict=False) + mx.eval(model.parameters()) + del weights + gc.collect() + mx.clear_cache() + + nn.quantize( + model, + group_size=group_size, + bits=bits, + class_predicate=lambda path, m: _quantize_predicate(path, m), + ) + + weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) + + # Validate for corruption + bad_keys = [] + for k, v in weights_dict.items(): + if k.endswith(".bias") and not k.endswith(".biases"): + mx.eval(v) + if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item(): + bad_keys.append(k) + if bad_keys: + raise RuntimeError( + f"Quantization produced corrupted weights: " + f"{len(bad_keys)} bias tensors contain NaN/Inf" + ) + + save_path = output_dir / "model.safetensors" + output_dir.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(save_path), weights_dict) + n_quantized = sum(1 for k in weights_dict if ".scales" in k) + print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved") + + del model, weights_dict + gc.collect() + mx.clear_cache() + + # Update config with quantization metadata + config_path = output_dir / "config.json" + with open(config_path) as f: + cfg = json.load(f) + cfg["quantization"] = {"group_size": group_size, "bits": bits} + with open(config_path, "w") as f: + json.dump(cfg, f, indent=2) + print(f" Updated config.json with quantization metadata") + + +def quantize_mlx_model( + mlx_model_dir: str, + output_dir: str, + bits: int = 4, + group_size: int = 64, +): + """Quantize an already-converted MLX Helios model. + + Args: + mlx_model_dir: Path to existing MLX model directory + output_dir: Output directory for quantized model + bits: Quantization bits (4 or 8) + group_size: Quantization group size + """ + import shutil + + src = Path(mlx_model_dir) + dst = Path(output_dir) + + config_path = src / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"No config.json found in {src}") + + with open(config_path) as f: + cfg = json.load(f) + + if cfg.get("quantization"): + raise ValueError( + f"Model at {src} is already quantized " + f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source." + ) + + from mlx_video.models.helios.config import HeliosModelConfig + + config_dict = {k: v for k, v in cfg.items() if k in HeliosModelConfig.__dataclass_fields__} + for key in ("patch_size", "history_sizes", "rope_dim"): + if key in config_dict and isinstance(config_dict[key], list): + config_dict[key] = tuple(config_dict[key]) + config = HeliosModelConfig(**config_dict) + + # Copy non-transformer files to output dir (skip large model weights) + transformer_files = {"model.safetensors"} + if dst.resolve() != src.resolve(): + dst.mkdir(parents=True, exist_ok=True) + for f in src.iterdir(): + if f.is_file() and f.name not in transformer_files: + shutil.copy2(f, dst / f.name) + print(f"Copied non-transformer files from {src} to {dst}") + + print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...") + _quantize_saved_model(dst, config, bits=bits, group_size=group_size, source_dir=src) + + print(f"\nQuantization complete! Output: {dst}") + + +if __name__ == "__main__": + import argparse + import shutil + + parser = argparse.ArgumentParser(description="Convert Helios weights to MLX format") + parser.add_argument( + "--checkpoint-dir", + type=str, + required=True, + help="Path to HF Helios checkpoint (or existing MLX model when using --quantize-only)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="helios_mlx_model", + help="Output path for MLX model", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float16", "float32", "bfloat16"], + default="bfloat16", + help="Target dtype", + ) + parser.add_argument( + "--quantize", + action="store_true", + help="Quantize transformer weights for faster inference", + ) + parser.add_argument( + "--quantize-only", + action="store_true", + help="Quantize an already-converted MLX model (skips HF conversion)", + ) + parser.add_argument( + "--bits", + type=int, + choices=[4, 8], + default=4, + help="Quantization bits (default: 4)", + ) + parser.add_argument( + "--group-size", + type=int, + choices=[32, 64, 128], + default=64, + help="Quantization group size (default: 64)", + ) + args = parser.parse_args() + + if args.quantize_only: + quantize_mlx_model( + args.checkpoint_dir, args.output_dir, + bits=args.bits, group_size=args.group_size, + ) + else: + convert_helios_checkpoint( + args.checkpoint_dir, + args.output_dir, + dtype=args.dtype, + quantize=args.quantize, + bits=args.bits, + group_size=args.group_size, + ) diff --git a/mlx_video/generate_helios.py b/mlx_video/generate_helios.py new file mode 100644 index 0000000..55c54c0 --- /dev/null +++ b/mlx_video/generate_helios.py @@ -0,0 +1,761 @@ +"""Helios Text-to-Video generation pipeline for MLX. + +Autoregressive chunk-based video generation with multi-scale history memory. +Supports the Helios-Distilled model (x0-prediction, no CFG, 2-3 steps/chunk). +""" + +import argparse +import gc +import json +import math +import random +import sys +import time +import warnings +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import cv2 +from tqdm import tqdm + +from mlx_video.models.helios.loading import ( + _clean_text, + encode_text, + load_helios_model, + load_t5_encoder, + load_vae_decoder, +) +from mlx_video.models.wan.postprocess import save_video +from mlx_video.generate_wan import Colors + + +def sample_block_noise( + batch_size: int, + channels: int, + num_frames: int, + height: int, + width: int, + patch_size: tuple[int, int, int], + gamma: float, +) -> mx.array: + """Generate structured per-patch noise using correlated multivariate normal. + + This reduces block artifacts by ensuring spatially adjacent latents within + each patch have correlated noise values. + + Returns: + Noise tensor of shape (channels, num_frames, height, width). + """ + _, ph, pw = patch_size + block_size = ph * pw + + # Build covariance matrix: I*(1+gamma) - ones*gamma + eps*I + cov = np.eye(block_size, dtype=np.float64) * (1 + gamma) - np.ones((block_size, block_size), dtype=np.float64) * gamma + cov += np.eye(block_size, dtype=np.float64) * 1e-6 + L = np.linalg.cholesky(cov) + + # Sample standard normal and transform + block_count = batch_size * channels * num_frames * (height // ph) * (width // pw) + z = np.random.randn(block_count, block_size) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + samples = (z @ L.T).astype(np.float32) + + # Reshape to spatial layout + samples = samples.reshape(batch_size, channels, num_frames, height // ph, width // pw, ph, pw) + samples = samples.transpose(0, 1, 2, 3, 5, 4, 6) # interleave patches + samples = samples.reshape(batch_size, channels, num_frames, height, width) + + # Return as (C, F, H, W) — drop batch dim since we always use batch=1 + return mx.array(samples[0]) + + +def _spatial_reshape(x: mx.array, num_frames: int, channels: int) -> mx.array: + """Reshape (C, F, H, W) → (F, C, H, W) for spatial operations.""" + return x.transpose(1, 0, 2, 3) # (F, C, H, W) + + +def _spatial_unreshape( + x: mx.array, num_frames: int, channels: int, h: int, w: int +) -> mx.array: + """Reshape (F, C, H, W) → (C, F, H, W).""" + return x.transpose(1, 0, 2, 3) # (C, F, H, W) + + +def _bilinear_downsample_2d(x: mx.array, target_h: int, target_w: int) -> mx.array: + """Bilinear interpolation downsample matching F.interpolate(mode='bilinear'). + + For 2× integer downsampling, PyTorch's bilinear interpolation with + align_corners=False samples at the centers of the output cells using a + triangular (tent) filter over the 2×2 input neighborhood. With a scale + factor of exactly 0.5 this reduces to a weighted average: + [1/4, 1/4, 1/4, 1/4] — i.e. the same as area averaging. + + Input: (F, C, H, W). + """ + F, C, H, W = x.shape + scale_h = H // target_h + scale_w = W // target_w + x = x.reshape(F, C, target_h, scale_h, target_w, scale_w) + x = x.mean(axis=(3, 5)) + return x + + +def _nearest_upsample_2d(x: mx.array, target_h: int, target_w: int) -> mx.array: + """Nearest-neighbor 2x upsample. Input: (F, C, H, W).""" + F, C, H, W = x.shape + scale_h = target_h // H + scale_w = target_w // W + # Repeat along spatial dims + x = mx.repeat(x, scale_h, axis=2) + x = mx.repeat(x, scale_w, axis=3) + return x + + +def _downsample_history(hist: mx.array, factor: int) -> mx.array: + """Downsample history latents spatially by factor. Input: (C, F, H, W).""" + C, F, H, W = hist.shape + target_h = H // factor + target_w = W // factor + hist = hist.reshape(C, F, target_h, factor, target_w, factor) + hist = hist.mean(axis=(3, 5)) + return hist + + +def _debug_stats(name: str, x: mx.array) -> str: + """Return a compact stats string for a tensor.""" + x_f = x.astype(mx.float32) + return ( + f"{name}: shape={list(x.shape)} dtype={x.dtype} " + f"mean={x_f.mean().item():.6f} std={x_f.std().item():.6f} " + f"min={x_f.min().item():.6f} max={x_f.max().item():.6f}" + ) + + +def generate_video( + model_dir: str, + prompt: str, + width: int = 640, + height: int = 384, + num_frames: int = 99, + pyramid_steps: list[int] | None = None, + seed: int = -1, + output_path: str = "output_helios.mp4", + tiling: str = "auto", + amplify_first_chunk: bool = True, + guidance_scale: float = 1.0, + negative_prompt: str = "", + chunk_blend: int = 0, + crossfade_frames: int = 0, + anti_drifting: bool = False, + anti_drift_blend: float = 0.5, + debug: bool = False, + no_compile: bool = False, +): + """Generate video using Helios autoregressive pipeline with pyramid denoising. + + Args: + model_dir: Path to converted MLX model directory + prompt: Text prompt + width: Video width (must be divisible by 16) + height: Video height (must be divisible by 16) + num_frames: Number of frames (auto-rounded to multiple of 33) + pyramid_steps: Steps per pyramid stage (default: [2, 2, 2] for distilled) + seed: Random seed (-1 for random) + output_path: Output video path + tiling: VAE tiling mode: auto, none, default, aggressive, conservative + amplify_first_chunk: Double steps for first chunk (recommended for distilled model) + guidance_scale: CFG guidance scale (1.0 = no CFG, 5.0 = default) + negative_prompt: Negative prompt for CFG (empty string = unconditional) + chunk_blend: Number of latent frames to blend at chunk boundaries (0 to disable) + crossfade_frames: Number of pixel frames to cross-fade between chunks (0 to disable) + anti_drifting: Enable adaptive anti-drifting for temporal consistency + anti_drift_blend: How much to normalize history toward EMA (0=off, 0.5=half, 1.0=full) + no_compile: If True, skip mx.compile on models (useful for debugging) + """ + from mlx_video.models.helios.config import HeliosModelConfig + + if pyramid_steps is None: + pyramid_steps = [2, 2, 2] + + model_dir = Path(model_dir) + t1 = time.time() + + # Load config + config_path = model_dir / "config.json" + quantization = None + if config_path.exists(): + with open(config_path) as f: + config_dict = json.load(f) + quantization = config_dict.pop("quantization", None) + for key in ("patch_size", "vae_stride", "rope_dim", "history_sizes", "stage_range"): + if key in config_dict and isinstance(config_dict[key], list): + config_dict[key] = tuple(config_dict[key]) if key in ("patch_size", "vae_stride", "rope_dim") else config_dict[key] + config = HeliosModelConfig(**{ + k: v for k, v in config_dict.items() + if k in HeliosModelConfig.__dataclass_fields__ + }) + else: + config = HeliosModelConfig.helios_distilled() + + # Frame and dimension alignment + vae_stride_t, vae_stride_h, vae_stride_w = config.vae_stride + frames_per_chunk = 33 # (num_latent_frames_per_chunk - 1) * vae_stride_t + 1 + num_latent_per_chunk = config.num_latent_frames_per_chunk # 9 + + # Round num_frames to nearest multiple of frames_per_chunk + num_chunks = max(1, (num_frames + frames_per_chunk - 1) // frames_per_chunk) + num_frames = num_chunks * frames_per_chunk + total_latent_frames = num_chunks * num_latent_per_chunk + + # Align spatial dimensions for pyramid: need latent H,W divisible by + # 2^(stages-1) * patch = 4*2 = 8, so pixel dims by 8*vae_stride = 64 + num_stages = len(pyramid_steps) + pyramid_factor = 2 ** (num_stages - 1) # 4 for 3-stage + align_h = config.patch_size[1] * pyramid_factor * vae_stride_h # 2*4*8 = 64 + align_w = config.patch_size[2] * pyramid_factor * vae_stride_w # 2*4*8 = 64 + height = ((height + align_h - 1) // align_h) * align_h + width = ((width + align_w - 1) // align_w) * align_w + + h_latent = height // vae_stride_h + w_latent = width // vae_stride_w + + if seed < 0: + seed = random.randint(0, 2**32 - 1) + mx.random.seed(seed) + + print(f"\n{Colors.CYAN}Helios Video Generation{Colors.RESET}") + print(f" Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") + print(f" Resolution: {width}x{height}, {num_frames} frames ({num_chunks} chunks)") + print(f" Pyramid steps: {pyramid_steps} ({sum(pyramid_steps)} total/chunk), Seed: {seed}, Guidance: {guidance_scale}") + if quantization: + print(f" Quantization: {quantization['bits']}-bit, group_size={quantization['group_size']}") + + # 1. Load T5 text encoder and encode prompt + print(f"\n{Colors.BLUE}Loading text encoder...{Colors.RESET}") + t2 = time.time() + t5_path = model_dir / "t5_encoder.safetensors" + tokenizer_path = model_dir / "tokenizer" + + # Try to find tokenizer + if tokenizer_path.exists(): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) + else: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + + encoder = load_t5_encoder(t5_path, config) + context = encode_text(encoder, tokenizer, prompt, text_len=config.text_len) + mx.eval(context) + + do_cfg = guidance_scale > 1.0 + negative_context = None + if do_cfg: + negative_context = encode_text(encoder, tokenizer, negative_prompt, text_len=config.text_len) + mx.eval(negative_context) + + print(f"{Colors.DIM} T5 encode: {time.time() - t2:.1f}s, tokens: {context.shape[0]}{', CFG enabled' if do_cfg else ''}{Colors.RESET}") + + del encoder + gc.collect() + mx.clear_cache() + + # 2. Load transformer + print(f"\n{Colors.BLUE}Loading Helios transformer...{Colors.RESET}") + t3 = time.time() + model_path = model_dir / "model.safetensors" + model = load_helios_model(model_path, config, quantization=quantization) + print(f"{Colors.DIM} Model load: {time.time() - t3:.1f}s{Colors.RESET}") + + # 3. Pre-compute text embeddings and cross-attention KV caches + context_embedded = model.embed_text([context]) + mx.eval(context_embedded) + cross_kv_caches = model.prepare_cross_kv(context_embedded) + mx.eval(*[v for kv in cross_kv_caches for v in kv]) + + negative_context_embedded = None + negative_cross_kv_caches = None + if do_cfg: + negative_context_embedded = model.embed_text([negative_context]) + mx.eval(negative_context_embedded) + negative_cross_kv_caches = model.prepare_cross_kv(negative_context_embedded) + mx.eval(*[v for kv in negative_cross_kv_caches for v in kv]) + + print(f"{Colors.DIM} Text embedding + KV cache: ready{Colors.RESET}") + + # Compile model for faster inference via kernel fusion + if not no_compile: + model._compiled = mx.compile(model) + + # 4. History setup (keep_first_frame=True matching reference) + history_sizes = config.history_sizes # [16, 2, 1] + num_history_frames = sum(history_sizes) # 19 latent frames of history + history_latents = mx.zeros((config.in_dim, num_history_frames, h_latent, w_latent)) + + # Frame indices with prefix: [prefix | history_long | history_mid | history_1x | current] + # Reference uses keep_first_frame=True which adds a prefix frame to short history + total_indices = 1 + sum(history_sizes) + num_latent_per_chunk # +1 for prefix + indices = mx.arange(total_indices) + idx_prefix = indices[:1] # [0] + idx_long = indices[1:1 + history_sizes[0]] # [1..16] + idx_mid = indices[1 + history_sizes[0]:1 + history_sizes[0] + history_sizes[1]] # [17..18] + idx_1x = indices[1 + history_sizes[0] + history_sizes[1]:1 + sum(history_sizes)] # [19] + idx_short = mx.concatenate([idx_prefix, idx_1x]) # [0, 19] + idx_current = indices[1 + sum(history_sizes):] # [20..28] + + # 5. Initialize scheduler + from mlx_video.models.helios.scheduler import HeliosScheduler + + scheduler = HeliosScheduler( + num_train_timesteps=1000, + shift=1.0, + stages=3, + gamma=1 / 3, + use_dynamic_shifting=True, + ) + + total_steps = sum(pyramid_steps) * num_chunks + print(f"\n{Colors.BLUE}Generating {num_chunks} chunks ({sum(pyramid_steps)} steps/chunk, 3-stage pyramid)...{Colors.RESET}") + all_latent_chunks = [] + total_generated = 0 + image_latents_prefix = None # Set after first chunk for keep_first_frame + + # Adaptive anti-drifting: EMA of per-channel latent statistics + drift_global_mean = None + drift_global_var = None + drift_rho = 0.9 # EMA momentum + + for chunk_idx in range(num_chunks): + t_chunk = time.time() + is_first = chunk_idx == 0 + + # Prepare history from accumulated latents (keep_first_frame=True) + hist_long, hist_mid, hist_1x = mx.split( + history_latents[:, -num_history_frames:], + [history_sizes[0], history_sizes[0] + history_sizes[1]], + axis=1, + ) + + # Prefix is zero for first chunk (no image conditioning), otherwise first frame + if is_first: + latents_prefix = mx.zeros((config.in_dim, 1, h_latent, w_latent)) + else: + latents_prefix = image_latents_prefix + + # Short history = prefix + 1x history (2 frames) + hist_short = mx.concatenate([latents_prefix, hist_1x], axis=1) + + # Initialize noise for this chunk at full resolution + noise = mx.random.normal((config.in_dim, num_latent_per_chunk, h_latent, w_latent)) + + # Downsample to 1/4 resolution (2 halvings for 3-stage pyramid) + cur_h, cur_w = h_latent, w_latent + latents = _spatial_reshape(noise, num_latent_per_chunk, config.in_dim) + for _ in range(scheduler.stages - 1): + cur_h //= 2 + cur_w //= 2 + latents = _bilinear_downsample_2d(latents, cur_h, cur_w) * 2 + latents = _spatial_unreshape(latents, num_latent_per_chunk, config.in_dim, cur_h, cur_w) + + # Track per-stage start points for DMD re-noising + start_point_list = [latents] + + if debug: + mx.eval(latents) + print(f"\n[DEBUG] Chunk {chunk_idx}: initial noise → 1/4 res") + print(f" {_debug_stats('start_point[0]', latents)}") + + is_amplified = amplify_first_chunk and is_first + total_steps = sum(s * 2 if is_amplified else s for s in pyramid_steps) + pbar = tqdm( + total=total_steps, + desc=f" Chunk {chunk_idx + 1}/{num_chunks}", + leave=True, + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + + for i_s in range(scheduler.stages): + # Compute image_seq_len at current resolution for dynamic shift + image_seq_len = ( + num_latent_per_chunk * cur_h * cur_w + // math.prod(config.patch_size) + ) + + scheduler.set_timesteps( + pyramid_steps[i_s], + stage_index=i_s, + image_seq_len=image_seq_len, + is_amplify_first_chunk=(amplify_first_chunk and is_first), + ) + timesteps = scheduler.timesteps + + if debug: + mx.eval(latents) + print(f"\n[DEBUG] Stage {i_s}: res={cur_h}x{cur_w}, seq_len={image_seq_len}") + print(f" sigmas: {[f'{s:.6f}' for s in scheduler.sigmas.tolist()]}") + print(f" timesteps: {[f'{t:.1f}' for t in timesteps.tolist()]}") + print(f" {_debug_stats('latents_in', latents)}") + + if i_s > 0: + # Upsample 2x with nearest-neighbor + cur_h *= 2 + cur_w *= 2 + latents = _spatial_reshape(latents, num_latent_per_chunk, config.in_dim) + latents = _nearest_upsample_2d(latents, cur_h, cur_w) + latents = _spatial_unreshape(latents, num_latent_per_chunk, config.in_dim, cur_h, cur_w) + + # Alpha/beta noise mixing to reduce block artifacts + ori_sigma = 1 - scheduler.ori_start_sigmas[i_s] + gamma = scheduler.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + block_noise = sample_block_noise( + 1, config.in_dim, num_latent_per_chunk, cur_h, cur_w, + config.patch_size, gamma, + ) + latents = alpha * latents + beta * block_noise + start_point_list.append(latents) + + if debug: + mx.eval(latents) + print(f" After upsample+mix: alpha={alpha:.4f} beta={beta:.4f} ori_sigma={ori_sigma:.4f}") + print(f" {_debug_stats('start_point[' + str(i_s) + ']', latents)}") + + # History is always passed at full resolution — the Conv3d + # patchifiers handle the spatial mismatch between history and + # current latents since they are concatenated in sequence dim. + h_short, h_mid, h_long = hist_short, hist_mid, hist_long + + # Scale frame indices to match current spatial resolution + cur_idx = idx_current # [20..28] with prefix offset + cur_idx_short = idx_short + cur_idx_mid = idx_mid + cur_idx_long = idx_long + + _call = getattr(model, '_compiled', model) + + # Pre-convert to Python lists to avoid .item()/.float() sync points + timestep_list = [int(t) for t in timesteps.tolist()] + sigma_list = scheduler.sigmas.tolist() + + # History doesn't change within a stage — cast once outside the loop + h_short_bf16 = h_short.astype(mx.bfloat16) + h_mid_bf16 = h_mid.astype(mx.bfloat16) + h_long_bf16 = h_long.astype(mx.bfloat16) + + for idx, t_val in enumerate(timestep_list): + timestep = mx.array(t_val, dtype=mx.int32) + noise_pred = _call( + latents=latents.astype(mx.bfloat16), + timestep=timestep, + encoder_hidden_states=context_embedded, + frame_indices=cur_idx, + history_short=h_short_bf16, + history_mid=h_mid_bf16, + history_long=h_long_bf16, + history_short_indices=cur_idx_short, + history_mid_indices=cur_idx_mid, + history_long_indices=cur_idx_long, + cross_kv_caches=cross_kv_caches, + ) + + if debug: + mx.eval(noise_pred) + print(f"\n [Step {idx}] t={t_val} sigma={sigma_list[idx]:.6f}") + print(f" {_debug_stats('model_in', latents)}") + print(f" {_debug_stats('noise_pred', noise_pred)}") + + if do_cfg: + noise_uncond = _call( + latents=latents.astype(mx.bfloat16), + timestep=timestep, + encoder_hidden_states=negative_context_embedded, + frame_indices=cur_idx, + history_short=h_short_bf16, + history_mid=h_mid_bf16, + history_long=h_long_bf16, + history_short_indices=cur_idx_short, + history_mid_indices=cur_idx_mid, + history_long_indices=cur_idx_long, + cross_kv_caches=negative_cross_kv_caches, + ) + mx.eval(noise_uncond) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + sigma_next = sigma_list[idx + 1] if idx < len(timestep_list) - 1 else None + latents = scheduler.step_dmd( + model_output=noise_pred, + sample=latents, + cur_step=idx, + noisy_start=start_point_list[i_s], + sigma_t=sigma_list[idx], + sigma_next=sigma_next, + ) + mx.eval(latents) + + if debug: + print(f" {_debug_stats('latents_out', latents)}") + + pbar.update(1) + + if debug: + mx.eval(latents) + print(f"\n[DEBUG] Stage {i_s} complete:") + print(f" {_debug_stats('stage_output', latents)}") + + pbar.close() + mx.eval(latents) + + # Adaptive anti-drifting: normalize history latent statistics to prevent + # color/style drift between chunks. Clean latents are kept for decoding; + # only the history copy is normalized toward the running EMA. + history_latents_chunk = latents # default: same as output + if anti_drifting and num_chunks > 1: + lat_f32 = latents.astype(mx.float32) + # Per-channel stats: latents is [C, F, H, W] + cur_mean = mx.mean(lat_f32, axis=(1, 2, 3)) # [C] + cur_var = mx.var(lat_f32, axis=(1, 2, 3)) # [C] + mx.eval(cur_mean, cur_var) + + if drift_global_mean is None: + drift_global_mean = cur_mean + drift_global_var = cur_var + else: + # Update EMA BEFORE detection (matching reference order) + drift_global_mean = drift_rho * drift_global_mean + (1 - drift_rho) * cur_mean + drift_global_var = drift_rho * drift_global_var + (1 - drift_rho) * cur_var + + # Detect drift: L2 norm of deviation from updated EMA + mean_drift = float(mx.sqrt(mx.sum((cur_mean - drift_global_mean) ** 2)).item()) + var_drift = float(mx.sqrt(mx.sum((cur_var - drift_global_var) ** 2)).item()) + has_drift = mean_drift > 0.15 and var_drift > 0.15 + + if has_drift and chunk_idx < num_chunks - 1: + # Normalize history copy toward EMA (deterministic, no noise) + # Per-channel: shift mean and scale variance + cur_mean_4d = cur_mean[:, None, None, None] + cur_std_4d = mx.sqrt(mx.maximum(cur_var, mx.array(1e-8)))[:, None, None, None] + global_mean_4d = drift_global_mean[:, None, None, None] + global_std_4d = mx.sqrt(mx.maximum(drift_global_var, mx.array(1e-8)))[:, None, None, None] + + # Standardize, then rescale to target stats + normalized = (latents - cur_mean_4d) / cur_std_4d * global_std_4d + global_mean_4d + # Blend: 0 = keep raw, 1 = fully normalize to EMA + history_latents_chunk = (1 - anti_drift_blend) * latents + anti_drift_blend * normalized + history_latents_chunk = history_latents_chunk.astype(latents.dtype) + mx.eval(history_latents_chunk) + print(f"{Colors.DIM} ⚠ Drift detected (mean={mean_drift:.3f}, var={var_drift:.3f}), normalized history{Colors.RESET}") + elif debug: + print(f" [drift] mean={mean_drift:.3f}, var={var_drift:.3f}, threshold=0.15") + + all_latent_chunks.append(latents) # clean latents for decoding + + # Update history: use potentially normalized chunk for conditioning + total_generated += num_latent_per_chunk + history_latents = mx.concatenate([history_latents, history_latents_chunk], axis=1) + + # After first chunk, save first frame as prefix for subsequent chunks + if is_first and image_latents_prefix is None: + image_latents_prefix = latents[:, 0:1, :, :] + + chunk_time = time.time() - t_chunk + step_count = sum(pyramid_steps) + print(f"{Colors.DIM} Chunk {chunk_idx + 1}/{num_chunks} done: {chunk_time:.1f}s ({chunk_time / step_count:.2f}s/step){Colors.RESET}") + + # Free transformer + del model + gc.collect() + mx.clear_cache() + + # 6. VAE decode + print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}") + t4 = time.time() + vae_path = model_dir / "vae.safetensors" + vae = load_vae_decoder(vae_path, config) + + # Select tiling config + from mlx_video.models.ltx.video_vae.tiling import TilingConfig + + if tiling == "none": + tiling_config = None + elif tiling == "auto": + tiling_config = TilingConfig.auto(height, width, frames_per_chunk) + elif tiling == "default": + tiling_config = TilingConfig.default() + elif tiling == "aggressive": + tiling_config = TilingConfig.aggressive() + elif tiling == "conservative": + tiling_config = TilingConfig.conservative() + else: + tiling_config = TilingConfig.auto(height, width, frames_per_chunk) + + # Optional: smooth chunk boundaries in latent space (off by default). + # When enabled, blends first N latent frames of each new chunk toward + # the previous chunk's last frame to reduce quality discontinuity. + if chunk_blend > 0 and num_chunks > 1: + blend_n = min(chunk_blend, num_latent_per_chunk - 1) + for b in range(1, num_chunks): + ref_np = np.array(all_latent_chunks[b - 1][:, -1]) # [C, H, W] + chunk_np = np.array(all_latent_chunks[b]) # [C, F, H, W] + for k in range(min(blend_n, chunk_np.shape[1])): + target = chunk_np[:, k] + ref_weight = 0.4 * (blend_n - k) / blend_n + blended = (1 - ref_weight) * target + ref_weight * ref_np + for c in range(blended.shape[0]): + blended[c] += target[c].mean() - blended[c].mean() + chunk_np[:, k] = blended + all_latent_chunks[b] = mx.array(chunk_np) + print(f"{Colors.DIM} Applied chunk boundary blend ({blend_n} latent frames){Colors.RESET}") + + # Decode each chunk independently (matching reference behavior). + # Per-chunk decoding avoids cross-chunk VAE temporal convolution artifacts + # that occur when the quality discontinuity at boundaries hits the causal conv. + video_chunks = [] + for ci, chunk_latents in enumerate(all_latent_chunks): + z = chunk_latents[None, :, :, :, :] # [1, C, 9, H_lat, W_lat] + if tiling_config is not None: + chunk_video = vae.decode_tiled(z, tiling_config) + else: + chunk_video = vae.decode(z) + mx.eval(chunk_video) + + chunk_np = np.array(chunk_video[0]) # [3, T_decoded, H, W] + # Trim VAE warmup frames (causal padding produces stride_t-1 garbage at start) + valid = (num_latent_per_chunk - 1) * vae_stride_t + 1 # 33 + trim = chunk_np.shape[1] - valid + if trim > 0: + chunk_np = chunk_np[:, trim:] + # Drop first pixel frame: it's the overlap/conditioning frame from history + # (distorted duplicate of previous chunk's last frame). 33→32 = exact 2s at 16fps. + chunk_np = chunk_np[:, 1:] + video_chunks.append(chunk_np) + + del chunk_video, z + gc.collect() + mx.clear_cache() + + print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") + + # Correct brightness/contrast discontinuity at chunk boundaries caused by VAE + # causal padding warmup. Two-stage correction: + # 1. Spatially-varying brightness: match low-frequency (blurred) brightness per + # channel to the previous chunk's last frame, fixing the "face darkens while + # background brightens" effect from the VAE's spatial redistribution. + # 2. Per-channel contrast: scale std dev to match, fixing the ~7% contrast drop. + if len(video_chunks) > 1: + blend_n = 6 # frames over which to ramp correction + blur_size = 16 # downscale factor for low-frequency brightness map + for i in range(1, len(video_chunks)): + ref_frame = video_chunks[i - 1][:, -1] # [3, H, W] + _, fh, fw = ref_frame.shape + # Pre-compute low-frequency brightness map of reference + small_h, small_w = max(fh // blur_size, 1), max(fw // blur_size, 1) + ref_lf = np.zeros((3, small_h, small_w), dtype=np.float32) + for c in range(3): + ref_lf[c] = cv2.resize(ref_frame[c], (small_w, small_h), interpolation=cv2.INTER_AREA) + # Per-channel global stats + ref_std = ref_frame.std(axis=(1, 2), keepdims=True) + + for k in range(min(blend_n, video_chunks[i].shape[1])): + frame = video_chunks[i][:, k] + ramp = 1.0 - k / blend_n # 1.0 → 0.0 + + # Stage 1: spatially-varying brightness correction + for c in range(3): + cur_lf = cv2.resize(frame[c], (small_w, small_h), interpolation=cv2.INTER_AREA) + diff_small = ref_lf[c] - cur_lf + diff_full = cv2.resize(diff_small, (fw, fh), interpolation=cv2.INTER_LINEAR) + frame[c] = frame[c] + ramp * diff_full + + # Stage 2: per-channel contrast correction + cur_std = frame.std(axis=(1, 2), keepdims=True) + cur_std = np.maximum(cur_std, 1e-6) + target_std = cur_std + ramp * (ref_std - cur_std) + cur_mean = frame.mean(axis=(1, 2), keepdims=True) + video_chunks[i][:, k] = (frame - cur_mean) * (target_std / cur_std) + cur_mean + + # Pixel-space cross-fade at chunk boundaries to smooth transitions. + # Unlike latent-space blending, this is clean — no grid artifacts since + # the VAE decode has already resolved block noise patterns. + if crossfade_frames > 0 and len(video_chunks) > 1: + cf = min(crossfade_frames, video_chunks[0].shape[1] - 1) + for i in range(1, len(video_chunks)): + for k in range(cf): + # Linear ramp: weight 1→0 for previous chunk, 0→1 for current + w = (k + 1) / (cf + 1) + video_chunks[i][:, k] = (1 - w) * video_chunks[i - 1][:, -(cf - k)] + w * video_chunks[i][:, k] + print(f"{Colors.DIM} Applied pixel cross-fade ({cf} frames at each boundary){Colors.RESET}") + + # Concatenate pixel frames from all chunks + video = np.concatenate(video_chunks, axis=1) # [3, T_total, H, W] + + video = (video + 1.0) / 2.0 + video = np.clip(video * 255.0, 0, 255).astype(np.uint8) + video = video.transpose(1, 2, 3, 0) # [T, H, W, 3] + + # Trim to requested frame count + video = video[:num_frames] + + save_video(video, output_path, fps=config.sample_fps) + print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}") + print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}") + + +def main(): + parser = argparse.ArgumentParser(description="Helios Text-to-Video Generation (MLX)") + parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory") + parser.add_argument("--prompt", type=str, required=True, help="Text prompt") + parser.add_argument("--width", type=int, default=640, help="Video width") + parser.add_argument("--height", type=int, default=384, help="Video height") + parser.add_argument("--num-frames", type=int, default=99, help="Number of frames (auto-rounded to multiple of 33)") + parser.add_argument( + "--pyramid-steps", type=int, nargs="+", default=[2, 2, 2], + help="Steps per pyramid stage (default: 2 2 2 for distilled, total 6 forward passes)", + ) + parser.add_argument("--amplify-first-chunk", action="store_true", default=True, help="Double steps for first chunk (default: on, recommended for distilled)") + parser.add_argument("--no-amplify-first-chunk", action="store_false", dest="amplify_first_chunk", help="Disable first chunk amplification") + parser.add_argument("--seed", type=int, default=-1, help="Random seed") + parser.add_argument("--output-path", type=str, default="output_helios.mp4", help="Output video path") + parser.add_argument( + "--tiling", type=str, default="auto", + choices=["auto", "none", "default", "aggressive", "conservative"], + help="VAE tiling mode for memory efficiency", + ) + parser.add_argument("--guidance-scale", type=float, default=1.0, help="CFG guidance scale (1.0 = no CFG, default for distilled)") + parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt for CFG") + parser.add_argument("--chunk-blend", type=int, default=0, help="Latent frames to blend at chunk boundaries (0=off, default=0)") + parser.add_argument("--crossfade-frames", type=int, default=0, help="Pixel frames to cross-fade between chunks (0=off, default=0)") + parser.add_argument("--anti-drifting", action="store_true", help="Enable adaptive anti-drifting for temporal consistency between chunks") + parser.add_argument("--anti-drift-blend", type=float, default=0.5, help="How much to normalize history toward EMA stats (0=off, 0.5=half, 1.0=full; default=0.5)") + parser.add_argument("--debug", action="store_true", help="Print per-step latent statistics for debugging") + parser.add_argument("--no-compile", action="store_true", help="Disable mx.compile on models (for debugging)") + args = parser.parse_args() + + generate_video( + model_dir=args.model_dir, + prompt=args.prompt, + width=args.width, + height=args.height, + num_frames=args.num_frames, + pyramid_steps=args.pyramid_steps, + seed=args.seed, + output_path=args.output_path, + tiling=args.tiling, + amplify_first_chunk=args.amplify_first_chunk, + guidance_scale=args.guidance_scale, + negative_prompt=args.negative_prompt, + chunk_blend=args.chunk_blend, + crossfade_frames=args.crossfade_frames, + anti_drifting=args.anti_drifting, + anti_drift_blend=args.anti_drift_blend, + debug=args.debug, + no_compile=args.no_compile, + ) + + +if __name__ == "__main__": + main() diff --git a/mlx_video/models/helios/README.md b/mlx_video/models/helios/README.md new file mode 100644 index 0000000..dbdc0d7 --- /dev/null +++ b/mlx_video/models/helios/README.md @@ -0,0 +1,226 @@ +# Helios — Text-to-Video Generation on Apple Silicon + +Helios is a 14B-parameter autoregressive video generation model that produces minute-scale, temporally coherent video. This implementation targets the **Helios-Distilled** variant for text-to-video generation on Apple Silicon via MLX. + +- Arxiv: https://arxiv.org/abs/2603.04379 +- https://pku-yuangroup.github.io/Helios-Page/ +- https://github.com/PKU-YuanGroup/Helios + +### Step 1: Download Weights + +Download the original PyTorch checkpoint from HuggingFace using the `huggingface-cli` tool (install with `pip install huggingface_hub`): + +```bash +huggingface-cli download BestWishYsh/Helios-Distilled --local-dir ./Helios-Distilled +``` + +### Step 2: Convert to MLX Format + +Convert the PyTorch checkpoint to MLX format: + +```bash +python -m mlx_video.convert_helios \ + --checkpoint-dir ./Helios-Distilled \ + --output-dir ./Helios-Distilled-MLX +``` + +#### Quantization (Reduced Memory) + +Quantize the transformer weights to reduce memory usage. With 4-bit quantization (~7 GB, fits 16 GB Macs): + +```bash +python -m mlx_video.convert_helios \ + --checkpoint-dir ./Helios-Distilled \ + --output-dir ./Helios-Distilled-MLX-Q4 \ + --quantize --bits 4 +``` + +You can also quantize an already-converted MLX model without re-converting from PyTorch: + +```bash +python -m mlx_video.convert_helios \ + --checkpoint-dir ./Helios-Distilled-MLX \ + --output-dir ./Helios-Distilled-MLX-Q4 \ + --quantize-only --bits 4 +``` + +#### Conversion Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory | +| `--output-dir` | `helios_mlx_model` | Output path for MLX model | +| `--quantize` | off | Quantize transformer weights for reduced memory | +| `--bits` | `4` | Quantization bits: `4` or `8` | +| `--quantize-only` | off | Quantize an existing MLX model (skip PyTorch conversion) | + +### Step 3: Generate Video + +```bash +python -m mlx_video.generate_helios \ + --model-dir ./Helios-Distilled-MLX \ + --prompt "A golden retriever running through a sunlit meadow" \ + --output-path my_video.mp4 +``` + +```bash +python -m mlx_video.generate_helios \ + --model-dir ./Helios-Distilled-MLX \ + --num-frames 999 \ + --seed 2391784614 \ + --prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" +``` + +### Enjoy the poodles + +![Poodles](../../../examples/poodles_helios.gif) + +Gif downsampled for file size resasons: +```bash +ffmpeg -i poodles_helios.mp4 -vf "fps=10,scale=480:-1:flags=lanczos,palettegen=max_colors=32" poodles_helios_palette.png +ffmpeg -i poodles_helios.mp4 -i poodles_helios_palette.png -filter_complex "fps=6,scale=260:-1:flags=lanczos[x];[x][1:v]paletteuse=dither=bayer:bayer_scale=1" poodles_helios.gif +```` + +#### Generation Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-dir` | (required) | Path to converted MLX model directory | +| `--prompt` | (required) | Text prompt describing the video | +| `--width` | `640` | Video width in pixels (must be divisible by 64) | +| `--height` | `384` | Video height in pixels (must be divisible by 64) | +| `--num-frames` | `99` | Number of output frames (auto-rounded to multiple of 33) | +| `--pyramid-steps` | `2 2 2` | Steps per pyramid stage (3-stage progressive denoising) | +| `--amplify-first-chunk` | off | Double steps for first chunk (better quality) | +| `--guidance-scale` | `5.0` | CFG guidance scale (`1.0` = no guidance, `5.0` = default) | +| `--negative-prompt` | `""` | Negative prompt for classifier-free guidance | +| `--seed` | `-1` (random) | Random seed for reproducibility | +| `--output-path` | `output_helios.mp4` | Output video file path | +| `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `default`, `aggressive`, `conservative` | + +## How It Works + +### Autoregressive Chunked Generation + +Unlike single-pass models, Helios generates video **autoregressively in 33-frame chunks** (9 latent frames each). This enables minute-scale video with temporal coherence: + +``` +Chunk 1: [frames 1-33] → denoise from noise +Chunk 2: [frames 34-66] → denoise with history from chunk 1 +Chunk 3: [frames 67-99] → denoise with history from chunks 1-2 +... +``` + +For a 99-frame video at 24 fps, this produces ~4 seconds of video across 3 chunks. + +### Multi-Scale History Memory + +Each chunk beyond the first receives context from prior chunks via three Conv3d patch embeddings at different temporal/spatial scales: + +| Scale | Kernel | Latent Frames | Purpose | +|---|---|---|---| +| **Long** | 4×8×8 | 16 | Coarse global context | +| **Mid** | 2×4×4 | 2 | Medium-term motion | +| **Short** | 1×2×2 | 1 | Fine local detail | + +Total history: 19 latent tokens prepended to the current chunk's 9 tokens, giving the model a 28-token sequence that sees both broad context and recent detail. + +### Pipeline Steps + +1. **Text encoding** — UMT5-XXL encodes the prompt (shared with Wan) +2. **Per-chunk 3-stage pyramid denoising**: + - Sample Gaussian noise for 9 latent frames + - **Downsample** noise to 1/4 spatial resolution + - **Stage 0** (quarter res): Denoise 2 steps — very fast (16× fewer tokens) + - **Upsample** 2×, mix in structured block noise (alpha/beta correction) + - **Stage 1** (half res): Denoise 2 steps — fast (4× fewer tokens) + - **Upsample** 2×, mix block noise + - **Stage 2** (full res): Denoise 2 steps — full quality + - Prepend multi-scale history tokens (if not first chunk) + - Extract current-chunk latents; update history buffer +3. **VAE decoding** — AutoencoderKLWan decodes latents to RGB (shared with Wan, supports tiled decoding) +4. **Video output** — Frames saved as MP4 via OpenCV + +### Pyramid Denoising + +The 3-stage pyramid dramatically speeds up generation by performing most denoising at reduced spatial resolution: + +``` +Stage 0: ████░░░░░░░░░░░░ (1/4 res, 2 steps) — 16× fewer tokens +Stage 1: ████████░░░░░░░░ (1/2 res, 2 steps) — 4× fewer tokens +Stage 2: ████████████████ (full res, 2 steps) — final refinement +``` + +Customize with `--pyramid-steps`: +- `--pyramid-steps 2 2 2` — default, 6 total forward passes (fastest) +- `--pyramid-steps 4 4 4` — 12 passes (higher quality) +- `--pyramid-steps 2 2 4` — more refinement at full resolution + +Use `--amplify-first-chunk` to double the steps for the first chunk, which typically has the biggest impact on overall quality. + +## Architecture + +Helios shares 95% of its architecture with Wan 2.1: + +| Component | Details | +|---|---| +| Transformer | 40 layers, dim=5120, 40 heads, head_dim=128 | +| FFN | SiLU-gated, dim=13824 | +| Patch embedding | (1, 2, 2) — 1 temporal, 2×2 spatial | +| RoPE | 3-way factorized (44, 42, 42), θ=10000 | +| Modulation | 6-vector AdaLN-Zero (shift/scale/gate × 2) | +| VAE | AutoencoderKLWan, stride (4, 8, 8), z_dim=16 | +| Text encoder | UMT5-XXL, dim=4096, 512 token context | + +**Helios-specific additions:** +- Restricted self-attention (history tokens attend only among themselves) +- Zero-timestep embedding for history tokens +- Multi-scale history Conv3d patching (short/mid/long) + +## Frame Count Constraints + +- Output frames are auto-rounded to multiples of **33** (the chunk size) +- Each chunk produces 33 pixel frames = 9 latent frames +- The VAE temporal stride is 4, with formula: `latent_frames = (pixel_frames - 1) / 4 + 1` + +Examples: +- `--num-frames 33` → 1 chunk, ~1.4s at 24fps +- `--num-frames 99` → 3 chunks, ~4.1s at 24fps +- `--num-frames 231` → 7 chunks, ~9.6s at 24fps + +## Resolution Guide + +Height and width must be divisible by 64. Recommended resolutions: + +| Resolution | Aspect Ratio | VRAM (bf16) | VRAM (4-bit) | +|---|---|---|---| +| 384 × 640 | 3:5 | ~28 GB | ~7 GB | +| 384 × 384 | 1:1 | ~24 GB | ~6 GB | +| 256 × 448 | 9:16 | ~20 GB | ~5 GB | + +> **Note — Resolution sensitivity**: The model was only trained at 384×640. Using non-default resolutions (even valid multiples of 64, like 640×384 portrait) causes obvious frame jumps in the reference pipeline too. This is a known upstream limitation ([PKU-YuanGroup/Helios#2](https://github.com/PKU-YuanGroup/Helios/issues/2)). Residual zoom with complex prompts at the default resolution is also an inherent model behavior. + +## Memory Tips + +- Use `--tiling aggressive` for lower VRAM usage during VAE decoding +- Use 4-bit quantization (`--quantize --bits 4` during conversion) to reduce model size from ~28 GB to ~7 GB +- Shorter videos (fewer chunks) require less peak memory for history +- Smaller resolutions significantly reduce memory (quadratic in spatial dimensions) + +## File Structure + +``` +mlx_video/models/helios/ +├── __init__.py +├── README.md ← you are here +├── config.py ← HeliosModelConfig dataclass +├── rope.py ← 3-way factorized RoPE (44,42,42) +├── attention.py ← Self-attention (with history restriction) + cross-attention +├── scheduler.py ← DMD flow-matching scheduler with 3-stage pyramid support +├── transformer.py ← HeliosTransformerBlock + HeliosModel backbone +└── loading.py ← Weight loading (reuses Wan's T5/VAE loaders) + +mlx_video/ +├── convert_helios.py ← HF diffusers → MLX weight conversion +└── generate_helios.py ← CLI generation pipeline +``` diff --git a/mlx_video/models/helios/__init__.py b/mlx_video/models/helios/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlx_video/models/helios/attention.py b/mlx_video/models/helios/attention.py new file mode 100644 index 0000000..bfd1944 --- /dev/null +++ b/mlx_video/models/helios/attention.py @@ -0,0 +1,239 @@ +import mlx.core as mx +import mlx.nn as nn + +def _linear_dtype(layer) -> mx.Dtype: + """Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers.""" + inner = getattr(layer, "linear", layer) + if isinstance(inner, nn.QuantizedLinear): + return inner.scales.dtype + return inner.weight.dtype + + +class HeliosRMSNorm(nn.Module): + """RMS normalization with learnable scale.""" + + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = mx.ones((dim,)) + + def __call__(self, x: mx.array) -> mx.array: + return mx.fast.rms_norm(x, self.weight, self.eps) + + +class HeliosLayerNorm(nn.Module): + """LayerNorm computed in float32, with optional affine.""" + + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False): + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = mx.ones((dim,)) + self.bias = mx.zeros((dim,)) + + def __call__(self, x: mx.array) -> mx.array: + if self.elementwise_affine: + return mx.fast.layer_norm(x, self.weight, self.bias, self.eps) + else: + return mx.fast.layer_norm(x, None, None, self.eps) + + +class HeliosSelfAttention(nn.Module): + """Self-attention with QK normalization, 3-way RoPE, and history restriction. + + When restrict_self_attn=True, the input sequence is split into history + tokens (from previous chunks) and current tokens. History tokens attend + only to other history tokens, while current tokens attend to all tokens. + """ + + def __init__( + self, + dim: int, + num_heads: int, + qk_norm: bool = True, + eps: float = 1e-6, + restrict_self_attn: bool = False, + ): + super().__init__() + assert dim % num_heads == 0 + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.restrict_self_attn = restrict_self_attn + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + + self.norm_q = HeliosRMSNorm(dim, eps=eps) if qk_norm else None + self.norm_k = HeliosRMSNorm(dim, eps=eps) if qk_norm else None + + def __call__( + self, + x: mx.array, + frame_indices: mx.array, + grid_size: tuple, + freqs: tuple, + rope_cos_sin: tuple | None = None, + original_context_length: int = 0, + ) -> mx.array: + b, s, _ = x.shape + n, d = self.num_heads, self.head_dim + history_seq_len = s - original_context_length + + w_dtype = _linear_dtype(self.q) + x_w = x.astype(w_dtype) + + q = self.q(x_w) + k = self.k(x_w) + if self.norm_q is not None: + q = self.norm_q(q) + if self.norm_k is not None: + k = self.norm_k(k) + + if self.restrict_self_attn and history_seq_len > 0: + # Single V projection on full sequence, then split + v = self.v(x_w).reshape(b, s, n, d) + v_hist = v[:, :history_seq_len] + v_curr = v[:, history_seq_len:] + + # Reshape Q/K to multi-head for full sequence + q = q.reshape(b, s, n, d) + k = k.reshape(b, s, n, d) + + # Apply RoPE to full (history + current) sequence + if rope_cos_sin is not None: + cos_f, sin_f = rope_cos_sin + half_d = d // 2 + + q_seq = q.astype(mx.float32).reshape(b, s, n, half_d, 2) + q_real, q_imag = q_seq[..., 0], q_seq[..., 1] + q = mx.stack([q_real * cos_f - q_imag * sin_f, + q_real * sin_f + q_imag * cos_f], axis=-1).reshape(b, s, n, d) + + k_seq = k.astype(mx.float32).reshape(b, s, n, half_d, 2) + k_real, k_imag = k_seq[..., 0], k_seq[..., 1] + k = mx.stack([k_real * cos_f - k_imag * sin_f, + k_real * sin_f + k_imag * cos_f], axis=-1).reshape(b, s, n, d) + + # Split into history and current after RoPE + q_hist = q[:, :history_seq_len].astype(w_dtype) + q_curr = q[:, history_seq_len:].astype(w_dtype) + k_hist = k[:, :history_seq_len].astype(w_dtype) + k_curr = k[:, history_seq_len:].astype(w_dtype) + + # History self-attention: history attends to history only + q_h = q_hist.transpose(0, 2, 1, 3) + k_h = k_hist.transpose(0, 2, 1, 3) + v_h = v_hist.transpose(0, 2, 1, 3) + hist_out = mx.fast.scaled_dot_product_attention( + q_h, k_h, v_h, scale=self.scale + ) + hist_out = hist_out.transpose(0, 2, 1, 3).reshape(b, history_seq_len, -1) + + # Current self-attention: current attends to history + current + k_all = mx.concatenate([k_hist, k_curr], axis=1).transpose(0, 2, 1, 3) + v_all = mx.concatenate([v_hist, v_curr], axis=1).transpose(0, 2, 1, 3) + q_c = q_curr.transpose(0, 2, 1, 3) + curr_out = mx.fast.scaled_dot_product_attention( + q_c, k_all, v_all, scale=self.scale + ) + curr_out = curr_out.transpose(0, 2, 1, 3).reshape(b, original_context_length, -1) + + out = mx.concatenate([hist_out, curr_out], axis=1) + else: + # Standard self-attention (no history) + q = q.reshape(b, s, n, d) + k = k.reshape(b, s, n, d) + v = self.v(x_w).reshape(b, s, n, d) + + if rope_cos_sin is not None: + cos_f, sin_f = rope_cos_sin + q_seq = q.astype(mx.float32).reshape(b, s, n, d // 2, 2) + q_real, q_imag = q_seq[..., 0], q_seq[..., 1] + q = mx.stack([q_real * cos_f - q_imag * sin_f, + q_real * sin_f + q_imag * cos_f], axis=-1).reshape(b, s, n, d) + + k_seq = k.astype(mx.float32).reshape(b, s, n, d // 2, 2) + k_real, k_imag = k_seq[..., 0], k_seq[..., 1] + k = mx.stack([k_real * cos_f - k_imag * sin_f, + k_real * sin_f + k_imag * cos_f], axis=-1).reshape(b, s, n, d) + + q = q.astype(w_dtype).transpose(0, 2, 1, 3) + k = k.astype(w_dtype).transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale) + out = out.transpose(0, 2, 1, 3).reshape(b, s, -1) + + return self.o(out) + + +class HeliosCrossAttention(nn.Module): + """Cross-attention: Q from hidden states, K/V from text context.""" + + def __init__( + self, + dim: int, + num_heads: int, + qk_norm: bool = True, + eps: float = 1e-6, + ): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + + self.norm_q = HeliosRMSNorm(dim, eps=eps) if qk_norm else None + self.norm_k = HeliosRMSNorm(dim, eps=eps) if qk_norm else None + + def prepare_kv(self, context: mx.array) -> tuple: + """Pre-compute K and V projections for caching.""" + b = context.shape[0] + n, d = self.num_heads, self.head_dim + w_dtype = _linear_dtype(self.k) + ctx = context.astype(w_dtype) + k = self.k(ctx) + if self.norm_k is not None: + k = self.norm_k(k) + k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3) + v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3) + return k, v + + def __call__( + self, + x: mx.array, + context: mx.array, + kv_cache: tuple | None = None, + ) -> mx.array: + b = x.shape[0] + n, d = self.num_heads, self.head_dim + + w_dtype = _linear_dtype(self.q) + q = self.q(x.astype(w_dtype)) + if self.norm_q is not None: + q = self.norm_q(q) + q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3) + + if kv_cache is not None: + k, v = kv_cache + else: + ctx = context.astype(w_dtype) + k = self.k(ctx) + if self.norm_k is not None: + k = self.norm_k(k) + k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3) + v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3) + + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale) + out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d) + return self.o(out) diff --git a/mlx_video/models/helios/config.py b/mlx_video/models/helios/config.py new file mode 100644 index 0000000..3227dfb --- /dev/null +++ b/mlx_video/models/helios/config.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Union + +from mlx_video.models.ltx.config import BaseModelConfig + + +@dataclass +class HeliosModelConfig(BaseModelConfig): + """Configuration for Helios video generation models.""" + + # Transformer architecture (identical to Wan 14B) + dim: int = 5120 + ffn_dim: int = 13824 + num_heads: int = 40 + num_layers: int = 40 + patch_size: Tuple[int, int, int] = (1, 2, 2) + in_dim: int = 16 + out_dim: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + text_len: int = 512 + eps: float = 1e-6 + qk_norm: bool = True + cross_attn_norm: bool = True + + # RoPE + rope_dim: Tuple[int, int, int] = (44, 42, 42) + rope_theta: float = 10000.0 + + # Helios-specific: multi-scale history memory + history_sizes: List[int] = field(default_factory=lambda: [16, 2, 1]) + num_latent_frames_per_chunk: int = 9 + has_multi_term_memory_patch: bool = True + zero_history_timestep: bool = True + + # VAE (identical to Wan — AutoencoderKLWan) + vae_stride: Tuple[int, int, int] = (4, 8, 8) + vae_z_dim: int = 16 + + # T5 text encoder (identical to Wan — UMT5-XXL) + t5_vocab_size: int = 256384 + t5_dim: int = 4096 + t5_dim_attn: int = 4096 + t5_dim_ffn: int = 10240 + t5_num_heads: int = 64 + t5_num_layers: int = 24 + t5_num_buckets: int = 32 + + # Scheduler + num_train_timesteps: int = 1000 + shift: float = 1.0 + stages: int = 3 + stage_range: List[float] = field(default_factory=lambda: [0, 1 / 3, 2 / 3, 1]) + gamma: float = 1 / 3 + + # Inference defaults + sample_fps: int = 24 + frame_num: int = 99 + + @property + def head_dim(self) -> int: + return self.dim // self.num_heads + + @classmethod + def helios_distilled(cls) -> "HeliosModelConfig": + """Helios-Distilled: x0-prediction, no CFG, DMD scheduler, 2-3 steps.""" + return cls( + shift=1.0, + ) diff --git a/mlx_video/models/helios/docs/DIAGNOSTICS.md b/mlx_video/models/helios/docs/DIAGNOSTICS.md new file mode 100644 index 0000000..6149b84 --- /dev/null +++ b/mlx_video/models/helios/docs/DIAGNOSTICS.md @@ -0,0 +1,724 @@ +# Helios Diagnostics & Engineering Notes + +Technical reference for the Helios (distilled) video generation pipeline in mlx-video. +Covers all findings from the bring-up, verified behaviors, resolved bugs, open problems, +and things to watch out for during future development. + +--- + +## Table of Contents + +- [Architecture Overview](#architecture-overview) +- [Verified Components](#verified-components) +- [Bug History & Resolutions](#bug-history--resolutions) +- [Open Problems](#open-problems) +- [Things to Watch Out For](#things-to-watch-out-for) +- [Key Constants & Formulas](#key-constants--formulas) +- [Diagnostic Recipes](#diagnostic-recipes) + +--- + +## Architecture Overview + +Helios is a 14B-parameter DiT for autoregressive video generation. It shares ~95% of its +architecture with Wan (same VAE, same T5 encoder, same dim/heads/layers). Key Helios-specific +additions: + +| Component | Description | +|-----------|-------------| +| **Autoregressive chunking** | 33-frame chunks (9 latent frames), each chunk conditioned on history from prior chunks | +| **Multi-scale history** | Short (1×), Mid (2× downsampled), Long (4× downsampled) history via Conv3d patchifiers | +| **3-stage pyramid denoising** | Denoise at 1/4 → 1/2 → full resolution for efficiency | +| **DMD scheduler** | x0-prediction with re-noising (distilled model uses 2+2+2 steps) | +| **Block noise** | Structured per-patch noise via correlated multivariate normal | + +### Pipeline flow (distilled, 3-stage pyramid) + +``` +Full-res noise → bilinear↓2 * 2 → bilinear↓2 * 2 → [1/4 res latents] + +Stage 0 (1/4 res): 2 DMD steps → denoised₀ +Stage 1 (1/2 res): nearest↑2(denoised₀) → α·up + β·block_noise → 2 DMD steps → denoised₁ +Stage 2 (full res): nearest↑2(denoised₁) → α·up + β·block_noise → 2 DMD steps → final + +VAE decode → video frames +``` + +### File layout (~2500 lines total) + +``` +mlx_video/generate_helios.py # Pipeline orchestration (~700 lines) +mlx_video/models/helios/ + config.py # HeliosModelConfig dataclass (69 lines) + transformer.py # 14B DiT backbone (511 lines) + attention.py # Self/cross attention with history (270 lines) + rope.py # 3-way factorized RoPE (215 lines) + scheduler.py # DMD + Euler schedulers (264 lines) + loading.py # Weight loading wrappers (51 lines) +mlx_video/convert_helios.py # HF→MLX weight conversion +tests/test_helios.py # 46 tests (554 lines) +mlx_video/models/helios/scripts/ + analyze_boundaries.py # Boundary quality analysis (compare videos) + run_reference.py # Run PyTorch reference pipeline on MPS + compare_pipelines.py # Compare scheduler/pipeline mechanics + compare_models.py # Cross-framework model output comparison +``` + +--- + +## Verified Components + +These components have been numerically verified against the reference PyTorch implementation +and can be considered correct. If output quality issues arise, look elsewhere first. + +### 1. Transformer model ✅ + +**Verification**: Fed identical random inputs (latents, encoder_hidden_states, timestep=500) +to both MLX and reference PyTorch implementations. + +| Metric | Value | +|--------|-------| +| Mean abs diff | 0.004190 | +| Correlation | **0.999773** | +| Per-channel means | Match to 3 decimal places | + +The model produces correct flow predictions. Color issues are in the pipeline, not the model. + +### 2. VAE decoder ✅ + +- All weight key mappings verified exact (0.000000 max diff per key) +- Decoder output correlation 0.999+ with reference +- **Temporal offset**: First `stride_t - 1 = 3` frames are warmup garbage from causal + padding. The pipeline trims these before saving. + +### 3. Scheduler (DMD) ✅ + +Verified against reference PyTorch scheduler with identical inputs. Both produce: + +| Parameter | Stage 0 (1/4 res) | Stage 1 (1/2 res) | Stage 2 (full res) | +|-----------|-------------------|-------------------|-------------------| +| seq_len | 540 | 540 | 2160+ | +| mu (shift) | 0.5481 | 0.5481 | 0.8223 | +| sigmas | [0.998, 0.354, 0.0] | [0.998, 0.354, 0.0] | [0.999, 0.451, 0.0] | +| timesteps | [998.5, 834.0] | [742.6, 512.5] | [385.2, 174.8] | + +Alpha/beta blending coefficients: +- Stage 1: α=0.6001, β=0.6926 (α²+β²=0.84) +- Stage 2: α=0.7498, β=0.4333 (α²+β²=0.75) + +### 4. Other verified components + +- **T5 text encoder**: Reused from Wan, works with sanitized HF UMT5 keys +- **RoPE**: 3-way factorized (44,42,42) split, pad+center_downsample for history +- **Bilinear downsample × 2**: Matches reference `F.interpolate(bilinear) * 2` +- **Nearest upsample**: Matches reference `F.interpolate(nearest)` +- **Block noise**: Mathematically equivalent to reference +- **History ordering**: [long | mid | short | current] matches reference +- **Video encoding**: imageio with libx264 (no color space issues) + +--- + +## Bug History & Resolutions + +### Bug 1: Timestep projection permutation + +**Symptom**: Garbage output, model crash. +**Root cause**: Reference permutes `timestep_proj` from `(B,6,L,dim)` → `(B,L,6,dim)` before +passing to blocks. Our code had the wrong axis order. +**Fix**: Added `.transpose(0, 2, 1, 3)` in `HeliosModel.__call__`. + +### Bug 2: T5 encoder key mismatch + +**Symptom**: `ValueError: Received 242 parameters not in model` +**Root cause**: HuggingFace UMT5 weight keys don't match MLX T5Encoder keys +(e.g., `encoder.block.0.layer.0.SelfAttention.q.weight` vs `encoder.layers.0.self_attn.q_proj.weight`). +**Fix**: Added `_sanitize_helios_t5_weights()` with complete HF→MLX key mapping. + +### Bug 3: RoPE reshape crash + +**Symptom**: `ValueError: Cannot reshape array of size 88 into shape (1,1,1,22,2)` +**Root cause**: RoPE frequency computation assumed fixed spatial dimensions. With pyramid +denoising, dimensions change per stage. +**Fix**: Rewrote `rope.py` with 5D compute + downsample approach (`_rope_compute_5d`, +`_rope_pad_and_downsample`). + +### Bug 4: Grey/uniform output + +**Symptom**: All pixels ~128 (mid-grey), no content visible. +**Root cause**: Two bugs: +1. Wrong `_time_shift` formula: Was `mu*t/(mu+(1-mu)*t)`, correct is `mu*t/(1+(mu-1)*t)` +2. VAE weight keys not mapped: decoder was using random weights. +**Fix**: Corrected formula + added `sanitize_helios_vae_weights()`. + +### Bug 5: Multi-chunk noise (chunks 2+ were random noise) + +**Symptom**: First chunk had content, subsequent chunks were pure noise. +**Root cause**: Code was downsampling history to match each pyramid stage's resolution. +The reference passes **full-resolution** history at **all** pyramid stages — the model's +Conv3d patchifiers handle the spatial mismatch. +**Fix**: Removed `_downsample_history` calls. + +### Bug 6: Video color space (macOS) + +**Symptom**: Colors appeared wrong in some players. +**Root cause**: OpenCV's `mp4v` codec on macOS uses a YUV color matrix that some players +interpret differently. +**Fix**: Switched to imageio + libx264 for video encoding. + +### Bug 7: Color distortion — solid red/yellow (pyramid-specific) + +**Symptom**: Output heavily biased toward a single color. R≈224, G≈100-197, B≈28-48. +Single-stage denoising produced correct colors; pyramid denoising did not. + +**Investigation** (9 controlled experiments): + +| Experiment | Result | +|-----------|--------| +| 1-stage, 8 steps, full res | ✅ Balanced colors | +| 3-stage pyramid, 2+2+2 | ❌ Red-biased | +| 3-stage pyramid, 8+8+8 | ❌ Still red-biased (more steps ≠ better) | +| Pyramid with zero block noise | ❌ Still red-biased (noise not the cause) | +| 2 steps at full res (stage 2 sigmas) | ✅ Balanced | +| 2 steps at full res (stage 0 sigmas) | ✅ Balanced | +| 3-stage at full res (no spatial scaling, with blend) | ❌ Wrong colors | +| 3-stage at full res, pure noise start_point | ✅ Balanced | + +**Root cause**: DMD re-noising cascades per-channel mean bias across pyramid stages. + +The formula `prev = (1-σ_next)·x0 + σ_next·start_point` re-injects the blended signal's +mean at each step. Channel means grow monotonically through stages: + +``` +Stage 0 → ch0: -0.25, ch2: +0.36 +Stage 1 → ch0: -0.49, ch2: +0.82 +Stage 2 → ch0: -0.83, ch2: +1.41 ← ~4× amplification +``` + +More steps per stage make this WORSE (4+4+4 gives ch2=+2.09). + +**Fix applied** (commit `c5acde72`): Normalize the start_point per-channel to zero mean +and unit std for stages > 0. This preserves spatial structure (which patches are high/low) +while breaking the mean cascade: + +```python +sp_mean = mx.mean(sp, axis=keepdim_axes, keepdims=True) +sp_std = mx.clip(sp.std(axis=keepdim_axes, keepdims=True), a_min=1e-6, a_max=None) +start_point_list.append((sp - sp_mean) / sp_std) +``` + +Result: R=206,G=107,B=43 → **R=152,G=111,B=75** (balanced warm tones for beach prompt). + +**REVERTED**: See Bug 9 below — this normalization was found to be the cause of the +pure noise output. The reference implementation does NOT normalize start_point. +Mild per-channel mean growth across stages is the expected behavior. + +### Bug 8: Precision mismatch (MLX vs PyTorch) + +**Symptom**: Subtle color shifts compared to reference. +**Root cause**: MLX promotes `bfloat16 × float32 → float32`, so the model was computing +in float32 instead of bfloat16 (which PyTorch uses on CUDA tensor cores). Also, the +scheduler's `step_dmd` never cast back to the original dtype. +**Fix** (commit `c5acde72`): +1. Cast latents + history to `bfloat16` before model calls +2. Return `prev_sample.astype(orig_dtype)` from `step_dmd` + +Note: This alone had minimal impact on color — the normalized start_point was the primary fix. + +### Bug 9: Pure noise output — start-point normalization breaks DMD trajectory + +**Symptom**: Output was pure noise even for the first chunk. No recognizable content. + +**Root cause**: The start_point normalization added in Bug 7's fix (commit `c5acde72`) +changed the scale of the noise tensor used in DMD re-noising. The DMD formula: +``` +prev = (1 - sigma_next) * x0_pred + sigma_next * start_point +``` +relies on `start_point` having the correct magnitude — it's the original noisy latent +at each pyramid stage, scaled by the alpha/beta blending coefficients. Normalizing to +unit std destroys this scale relationship, causing the denoising trajectory to diverge. + +The reference implementation (`pipeline_helios_diffusers.py` line 703) simply appends +the blended latent without any normalization: +```python +start_point_list.append(latents) +``` + +**Investigation** (systematic comparison against reference): +1. Line-by-line comparison of `generate_helios.py` vs `pipeline_helios_diffusers.py` +2. Verified scheduler sigmas, timesteps, DMD expansion/trim all match reference +3. Verified VAE denormalization is correct (WanVAE handles internally) +4. Verified block noise Cholesky approach matches reference MultivariateNormal +5. Identified start_point normalization as the only functional deviation from reference + +**Fix**: Removed the normalization, restoring `start_point_list.append(latents)` to +match the reference. + +**Debug output** (seed=42, "A calm ocean at sunset", 384×640, 33 frames): +``` +Stage 0 (1/4 res, 12×20): sigmas=[0.998, 0.354, 0.0], ts=[998.5, 834.0] + Step 0: model_out std=0.505 → latent std=0.719 + Step 1: model_out std=0.515 → latent std=0.603 +Stage 1 (1/2 res, 24×40): alpha=0.60, beta=0.69 + Step 0: model_out std=0.622 → latent std=0.552 + Step 1: model_out std=0.581 → latent std=0.548 +Stage 2 (full res, 48×80): alpha=0.75, beta=0.43 + Step 0: model_out std=0.668 → latent std=0.603 + Step 1: model_out std=0.594 → latent std=0.762 + +Output frame analysis: + R=114, G=59, B=17 (warm sunset tones ✓) + Gradient: dx=0.2, dy=0.4 (smooth, structured) + Entropy: 5.54 bits (normal range) + Frame-to-frame diff: 3.46 avg (temporally coherent ✓) +``` + +**Status**: Mean cascade still exists (mean grows -0.07 → -0.15 → -0.23 across stages) +but is mild. This appears to be inherent model behavior, not a bug. The per-channel +growth is within the VAE's normalization range and decodes to warm, plausible colors. + +### Bug 10: Uniform color output — wrong zero-history timestep embedding + +**Symptom**: Output video showed near-uniform red/orange color (R=114, G=59, B=17 with +very low per-channel variance). No recognizable content despite plausible color range. + +**Root cause**: The zero-history timestep embedding was computed incorrectly. The reference +passes `timestep=0` through the sinusoidal `Timesteps()` encoder which produces +`[cos(0), sin(0)] = [1,1,...,1, 0,0,...,0]` (128 ones followed by 128 zeros). Our code +used `mx.zeros_like(t_emb)` — all zeros — which produces a completely different MLP output. + +Since history tokens make up ~81.6% of all tokens (2400 out of 2940 at 1/4 resolution), +the vast majority of tokens received wrong scale/shift/gate modulation vectors from the +`scale_shift_table`. This corrupted self-attention (history and current tokens interact), +making the transformer output effectively random. + +**Diagnosis** (block-by-block comparison against reference PyTorch): +1. Verified all inputs to the transformer match: patches, RoPE, text embeddings, time + embeddings, history patches — all cosine_sim ≈ 1.0 +2. Block 0 output diverged catastrophically: cosine_sim = -0.30 (essentially uncorrelated) +3. Traced the bug to `HeliosModel.__call__` line 459 where `t0_emb = mx.zeros_like(t_emb)` + should have been the sinusoidal encoding of timestep=0 + +**Fix** (commit `061f191b`): +```python +# Before (wrong): +t0_emb = mx.zeros_like(t_emb) + +# After (correct): +t0_emb = mx.array([0.0]) * self._inv_freq +t0_emb = mx.concatenate([mx.cos(t0_emb), mx.sin(t0_emb)], axis=-1) +``` + +**Result**: Block 0 cosine similarity: -0.30 → 0.999982. Full pipeline output now shows +recognizable structured content with high per-channel variance (R=100±100, G=81±81, B=33±37). + +### Bug 11: Scheduler step_dmd returning bfloat16 + +**Symptom**: Minor precision loss across denoising steps (contributed to warm color bias +but not the primary cause of bad output). + +**Root cause**: `step_dmd()` cast the result back to `orig_dtype` (bfloat16) at the end. +The reference keeps latents in float32 between steps. Since the DMD formula +`prev = (1-σ)·x0 + σ·start_point` involves near-cancellation when σ≈1, float32 precision +is important. + +**Fix** (commit `061f191b`): Return float32 from `step_dmd()`, use `float()` for sigma +values to avoid array overhead. + +--- + +## Open Problems + +### 1. Chunk boundary quality ✅ RESOLVED + +**Status**: Fixed. Four-layer fix eliminates visible boundary artifacts. + +**Root cause (cross-fade)**: Pixel cross-fade was blending the first N frames of each new +chunk with the tail of the previous chunk. Since frames from different chunks don't +spatially align, this created blur — causing a **40% sharpness drop** at every boundary. +The reference pipeline uses **no cross-fade** at all. + +**Root cause (conditioning frame)**: The first pixel frame of each non-first chunk is a +distorted reconstruction of the previous chunk's last frame (via history conditioning). +The reference keeps it as a bridge frame, but it creates visual stutter. Dropping it gives +exactly 32 frames per chunk = 2 seconds at 16 fps. + +**Root cause (VAE warmup)**: The WanVAE decoder uses causal temporal convolutions. When +decoding each chunk independently, the first few frames lack temporal context (only zero +padding). This causes a **~7% contrast drop** in the first frames of each chunk, plus a +spatial brightness redistribution (face darkens, background brightens). + +**Fix layers** (applied in order during VAE decode): +1. **No cross-fade** (default `--crossfade-frames 0`) — matches reference +2. **First-frame trim** — drops conditioning frame from each chunk (33 → 32 frames) +3. **Spatially-varying brightness correction** — matches low-frequency per-channel brightness + between chunks via downscale/diff/upscale additive correction (6-frame ramp) +4. **Per-channel contrast correction** — scales std dev to match previous chunk's last frame + +**Results** (measured with `mlx_video/models/helios/scripts/analyze_boundaries.py`): +| Metric | No fix | With all fixes | +|--------|--------|----------------| +| Contrast jump | -7.0% | **-1.0%** | +| Brightness jump | +0.7% | **+0.0%** | +| Max color shift | 1.8 px | **0.3 px** | +| Frame diff ratio | 4.1× | **2.5×** | +| Center brightness shift | -0.90 | **-0.15** | +| Periphery brightness shift | +1.39 | **+0.09** | + +**Note**: VAE overlap decode (prepending previous chunk's last latent frames as temporal +context) was tested but made things **worse** (22% contrast drop). The VAE's causal +convolutions see conflicting content from different chunks and create larger artifacts. + +**Optional**: Latent-space blend (`--chunk-blend N`, default 0 = off). Generally not +recommended as it introduces its own artifacts (grid patterns, brightness shift). + +### 2. Color warmth / saturation + +**Status**: RESOLVED by Bug 10 fix. The uniform warm color was caused by the wrong +zero-history timestep embedding, not inherent model behavior. Output now shows proper +color variation matching the prompt. + +### 3. Adaptive anti-drifting for temporal consistency + +**Status**: Implemented (`--anti-drifting` flag). Prevents color/style drift across chunks in +long videos by normalizing history latent statistics toward a running average. + +**How it works**: +- Tracks per-channel latent mean/variance via EMA (momentum=0.9) across chunks +- After each chunk: computes L2 norm of deviation from updated EMA +- If BOTH mean drift > 0.15 AND variance drift > 0.15 AND not last chunk: + - Normalizes the **history copy** (not the decoded output) per-channel to match the EMA + - `normalized = (latents - cur_mean) / cur_std * global_std + global_mean` + - Controlled by `--anti-drift-blend` (0=off, 0.5=half-normalize, 1.0=full normalize) +- Clean latents are always kept for decoding — no output quality impact +- Only the history conditioning is adjusted, gently steering future chunks + +**Previous approach** (noise corruption, commit `a9fd911d`) was reverted because adding +Gaussian noise degraded output quality and cascaded noise into subsequent chunks via history. + +**Usage**: `--anti-drifting` to enable, `--anti-drift-blend 0.5` (default) to control strength. +Off by default (matching reference). + +### 4. Generation speed + +**Status**: ~14s/step at 384×640 resolution. This is limited by the full-resolution stages +(stages 0-1 at reduced resolution are fast: ~5s/step). + +**Not yet explored**: +- `mx.compile()` for the model forward pass +- Quantization (model supports 4/8-bit via convert_helios.py) +- Memory-efficient attention + +### 5. Camera jumps at chunk boundaries ✅ RESOLVED + +**Status**: Fixed by two changes: float32 residual connections (commit b24d60a1) and +disabling pixel cross-fade (commit f89eeeb9). + +**Root cause 1 — Gradual zoom**: bfloat16 residual connections systematically truncated +high-frequency spatial content over 48 blocks × 3 residuals × 6 model calls per chunk. +When smoothed output became history for the next chunk, the effect compounded → progressive +zoom. Fixed by promoting residual additions to float32 (matching reference's `.float()` +pattern). + +**Root cause 2 — Boundary jumps**: Pixel cross-fade blended frames from different chunks +that didn't spatially align, creating visible blur/jumps at every boundary. The reference +uses no cross-fade. Disabling it matches reference boundary behavior. + +**Mitigations retained**: +- `--amplify-first-chunk` (ON by default): Doubles DMD steps for the first chunk, + providing a higher-quality anchor for subsequent chunks via history. Reference ALWAYS uses + this for distilled models. +- `--crossfade-frames 0` (OFF by default): Can be re-enabled if desired but not recommended + based on reference comparison. + +**Note — Resolution sensitivity**: The model was only trained at 384×640. Using non-default +resolutions (even valid multiples of 64, like 640×384 portrait) causes obvious frame jumps +in the reference pipeline too. This is a known upstream limitation +([PKU-YuanGroup/Helios#2](https://github.com/PKU-YuanGroup/Helios/issues/2)). Residual +zoom with complex prompts at the default resolution is also an inherent model behavior. + +### 6. Non-distilled model not supported + +Only the distilled model (DMD scheduler, 2+2+2 steps, no CFG) is implemented. The +non-distilled model uses Euler/UniPC schedulers with 20+20+20 steps and requires CFG. +The scheduler infrastructure exists (`step()` method) but the pipeline hasn't been tested. + +--- + +## Things to Watch Out For + +### Precision: bfloat16 vs float32 in MLX + +MLX type promotion rules: +``` +bfloat16 × float32 → float32 (NOT bfloat16!) +bfloat16 × float16 → float32 (promoted to higher common type) +bfloat16 × bfloat16 → bfloat16 (stays in bf16) +``` + +The reference runs the model in bfloat16 throughout (CUDA tensor cores). To match, we +**must** cast latents and history to bfloat16 before model calls. The model weights are +stored in bfloat16 (`model.safetensors`), so if inputs are also bfloat16, all computations +stay in bfloat16. + +However: Empirically, the precision difference has minimal impact on output quality. The +normalized start_point fix was far more impactful. + +### VAE temporal offset + +The WanVAE's causal Conv3d layers produce `stride_t - 1 = 3` warmup frames at the start. +These frames are garbage and must be trimmed: + +```python +video = video[:, :, 3:, :, :] # trim causal padding warmup +``` + +The pipeline handles this automatically but be careful if decoding latents manually. + +### History is always full resolution + +When passing history latents to the transformer, they must be at **full resolution** regardless +of which pyramid stage is active. The model's Conv3d patchifiers (with stride 2×2×2 and 4×4×4) +handle the downsampling internally. Passing pre-downsampled history causes noisy output. + +### Frame count requirements + +- Frames per chunk: **33** (hardcoded: `(9 - 1) * 4 + 1`) +- Total frames must be `1 + 32*k` (e.g., 33, 65, 97, 129) +- The pipeline automatically rounds up to the nearest valid count + +### Dimension alignment + +- Height and width must be divisible by **16** (VAE spatial compression × patch size) +- Latent dimensions: `h_lat = h // 8`, `w_lat = w // 8` +- For pyramid, each dimension halves twice (so full-res latent dims must be divisible by 4) + +### DMD re-noising formula + +``` +x0_pred = sample - sigma_t * flow_pred (float32, upcasted) +prev = (1 - sigma_next) * x0_pred + sigma_next * noisy_start (if not last step) +prev = x0_pred (if last step) +``` + +The `noisy_start` is stored per-stage. For stage 0, it's the initial downsampled noise. +For stages > 0, it's the **normalized** blended signal (zero-mean, unit-std per channel). + +**Critical**: Passing the raw blended signal as `noisy_start` causes mean cascading +(see Bug 7 above). Always normalize. + +### Block noise structure + +`sample_block_noise()` generates per-patch correlated noise using: +``` +noise = N(0, gamma*I) + mean(patch_noise) * (1 - gamma) +``` +where `gamma = 1/3`. This is NOT standard Gaussian noise. It reduces visible block +artifacts at patch boundaries. + +### Dynamic shifting + +The sigma schedule is shifted based on spatial resolution via `calculate_shift()`: +``` +mu = base_shift + (max_shift - base_shift) * (seq_len - base_seq) / (max_seq - base_seq) +``` + +This `mu` is used in `_time_shift`: `shifted_t = mu * t / (1 + (mu - 1) * t)`. + +**Important**: `image_seq_len` must be computed at **pre-upsample** resolution for each +pyramid stage (matching reference). Using post-upsample seq_len gives wrong sigmas. + +### Restrict self-attention + +Set to `False` (full attention) to match reference behavior. Setting `True` restricts +self-attention to only current-chunk tokens (excluding history), which significantly +degrades quality. + +--- + +## Key Constants & Formulas + +### Scheduler defaults + +```python +num_train_timesteps = 1000 +stages = 3 +stage_range = [0, 1/3, 2/3, 1] +gamma = 1/3 +base_shift = 0.5 +max_shift = 1.15 +base_image_seq_len = 256 +max_image_seq_len = 4096 +``` + +### `ori_start_sigmas` (per-stage starting signal coefficient) + +```python +ori_start_sigmas = {0: 0.999, 1: 0.666, 2: 0.334} +# ori_sigma (signal coeff) = 1 - ori_start_sigmas[i_s] +# Stage 0: ori_sigma = 0.001 (almost pure noise) +# Stage 1: ori_sigma = 0.334 +# Stage 2: ori_sigma = 0.666 +``` + +### Alpha/beta blending (stage transitions) + +```python +ori_sigma = 1 - scheduler.ori_start_sigmas[i_s] +gamma = 1/3 +alpha = 1 / (sqrt(1 + 1/gamma) * (1 - ori_sigma) + ori_sigma) +beta = alpha * (1 - ori_sigma) / sqrt(gamma) +# Note: alpha + beta > 1 (NOT a convex combination) — this is intentional +``` + +### Time shift formula + +```python +def _time_shift(mu, t): + return mu * t / (1 + (mu - 1) * t) +# WARNING: a common bug is mu*t/(mu+(1-mu)*t) — this is WRONG +``` + +--- + +## Diagnostic Recipes + +### Diagnostic scripts + +Consolidated diagnostic scripts live in `mlx_video/models/helios/scripts/`: + +| Script | Purpose | +|--------|---------| +| `analyze_boundaries.py` | Measure boundary quality (contrast, brightness, color, spatial). Compare multiple videos side-by-side. | +| `run_reference.py` | Run the PyTorch reference pipeline on MPS for ground-truth comparison. Patches float64 → float32 for MPS. | +| `compare_pipelines.py` | Compare scheduler/pipeline mechanics between MLX and PyTorch using identical dummy inputs. No model needed. | +| `compare_models.py` | Feed identical inputs to both MLX and PyTorch transformer models and compare flow predictions. | + +**Analyze chunk boundaries** (most commonly used): + +```bash +# Single video +python mlx_video/models/helios/scripts/analyze_boundaries.py /tmp/helios_output.mp4 + +# Compare before/after fix +python mlx_video/models/helios/scripts/analyze_boundaries.py \ + /tmp/before_fix.mp4 /tmp/after_fix.mp4 + +# Reference pipeline uses 33 frames/chunk (no first-frame trim) +python mlx_video/models/helios/scripts/analyze_boundaries.py --chunk-size 33 /tmp/ref.mp4 +``` + +**Run reference pipeline for comparison**: + +```bash +# Requires: pip install diffusers transformers torch accelerate +python mlx_video/models/helios/scripts/run_reference.py \ + --model-dir /path/to/Helios-Distilled \ + --prompt "A golden retriever running on a sunny beach" \ + --output /tmp/helios_ref.mp4 +``` + +**Compare pipeline mechanics** (no model weights needed): + +```bash +python mlx_video/models/helios/scripts/compare_pipelines.py \ + --helios-dir /path/to/Helios +``` + +### Check latent channel statistics per step + +Add this inside the denoising loop (after `scheduler.step_dmd`): + +```python +mx.eval(latents) +ch_means = [latents[c].mean().item() for c in range(min(4, latents.shape[0]))] +ch_stds = [latents[c].std().item() for c in range(min(4, latents.shape[0]))] +print(f"S{i_s} step{idx}: mean={ch_means} std={ch_stds}") +``` + +**What to look for**: +- Means should not grow unboundedly across stages (if they do, start_point normalization may be off) +- Stds should stay in 0.3–1.0 range (collapsing to <0.1 indicates degenerate output) + +### Check pixel-level output quality + +```python +import numpy as np, imageio.v3 as iio +vid = iio.imread('/tmp/output.mp4') +for fi in [0, vid.shape[0]//2, vid.shape[0]-1]: + f = vid[fi] + print(f'Frame {fi}: R={f[:,:,0].mean():.1f} G={f[:,:,1].mean():.1f} B={f[:,:,2].mean():.1f} ' + f'std=({f[:,:,0].std():.1f},{f[:,:,1].std():.1f},{f[:,:,2].std():.1f})') +``` + +**Healthy output indicators**: +- RGB means between 60–200 (not pegged to extremes) +- Per-channel std > 15 (indicates spatial diversity, not flat color) +- No single channel dominating (R≈G≈B for neutral scenes) + +### Check motion between frames + +```python +for a, b in [(0, 16), (16, 32), (32, 33)]: + diff = np.abs(vid[b].astype(float) - vid[a].astype(float)).mean() + print(f'Motion {a}→{b}: {diff:.1f}') +``` + +**Healthy values**: 10–30 for natural motion. >50 suggests major artifacts or scene breaks. +32→33 (chunk boundary) should be <30 for smooth transitions. + +### Compare scheduler values with reference + +```python +from mlx_video.models.helios.scheduler import HeliosScheduler +s = HeliosScheduler() +for stage in range(3): + s.set_timesteps(2, stage_index=stage, image_seq_len=540) + print(f"Stage {stage}: sigmas={s.sigmas.tolist()}, timesteps={s.timesteps.tolist()}") +``` + +Expected output should match the values in the [Verified Components](#3-scheduler-dmd-) table. + +### Run all tests + +```bash +.venv2/bin/python3 -m pytest tests/test_helios.py -v +# Expected: 46 passed +``` + +### Quick generation test + +```bash +.venv2/bin/python3 -m mlx_video.generate_helios \ + --model-dir /path/to/Helios-Distilled-MLX \ + --prompt "A golden retriever running on a sunny beach" \ + --num-frames 33 --height 384 --width 640 \ + --output-path /tmp/test.mp4 \ + --pyramid-steps 2 2 2 --seed 42 +``` + +--- + +## Appendix: Commit History + +| Commit | Description | +|--------|-------------| +| `45c20851` | Initial Helios model with 3-stage pyramid denoising | +| `70214cea` | Fix grey output (time_shift formula + VAE key mapping) | +| `fcefee27` | Add CFG support and VAE frame trimming | +| `e61eb33b` | Fix pyramid color distortion (restrict_self_attn, float32 precision, int timestep) | +| `c5acde72` | Fix color bias (normalized start_point + bfloat16 inputs) | +| `24012f96` | Add diagnostics and engineering notes | +| `35f700be` | Fix pure noise: remove start-point normalization, add debug mode | +| `061f191b` | Fix zero-history timestep embedding and scheduler precision | +| `d6f9b4e2` | Mitigate chunk boundary blur with latent-space temporal blend | +| `38d454d0` | Switch to per-chunk VAE decoding, revert blend default to off | +| `a9fd911d` | Implement adaptive anti-drifting for temporal consistency | +| `a7c9086e` | Replace anti-drifting noise corruption with history normalization | +| `4415b746` | Amplify first chunk + pixel cross-fade for camera jumps | +| `b24d60a1` | Fix zoom bug: float32 residual connections in transformer blocks | +| `f89eeeb9` | Disable pixel cross-fade by default (matches reference) | +| `811c4deb` | Document resolution sensitivity as upstream model limitation | +| `e60688fd` | Drop first pixel frame from each chunk to remove boundary distortion | +| `a716216b` | Fix brightness jumps: global contrast correction at boundaries | +| `ff101b2a` | Upgrade to per-channel brightness and contrast matching | +| `8cb3ca96` | Add spatially-varying brightness correction at chunk boundaries | diff --git a/mlx_video/models/helios/loading.py b/mlx_video/models/helios/loading.py new file mode 100644 index 0000000..55dc68e --- /dev/null +++ b/mlx_video/models/helios/loading.py @@ -0,0 +1,51 @@ +"""Helios model loading utilities. + +Reuses Wan's T5 encoder and VAE since Helios uses identical components. +""" + +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn + + +def load_helios_model( + model_path: Path, + config, + quantization: dict | None = None, +): + """Load and initialize HeliosModel, with optional quantization. + + Args: + model_path: Path to model safetensors file + config: HeliosModelConfig + quantization: Optional dict with 'bits' and 'group_size' keys. + """ + from mlx_video.models.helios.transformer import HeliosModel + + model = HeliosModel(config) + + if quantization: + from mlx_video.convert_helios import _quantize_predicate + + nn.quantize( + model, + group_size=quantization["group_size"], + bits=quantization["bits"], + class_predicate=lambda path, m: _quantize_predicate(path, m), + ) + + weights = mx.load(str(model_path)) + model.load_weights(list(weights.items()), strict=False) + mx.eval(model.parameters()) + return model + + +# Reuse Wan's T5 encoder and VAE loaders since Helios uses the same components +from mlx_video.models.wan.loading import ( # noqa: E402, F401 + _clean_text, + encode_text, + load_t5_encoder, + load_vae_decoder, + load_vae_encoder, +) diff --git a/mlx_video/models/helios/rope.py b/mlx_video/models/helios/rope.py new file mode 100644 index 0000000..474e0a5 --- /dev/null +++ b/mlx_video/models/helios/rope.py @@ -0,0 +1,209 @@ +import mlx.core as mx +import numpy as np + + +def helios_rope_params( + rope_dim: tuple = (44, 42, 42), + theta: float = 10000.0, + max_seq_len: int = 1024, +) -> tuple: + """Precompute per-dimension RoPE frequency bases for Helios. + + Unlike Wan which uses a single frequency table split by dimension, + Helios computes separate frequency bases for each spatial/temporal + dimension using that dimension's size as the denominator. + + Returns: + (freqs_t, freqs_h, freqs_w) each of shape [max_seq_len, d_i//2, 2] + """ + results = [] + for d in rope_dim: + base = 1.0 / np.power( + theta, np.arange(0, d, 2, dtype=np.float64) / d + ) + positions = np.arange(max_seq_len, dtype=np.float64) + freqs = positions[:, None] * base[None, :] + cos_f = np.cos(freqs).astype(np.float32) + sin_f = np.sin(freqs).astype(np.float32) + results.append(mx.array(np.stack([cos_f, sin_f], axis=-1))) + return tuple(results) + + +def _rope_compute_5d( + frame_indices: mx.array, + height: int, + width: int, + freqs: tuple, + dtype: type = mx.float32, +) -> mx.array: + """Compute RoPE frequencies as a 5D tensor [F, H, W, half_d, 2]. + + Args: + frame_indices: 1D array of temporal frame indices [F] + height: Spatial height grid size + width: Spatial width grid size + freqs: (freqs_t, freqs_h, freqs_w) from helios_rope_params() + dtype: Output dtype + + Returns: + [F, H, W, half_d, 2] where half_d = d_t + d_h + d_w + """ + freqs_t, freqs_h, freqs_w = freqs + if freqs_t.dtype != dtype: + freqs_t = freqs_t.astype(dtype) + freqs_h = freqs_h.astype(dtype) + freqs_w = freqs_w.astype(dtype) + + f = frame_indices.shape[0] + d_t, d_h, d_w = freqs_t.shape[1], freqs_h.shape[1], freqs_w.shape[1] + + ft = mx.broadcast_to( + freqs_t[frame_indices].reshape(f, 1, 1, d_t, 2), (f, height, width, d_t, 2) + ) + fh = mx.broadcast_to( + freqs_h[:height].reshape(1, height, 1, d_h, 2), (f, height, width, d_h, 2) + ) + fw = mx.broadcast_to( + freqs_w[:width].reshape(1, 1, width, d_w, 2), (f, height, width, d_w, 2) + ) + + return mx.concatenate([ft, fh, fw], axis=3) + + +def _rope_pad_and_downsample(rope_5d: mx.array, kernel: tuple) -> mx.array: + """Downsample a 5D RoPE tensor by averaging over kernel-sized blocks. + + Equivalent to: pad_for_3d_conv + center_down_sample_3d (avg_pool3d) from + the reference implementation. + + Args: + rope_5d: [F, H, W, half_d, 2] + kernel: (kt, kh, kw) downsampling factors + + Returns: + [F//kt, H//kh, W//kw, half_d, 2] + """ + f, h, w, d, c = rope_5d.shape + kt, kh, kw = kernel + + # Replicate-pad to make divisible + pad_t = (kt - (f % kt)) % kt + pad_h = (kh - (h % kh)) % kh + pad_w = (kw - (w % kw)) % kw + if pad_t > 0 or pad_h > 0 or pad_w > 0: + rope_5d = mx.pad( + rope_5d, + [(0, pad_t), (0, pad_h), (0, pad_w), (0, 0), (0, 0)], + mode="edge", + ) + f, h, w = rope_5d.shape[:3] + + # Reshape and average (avg_pool3d equivalent) + rope_5d = rope_5d.reshape(f // kt, kt, h // kh, kh, w // kw, kw, d, c) + rope_5d = rope_5d.mean(axis=(1, 3, 5)) + return rope_5d + + +def _flatten_rope_5d(rope_5d: mx.array) -> tuple: + """Flatten 5D RoPE to (cos, sin) each [seq_len, 1, half_d]. + + Args: + rope_5d: [F, H, W, half_d, 2] + + Returns: + (cos_f, sin_f) each [F*H*W, 1, half_d] + """ + f, h, w, d, _ = rope_5d.shape + flat = rope_5d.reshape(f * h * w, 1, d, 2) + return flat[..., 0], flat[..., 1] + + +def helios_rope_apply( + x: mx.array, + frame_indices: mx.array, + grid_size: tuple, + freqs: tuple, + precomputed_cos_sin: tuple | None = None, +) -> mx.array: + """Apply 3-way factorized RoPE to Q or K tensor for Helios. + + Args: + x: Shape [B, L, num_heads, head_dim] + frame_indices: Frame indices for this chunk, shape [F] (auto-regressive offset) + grid_size: (F, H, W) spatial/temporal grid for current chunk + freqs: (freqs_t, freqs_h, freqs_w) from helios_rope_params() + precomputed_cos_sin: Optional (cos_f, sin_f) for constant grids + """ + b, s, n, d = x.shape + half_d = d // 2 + f, h, w = grid_size + seq_len = f * h * w + + if precomputed_cos_sin is not None: + cos_f, sin_f = precomputed_cos_sin + else: + rope_5d = _rope_compute_5d(frame_indices, h, w, freqs, dtype=x.dtype) + cos_f, sin_f = _flatten_rope_5d(rope_5d) + + x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2) + x_real = x_seq[..., 0] + x_imag = x_seq[..., 1] + out_real = x_real * cos_f - x_imag * sin_f + out_imag = x_real * sin_f + x_imag * cos_f + x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape( + b, seq_len, n, d + ) + if seq_len < s: + x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1) + return x_rotated + + +def helios_rope_precompute_cos_sin( + frame_indices: mx.array, + grid_size: tuple, + freqs: tuple, + dtype: type = mx.float32, +) -> tuple: + """Precompute cos/sin for a constant grid. Call once before denoising loop. + + Args: + frame_indices: 1D array [F] of temporal frame indices + grid_size: (F, H, W) + freqs: (freqs_t, freqs_h, freqs_w) + + Returns: + (cos_f, sin_f) each [seq_len, 1, half_d] + """ + f, h, w = grid_size + rope_5d = _rope_compute_5d(frame_indices, h, w, freqs, dtype=dtype) + return _flatten_rope_5d(rope_5d) + + +def helios_rope_precompute_history( + frame_indices: mx.array, + spatial_h: int, + spatial_w: int, + freqs: tuple, + downsample_kernel: tuple | None = None, + dtype: type = mx.float32, +) -> tuple: + """Precompute cos/sin for history tokens, with optional spatial downsampling. + + This matches the reference approach: compute RoPE at the short-history + spatial resolution, then avg-pool downsample for mid/long scales. + + Args: + frame_indices: 1D array [F] of temporal frame indices + spatial_h: Height at the base (short) spatial resolution + spatial_w: Width at the base (short) spatial resolution + freqs: (freqs_t, freqs_h, freqs_w) + downsample_kernel: (kt, kh, kw) for mid/long. None for short. + dtype: Output dtype + + Returns: + (cos_f, sin_f) each [seq_len, 1, half_d] + """ + rope_5d = _rope_compute_5d(frame_indices, spatial_h, spatial_w, freqs, dtype=dtype) + if downsample_kernel is not None: + rope_5d = _rope_pad_and_downsample(rope_5d, downsample_kernel) + return _flatten_rope_5d(rope_5d) diff --git a/mlx_video/models/helios/scheduler.py b/mlx_video/models/helios/scheduler.py new file mode 100644 index 0000000..2bf9bbd --- /dev/null +++ b/mlx_video/models/helios/scheduler.py @@ -0,0 +1,272 @@ +"""Helios scheduler for MLX — DMD flow-matching with 3-stage pyramid support.""" + +import math + +import mlx.core as mx +import numpy as np + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> float: + """Compute dynamic shift (mu) based on spatial sequence length.""" + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +class HeliosScheduler: + """Flow-matching scheduler with shifted sigmas, DMD steps, and multi-stage + pyramid support. + + For the Distilled model, each pyramid stage uses DMD (Distribution Matching + Distillation) steps: predict x0 from flow, then re-noise with the original + noisy tensor for all but the last step. + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + stages: int = 3, + stage_range: list | None = None, + gamma: float = 1 / 3, + use_dynamic_shifting: bool = True, + base_image_seq_len: int = 256, + max_image_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, + ): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.stages = stages + self.stage_range = stage_range or [0, 1 / 3, 2 / 3, 1] + self.gamma = gamma + self.use_dynamic_shifting = use_dynamic_shifting + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.base_shift = base_shift + self.max_shift = max_shift + + # Precompute global and per-stage schedules + self.timestep_ratios = {} + self.timesteps_per_stage = {} + self.sigmas_per_stage = {} + self.start_sigmas = {} + self.end_sigmas = {} + self.ori_start_sigmas = {} + + self._init_sigmas() + self._init_sigmas_per_stage() + + self.sigma_min = float(self.global_sigmas[-1]) + self.sigma_max = float(self.global_sigmas[0]) + + # Runtime state (set by set_timesteps) + self.sigmas = None + self.timesteps = None + self._step_index = 0 + + def _init_sigmas(self): + """Compute the global shifted sigma schedule.""" + n = self.num_train_timesteps + alphas = np.linspace(1, 1 / n, n + 1) + sigmas = 1.0 - alphas + sigmas = np.flip( + self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + )[:-1].copy() + self.global_sigmas = sigmas + self.global_timesteps = sigmas * n + + def _init_sigmas_per_stage(self): + """Compute per-stage sigma schedules with gamma correction.""" + n = self.num_train_timesteps + stage_distance = [] + + for i_s in range(self.stages): + start_idx = int(self.stage_range[i_s] * n) + start_idx = max(start_idx, 0) + end_idx = int(self.stage_range[i_s + 1] * n) + end_idx = min(end_idx, n) + + start_sigma = float(self.global_sigmas[start_idx]) + end_sigma = float(self.global_sigmas[end_idx]) if end_idx < n else 0.0 + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.gamma + corrected = ( + 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + ) * ori_sigma + start_sigma = 1 - corrected + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + tot_distance = sum(stage_distance) + for i_s in range(self.stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == self.stages - 1: + end_ratio = 1.0 - 1e-16 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + for i_s in range(self.stages): + ratio = self.timestep_ratios[i_s] + t_max = min(float(self.global_timesteps[int(ratio[0] * n)]), 999) + t_min = float( + self.global_timesteps[min(int(ratio[1] * n), n - 1)] + ) + self.timesteps_per_stage[i_s] = np.linspace(t_max, t_min, n + 1)[:-1] + self.sigmas_per_stage[i_s] = np.linspace(0.999, 0, n + 1)[:-1] + + @staticmethod + def _time_shift(mu: float, sigma: float, t): + """Apply dynamic time shift: mu / (mu + (1/t - 1)^sigma). + + For sigma=1 (default), simplifies to: mu*t / (1 + (mu-1)*t). + """ + return mu * t / (1 + (mu - 1) * t) + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int, + image_seq_len: int | None = None, + is_amplify_first_chunk: bool = False, + ): + """Set timesteps and sigmas for a specific pyramid stage with DMD. + + For DMD, num_inference_steps is expanded (+1, or *2+1 for amplify), + then trimmed so that the final result has exactly num_inference_steps + forward passes per stage. + """ + # DMD expansion + if is_amplify_first_chunk: + n_steps = num_inference_steps * 2 + 1 + else: + n_steps = num_inference_steps + 1 + + n = self.num_train_timesteps + stage_ts = self.timesteps_per_stage[stage_index] + t_max, t_min = float(stage_ts[0]), float(stage_ts[-1]) + timesteps = np.linspace(t_max, t_min, n_steps) + + stage_sigmas = self.sigmas_per_stage[stage_index] + s_max, s_min = float(stage_sigmas[0]), float(stage_sigmas[-1]) + sigmas = np.linspace(s_max, s_min, n_steps) + sigmas = np.append(sigmas, 0.0) + + # DMD trim: drop last timestep, keep [sigmas[:-2], sigmas[-1:]] + timesteps = timesteps[:-1] + sigmas = np.concatenate([sigmas[:-2], sigmas[-1:]]) + + # Dynamic shifting based on spatial resolution + if self.use_dynamic_shifting and image_seq_len is not None: + mu = calculate_shift( + image_seq_len, + self.base_image_seq_len, + self.max_image_seq_len, + self.base_shift, + self.max_shift, + ) + sigmas = self._time_shift(mu, 1.0, sigmas) + # Remap timesteps to match shifted sigmas + timesteps = ( + self.timesteps_per_stage[stage_index].min() + + sigmas[:-1] + * ( + self.timesteps_per_stage[stage_index].max() + - self.timesteps_per_stage[stage_index].min() + ) + ) + + self.timesteps = mx.array(timesteps, dtype=mx.float32) + self.sigmas = mx.array(sigmas, dtype=mx.float32) + self._step_index = 0 + + def step_dmd( + self, + model_output: mx.array, + sample: mx.array, + cur_step: int, + noisy_start: mx.array, + sigma_t: float | None = None, + sigma_next: float | None = None, + ) -> mx.array: + """DMD step: predict x0 from flow, optionally re-noise. + + Matches reference: x0 computed in float64, re-noising in float32. + Returns float32 to keep latent precision across steps (reference + returns float32 from add_noise's .type_as(noise)). + + Args: + model_output: Flow prediction from model + sample: Current noisy latent (x_t) + cur_step: Current step index within this stage + noisy_start: Original noisy tensor for this stage (for re-noising) + sigma_t: Current sigma as Python float (avoids mx.array sync) + sigma_next: Next sigma as Python float (avoids mx.array sync) + + Returns: + Denoised or re-noised sample (float32) + """ + # Upcast to float32 for x0 = xt - sigma*flow (reference uses float64 + # but MLX GPU doesn't support float64; float32 is adequate here) + model_output = model_output.astype(mx.float32) + sample = sample.astype(mx.float32) + + # Use pre-extracted Python floats to avoid sync points + if sigma_t is None: + sigma_t = float(self.sigmas[cur_step]) + x0_pred = sample - sigma_t * model_output + + num_timesteps = len(self.timesteps) + if cur_step < num_timesteps - 1: + # Re-noise: blend x0_pred with original noisy tensor at next sigma + if sigma_next is None: + sigma_next = float(self.sigmas[cur_step + 1]) + noisy_start = noisy_start.astype(mx.float32) + prev_sample = (1 - sigma_next) * x0_pred + sigma_next * noisy_start + else: + prev_sample = x0_pred + + self._step_index = cur_step + 1 + return prev_sample + + def step( + self, + model_output: mx.array, + sample: mx.array, + sigma: mx.array | None = None, + sigma_next: mx.array | None = None, + ) -> mx.array: + """Euler step: x_{t-1} = x_t + (sigma_next - sigma) * v.""" + if sigma is None: + sigma = self.sigmas[self._step_index] + if sigma_next is None: + sigma_next = self.sigmas[self._step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + self._step_index += 1 + return prev_sample + + def add_noise( + self, + original: mx.array, + noise: mx.array, + sigma: mx.array, + ) -> mx.array: + """Add noise at a given sigma level for flow matching.""" + return (1 - sigma) * original + sigma * noise diff --git a/mlx_video/models/helios/scripts/analyze_boundaries.py b/mlx_video/models/helios/scripts/analyze_boundaries.py new file mode 100644 index 0000000..5f91171 --- /dev/null +++ b/mlx_video/models/helios/scripts/analyze_boundaries.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +"""Analyze chunk boundary quality in Helios-generated videos. + +Measures brightness, contrast, color shifts, spatial distribution, and +frame-to-frame differences at chunk boundaries. Compares multiple videos +side-by-side when given multiple paths. + +This was the primary diagnostic tool used to identify and fix: +- 40% contrast drops from pixel cross-fade (→ disabled cross-fade) +- 7% contrast drops from VAE causal padding warmup (→ contrast correction) +- Per-channel color shifts at boundaries (→ per-channel matching) +- Spatial brightness redistribution (→ low-frequency spatial correction) + +Usage: + # Analyze a single video + python mlx_video/models/helios/scripts/analyze_boundaries.py /tmp/helios_output.mp4 + + # Compare multiple videos + python mlx_video/models/helios/scripts/analyze_boundaries.py \ + /tmp/helios_before.mp4 /tmp/helios_after.mp4 + + # Custom chunk size (default: 32 frames per chunk) + python mlx_video/models/helios/scripts/analyze_boundaries.py --chunk-size 33 /tmp/ref.mp4 +""" + +import argparse +import sys + +import cv2 +import numpy as np + + +def analyze_video(path, chunk_size=32): + """Analyze boundary quality metrics for a video.""" + vid = cv2.VideoCapture(path) + if not vid.isOpened(): + print(f"Error: cannot open {path}", file=sys.stderr) + return None + + frames = [] + while True: + ret, f = vid.read() + if not ret: + break + frames.append(f) + vid.release() + + n = len(frames) + if n == 0: + print(f"Error: no frames in {path}", file=sys.stderr) + return None + + # Compute per-frame statistics + means = np.zeros(n) + stds = np.zeros(n) + ch_means = np.zeros((n, 3)) + diffs = np.zeros(n - 1) + + for i, f in enumerate(frames): + gray = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY).astype(np.float64) + means[i] = gray.mean() + stds[i] = gray.std() + ch_means[i] = [f[:, :, c].mean() for c in range(3)] + if i > 0: + prev_gray = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY).astype(np.float64) + diffs[i - 1] = np.abs(gray - prev_gray).mean() + + # Find chunk boundaries + boundaries = [] + b = chunk_size - 1 + while b < n - 1: + boundaries.append(b) + b += chunk_size + + results = { + "path": path, + "num_frames": n, + "chunk_size": chunk_size, + "boundaries": [], + } + + for b in boundaries: + if b >= n - 1: + break + + # Contrast + pre_std = stds[max(0, b - 2) : b + 1] + post_std = stds[b + 1 : min(n, b + 4)] + contrast_jump = post_std[0] - pre_std[-1] + contrast_pct = contrast_jump / max(pre_std[-1], 1e-6) * 100 + + # Brightness + bright_jump = means[b + 1] - means[b] + bright_pct = bright_jump / max(means[b], 1e-6) * 100 + + # Per-channel color shift + ch_shifts = ch_means[b + 1] - ch_means[b] # B, G, R + + # Frame diff ratio + boundary_diff = diffs[b] + window = 3 + nearby_indices = list(range(max(0, b - window), b)) + list( + range(b + 1, min(len(diffs), b + 1 + window)) + ) + nearby_avg = np.mean(diffs[nearby_indices]) if nearby_indices else 1.0 + diff_ratio = boundary_diff / max(nearby_avg, 1e-6) + + # Spatial analysis + f_pre = frames[b].astype(np.float64) + f_post = frames[b + 1].astype(np.float64) + gray_diff = cv2.cvtColor(frames[b + 1], cv2.COLOR_BGR2GRAY).astype( + np.float64 + ) - cv2.cvtColor(frames[b], cv2.COLOR_BGR2GRAY).astype(np.float64) + h, w = gray_diff.shape + ch, cw = h // 4, w // 4 + center_shift = gray_diff[ch : 3 * ch, cw : 3 * cw].mean() + periph_mask = np.ones_like(gray_diff, dtype=bool) + periph_mask[ch : 3 * ch, cw : 3 * cw] = False + periph_shift = gray_diff[periph_mask].mean() + + results["boundaries"].append( + { + "frame": b, + "contrast_pct": contrast_pct, + "bright_pct": bright_pct, + "ch_shifts_bgr": ch_shifts.tolist(), + "diff_ratio": diff_ratio, + "center_shift": center_shift, + "periph_shift": periph_shift, + "boundary_diff": boundary_diff, + "nearby_diff": nearby_avg, + } + ) + + # Per-chunk stats + chunk_stats = [] + for c in range(0, n, chunk_size): + end = min(c + chunk_size, n) + chunk_stats.append( + { + "frames": f"{c}-{end - 1}", + "mean_bright": means[c:end].mean(), + "mean_contrast": stds[c:end].mean(), + "first_contrast": stds[c], + "last_contrast": stds[end - 1], + } + ) + results["chunk_stats"] = chunk_stats + + return results + + +def print_results(results): + """Pretty-print analysis results.""" + print(f"\n{'=' * 70}") + print(f" {results['path']}") + print(f" {results['num_frames']} frames, chunk size = {results['chunk_size']}") + print(f"{'=' * 70}") + + for bd in results["boundaries"]: + b = bd["frame"] + print(f"\n Boundary {b}→{b + 1}:") + print(f" Contrast jump: {bd['contrast_pct']:+.1f}%") + print(f" Brightness jump: {bd['bright_pct']:+.1f}%") + print( + f" Color shift B/G/R: {bd['ch_shifts_bgr'][0]:+.1f} / " + f"{bd['ch_shifts_bgr'][1]:+.1f} / {bd['ch_shifts_bgr'][2]:+.1f}" + ) + print( + f" Frame diff: {bd['boundary_diff']:.1f} vs nearby " + f"{bd['nearby_diff']:.1f} ({bd['diff_ratio']:.1f}×)" + ) + print( + f" Spatial: center {bd['center_shift']:+.2f}, " + f"periphery {bd['periph_shift']:+.2f}" + ) + + print(f"\n Per-chunk summary:") + for cs in results["chunk_stats"]: + print( + f" Frames {cs['frames']:>7s}: brightness={cs['mean_bright']:.1f}, " + f"contrast={cs['mean_contrast']:.1f} " + f"(first={cs['first_contrast']:.1f}, last={cs['last_contrast']:.1f})" + ) + + +def print_comparison(all_results): + """Print side-by-side comparison table.""" + if len(all_results) < 2: + return + + print(f"\n{'=' * 70}") + print(" COMPARISON SUMMARY") + print(f"{'=' * 70}") + + # Header + labels = [r["path"].split("/")[-1] for r in all_results] + header = f"{'Metric':<25s}" + for label in labels: + header += f" {label:>18s}" + print(f"\n{header}") + print("-" * (25 + 20 * len(labels))) + + # For each boundary index + max_boundaries = max(len(r["boundaries"]) for r in all_results) + for bi in range(max_boundaries): + print(f"\n Boundary {bi + 1}:") + for metric, key, fmt in [ + ("Contrast jump", "contrast_pct", "{:+.1f}%"), + ("Brightness jump", "bright_pct", "{:+.1f}%"), + ("Frame diff ratio", "diff_ratio", "{:.1f}×"), + ("Center shift", "center_shift", "{:+.2f}"), + ("Periphery shift", "periph_shift", "{:+.2f}"), + ]: + row = f" {metric:<23s}" + for r in all_results: + if bi < len(r["boundaries"]): + val = r["boundaries"][bi][key] + row += f" {fmt.format(val):>18s}" + else: + row += f" {'N/A':>18s}" + print(row) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze chunk boundary quality in Helios videos" + ) + parser.add_argument("videos", nargs="+", help="Video file paths to analyze") + parser.add_argument( + "--chunk-size", + type=int, + default=32, + help="Frames per chunk (default: 32, use 33 for reference pipeline)", + ) + args = parser.parse_args() + + all_results = [] + for path in args.videos: + results = analyze_video(path, args.chunk_size) + if results is not None: + print_results(results) + all_results.append(results) + + if len(all_results) > 1: + print_comparison(all_results) + + +if __name__ == "__main__": + main() diff --git a/mlx_video/models/helios/scripts/compare_models.py b/mlx_video/models/helios/scripts/compare_models.py new file mode 100644 index 0000000..d8341ff --- /dev/null +++ b/mlx_video/models/helios/scripts/compare_models.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +"""Cross-framework model comparison: feed identical inputs to MLX and PyTorch models. + +Saves intermediate tensors from the MLX pipeline, loads them into the PyTorch +reference model, and compares outputs. Used to verify the MLX transformer +produces numerically equivalent flow predictions. + +Workflow: + 1. Run MLX generation with --debug to save /tmp/helios_model_inputs.npz + and /tmp/helios_mlx_output.npy + 2. Run this script to load inputs into PyTorch model and compare + +Requirements: + - Reference Helios weights (original PyTorch format) + - diffusers, torch + - Saved inputs from MLX debug run + +Usage: + # Step 1: Generate with debug to save inputs + python -m mlx_video.generate_helios \ + --model-dir /path/to/Helios-Distilled-MLX \ + --prompt "A beautiful sunset over the ocean" \ + --debug --num-frames 33 \ + --output-path /tmp/debug_test.mp4 + + # Step 2: Compare with PyTorch + python mlx_video/models/helios/scripts/compare_models.py \ + --model-dir /path/to/Helios-Distilled \ + --prompt "A beautiful sunset over the ocean" \ + --inputs /tmp/helios_model_inputs.npz \ + --mlx-output /tmp/helios_mlx_output.npy +""" + +import argparse +import sys + +import numpy as np +import torch + + +def main(): + parser = argparse.ArgumentParser(description="Compare MLX vs PyTorch model outputs") + parser.add_argument("--model-dir", required=True, help="Path to original Helios weights") + parser.add_argument("--prompt", required=True, help="Same prompt used for MLX debug run") + parser.add_argument("--inputs", default="/tmp/helios_model_inputs.npz", help="Saved MLX inputs") + parser.add_argument("--mlx-output", default="/tmp/helios_mlx_output.npy", help="Saved MLX output") + args = parser.parse_args() + + # Load saved MLX inputs + data = np.load(args.inputs) + print("Loaded inputs:") + for k in data.files: + print(f" {k}: shape={data[k].shape}, dtype={data[k].dtype}") + + # Load reference model + print("\nLoading reference pipeline...") + from diffusers import DiffusionPipeline + + pipe = DiffusionPipeline.from_pretrained( + args.model_dir, torch_dtype=torch.float16 + ).to("mps") + transformer = pipe.transformer + + # Convert inputs to torch tensors (MLX: [C,F,H,W] → PT: [B,C,F,H,W]) + latents_pt = torch.from_numpy(data["latents"]).unsqueeze(0).to("mps") + timestep_pt = torch.tensor([int(data["timestep"][0])], dtype=torch.int64, device="mps") + hist_short = torch.from_numpy(data["hist_short"]).unsqueeze(0).to("mps") + hist_mid = torch.from_numpy(data["hist_mid"]).unsqueeze(0).to("mps") + hist_long = torch.from_numpy(data["hist_long"]).unsqueeze(0).to("mps") + idx_current = torch.from_numpy(data["idx_current"]).unsqueeze(0).to("mps") + idx_short = torch.from_numpy(data["idx_short"]).unsqueeze(0).to("mps") + idx_mid = torch.from_numpy(data["idx_mid"]).unsqueeze(0).to("mps") + idx_long = torch.from_numpy(data["idx_long"]).unsqueeze(0).to("mps") + + # Encode prompt with reference text encoder + print("\nEncoding prompt with reference T5...") + prompt_embeds, _ = pipe.encode_prompt( + prompt=args.prompt, do_classifier_free_guidance=False, device="mps" + ) + print(f" prompt_embeds: {prompt_embeds.shape}") + + # Run reference forward pass + print("Running reference model...") + transformer.eval() + with torch.no_grad(): + output = transformer( + hidden_states=latents_pt.half(), + timestep=timestep_pt, + encoder_hidden_states=prompt_embeds.half(), + return_dict=False, + indices_hidden_states=idx_current, + indices_latents_history_short=idx_short, + indices_latents_history_mid=idx_mid, + indices_latents_history_long=idx_long, + latents_history_short=hist_short.half(), + latents_history_mid=hist_mid.half(), + latents_history_long=hist_long.half(), + ) + + pt_output = output[0].float().cpu().numpy().squeeze(0) + + # Load MLX output + mlx_output = np.load(args.mlx_output) + + print(f"\n{'=' * 50}") + print(f"MLX: shape={mlx_output.shape}, mean={mlx_output.mean():.6f}, std={mlx_output.std():.6f}") + print(f"PT: shape={pt_output.shape}, mean={pt_output.mean():.6f}, std={pt_output.std():.6f}") + + diff = mlx_output - pt_output + rmse = np.sqrt(np.mean(diff**2)) + mae = np.mean(np.abs(diff)) + cos_sim = np.sum(mlx_output * pt_output) / ( + np.linalg.norm(mlx_output) * np.linalg.norm(pt_output) + ) + + print(f"\nRMSE: {rmse:.6f}") + print(f"MAE: {mae:.6f}") + print(f"Cosine similarity: {cos_sim:.6f}") + print(f"Max abs diff: {np.abs(diff).max():.6f}") + + if cos_sim > 0.999: + print("\n✓ Models produce equivalent outputs") + elif cos_sim > 0.99: + print("\n⚠ Minor differences (likely precision-related)") + else: + print("\n✗ Significant differences — investigate!") + + # Per-channel breakdown + print(f"\nPer-channel (first 4):") + for c in range(min(4, mlx_output.shape[0])): + c_cos = np.sum(mlx_output[c] * pt_output[c]) / ( + np.linalg.norm(mlx_output[c]) * np.linalg.norm(pt_output[c]) + 1e-8 + ) + print( + f" Ch {c}: MLX mean={mlx_output[c].mean():.4f} " + f"PT mean={pt_output[c].mean():.4f} cos={c_cos:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/mlx_video/models/helios/scripts/compare_pipelines.py b/mlx_video/models/helios/scripts/compare_pipelines.py new file mode 100644 index 0000000..7e6b261 --- /dev/null +++ b/mlx_video/models/helios/scripts/compare_pipelines.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +"""Compare Helios pipeline mechanics: PyTorch reference vs MLX. + +Runs both schedulers with identical inputs and fixed dummy model outputs to +isolate pipeline logic differences (downsampling, upsampling, alpha/beta +blending, DMD stepping). No model weights needed. + +This was used to verify that the MLX scheduler produces numerically identical +results to the PyTorch reference, ruling out pipeline mechanics as a source of +output quality differences. + +Requirements: + - MLX video package (this repo) + - Reference Helios repo on sys.path (--helios-dir) + - PyTorch + diffusers + +Usage: + python mlx_video/models/helios/scripts/compare_pipelines.py \ + --helios-dir /path/to/Helios + + # Custom parameters + python mlx_video/models/helios/scripts/compare_pipelines.py \ + --helios-dir /path/to/Helios \ + --seed 123 --stages 3 --steps 2 2 2 +""" + +import argparse +import math +import sys + +import numpy as np + + +def calculate_shift( + image_seq_len, + base_seq_len=256, + max_seq_len=4096, + base_shift=0.5, + max_shift=1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +def main(): + parser = argparse.ArgumentParser(description="Compare pipeline mechanics") + parser.add_argument("--helios-dir", required=True, help="Path to reference Helios repo") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--stages", type=int, default=3) + parser.add_argument("--steps", type=int, nargs="+", default=[2, 2, 2]) + args = parser.parse_args() + + sys.path.insert(0, args.helios_dir) + sys.path.insert(0, ".") + + import torch + import torch.nn.functional as F + import mlx.core as mx + + from helios.diffusers_version.scheduling_helios_diffusers import ( + HeliosScheduler as PTScheduler, + ) + from mlx_video.generate_helios import ( + _bilinear_downsample_2d, + _nearest_upsample_2d, + _spatial_reshape, + _spatial_unreshape, + ) + from mlx_video.models.helios.scheduler import HeliosScheduler as MLXScheduler + + C, NL, H, W = 16, 9, 48, 80 + PATCH_SIZE = (1, 2, 2) + GAMMA = 1 / 3 + STAGES = args.stages + PYRAMID_STEPS = args.steps + + # Create identical initial noise from numpy + rng = np.random.RandomState(args.seed) + noise_np = rng.randn(C, NL, H, W).astype(np.float32) + mx_latents = mx.array(noise_np) + pt_latents = torch.from_numpy(noise_np).unsqueeze(0) + + print(f"Initial noise: mean={noise_np.mean():.6f} std={noise_np.std():.6f}") + + # Downsample — MLX + cur_h, cur_w = H, W + mx_flat = _spatial_reshape(mx_latents, NL, C) + for _ in range(STAGES - 1): + cur_h //= 2 + cur_w //= 2 + mx_flat = _bilinear_downsample_2d(mx_flat, cur_h, cur_w) * 2 + mx_latents = _spatial_unreshape(mx_flat, NL, C, cur_h, cur_w) + mx.eval(mx_latents) + + # Downsample — PyTorch + cur_h_pt, cur_w_pt = H, W + pt_flat = pt_latents.permute(0, 2, 1, 3, 4).reshape(NL, C, H, W) + for _ in range(STAGES - 1): + cur_h_pt //= 2 + cur_w_pt //= 2 + pt_flat = F.interpolate(pt_flat, size=(cur_h_pt, cur_w_pt), mode="bilinear") * 2 + pt_latents = pt_flat.reshape(1, NL, C, cur_h_pt, cur_w_pt).permute(0, 2, 1, 3, 4) + + mx_np = np.array(mx_latents) + pt_np = pt_latents.squeeze(0).numpy() + diff = np.abs(mx_np - pt_np) + print(f"After downsample to {cur_h}×{cur_w}: diff max={diff.max():.8f} mean={diff.mean():.8f}") + + # Initialize schedulers + mlx_sched = MLXScheduler( + num_train_timesteps=1000, shift=1.0, stages=STAGES, + gamma=GAMMA, use_dynamic_shifting=True, + ) + pt_sched = PTScheduler( + num_train_timesteps=1000, shift=1.0, stages=STAGES, + gamma=GAMMA, use_dynamic_shifting=True, scheduler_type="dmd", + ) + + mx_start_points = [mx_latents] + pt_start_points = [pt_latents.clone()] + max_diff = 0.0 + + for i_s in range(STAGES): + seq_len = (NL * cur_h * cur_w) // math.prod(PATCH_SIZE) + mu = calculate_shift(seq_len) + + mlx_sched.set_timesteps(PYRAMID_STEPS[i_s], stage_index=i_s, image_seq_len=seq_len) + pt_sched.set_timesteps(PYRAMID_STEPS[i_s], i_s, device="cpu", mu=mu) + + print(f"\nStage {i_s}: {cur_h}×{cur_w}, seq_len={seq_len}") + print(f" MLX sigmas: {mlx_sched.sigmas.tolist()}") + print(f" PT sigmas: {pt_sched.sigmas.tolist()}") + + if i_s > 0: + cur_h *= 2 + cur_w *= 2 + cur_h_pt *= 2 + cur_w_pt *= 2 + + # Upsample + mx_flat = _spatial_reshape(mx_latents, NL, C) + mx_flat = _nearest_upsample_2d(mx_flat, cur_h, cur_w) + mx_latents = _spatial_unreshape(mx_flat, NL, C, cur_h, cur_w) + + pt_flat = pt_latents.permute(0, 2, 1, 3, 4).reshape(NL, C, cur_h_pt // 2, cur_w_pt // 2) + pt_flat = F.interpolate(pt_flat, size=(cur_h_pt, cur_w_pt), mode="nearest") + pt_latents = pt_flat.reshape(1, NL, C, cur_h_pt, cur_w_pt).permute(0, 2, 1, 3, 4) + + # Alpha/beta blending with same noise + ori_sigma = 1 - mlx_sched.ori_start_sigmas[i_s] + alpha = 1 / (math.sqrt(1 + (1 / GAMMA)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(GAMMA) + + noise_np2 = np.random.RandomState(args.seed + i_s * 1000).randn( + C, NL, cur_h, cur_w + ).astype(np.float32) + mx_latents = alpha * mx_latents + beta * mx.array(noise_np2) + pt_latents = alpha * pt_latents + beta * torch.from_numpy(noise_np2).unsqueeze(0) + + mx.eval(mx_latents) + mx_start_points.append(mx_latents) + pt_start_points.append(pt_latents.clone()) + + d = np.abs(np.array(mx_latents) - pt_latents.squeeze(0).numpy()) + print(f" After upsample+mix: diff max={d.max():.8f} mean={d.mean():.8f}") + + # DMD steps with fixed model output + for idx in range(len(mlx_sched.timesteps)): + t_pt = pt_sched.timesteps[idx] + + mx_pred = mx.full(mx_latents.shape, 0.05) + pt_pred = torch.full(pt_latents.shape, 0.05) + + mx_latents = mlx_sched.step_dmd(mx_pred, mx_latents, idx, mx_start_points[i_s]) + mx.eval(mx_latents) + + pt_latents = pt_sched.step( + pt_pred, t_pt, pt_latents, return_dict=False, + cur_sampling_step=idx, + dmd_noisy_tensor=pt_start_points[i_s], + dmd_sigmas=pt_sched.sigmas, + dmd_timesteps=pt_sched.timesteps, + all_timesteps=pt_sched.timesteps, + )[0] + + d = np.abs(np.array(mx_latents) - pt_latents.squeeze(0).numpy()) + max_diff = max(max_diff, d.max()) + print( + f" Step {idx}: MLX mean={mx_latents.mean().item():.6f} " + f"PT mean={pt_latents.mean():.6f} diff max={d.max():.8f}" + ) + + print(f"\n{'=' * 50}") + print(f"Maximum difference across all stages/steps: {max_diff:.8f}") + if max_diff < 1e-4: + print("✓ Pipelines are numerically equivalent") + elif max_diff < 1e-2: + print("⚠ Small differences (likely floating point precision)") + else: + print("✗ Significant differences detected — investigate!") + + +if __name__ == "__main__": + main() diff --git a/mlx_video/models/helios/scripts/run_reference.py b/mlx_video/models/helios/scripts/run_reference.py new file mode 100644 index 0000000..389a6e4 --- /dev/null +++ b/mlx_video/models/helios/scripts/run_reference.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +"""Run the reference Helios pipeline on MPS for comparison. + +Generates a video using the original PyTorch/diffusers Helios pipeline on Apple +MPS, with necessary float64→float32 patches for MPS compatibility. Useful for +comparing output quality against the MLX implementation. + +Requirements: + pip install diffusers transformers torch accelerate + +Usage: + python mlx_video/models/helios/scripts/run_reference.py \ + --model-dir /path/to/Helios-Distilled \ + --prompt "A golden retriever running on a sunny beach" \ + --output /tmp/helios_ref.mp4 + + # Compare against MLX output + python -m mlx_video.generate_helios \ + --model-dir /path/to/Helios-Distilled-MLX \ + --prompt "A golden retriever running on a sunny beach" \ + --output-path /tmp/helios_mlx.mp4 + python mlx_video/models/helios/scripts/analyze_boundaries.py \ + /tmp/helios_ref.mp4 /tmp/helios_mlx.mp4 +""" + +import argparse + +import cv2 +import numpy as np +import torch + + +def patch_scheduler_for_mps(): + """Patch the Helios DMD scheduler to work on MPS (no float64 support).""" + import diffusers.schedulers.scheduling_helios_dmd as sched_mod + + _orig_set_ts = sched_mod.HeliosDMDScheduler.set_timesteps + + def _patched_set_ts( + self, + num_inference_steps, + stage_index=None, + device=None, + sigmas=None, + mu=None, + is_amplify_first_chunk=False, + ): + real_device = device + _orig_set_ts( + self, + num_inference_steps, + stage_index=stage_index, + device="cpu", + sigmas=sigmas, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk, + ) + self.timesteps = self.timesteps.float() + self.sigmas = self.sigmas.float() + if real_device is not None and str(real_device) != "cpu": + self.timesteps = self.timesteps.to(real_device) + self.sigmas = self.sigmas.to(real_device) + + sched_mod.HeliosDMDScheduler.set_timesteps = _patched_set_ts + + def _patched_convert_flow(self, flow_pred, xt, timestep, sigmas, timesteps): + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = ( + x.float().to(device) for x in (flow_pred, xt, sigmas, timesteps) + ) + timestep_id = torch.argmin( + (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + sched_mod.HeliosDMDScheduler.convert_flow_pred_to_x0 = _patched_convert_flow + + +def main(): + parser = argparse.ArgumentParser(description="Run Helios reference pipeline on MPS") + parser.add_argument("--model-dir", required=True, help="Path to Helios-Distilled weights") + parser.add_argument("--prompt", required=True, help="Text prompt") + parser.add_argument("--output", default="/tmp/helios_ref.mp4", help="Output video path") + parser.add_argument("--height", type=int, default=384) + parser.add_argument("--width", type=int, default=640) + parser.add_argument("--num-frames", type=int, default=99, help="Total frames (33 per chunk)") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--fps", type=int, default=16) + args = parser.parse_args() + + print("Patching scheduler for MPS compatibility...") + patch_scheduler_for_mps() + + print("Loading pipeline...") + from diffusers import DiffusionPipeline + + pipe = DiffusionPipeline.from_pretrained( + args.model_dir, + torch_dtype=torch.float16, + ).to("mps") + + generator = torch.Generator("mps").manual_seed(args.seed) + + print(f"Generating {args.num_frames} frames...") + video = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + guidance_scale=1.0, + generator=generator, + pyramid_num_inference_steps_list=[2, 2, 2], + is_amplify_first_chunk=True, + ).frames + + frames = video[0] + print(f"Got {len(frames)} frames, size: {frames[0].size}") + + out = cv2.VideoWriter( + args.output, cv2.VideoWriter_fourcc(*"mp4v"), args.fps, frames[0].size + ) + for f in frames: + out.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR)) + out.release() + print(f"Saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/mlx_video/models/helios/transformer.py b/mlx_video/models/helios/transformer.py new file mode 100644 index 0000000..62d6de5 --- /dev/null +++ b/mlx_video/models/helios/transformer.py @@ -0,0 +1,526 @@ +"""Helios transformer backbone for MLX. + +Implements the Helios 14B DiT with multi-scale history memory, +restricted self-attention, and 6-vector modulation per block. +""" + +import math + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .attention import ( + HeliosCrossAttention, + HeliosLayerNorm, + HeliosRMSNorm, + HeliosSelfAttention, + _linear_dtype, +) +from .config import HeliosModelConfig +from .rope import helios_rope_params, helios_rope_precompute_cos_sin, helios_rope_precompute_history + + +class HeliosFFN(nn.Module): + """Gated feed-forward network with GELU(tanh) activation.""" + + def __init__(self, dim: int, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(dim, ffn_dim) + self.act = nn.GELU(approx="tanh") + self.fc2 = nn.Linear(ffn_dim, dim) + + def __call__(self, x: mx.array) -> mx.array: + x_w = x.astype(_linear_dtype(self.fc1)) + return self.fc2(self.act(self.fc1(x_w))) + + +class HeliosTransformerBlock(nn.Module): + """Helios transformer block: self-attn + cross-attn + FFN with 6-vector modulation.""" + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: bool = True, + cross_attn_norm: bool = True, + eps: float = 1e-6, + restrict_self_attn: bool = False, + ): + super().__init__() + + # Self-attention + self.norm1 = HeliosLayerNorm(dim, eps) + self.self_attn = HeliosSelfAttention( + dim, num_heads, qk_norm, eps, + restrict_self_attn=restrict_self_attn, + ) + + # Cross-attention + self.cross_attn = HeliosCrossAttention(dim, num_heads, qk_norm, eps) + self.norm2 = ( + HeliosLayerNorm(dim, eps, elementwise_affine=True) + if cross_attn_norm + else None + ) + + # Feed-forward + self.ffn = HeliosFFN(dim, ffn_dim) + self.norm3 = HeliosLayerNorm(dim, eps) + + # 6-vector modulation table (scale/shift/gate for self-attn and FFN) + self.scale_shift_table = ( + mx.random.normal((1, 6, dim)) * (dim**-0.5) + ).astype(mx.float32) + + # Whether to separate history from current in cross-attention + self.guidance_cross_attn = True + + def __call__( + self, + x: mx.array, + encoder_hidden_states: mx.array, + timestep_proj: mx.array, + rotary_emb: tuple | None, + original_context_length: int, + frame_indices: mx.array | None = None, + grid_size: tuple | None = None, + freqs: tuple | None = None, + cross_kv_cache: tuple | None = None, + ) -> mx.array: + w_dtype = _linear_dtype(self.self_attn.q) + history_seq_len = x.shape[1] - original_context_length + + # Compute 6-vector modulation + # timestep_proj: [B, L, 6, dim] (per-token) or [B, 6, dim] (global) + if timestep_proj.ndim == 4: + # [B, L, 6, dim] + [1, 1, 6, dim] → [B, L, 6, dim] + mod = (self.scale_shift_table[None, :, :] + timestep_proj.astype(mx.float32)).astype(w_dtype) + shift_msa = mod[:, :, 0] + scale_msa = mod[:, :, 1] + gate_msa = mod[:, :, 2] + c_shift = mod[:, :, 3] + c_scale = mod[:, :, 4] + c_gate = mod[:, :, 5] + else: + # [B, 6, dim] + [1, 6, dim] → [B, 6, dim] + mod = (self.scale_shift_table + timestep_proj.astype(mx.float32)).astype(w_dtype) + shift_msa = mod[:, 0:1] + scale_msa = mod[:, 1:2] + gate_msa = mod[:, 2:3] + c_shift = mod[:, 3:4] + c_scale = mod[:, 4:5] + c_gate = mod[:, 5:6] + + # 1. Self-attention + norm_x = (self.norm1(x) * (1 + scale_msa) + shift_msa).astype(w_dtype) + attn_out = self.self_attn( + norm_x, + frame_indices=frame_indices, + grid_size=grid_size, + freqs=freqs, + rope_cos_sin=rotary_emb, + original_context_length=original_context_length, + ) + # Residual in float32 to match reference (prevents systematic truncation + # of small updates in bfloat16, which compounds across chunks via history) + x = (x.astype(mx.float32) + attn_out * gate_msa).astype(w_dtype) + + # 2. Cross-attention (history tokens skip cross-attention) + if self.guidance_cross_attn and history_seq_len > 0: + history_x, current_x = x[:, :history_seq_len], x[:, history_seq_len:] + norm_current = self.norm2(current_x) if self.norm2 is not None else current_x + cross_out = self.cross_attn(norm_current, encoder_hidden_states, kv_cache=cross_kv_cache) + current_x = (current_x.astype(mx.float32) + cross_out).astype(w_dtype) + x = mx.concatenate([history_x, current_x], axis=1) + else: + norm_x = self.norm2(x) if self.norm2 is not None else x + cross_out = self.cross_attn(norm_x, encoder_hidden_states, kv_cache=cross_kv_cache) + x = (x.astype(mx.float32) + cross_out).astype(w_dtype) + + # 3. Feed-forward + norm_x = (self.norm3(x) * (1 + c_scale) + c_shift).astype(w_dtype) + ff_out = self.ffn(norm_x) + x = (x.astype(mx.float32) + ff_out.astype(mx.float32) * c_gate).astype(w_dtype) + + return x + + +class HeliosModel(nn.Module): + """Helios 14B diffusion backbone with multi-scale history memory.""" + + def __init__(self, config: HeliosModelConfig): + super().__init__() + self.config = config + dim = config.dim + self.dim = dim + self.num_heads = config.num_heads + self.out_dim = config.out_dim + self.patch_size = config.patch_size + self.text_len = config.text_len + self.freq_dim = config.freq_dim + + # Patch embedding (Conv3d as reshaped Linear) + patch_dim = config.in_dim * math.prod(config.patch_size) + self.patch_embedding = nn.Linear(patch_dim, dim) + self._patch_size = config.patch_size + + # Multi-scale history patches (short/mid/long Conv3d as Linear) + if config.has_multi_term_memory_patch: + self.patch_short = nn.Linear(config.in_dim * 1 * 2 * 2, dim) + self.patch_mid = nn.Linear(config.in_dim * 2 * 4 * 4, dim) + self.patch_long = nn.Linear(config.in_dim * 4 * 8 * 8, dim) + self.has_multi_term_memory_patch = config.has_multi_term_memory_patch + + # Text embedding (PixArtAlpha-style projection) + self.text_embedding_0 = nn.Linear(config.text_dim, dim) + self.text_embedding_act = nn.GELU(approx="tanh") + self.text_embedding_1 = nn.Linear(dim, dim) + + # Time embedding (sinusoidal → MLP) + self.time_embedding_0 = nn.Linear(config.freq_dim, dim) + self.time_embedding_act = nn.SiLU() + self.time_embedding_1 = nn.Linear(dim, dim) + + # Time projection for modulation (6x dim for scale/shift/gate) + self.time_projection_act = nn.SiLU() + self.time_projection = nn.Linear(dim, dim * 6) + + # Transformer blocks + self.blocks = [ + HeliosTransformerBlock( + dim=dim, + ffn_dim=config.ffn_dim, + num_heads=config.num_heads, + qk_norm=config.qk_norm, + cross_attn_norm=config.cross_attn_norm, + eps=config.eps, + restrict_self_attn=False, + ) + for _ in range(config.num_layers) + ] + + # Output norm and projection + self.output_norm = HeliosLayerNorm(dim, config.eps) + self.output_norm_table = ( + mx.random.normal((1, 2, dim)) * (dim**-0.5) + ).astype(mx.float32) + proj_dim = math.prod(config.patch_size) * config.out_dim + self.proj_out = nn.Linear(dim, proj_dim) + + # RoPE frequencies + self.rope_freqs = helios_rope_params( + rope_dim=config.rope_dim, + theta=config.rope_theta, + max_seq_len=1024, + ) + + # Whether to zero out history timestep embedding + self.zero_history_timestep = config.zero_history_timestep + + # Precompute sinusoidal inv_freq for time embedding + half = config.freq_dim // 2 + self._inv_freq = mx.power( + 10000.0, -mx.arange(half).astype(mx.float32) / half + ) + + # Will be computed lazily after weights are loaded + self._t0_proj: mx.array | None = None + + def _get_t0_projection(self) -> mx.array: + """Get the cached t=0 timestep projection, computing on first use. + + Must be called after weights are loaded (not in __init__) so that + the time_embedding and time_projection layers use trained weights. + """ + if self._t0_proj is None: + t0_emb = mx.array([0.0]) * self._inv_freq + t0_emb = mx.concatenate([mx.cos(t0_emb), mx.sin(t0_emb)], axis=-1)[None, :] + temb_t0 = self.time_embedding_1( + self.time_embedding_act(self.time_embedding_0(t0_emb)) + ) + tp_t0 = self.time_projection( + self.time_projection_act(temb_t0) + ).reshape(1, 6, -1) + mx.eval(tp_t0) + self._t0_proj = tp_t0 + return self._t0_proj + + def _patchify(self, x: mx.array) -> tuple: + """Convert video latent to patch embeddings. + + Args: + x: Video latent [C, F, H, W] + + Returns: + (patches, grid_size): patches [1, L, dim], grid_size (F', H', W') + """ + c, f, h, w = x.shape + pt, ph, pw = self._patch_size + + # Truncate to patch-aligned dims (matches Conv3d floor-division) + f = (f // pt) * pt + h = (h // ph) * ph + w = (w // pw) * pw + x = x[:, :f, :h, :w] + + f_out = f // pt + h_out = h // ph + w_out = w // pw + + x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw) + x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw] + x = x.reshape(f_out * h_out * w_out, -1) + + patches = self.patch_embedding(x) + patches = patches.astype(_linear_dtype(self.patch_embedding)) + return patches[None, :, :], (f_out, h_out, w_out) + + def _patchify_history(self, x: mx.array, scale: str) -> mx.array: + """Patchify history latents at different scales. + + Args: + x: History latent [C, F, H, W] + scale: 'short' (1,2,2), 'mid' (2,4,4), or 'long' (4,8,8) + """ + c, f, h, w = x.shape + if scale == "short": + kernel = (1, 2, 2) + proj = self.patch_short + elif scale == "mid": + kernel = (2, 4, 4) + proj = self.patch_mid + else: + kernel = (4, 8, 8) + proj = self.patch_long + + kt, kh, kw = kernel + + # Pad to make divisible by kernel size + pad_t = (kt - (f % kt)) % kt + pad_h = (kh - (h % kh)) % kh + pad_w = (kw - (w % kw)) % kw + if pad_t > 0 or pad_h > 0 or pad_w > 0: + x = mx.pad( + x, + [(0, 0), (0, pad_t), (0, pad_h), (0, pad_w)], + mode="edge", + ) + c, f, h, w = x.shape + + f_out = f // kt + h_out = h // kh + w_out = w // kw + + x = x.reshape(c, f_out, kt, h_out, kh, w_out, kw) + x = x.transpose(1, 3, 5, 0, 2, 4, 6) + x = x.reshape(f_out * h_out * w_out, -1) + + patches = proj(x) + patches = patches.astype(_linear_dtype(proj)) + return patches[None, :, :] # [1, L, dim] + + def embed_text(self, context: list) -> mx.array: + """Precompute text embeddings.""" + model_dtype = _linear_dtype(self.patch_embedding) + context_padded = [] + for ctx in context: + pad_len = self.text_len - ctx.shape[0] + if pad_len > 0: + ctx = mx.concatenate( + [ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)], + axis=0, + ) + context_padded.append(ctx) + context_batch = mx.stack(context_padded) + context_batch = self.text_embedding_1( + self.text_embedding_act(self.text_embedding_0(context_batch)) + ) + return context_batch.astype(model_dtype) + + def prepare_cross_kv(self, context: mx.array) -> list: + """Pre-compute cross-attention K/V caches.""" + return [block.cross_attn.prepare_kv(context) for block in self.blocks] + + def unpatchify(self, x: mx.array, grid_size: tuple) -> mx.array: + """Reconstruct video from patch embeddings. + + Args: + x: [B, L, out_dim * prod(patch_size)] + grid_size: (F', H', W') + + Returns: + [C, F, H, W] + """ + c = self.out_dim + pt, ph, pw = self.patch_size + f, h, w = grid_size + seq_len = f * h * w + + u = x[0, :seq_len] + u = u.reshape(f, h, w, pt, ph, pw, c) + u = u.transpose(6, 0, 3, 1, 4, 2, 5) + return u.reshape(c, f * pt, h * ph, w * pw) + + def __call__( + self, + latents: mx.array, + timestep: mx.array, + encoder_hidden_states: mx.array, + frame_indices: mx.array | None = None, + history_short: mx.array | None = None, + history_mid: mx.array | None = None, + history_long: mx.array | None = None, + history_short_indices: mx.array | None = None, + history_mid_indices: mx.array | None = None, + history_long_indices: mx.array | None = None, + cross_kv_caches: list | None = None, + ) -> mx.array: + """Forward pass through the Helios transformer. + + Args: + latents: Current chunk latent [C, F, H, W] + timestep: Scalar diffusion timestep + encoder_hidden_states: Text embeddings [B, text_len, dim] + frame_indices: Frame indices for current chunk [F'] + history_short/mid/long: History latents at different scales + history_*_indices: Frame indices for history at each scale + cross_kv_caches: Pre-computed cross-attention K/V caches + + Returns: + Predicted velocity/noise [C, F, H, W] + """ + # 1. Patchify current latents + hidden_states, grid_size = self._patchify(latents) + f_out, h_out, w_out = grid_size + current_seq_len = hidden_states.shape[1] + + if frame_indices is None: + frame_indices = mx.arange(f_out) + + # 2. Compute RoPE for current chunk + rope_cos_sin = helios_rope_precompute_cos_sin( + frame_indices, grid_size, self.rope_freqs, + dtype=hidden_states.dtype, + ) + + # 3. Process multi-scale history and prepend + history_seq_len = 0 + if history_short is not None and self.has_multi_term_memory_patch: + hist_s = self._patchify_history(history_short, "short") + hist_m = self._patchify_history(history_mid, "mid") + hist_l = self._patchify_history(history_long, "long") + + sh, sm, sl = hist_s.shape[1], hist_m.shape[1], hist_l.shape[1] + + # Short patch output spatial dims (kernel 1,2,2) + c_s, f_s, h_s, w_s = history_short.shape + hs_h = h_s // 2 + hs_w = w_s // 2 + hs_f = f_s # temporal stride 1 + + # RoPE for short history: compute at short output resolution + rope_hist_s = helios_rope_precompute_history( + history_short_indices, hs_h, hs_w, self.rope_freqs, + downsample_kernel=None, + dtype=hidden_states.dtype, + ) + # RoPE for mid history: compute at short resolution, downsample (2,2,2) + rope_hist_m = helios_rope_precompute_history( + history_mid_indices, hs_h, hs_w, self.rope_freqs, + downsample_kernel=(2, 2, 2), + dtype=hidden_states.dtype, + ) + # RoPE for long history: compute at short resolution, downsample (4,4,4) + rope_hist_l = helios_rope_precompute_history( + history_long_indices, hs_h, hs_w, self.rope_freqs, + downsample_kernel=(4, 4, 4), + dtype=hidden_states.dtype, + ) + + # Concatenate history: [long, mid, short, current] + hidden_states = mx.concatenate( + [hist_l, hist_m, hist_s, hidden_states], axis=1 + ) + history_seq_len = sl + sm + sh + + # Concatenate RoPE: match history ordering + all_cos = mx.concatenate([ + rope_hist_l[0], rope_hist_m[0], rope_hist_s[0], rope_cos_sin[0] + ], axis=0) + all_sin = mx.concatenate([ + rope_hist_l[1], rope_hist_m[1], rope_hist_s[1], rope_cos_sin[1] + ], axis=0) + rope_cos_sin = (all_cos, all_sin) + + original_context_length = current_seq_len + + # 4. Time embedding + t_emb = timestep.astype(mx.float32) * self._inv_freq + t_emb = mx.concatenate([mx.cos(t_emb), mx.sin(t_emb)], axis=-1) + if t_emb.ndim == 1: + t_emb = t_emb[None, :] + + temb = self.time_embedding_1( + self.time_embedding_act(self.time_embedding_0(t_emb)) + ) + timestep_proj = self.time_projection( + self.time_projection_act(temb) + ) + timestep_proj = timestep_proj.reshape(1, 6, -1) + + # Expand to per-token: [B, 6, L, dim] + timestep_proj_expanded = mx.broadcast_to( + timestep_proj[:, :, None, :], + (1, 6, original_context_length, self.dim), + ) + + # Zero history timestep embedding (use precomputed t=0 projection) + if self.zero_history_timestep and history_seq_len > 0: + tp_t0_expanded = mx.broadcast_to( + self._get_t0_projection()[:, :, None, :], + (1, 6, history_seq_len, self.dim), + ) + timestep_proj_expanded = mx.concatenate( + [tp_t0_expanded, timestep_proj_expanded], axis=2 + ) + + # Permute to [B, L, 6, dim] for block consumption + timestep_proj_expanded = timestep_proj_expanded.transpose(0, 2, 1, 3) + + # 5. Transformer blocks + for i, block in enumerate(self.blocks): + kv_cache = cross_kv_caches[i] if cross_kv_caches is not None else None + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj_expanded, + rotary_emb=rope_cos_sin, + original_context_length=original_context_length, + frame_indices=frame_indices, + grid_size=grid_size, + freqs=self.rope_freqs, + cross_kv_cache=kv_cache, + ) + + # 6. Output norm, projection & unpatchify (only current tokens) + hidden_out = hidden_states[:, -original_context_length:] + + # Output modulation: temb is [B, 1, dim], expand to [B, L, dim] + w_dtype = _linear_dtype(self.proj_out) + temb_expanded = mx.broadcast_to( + temb[:, None, :], (1, original_context_length, self.dim) + ) + # scale_shift_table: [1, 2, dim] → [1, 1, 2, dim] + # temb_expanded: [B, L, dim] → [B, L, 1, dim] + mod_out = (self.output_norm_table[None, :, :] + temb_expanded[:, :, None, :]).astype(w_dtype) + shift = mod_out[:, :, 0, :] # [B, L, dim] + scale = mod_out[:, :, 1, :] + + hidden_out = (self.output_norm(hidden_out) * (1 + scale) + shift).astype(w_dtype) + hidden_out = self.proj_out(hidden_out) + + # Unpatchify + output = self.unpatchify(hidden_out, grid_size) + return output diff --git a/pyproject.toml b/pyproject.toml index 198956d..19883ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ Issues = "https://github.com/Blaizzy/mlx-video/issues" [project.scripts] "mlx_video.generate" = "mlx_video.generate:main" "mlx_video.generate_wan" = "mlx_video.generate_wan:main" +"mlx_video.generate_helios" = "mlx_video.generate_helios:main" [tool.setuptools.packages.find] include = ["mlx_video*"] diff --git a/scripts/helios/analyze_boundaries.py b/scripts/helios/analyze_boundaries.py new file mode 100644 index 0000000..5f91171 --- /dev/null +++ b/scripts/helios/analyze_boundaries.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +"""Analyze chunk boundary quality in Helios-generated videos. + +Measures brightness, contrast, color shifts, spatial distribution, and +frame-to-frame differences at chunk boundaries. Compares multiple videos +side-by-side when given multiple paths. + +This was the primary diagnostic tool used to identify and fix: +- 40% contrast drops from pixel cross-fade (→ disabled cross-fade) +- 7% contrast drops from VAE causal padding warmup (→ contrast correction) +- Per-channel color shifts at boundaries (→ per-channel matching) +- Spatial brightness redistribution (→ low-frequency spatial correction) + +Usage: + # Analyze a single video + python mlx_video/models/helios/scripts/analyze_boundaries.py /tmp/helios_output.mp4 + + # Compare multiple videos + python mlx_video/models/helios/scripts/analyze_boundaries.py \ + /tmp/helios_before.mp4 /tmp/helios_after.mp4 + + # Custom chunk size (default: 32 frames per chunk) + python mlx_video/models/helios/scripts/analyze_boundaries.py --chunk-size 33 /tmp/ref.mp4 +""" + +import argparse +import sys + +import cv2 +import numpy as np + + +def analyze_video(path, chunk_size=32): + """Analyze boundary quality metrics for a video.""" + vid = cv2.VideoCapture(path) + if not vid.isOpened(): + print(f"Error: cannot open {path}", file=sys.stderr) + return None + + frames = [] + while True: + ret, f = vid.read() + if not ret: + break + frames.append(f) + vid.release() + + n = len(frames) + if n == 0: + print(f"Error: no frames in {path}", file=sys.stderr) + return None + + # Compute per-frame statistics + means = np.zeros(n) + stds = np.zeros(n) + ch_means = np.zeros((n, 3)) + diffs = np.zeros(n - 1) + + for i, f in enumerate(frames): + gray = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY).astype(np.float64) + means[i] = gray.mean() + stds[i] = gray.std() + ch_means[i] = [f[:, :, c].mean() for c in range(3)] + if i > 0: + prev_gray = cv2.cvtColor(frames[i - 1], cv2.COLOR_BGR2GRAY).astype(np.float64) + diffs[i - 1] = np.abs(gray - prev_gray).mean() + + # Find chunk boundaries + boundaries = [] + b = chunk_size - 1 + while b < n - 1: + boundaries.append(b) + b += chunk_size + + results = { + "path": path, + "num_frames": n, + "chunk_size": chunk_size, + "boundaries": [], + } + + for b in boundaries: + if b >= n - 1: + break + + # Contrast + pre_std = stds[max(0, b - 2) : b + 1] + post_std = stds[b + 1 : min(n, b + 4)] + contrast_jump = post_std[0] - pre_std[-1] + contrast_pct = contrast_jump / max(pre_std[-1], 1e-6) * 100 + + # Brightness + bright_jump = means[b + 1] - means[b] + bright_pct = bright_jump / max(means[b], 1e-6) * 100 + + # Per-channel color shift + ch_shifts = ch_means[b + 1] - ch_means[b] # B, G, R + + # Frame diff ratio + boundary_diff = diffs[b] + window = 3 + nearby_indices = list(range(max(0, b - window), b)) + list( + range(b + 1, min(len(diffs), b + 1 + window)) + ) + nearby_avg = np.mean(diffs[nearby_indices]) if nearby_indices else 1.0 + diff_ratio = boundary_diff / max(nearby_avg, 1e-6) + + # Spatial analysis + f_pre = frames[b].astype(np.float64) + f_post = frames[b + 1].astype(np.float64) + gray_diff = cv2.cvtColor(frames[b + 1], cv2.COLOR_BGR2GRAY).astype( + np.float64 + ) - cv2.cvtColor(frames[b], cv2.COLOR_BGR2GRAY).astype(np.float64) + h, w = gray_diff.shape + ch, cw = h // 4, w // 4 + center_shift = gray_diff[ch : 3 * ch, cw : 3 * cw].mean() + periph_mask = np.ones_like(gray_diff, dtype=bool) + periph_mask[ch : 3 * ch, cw : 3 * cw] = False + periph_shift = gray_diff[periph_mask].mean() + + results["boundaries"].append( + { + "frame": b, + "contrast_pct": contrast_pct, + "bright_pct": bright_pct, + "ch_shifts_bgr": ch_shifts.tolist(), + "diff_ratio": diff_ratio, + "center_shift": center_shift, + "periph_shift": periph_shift, + "boundary_diff": boundary_diff, + "nearby_diff": nearby_avg, + } + ) + + # Per-chunk stats + chunk_stats = [] + for c in range(0, n, chunk_size): + end = min(c + chunk_size, n) + chunk_stats.append( + { + "frames": f"{c}-{end - 1}", + "mean_bright": means[c:end].mean(), + "mean_contrast": stds[c:end].mean(), + "first_contrast": stds[c], + "last_contrast": stds[end - 1], + } + ) + results["chunk_stats"] = chunk_stats + + return results + + +def print_results(results): + """Pretty-print analysis results.""" + print(f"\n{'=' * 70}") + print(f" {results['path']}") + print(f" {results['num_frames']} frames, chunk size = {results['chunk_size']}") + print(f"{'=' * 70}") + + for bd in results["boundaries"]: + b = bd["frame"] + print(f"\n Boundary {b}→{b + 1}:") + print(f" Contrast jump: {bd['contrast_pct']:+.1f}%") + print(f" Brightness jump: {bd['bright_pct']:+.1f}%") + print( + f" Color shift B/G/R: {bd['ch_shifts_bgr'][0]:+.1f} / " + f"{bd['ch_shifts_bgr'][1]:+.1f} / {bd['ch_shifts_bgr'][2]:+.1f}" + ) + print( + f" Frame diff: {bd['boundary_diff']:.1f} vs nearby " + f"{bd['nearby_diff']:.1f} ({bd['diff_ratio']:.1f}×)" + ) + print( + f" Spatial: center {bd['center_shift']:+.2f}, " + f"periphery {bd['periph_shift']:+.2f}" + ) + + print(f"\n Per-chunk summary:") + for cs in results["chunk_stats"]: + print( + f" Frames {cs['frames']:>7s}: brightness={cs['mean_bright']:.1f}, " + f"contrast={cs['mean_contrast']:.1f} " + f"(first={cs['first_contrast']:.1f}, last={cs['last_contrast']:.1f})" + ) + + +def print_comparison(all_results): + """Print side-by-side comparison table.""" + if len(all_results) < 2: + return + + print(f"\n{'=' * 70}") + print(" COMPARISON SUMMARY") + print(f"{'=' * 70}") + + # Header + labels = [r["path"].split("/")[-1] for r in all_results] + header = f"{'Metric':<25s}" + for label in labels: + header += f" {label:>18s}" + print(f"\n{header}") + print("-" * (25 + 20 * len(labels))) + + # For each boundary index + max_boundaries = max(len(r["boundaries"]) for r in all_results) + for bi in range(max_boundaries): + print(f"\n Boundary {bi + 1}:") + for metric, key, fmt in [ + ("Contrast jump", "contrast_pct", "{:+.1f}%"), + ("Brightness jump", "bright_pct", "{:+.1f}%"), + ("Frame diff ratio", "diff_ratio", "{:.1f}×"), + ("Center shift", "center_shift", "{:+.2f}"), + ("Periphery shift", "periph_shift", "{:+.2f}"), + ]: + row = f" {metric:<23s}" + for r in all_results: + if bi < len(r["boundaries"]): + val = r["boundaries"][bi][key] + row += f" {fmt.format(val):>18s}" + else: + row += f" {'N/A':>18s}" + print(row) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze chunk boundary quality in Helios videos" + ) + parser.add_argument("videos", nargs="+", help="Video file paths to analyze") + parser.add_argument( + "--chunk-size", + type=int, + default=32, + help="Frames per chunk (default: 32, use 33 for reference pipeline)", + ) + args = parser.parse_args() + + all_results = [] + for path in args.videos: + results = analyze_video(path, args.chunk_size) + if results is not None: + print_results(results) + all_results.append(results) + + if len(all_results) > 1: + print_comparison(all_results) + + +if __name__ == "__main__": + main() diff --git a/tests/test_helios.py b/tests/test_helios.py new file mode 100644 index 0000000..76004ab --- /dev/null +++ b/tests/test_helios.py @@ -0,0 +1,554 @@ +"""Tests for Helios model configuration, scheduler, RoPE, and transformer.""" + +import math + +import mlx.core as mx +import numpy as np +import pytest + + +# --------------------------------------------------------------------------- +# Config Tests +# --------------------------------------------------------------------------- + +class TestHeliosModelConfig: + """Tests for HeliosModelConfig dataclass.""" + + def test_default_values(self): + from mlx_video.models.helios.config import HeliosModelConfig + config = HeliosModelConfig() + assert config.dim == 5120 + assert config.ffn_dim == 13824 + assert config.num_heads == 40 + assert config.num_layers == 40 + assert config.in_dim == 16 + assert config.out_dim == 16 + assert config.patch_size == (1, 2, 2) + assert config.rope_dim == (44, 42, 42) + assert config.history_sizes == [16, 2, 1] + assert config.num_latent_frames_per_chunk == 9 + assert config.vae_stride == (4, 8, 8) + assert config.vae_z_dim == 16 + assert config.text_dim == 4096 + + def test_head_dim_property(self): + from mlx_video.models.helios.config import HeliosModelConfig + config = HeliosModelConfig() + assert config.head_dim == 128 # 5120 // 40 + + def test_distilled_preset(self): + from mlx_video.models.helios.config import HeliosModelConfig + config = HeliosModelConfig.helios_distilled() + assert config.shift == 1.0 + assert config.dim == 5120 + + def test_rope_dim_sums_to_half_head_dim(self): + from mlx_video.models.helios.config import HeliosModelConfig + config = HeliosModelConfig() + # sum(rope_dim) should equal head_dim = 128, since each dim is half + # Actually: 44 + 42 + 42 = 128 = head_dim + assert sum(config.rope_dim) == config.head_dim + + +# --------------------------------------------------------------------------- +# Scheduler Tests +# --------------------------------------------------------------------------- + +class TestHeliosScheduler: + """Tests for HeliosScheduler.""" + + def test_init(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + assert sched.num_train_timesteps == 1000 + assert sched.shift == 1.0 + assert sched.stages == 3 + + def test_global_sigmas_shape(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + assert len(sched.global_sigmas) == 1000 + + def test_set_timesteps(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + sched.set_timesteps(num_inference_steps=10, stage_index=0) + assert sched.timesteps.shape == (10,) + assert sched.sigmas.shape == (11,) # N+1 for boundaries + + def test_step(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + sched.set_timesteps(num_inference_steps=2, stage_index=0) + sample = mx.ones((16, 4, 4, 4)) + model_output = mx.zeros_like(sample) + result = sched.step(model_output, sample) + # With zero model output, result should be close to sample + assert result.shape == sample.shape + + def test_add_noise(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + original = mx.ones((16, 4, 4, 4)) + noise = mx.zeros_like(original) + sigma = mx.array(0.5) + result = sched.add_noise(original, noise, sigma) + # (1 - 0.5) * 1 + 0.5 * 0 = 0.5 + expected = mx.ones_like(result) * 0.5 + assert mx.allclose(result, expected).item() + + def test_per_stage_consistency(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + # All 3 stages should have valid sigma ranges + for i in range(3): + assert sched.start_sigmas[i] >= sched.end_sigmas[i] + + def test_step_dmd_last_step_returns_x0(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + sched.set_timesteps(num_inference_steps=2, stage_index=0) + sample = mx.ones((16, 4, 4, 4)) + flow = mx.ones_like(sample) * 0.1 + noisy_start = mx.zeros_like(sample) + # Last step (idx=1) should return x0_pred directly + result = sched.step_dmd(flow, sample, cur_step=1, noisy_start=noisy_start) + # x0 = sample - sigma[1] * flow + assert result.shape == sample.shape + + def test_step_dmd_non_last_renoises(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + sched.set_timesteps(num_inference_steps=2, stage_index=0) + sample = mx.ones((16, 4, 4, 4)) + flow = mx.zeros_like(sample) + noisy_start = mx.ones_like(sample) * 2.0 + # Non-last step: should blend x0_pred with noisy_start + result = sched.step_dmd(flow, sample, cur_step=0, noisy_start=noisy_start) + assert result.shape == sample.shape + # With flow=0, x0=sample. Result = (1-sigma_next)*x0 + sigma_next*noisy_start + # Should differ from sample since noisy_start != sample + assert not mx.allclose(result, sample).item() + + def test_dynamic_shifting(self): + from mlx_video.models.helios.scheduler import HeliosScheduler, calculate_shift + mu = calculate_shift(1024) + assert 0.3 < mu < 2.0 # reasonable range + sched = HeliosScheduler(use_dynamic_shifting=True) + sched.set_timesteps(2, stage_index=0, image_seq_len=1024) + assert sched.timesteps.shape[0] == 2 + + def test_amplify_first_chunk_doubles_steps(self): + from mlx_video.models.helios.scheduler import HeliosScheduler + sched = HeliosScheduler() + sched.set_timesteps(2, stage_index=0, is_amplify_first_chunk=True) + # 2*2+1 = 5 → DMD trim → 4 timesteps + assert sched.timesteps.shape[0] == 4 + + +# --------------------------------------------------------------------------- +# Pyramid Helper Tests +# --------------------------------------------------------------------------- + +class TestPyramidHelpers: + """Tests for pyramid denoising helper functions.""" + + def test_sample_block_noise_shape(self): + from mlx_video.generate_helios import sample_block_noise + noise = sample_block_noise(1, 16, 9, 48, 80, (1, 2, 2), 1 / 3) + assert noise.shape == (16, 9, 48, 80) + + def test_sample_block_noise_statistics(self): + from mlx_video.generate_helios import sample_block_noise + np.random.seed(42) + noise = sample_block_noise(1, 16, 9, 48, 80, (1, 2, 2), 1 / 3) + noise_np = np.array(noise) + # Should be roughly zero-mean, unit-ish variance + assert abs(noise_np.mean()) < 0.1 + assert 0.5 < noise_np.std() < 2.0 + + def test_bilinear_downsample(self): + from mlx_video.generate_helios import _bilinear_downsample_2d + x = mx.ones((9, 16, 48, 80)) + result = _bilinear_downsample_2d(x, 24, 40) + assert result.shape == (9, 16, 24, 40) + assert mx.allclose(result, mx.ones_like(result)).item() + + def test_nearest_upsample(self): + from mlx_video.generate_helios import _nearest_upsample_2d + x = mx.ones((9, 16, 24, 40)) + result = _nearest_upsample_2d(x, 48, 80) + assert result.shape == (9, 16, 48, 80) + + def test_downsample_history(self): + from mlx_video.generate_helios import _downsample_history + hist = mx.ones((16, 2, 48, 80)) + result = _downsample_history(hist, 2) + assert result.shape == (16, 2, 24, 40) + + def test_spatial_reshape_roundtrip(self): + from mlx_video.generate_helios import _spatial_reshape, _spatial_unreshape + x = mx.random.normal((16, 9, 48, 80)) + reshaped = _spatial_reshape(x, 9, 16) + unreshaped = _spatial_unreshape(reshaped, 9, 16, 48, 80) + assert mx.allclose(x, unreshaped).item() + + +# --------------------------------------------------------------------------- +# RoPE Tests +# --------------------------------------------------------------------------- + +class TestHeliosRoPE: + """Tests for Helios RoPE computation.""" + + def test_rope_params_shape(self): + from mlx_video.models.helios.rope import helios_rope_params + freqs = helios_rope_params( + rope_dim=(44, 42, 42), + theta=10000.0, + max_seq_len=1024, + ) + freqs_t, freqs_h, freqs_w = freqs + # Each freq: [max_seq_len, d_i//2, 2] (cos/sin stacked) + assert freqs_t.shape == (1024, 22, 2) # 44 // 2 + assert freqs_h.shape == (1024, 21, 2) # 42 // 2 + assert freqs_w.shape == (1024, 21, 2) # 42 // 2 + + def test_rope_precompute_shape(self): + from mlx_video.models.helios.rope import ( + helios_rope_params, + helios_rope_precompute_cos_sin, + ) + freqs = helios_rope_params((44, 42, 42), 10000.0, 1024) + frame_indices = mx.arange(9) # 9 latent frames + grid_size = (9, 12, 20) # F, H, W after patchify + + cos_sin = helios_rope_precompute_cos_sin( + frame_indices, grid_size, freqs, dtype=mx.float32, + ) + cos_f, sin_f = cos_sin + total_patches = 9 * 12 * 20 + # Each should be [total_patches, 1, half_head_dim] + # Actually check the actual output shape from the implementation + assert cos_f.shape[0] == total_patches or cos_f.ndim >= 2 + + +# --------------------------------------------------------------------------- +# Attention Tests +# --------------------------------------------------------------------------- + +class TestHeliosAttention: + """Tests for Helios attention modules.""" + + def test_self_attention_no_history(self): + from mlx_video.models.helios.attention import HeliosSelfAttention + dim = 64 + heads = 4 + attn = HeliosSelfAttention(dim, heads, qk_norm=True, eps=1e-6) + x = mx.random.normal((1, 16, dim)) + out = attn( + x, + frame_indices=mx.arange(16), + grid_size=(16, 1, 1), + freqs=None, + rope_cos_sin=None, + original_context_length=16, + ) + assert out.shape == (1, 16, dim) + + def test_cross_attention(self): + from mlx_video.models.helios.attention import HeliosCrossAttention + dim = 64 + heads = 4 + attn = HeliosCrossAttention(dim, heads, qk_norm=True, eps=1e-6) + x = mx.random.normal((1, 16, dim)) + ctx = mx.random.normal((1, 32, dim)) + out = attn(x, ctx) + assert out.shape == (1, 16, dim) + + def test_cross_attention_kv_cache(self): + from mlx_video.models.helios.attention import HeliosCrossAttention + dim = 64 + heads = 4 + attn = HeliosCrossAttention(dim, heads, qk_norm=True, eps=1e-6) + ctx = mx.random.normal((1, 32, dim)) + kv = attn.prepare_kv(ctx) + assert len(kv) == 2 # (k, v) + + x = mx.random.normal((1, 16, dim)) + out = attn(x, ctx, kv_cache=kv) + assert out.shape == (1, 16, dim) + + +# --------------------------------------------------------------------------- +# Transformer Block Tests (small scale) +# --------------------------------------------------------------------------- + +class TestHeliosTransformerBlock: + """Tests for HeliosTransformerBlock.""" + + def test_block_forward_no_history(self): + from mlx_video.models.helios.transformer import HeliosTransformerBlock + dim = 64 + block = HeliosTransformerBlock( + dim=dim, ffn_dim=128, num_heads=4, + qk_norm=True, cross_attn_norm=True, eps=1e-6, + ) + x = mx.random.normal((1, 16, dim)) + ctx = mx.random.normal((1, 32, dim)) + temb = mx.random.normal((1, 16, 6, dim)) + + out = block( + x, ctx, temb, + rotary_emb=None, + original_context_length=16, + ) + assert out.shape == (1, 16, dim) + + +# --------------------------------------------------------------------------- +# Weight Sanitization Tests +# --------------------------------------------------------------------------- + +class TestHeliosWeightSanitization: + """Tests for convert_helios weight key mapping.""" + + def test_patch_embedding_reshape(self): + from mlx_video.convert_helios import sanitize_helios_transformer_weights + # Simulate Conv3d weight: [O, I, D, H, W] + w = { + "patch_embedding.weight": mx.ones((5120, 16, 1, 2, 2)), + "patch_embedding.bias": mx.zeros((5120,)), + } + s = sanitize_helios_transformer_weights(w) + assert "patch_embedding.weight" in s + assert s["patch_embedding.weight"].shape == (5120, 64) # 16*1*2*2 + + def test_condition_embedder_mapping(self): + from mlx_video.convert_helios import sanitize_helios_transformer_weights + w = { + "condition_embedder.time_embedder.linear_1.weight": mx.ones((5120, 256)), + "condition_embedder.time_embedder.linear_2.weight": mx.ones((5120, 5120)), + "condition_embedder.time_proj.weight": mx.ones((30720, 5120)), + "condition_embedder.text_embedder.linear_1.weight": mx.ones((5120, 4096)), + "condition_embedder.text_embedder.linear_2.weight": mx.ones((5120, 5120)), + } + s = sanitize_helios_transformer_weights(w) + assert "time_embedding_0.weight" in s + assert "time_embedding_1.weight" in s + assert "time_projection.weight" in s + assert "text_embedding_0.weight" in s + assert "text_embedding_1.weight" in s + + def test_attention_key_mapping(self): + from mlx_video.convert_helios import sanitize_helios_transformer_weights + w = { + "blocks.0.attn1.to_q.weight": mx.ones((5120, 5120)), + "blocks.0.attn1.to_out.0.weight": mx.ones((5120, 5120)), + "blocks.0.attn2.to_k.weight": mx.ones((5120, 5120)), + } + s = sanitize_helios_transformer_weights(w) + assert "blocks.0.self_attn.q.weight" in s + assert "blocks.0.self_attn.o.weight" in s + assert "blocks.0.cross_attn.k.weight" in s + + def test_ffn_key_mapping(self): + from mlx_video.convert_helios import sanitize_helios_transformer_weights + w = { + "blocks.0.ffn.net.0.proj.weight": mx.ones((13824, 5120)), + "blocks.0.ffn.net.2.weight": mx.ones((5120, 13824)), + } + s = sanitize_helios_transformer_weights(w) + assert "blocks.0.ffn.fc1.weight" in s + assert "blocks.0.ffn.fc2.weight" in s + + def test_output_norm_mapping(self): + from mlx_video.convert_helios import sanitize_helios_transformer_weights + w = { + "norm_out.norm.weight": mx.ones((5120,)), + "norm_out.norm.bias": mx.zeros((5120,)), + "norm_out.scale_shift_table": mx.ones((1, 2, 5120)), + } + s = sanitize_helios_transformer_weights(w) + assert "output_norm.weight" in s + assert "output_norm.bias" in s + assert "output_norm_table" in s + + def test_skips_rope_buffers(self): + from mlx_video.convert_helios import sanitize_helios_transformer_weights + w = { + "rope.freqs_base_t": mx.ones((22,)), + "rope.freqs_base_y": mx.ones((21,)), + } + s = sanitize_helios_transformer_weights(w) + assert len(s) == 0 # All skipped + + +class TestHeliosT5Sanitization: + """Tests for Helios T5 (HF UMT5 → MLX) weight key mapping.""" + + def test_token_embedding(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = {"shared.weight": mx.ones((100, 64))} + s = sanitize_helios_t5_weights(w) + assert "token_embedding.weight" in s + + def test_encoder_embed_tokens(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = {"encoder.embed_tokens.weight": mx.ones((100, 64))} + s = sanitize_helios_t5_weights(w) + assert "token_embedding.weight" in s + + def test_final_layer_norm(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = {"encoder.final_layer_norm.weight": mx.ones((64,))} + s = sanitize_helios_t5_weights(w) + assert "norm.weight" in s + + def test_self_attention_mapping(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = { + "encoder.block.0.layer.0.SelfAttention.q.weight": mx.ones((64, 64)), + "encoder.block.0.layer.0.SelfAttention.k.weight": mx.ones((64, 64)), + "encoder.block.0.layer.0.SelfAttention.v.weight": mx.ones((64, 64)), + "encoder.block.0.layer.0.SelfAttention.o.weight": mx.ones((64, 64)), + } + s = sanitize_helios_t5_weights(w) + assert "blocks.0.attn.q.weight" in s + assert "blocks.0.attn.k.weight" in s + assert "blocks.0.attn.v.weight" in s + assert "blocks.0.attn.o.weight" in s + + def test_relative_attention_bias(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = { + "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": mx.ones((32, 64)), + } + s = sanitize_helios_t5_weights(w) + assert "blocks.0.pos_embedding.embedding.weight" in s + + def test_layer_norms(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = { + "encoder.block.2.layer.0.layer_norm.weight": mx.ones((64,)), + "encoder.block.2.layer.1.layer_norm.weight": mx.ones((64,)), + } + s = sanitize_helios_t5_weights(w) + assert "blocks.2.norm1.weight" in s + assert "blocks.2.norm2.weight" in s + + def test_ffn_mapping(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = { + "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": mx.ones((128, 64)), + "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": mx.ones((128, 64)), + "encoder.block.1.layer.1.DenseReluDense.wo.weight": mx.ones((64, 128)), + } + s = sanitize_helios_t5_weights(w) + assert "blocks.1.ffn.gate_proj.weight" in s + assert "blocks.1.ffn.fc1.weight" in s + assert "blocks.1.ffn.fc2.weight" in s + + def test_skips_decoder_keys(self): + from mlx_video.convert_helios import sanitize_helios_t5_weights + + w = { + "decoder.block.0.layer.0.SelfAttention.q.weight": mx.ones((64, 64)), + "lm_head.weight": mx.ones((100, 64)), + } + s = sanitize_helios_t5_weights(w) + assert len(s) == 0 + + +class TestHeliosVAESanitization: + """Tests for Helios VAE (HF diffusers → WanVAE) weight key mapping.""" + + def test_top_level_convolutions(self): + from mlx_video.convert_helios import sanitize_helios_vae_weights + + w = { + "post_quant_conv.weight": mx.ones((16, 16, 1, 1, 1)), + "post_quant_conv.bias": mx.ones((16,)), + "quant_conv.weight": mx.ones((32, 32, 1, 1, 1)), + "quant_conv.bias": mx.ones((32,)), + } + s = sanitize_helios_vae_weights(w) + assert "conv2.weight" in s + assert "conv2.bias" in s + assert "conv1.weight" in s + assert "conv1.bias" in s + # Conv3d should be transposed + assert s["conv2.weight"].shape == (16, 1, 1, 1, 16) + + def test_decoder_conv_in_out(self): + from mlx_video.convert_helios import sanitize_helios_vae_weights + + w = { + "decoder.conv_in.weight": mx.ones((384, 16, 3, 3, 3)), + "decoder.conv_in.bias": mx.ones((384,)), + "decoder.conv_out.weight": mx.ones((3, 96, 3, 3, 3)), + "decoder.conv_out.bias": mx.ones((3,)), + "decoder.norm_out.gamma": mx.ones((96, 1, 1, 1)), + } + s = sanitize_helios_vae_weights(w) + assert "decoder.conv1.weight" in s + assert "decoder.conv1.bias" in s + assert "decoder.head.2.weight" in s + assert "decoder.head.2.bias" in s + assert "decoder.head.0.gamma" in s + + def test_mid_block_mapping(self): + from mlx_video.convert_helios import sanitize_helios_vae_weights + + w = { + "decoder.mid_block.resnets.0.norm1.gamma": mx.ones((384, 1, 1, 1)), + "decoder.mid_block.resnets.0.conv1.weight": mx.ones((384, 384, 3, 3, 3)), + "decoder.mid_block.attentions.0.norm.gamma": mx.ones((384, 1, 1)), + "decoder.mid_block.resnets.1.conv2.bias": mx.ones((384,)), + } + s = sanitize_helios_vae_weights(w) + assert "decoder.middle.0.residual.0.gamma" in s + assert "decoder.middle.0.residual.2.weight" in s + assert "decoder.middle.1.norm.gamma" in s + assert "decoder.middle.2.residual.6.bias" in s + + def test_up_blocks_resnet_mapping(self): + from mlx_video.convert_helios import sanitize_helios_vae_weights + + w = { + "decoder.up_blocks.0.resnets.0.norm1.gamma": mx.ones((384, 1, 1, 1)), + "decoder.up_blocks.0.resnets.1.conv2.weight": mx.ones((384, 384, 3, 3, 3)), + "decoder.up_blocks.1.resnets.0.conv_shortcut.weight": mx.ones((384, 192, 1, 1, 1)), + } + s = sanitize_helios_vae_weights(w) + assert "decoder.upsamples.0.residual.0.gamma" in s + assert "decoder.upsamples.1.residual.6.weight" in s + assert "decoder.upsamples.4.shortcut.weight" in s + + def test_upsampler_mapping(self): + from mlx_video.convert_helios import sanitize_helios_vae_weights + + w = { + "decoder.up_blocks.0.upsamplers.0.resample.1.weight": mx.ones((192, 384, 3, 3)), + "decoder.up_blocks.0.upsamplers.0.time_conv.weight": mx.ones((768, 384, 3, 1, 1)), + } + s = sanitize_helios_vae_weights(w) + assert "decoder.upsamples.3.resample.1.weight" in s + assert "decoder.upsamples.3.time_conv.weight" in s + + def test_skips_encoder_keys(self): + from mlx_video.convert_helios import sanitize_helios_vae_weights + + w = { + "encoder.conv_in.weight": mx.ones((384, 3, 3, 3, 3)), + "encoder.mid_block.resnets.0.conv1.weight": mx.ones((384, 384, 3, 3, 3)), + } + s = sanitize_helios_vae_weights(w) + assert len(s) == 0