diff --git a/src/maxdiffusion/models/ltx2/__init__.py b/src/maxdiffusion/models/ltx2/__init__.py new file mode 100644 index 00000000..11f31009 --- /dev/null +++ b/src/maxdiffusion/models/ltx2/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py new file mode 100644 index 00000000..8c8a46ff --- /dev/null +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -0,0 +1,487 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Tuple +from flax import nnx +import jax.numpy as jnp +from ... import common_types +from ..attention_flax import NNXAttentionOp + +Array = common_types.Array +Mesh = common_types.Mesh +DType = common_types.DType + + +def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array: + """ + Applies Interleaved RoPE to input x. + Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1]. + + Args: + x: Input tensor [..., D] + freqs: Tuple of (cos, sin), broadcasting to [..., D] + """ + cos, sin = freqs + + # 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2] + # This corresponds to "rearrange(..., (d r) -> ... d r, r=2)" + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) + + # 2. Split into components + # x_real = x[..., 0], x_imag = x[..., 1] + x_real, x_imag = x_reshaped[..., 0], x_reshaped[..., 1] + + # 3. Rotate [-x2, x1] + # Corresponds to "stack((-t2, t1))" + x_rotated = jnp.stack([-x_imag, x_real], axis=-1).reshape(*x.shape) + + # 4. Apply frequencies (Float32 for stability) + out = x.astype(jnp.float32) * cos + x_rotated.astype(jnp.float32) * sin + + return out.astype(x.dtype) + + +def apply_split_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array: + """ + Applies Split RoPE to input x. + Logic matches Diffusers apply_split_rotary_emb. + + Args: + x: Input tensor. + If ndim=3 [B, S, D], it will be reshaped to satisfy cos/sin shapes if needed. + freqs: Tuple of (cos, sin). + Expected to be [B, H, S, D//2] if coming from LTX2RotaryPosEmbed(split). + """ + cos, sin = freqs + + x_dtype = x.dtype + needed_reshape = False + original_shape = x.shape + + if x.ndim != 4 and cos.ndim == 4: + b = x.shape[0] + h, s, r = cos.shape[1], cos.shape[2], cos.shape[3] + x = x.reshape(b, s, h, -1).transpose(0, 2, 1, 3) + needed_reshape = True + + last_dim = x.shape[-1] + r = last_dim // 2 + + split_x = x.reshape(*x.shape[:-1], 2, r) + + first_x = split_x[..., 0, :] + second_x = split_x[..., 1, :] + + cos_u = jnp.expand_dims(cos, axis=-2) + sin_u = jnp.expand_dims(sin, axis=-2) + + out = split_x * cos_u + + out_first = out[..., 0, :] - second_x * sin_u.squeeze(-2) + out_second = out[..., 1, :] + first_x * sin_u.squeeze(-2) + + out = jnp.stack([out_first, out_second], axis=-2) + out = out.reshape(*out.shape[:-2], last_dim) + + if needed_reshape: + out = out.transpose(0, 2, 1, 3).reshape(original_shape) + + return out.astype(x_dtype) + + +class LTX2RotaryPosEmbed(nnx.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + Matches logic of LTX2AudioVideoRotaryPosEmbed from Diffusers. + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: Tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.base_num_frames = base_num_frames + self.base_height = base_height + self.base_width = base_width + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + self.modality = modality + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + if self.rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + fps: float = 24.0, + ) -> Array: + # 1. Generate grid coordinates for each spatiotemporal dimension + grid_f = jnp.arange(0, num_frames, self.patch_size_t, dtype=jnp.float32) + grid_h = jnp.arange(0, height, self.patch_size, dtype=jnp.float32) + grid_w = jnp.arange(0, width, self.patch_size, dtype=jnp.float32) + + # indexing='ij' ensures (frames, height, width) order + grid = jnp.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = jnp.stack(grid, axis=0) # [3, N_F, N_H, N_W] + + # 2. Get patch boundaries + patch_size_arr = jnp.array((self.patch_size_t, self.patch_size, self.patch_size), dtype=grid.dtype) + patch_size_delta = patch_size_arr.reshape(3, 1, 1, 1) + patch_ends = grid + patch_size_delta + + # Combine start and end coordinates + latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2] + latent_coords = latent_coords.transpose(1, 2, 3, 0, 4) # [N_F, N_H, N_W, 3, 2] + latent_coords = latent_coords.reshape(-1, 3, 2) # [num_patches, 3, 2] + latent_coords = jnp.expand_dims(latent_coords, 0) # [1, num_patches, 3, 2] + latent_coords = jnp.tile(latent_coords, (batch_size, 1, 1, 1)) # [B, num_patches, 3, 2] + + latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2] + latent_coords = latent_coords.reshape(3, -1, 2) # [3, num_patches, 2] + latent_coords = jnp.expand_dims(latent_coords, 0) # [1, 3, num_patches, 2] + latent_coords = jnp.tile(latent_coords, (batch_size, 1, 1, 1)) # [B, 3, num_patches, 2] + + # 3. Calculate pixel space coords + scale_tensor = jnp.array(self.scale_factors, dtype=latent_coords.dtype) + scale_tensor = scale_tensor.reshape(1, 3, 1, 1) + pixel_coords = latent_coords * scale_tensor + + # Causal clamp logic + # pixel_coords[:, 0, ...] selects Frame dimension. + # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W) + frame_coords = pixel_coords[:, 0, ...] + frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0) + pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps) + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + shift: int = 0, + ) -> Array: + # 1. Generate coordinates in frame (time) dimension + grid_f = jnp.arange(shift, num_frames + shift, self.patch_size_t, dtype=jnp.float32) + + # 2. Start timestamps + audio_scale_factor = self.scale_factors[0] + grid_start_mel = grid_f * audio_scale_factor + grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0) + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. End timestamps + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + # Stack [num_patches, 2] + audio_coords = jnp.stack([grid_start_s, grid_end_s], axis=-1) + # [num_patches, 2] -> [B, num_patches, 2] + audio_coords = jnp.expand_dims(audio_coords, 0) + audio_coords = jnp.tile(audio_coords, (batch_size, 1, 1)) + # [B, 1, num_patches, 2] + audio_coords = jnp.expand_dims(audio_coords, 1) + + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + return None + + def __call__(self, coords: Array) -> Tuple[Array, Array]: + # Handle both [B, num_pos_dims, num_patches, 2] (from prepare_coords) + # and [B, num_patches, num_pos_dims] (raw grid coordinates) + if coords.ndim == 4: + num_pos_dims = coords.shape[1] + # 1. Midpoint + coords_start = coords[..., 0] + coords_end = coords[..., 1] + coords = (coords_start + coords_end) / 2.0 # [B, num_pos_dims, num_patches] + # Transpose to standardize layout: [B, num_patches, num_pos_dims] + grid = coords.transpose(0, 2, 1) + elif coords.ndim == 3: + num_pos_dims = coords.shape[-1] + grid = coords # Already [B, num_patches, num_pos_dims] + else: + raise ValueError(f"coords must be 3D or 4D, got {coords.ndim}D") + + # 2. Fractions + if self.modality == "video": + max_positions = jnp.array((self.base_num_frames, self.base_height, self.base_width), dtype=coords.dtype) + elif self.modality == "audio": + max_positions = jnp.array((self.base_num_frames,), dtype=coords.dtype) + + max_positions = max_positions[:num_pos_dims] + # Reshape to broadcast with [B, num_patches, num_pos_dims] + max_positions = max_positions.reshape(1, 1, num_pos_dims) + + # Scale to [0, 1] + grid = grid / max_positions + + num_rope_elems = num_pos_dims * 2 + + # 3. Frequencies + freqs_dtype = jnp.float64 if self.double_precision else jnp.float32 + # linspace 0..1 + steps = self.dim // num_rope_elems + pow_indices = jnp.power(self.theta, jnp.linspace(0.0, 1.0, steps, dtype=freqs_dtype)) + base_freqs = (pow_indices * jnp.pi / 2.0).astype(jnp.float32) # [steps] + + # 4. Outer product + # Map grid [0, 1] -> [-1, 1] + scaled_grid = grid * 2.0 - 1.0 # [B, num_patches, num_pos_dims] + + # [B, num_patches, num_pos_dims, 1] * [steps] -> [B, num_patches, num_pos_dims, steps] + freqs = jnp.expand_dims(scaled_grid, -1) * base_freqs + + # CRITICAL: Transpose the last two dimensions to exactly match Diffusers flattening order! + freqs = jnp.swapaxes(freqs, -1, -2) # [B, num_patches, steps, num_pos_dims] + + # Flatten last two dims -> [B, num_patches, dim // 2] + freqs = freqs.reshape(*freqs.shape[:2], -1) + + # 5. Cos/Sin + cos_freqs = jnp.cos(freqs) + sin_freqs = jnp.sin(freqs) + + if self.rope_type == "interleaved": + # repeat interleave: [c1, c2] -> [c1, c1, c2, c2] + cos_freqs = jnp.repeat(cos_freqs, 2, axis=-1) + sin_freqs = jnp.repeat(sin_freqs, 2, axis=-1) + + # Padding if needed + if self.dim % num_rope_elems != 0: + curr_dim = cos_freqs.shape[-1] + pad_amt = self.dim - curr_dim + if pad_amt > 0: + cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_amt), dtype=cos_freqs.dtype) + sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_amt), dtype=sin_freqs.dtype) + cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1) + sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1) + + elif self.rope_type == "split": + # Cos/Sin + curr_dim = cos_freqs.shape[-1] + expected_dim = self.dim // 2 + pad_size = expected_dim - curr_dim + + if pad_size > 0: + cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_size), dtype=cos_freqs.dtype) + sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_size), dtype=sin_freqs.dtype) + cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1) + sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1) + + b = cos_freqs.shape[0] + s = cos_freqs.shape[1] + h = self.num_attention_heads + + cos_freqs = cos_freqs.reshape(b, s, h, -1).transpose(0, 2, 1, 3) + sin_freqs = sin_freqs.reshape(b, s, h, -1).transpose(0, 2, 1, 3) + + return cos_freqs, sin_freqs + + +class LTX2Attention(nnx.Module): + + def __init__( + self, + query_dim: int, + heads: int, + dim_head: int, + context_dim: Optional[int] = None, + dropout: float = 0.0, + bias: bool = True, # LTX-2 uses bias=True for projections + out_bias: bool = True, + rngs: nnx.Rngs = None, + mesh: Mesh = None, + eps: float = 1e-6, + dtype: DType = jnp.float32, + attention_kernel: str = "flash", + rope_type: str = "interleaved", + ): + self.heads = heads + self.rope_type = rope_type + self.dim_head = dim_head + self.inner_dim = dim_head * heads + self.dropout_rate = dropout + + # 1. Define Partitioned Initializers (Logical Axes) + # Q, K, V kernels: [in_features (embed), out_features (heads)] + qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")) + # Q, K, V biases: [out_features (heads)] + qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) + + # Out kernel: [in_features (heads), out_features (embed)] + out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) + # Out bias: [out_features (embed)] + out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) + + # Norm scales + norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)) + + # 2. Projections + self.to_q = nnx.Linear( + query_dim, + self.inner_dim, + use_bias=bias, + kernel_init=qkv_kernel_init, + bias_init=qkv_bias_init, + rngs=rngs, + dtype=dtype, + ) + + # Handle Self vs Cross Attention input dims + kv_dim = context_dim if context_dim is not None else query_dim + self.to_k = nnx.Linear( + kv_dim, self.inner_dim, use_bias=bias, kernel_init=qkv_kernel_init, bias_init=qkv_bias_init, rngs=rngs, dtype=dtype + ) + self.to_v = nnx.Linear( + kv_dim, self.inner_dim, use_bias=bias, kernel_init=qkv_kernel_init, bias_init=qkv_bias_init, rngs=rngs, dtype=dtype + ) + + # 3. Normalization (Applied to full inner_dim, NOT per-head) + self.norm_q = nnx.RMSNorm( + self.inner_dim, + epsilon=eps, + dtype=jnp.float32, + param_dtype=jnp.float32, + use_scale=True, + scale_init=norm_scale_init, + rngs=rngs, + ) + self.norm_k = nnx.RMSNorm( + self.inner_dim, + epsilon=eps, + dtype=jnp.float32, + param_dtype=jnp.float32, + use_scale=True, + scale_init=norm_scale_init, + rngs=rngs, + ) + + # 4. Output + self.to_out = nnx.Linear( + self.inner_dim, + query_dim, + use_bias=out_bias, + kernel_init=out_kernel_init, + bias_init=out_bias_init, + rngs=rngs, + dtype=dtype, + ) + + if self.dropout_rate > 0: + self.dropout_layer = nnx.Dropout(self.dropout_rate, rngs=rngs) + else: + self.dropout_layer = None + + self.attention_op = NNXAttentionOp( + mesh=mesh, + attention_kernel=attention_kernel, + scale=dim_head**-0.5, + heads=heads, + dim_head=dim_head, + dtype=dtype, + ) + + def __call__( + self, + hidden_states: Array, + encoder_hidden_states: Optional[Array] = None, + attention_mask: Optional[Array] = None, + rotary_emb: Optional[Tuple[Array, Array]] = None, + k_rotary_emb: Optional[Tuple[Array, Array]] = None, + ) -> Array: + # Determine context (Self or Cross) + context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + # 1. Project + query = self.to_q(hidden_states) + key = self.to_k(context) + value = self.to_v(context) + + # 2. Norm (Full Inner Dimension) + query = self.norm_q(query) + key = self.norm_k(key) + + # 3. Apply RoPE to tensors of shape [B, S, InnerDim] + # Frequencies are shape [B, S, InnerDim] + # 3. Apply RoPE + if rotary_emb is not None: + if hasattr(self, "rope_type") and self.rope_type == "split": + # Split RoPE: passing full freqs [B, H, S, D//2] + # apply_split_rotary_emb handles reshaping query/key + + query = apply_split_rotary_emb(query, rotary_emb) + + if k_rotary_emb is not None: + key = apply_split_rotary_emb(key, k_rotary_emb) + elif encoder_hidden_states is None: + key = apply_split_rotary_emb(key, rotary_emb) + + else: + # Interleaved (Default) + query = apply_rotary_emb(query, rotary_emb) + if k_rotary_emb is not None: + key = apply_rotary_emb(key, k_rotary_emb) + elif encoder_hidden_states is None: + key = apply_rotary_emb(key, rotary_emb) + + # 4. Attention + # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel + attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask) + + # 7. Output Projection + hidden_states = self.to_out(attn_output) + + if self.dropout_layer is not None: + hidden_states = self.dropout_layer(hidden_states) + + return hidden_states diff --git a/src/maxdiffusion/tests/ltx2/__init__.py b/src/maxdiffusion/tests/ltx2/__init__.py new file mode 100644 index 00000000..11f31009 --- /dev/null +++ b/src/maxdiffusion/tests/ltx2/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py b/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py new file mode 100644 index 00000000..9acc147e --- /dev/null +++ b/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py @@ -0,0 +1,399 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import torch +import numpy as np +import jax +import jax.numpy as jnp +from flax import nnx +import pandas as pd +from jax.sharding import Mesh +from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed + +# ========================================== +# 1. PyTorch Reference Implementations +# ========================================== + + +class PytorchLTX2RotaryPosEmbed(torch.nn.Module): + """ + Exact mathematical replica of Diffusers LTX2AudioVideoRotaryPosEmbed.forward + stripped down for testing the core RoPE frequency generation logic. + """ + + def __init__( + self, dim: int, theta: float = 10000.0, base_dims=(20, 2048, 2048), rope_type="interleaved", num_attention_heads=32 + ): + super().__init__() + self.dim = dim + self.theta = theta + self.base_dims = base_dims + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + self.double_precision = True + + def forward(self, ids): + # Test passes ids as [Batch, Sequence, NumAxes] + num_axes = ids.shape[-1] + + # 1. Scale by max_positions -> [B, S, num_axes] + max_pos = torch.tensor(self.base_dims[:num_axes], dtype=torch.float32, device=ids.device) + grid = ids / max_pos.view(1, 1, num_axes) + + # 2. Map to [-1, 1] + scaled_grid = grid * 2.0 - 1.0 + + # 3. Base Frequencies + num_rope_elems = num_axes * 2 + dim_per_axis = self.dim // num_rope_elems + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=dim_per_axis, dtype=freqs_dtype, device=ids.device), + ) + base_freqs = (pow_indices * (torch.pi / 2.0)).to(dtype=torch.float32) # [steps] + + # 4. Outer Product & Transpose (Diffusers specific logic) + # grid: [B, S, num_axes, 1] * base_freqs: [steps] -> [B, S, num_axes, steps] + freqs = scaled_grid.unsqueeze(-1) * base_freqs + # Transpose last two dims: [B, S, steps, num_axes] + freqs = freqs.transpose(-1, -2) + # Flatten: [B, S, steps * num_axes] + emb = freqs.flatten(2) + + cos = torch.cos(emb) + sin = torch.sin(emb) + + if self.rope_type == "interleaved": + # Interleave: [c1, c2] -> [c1, c1, c2, c2] + cos = torch.repeat_interleave(cos, 2, dim=-1) + sin = torch.repeat_interleave(sin, 2, dim=-1) + + if self.dim % num_rope_elems != 0: + pad_amt = self.dim - cos.shape[-1] + cos_padding = torch.ones_like(cos[..., :pad_amt]) + sin_padding = torch.zeros_like(sin[..., :pad_amt]) + cos = torch.cat([cos_padding, cos], dim=-1) + sin = torch.cat([sin_padding, sin], dim=-1) + + elif self.rope_type == "split": + pad_size = (self.dim // 2) - cos.shape[-1] + if pad_size > 0: + cos_padding = torch.ones_like(cos[..., :pad_size]) + sin_padding = torch.zeros_like(sin[..., :pad_size]) + cos = torch.cat([cos_padding, cos], dim=-1) + sin = torch.cat([sin_padding, sin], dim=-1) + + b, s, _ = cos.shape + cos = cos.view(b, s, self.num_attention_heads, -1).transpose(1, 2) + sin = sin.view(b, s, self.num_attention_heads, -1).transpose(1, 2) + + return cos, sin + + +def apply_rotary_emb_pt(x, cos, sin): + """ + Standard PyTorch Interleaved RoPE application. + Dimension-agnostic: Works for [B, S, D] or [B, H, S, D]. + """ + # 1. Reshape last dim to pairs: [..., D] -> [..., D//2, 2] + shape = x.shape + x_reshaped = x.view(*shape[:-1], -1, 2) + + # 2. Rotate: [-x2, x1] + x1, x2 = x_reshaped.unbind(-1) + x_rotated = torch.stack((-x2, x1), dim=-1).view(*shape) + + # 3. Apply Frequencies (Float32 for parity) + orig_dtype = x.dtype + x_f32 = x.to(torch.float32) + rot_f32 = x_rotated.to(torch.float32) + cos_f32 = cos.to(torch.float32) + sin_f32 = sin.to(torch.float32) + + out = x_f32 * cos_f32 + rot_f32 * sin_f32 + return out.to(orig_dtype) + + +class PytorchLTX2Attention(torch.nn.Module): + + def __init__(self, query_dim, context_dim, heads, dim_head): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=1e-6) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=1e-6) + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None): + q = self.to_q(x) + ctx = x if context is None else context + k = self.to_k(ctx) + v = self.to_v(ctx) + + # Keep raw projections for test_layer_wise_stats + q_raw, k_raw = q, k + + q_normed = self.q_norm(q) + k_normed = self.k_norm(k) + + if q_rope is not None: + q_cos, q_sin = q_rope + q_normed = apply_rotary_emb_pt(q_normed, q_cos, q_sin) + + if k_rope is not None: + k_cos, k_sin = k_rope + k_normed = apply_rotary_emb_pt(k_normed, k_cos, k_sin) + + # Split Heads for Attention + b, s_q, _ = q_normed.shape + _, s_kv, _ = k_normed.shape + q_h = q_normed.view(b, s_q, self.heads, self.dim_head).transpose(1, 2) + k_h = k_normed.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2) + v_h = v.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2) + + out = torch.nn.functional.scaled_dot_product_attention(q_h, k_h, v_h, attn_mask=mask, dropout_p=0.0) + out = out.transpose(1, 2).reshape(b, s_q, -1) + return self.to_out(out), (q_raw, k_raw, v, q_normed, k_normed, out) + + +# ========================================== +# 2. JAX Test Suite +# ========================================== + + +class LTX2AttentionTest(unittest.TestCase): + + def setUp(self): + # S=128 is preferred for TPU Flash Attention block sizes + self.B, self.S, self.D = 1, 128, 512 + self.heads = 4 + self.dim_head = 128 + self.context_dim = 512 + + torch.manual_seed(0) + self.rng = nnx.Rngs(0) + self.np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32) + + def _init_and_sync_models(self, dtype=jnp.bfloat16): + pt_dtype = torch.float32 if dtype == jnp.float32 else torch.bfloat16 + pt_model = PytorchLTX2Attention(self.D, self.context_dim, self.heads, self.dim_head) + pt_model.to(device="cpu", dtype=pt_dtype) + pt_model.eval() + + jax_model = LTX2Attention( + query_dim=self.D, + heads=self.heads, + dim_head=self.dim_head, + context_dim=self.context_dim, + rngs=self.rng, + attention_kernel="dot_product", + dtype=dtype, + ) + + def to_jax_dtype(arr): + return jnp.array(arr).astype(dtype) + + def copy_linear(jax_layer, pt_layer): + w_pt = pt_layer.weight.detach().float().numpy().T + b_pt = pt_layer.bias.detach().float().numpy() + jax_layer.kernel[...] = to_jax_dtype(w_pt) + jax_layer.bias[...] = to_jax_dtype(b_pt) + + def copy_norm(jax_layer, pt_layer): + w_pt = pt_layer.weight.detach().float().numpy() + jax_layer.scale[...] = to_jax_dtype(w_pt) + + copy_linear(jax_model.to_q, pt_model.to_q) + copy_linear(jax_model.to_k, pt_model.to_k) + copy_linear(jax_model.to_v, pt_model.to_v) + copy_linear(jax_model.to_out, pt_model.to_out[0]) + copy_norm(jax_model.norm_q, pt_model.q_norm) + copy_norm(jax_model.norm_k, pt_model.k_norm) + + return pt_model, jax_model + + def test_shapes(self): + model = LTX2Attention(64, 4, 16, 64, rngs=self.rng, attention_kernel="dot_product") + + x_vid = jnp.zeros((1, 128, 64)) + out_vid = model(x_vid) + self.assertEqual(out_vid.shape, (1, 128, 64)) + + x_aud = jnp.zeros((1, 32, 64)) + out_cross = model(x_vid, encoder_hidden_states=x_aud) + self.assertEqual(out_cross.shape, (1, 128, 64)) + print("\n[PASS] Shape Tests Passed.") + + def test_rope_frequency_parity(self): + dim = 60 + rope_pt = PytorchLTX2RotaryPosEmbed(dim=dim) + rope_pt.double_precision = False + rope_jax = LTX2RotaryPosEmbed(dim=dim, double_precision=False) + + np_ids = np.random.randint(0, 100, (2, 16, 3)).astype(np.float32) + + # 1. PyTorch Generation and BF16 Cast + pt_cos, pt_sin = rope_pt(torch.from_numpy(np_ids)) + pt_cos = pt_cos.to(torch.bfloat16) + pt_sin = pt_sin.to(torch.bfloat16) + + # 2. JAX Generation and BF16 Cast + jax_cos, jax_sin = rope_jax(jnp.array(np_ids)) + jax_cos = jax_cos.astype(jnp.bfloat16) + jax_sin = jax_sin.astype(jnp.bfloat16) + + # Note: Tolerance (3e-2) accounts for JAX XLA fast-math approximations + # combined with the bfloat16 truncation. + # We cast to float32 at the very end because NumPy testing doesn't natively support bfloat16. + np.testing.assert_allclose(pt_cos.float().numpy(), np.array(jax_cos, dtype=np.float32), rtol=0, atol=5e-2) + np.testing.assert_allclose(pt_sin.float().numpy(), np.array(jax_sin, dtype=np.float32), rtol=0, atol=5e-2) + print("[PASS] RoPE Frequency Parity (BF16) Verified.") + + def test_parity_bf16_strict(self): + pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16) + + pt_in = torch.from_numpy(self.np_x).to(device="cpu", dtype=torch.bfloat16) + jax_in = jnp.array(self.np_x).astype(jnp.bfloat16) + + with torch.no_grad(): + pt_out, _ = pt_model(pt_in) + + jax_out = jax_model(jax_in) + + pt_res = pt_out.float().numpy() + jax_res = np.array(jax_out, dtype=np.float32) + + np.testing.assert_allclose(pt_res, jax_res, atol=2e-2, rtol=1e-2, err_msg="BF16 Parity Failed") + print("\n[PASS] BF16 Strict Parity Test passed.") + + def test_layer_wise_stats(self): + pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16) + + pt_in = torch.from_numpy(self.np_x).to(device="cpu", dtype=torch.bfloat16) + jax_in = jnp.array(self.np_x).astype(jnp.bfloat16) + + with torch.no_grad(): + pt_out, (pt_q, pt_k, pt_v, pt_qn, pt_kn, pt_attn) = pt_model(pt_in) + + jax_q = jax_model.to_q(jax_in) + jax_k = jax_model.to_k(jax_in) + jax_v = jax_model.to_v(jax_in) + jax_qn = jax_model.norm_q(jax_q) + jax_kn = jax_model.norm_k(jax_k) + + jax_attn = jax_model.attention_op.apply_attention(jax_qn, jax_kn, jax_v) + jax_out = jax_model.to_out(jax_attn) + + stats = [] + + def add_stat(name, pt_t, jax_t): + if isinstance(pt_t, torch.Tensor): + pt_val = pt_t.float().numpy() + else: + pt_val = pt_t + jax_val = np.array(jax_t, dtype=np.float32) + stats.append({ + "Layer": name, + "PT Max": f"{pt_val.max():.4f}", + "JAX Max": f"{jax_val.max():.4f}", + "PT Mean": f"{pt_val.mean():.4f}", + "JAX Mean": f"{jax_val.mean():.4f}", + "PT Min": f"{pt_val.min():.4f}", + "JAX Min": f"{jax_val.min():.4f}", + "Diff (L1)": f"{np.abs(pt_val - jax_val).mean():.6f}", + }) + + add_stat("Query Proj", pt_q, jax_q) + add_stat("Key Proj", pt_k, jax_k) + add_stat("Value Proj", pt_v, jax_v) + add_stat("Query Norm", pt_qn, jax_qn) + add_stat("Key Norm", pt_kn, jax_kn) + add_stat("Attn Output", pt_attn, jax_attn) + add_stat("Final Output", pt_out, jax_out) + + df = pd.DataFrame(stats) + print("\n[DIAGNOSTIC] Layer-wise Stats:") + print(df.to_string(index=False)) + + def test_cross_attn_rope_integration(self): + S_Q, S_KV = 16, 20 + pt_model, jax_model = self._init_and_sync_models(dtype=jnp.float32) + + np_x = np.random.randn(self.B, S_Q, self.D).astype(np.float32) + np_ctx = np.random.randn(self.B, S_KV, self.D).astype(np.float32) + + inner_dim = self.heads * self.dim_head + rope_gen_pt = PytorchLTX2RotaryPosEmbed(dim=inner_dim) # Gen [B, S, InnerDim] + + ids_q = torch.randint(0, 100, (self.B, S_Q, 1)) + ids_k = torch.randint(0, 100, (self.B, S_KV, 1)) + + q_cos_pt, q_sin_pt = rope_gen_pt(ids_q.float()) + k_cos_pt, k_sin_pt = rope_gen_pt(ids_k.float()) + + with torch.no_grad(): + pt_out, _ = pt_model( + torch.from_numpy(np_x), context=torch.from_numpy(np_ctx), q_rope=(q_cos_pt, q_sin_pt), k_rope=(k_cos_pt, k_sin_pt) + ) + + jax_q_rope = (jnp.array(q_cos_pt.numpy()), jnp.array(q_sin_pt.numpy())) + jax_k_rope = (jnp.array(k_cos_pt.numpy()), jnp.array(k_sin_pt.numpy())) + + jax_out = jax_model( + jnp.array(np_x), encoder_hidden_states=jnp.array(np_ctx), rotary_emb=jax_q_rope, k_rotary_emb=jax_k_rope + ) + + diff = np.abs(pt_out.numpy() - np.array(jax_out)).max() + print(f"\n[Cross-Attn + RoPE] Max Diff: {diff:.6f}") + np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=5e-3) + print("[PASS] Cross-Attention with RoPE Parity Verified.") + + def test_attention_mask_parity(self): + S_flash = 512 + np_x = np.random.randn(self.B, S_flash, self.D).astype(np.float32) + pt_model, jax_model = self._init_and_sync_models(dtype=jnp.float32) + + devices = jax.devices() + mesh = Mesh(np.array(devices).reshape(1, -1), ("data", "context")) + + jax_model.attention_op.attention_kernel = "flash" + jax_model.attention_op.mesh = mesh + + mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32) + pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :] + jax_mask_multiplicative = jnp.array(mask_pattern_np) + + with torch.no_grad(): + pt_out, _ = pt_model(torch.from_numpy(np_x), mask=pt_mask_additive) + + with mesh: + jax_out = jax_model(jnp.array(np_x), attention_mask=jax_mask_multiplicative) + + diff = np.abs(pt_out.numpy() - np.array(jax_out)).max() + print(f"\n[Mask Parity] Max Diff (Flash): {diff:.6f}") + np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=5e-3) + print("[PASS] Attention Mask Parity Verified.") + + +if __name__ == "__main__": + unittest.main()