diff --git a/mlx_video/models/ltx_2/conditioning/latent.py b/mlx_video/models/ltx_2/conditioning/latent.py index 4f101b2..5f0edc4 100644 --- a/mlx_video/models/ltx_2/conditioning/latent.py +++ b/mlx_video/models/ltx_2/conditioning/latent.py @@ -105,6 +105,10 @@ def apply_conditioning( frame_idx = cond.frame_idx strength = cond.strength + # Normalize negative indices (e.g. -1 -> last frame) + if frame_idx < 0: + frame_idx = frame_idx % f + # Validate shapes _, cond_c, cond_f, cond_h, cond_w = cond_latent.shape if (cond_c, cond_h, cond_w) != (c, h, w): diff --git a/mlx_video/models/ltx_2/generate.py b/mlx_video/models/ltx_2/generate.py index c6c592d..d58193f 100644 --- a/mlx_video/models/ltx_2/generate.py +++ b/mlx_video/models/ltx_2/generate.py @@ -1673,6 +1673,34 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): # ============================================================================= +def _build_i2v_conditionings( + image_latent, + image_frame_idx: int, + image_strength: float, + end_image_latent=None, + end_image_strength: float = 1.0, +): + """Build a list of VideoConditionByLatentIndex for I2V conditioning. + + Supports first-frame, last-frame, or both simultaneously. + """ + conditionings = [] + if image_latent is not None: + idx = 0 if end_image_latent is not None else image_frame_idx + conditionings.append( + VideoConditionByLatentIndex( + latent=image_latent, frame_idx=idx, strength=image_strength + ) + ) + if end_image_latent is not None: + conditionings.append( + VideoConditionByLatentIndex( + latent=end_image_latent, frame_idx=-1, strength=end_image_strength + ) + ) + return conditionings + + def generate_video( model_repo: str, text_encoder_repo: str, @@ -1697,6 +1725,8 @@ def generate_video( image: Optional[str] = None, image_strength: float = 1.0, image_frame_idx: int = 0, + end_image: Optional[str] = None, + end_image_strength: Optional[float] = None, tiling: str = "auto", stream: bool = False, audio: bool = False, @@ -1742,9 +1772,11 @@ def generate_video( enhance_prompt: Whether to enhance prompt using Gemma max_tokens: Max tokens for prompt enhancement temperature: Temperature for prompt enhancement - image: Path to conditioning image for I2V + image: Path to conditioning image for I2V (first frame by default) image_strength: Conditioning strength for I2V - image_frame_idx: Frame index to condition for I2V + image_frame_idx: Frame index to condition for I2V (ignored when end_image is set) + end_image: Path to conditioning image for the last frame (I2V end-frame control) + end_image_strength: Conditioning strength for end frame (defaults to image_strength) tiling: Tiling mode for VAE decoding stream: Stream frames to output as they're decoded audio: Enable synchronized audio generation @@ -1772,7 +1804,10 @@ def generate_video( ) num_frames = adjusted_num_frames - is_i2v = image is not None + is_i2v = image is not None or end_image is not None + has_end_image = end_image is not None + if end_image_strength is None: + end_image_strength = image_strength is_a2v = audio_file is not None if is_a2v and audio: raise ValueError( @@ -1782,6 +1817,10 @@ def generate_video( if is_a2v: audio = True mode_str = "I2V" if is_i2v else "T2V" + if has_end_image and image is not None: + mode_str = "I2V(first+last)" + elif has_end_image: + mode_str = "I2V(last)" if is_a2v: mode_str = "A2V" + ("+I2V" if is_i2v else "") elif audio: @@ -1811,9 +1850,14 @@ def generate_video( ) if is_i2v: - console.print( - f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]" - ) + if image is not None: + console.print( + f"[dim]First image: {image} (strength={image_strength}, frame={image_frame_idx})[/]" + ) + if has_end_image: + console.print( + f"[dim]Last image: {end_image} (strength={end_image_strength}, frame=-1)[/]" + ) # Always compute audio frames - PyTorch distilled pipeline unconditionally # generates audio alongside video (model was trained with joint audio-video). @@ -2045,37 +2089,38 @@ def generate_video( # Load VAE encoder for I2V stage1_image_latent = None stage2_image_latent = None + stage1_end_image_latent = None + stage2_end_image_latent = None if is_i2v: with console.status( - "[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots" + "[blue]🖼️ Loading VAE encoder and encoding image(s)...[/]", spinner="dots" ): vae_encoder = VideoEncoder.from_pretrained( model_path / "vae" / "encoder" ) s1_h, s1_w = stage1_h * 32, stage1_w * 32 - input_image = load_image( - image, height=s1_h, width=s1_w, dtype=model_dtype - ) - stage1_image_tensor = prepare_image_for_encoding( - input_image, s1_h, s1_w, dtype=model_dtype - ) - stage1_image_latent = vae_encoder(stage1_image_tensor) - mx.eval(stage1_image_latent) - s2_h, s2_w = stage2_h * 32, stage2_w * 32 - input_image = load_image( - image, height=s2_h, width=s2_w, dtype=model_dtype - ) - stage2_image_tensor = prepare_image_for_encoding( - input_image, s2_h, s2_w, dtype=model_dtype - ) - stage2_image_latent = vae_encoder(stage2_image_tensor) - mx.eval(stage2_image_latent) + + if image is not None: + input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype) + stage1_image_latent = vae_encoder(prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)) + mx.eval(stage1_image_latent) + input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype) + stage2_image_latent = vae_encoder(prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)) + mx.eval(stage2_image_latent) + + if has_end_image: + end_input = load_image(end_image, height=s1_h, width=s1_w, dtype=model_dtype) + stage1_end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, s1_h, s1_w, dtype=model_dtype)) + mx.eval(stage1_end_image_latent) + end_input = load_image(end_image, height=s2_h, width=s2_w, dtype=model_dtype) + stage2_end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, s2_h, s2_w, dtype=model_dtype)) + mx.eval(stage2_end_image_latent) del vae_encoder mx.clear_cache() - console.print("[green]✓[/] VAE encoder loaded and image encoded") + console.print("[green]✓[/] VAE encoder loaded and image(s) encoded") # Stage 1 console.print( @@ -2099,19 +2144,18 @@ def generate_video( # Apply I2V conditioning state1 = None - if is_i2v and stage1_image_latent is not None: + if is_i2v and (stage1_image_latent is not None or stage1_end_image_latent is not None): latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) state1 = LatentState( latent=mx.zeros(latent_shape, dtype=model_dtype), clean_latent=mx.zeros(latent_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=stage1_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + conditionings = _build_i2v_conditionings( + stage1_image_latent, image_frame_idx, image_strength, + stage1_end_image_latent, end_image_strength, ) - state1 = apply_conditioning(state1, [conditioning]) + state1 = apply_conditioning(state1, conditionings) noise = mx.random.normal(latent_shape, dtype=model_dtype) noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) @@ -2177,18 +2221,17 @@ def generate_video( mx.eval(positions) state2 = None - if is_i2v and stage2_image_latent is not None: + if is_i2v and (stage2_image_latent is not None or stage2_end_image_latent is not None): state2 = LatentState( latent=latents, clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=stage2_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + conditionings = _build_i2v_conditionings( + stage2_image_latent, image_frame_idx, image_strength, + stage2_end_image_latent, end_image_strength, ) - state2 = apply_conditioning(state2, [conditioning]) + state2 = apply_conditioning(state2, conditionings) noise = mx.random.normal(latents.shape).astype(model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) @@ -2239,26 +2282,28 @@ def generate_video( # Load VAE encoder for I2V image_latent = None + end_image_latent = None if is_i2v: with console.status( - "[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots" + "[blue]🖼️ Loading VAE encoder and encoding image(s)...[/]", spinner="dots" ): vae_encoder = VideoEncoder.from_pretrained( model_path / "vae" / "encoder" ) - input_image = load_image( - image, height=height, width=width, dtype=model_dtype - ) - image_tensor = prepare_image_for_encoding( - input_image, height, width, dtype=model_dtype - ) - image_latent = vae_encoder(image_tensor) - mx.eval(image_latent) + if image is not None: + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + image_latent = vae_encoder(prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)) + mx.eval(image_latent) + + if has_end_image: + end_input = load_image(end_image, height=height, width=width, dtype=model_dtype) + end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, height, width, dtype=model_dtype)) + mx.eval(end_image_latent) del vae_encoder mx.clear_cache() - console.print("[green]✓[/] VAE encoder loaded and image encoded") + console.print("[green]✓[/] VAE encoder loaded and image(s) encoded") # Generate sigma schedule with token-count-dependent shifting sigmas = ltx2_scheduler(steps=num_inference_steps) @@ -2290,16 +2335,17 @@ def generate_video( # Initialize latents with optional I2V conditioning video_state = None video_latent_shape = (1, 128, latent_frames, latent_h, latent_w) - if is_i2v and image_latent is not None: + if is_i2v and (image_latent is not None or end_image_latent is not None): video_state = LatentState( latent=mx.zeros(video_latent_shape, dtype=model_dtype), clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=image_latent, frame_idx=image_frame_idx, strength=image_strength + conditionings = _build_i2v_conditionings( + image_latent, image_frame_idx, image_strength, + end_image_latent, end_image_strength, ) - video_state = apply_conditioning(video_state, [conditioning]) + video_state = apply_conditioning(video_state, conditionings) noise = mx.random.normal(video_latent_shape, dtype=model_dtype) noise_scale = sigmas[0] @@ -2357,37 +2403,38 @@ def generate_video( # Load VAE encoder for I2V stage1_image_latent = None stage2_image_latent = None + stage1_end_image_latent = None + stage2_end_image_latent = None if is_i2v: with console.status( - "[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots" + "[blue]🖼️ Loading VAE encoder and encoding image(s)...[/]", spinner="dots" ): vae_encoder = VideoEncoder.from_pretrained( model_path / "vae" / "encoder" ) s1_h, s1_w = stage1_h * 32, stage1_w * 32 - input_image = load_image( - image, height=s1_h, width=s1_w, dtype=model_dtype - ) - stage1_image_tensor = prepare_image_for_encoding( - input_image, s1_h, s1_w, dtype=model_dtype - ) - stage1_image_latent = vae_encoder(stage1_image_tensor) - mx.eval(stage1_image_latent) - s2_h, s2_w = stage2_h * 32, stage2_w * 32 - input_image = load_image( - image, height=s2_h, width=s2_w, dtype=model_dtype - ) - stage2_image_tensor = prepare_image_for_encoding( - input_image, s2_h, s2_w, dtype=model_dtype - ) - stage2_image_latent = vae_encoder(stage2_image_tensor) - mx.eval(stage2_image_latent) + + if image is not None: + input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype) + stage1_image_latent = vae_encoder(prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)) + mx.eval(stage1_image_latent) + input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype) + stage2_image_latent = vae_encoder(prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)) + mx.eval(stage2_image_latent) + + if has_end_image: + end_input = load_image(end_image, height=s1_h, width=s1_w, dtype=model_dtype) + stage1_end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, s1_h, s1_w, dtype=model_dtype)) + mx.eval(stage1_end_image_latent) + end_input = load_image(end_image, height=s2_h, width=s2_w, dtype=model_dtype) + stage2_end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, s2_h, s2_w, dtype=model_dtype)) + mx.eval(stage2_end_image_latent) del vae_encoder mx.clear_cache() - console.print("[green]✓[/] VAE encoder loaded and image encoded") + console.print("[green]✓[/] VAE encoder loaded and image(s) encoded") # Stage 1: Dev denoising at reduced resolution with CFG sigmas = ltx2_scheduler(steps=num_inference_steps) @@ -2419,18 +2466,17 @@ def generate_video( # Apply I2V conditioning for stage 1 state1 = None stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w) - if is_i2v and stage1_image_latent is not None: + if is_i2v and (stage1_image_latent is not None or stage1_end_image_latent is not None): state1 = LatentState( latent=mx.zeros(stage1_shape, dtype=model_dtype), clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=stage1_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + conditionings = _build_i2v_conditionings( + stage1_image_latent, image_frame_idx, image_strength, + stage1_end_image_latent, end_image_strength, ) - state1 = apply_conditioning(state1, [conditioning]) + state1 = apply_conditioning(state1, conditionings) noise = mx.random.normal(stage1_shape, dtype=model_dtype) noise_scale = sigmas[0] @@ -2529,18 +2575,17 @@ def generate_video( mx.eval(positions) state2 = None - if is_i2v and stage2_image_latent is not None: + if is_i2v and (stage2_image_latent is not None or stage2_end_image_latent is not None): state2 = LatentState( latent=latents, clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=stage2_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + conditionings = _build_i2v_conditionings( + stage2_image_latent, image_frame_idx, image_strength, + stage2_end_image_latent, end_image_strength, ) - state2 = apply_conditioning(state2, [conditioning]) + state2 = apply_conditioning(state2, conditionings) noise = mx.random.normal(latents.shape).astype(model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) @@ -2612,33 +2657,34 @@ def generate_video( # Load VAE encoder for I2V stage1_image_latent = None stage2_image_latent = None + stage1_end_image_latent = None + stage2_end_image_latent = None if is_i2v: with console.status( - "[blue]Loading VAE encoder and encoding image...[/]", spinner="dots" + "[blue]Loading VAE encoder and encoding image(s)...[/]", spinner="dots" ): vae_encoder = VideoEncoder.from_pretrained( model_path / "vae" / "encoder" ) s1_h, s1_w = stage1_h * 32, stage1_w * 32 - input_image = load_image( - image, height=s1_h, width=s1_w, dtype=model_dtype - ) - stage1_image_tensor = prepare_image_for_encoding( - input_image, s1_h, s1_w, dtype=model_dtype - ) - stage1_image_latent = vae_encoder(stage1_image_tensor) - mx.eval(stage1_image_latent) - s2_h, s2_w = stage2_h * 32, stage2_w * 32 - input_image = load_image( - image, height=s2_h, width=s2_w, dtype=model_dtype - ) - stage2_image_tensor = prepare_image_for_encoding( - input_image, s2_h, s2_w, dtype=model_dtype - ) - stage2_image_latent = vae_encoder(stage2_image_tensor) - mx.eval(stage2_image_latent) + + if image is not None: + input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype) + stage1_image_latent = vae_encoder(prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)) + mx.eval(stage1_image_latent) + input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype) + stage2_image_latent = vae_encoder(prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)) + mx.eval(stage2_image_latent) + + if has_end_image: + end_input = load_image(end_image, height=s1_h, width=s1_w, dtype=model_dtype) + stage1_end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, s1_h, s1_w, dtype=model_dtype)) + mx.eval(stage1_end_image_latent) + end_input = load_image(end_image, height=s2_h, width=s2_w, dtype=model_dtype) + stage2_end_image_latent = vae_encoder(prepare_image_for_encoding(end_input, s2_h, s2_w, dtype=model_dtype)) + mx.eval(stage2_end_image_latent) del vae_encoder mx.clear_cache() @@ -2695,18 +2741,17 @@ def generate_video( # Apply I2V conditioning for stage 1 state1 = None stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w) - if is_i2v and stage1_image_latent is not None: + if is_i2v and (stage1_image_latent is not None or stage1_end_image_latent is not None): state1 = LatentState( latent=mx.zeros(stage1_shape, dtype=model_dtype), clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=stage1_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + conditionings = _build_i2v_conditionings( + stage1_image_latent, image_frame_idx, image_strength, + stage1_end_image_latent, end_image_strength, ) - state1 = apply_conditioning(state1, [conditioning]) + state1 = apply_conditioning(state1, conditionings) noise = mx.random.normal(stage1_shape, dtype=model_dtype) noise_scale = sigmas[0] @@ -2796,18 +2841,17 @@ def generate_video( mx.eval(positions) state2 = None - if is_i2v and stage2_image_latent is not None: + if is_i2v and (stage2_image_latent is not None or stage2_end_image_latent is not None): state2 = LatentState( latent=latents, clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) - conditioning = VideoConditionByLatentIndex( - latent=stage2_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + conditionings = _build_i2v_conditionings( + stage2_image_latent, image_frame_idx, image_strength, + stage2_end_image_latent, end_image_strength, ) - state2 = apply_conditioning(state2, [conditioning]) + state2 = apply_conditioning(state2, conditionings) noise = mx.random.normal(latents.shape).astype(model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) @@ -3196,7 +3240,19 @@ def main(): "--image-frame-idx", type=int, default=0, - help="Frame index to condition for I2V", + help="Frame index to condition for I2V (ignored when --end-image is set)", + ) + parser.add_argument( + "--end-image", + type=str, + default=None, + help="Path to conditioning image for the last frame (I2V end-frame control)", + ) + parser.add_argument( + "--end-image-strength", + type=float, + default=None, + help="Conditioning strength for end frame (defaults to --image-strength)", ) parser.add_argument( "--tiling", @@ -3340,6 +3396,8 @@ def main(): image=args.image, image_strength=args.image_strength, image_frame_idx=args.image_frame_idx, + end_image=args.end_image, + end_image_strength=args.end_image_strength, tiling=args.tiling, stream=args.stream, audio=args.audio,