Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions mlx_video/models/ltx_2/video_vae/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,20 @@ def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
if pad_d > 0 or pad_h > 0 or pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])

# Skip connection: space-to-depth on input, then group mean
x_in = self._space_to_depth(x)
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
b2, c2, d2, h2, w2 = x_in.shape
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)

# Conv branch: apply conv then space-to-depth
x_conv = self.conv(x, causal=causal)
x_conv = self._space_to_depth(x_conv)

# Skip connection: space-to-depth on input, then group mean
# Derive output channels from conv result to handle weight/config mismatches
# (e.g. LTX-2.3 converted weights where conv out_channels != out_channels // multiplier)
actual_out_channels = x_conv.shape[1]
x_in = self._space_to_depth(x)
b2, c2, d2, h2, w2 = x_in.shape
actual_group_size = c2 // actual_out_channels
x_in = mx.reshape(x_in, (b2, actual_out_channels, actual_group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, actual_out_channels, d, h, w)

# Add skip connection
return x_conv + x_in

Expand Down