Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions lib/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,20 @@ def _load_pipeline_sync_wrapper(
) -> bool:
"""Synchronous wrapper for pipeline loading with proper locking."""
with self._lock:
# If already loaded with same type and same params, return success
# If already loading, someone else is handling it
if self._status == PipelineStatus.LOADING:
logger.info("Pipeline already loading by another thread")
return False

# Determine pipeline type
if pipeline_id is None:
pipeline_id = os.getenv("PIPELINE", "longlive")

# Normalize None to empty dict for comparison
current_params = self._load_params or {}
new_params = load_params or {}

# If already loaded with same type and same params, return success
if (
self._status == PipelineStatus.LOADED
and self._pipeline_id == pipeline_id
Expand All @@ -142,25 +151,41 @@ def _load_pipeline_sync_wrapper(
)
return True

# If a different pipeline is loaded OR same pipeline with different params, unload it first
if self._status == PipelineStatus.LOADED and (
self._pipeline_id != pipeline_id or current_params != new_params
# If same pipeline but different params, try to update instead of reloading
if (
self._status == PipelineStatus.LOADED
and self._pipeline_id == pipeline_id
and current_params != new_params
):
try:
logger.info(
f"Updating pipeline {pipeline_id} parameters without reloading models"
)
updated = self._update_pipeline_params(pipeline_id, new_params)
if updated:
self._load_params = load_params
logger.info(f"Pipeline {pipeline_id} parameters updated successfully")
return True
else:
# Update failed, fall through to reload
logger.info(
f"Pipeline update not supported, reloading pipeline {pipeline_id}"
)
self._unload_pipeline_unsafe()
except Exception as e:
logger.warning(
f"Failed to update pipeline parameters: {e}. Reloading pipeline."
)
self._unload_pipeline_unsafe()

# If a different pipeline is loaded, unload it first
if self._status == PipelineStatus.LOADED and self._pipeline_id != pipeline_id:
self._unload_pipeline_unsafe()

# If already loading, someone else is handling it
if self._status == PipelineStatus.LOADING:
logger.info("Pipeline already loading by another thread")
return False

try:
self._status = PipelineStatus.LOADING
self._error_message = None

# Determine pipeline type
if pipeline_id is None:
pipeline_id = os.getenv("PIPELINE", "longlive")

logger.info(f"Loading pipeline: {pipeline_id}")

# Load the pipeline synchronously (we're already in executor thread)
Expand All @@ -186,6 +211,43 @@ def _load_pipeline_sync_wrapper(

return False

def _update_pipeline_params(self, pipeline_id: str, load_params: dict) -> bool:
"""
Update pipeline parameters without reloading models.

Args:
pipeline_id: ID of the pipeline to update
load_params: New load parameters

Returns:
bool: True if update was successful, False if not supported
"""
if self._pipeline is None:
return False

# Only streamdiffusionv2 currently supports parameter updates
if pipeline_id == "streamdiffusionv2":
try:
from pipelines.streamdiffusionv2.pipeline import StreamDiffusionV2Pipeline

if isinstance(self._pipeline, StreamDiffusionV2Pipeline):
height = load_params.get("height")
width = load_params.get("width")
seed = load_params.get("seed")

self._pipeline.update_params(
height=height,
width=width,
seed=seed
)
return True
except Exception as e:
logger.error(f"Error updating streamdiffusionv2 pipeline params: {e}")
return False

# Other pipelines don't support parameter updates yet
return False

def _unload_pipeline_unsafe(self):
"""Unload the current pipeline. Must be called with lock held."""
if self._pipeline:
Expand Down
108 changes: 108 additions & 0 deletions pipelines/streamdiffusionv2/components_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Helper functions to load components using ComponentsManager."""

import os
import time

import torch
from diffusers.modular_pipelines.components_manager import ComponentsManager

from .vendor.causvid.models.wan.causal_stream_inference import (
CausalStreamInferencePipeline,
)


class ComponentProvider:
"""Simple wrapper to provide component access from ComponentsManager to blocks."""

def __init__(self, components_manager: ComponentsManager, component_name: str, collection: str = "streamdiffusionv2"):
"""
Initialize the component provider.

Args:
components_manager: The ComponentsManager instance
component_name: Name of the component to provide
collection: Collection name for retrieving the component
"""
self.components_manager = components_manager
self.component_name = component_name
self.collection = collection
# Cache the component to avoid repeated lookups
self._component = None

@property
def stream(self):
"""Provide access to the stream component."""
if self._component is None:
self._component = self.components_manager.get_one(
name=self.component_name, collection=self.collection
)
return self._component


def load_stream_component(
config,
device,
dtype,
model_dir,
components_manager: ComponentsManager,
collection: str = "streamdiffusionv2",
) -> ComponentProvider:
"""
Load the CausalStreamInferencePipeline and add it to ComponentsManager.

Args:
config: Configuration dictionary for the pipeline
device: Device to run the pipeline on
dtype: Data type for the pipeline
model_dir: Directory containing the model files
components_manager: ComponentsManager instance to add component to
collection: Collection name for organizing components

Returns:
ComponentProvider: A provider that gives access to the stream component
"""
# Check if component already exists in ComponentsManager
try:
existing = components_manager.get_one(name="stream", collection=collection)
# Component exists, create provider for it
print(f"Reusing existing stream component from collection '{collection}'")
return ComponentProvider(components_manager, "stream", collection)
except Exception:
# Component doesn't exist, create and add it
pass

# Create and initialize the stream pipeline
stream = CausalStreamInferencePipeline(config, device).to(
device=device, dtype=dtype
)

# Load the generator state dict
start = time.time()
model_path = os.path.join(model_dir, "StreamDiffusionV2/model.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"Model file not found at {model_path}. "
"Please ensure StreamDiffusionV2/model.pt exists in the model directory."
)

state_dict_data = torch.load(model_path, map_location="cpu")

# Handle both dict with "generator" key and direct state dict
if isinstance(state_dict_data, dict) and "generator" in state_dict_data:
state_dict = state_dict_data["generator"]
else:
state_dict = state_dict_data

stream.generator.load_state_dict(state_dict, strict=True)
print(f"Loaded diffusion state dict in {time.time() - start:.3f}s")

# Add component to ComponentsManager
component_id = components_manager.add(
"stream",
stream,
collection=collection,
)
print(f"Added stream component to ComponentsManager with ID: {component_id}")

# Create and return provider
return ComponentProvider(components_manager, "stream", collection)
69 changes: 69 additions & 0 deletions pipelines/streamdiffusionv2/decoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)


class StreamDiffusionV2PostprocessStep(ModularPipelineBlocks):
model_name = "StreamDiffusionV2"

@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("stream", torch.nn.Module),
]

@property
def description(self) -> str:
return "Postprocess step that decodes denoised latents to pixel space"

@property
def inputs(self) -> list[InputParam]:
return [
InputParam(
"denoised_pred",
required=True,
type_hint=torch.Tensor,
description="Denoised latents",
),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"output",
type_hint=torch.Tensor,
description="Decoded video frames",
),
]

@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

# Decode to pixel space - direct assignment to reduce overhead
block_state.output = components.stream.vae.stream_decode_to_pixel(block_state.denoised_pred)

self.set_block_state(state, block_state)
return components, state
97 changes: 97 additions & 0 deletions pipelines/streamdiffusionv2/denoise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)


class StreamDiffusionV2DenoiseStep(ModularPipelineBlocks):
model_name = "StreamDiffusionV2"

@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("stream", torch.nn.Module),
]

@property
def description(self) -> str:
return "Denoise step that performs inference using the stream pipeline"

@property
def inputs(self) -> list[InputParam]:
return [
InputParam(
"noisy_latents",
required=True,
type_hint=torch.Tensor,
description="Noisy latents to denoise",
),
InputParam(
"current_start",
required=True,
type_hint=int,
description="Current start position",
),
InputParam(
"current_end",
required=True,
type_hint=int,
description="Current end position",
),
InputParam(
"current_step",
required=True,
type_hint=int,
description="Current denoising step",
),
InputParam(
"generator",
description="Random number generator",
),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"denoised_pred",
type_hint=torch.Tensor,
description="Denoised prediction",
),
]

@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

# Use the stream's inference method - direct call without intermediate variable
block_state.denoised_pred = components.stream.inference(
noise=block_state.noisy_latents,
current_start=block_state.current_start,
current_end=block_state.current_end,
current_step=block_state.current_step,
generator=block_state.generator,
)

self.set_block_state(state, block_state)
return components, state
Loading