From a8cd1db78ad90fa69f30ab6764e14f5c1ba94ca7 Mon Sep 17 00:00:00 2001 From: Norbert Schmidt Date: Tue, 31 Mar 2026 14:00:17 +0200 Subject: [PATCH] Fix VAE encoder broadcast error for LTX-2.3 I2V The SpaceToDepthDownsample skip connection computed channel counts from self.out_channels, but LTX-2.3 converted weights (prince-canuma/LTX-2.3-distilled) have conv output channels that don't match out_channels // multiplier for the last downsample block (128 vs 256), causing a broadcast shape error: ValueError: Shapes (1,1024,1,8,12) and (1,2048,1,8,12) cannot be broadcast This only affects I2V (image-to-video) since T2V doesn't use the VAE encoder. Fix: derive the skip connection channel count from the actual conv output shape instead of the configured out_channels. This is backwards-compatible with LTX-2 weights where the values already match. Co-Authored-By: Claude Opus 4.6 (1M context) --- mlx_video/models/ltx_2/video_vae/sampling.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mlx_video/models/ltx_2/video_vae/sampling.py b/mlx_video/models/ltx_2/video_vae/sampling.py index 7e351ba..424e22e 100644 --- a/mlx_video/models/ltx_2/video_vae/sampling.py +++ b/mlx_video/models/ltx_2/video_vae/sampling.py @@ -88,17 +88,20 @@ def __call__(self, x: mx.array, causal: bool = True) -> mx.array: if pad_d > 0 or pad_h > 0 or pad_w > 0: x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)]) - # Skip connection: space-to-depth on input, then group mean - x_in = self._space_to_depth(x) - # Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w) - b2, c2, d2, h2, w2 = x_in.shape - x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2)) - x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w) - # Conv branch: apply conv then space-to-depth x_conv = self.conv(x, causal=causal) x_conv = self._space_to_depth(x_conv) + # Skip connection: space-to-depth on input, then group mean + # Derive output channels from conv result to handle weight/config mismatches + # (e.g. LTX-2.3 converted weights where conv out_channels != out_channels // multiplier) + actual_out_channels = x_conv.shape[1] + x_in = self._space_to_depth(x) + b2, c2, d2, h2, w2 = x_in.shape + actual_group_size = c2 // actual_out_channels + x_in = mx.reshape(x_in, (b2, actual_out_channels, actual_group_size, d2, h2, w2)) + x_in = mx.mean(x_in, axis=2) # (b, actual_out_channels, d, h, w) + # Add skip connection return x_conv + x_in