diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index e85d270a..410ad8d2 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -51,6 +51,7 @@ WAN2_1 = "wan2.1" WAN2_2 = "wan2.2" +LTX2_VIDEO = "ltx2_video" WAN_MODEL = WAN2_1 diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml new file mode 100644 index 00000000..57c51ffe --- /dev/null +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -0,0 +1,87 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False +attention: 'flash' +attention_sharding_uniform: True +precision: 'bf16' +data_sharding: ['data', 'fsdp', 'context', 'tensor'] +remat_policy: "NONE" +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + +run_name: '' +output_dir: '' +config_path: '' +save_config_to_gcs: False + +frame_rate: 30 +max_sequence_length: 1024 +sampler: "from_checkpoint" + +# Generation parameters +dataset_name: '' +dataset_save_location: '' +global_batch_size_to_train_on: 1 +num_inference_steps: 40 +guidance_scale: 3.0 +fps: 24 +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +height: 512 +width: 768 +num_frames: 121 +decode_timestep: 0.05 +decode_noise_scale: 0.025 +quantization: "int8" +seed: 10 +#parallelism +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_heads', 'fsdp'], + ['activation_batch', 'data'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'fsdp'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ['conv_in', 'fsdp'] + ] +dcn_data_parallelism: 1 +dcn_fsdp_parallelism: -1 +dcn_context_parallelism: 1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 +ici_context_parallelism: 1 +ici_tensor_parallelism: 1 +enable_profiler: False + +replicate_vae: False + +allow_split_physical_axes: False +learning_rate_schedule_steps: -1 +max_train_steps: 500 +pretrained_model_name_or_path: 'Lightricks/LTX-2' +model_name: "ltx2_video" +model_type: "T2V" +unet_checkpoint: '' +checkpoint_dir: "" +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +use_qwix_quantization: False +weight_quantization_calibration_method: "absmax" +act_quantization_calibration_method: "absmax" +bwd_quantization_calibration_method: "absmax" +qwix_module_path: ".*" +jit_initializers: True +enable_single_replica_ckpt_restoring: False \ No newline at end of file diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 41afa3b4..46a57360 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -501,3 +501,78 @@ def __call__(self, timestep, guidance, pooled_projection): conditioning = time_guidance_emb + pooled_projections return conditioning + + +class NNXTimesteps(nnx.Module): + + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def __call__(self, timesteps: jax.Array) -> jax.Array: + return get_sinusoidal_embeddings( + timesteps=timesteps, + embedding_dim=self.num_channels, + freq_shift=self.downscale_freq_shift, + flip_sin_to_cos=self.flip_sin_to_cos, + scale=self.scale, + ) + + +class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + embedding_dim: int, + size_emb_dim: int, + use_additional_conditions: bool = False, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.outdim = size_emb_dim + self.use_additional_conditions = use_additional_conditions + + self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = NNXTimestepEmbedding( + rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype + ) + + if use_additional_conditions: + self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = NNXTimestepEmbedding( + rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype + ) + self.aspect_ratio_embedder = NNXTimestepEmbedding( + rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype + ) + + def __call__( + self, + timestep: jax.Array, + resolution: Optional[jax.Array] = None, + aspect_ratio: Optional[jax.Array] = None, + hidden_dtype: jnp.dtype = jnp.float32, + ) -> jax.Array: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + + if self.use_additional_conditions: + if resolution is None or aspect_ratio is None: + raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True") + + resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb) + resolution_emb = resolution_emb.reshape(timestep.shape[0], -1) + + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb) + aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1) + + conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1) + else: + conditioning = timesteps_emb + + return conditioning diff --git a/src/maxdiffusion/models/ltx2/transformer_ltx2.py b/src/maxdiffusion/models/ltx2/transformer_ltx2.py new file mode 100644 index 00000000..7382aae5 --- /dev/null +++ b/src/maxdiffusion/models/ltx2/transformer_ltx2.py @@ -0,0 +1,1038 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import Optional, Tuple, Any, Dict +import jax +import jax.numpy as jnp +from flax import nnx +import flax.linen as nn + +from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed +from maxdiffusion.models.attention_flax import NNXSimpleFeedForward +from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection +from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.configuration_utils import ConfigMixin, register_to_config + + +class LTX2AdaLayerNormSingle(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + embedding_dim: int, + num_mod_params: int = 6, + use_additional_conditions: bool = False, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.num_mod_params = num_mod_params + self.use_additional_conditions = use_additional_conditions + self.emb = NNXPixArtAlphaCombinedTimestepSizeEmbeddings( + rngs=rngs, + embedding_dim=embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + weights_dtype=weights_dtype, + ) + self.silu = nnx.silu + self.linear = nnx.Linear( + rngs=rngs, + in_features=embedding_dim, + out_features=num_mod_params * embedding_dim, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + def __call__( + self, + timestep: jax.Array, + added_cond_kwargs: Optional[Dict[str, jax.Array]] = None, + batch_size: Optional[int] = None, # Unused in JAX path usually inferred + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jax.Array, jax.Array]: + resolution = None + aspect_ratio = None + if self.use_additional_conditions: + if added_cond_kwargs is None: + raise ValueError("added_cond_kwargs must be provided when use_additional_conditions is True") + resolution = added_cond_kwargs.get("resolution", None) + aspect_ratio = added_cond_kwargs.get("aspect_ratio", None) + + embedded_timestep = self.emb(timestep, resolution=resolution, aspect_ratio=aspect_ratio, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2VideoTransformerBlock(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + activation_fn: str = "gelu", + attention_bias: bool = True, + attention_out_bias: bool = True, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + rope_type: str = "interleaved", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + mesh: jax.sharding.Mesh = None, + remat_policy: str = "None", + precision: jax.lax.Precision = None, + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + attention_kernel: str = "flash", + ): + self.dim = dim + self.norm_eps = norm_eps + self.norm_elementwise_affine = norm_elementwise_affine + self.attention_kernel = attention_kernel + + # 1. Self-Attention (video and audio) + self.norm1 = nnx.RMSNorm( + self.dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.attn1 = LTX2Attention( + rngs=rngs, + query_dim=self.dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=0.0, + bias=attention_bias, + out_bias=attention_out_bias, + eps=norm_eps, + dtype=dtype, + mesh=mesh, + attention_kernel=self.attention_kernel, + rope_type=rope_type, + ) + + self.audio_norm1 = nnx.RMSNorm( + audio_dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.audio_attn1 = LTX2Attention( + rngs=rngs, + query_dim=audio_dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + dropout=0.0, + bias=attention_bias, + out_bias=attention_out_bias, + eps=norm_eps, + dtype=dtype, + mesh=mesh, + attention_kernel=self.attention_kernel, + rope_type=rope_type, + ) + + # 2. Prompt Cross-Attention + self.norm2 = nnx.RMSNorm( + self.dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.attn2 = LTX2Attention( + rngs=rngs, + query_dim=dim, + context_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=0.0, + bias=attention_bias, + out_bias=attention_out_bias, + eps=norm_eps, + dtype=dtype, + mesh=mesh, + attention_kernel=self.attention_kernel, + rope_type=rope_type, + ) + + self.audio_norm2 = nnx.RMSNorm( + audio_dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.audio_attn2 = LTX2Attention( + rngs=rngs, + query_dim=audio_dim, + context_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + dropout=0.0, + bias=attention_bias, + out_bias=attention_out_bias, + eps=norm_eps, + dtype=dtype, + mesh=mesh, + attention_kernel=self.attention_kernel, + rope_type=rope_type, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + self.audio_to_video_norm = nnx.RMSNorm( + dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.audio_to_video_attn = LTX2Attention( + rngs=rngs, + query_dim=dim, + context_dim=audio_dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + dropout=0.0, + bias=attention_bias, + out_bias=attention_out_bias, + eps=norm_eps, + dtype=dtype, + mesh=mesh, + attention_kernel=self.attention_kernel, + rope_type=rope_type, + ) + + self.video_to_audio_norm = nnx.RMSNorm( + audio_dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.video_to_audio_attn = LTX2Attention( + rngs=rngs, + query_dim=audio_dim, + context_dim=dim, + heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + dropout=0.0, + bias=attention_bias, + out_bias=attention_out_bias, + eps=norm_eps, + dtype=dtype, + mesh=mesh, + attention_kernel=self.attention_kernel, + rope_type=rope_type, + ) + + # 4. Feed Forward + self.norm3 = nnx.RMSNorm( + dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.ff = NNXSimpleFeedForward( + rngs=rngs, + dim=dim, + dim_out=dim, + activation_fn=activation_fn, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + self.audio_norm3 = nnx.RMSNorm( + audio_dim, + epsilon=self.norm_eps, + use_scale=self.norm_elementwise_affine, + rngs=rngs, + dtype=jnp.float32, + param_dtype=jnp.float32, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)), + ) + self.audio_ff = NNXSimpleFeedForward( + rngs=rngs, dim=audio_dim, dim_out=audio_dim, activation_fn=activation_fn, dtype=dtype, weights_dtype=weights_dtype + ) + + key = rngs.params() + k1, k2, k3, k4 = jax.random.split(key, 4) + + self.scale_shift_table = nnx.Param( + jax.random.normal(k1, (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim), + kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + ) + self.audio_scale_shift_table = nnx.Param( + jax.random.normal(k2, (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim), + kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + ) + self.video_a2v_cross_attn_scale_shift_table = nnx.Param( + jax.random.normal(k3, (5, self.dim), dtype=weights_dtype), + kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + ) + self.audio_a2v_cross_attn_scale_shift_table = nnx.Param( + jax.random.normal(k4, (5, audio_dim), dtype=weights_dtype), + kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")), + ) + + def __call__( + self, + hidden_states: jax.Array, # Video + audio_hidden_states: jax.Array, # Audio + encoder_hidden_states: jax.Array, # Context (Text) + audio_encoder_hidden_states: jax.Array, # Audio Context + # Timestep embeddings for AdaLN + temb: jax.Array, + temb_audio: jax.Array, + temb_ca_scale_shift: jax.Array, + temb_ca_audio_scale_shift: jax.Array, + temb_ca_gate: jax.Array, + temb_ca_audio_gate: jax.Array, + # RoPE + video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, + audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, + ca_video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, + ca_audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None, + attention_mask: Optional[jax.Array] = None, + encoder_attention_mask: Optional[jax.Array] = None, + audio_encoder_attention_mask: Optional[jax.Array] = None, + a2v_cross_attention_mask: Optional[jax.Array] = None, + v2a_cross_attention_mask: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, jax.Array]: + batch_size = hidden_states.shape[0] + + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) + audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names) + + if encoder_hidden_states is not None: + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) + if audio_encoder_hidden_states is not None: + audio_encoder_hidden_states = jax.lax.with_sharding_constraint(audio_encoder_hidden_states, axis_names) + + # 1. Video and Audio Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + # Calculate Video AdaLN values + num_ada_params = self.scale_shift_table.shape[0] + # table shape: (6, dim) -> (1, 1, 6, dim) + scale_shift_table_reshaped = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + # temb shape: (batch, temb_dim) -> (batch, 1, 6, dim) + temb_reshaped = temb.reshape(batch_size, 1, num_ada_params, -1) + ada_values = scale_shift_table_reshaped + temb_reshaped + + # Diffusers Order: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp + shift_msa = ada_values[:, :, 0, :] + scale_msa = ada_values[:, :, 1, :] + gate_msa = ada_values[:, :, 2, :] + shift_mlp = ada_values[:, :, 3, :] + scale_mlp = ada_values[:, :, 4, :] + gate_mlp = ada_values[:, :, 5, :] + + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + # Calculate Audio AdaLN values + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_scale_shift_table_reshaped = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + temb_audio_reshaped = temb_audio.reshape(batch_size, 1, num_audio_ada_params, -1) + audio_ada_values = audio_scale_shift_table_reshaped + temb_audio_reshaped + + audio_shift_msa = audio_ada_values[:, :, 0, :] + audio_scale_msa = audio_ada_values[:, :, 1, :] + audio_gate_msa = audio_ada_values[:, :, 2, :] + audio_shift_mlp = audio_ada_values[:, :, 3, :] + audio_scale_mlp = audio_ada_values[:, :, 4, :] + audio_gate_mlp = audio_ada_values[:, :, 5, :] + + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) + + # Calculate Cross-Attention Modulation values + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + # table: (4, dim) -> (1, 1, 4, dim) + video_ca_scale_shift_table = jnp.expand_dims(video_per_layer_ca_scale_shift, axis=(0, 1)) + temb_ca_scale_shift.reshape( + batch_size, 1, 4, -1 + ) + + video_a2v_ca_scale = video_ca_scale_shift_table[:, :, 0, :] + video_a2v_ca_shift = video_ca_scale_shift_table[:, :, 1, :] + video_v2a_ca_scale = video_ca_scale_shift_table[:, :, 2, :] + video_v2a_ca_shift = video_ca_scale_shift_table[:, :, 3, :] + + # table: (1, dim) -> (1, 1, 1, dim) + a2v_gate = (jnp.expand_dims(video_per_layer_ca_gate, axis=(0, 1)) + temb_ca_gate.reshape(batch_size, 1, 1, -1))[ + :, :, 0, : + ] + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = jnp.expand_dims( + audio_per_layer_ca_scale_shift, axis=(0, 1) + ) + temb_ca_audio_scale_shift.reshape(batch_size, 1, 4, -1) + + audio_a2v_ca_scale = audio_ca_scale_shift_table[:, :, 0, :] + audio_a2v_ca_shift = audio_ca_scale_shift_table[:, :, 1, :] + audio_v2a_ca_scale = audio_ca_scale_shift_table[:, :, 2, :] + audio_v2a_ca_shift = audio_ca_scale_shift_table[:, :, 3, :] + + v2a_gate = (jnp.expand_dims(audio_per_layer_ca_gate, axis=(0, 1)) + temb_ca_audio_gate.reshape(batch_size, 1, 1, -1))[ + :, :, 0, : + ] + + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale) + video_a2v_ca_shift + mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale) + audio_a2v_ca_shift + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + rotary_emb=ca_video_rotary_emb, + k_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states_v2a = norm_hidden_states * (1 + video_v2a_ca_scale) + video_v2a_ca_shift + mod_norm_audio_hidden_states_v2a = norm_audio_hidden_states * (1 + audio_v2a_ca_scale) + audio_v2a_ca_shift + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states_v2a, + encoder_hidden_states=mod_norm_hidden_states_v2a, + rotary_emb=ca_audio_rotary_emb, + k_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2VideoTransformer3DModel(nnx.Module, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int = 128, # Video Arguments + out_channels: Optional[int] = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: Tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: Optional[int] = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + mesh: jax.sharding.Mesh = None, + remat_policy: str = "None", + precision: jax.lax.Precision = None, + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + scan_layers: bool = True, + attention_kernel: str = "flash", + qk_norm: str = "rms_norm_across_heads", + **kwargs, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.cross_attention_dim = cross_attention_dim + self.vae_scale_factors = vae_scale_factors + self.pos_embed_max_pos = pos_embed_max_pos + self.base_height = base_height + self.base_width = base_width + self.audio_in_channels = audio_in_channels + self.audio_out_channels = audio_out_channels + self.audio_patch_size = audio_patch_size + self.audio_patch_size_t = audio_patch_size_t + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_scale_factor = audio_scale_factor + self.audio_pos_embed_max_pos = audio_pos_embed_max_pos + self.audio_sampling_rate = audio_sampling_rate + self.audio_hop_length = audio_hop_length + self.num_layers = num_layers + self.activation_fn = activation_fn + self.norm_elementwise_affine = norm_elementwise_affine + self.norm_eps = norm_eps + self.caption_channels = caption_channels + self.attention_bias = attention_bias + self.attention_out_bias = attention_out_bias + self.rope_theta = rope_theta + self.rope_double_precision = rope_double_precision + self.causal_offset = causal_offset + self.timestep_scale_multiplier = timestep_scale_multiplier + self.cross_attn_timestep_scale_multiplier = cross_attn_timestep_scale_multiplier + self.rope_type = rope_type + self.dtype = dtype + self.weights_dtype = weights_dtype + self.mesh = mesh + self.remat_policy = remat_policy + self.precision = precision + self.names_which_can_be_saved = names_which_can_be_saved + self.names_which_can_be_offloaded = names_which_can_be_offloaded + self.scan_layers = scan_layers + self.attention_kernel = attention_kernel + + _out_channels = self.out_channels or self.in_channels + _audio_out_channels = self.audio_out_channels or self.audio_in_channels + inner_dim = self.num_attention_heads * self.attention_head_dim + audio_inner_dim = self.audio_num_attention_heads * self.audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nnx.Linear( + self.in_channels, + inner_dim, + rngs=rngs, + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + self.audio_proj_in = nnx.Linear( + self.audio_in_channels, + audio_inner_dim, + rngs=rngs, + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + # 2. Prompt embeddings + self.caption_projection = NNXPixArtAlphaTextProjection( + rngs=rngs, + in_features=self.caption_channels, + hidden_size=inner_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + self.audio_caption_projection = NNXPixArtAlphaTextProjection( + rngs=rngs, + in_features=self.caption_channels, + hidden_size=audio_inner_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + # 3. Timestep Modulation Params and Embedding + self.time_embed = LTX2AdaLayerNormSingle( + rngs=rngs, + embedding_dim=inner_dim, + num_mod_params=6, + use_additional_conditions=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + self.audio_time_embed = LTX2AdaLayerNormSingle( + rngs=rngs, + embedding_dim=audio_inner_dim, + num_mod_params=6, + use_additional_conditions=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + rngs=rngs, + embedding_dim=inner_dim, + num_mod_params=4, + use_additional_conditions=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + rngs=rngs, + embedding_dim=audio_inner_dim, + num_mod_params=4, + use_additional_conditions=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + rngs=rngs, + embedding_dim=inner_dim, + num_mod_params=1, + use_additional_conditions=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + rngs=rngs, + embedding_dim=audio_inner_dim, + num_mod_params=1, + use_additional_conditions=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + ) + + # 3. Output Layer Scale/Shift Modulation parameters + param_rng = rngs.params() + self.scale_shift_table = nnx.Param( + jax.random.normal(param_rng, (2, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")), + ) + self.audio_scale_shift_table = nnx.Param( + jax.random.normal(param_rng, (2, audio_inner_dim), dtype=self.weights_dtype) / jnp.sqrt(audio_inner_dim), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")), + ) + + # 4. Rotary Positional Embeddings (RoPE) + self.rope = LTX2RotaryPosEmbed( + dim=inner_dim, + patch_size=self.patch_size, + patch_size_t=self.patch_size_t, + base_num_frames=self.pos_embed_max_pos, + base_height=self.base_height, + base_width=self.base_width, + scale_factors=self.vae_scale_factors, + theta=self.rope_theta, + causal_offset=self.causal_offset, + modality="video", + double_precision=self.rope_double_precision, + rope_type=self.rope_type, + num_attention_heads=self.num_attention_heads, + ) + self.audio_rope = LTX2RotaryPosEmbed( + dim=audio_inner_dim, + patch_size=self.audio_patch_size, + patch_size_t=self.audio_patch_size_t, + base_num_frames=self.audio_pos_embed_max_pos, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + scale_factors=[self.audio_scale_factor], + theta=self.rope_theta, + causal_offset=self.causal_offset, + modality="audio", + double_precision=self.rope_double_precision, + rope_type=self.rope_type, + num_attention_heads=self.audio_num_attention_heads, + ) + + cross_attn_pos_embed_max_pos = max(self.pos_embed_max_pos, self.audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2RotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=self.patch_size, + patch_size_t=self.patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=self.base_height, + base_width=self.base_width, + theta=self.rope_theta, + causal_offset=self.causal_offset, + modality="video", + double_precision=self.rope_double_precision, + rope_type=self.rope_type, + num_attention_heads=self.num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2RotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=self.audio_patch_size, + patch_size_t=self.audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + theta=self.rope_theta, + causal_offset=self.causal_offset, + modality="audio", + double_precision=self.rope_double_precision, + rope_type=self.rope_type, + num_attention_heads=self.audio_num_attention_heads, + ) + + # 5. Transformer Blocks + @nnx.split_rngs(splits=self.num_layers) + @nnx.vmap(in_axes=0, out_axes=0, axis_size=self.num_layers, transform_metadata={nnx.PARTITION_NAME: "layers"}) + def init_block(rngs): + return LTX2VideoTransformerBlock( + rngs=rngs, + dim=inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + cross_attention_dim=inner_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=self.audio_num_attention_heads, + audio_attention_head_dim=self.audio_attention_head_dim, + audio_cross_attention_dim=audio_inner_dim, + activation_fn=self.activation_fn, + attention_bias=self.attention_bias, + attention_out_bias=self.attention_out_bias, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + rope_type=self.rope_type, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + mesh=self.mesh, + remat_policy=self.remat_policy, + precision=self.precision, + names_which_can_be_saved=self.names_which_can_be_saved, + names_which_can_be_offloaded=self.names_which_can_be_offloaded, + attention_kernel=self.attention_kernel, + ) + + if self.scan_layers: + self.transformer_blocks = init_block(rngs) + else: + blocks = [] + for _ in range(self.num_layers): + block = LTX2VideoTransformerBlock( + rngs=rngs, + dim=inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + cross_attention_dim=inner_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=self.audio_num_attention_heads, + audio_attention_head_dim=self.audio_attention_head_dim, + audio_cross_attention_dim=audio_inner_dim, + activation_fn=self.activation_fn, + attention_bias=self.attention_bias, + attention_out_bias=self.attention_out_bias, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + rope_type=self.rope_type, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + mesh=self.mesh, + remat_policy=self.remat_policy, + precision=self.precision, + names_which_can_be_saved=self.names_which_can_be_saved, + names_which_can_be_offloaded=self.names_which_can_be_offloaded, + attention_kernel=self.attention_kernel, + ) + blocks.append(block) + self.transformer_blocks = nnx.List(blocks) + + # 6. Output layers + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + self.norm_out = nnx.LayerNorm( + inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32 + ) + self.proj_out = nnx.Linear( + inner_dim, + _out_channels, + rngs=rngs, + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + self.audio_norm_out = nnx.LayerNorm( + audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32 + ) + self.audio_proj_out = nnx.Linear( + audio_inner_dim, + _audio_out_channels, + rngs=rngs, + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + def __call__( + self, + hidden_states: jax.Array, + audio_hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + audio_encoder_hidden_states: jax.Array, + timestep: jax.Array, + audio_timestep: Optional[jax.Array] = None, + encoder_attention_mask: Optional[jax.Array] = None, + audio_encoder_attention_mask: Optional[jax.Array] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + audio_num_frames: Optional[int] = None, + video_coords: Optional[jax.Array] = None, + audio_coords: Optional[jax.Array] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Any: + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + if self.attention_kernel == "dot_product": + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.astype(self.dtype)) * -10000.0 + encoder_attention_mask = jnp.expand_dims(encoder_attention_mask, axis=1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.astype(self.dtype)) * -10000.0 + audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1) + + batch_size = hidden_states.shape[0] + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames) + + video_rotary_emb = self.rope(video_coords) + audio_rotary_emb = self.audio_rope(audio_coords) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :]) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :]) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.reshape(batch_size, -1, temb.shape[-1]) + embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1]) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1]) + audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1]) + + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1]) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) + + # 5. Run transformer blocks + def scan_fn(carry, block): + hidden_states, audio_hidden_states, rngs_carry = carry + hidden_states_out, audio_hidden_states_out = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + return ( + hidden_states_out.astype(hidden_states.dtype), + audio_hidden_states_out.astype(audio_hidden_states.dtype), + rngs_carry, + ), None + + if self.scan_layers: + rematted_scan_fn = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block + (hidden_states, audio_hidden_states, _), _ = nnx.scan( + rematted_scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + transform_metadata={nnx.PARTITION_NAME: "layers"}, + )(carry, self.transformer_blocks) + else: + for block in self.transformer_blocks: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + # 6. Output layers + scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2) + shift = scale_shift_values[:, :, 0, :] + scale = scale_shift_values[:, :, 1, :] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims( + audio_embedded_timestep, axis=2 + ) + audio_shift = audio_scale_shift_values[:, :, 0, :] + audio_scale = audio_scale_shift_values[:, :, 1, :] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if not return_dict: + return (output, audio_output) + return {"sample": output, "audio_sample": audio_output} diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 61f17932..ebcdeaea 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,9 +27,9 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES +from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, LTX2_VIDEO, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES -_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2} +_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO} _ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1} diff --git a/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py b/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py new file mode 100644 index 00000000..fcf8f282 --- /dev/null +++ b/src/maxdiffusion/tests/ltx2/test_transformer_ltx2.py @@ -0,0 +1,283 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import jax +import jax.numpy as jnp +import unittest +from absl.testing import absltest +from flax import nnx +from jax.sharding import Mesh +from flax.linen import partitioning as nn_partitioning +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import create_device_mesh +from maxdiffusion.models.ltx2.transformer_ltx2 import ( + LTX2VideoTransformerBlock, + LTX2VideoTransformer3DModel, + LTX2AdaLayerNormSingle, + LTX2RotaryPosEmbed, +) +import flax + + +flax.config.update("flax_always_shard_variable", False) + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class LTX2TransformerTest(unittest.TestCase): + + def setUp(self): + LTX2TransformerTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "..", "configs", "ltx2_video.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) + + self.batch_size = 1 + self.num_frames = 4 + self.height = 32 + self.width = 32 + self.patch_size = 1 + self.patch_size_t = 1 + + self.in_channels = 8 + self.out_channels = 8 + self.audio_in_channels = 4 + + self.seq_len = ( + (self.num_frames // self.patch_size_t) * (self.height // self.patch_size) * (self.width // self.patch_size) + ) + + self.dim = 1024 + self.num_heads = 8 + self.head_dim = 128 + self.cross_dim = 1024 # context dim + + self.audio_dim = 1024 + self.audio_num_heads = 8 + self.audio_head_dim = 128 + self.audio_cross_dim = 1024 + + def test_ltx2_rope(self): + """Tests LTX2RotaryPosEmbed output shapes and basic functionality.""" + dim = self.dim + patch_size = self.patch_size + patch_size_t = self.patch_size_t + base_num_frames = 8 + base_height = 32 + base_width = 32 + + # Video RoPE + rope = LTX2RotaryPosEmbed( + dim=dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=base_num_frames, + base_height=base_height, + base_width=base_width, + modality="video", + ) + ids = jnp.ones((1, 10, 3)) # (B, S, Axes) for 3D coords + cos, sin = rope(ids) + + # Check output shape + self.assertEqual(cos.shape, (1, 10, dim)) + self.assertEqual(sin.shape, (1, 10, dim)) + + def test_ltx2_rope_split(self): + """Tests LTX2RotaryPosEmbed with rope_type='split'.""" + dim = self.dim + patch_size = self.patch_size + patch_size_t = self.patch_size_t + base_num_frames = 8 + base_height = 32 + base_width = 32 + + # Video RoPE Split + rope = LTX2RotaryPosEmbed( + dim=dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=base_num_frames, + base_height=base_height, + base_width=base_width, + modality="video", + rope_type="split", + ) + ids = jnp.ones((1, 10, 3)) # (B, S, Axes) + cos, sin = rope(ids) + + # Check output shape + # Split RoPE returns [B, H, S, D//2] + # dim=1024, heads=32 => head_dim=32 => D//2 = 16 + self.assertEqual(cos.shape, (1, 32, 10, 16)) + self.assertEqual(sin.shape, (1, 32, 10, 16)) + + def test_ltx2_ada_layer_norm_single(self): + """Tests LTX2AdaLayerNormSingle initialization and execution.""" + key = jax.random.key(0) + rngs = nnx.Rngs(key) + embedding_dim = self.dim + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = LTX2AdaLayerNormSingle( + rngs=rngs, embedding_dim=embedding_dim, num_mod_params=6, use_additional_conditions=False # Default + ) + + timestep = jnp.array([1.0]) + batch_size = self.batch_size + + # Forward + output, embedded_timestep = layer(timestep) + + # Expected output shape: (B, num_mod_params * embedding_dim) + # embedded_timestep shape: (B, embedding_dim) + self.assertEqual(output.shape, (batch_size, 6 * embedding_dim)) + self.assertEqual(embedded_timestep.shape, (batch_size, embedding_dim)) + + def test_ltx2_transformer_block(self): + """Tests LTX2VideoTransformerBlock with video and audio inputs.""" + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + dim = self.dim + audio_dim = self.audio_dim + cross_attention_dim = self.cross_dim + audio_cross_attention_dim = self.audio_cross_dim + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + block = LTX2VideoTransformerBlock( + rngs=rngs, + dim=dim, + num_attention_heads=self.num_heads, + attention_head_dim=self.head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_dim, + audio_num_attention_heads=self.audio_num_heads, + audio_attention_head_dim=self.audio_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, + mesh=self.mesh, + ) + + batch_size = self.batch_size + seq_len = self.seq_len + audio_seq_len = 128 # Matching parity test + + hidden_states = jnp.zeros((batch_size, seq_len, dim)) + audio_hidden_states = jnp.zeros((batch_size, audio_seq_len, audio_dim)) + encoder_hidden_states = jnp.zeros((batch_size, 128, cross_attention_dim)) + audio_encoder_hidden_states = jnp.zeros((batch_size, 128, audio_cross_attention_dim)) + + # Mock modulation parameters + # sizes based on `transformer_ltx2.py` logic + temb_dim = 6 * dim # 6 params * dim + temb = jnp.zeros((batch_size, temb_dim)) + temb_audio = jnp.zeros((batch_size, 6 * audio_dim)) + + temb_ca_scale_shift = jnp.zeros((batch_size, 4 * dim)) + temb_ca_audio_scale_shift = jnp.zeros((batch_size, 4 * audio_dim)) + temb_ca_gate = jnp.zeros((batch_size, 1 * dim)) + temb_ca_audio_gate = jnp.zeros((batch_size, 1 * audio_dim)) + + output_hidden, output_audio = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=temb_ca_scale_shift, + temb_ca_audio_scale_shift=temb_ca_audio_scale_shift, + temb_ca_gate=temb_ca_gate, + temb_ca_audio_gate=temb_ca_audio_gate, + ) + + self.assertEqual(output_hidden.shape, hidden_states.shape) + self.assertEqual(output_audio.shape, audio_hidden_states.shape) + + def test_ltx2_transformer_model(self): + """Tests LTX2VideoTransformer3DModel full forward pass.""" + key = jax.random.key(0) + rngs = nnx.Rngs(key) + + in_channels = self.in_channels + out_channels = self.out_channels + audio_in_channels = self.audio_in_channels + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + model = LTX2VideoTransformer3DModel( + rngs=rngs, + in_channels=in_channels, + out_channels=out_channels, + patch_size=self.patch_size, + patch_size_t=self.patch_size_t, + num_attention_heads=self.num_heads, + attention_head_dim=self.head_dim, + cross_attention_dim=self.cross_dim, + caption_channels=32, + audio_in_channels=audio_in_channels, + audio_out_channels=audio_in_channels, + audio_num_attention_heads=self.audio_num_heads, + audio_attention_head_dim=self.audio_head_dim, + audio_cross_attention_dim=self.audio_cross_dim, + num_layers=1, + mesh=self.mesh, + attention_kernel="dot_product", + ) + + batch_size = self.batch_size + seq_len = self.seq_len + audio_seq_len = 128 + + hidden_states = jnp.zeros((batch_size, seq_len, in_channels)) + audio_hidden_states = jnp.zeros((batch_size, audio_seq_len, audio_in_channels)) + + timestep = jnp.array([1.0]) + encoder_hidden_states = jnp.zeros((batch_size, 128, 32)) # (B, L, D) match caption_channels + audio_encoder_hidden_states = jnp.zeros((batch_size, 128, 32)) + + encoder_attention_mask = jnp.ones((batch_size, 128)) + audio_encoder_attention_mask = jnp.ones((batch_size, 128)) + + output = model( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + timestep=timestep, + num_frames=self.num_frames, + height=self.height, + width=self.width, + audio_num_frames=audio_seq_len, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + return_dict=True, + ) + + self.assertEqual(output["sample"].shape, (batch_size, seq_len, out_channels)) + self.assertEqual(output["audio_sample"].shape, (batch_size, audio_seq_len, audio_in_channels)) + + +if __name__ == "__main__": + absltest.main()