diff --git a/lib/pipeline_manager.py b/lib/pipeline_manager.py index 4c783d6f1..86ecfe3d4 100644 --- a/lib/pipeline_manager.py +++ b/lib/pipeline_manager.py @@ -127,11 +127,20 @@ def _load_pipeline_sync_wrapper( ) -> bool: """Synchronous wrapper for pipeline loading with proper locking.""" with self._lock: - # If already loaded with same type and same params, return success + # If already loading, someone else is handling it + if self._status == PipelineStatus.LOADING: + logger.info("Pipeline already loading by another thread") + return False + + # Determine pipeline type + if pipeline_id is None: + pipeline_id = os.getenv("PIPELINE", "longlive") + # Normalize None to empty dict for comparison current_params = self._load_params or {} new_params = load_params or {} + # If already loaded with same type and same params, return success if ( self._status == PipelineStatus.LOADED and self._pipeline_id == pipeline_id @@ -142,25 +151,41 @@ def _load_pipeline_sync_wrapper( ) return True - # If a different pipeline is loaded OR same pipeline with different params, unload it first - if self._status == PipelineStatus.LOADED and ( - self._pipeline_id != pipeline_id or current_params != new_params + # If same pipeline but different params, try to update instead of reloading + if ( + self._status == PipelineStatus.LOADED + and self._pipeline_id == pipeline_id + and current_params != new_params ): + try: + logger.info( + f"Updating pipeline {pipeline_id} parameters without reloading models" + ) + updated = self._update_pipeline_params(pipeline_id, new_params) + if updated: + self._load_params = load_params + logger.info(f"Pipeline {pipeline_id} parameters updated successfully") + return True + else: + # Update failed, fall through to reload + logger.info( + f"Pipeline update not supported, reloading pipeline {pipeline_id}" + ) + self._unload_pipeline_unsafe() + except Exception as e: + logger.warning( + f"Failed to update pipeline parameters: {e}. Reloading pipeline." + ) + self._unload_pipeline_unsafe() + + # If a different pipeline is loaded, unload it first + if self._status == PipelineStatus.LOADED and self._pipeline_id != pipeline_id: self._unload_pipeline_unsafe() - # If already loading, someone else is handling it - if self._status == PipelineStatus.LOADING: - logger.info("Pipeline already loading by another thread") - return False - try: self._status = PipelineStatus.LOADING self._error_message = None - # Determine pipeline type - if pipeline_id is None: - pipeline_id = os.getenv("PIPELINE", "longlive") - logger.info(f"Loading pipeline: {pipeline_id}") # Load the pipeline synchronously (we're already in executor thread) @@ -186,6 +211,43 @@ def _load_pipeline_sync_wrapper( return False + def _update_pipeline_params(self, pipeline_id: str, load_params: dict) -> bool: + """ + Update pipeline parameters without reloading models. + + Args: + pipeline_id: ID of the pipeline to update + load_params: New load parameters + + Returns: + bool: True if update was successful, False if not supported + """ + if self._pipeline is None: + return False + + # Only streamdiffusionv2 currently supports parameter updates + if pipeline_id == "streamdiffusionv2": + try: + from pipelines.streamdiffusionv2.pipeline import StreamDiffusionV2Pipeline + + if isinstance(self._pipeline, StreamDiffusionV2Pipeline): + height = load_params.get("height") + width = load_params.get("width") + seed = load_params.get("seed") + + self._pipeline.update_params( + height=height, + width=width, + seed=seed + ) + return True + except Exception as e: + logger.error(f"Error updating streamdiffusionv2 pipeline params: {e}") + return False + + # Other pipelines don't support parameter updates yet + return False + def _unload_pipeline_unsafe(self): """Unload the current pipeline. Must be called with lock held.""" if self._pipeline: diff --git a/pipelines/streamdiffusionv2/components_loader.py b/pipelines/streamdiffusionv2/components_loader.py new file mode 100644 index 000000000..4779c7921 --- /dev/null +++ b/pipelines/streamdiffusionv2/components_loader.py @@ -0,0 +1,108 @@ +"""Helper functions to load components using ComponentsManager.""" + +import os +import time + +import torch +from diffusers.modular_pipelines.components_manager import ComponentsManager + +from .vendor.causvid.models.wan.causal_stream_inference import ( + CausalStreamInferencePipeline, +) + + +class ComponentProvider: + """Simple wrapper to provide component access from ComponentsManager to blocks.""" + + def __init__(self, components_manager: ComponentsManager, component_name: str, collection: str = "streamdiffusionv2"): + """ + Initialize the component provider. + + Args: + components_manager: The ComponentsManager instance + component_name: Name of the component to provide + collection: Collection name for retrieving the component + """ + self.components_manager = components_manager + self.component_name = component_name + self.collection = collection + # Cache the component to avoid repeated lookups + self._component = None + + @property + def stream(self): + """Provide access to the stream component.""" + if self._component is None: + self._component = self.components_manager.get_one( + name=self.component_name, collection=self.collection + ) + return self._component + + +def load_stream_component( + config, + device, + dtype, + model_dir, + components_manager: ComponentsManager, + collection: str = "streamdiffusionv2", +) -> ComponentProvider: + """ + Load the CausalStreamInferencePipeline and add it to ComponentsManager. + + Args: + config: Configuration dictionary for the pipeline + device: Device to run the pipeline on + dtype: Data type for the pipeline + model_dir: Directory containing the model files + components_manager: ComponentsManager instance to add component to + collection: Collection name for organizing components + + Returns: + ComponentProvider: A provider that gives access to the stream component + """ + # Check if component already exists in ComponentsManager + try: + existing = components_manager.get_one(name="stream", collection=collection) + # Component exists, create provider for it + print(f"Reusing existing stream component from collection '{collection}'") + return ComponentProvider(components_manager, "stream", collection) + except Exception: + # Component doesn't exist, create and add it + pass + + # Create and initialize the stream pipeline + stream = CausalStreamInferencePipeline(config, device).to( + device=device, dtype=dtype + ) + + # Load the generator state dict + start = time.time() + model_path = os.path.join(model_dir, "StreamDiffusionV2/model.pt") + if not os.path.exists(model_path): + raise FileNotFoundError( + f"Model file not found at {model_path}. " + "Please ensure StreamDiffusionV2/model.pt exists in the model directory." + ) + + state_dict_data = torch.load(model_path, map_location="cpu") + + # Handle both dict with "generator" key and direct state dict + if isinstance(state_dict_data, dict) and "generator" in state_dict_data: + state_dict = state_dict_data["generator"] + else: + state_dict = state_dict_data + + stream.generator.load_state_dict(state_dict, strict=True) + print(f"Loaded diffusion state dict in {time.time() - start:.3f}s") + + # Add component to ComponentsManager + component_id = components_manager.add( + "stream", + stream, + collection=collection, + ) + print(f"Added stream component to ComponentsManager with ID: {component_id}") + + # Create and return provider + return ComponentProvider(components_manager, "stream", collection) diff --git a/pipelines/streamdiffusionv2/decoders.py b/pipelines/streamdiffusionv2/decoders.py new file mode 100644 index 000000000..46955c918 --- /dev/null +++ b/pipelines/streamdiffusionv2/decoders.py @@ -0,0 +1,69 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + InputParam, + OutputParam, +) + + +class StreamDiffusionV2PostprocessStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Postprocess step that decodes denoised latents to pixel space" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "denoised_pred", + required=True, + type_hint=torch.Tensor, + description="Denoised latents", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "output", + type_hint=torch.Tensor, + description="Decoded video frames", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Decode to pixel space - direct assignment to reduce overhead + block_state.output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred) + + self.set_block_state(state, block_state) + return components, state diff --git a/pipelines/streamdiffusionv2/denoise.py b/pipelines/streamdiffusionv2/denoise.py new file mode 100644 index 000000000..3129d1032 --- /dev/null +++ b/pipelines/streamdiffusionv2/denoise.py @@ -0,0 +1,97 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + InputParam, + OutputParam, +) + + +class StreamDiffusionV2DenoiseStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Denoise step that performs inference using the stream pipeline" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "noisy_latents", + required=True, + type_hint=torch.Tensor, + description="Noisy latents to denoise", + ), + InputParam( + "current_start", + required=True, + type_hint=int, + description="Current start position", + ), + InputParam( + "current_end", + required=True, + type_hint=int, + description="Current end position", + ), + InputParam( + "current_step", + required=True, + type_hint=int, + description="Current denoising step", + ), + InputParam( + "generator", + description="Random number generator", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "denoised_pred", + type_hint=torch.Tensor, + description="Denoised prediction", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Use the stream's inference method - direct call without intermediate variable + block_state.denoised_pred = components.stream.inference( + noise=block_state.noisy_latents, + current_start=block_state.current_start, + current_end=block_state.current_end, + current_step=block_state.current_step, + generator=block_state.generator, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/pipelines/streamdiffusionv2/encoders.py b/pipelines/streamdiffusionv2/encoders.py new file mode 100644 index 000000000..cd52f45b9 --- /dev/null +++ b/pipelines/streamdiffusionv2/encoders.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import regex as re +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, +) +from diffusers.utils import is_ftfy_available +from diffusers.utils import logging as diffusers_logging + +if is_ftfy_available(): + import ftfy + +logger = diffusers_logging.get_logger(__name__) + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class StreamDiffusionV2TextEncoderStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def description(self) -> str: + return "Text Encoder step that generates text_embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="negative text embeddings used to guide the image generation", + ), + InputParam("attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) + and not isinstance(block_state.prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" + ) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # For streamdiffusionv2, prompt encoding is handled by prompt_blender + # The actual prompt_embeds are set via conditional_dict in the pipeline before blocks execute + # Skip all work if prompt_embeds already exist (which they should via conditional_dict) + if block_state.prompt_embeds is None: + # Only check inputs if we actually need to encode + self.check_inputs(block_state) + # Use the stream's text encoder if needed + if hasattr(block_state, "prompt") and block_state.prompt is not None: + conditional_dict = components.stream.text_encoder( + text_prompts=[block_state.prompt] + ) + block_state.prompt_embeds = conditional_dict["prompt_embeds"] + else: + # Default empty prompt + conditional_dict = components.stream.text_encoder(text_prompts=[""]) + block_state.prompt_embeds = conditional_dict["prompt_embeds"] + self.set_block_state(state, block_state) + # If prompt_embeds already exists, skip state update to reduce overhead + + return components, state diff --git a/pipelines/streamdiffusionv2/modular_blocks.py b/pipelines/streamdiffusionv2/modular_blocks.py new file mode 100644 index 000000000..f63ecb531 --- /dev/null +++ b/pipelines/streamdiffusionv2/modular_blocks.py @@ -0,0 +1,42 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from diffusers.utils import logging as diffusers_logging +from diffusers.modular_pipelines import SequentialPipelineBlocks +from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict + +from .preprocess import StreamDiffusionV2PreprocessStep +from .decoders import StreamDiffusionV2PostprocessStep +from .encoders import StreamDiffusionV2TextEncoderStep +from .denoise import StreamDiffusionV2DenoiseStep + +logger = diffusers_logging.get_logger(__name__) + +VIDEO2VIDEO_BLOCKS = InsertableDict( + [ + ("text_encoder", StreamDiffusionV2TextEncoderStep), + ("preprocess", StreamDiffusionV2PreprocessStep), + ("denoise", StreamDiffusionV2DenoiseStep), + ("postprocess", StreamDiffusionV2PostprocessStep), + ] +) + +ALL_BLOCKS = { + "video2video": VIDEO2VIDEO_BLOCKS, +} + + +class StreamDiffusionV2Blocks(SequentialPipelineBlocks): + block_classes = list(VIDEO2VIDEO_BLOCKS.values()) + block_names = list(VIDEO2VIDEO_BLOCKS.keys()) diff --git a/pipelines/streamdiffusionv2/modular_config.json b/pipelines/streamdiffusionv2/modular_config.json new file mode 100644 index 000000000..e9c4d0c42 --- /dev/null +++ b/pipelines/streamdiffusionv2/modular_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "StreamDiffusionV2Blocks", + "_diffusers_version": "0.36.0.dev0", + "auto_map": { + "ModularPipelineBlocks": "modular_blocks.StreamDiffusionV2Blocks" + } +} diff --git a/pipelines/streamdiffusionv2/pipeline.py b/pipelines/streamdiffusionv2/pipeline.py index 203d326d9..ec03bc5a1 100644 --- a/pipelines/streamdiffusionv2/pipeline.py +++ b/pipelines/streamdiffusionv2/pipeline.py @@ -1,15 +1,14 @@ import logging -import os -import time import torch +from diffusers.modular_pipelines import PipelineState +from diffusers.modular_pipelines.components_manager import ComponentsManager from ..blending import PromptBlender, handle_transition_prepare from ..interface import Pipeline, Requirements from ..process import postprocess_chunk, preprocess_chunk -from .vendor.causvid.models.wan.causal_stream_inference import ( - CausalStreamInferencePipeline, -) +from .components_loader import load_stream_component, ComponentProvider +from .modular_blocks import StreamDiffusionV2Blocks # https://github.com/daydreamlive/scope/blob/0cf1766186be3802bf97ce550c2c978439f22068/pipelines/streamdiffusionv2/vendor/causvid/models/wan/causal_model.py#L306 MAX_ROPE_FREQ_TABLE_SEQ_LEN = 1024 @@ -46,20 +45,9 @@ def __init__( config["height"] = self.height config["width"] = self.width - self.stream = CausalStreamInferencePipeline(config, device).to( - device=device, dtype=dtype - ) self.device = device self.dtype = dtype - start = time.time() - state_dict = torch.load( - os.path.join(config.model_dir, "StreamDiffusionV2/model.pt"), - map_location="cpu", - )["generator"] - self.stream.generator.load_state_dict(state_dict, strict=True) - print(f"Loaded diffusion state dict in {time.time() - start:.3f}s") - self.chunk_size = chunk_size self.start_chunk_size = start_chunk_size self.noise_scale = noise_scale @@ -68,6 +56,17 @@ def __init__( self.prompts = None self.denoising_step_list = None + # Initialize ComponentsManager + self.components_manager = ComponentsManager() + + # Initialize Modular Diffusers blocks + self.modular_blocks = StreamDiffusionV2Blocks() + + # Load stream component using ComponentsManager + self.component_provider = load_stream_component( + config, device, dtype, config.model_dir, self.components_manager + ) + # Prompt blending with cache reset callback for transitions self.prompt_blender = PromptBlender( device, dtype, cache_reset_callback=self._initialize_stream_caches @@ -75,7 +74,35 @@ def __init__( self.last_frame = None self.current_start = 0 - self.current_end = self.stream.frame_seq_length * 2 + self.current_end = self.component_provider.stream.frame_seq_length * 2 + + def update_params(self, height: int | None = None, width: int | None = None, seed: int | None = None): + """ + Update pipeline parameters without reloading model weights. + + Args: + height: New height (will be rounded to nearest multiple of SCALE_SIZE) + width: New width (will be rounded to nearest multiple of SCALE_SIZE) + seed: New seed value + """ + if height is not None or width is not None: + req_height = height if height is not None else self.height + req_width = width if width is not None else self.width + new_height = round(req_height / SCALE_SIZE) * SCALE_SIZE + new_width = round(req_width / SCALE_SIZE) * SCALE_SIZE + + if new_height != self.height or new_width != self.width: + logger.info(f"Updating resolution from {self.width}x{self.height} to {new_width}x{new_height}") + self.height = new_height + self.width = new_width + # Reset caches when resolution changes + self.last_frame = None + self.current_start = 0 + self.current_end = self.component_provider.stream.frame_seq_length * 2 + + if seed is not None and seed != self.base_seed: + logger.info(f"Updating seed from {self.base_seed} to {seed}") + self.base_seed = seed def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: if should_prepare: @@ -98,7 +125,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # Handle prompt transition requests should_prepare_from_transition, target_prompts = handle_transition_prepare( - transition, self.prompt_blender, self.stream.text_encoder + transition, self.prompt_blender, self.component_provider.stream.text_encoder ) if target_prompts: self.prompts = target_prompts @@ -129,7 +156,9 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # We need to make sure that current_start does not shift past the max sequence length of the RoPE frequency table # When we hit the limit we reset the caches and indices # See this issue for more context https://github.com/daydreamlive/scope/issues/95 - max_current_start = MAX_ROPE_FREQ_TABLE_SEQ_LEN * self.stream.frame_seq_length + max_current_start = ( + MAX_ROPE_FREQ_TABLE_SEQ_LEN * self.component_provider.stream.frame_seq_length + ) # We reset at whatever is smaller the theoretically max value or some % of it max_current_start = min( int(max_current_start * CURRENT_START_RESET_RATIO), max_current_start @@ -142,7 +171,7 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: # Update internal state before preparing pipeline if denoising_step_list is not None: self.denoising_step_list = denoising_step_list - self.stream.denoising_step_list = torch.tensor( + self.component_provider.stream.denoising_step_list = torch.tensor( denoising_step_list, dtype=torch.long, device=self.device ) @@ -161,16 +190,16 @@ def prepare(self, should_prepare: bool = False, **kwargs) -> Requirements: @torch.no_grad() def _prepare_pipeline(self, prompts=None, interpolation_method="linear"): # Trigger KV + cross-attn cache re-initialization in prepare() - self.stream.kv_cache1 = None + self.component_provider.stream.kv_cache1 = None # Apply prompt blending and set conditional_dict self._apply_prompt_blending(prompts, interpolation_method) - self.stream.vae.model.first_batch = True + self.component_provider.stream.vae.model.first_batch = True self.last_frame = None self.current_start = 0 - self.current_end = self.stream.frame_seq_length * 2 + self.current_end = self.component_provider.stream.frame_seq_length * 2 def _apply_motion_aware_noise_controller(self, input: torch.Tensor): # The prev seq is the last chunk_size frames of the current input @@ -222,11 +251,13 @@ def __call__( # Update prompt embedding for this generation call # Handles both static blending and temporal transitions next_embedding = self.prompt_blender.get_next_embedding( - self.stream.text_encoder + self.component_provider.stream.text_encoder ) if next_embedding is not None: - self.stream.conditional_dict = {"prompt_embeds": next_embedding} + self.component_provider.stream.conditional_dict = { + "prompt_embeds": next_embedding + } # Note: The caller must call prepare() before __call__() # We just need to get the expected chunk size based on current state @@ -252,66 +283,74 @@ def __call__( if noise_controller: self._apply_motion_aware_noise_controller(input) + # Use Modular Diffusers blocks to process the input + state = PipelineState() + + # Set up state for modular blocks + state.set("input", input) + state.set("noise_scale", self.noise_scale) + state.set("base_seed", self.base_seed) + state.set("current_start", self.current_start) + state.set("current_end", self.current_end) + state.set( + "denoising_step_list", self.component_provider.stream.denoising_step_list + ) + # Determine the number of denoising steps - # Higher noise scale -> more denoising steps, more intense changes to input - # Lower noise scale -> less denoising steps, less intense changes to input current_step = int(1000 * self.noise_scale) - 100 + state.set("current_step", current_step) - # Encode frames to latents using VAE - latents = self.stream.vae.model.stream_encode(input) - # Transpose latents - latents = latents.transpose(2, 1) - - # Create generator from seed for reproducible generation - # Derive unique seed per chunk using current_start as offset - frame_seed = self.base_seed + self.current_start - rng = torch.Generator(device=latents.device).manual_seed(frame_seed) - - noise = torch.randn( - latents.shape, - device=latents.device, - dtype=latents.dtype, - generator=rng, - ) - # Determine how noisy the latents should be - # Higher noise scale -> noiser latents, less of inputs preserved - # Lower noise scale -> less noisy latents, more of inputs preserved - noisy_latents = noise * self.noise_scale + latents * (1 - self.noise_scale) - denoised_pred = self.stream.inference( - noise=noisy_latents, - current_start=self.current_start, - current_end=self.current_end, - current_step=current_step, - generator=rng, - ) + # Set prompt_embeds in state if available from conditional_dict to skip text_encoder work + if ( + hasattr(self.component_provider.stream, "conditional_dict") + and self.component_provider.stream.conditional_dict is not None + and "prompt_embeds" in self.component_provider.stream.conditional_dict + ): + state.set("prompt_embeds", self.component_provider.stream.conditional_dict["prompt_embeds"]) + + # Execute modular blocks (returns tuple: components, state) + # Pass component_provider which provides components.stream access + _, state = self.modular_blocks(self.component_provider, state) + + # Get output from state + output = state.values.get("output") + + if output is None: + raise RuntimeError("Modular blocks did not produce output") + + # Ensure output is in the right format + if not isinstance(output, torch.Tensor): + output = output[0] if isinstance(output, list) else output - # # Update tracking variables for next input + # Update tracking variables for next input self.last_frame = input[:, :, [-1]] self.current_start = self.current_end - self.current_end += (self.chunk_size // 4) * self.stream.frame_seq_length + self.current_end += ( + self.chunk_size // 4 + ) * self.component_provider.stream.frame_seq_length - # Decode to pixel space - output = self.stream.vae.stream_decode_to_pixel(denoised_pred) return postprocess_chunk(output) def _initialize_stream_caches(self): """Initialize stream caches without overriding conditional_dict.""" noise = torch.zeros(1, 1).to(self.device, self.dtype) - saved = self.stream.conditional_dict - self.stream.prepare(noise, text_prompts=[""]) - self.stream.conditional_dict = saved + saved = self.component_provider.stream.conditional_dict + self.component_provider.stream.prepare(noise, text_prompts=[""]) + self.component_provider.stream.conditional_dict = saved def _apply_prompt_blending(self, prompts=None, interpolation_method="linear"): """Apply weighted blending of cached prompt embeddings.""" combined_embeds = self.prompt_blender.blend( - prompts, interpolation_method, self.stream.text_encoder + prompts, interpolation_method, self.component_provider.stream.text_encoder ) if combined_embeds is None: return # Set the blended embeddings on the stream - self.stream.conditional_dict = {"prompt_embeds": combined_embeds} + self.component_provider.stream.conditional_dict = { + "prompt_embeds": combined_embeds + } # Initialize caches without overriding conditional_dict self._initialize_stream_caches() diff --git a/pipelines/streamdiffusionv2/preprocess.py b/pipelines/streamdiffusionv2/preprocess.py new file mode 100644 index 000000000..7b7e8ca4f --- /dev/null +++ b/pipelines/streamdiffusionv2/preprocess.py @@ -0,0 +1,127 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers.modular_pipelines import ( + ModularPipelineBlocks, + PipelineState, +) +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + InputParam, + OutputParam, +) + + +class StreamDiffusionV2PreprocessStep(ModularPipelineBlocks): + model_name = "StreamDiffusionV2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("stream", torch.nn.Module), + ] + + @property + def description(self) -> str: + return "Preprocess step that encodes input frames to latents and adds noise" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "input", + required=True, + type_hint=torch.Tensor, + description="Input video frames in BCTHW format", + ), + InputParam( + "noise_scale", + type_hint=float, + default=0.7, + description="Scale of noise to add to latents", + ), + InputParam( + "generator", + description="Random number generator for noise", + ), + InputParam( + "base_seed", + type_hint=int, + default=42, + description="Base seed for random number generation", + ), + InputParam( + "current_start", + type_hint=int, + default=0, + description="Current start position for seed offset", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="Encoded and noised latents", + ), + OutputParam( + "noisy_latents", + type_hint=torch.Tensor, + description="Noisy latents ready for denoising", + ), + OutputParam( + "current_step", + type_hint=int, + description="Current denoising step", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Encode frames to latents using VAE + latents = components.stream.vae.model.stream_encode(block_state.input) + # Transpose latents + latents = latents.transpose(2, 1) + + # Create generator from seed for reproducible generation + frame_seed = block_state.base_seed + block_state.current_start + rng = torch.Generator(device=latents.device).manual_seed(frame_seed) + + noise = torch.randn( + latents.shape, + device=latents.device, + dtype=latents.dtype, + generator=rng, + ) + # Determine how noisy the latents should be + noisy_latents = noise * block_state.noise_scale + latents * ( + 1 - block_state.noise_scale + ) + + # Determine the number of denoising steps + current_step = int(1000 * block_state.noise_scale) - 100 + + # Update state directly without intermediate variables where possible + block_state.latents = latents + block_state.noisy_latents = noisy_latents + block_state.current_step = current_step + block_state.generator = rng + + self.set_block_state(state, block_state) + return components, state