diff --git a/mlx_video/models/ltx_2/video_vae/sampling.py b/mlx_video/models/ltx_2/video_vae/sampling.py index 7e351ba..424e22e 100644 --- a/mlx_video/models/ltx_2/video_vae/sampling.py +++ b/mlx_video/models/ltx_2/video_vae/sampling.py @@ -88,17 +88,20 @@ def __call__(self, x: mx.array, causal: bool = True) -> mx.array: if pad_d > 0 or pad_h > 0 or pad_w > 0: x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)]) - # Skip connection: space-to-depth on input, then group mean - x_in = self._space_to_depth(x) - # Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w) - b2, c2, d2, h2, w2 = x_in.shape - x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2)) - x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w) - # Conv branch: apply conv then space-to-depth x_conv = self.conv(x, causal=causal) x_conv = self._space_to_depth(x_conv) + # Skip connection: space-to-depth on input, then group mean + # Derive output channels from conv result to handle weight/config mismatches + # (e.g. LTX-2.3 converted weights where conv out_channels != out_channels // multiplier) + actual_out_channels = x_conv.shape[1] + x_in = self._space_to_depth(x) + b2, c2, d2, h2, w2 = x_in.shape + actual_group_size = c2 // actual_out_channels + x_in = mx.reshape(x_in, (b2, actual_out_channels, actual_group_size, d2, h2, w2)) + x_in = mx.mean(x_in, axis=2) # (b, actual_out_channels, d, h, w) + # Add skip connection return x_conv + x_in