From 14af138026fed9b76a37b4b88216df943948474b Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Tue, 4 Nov 2025 10:37:09 +0000 Subject: [PATCH 1/7] Use diffusers for streamdiffusionv2 Signed-off-by: Rafal Leszko --- pipelines/streamdiffusionv2/modular_blocks.py | 390 ++++++++++++++++++ .../streamdiffusionv2/modular_config.json | 7 + .../modular_pipeline_wrapper.py | 28 ++ pipelines/streamdiffusionv2/pipeline.py | 94 +++-- 4 files changed, 487 insertions(+), 32 deletions(-) create mode 100644 pipelines/streamdiffusionv2/modular_blocks.py create mode 100644 pipelines/streamdiffusionv2/modular_config.json create mode 100644 pipelines/streamdiffusionv2/modular_pipeline_wrapper.py diff --git a/pipelines/streamdiffusionv2/modular_blocks.py b/pipelines/streamdiffusionv2/modular_blocks.py new file mode 100644 index 000000000..879819399 --- /dev/null +++ b/pipelines/streamdiffusionv2/modular_blocks.py @@ -0,0 +1,390 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import List, Optional, Union, Dict +import logging + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffusers.configuration_utils import FrozenDict +from diffusers.guiders import ClassifierFreeGuidance +from diffusers.utils import is_ftfy_available, logging as diffusers_logging +from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, +) + +if is_ftfy_available(): + import ftfy + +logger = diffusers_logging.get_logger(__name__) + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class StreamDiffusionV2TextEncoderStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def description(self) -> str: + return "Text Encoder step that generates text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="negative text embeddings used to guide the image generation", + ), + InputParam("attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) + and not isinstance(block_state.prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" + ) + + @torch.no_grad() + def __call__( + self, components, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + # For streamdiffusionv2, prompt encoding is handled by prompt_blender + # This block is a placeholder to maintain compatibility with Modular Diffusers + # The actual prompt_embeds are set via conditional_dict in the pipeline + # We just need to ensure prompt_embeds is in the state + if block_state.prompt_embeds is None: + # Use the stream's text encoder if needed + if hasattr(block_state, 'prompt') and block_state.prompt is not None: + conditional_dict = components.stream.text_encoder(text_prompts=[block_state.prompt]) + block_state.prompt_embeds = conditional_dict["prompt_embeds"] + else: + # Default empty prompt + conditional_dict = components.stream.text_encoder(text_prompts=[""]) + block_state.prompt_embeds = conditional_dict["prompt_embeds"] + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class StreamDiffusionV2PreprocessStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Preprocess step that encodes input frames to latents and adds noise" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "input", + required=True, + type_hint=torch.Tensor, + description="Input video frames in BCTHW format", + ), + InputParam( + "noise_scale", + type_hint=float, + default=0.7, + description="Scale of noise to add to latents", + ), + InputParam( + "generator", + description="Random number generator for noise", + ), + InputParam( + "base_seed", + type_hint=int, + default=42, + description="Base seed for random number generation", + ), + InputParam( + "current_start", + type_hint=int, + default=0, + description="Current start position for seed offset", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="Encoded and noised latents", + ), + OutputParam( + "noisy_latents", + type_hint=torch.Tensor, + description="Noisy latents ready for denoising", + ), + OutputParam( + "current_step", + type_hint=int, + description="Current denoising step", + ), + ] + + @torch.no_grad() + def __call__( + self, components, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + + # Encode frames to latents using VAE + latents = components.stream.vae.model.stream_encode(block_state.input) + # Transpose latents + latents = latents.transpose(2, 1) + + # Create generator from seed for reproducible generation + frame_seed = block_state.base_seed + block_state.current_start + rng = torch.Generator(device=latents.device).manual_seed(frame_seed) + + noise = torch.randn( + latents.shape, + device=latents.device, + dtype=latents.dtype, + generator=rng, + ) + # Determine how noisy the latents should be + noisy_latents = noise * block_state.noise_scale + latents * (1 - block_state.noise_scale) + + # Determine the number of denoising steps + current_step = int(1000 * block_state.noise_scale) - 100 + + block_state.latents = latents + block_state.noisy_latents = noisy_latents + block_state.current_step = current_step + block_state.generator = rng + + self.set_block_state(state, block_state) + return components, state + + +class StreamDiffusionV2DenoiseStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Denoise step that performs inference using the stream pipeline" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "noisy_latents", + required=True, + type_hint=torch.Tensor, + description="Noisy latents to denoise", + ), + InputParam( + "current_start", + required=True, + type_hint=int, + description="Current start position", + ), + InputParam( + "current_end", + required=True, + type_hint=int, + description="Current end position", + ), + InputParam( + "current_step", + required=True, + type_hint=int, + description="Current denoising step", + ), + InputParam( + "generator", + description="Random number generator", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "denoised_pred", + type_hint=torch.Tensor, + description="Denoised prediction", + ), + ] + + @torch.no_grad() + def __call__( + self, components, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + + # Use the stream's inference method + denoised_pred = components.stream.inference( + noise=block_state.noisy_latents, + current_start=block_state.current_start, + current_end=block_state.current_end, + current_step=block_state.current_step, + generator=block_state.generator, + ) + + block_state.denoised_pred = denoised_pred + self.set_block_state(state, block_state) + return components, state + + +class StreamDiffusionV2PostprocessStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Postprocess step that decodes denoised latents to pixel space" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "denoised_pred", + required=True, + type_hint=torch.Tensor, + description="Denoised latents", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "output", + type_hint=torch.Tensor, + description="Decoded video frames", + ), + ] + + @torch.no_grad() + def __call__( + self, components, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + + # Decode to pixel space + output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred) + block_state.output = output + + self.set_block_state(state, block_state) + return components, state + + +from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict + +VIDEO2VIDEO_BLOCKS = InsertableDict( + [ + ("text_encoder", StreamDiffusionV2TextEncoderStep), + ("preprocess", StreamDiffusionV2PreprocessStep), + ("denoise", StreamDiffusionV2DenoiseStep), + ("postprocess", StreamDiffusionV2PostprocessStep), + ] +) + +ALL_BLOCKS = { + "video2video": VIDEO2VIDEO_BLOCKS, +} + + +class StreamDiffusionV2Blocks(SequentialPipelineBlocks): + block_classes = list(VIDEO2VIDEO_BLOCKS.copy().values()) + block_names = list(VIDEO2VIDEO_BLOCKS.copy().keys()) diff --git a/pipelines/streamdiffusionv2/modular_config.json b/pipelines/streamdiffusionv2/modular_config.json new file mode 100644 index 000000000..e9c4d0c42 --- /dev/null +++ b/pipelines/streamdiffusionv2/modular_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "StreamDiffusionV2Blocks", + "_diffusers_version": "0.36.0.dev0", + "auto_map": { + "ModularPipelineBlocks": "modular_blocks.StreamDiffusionV2Blocks" + } +} diff --git a/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py b/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py new file mode 100644 index 000000000..6e9c87d80 --- /dev/null +++ b/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py @@ -0,0 +1,28 @@ +"""Wrapper to expose CausalStreamInferencePipeline as a ModularPipeline for Modular Diffusers.""" + +import torch +from diffusers.modular_pipelines import ModularPipeline + + +class StreamDiffusionV2ModularPipeline(ModularPipeline): + """Wrapper that exposes CausalStreamInferencePipeline as a ModularPipeline.""" + + def __init__(self, stream): + """ + Initialize the wrapper. + + Args: + stream: CausalStreamInferencePipeline instance + """ + self.stream = stream + self._execution_device_val = next(stream.generator.parameters()).device + + @property + def _execution_device(self): + """Return the execution device.""" + return self._execution_device_val + + @_execution_device.setter + def _execution_device(self, value): + """Set the execution device.""" + self._execution_device_val = value diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index 203d326d9..263184a58 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -4,9 +4,14 @@ import torch +from diffusers.modular_pipelines import PipelineState +from diffusers.modular_pipelines import ModularPipeline + from ..blending import PromptBlender, handle_transition_prepare from ..interface import Pipeline, Requirements from ..process import postprocess_chunk, preprocess_chunk +from .modular_blocks import StreamDiffusionV2Blocks +from .modular_pipeline_wrapper import StreamDiffusionV2ModularPipeline from .vendor.causvid.models.wan.causal_stream_inference import ( CausalStreamInferencePipeline, ) @@ -77,6 +82,10 @@ def __init__( self.current_start = 0 self.current_end = self.stream.frame_seq_length * 2 + # Initialize Modular Diffusers blocks + self.modular_blocks = StreamDiffusionV2Blocks() + self.modular_pipeline = StreamDiffusionV2ModularPipeline(self.stream) + def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: if should_prepare: logger.info("prepare: Initiating pipeline prepare for request") @@ -252,46 +261,67 @@ def __call__( if noise_controller: self._apply_motion_aware_noise_controller(input) + # Use Modular Diffusers blocks to process the input + state = PipelineState() + + # Set up state for modular blocks + state.set("input", input) + state.set("noise_scale", self.noise_scale) + state.set("base_seed", self.base_seed) + state.set("current_start", self.current_start) + state.set("current_end", self.current_end) + state.set("denoising_step_list", self.stream.denoising_step_list) + # Determine the number of denoising steps - # Higher noise scale -> more denoising steps, more intense changes to input - # Lower noise scale -> less denoising steps, less intense changes to input current_step = int(1000 * self.noise_scale) - 100 + state.set("current_step", current_step) + + # Execute modular blocks (returns tuple: components, state) + _, state = self.modular_blocks(self.modular_pipeline, state) + + # Get output from state + output = state.values.get("output") + if output is None: + # Fallback to original implementation if modular blocks didn't produce output + # Determine the number of denoising steps + current_step = int(1000 * self.noise_scale) - 100 + + # Encode frames to latents using VAE + latents = self.stream.vae.model.stream_encode(input) + # Transpose latents + latents = latents.transpose(2, 1) + + # Create generator from seed for reproducible generation + frame_seed = self.base_seed + self.current_start + rng = torch.Generator(device=latents.device).manual_seed(frame_seed) + + noise = torch.randn( + latents.shape, + device=latents.device, + dtype=latents.dtype, + generator=rng, + ) + noisy_latents = noise * self.noise_scale + latents * (1 - self.noise_scale) + denoised_pred = self.stream.inference( + noise=noisy_latents, + current_start=self.current_start, + current_end=self.current_end, + current_step=current_step, + generator=rng, + ) - # Encode frames to latents using VAE - latents = self.stream.vae.model.stream_encode(input) - # Transpose latents - latents = latents.transpose(2, 1) - - # Create generator from seed for reproducible generation - # Derive unique seed per chunk using current_start as offset - frame_seed = self.base_seed + self.current_start - rng = torch.Generator(device=latents.device).manual_seed(frame_seed) - - noise = torch.randn( - latents.shape, - device=latents.device, - dtype=latents.dtype, - generator=rng, - ) - # Determine how noisy the latents should be - # Higher noise scale -> noiser latents, less of inputs preserved - # Lower noise scale -> less noisy latents, more of inputs preserved - noisy_latents = noise * self.noise_scale + latents * (1 - self.noise_scale) - denoised_pred = self.stream.inference( - noise=noisy_latents, - current_start=self.current_start, - current_end=self.current_end, - current_step=current_step, - generator=rng, - ) + # Decode to pixel space + output = self.stream.vae.stream_decode_to_pixel(denoised_pred) + else: + # Ensure output is in the right format + if not isinstance(output, torch.Tensor): + output = output[0] if isinstance(output, list) else output - # # Update tracking variables for next input + # Update tracking variables for next input self.last_frame = input[:, :, [-1]] self.current_start = self.current_end self.current_end += (self.chunk_size // 4) * self.stream.frame_seq_length - # Decode to pixel space - output = self.stream.vae.stream_decode_to_pixel(denoised_pred) return postprocess_chunk(output) def _initialize_stream_caches(self): From f512e7455824463394055ab3990f763ac0559e2a Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Tue, 4 Nov 2025 11:45:51 +0000 Subject: [PATCH 2/7] Remove fallback Signed-off-by: Rafal Leszko --- pipelines/streamdiffusionv2/pipeline.py | 82 ++++++++++--------------- 1 file changed, 31 insertions(+), 51 deletions(-) diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index 263184a58..f99396a8f 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -3,9 +3,7 @@ import time import torch - from diffusers.modular_pipelines import PipelineState -from diffusers.modular_pipelines import ModularPipeline from ..blending import PromptBlender, handle_transition_prepare from ..interface import Pipeline, Requirements @@ -80,11 +78,11 @@ def __init__( self.last_frame = None self.current_start = 0 - self.current_end = self.stream.frame_seq_length * 2 # Initialize Modular Diffusers blocks self.modular_blocks = StreamDiffusionV2Blocks() self.modular_pipeline = StreamDiffusionV2ModularPipeline(self.stream) + self.current_end = self.modular_pipeline.stream.frame_seq_length * 2 def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: if should_prepare: @@ -107,7 +105,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # Handle prompt transition requests should_prepare_from_transition, target_prompts = handle_transition_prepare( - transition, self.prompt_blender, self.stream.text_encoder + transition, self.prompt_blender, self.modular_pipeline.stream.text_encoder ) if target_prompts: self.prompts = target_prompts @@ -138,7 +136,9 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # We need to make sure that current_start does not shift past the max sequence length of the RoPE frequency table # When we hit the limit we reset the caches and indices # See this issue for more context https://github.com/daydreamlive/scope/issues/95 - max_current_start = MAX_ROPE_FREQ_TABLE_SEQ_LEN * self.stream.frame_seq_length + max_current_start = ( + MAX_ROPE_FREQ_TABLE_SEQ_LEN * self.modular_pipeline.stream.frame_seq_length + ) # We reset at whatever is smaller the theoretically max value or some % of it max_current_start = min( int(max_current_start * CURRENT_START_RESET_RATIO), max_current_start @@ -151,7 +151,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # Update internal state before preparing pipeline if denoising_step_list is not None: self.denoising_step_list = denoising_step_list - self.stream.denoising_step_list = torch.tensor( + self.modular_pipeline.stream.denoising_step_list = torch.tensor( denoising_step_list, dtype=torch.long, device=self.device ) @@ -170,16 +170,16 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: @torch.no_grad() def _prepare_pipeline(self, prompts=None, interpolation_method="linear"): # Trigger KV + cross-attn cache re-initialization in prepare() - self.stream.kv_cache1 = None + self.modular_pipeline.stream.kv_cache1 = None # Apply prompt blending and set conditional_dict self._apply_prompt_blending(prompts, interpolation_method) - self.stream.vae.model.first_batch = True + self.modular_pipeline.stream.vae.model.first_batch = True self.last_frame = None self.current_start = 0 - self.current_end = self.stream.frame_seq_length * 2 + self.current_end = self.modular_pipeline.stream.frame_seq_length * 2 def _apply_motion_aware_noise_controller(self, input: torch.Tensor): # The prev seq is the last chunk_size frames of the current input @@ -231,11 +231,13 @@ def __call__( # Update prompt embedding for this generation call # Handles both static blending and temporal transitions next_embedding = self.prompt_blender.get_next_embedding( - self.stream.text_encoder + self.modular_pipeline.stream.text_encoder ) if next_embedding is not None: - self.stream.conditional_dict = {"prompt_embeds": next_embedding} + self.modular_pipeline.stream.conditional_dict = { + "prompt_embeds": next_embedding + } # Note: The caller must call prepare() before __call__() # We just need to get the expected chunk size based on current state @@ -270,7 +272,9 @@ def __call__( state.set("base_seed", self.base_seed) state.set("current_start", self.current_start) state.set("current_end", self.current_end) - state.set("denoising_step_list", self.stream.denoising_step_list) + state.set( + "denoising_step_list", self.modular_pipeline.stream.denoising_step_list + ) # Determine the number of denoising steps current_step = int(1000 * self.noise_scale) - 100 @@ -281,67 +285,43 @@ def __call__( # Get output from state output = state.values.get("output") + if output is None: - # Fallback to original implementation if modular blocks didn't produce output - # Determine the number of denoising steps - current_step = int(1000 * self.noise_scale) - 100 - - # Encode frames to latents using VAE - latents = self.stream.vae.model.stream_encode(input) - # Transpose latents - latents = latents.transpose(2, 1) - - # Create generator from seed for reproducible generation - frame_seed = self.base_seed + self.current_start - rng = torch.Generator(device=latents.device).manual_seed(frame_seed) - - noise = torch.randn( - latents.shape, - device=latents.device, - dtype=latents.dtype, - generator=rng, - ) - noisy_latents = noise * self.noise_scale + latents * (1 - self.noise_scale) - denoised_pred = self.stream.inference( - noise=noisy_latents, - current_start=self.current_start, - current_end=self.current_end, - current_step=current_step, - generator=rng, - ) + raise RuntimeError("Modular blocks did not produce output") - # Decode to pixel space - output = self.stream.vae.stream_decode_to_pixel(denoised_pred) - else: - # Ensure output is in the right format - if not isinstance(output, torch.Tensor): - output = output[0] if isinstance(output, list) else output + # Ensure output is in the right format + if not isinstance(output, torch.Tensor): + output = output[0] if isinstance(output, list) else output # Update tracking variables for next input self.last_frame = input[:, :, [-1]] self.current_start = self.current_end - self.current_end += (self.chunk_size // 4) * self.stream.frame_seq_length + self.current_end += ( + self.chunk_size // 4 + ) * self.modular_pipeline.stream.frame_seq_length return postprocess_chunk(output) def _initialize_stream_caches(self): """Initialize stream caches without overriding conditional_dict.""" noise = torch.zeros(1, 1).to(self.device, self.dtype) - saved = self.stream.conditional_dict - self.stream.prepare(noise, text_prompts=[""]) - self.stream.conditional_dict = saved + saved = self.modular_pipeline.stream.conditional_dict + self.modular_pipeline.stream.prepare(noise, text_prompts=[""]) + self.modular_pipeline.stream.conditional_dict = saved def _apply_prompt_blending(self, prompts=None, interpolation_method="linear"): """Apply weighted blending of cached prompt embeddings.""" combined_embeds = self.prompt_blender.blend( - prompts, interpolation_method, self.stream.text_encoder + prompts, interpolation_method, self.modular_pipeline.stream.text_encoder ) if combined_embeds is None: return # Set the blended embeddings on the stream - self.stream.conditional_dict = {"prompt_embeds": combined_embeds} + self.modular_pipeline.stream.conditional_dict = { + "prompt_embeds": combined_embeds + } # Initialize caches without overriding conditional_dict self._initialize_stream_caches() From f0cffe21c8ebd3622061d44f5b96e210a4b1cccb Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Tue, 4 Nov 2025 12:00:42 +0000 Subject: [PATCH 3/7] Refactor Signed-off-by: Rafal Leszko --- pipelines/streamdiffusionv2/modular_blocks.py | 71 +++++++++---------- .../modular_pipeline_wrapper.py | 31 ++++++-- pipelines/streamdiffusionv2/pipeline.py | 26 ++----- 3 files changed, 66 insertions(+), 62 deletions(-) diff --git a/pipelines/streamdiffusionv2/modular_blocks.py b/pipelines/streamdiffusionv2/modular_blocks.py index 879819399..be3560f17 100644 --- a/pipelines/streamdiffusionv2/modular_blocks.py +++ b/pipelines/streamdiffusionv2/modular_blocks.py @@ -13,23 +13,22 @@ # limitations under the License. import html -from typing import List, Optional, Union, Dict -import logging import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel - -from diffusers.configuration_utils import FrozenDict -from diffusers.guiders import ClassifierFreeGuidance -from diffusers.utils import is_ftfy_available, logging as diffusers_logging -from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, + SequentialPipelineBlocks, +) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, ConfigSpec, InputParam, OutputParam, ) +from diffusers.utils import is_ftfy_available +from diffusers.utils import logging as diffusers_logging if is_ftfy_available(): import ftfy @@ -63,17 +62,17 @@ def description(self) -> str: return "Text Encoder step that generates text_embeddings to guide the video generation" @property - def expected_components(self) -> List[ComponentSpec]: + def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("stream", torch.nn.Module), ] @property - def expected_configs(self) -> List[ConfigSpec]: + def expected_configs(self) -> list[ConfigSpec]: return [] @property - def inputs(self) -> List[InputParam]: + def inputs(self) -> list[InputParam]: return [ InputParam("prompt"), InputParam("negative_prompt"), @@ -91,7 +90,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediate_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam( "prompt_embeds", @@ -118,9 +117,7 @@ def check_inputs(block_state): ) @torch.no_grad() - def __call__( - self, components, state: PipelineState - ) -> PipelineState: + def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(block_state) @@ -130,8 +127,10 @@ def __call__( # We just need to ensure prompt_embeds is in the state if block_state.prompt_embeds is None: # Use the stream's text encoder if needed - if hasattr(block_state, 'prompt') and block_state.prompt is not None: - conditional_dict = components.stream.text_encoder(text_prompts=[block_state.prompt]) + if hasattr(block_state, "prompt") and block_state.prompt is not None: + conditional_dict = components.stream.text_encoder( + text_prompts=[block_state.prompt] + ) block_state.prompt_embeds = conditional_dict["prompt_embeds"] else: # Default empty prompt @@ -147,7 +146,7 @@ class StreamDiffusionV2PreprocessStep(ModularPipelineBlocks): model_name = "StreamDiffusionV2" @property - def expected_components(self) -> List[ComponentSpec]: + def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("stream", torch.nn.Module), ] @@ -157,7 +156,7 @@ def description(self) -> str: return "Preprocess step that encodes input frames to latents and adds noise" @property - def inputs(self) -> List[InputParam]: + def inputs(self) -> list[InputParam]: return [ InputParam( "input", @@ -190,7 +189,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediate_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam( "latents", @@ -210,9 +209,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @torch.no_grad() - def __call__( - self, components, state: PipelineState - ) -> PipelineState: + def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Encode frames to latents using VAE @@ -231,7 +228,9 @@ def __call__( generator=rng, ) # Determine how noisy the latents should be - noisy_latents = noise * block_state.noise_scale + latents * (1 - block_state.noise_scale) + noisy_latents = noise * block_state.noise_scale + latents * ( + 1 - block_state.noise_scale + ) # Determine the number of denoising steps current_step = int(1000 * block_state.noise_scale) - 100 @@ -249,7 +248,7 @@ class StreamDiffusionV2DenoiseStep(ModularPipelineBlocks): model_name = "StreamDiffusionV2" @property - def expected_components(self) -> List[ComponentSpec]: + def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("stream", torch.nn.Module), ] @@ -259,7 +258,7 @@ def description(self) -> str: return "Denoise step that performs inference using the stream pipeline" @property - def inputs(self) -> List[InputParam]: + def inputs(self) -> list[InputParam]: return [ InputParam( "noisy_latents", @@ -292,7 +291,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediate_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam( "denoised_pred", @@ -302,9 +301,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @torch.no_grad() - def __call__( - self, components, state: PipelineState - ) -> PipelineState: + def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Use the stream's inference method @@ -325,7 +322,7 @@ class StreamDiffusionV2PostprocessStep(ModularPipelineBlocks): model_name = "StreamDiffusionV2" @property - def expected_components(self) -> List[ComponentSpec]: + def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("stream", torch.nn.Module), ] @@ -335,7 +332,7 @@ def description(self) -> str: return "Postprocess step that decodes denoised latents to pixel space" @property - def inputs(self) -> List[InputParam]: + def inputs(self) -> list[InputParam]: return [ InputParam( "denoised_pred", @@ -346,7 +343,7 @@ def inputs(self) -> List[InputParam]: ] @property - def intermediate_outputs(self) -> List[OutputParam]: + def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam( "output", @@ -356,9 +353,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ] @torch.no_grad() - def __call__( - self, components, state: PipelineState - ) -> PipelineState: + def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Decode to pixel space @@ -386,5 +381,5 @@ def __call__( class StreamDiffusionV2Blocks(SequentialPipelineBlocks): - block_classes = list(VIDEO2VIDEO_BLOCKS.copy().values()) - block_names = list(VIDEO2VIDEO_BLOCKS.copy().keys()) + block_classes = list(VIDEO2VIDEO_BLOCKS.values()) + block_names = list(VIDEO2VIDEO_BLOCKS.keys()) diff --git a/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py b/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py index 6e9c87d80..fc9ced765 100644 --- a/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py +++ b/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py @@ -1,21 +1,44 @@ """Wrapper to expose CausalStreamInferencePipeline as a ModularPipeline for Modular Diffusers.""" +import os +import time + import torch from diffusers.modular_pipelines import ModularPipeline +from .vendor.causvid.models.wan.causal_stream_inference import ( + CausalStreamInferencePipeline, +) + class StreamDiffusionV2ModularPipeline(ModularPipeline): """Wrapper that exposes CausalStreamInferencePipeline as a ModularPipeline.""" - def __init__(self, stream): + def __init__(self, config, device, dtype, model_dir): """ Initialize the wrapper. Args: - stream: CausalStreamInferencePipeline instance + config: Configuration dictionary for the pipeline + device: Device to run the pipeline on + dtype: Data type for the pipeline + model_dir: Directory containing the model files """ - self.stream = stream - self._execution_device_val = next(stream.generator.parameters()).device + # Create and initialize the stream pipeline + self.stream = CausalStreamInferencePipeline(config, device).to( + device=device, dtype=dtype + ) + + # Load the generator state dict + start = time.time() + state_dict = torch.load( + os.path.join(model_dir, "StreamDiffusionV2/model.pt"), + map_location="cpu", + )["generator"] + self.stream.generator.load_state_dict(state_dict, strict=True) + print(f"Loaded diffusion state dict in {time.time() - start:.3f}s") + + self._execution_device_val = next(self.stream.generator.parameters()).device @property def _execution_device(self): diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index f99396a8f..d19df3b2d 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -1,6 +1,4 @@ import logging -import os -import time import torch from diffusers.modular_pipelines import PipelineState @@ -10,9 +8,6 @@ from ..process import postprocess_chunk, preprocess_chunk from .modular_blocks import StreamDiffusionV2Blocks from .modular_pipeline_wrapper import StreamDiffusionV2ModularPipeline -from .vendor.causvid.models.wan.causal_stream_inference import ( - CausalStreamInferencePipeline, -) # https://github.com/daydreamlive/scope/blob/0cf1766186be3802bf97ce550c2c978439f22068/pipelines/streamdiffusionv2/vendor/causvid/models/wan/causal_model.py#L306 MAX_ROPE_FREQ_TABLE_SEQ_LEN = 1024 @@ -49,20 +44,9 @@ def __init__( config["height"] = self.height config["width"] = self.width - self.stream = CausalStreamInferencePipeline(config, device).to( - device=device, dtype=dtype - ) self.device = device self.dtype = dtype - start = time.time() - state_dict = torch.load( - os.path.join(config.model_dir, "StreamDiffusionV2/model.pt"), - map_location="cpu", - )["generator"] - self.stream.generator.load_state_dict(state_dict, strict=True) - print(f"Loaded diffusion state dict in {time.time() - start:.3f}s") - self.chunk_size = chunk_size self.start_chunk_size = start_chunk_size self.noise_scale = noise_scale @@ -71,6 +55,12 @@ def __init__( self.prompts = None self.denoising_step_list = None + # Initialize Modular Diffusers blocks and pipeline + self.modular_blocks = StreamDiffusionV2Blocks() + self.modular_pipeline = StreamDiffusionV2ModularPipeline( + config, device, dtype, config.model_dir + ) + # Prompt blending with cache reset callback for transitions self.prompt_blender = PromptBlender( device, dtype, cache_reset_callback=self._initialize_stream_caches @@ -78,10 +68,6 @@ def __init__( self.last_frame = None self.current_start = 0 - - # Initialize Modular Diffusers blocks - self.modular_blocks = StreamDiffusionV2Blocks() - self.modular_pipeline = StreamDiffusionV2ModularPipeline(self.stream) self.current_end = self.modular_pipeline.stream.frame_seq_length * 2 def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: From ec2ef47bfb866d6ba3753aa5b15c41ab05bbbf4b Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Wed, 5 Nov 2025 08:37:20 +0000 Subject: [PATCH 4/7] Optimize Modular Diffusers pipeline Signed-off-by: Rafal Leszko --- pipelines/streamdiffusionv2/modular_blocks.py | 23 +++++++++---------- pipelines/streamdiffusionv2/pipeline.py | 8 +++++++ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/pipelines/streamdiffusionv2/modular_blocks.py b/pipelines/streamdiffusionv2/modular_blocks.py index be3560f17..fa15e082a 100644 --- a/pipelines/streamdiffusionv2/modular_blocks.py +++ b/pipelines/streamdiffusionv2/modular_blocks.py @@ -119,13 +119,13 @@ def check_inputs(block_state): @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - self.check_inputs(block_state) # For streamdiffusionv2, prompt encoding is handled by prompt_blender - # This block is a placeholder to maintain compatibility with Modular Diffusers - # The actual prompt_embeds are set via conditional_dict in the pipeline - # We just need to ensure prompt_embeds is in the state + # The actual prompt_embeds are set via conditional_dict in the pipeline before blocks execute + # Skip all work if prompt_embeds already exist (which they should via conditional_dict) if block_state.prompt_embeds is None: + # Only check inputs if we actually need to encode + self.check_inputs(block_state) # Use the stream's text encoder if needed if hasattr(block_state, "prompt") and block_state.prompt is not None: conditional_dict = components.stream.text_encoder( @@ -136,9 +136,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState: # Default empty prompt conditional_dict = components.stream.text_encoder(text_prompts=[""]) block_state.prompt_embeds = conditional_dict["prompt_embeds"] + self.set_block_state(state, block_state) + # If prompt_embeds already exists, skip state update to reduce overhead - # Add outputs - self.set_block_state(state, block_state) return components, state @@ -235,6 +235,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: # Determine the number of denoising steps current_step = int(1000 * block_state.noise_scale) - 100 + # Update state directly without intermediate variables where possible block_state.latents = latents block_state.noisy_latents = noisy_latents block_state.current_step = current_step @@ -304,8 +305,8 @@ def intermediate_outputs(self) -> list[OutputParam]: def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # Use the stream's inference method - denoised_pred = components.stream.inference( + # Use the stream's inference method - direct call without intermediate variable + block_state.denoised_pred = components.stream.inference( noise=block_state.noisy_latents, current_start=block_state.current_start, current_end=block_state.current_end, @@ -313,7 +314,6 @@ def __call__(self, components, state: PipelineState) -> PipelineState: generator=block_state.generator, ) - block_state.denoised_pred = denoised_pred self.set_block_state(state, block_state) return components, state @@ -356,9 +356,8 @@ def intermediate_outputs(self) -> list[OutputParam]: def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # Decode to pixel space - output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred) - block_state.output = output + # Decode to pixel space - direct assignment to reduce overhead + block_state.output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred) self.set_block_state(state, block_state) return components, state diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index d19df3b2d..61530793c 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -266,6 +266,14 @@ def __call__( current_step = int(1000 * self.noise_scale) - 100 state.set("current_step", current_step) + # Set prompt_embeds in state if available from conditional_dict to skip text_encoder work + if ( + hasattr(self.modular_pipeline.stream, "conditional_dict") + and self.modular_pipeline.stream.conditional_dict is not None + and "prompt_embeds" in self.modular_pipeline.stream.conditional_dict + ): + state.set("prompt_embeds", self.modular_pipeline.stream.conditional_dict["prompt_embeds"]) + # Execute modular blocks (returns tuple: components, state) _, state = self.modular_blocks(self.modular_pipeline, state) From 9367eb55efc88dfebddcb7730a3ff5f42d5d1ba5 Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Wed, 5 Nov 2025 09:27:18 +0000 Subject: [PATCH 5/7] Keep separate modules in separate files Signed-off-by: Rafal Leszko --- pipelines/streamdiffusionv2/decoders.py | 69 ++++ pipelines/streamdiffusionv2/denoise.py | 97 +++++ pipelines/streamdiffusionv2/encoders.py | 141 +++++++ pipelines/streamdiffusionv2/modular_blocks.py | 354 +----------------- pipelines/streamdiffusionv2/preprocess.py | 127 +++++++ 5 files changed, 440 insertions(+), 348 deletions(-) create mode 100644 pipelines/streamdiffusionv2/decoders.py create mode 100644 pipelines/streamdiffusionv2/denoise.py create mode 100644 pipelines/streamdiffusionv2/encoders.py create mode 100644 pipelines/streamdiffusionv2/preprocess.py diff --git a/pipelines/streamdiffusionv2/decoders.py b/pipelines/streamdiffusionv2/decoders.py new file mode 100644 index 000000000..46955c918 --- /dev/null +++ b/pipelines/streamdiffusionv2/decoders.py @@ -0,0 +1,69 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + InputParam, + OutputParam, +) + + +class StreamDiffusionV2PostprocessStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Postprocess step that decodes denoised latents to pixel space" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "denoised_pred", + required=True, + type_hint=torch.Tensor, + description="Denoised latents", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "output", + type_hint=torch.Tensor, + description="Decoded video frames", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Decode to pixel space - direct assignment to reduce overhead + block_state.output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred) + + self.set_block_state(state, block_state) + return components, state diff --git a/pipelines/streamdiffusionv2/denoise.py b/pipelines/streamdiffusionv2/denoise.py new file mode 100644 index 000000000..3129d1032 --- /dev/null +++ b/pipelines/streamdiffusionv2/denoise.py @@ -0,0 +1,97 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + InputParam, + OutputParam, +) + + +class StreamDiffusionV2DenoiseStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Denoise step that performs inference using the stream pipeline" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "noisy_latents", + required=True, + type_hint=torch.Tensor, + description="Noisy latents to denoise", + ), + InputParam( + "current_start", + required=True, + type_hint=int, + description="Current start position", + ), + InputParam( + "current_end", + required=True, + type_hint=int, + description="Current end position", + ), + InputParam( + "current_step", + required=True, + type_hint=int, + description="Current denoising step", + ), + InputParam( + "generator", + description="Random number generator", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "denoised_pred", + type_hint=torch.Tensor, + description="Denoised prediction", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Use the stream's inference method - direct call without intermediate variable + block_state.denoised_pred = components.stream.inference( + noise=block_state.noisy_latents, + current_start=block_state.current_start, + current_end=block_state.current_end, + current_step=block_state.current_step, + generator=block_state.generator, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/pipelines/streamdiffusionv2/encoders.py b/pipelines/streamdiffusionv2/encoders.py new file mode 100644 index 000000000..cd52f45b9 --- /dev/null +++ b/pipelines/streamdiffusionv2/encoders.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import regex as re +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, +) +from diffusers.utils import is_ftfy_available +from diffusers.utils import logging as diffusers_logging + +if is_ftfy_available(): + import ftfy + +logger = diffusers_logging.get_logger(__name__) + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class StreamDiffusionV2TextEncoderStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def description(self) -> str: + return "Text Encoder step that generates text_embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="negative text embeddings used to guide the image generation", + ), + InputParam("attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) + and not isinstance(block_state.prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" + ) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # For streamdiffusionv2, prompt encoding is handled by prompt_blender + # The actual prompt_embeds are set via conditional_dict in the pipeline before blocks execute + # Skip all work if prompt_embeds already exist (which they should via conditional_dict) + if block_state.prompt_embeds is None: + # Only check inputs if we actually need to encode + self.check_inputs(block_state) + # Use the stream's text encoder if needed + if hasattr(block_state, "prompt") and block_state.prompt is not None: + conditional_dict = components.stream.text_encoder( + text_prompts=[block_state.prompt] + ) + block_state.prompt_embeds = conditional_dict["prompt_embeds"] + else: + # Default empty prompt + conditional_dict = components.stream.text_encoder(text_prompts=[""]) + block_state.prompt_embeds = conditional_dict["prompt_embeds"] + self.set_block_state(state, block_state) + # If prompt_embeds already exists, skip state update to reduce overhead + + return components, state diff --git a/pipelines/streamdiffusionv2/modular_blocks.py b/pipelines/streamdiffusionv2/modular_blocks.py index fa15e082a..f63ecb531 100644 --- a/pipelines/streamdiffusionv2/modular_blocks.py +++ b/pipelines/streamdiffusionv2/modular_blocks.py @@ -12,359 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import html - -import regex as re -import torch -from diffusers.modular_pipelines import ( - ModularPipelineBlocks, - PipelineState, - SequentialPipelineBlocks, -) -from diffusers.modular_pipelines.modular_pipeline_utils import ( - ComponentSpec, - ConfigSpec, - InputParam, - OutputParam, -) -from diffusers.utils import is_ftfy_available from diffusers.utils import logging as diffusers_logging +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict -if is_ftfy_available(): - import ftfy +from .preprocess import StreamDiffusionV2PreprocessStep +from .decoders import StreamDiffusionV2PostprocessStep +from .encoders import StreamDiffusionV2TextEncoderStep +from .denoise import StreamDiffusionV2DenoiseStep logger = diffusers_logging.get_logger(__name__) - -def basic_clean(text): - if is_ftfy_available(): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text - - -class StreamDiffusionV2TextEncoderStep(ModularPipelineBlocks): - model_name = "StreamDiffusionV2" - - @property - def description(self) -> str: - return "Text Encoder step that generates text_embeddings to guide the video generation" - - @property - def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec("stream", torch.nn.Module), - ] - - @property - def expected_configs(self) -> list[ConfigSpec]: - return [] - - @property - def inputs(self) -> list[InputParam]: - return [ - InputParam("prompt"), - InputParam("negative_prompt"), - InputParam( - "prompt_embeds", - type_hint=torch.Tensor, - description="text embeddings used to guide the image generation", - ), - InputParam( - "negative_prompt_embeds", - type_hint=torch.Tensor, - description="negative text embeddings used to guide the image generation", - ), - InputParam("attention_kwargs"), - ] - - @property - def intermediate_outputs(self) -> list[OutputParam]: - return [ - OutputParam( - "prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", - description="text embeddings used to guide the image generation", - ), - OutputParam( - "negative_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", - description="negative text embeddings used to guide the image generation", - ), - ] - - @staticmethod - def check_inputs(block_state): - if block_state.prompt is not None and ( - not isinstance(block_state.prompt, str) - and not isinstance(block_state.prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" - ) - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - # For streamdiffusionv2, prompt encoding is handled by prompt_blender - # The actual prompt_embeds are set via conditional_dict in the pipeline before blocks execute - # Skip all work if prompt_embeds already exist (which they should via conditional_dict) - if block_state.prompt_embeds is None: - # Only check inputs if we actually need to encode - self.check_inputs(block_state) - # Use the stream's text encoder if needed - if hasattr(block_state, "prompt") and block_state.prompt is not None: - conditional_dict = components.stream.text_encoder( - text_prompts=[block_state.prompt] - ) - block_state.prompt_embeds = conditional_dict["prompt_embeds"] - else: - # Default empty prompt - conditional_dict = components.stream.text_encoder(text_prompts=[""]) - block_state.prompt_embeds = conditional_dict["prompt_embeds"] - self.set_block_state(state, block_state) - # If prompt_embeds already exists, skip state update to reduce overhead - - return components, state - - -class StreamDiffusionV2PreprocessStep(ModularPipelineBlocks): - model_name = "StreamDiffusionV2" - - @property - def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec("stream", torch.nn.Module), - ] - - @property - def description(self) -> str: - return "Preprocess step that encodes input frames to latents and adds noise" - - @property - def inputs(self) -> list[InputParam]: - return [ - InputParam( - "input", - required=True, - type_hint=torch.Tensor, - description="Input video frames in BCTHW format", - ), - InputParam( - "noise_scale", - type_hint=float, - default=0.7, - description="Scale of noise to add to latents", - ), - InputParam( - "generator", - description="Random number generator for noise", - ), - InputParam( - "base_seed", - type_hint=int, - default=42, - description="Base seed for random number generation", - ), - InputParam( - "current_start", - type_hint=int, - default=0, - description="Current start position for seed offset", - ), - ] - - @property - def intermediate_outputs(self) -> list[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="Encoded and noised latents", - ), - OutputParam( - "noisy_latents", - type_hint=torch.Tensor, - description="Noisy latents ready for denoising", - ), - OutputParam( - "current_step", - type_hint=int, - description="Current denoising step", - ), - ] - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - # Encode frames to latents using VAE - latents = components.stream.vae.model.stream_encode(block_state.input) - # Transpose latents - latents = latents.transpose(2, 1) - - # Create generator from seed for reproducible generation - frame_seed = block_state.base_seed + block_state.current_start - rng = torch.Generator(device=latents.device).manual_seed(frame_seed) - - noise = torch.randn( - latents.shape, - device=latents.device, - dtype=latents.dtype, - generator=rng, - ) - # Determine how noisy the latents should be - noisy_latents = noise * block_state.noise_scale + latents * ( - 1 - block_state.noise_scale - ) - - # Determine the number of denoising steps - current_step = int(1000 * block_state.noise_scale) - 100 - - # Update state directly without intermediate variables where possible - block_state.latents = latents - block_state.noisy_latents = noisy_latents - block_state.current_step = current_step - block_state.generator = rng - - self.set_block_state(state, block_state) - return components, state - - -class StreamDiffusionV2DenoiseStep(ModularPipelineBlocks): - model_name = "StreamDiffusionV2" - - @property - def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec("stream", torch.nn.Module), - ] - - @property - def description(self) -> str: - return "Denoise step that performs inference using the stream pipeline" - - @property - def inputs(self) -> list[InputParam]: - return [ - InputParam( - "noisy_latents", - required=True, - type_hint=torch.Tensor, - description="Noisy latents to denoise", - ), - InputParam( - "current_start", - required=True, - type_hint=int, - description="Current start position", - ), - InputParam( - "current_end", - required=True, - type_hint=int, - description="Current end position", - ), - InputParam( - "current_step", - required=True, - type_hint=int, - description="Current denoising step", - ), - InputParam( - "generator", - description="Random number generator", - ), - ] - - @property - def intermediate_outputs(self) -> list[OutputParam]: - return [ - OutputParam( - "denoised_pred", - type_hint=torch.Tensor, - description="Denoised prediction", - ), - ] - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - # Use the stream's inference method - direct call without intermediate variable - block_state.denoised_pred = components.stream.inference( - noise=block_state.noisy_latents, - current_start=block_state.current_start, - current_end=block_state.current_end, - current_step=block_state.current_step, - generator=block_state.generator, - ) - - self.set_block_state(state, block_state) - return components, state - - -class StreamDiffusionV2PostprocessStep(ModularPipelineBlocks): - model_name = "StreamDiffusionV2" - - @property - def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec("stream", torch.nn.Module), - ] - - @property - def description(self) -> str: - return "Postprocess step that decodes denoised latents to pixel space" - - @property - def inputs(self) -> list[InputParam]: - return [ - InputParam( - "denoised_pred", - required=True, - type_hint=torch.Tensor, - description="Denoised latents", - ), - ] - - @property - def intermediate_outputs(self) -> list[OutputParam]: - return [ - OutputParam( - "output", - type_hint=torch.Tensor, - description="Decoded video frames", - ), - ] - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - # Decode to pixel space - direct assignment to reduce overhead - block_state.output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred) - - self.set_block_state(state, block_state) - return components, state - - -from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict - VIDEO2VIDEO_BLOCKS = InsertableDict( [ ("text_encoder", StreamDiffusionV2TextEncoderStep), diff --git a/pipelines/streamdiffusionv2/preprocess.py b/pipelines/streamdiffusionv2/preprocess.py new file mode 100644 index 000000000..7b7e8ca4f --- /dev/null +++ b/pipelines/streamdiffusionv2/preprocess.py @@ -0,0 +1,127 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + InputParam, + OutputParam, +) + + +class StreamDiffusionV2PreprocessStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Preprocess step that encodes input frames to latents and adds noise" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "input", + required=True, + type_hint=torch.Tensor, + description="Input video frames in BCTHW format", + ), + InputParam( + "noise_scale", + type_hint=float, + default=0.7, + description="Scale of noise to add to latents", + ), + InputParam( + "generator", + description="Random number generator for noise", + ), + InputParam( + "base_seed", + type_hint=int, + default=42, + description="Base seed for random number generation", + ), + InputParam( + "current_start", + type_hint=int, + default=0, + description="Current start position for seed offset", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="Encoded and noised latents", + ), + OutputParam( + "noisy_latents", + type_hint=torch.Tensor, + description="Noisy latents ready for denoising", + ), + OutputParam( + "current_step", + type_hint=int, + description="Current denoising step", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Encode frames to latents using VAE + latents = components.stream.vae.model.stream_encode(block_state.input) + # Transpose latents + latents = latents.transpose(2, 1) + + # Create generator from seed for reproducible generation + frame_seed = block_state.base_seed + block_state.current_start + rng = torch.Generator(device=latents.device).manual_seed(frame_seed) + + noise = torch.randn( + latents.shape, + device=latents.device, + dtype=latents.dtype, + generator=rng, + ) + # Determine how noisy the latents should be + noisy_latents = noise * block_state.noise_scale + latents * ( + 1 - block_state.noise_scale + ) + + # Determine the number of denoising steps + current_step = int(1000 * block_state.noise_scale) - 100 + + # Update state directly without intermediate variables where possible + block_state.latents = latents + block_state.noisy_latents = noisy_latents + block_state.current_step = current_step + block_state.generator = rng + + self.set_block_state(state, block_state) + return components, state From f9fa269e6098ff80f4247416cf2e27c5eab7179a Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Wed, 5 Nov 2025 10:03:29 +0000 Subject: [PATCH 6/7] Use ComponentsManager Signed-off-by: Rafal Leszko --- .../streamdiffusionv2/components_loader.py | 108 ++++++++++++++++++ .../modular_pipeline_wrapper.py | 51 --------- pipelines/streamdiffusionv2/pipeline.py | 57 +++++---- 3 files changed, 140 insertions(+), 76 deletions(-) create mode 100644 pipelines/streamdiffusionv2/components_loader.py delete mode 100644 pipelines/streamdiffusionv2/modular_pipeline_wrapper.py diff --git a/pipelines/streamdiffusionv2/components_loader.py b/pipelines/streamdiffusionv2/components_loader.py new file mode 100644 index 000000000..4779c7921 --- /dev/null +++ b/pipelines/streamdiffusionv2/components_loader.py @@ -0,0 +1,108 @@ +"""Helper functions to load components using ComponentsManager.""" + +import os +import time + +import torch +from diffusers.modular_pipelines.components_manager import ComponentsManager + +from .vendor.causvid.models.wan.causal_stream_inference import ( + CausalStreamInferencePipeline, +) + + +class ComponentProvider: + """Simple wrapper to provide component access from ComponentsManager to blocks.""" + + def __init__(self, components_manager: ComponentsManager, component_name: str, collection: str = "streamdiffusionv2"): + """ + Initialize the component provider. + + Args: + components_manager: The ComponentsManager instance + component_name: Name of the component to provide + collection: Collection name for retrieving the component + """ + self.components_manager = components_manager + self.component_name = component_name + self.collection = collection + # Cache the component to avoid repeated lookups + self._component = None + + @property + def stream(self): + """Provide access to the stream component.""" + if self._component is None: + self._component = self.components_manager.get_one( + name=self.component_name, collection=self.collection + ) + return self._component + + +def load_stream_component( + config, + device, + dtype, + model_dir, + components_manager: ComponentsManager, + collection: str = "streamdiffusionv2", +) -> ComponentProvider: + """ + Load the CausalStreamInferencePipeline and add it to ComponentsManager. + + Args: + config: Configuration dictionary for the pipeline + device: Device to run the pipeline on + dtype: Data type for the pipeline + model_dir: Directory containing the model files + components_manager: ComponentsManager instance to add component to + collection: Collection name for organizing components + + Returns: + ComponentProvider: A provider that gives access to the stream component + """ + # Check if component already exists in ComponentsManager + try: + existing = components_manager.get_one(name="stream", collection=collection) + # Component exists, create provider for it + print(f"Reusing existing stream component from collection '{collection}'") + return ComponentProvider(components_manager, "stream", collection) + except Exception: + # Component doesn't exist, create and add it + pass + + # Create and initialize the stream pipeline + stream = CausalStreamInferencePipeline(config, device).to( + device=device, dtype=dtype + ) + + # Load the generator state dict + start = time.time() + model_path = os.path.join(model_dir, "StreamDiffusionV2/model.pt") + if not os.path.exists(model_path): + raise FileNotFoundError( + f"Model file not found at {model_path}. " + "Please ensure StreamDiffusionV2/model.pt exists in the model directory." + ) + + state_dict_data = torch.load(model_path, map_location="cpu") + + # Handle both dict with "generator" key and direct state dict + if isinstance(state_dict_data, dict) and "generator" in state_dict_data: + state_dict = state_dict_data["generator"] + else: + state_dict = state_dict_data + + stream.generator.load_state_dict(state_dict, strict=True) + print(f"Loaded diffusion state dict in {time.time() - start:.3f}s") + + # Add component to ComponentsManager + component_id = components_manager.add( + "stream", + stream, + collection=collection, + ) + print(f"Added stream component to ComponentsManager with ID: {component_id}") + + # Create and return provider + return ComponentProvider(components_manager, "stream", collection) diff --git a/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py b/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py deleted file mode 100644 index fc9ced765..000000000 --- a/pipelines/streamdiffusionv2/modular_pipeline_wrapper.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Wrapper to expose CausalStreamInferencePipeline as a ModularPipeline for Modular Diffusers.""" - -import os -import time - -import torch -from diffusers.modular_pipelines import ModularPipeline - -from .vendor.causvid.models.wan.causal_stream_inference import ( - CausalStreamInferencePipeline, -) - - -class StreamDiffusionV2ModularPipeline(ModularPipeline): - """Wrapper that exposes CausalStreamInferencePipeline as a ModularPipeline.""" - - def __init__(self, config, device, dtype, model_dir): - """ - Initialize the wrapper. - - Args: - config: Configuration dictionary for the pipeline - device: Device to run the pipeline on - dtype: Data type for the pipeline - model_dir: Directory containing the model files - """ - # Create and initialize the stream pipeline - self.stream = CausalStreamInferencePipeline(config, device).to( - device=device, dtype=dtype - ) - - # Load the generator state dict - start = time.time() - state_dict = torch.load( - os.path.join(model_dir, "StreamDiffusionV2/model.pt"), - map_location="cpu", - )["generator"] - self.stream.generator.load_state_dict(state_dict, strict=True) - print(f"Loaded diffusion state dict in {time.time() - start:.3f}s") - - self._execution_device_val = next(self.stream.generator.parameters()).device - - @property - def _execution_device(self): - """Return the execution device.""" - return self._execution_device_val - - @_execution_device.setter - def _execution_device(self, value): - """Set the execution device.""" - self._execution_device_val = value diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index 61530793c..d3a8b619a 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -2,12 +2,13 @@ import torch from diffusers.modular_pipelines import PipelineState +from diffusers.modular_pipelines.components_manager import ComponentsManager from ..blending import PromptBlender, handle_transition_prepare from ..interface import Pipeline, Requirements from ..process import postprocess_chunk, preprocess_chunk +from .components_loader import load_stream_component, ComponentProvider from .modular_blocks import StreamDiffusionV2Blocks -from .modular_pipeline_wrapper import StreamDiffusionV2ModularPipeline # https://github.com/daydreamlive/scope/blob/0cf1766186be3802bf97ce550c2c978439f22068/pipelines/streamdiffusionv2/vendor/causvid/models/wan/causal_model.py#L306 MAX_ROPE_FREQ_TABLE_SEQ_LEN = 1024 @@ -55,10 +56,15 @@ def __init__( self.prompts = None self.denoising_step_list = None - # Initialize Modular Diffusers blocks and pipeline + # Initialize ComponentsManager + self.components_manager = ComponentsManager() + + # Initialize Modular Diffusers blocks self.modular_blocks = StreamDiffusionV2Blocks() - self.modular_pipeline = StreamDiffusionV2ModularPipeline( - config, device, dtype, config.model_dir + + # Load stream component using ComponentsManager + self.component_provider = load_stream_component( + config, device, dtype, config.model_dir, self.components_manager ) # Prompt blending with cache reset callback for transitions @@ -68,7 +74,7 @@ def __init__( self.last_frame = None self.current_start = 0 - self.current_end = self.modular_pipeline.stream.frame_seq_length * 2 + self.current_end = self.component_provider.stream.frame_seq_length * 2 def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: if should_prepare: @@ -91,7 +97,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # Handle prompt transition requests should_prepare_from_transition, target_prompts = handle_transition_prepare( - transition, self.prompt_blender, self.modular_pipeline.stream.text_encoder + transition, self.prompt_blender, self.component_provider.stream.text_encoder ) if target_prompts: self.prompts = target_prompts @@ -123,7 +129,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # When we hit the limit we reset the caches and indices # See this issue for more context https://github.com/daydreamlive/scope/issues/95 max_current_start = ( - MAX_ROPE_FREQ_TABLE_SEQ_LEN * self.modular_pipeline.stream.frame_seq_length + MAX_ROPE_FREQ_TABLE_SEQ_LEN * self.component_provider.stream.frame_seq_length ) # We reset at whatever is smaller the theoretically max value or some % of it max_current_start = min( @@ -137,7 +143,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # Update internal state before preparing pipeline if denoising_step_list is not None: self.denoising_step_list = denoising_step_list - self.modular_pipeline.stream.denoising_step_list = torch.tensor( + self.component_provider.stream.denoising_step_list = torch.tensor( denoising_step_list, dtype=torch.long, device=self.device ) @@ -156,16 +162,16 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: @torch.no_grad() def _prepare_pipeline(self, prompts=None, interpolation_method="linear"): # Trigger KV + cross-attn cache re-initialization in prepare() - self.modular_pipeline.stream.kv_cache1 = None + self.component_provider.stream.kv_cache1 = None # Apply prompt blending and set conditional_dict self._apply_prompt_blending(prompts, interpolation_method) - self.modular_pipeline.stream.vae.model.first_batch = True + self.component_provider.stream.vae.model.first_batch = True self.last_frame = None self.current_start = 0 - self.current_end = self.modular_pipeline.stream.frame_seq_length * 2 + self.current_end = self.component_provider.stream.frame_seq_length * 2 def _apply_motion_aware_noise_controller(self, input: torch.Tensor): # The prev seq is the last chunk_size frames of the current input @@ -217,11 +223,11 @@ def __call__( # Update prompt embedding for this generation call # Handles both static blending and temporal transitions next_embedding = self.prompt_blender.get_next_embedding( - self.modular_pipeline.stream.text_encoder + self.component_provider.stream.text_encoder ) if next_embedding is not None: - self.modular_pipeline.stream.conditional_dict = { + self.component_provider.stream.conditional_dict = { "prompt_embeds": next_embedding } @@ -259,7 +265,7 @@ def __call__( state.set("current_start", self.current_start) state.set("current_end", self.current_end) state.set( - "denoising_step_list", self.modular_pipeline.stream.denoising_step_list + "denoising_step_list", self.component_provider.stream.denoising_step_list ) # Determine the number of denoising steps @@ -268,14 +274,15 @@ def __call__( # Set prompt_embeds in state if available from conditional_dict to skip text_encoder work if ( - hasattr(self.modular_pipeline.stream, "conditional_dict") - and self.modular_pipeline.stream.conditional_dict is not None - and "prompt_embeds" in self.modular_pipeline.stream.conditional_dict + hasattr(self.component_provider.stream, "conditional_dict") + and self.component_provider.stream.conditional_dict is not None + and "prompt_embeds" in self.component_provider.stream.conditional_dict ): - state.set("prompt_embeds", self.modular_pipeline.stream.conditional_dict["prompt_embeds"]) + state.set("prompt_embeds", self.component_provider.stream.conditional_dict["prompt_embeds"]) # Execute modular blocks (returns tuple: components, state) - _, state = self.modular_blocks(self.modular_pipeline, state) + # Pass component_provider which provides components.stream access + _, state = self.modular_blocks(self.component_provider, state) # Get output from state output = state.values.get("output") @@ -292,28 +299,28 @@ def __call__( self.current_start = self.current_end self.current_end += ( self.chunk_size // 4 - ) * self.modular_pipeline.stream.frame_seq_length + ) * self.component_provider.stream.frame_seq_length return postprocess_chunk(output) def _initialize_stream_caches(self): """Initialize stream caches without overriding conditional_dict.""" noise = torch.zeros(1, 1).to(self.device, self.dtype) - saved = self.modular_pipeline.stream.conditional_dict - self.modular_pipeline.stream.prepare(noise, text_prompts=[""]) - self.modular_pipeline.stream.conditional_dict = saved + saved = self.component_provider.stream.conditional_dict + self.component_provider.stream.prepare(noise, text_prompts=[""]) + self.component_provider.stream.conditional_dict = saved def _apply_prompt_blending(self, prompts=None, interpolation_method="linear"): """Apply weighted blending of cached prompt embeddings.""" combined_embeds = self.prompt_blender.blend( - prompts, interpolation_method, self.modular_pipeline.stream.text_encoder + prompts, interpolation_method, self.component_provider.stream.text_encoder ) if combined_embeds is None: return # Set the blended embeddings on the stream - self.modular_pipeline.stream.conditional_dict = { + self.component_provider.stream.conditional_dict = { "prompt_embeds": combined_embeds } From fb60cdbed52ec775e3f4c265adea40f199a8ca19 Mon Sep 17 00:00:00 2001 From: Rafal Leszko Date: Wed, 5 Nov 2025 10:16:11 +0000 Subject: [PATCH 7/7] Do not reload pipeline if only params have changed Signed-off-by: Rafal Leszko --- lib/pipeline_manager.py | 88 +++++++++++++++++++++---- pipelines/streamdiffusionv2/pipeline.py | 28 ++++++++ 2 files changed, 103 insertions(+), 13 deletions(-) diff --git a/lib/pipeline_manager.py b/lib/pipeline_manager.py index 4c783d6f1..86ecfe3d4 100644 --- a/lib/pipeline_manager.py +++ b/lib/pipeline_manager.py @@ -127,11 +127,20 @@ def _load_pipeline_sync_wrapper( ) -> bool: """Synchronous wrapper for pipeline loading with proper locking.""" with self._lock: - # If already loaded with same type and same params, return success + # If already loading, someone else is handling it + if self._status == PipelineStatus.LOADING: + logger.info("Pipeline already loading by another thread") + return False + + # Determine pipeline type + if pipeline_id is None: + pipeline_id = os.getenv("PIPELINE", "longlive") + # Normalize None to empty dict for comparison current_params = self._load_params or {} new_params = load_params or {} + # If already loaded with same type and same params, return success if ( self._status == PipelineStatus.LOADED and self._pipeline_id == pipeline_id @@ -142,25 +151,41 @@ def _load_pipeline_sync_wrapper( ) return True - # If a different pipeline is loaded OR same pipeline with different params, unload it first - if self._status == PipelineStatus.LOADED and ( - self._pipeline_id != pipeline_id or current_params != new_params + # If same pipeline but different params, try to update instead of reloading + if ( + self._status == PipelineStatus.LOADED + and self._pipeline_id == pipeline_id + and current_params != new_params ): + try: + logger.info( + f"Updating pipeline {pipeline_id} parameters without reloading models" + ) + updated = self._update_pipeline_params(pipeline_id, new_params) + if updated: + self._load_params = load_params + logger.info(f"Pipeline {pipeline_id} parameters updated successfully") + return True + else: + # Update failed, fall through to reload + logger.info( + f"Pipeline update not supported, reloading pipeline {pipeline_id}" + ) + self._unload_pipeline_unsafe() + except Exception as e: + logger.warning( + f"Failed to update pipeline parameters: {e}. Reloading pipeline." + ) + self._unload_pipeline_unsafe() + + # If a different pipeline is loaded, unload it first + if self._status == PipelineStatus.LOADED and self._pipeline_id != pipeline_id: self._unload_pipeline_unsafe() - # If already loading, someone else is handling it - if self._status == PipelineStatus.LOADING: - logger.info("Pipeline already loading by another thread") - return False - try: self._status = PipelineStatus.LOADING self._error_message = None - # Determine pipeline type - if pipeline_id is None: - pipeline_id = os.getenv("PIPELINE", "longlive") - logger.info(f"Loading pipeline: {pipeline_id}") # Load the pipeline synchronously (we're already in executor thread) @@ -186,6 +211,43 @@ def _load_pipeline_sync_wrapper( return False + def _update_pipeline_params(self, pipeline_id: str, load_params: dict) -> bool: + """ + Update pipeline parameters without reloading models. + + Args: + pipeline_id: ID of the pipeline to update + load_params: New load parameters + + Returns: + bool: True if update was successful, False if not supported + """ + if self._pipeline is None: + return False + + # Only streamdiffusionv2 currently supports parameter updates + if pipeline_id == "streamdiffusionv2": + try: + from pipelines.streamdiffusionv2.pipeline import StreamDiffusionV2Pipeline + + if isinstance(self._pipeline, StreamDiffusionV2Pipeline): + height = load_params.get("height") + width = load_params.get("width") + seed = load_params.get("seed") + + self._pipeline.update_params( + height=height, + width=width, + seed=seed + ) + return True + except Exception as e: + logger.error(f"Error updating streamdiffusionv2 pipeline params: {e}") + return False + + # Other pipelines don't support parameter updates yet + return False + def _unload_pipeline_unsafe(self): """Unload the current pipeline. Must be called with lock held.""" if self._pipeline: diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index d3a8b619a..ec03bc5a1 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -76,6 +76,34 @@ def __init__( self.current_start = 0 self.current_end = self.component_provider.stream.frame_seq_length * 2 + def update_params(self, height: int | None = None, width: int | None = None, seed: int | None = None): + """ + Update pipeline parameters without reloading model weights. + + Args: + height: New height (will be rounded to nearest multiple of SCALE_SIZE) + width: New width (will be rounded to nearest multiple of SCALE_SIZE) + seed: New seed value + """ + if height is not None or width is not None: + req_height = height if height is not None else self.height + req_width = width if width is not None else self.width + new_height = round(req_height / SCALE_SIZE) * SCALE_SIZE + new_width = round(req_width / SCALE_SIZE) * SCALE_SIZE + + if new_height != self.height or new_width != self.width: + logger.info(f"Updating resolution from {self.width}x{self.height} to {new_width}x{new_height}") + self.height = new_height + self.width = new_width + # Reset caches when resolution changes + self.last_frame = None + self.current_start = 0 + self.current_end = self.component_provider.stream.frame_seq_length * 2 + + if seed is not None and seed != self.base_seed: + logger.info(f"Updating seed from {self.base_seed} to {seed}") + self.base_seed = seed + def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: if should_prepare: logger.info("prepare: Initiating pipeline prepare for request")