From 132ec10f5c7b47f855ee12fc9879651d9125446a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Leszko?= Date: Thu, 11 Dec 2025 09:30:22 +0100 Subject: [PATCH 1/2] Add Decart API pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rafał Leszko --- frontend/src/data/pipelines.ts | 10 + frontend/src/types/index.ts | 3 +- pyproject.toml | 1 + src/scope/core/pipelines/__init__.py | 10 + .../core/pipelines/decart_api/__init__.py | 3 + .../core/pipelines/decart_api/pipeline.py | 555 ++++++++++++++++++ src/scope/core/pipelines/decart_api/test.py | 189 ++++++ src/scope/core/pipelines/registry.py | 2 + src/scope/core/pipelines/schema.py | 27 + src/scope/server/pipeline_manager.py | 29 + src/scope/server/schema.py | 27 + uv.lock | 2 + 12 files changed, 857 insertions(+), 1 deletion(-) create mode 100644 src/scope/core/pipelines/decart_api/__init__.py create mode 100644 src/scope/core/pipelines/decart_api/pipeline.py create mode 100644 src/scope/core/pipelines/decart_api/test.py diff --git a/frontend/src/data/pipelines.ts b/frontend/src/data/pipelines.ts index 8ab68cdec..a63aa67ec 100644 --- a/frontend/src/data/pipelines.ts +++ b/frontend/src/data/pipelines.ts @@ -98,6 +98,16 @@ export const PIPELINES: Record = { supportedModes: ["video"], defaultMode: "video", }, + "decart-api": { + name: "Decart API", + about: + "Real-time video restyling using Decart's Mirage LSD API. Processes video frames through Decart's cloud-based realtime video transformation service.", + requiresModels: false, + estimatedVram: 0, // Cloud-based, no local VRAM required + // Video-only pipeline + supportedModes: ["video"], + defaultMode: "video", + }, }; export function pipelineSupportsLoRA(pipelineId: string): boolean { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index f33c8fecc..b5aad3f67 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -3,7 +3,8 @@ export type PipelineId = | "passthrough" | "longlive" | "krea-realtime-video" - | "reward-forcing"; + | "reward-forcing" + | "decart-api"; // Input mode for pipeline operation export type InputMode = "text" | "video"; diff --git a/pyproject.toml b/pyproject.toml index 8423ca919..79d885c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "kernels>=0.10.4", "triton==3.4.0; sys_platform == 'linux'", "triton-windows==3.4.0.post21; sys_platform == 'win32'", + "pillow>=10.0.0", ] [project.scripts] diff --git a/src/scope/core/pipelines/__init__.py b/src/scope/core/pipelines/__init__.py index 9ea633032..b57fcc7a5 100644 --- a/src/scope/core/pipelines/__init__.py +++ b/src/scope/core/pipelines/__init__.py @@ -24,6 +24,10 @@ def __getattr__(name): from .passthrough.pipeline import PassthroughPipeline return PassthroughPipeline + elif name == "DecartApiPipeline": + from .decart_api.pipeline import DecartApiPipeline + + return DecartApiPipeline # Config classes elif name == "BasePipelineConfig": from .schema import BasePipelineConfig @@ -45,6 +49,10 @@ def __getattr__(name): from .schema import PassthroughConfig return PassthroughConfig + elif name == "DecartApiConfig": + from .schema import DecartApiConfig + + return DecartApiConfig raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -55,10 +63,12 @@ def __getattr__(name): "RewardForcingPipeline", "StreamDiffusionV2Pipeline", "PassthroughPipeline", + "DecartApiPipeline", # Config classes "BasePipelineConfig", "LongLiveConfig", "StreamDiffusionV2Config", "KreaRealtimeVideoConfig", "PassthroughConfig", + "DecartApiConfig", ] diff --git a/src/scope/core/pipelines/decart_api/__init__.py b/src/scope/core/pipelines/decart_api/__init__.py new file mode 100644 index 000000000..4e5024f07 --- /dev/null +++ b/src/scope/core/pipelines/decart_api/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import DecartApiPipeline + +__all__ = ["DecartApiPipeline"] diff --git a/src/scope/core/pipelines/decart_api/pipeline.py b/src/scope/core/pipelines/decart_api/pipeline.py new file mode 100644 index 000000000..fbd469c10 --- /dev/null +++ b/src/scope/core/pipelines/decart_api/pipeline.py @@ -0,0 +1,555 @@ +import asyncio +import logging +import os +import queue +import threading +import time +from typing import TYPE_CHECKING + +import numpy as np +import torch +from einops import rearrange +from PIL import Image + +from ..interface import Pipeline, Requirements +from ..schema import DecartApiConfig + +if TYPE_CHECKING: + from ..schema import BasePipelineConfig + +logger = logging.getLogger(__name__) + +try: + from decart import DecartClient + from decart import models as decart_models + from decart.realtime.client import RealtimeClient + from decart.realtime.types import RealtimeConnectOptions, ModelState + from aiortc import MediaStreamTrack + from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE + from av import VideoFrame + + DECART_AVAILABLE = True + REALTIME_AVAILABLE = True +except ImportError as e: + DECART_AVAILABLE = False + REALTIME_AVAILABLE = False + logger.warning( + f"Decart SDK not available: {e}. Install with: pip install decart" + ) + + +class FrameSourceTrack(MediaStreamTrack): + """Custom MediaStreamTrack that feeds frames from a queue to Decart API.""" + + kind = "video" + + def __init__(self, frame_queue: queue.Queue, fps: int = 22, width: int = 1280, height: int = 704): + super().__init__() + self.frame_queue = frame_queue + self.fps = fps + self.width = width + self.height = height + self.frame_ptime = 1.0 / fps + self.timestamp = 0 + self.start_time = None + self.last_frame_time = None + + async def recv(self) -> VideoFrame: + """Return the next frame from the queue.""" + # Wait for a frame to be available + try: + frame_np = self.frame_queue.get(timeout=0.1) + except queue.Empty: + # If no frame available, create a black frame + # Use configured dimensions + frame_np = np.zeros( + (self.height, self.width, 3), dtype=np.uint8 + ) # Use configured dimensions + + # Convert numpy array to VideoFrame + frame = VideoFrame.from_ndarray(frame_np, format="rgb24") + + # Set timestamp + if self.start_time is None: + self.start_time = time.time() + self.last_frame_time = self.start_time + self.timestamp = 0 + else: + current_time = time.time() + time_since_last = current_time - self.last_frame_time + wait_time = self.frame_ptime - time_since_last + if wait_time > 0: + await asyncio.sleep(wait_time) + self.timestamp += int(self.frame_ptime * VIDEO_CLOCK_RATE) + self.last_frame_time = time.time() + + frame.pts = self.timestamp + frame.time_base = VIDEO_TIME_BASE + return frame + + +class DecartApiPipeline(Pipeline): + """Pipeline that processes video frames through Decart API.""" + + @classmethod + def get_config_class(cls) -> type["BasePipelineConfig"]: + return DecartApiConfig + + def __init__( + self, + config, + device: torch.device | None = None, + dtype: torch.dtype = torch.bfloat16, + ): + if not DECART_AVAILABLE: + raise ImportError( + "Decart SDK is required. Install with: pip install decart" + ) + + self.height = config.height + self.width = config.width + if device is not None: + self.device = device + else: + device_name = "cuda" if torch.cuda.is_available() else "cpu" + self.device = torch.device(device_name) + self.dtype = dtype + + # Get API key from environment + api_key = os.getenv("DECART_API_KEY") + if not api_key: + raise ValueError( + "DECART_API_KEY environment variable is required" + ) + + # Initialize Decart client + self.client = DecartClient(api_key=api_key) + + # Get the model from environment variable, default to mirage_v2 + model_name = os.getenv("DECART_MODEL", "mirage_v2") + logger.info(f"Using Decart model: {model_name}") + self.model = decart_models.realtime(model_name) + + # Store current prompt + self.current_prompt = None + # Track last prompt we saw in prompts parameter to detect stale values + self.last_prompts_value = None + + # WebRTC connection state + self.realtime_client = None + self.connection_established = False + self.async_loop = None + self.async_thread = None + self.shutdown_event = threading.Event() + + # Queues for frame exchange between sync and async worlds + self.input_frame_queue = queue.Queue(maxsize=10) + self.output_frame_queue = queue.Queue(maxsize=10) + + # MediaStreamTrack for feeding frames to Decart + self.local_track = None + + # Start async thread for WebRTC connection + if REALTIME_AVAILABLE: + self._start_async_thread() + # Wait for connection to establish (with timeout) + max_wait = 10.0 + wait_interval = 0.1 + waited = 0.0 + while ( + not self.connection_established + and waited < max_wait + and not self.shutdown_event.is_set() + ): + time.sleep(wait_interval) + waited += wait_interval + if self.connection_established: + logger.info("Decart connection established") + else: + logger.warning( + f"Decart connection not established after {waited}s, " + "will use passthrough mode" + ) + + logger.info("DecartApiPipeline initialized") + + def _start_async_thread(self): + """Start background thread for async WebRTC operations.""" + def run_async_loop(): + self.async_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.async_loop) + try: + self.async_loop.run_until_complete(self._async_worker()) + except Exception as e: + logger.error(f"Async worker error: {e}", exc_info=True) + + self.async_thread = threading.Thread( + target=run_async_loop, daemon=True + ) + self.async_thread.start() + + async def _async_worker(self): + """Async worker that manages WebRTC connection and frame processing.""" + try: + # Create local track for feeding frames to Decart + # Use config resolution to preserve aspect ratio (e.g., square input -> square output) + logger.info( + f"Creating FrameSourceTrack with config resolution: " + f"{self.width}x{self.height} @ {self.model.fps}fps" + ) + self.local_track = FrameSourceTrack( + self.input_frame_queue, + fps=self.model.fps, + width=self.width, + height=self.height, + ) + + # Set up callback for receiving processed frames + def on_remote_stream(remote_track: MediaStreamTrack): + """Callback when remote stream is available.""" + logger.info( + f"Remote stream received from Decart: {remote_track}" + ) + self.connection_established = True + # Start receiving frames from remote track + receive_task = asyncio.create_task( + self._receive_remote_frames(remote_track) + ) + # Store task to prevent garbage collection + if not hasattr(self, '_receive_tasks'): + self._receive_tasks = [] + self._receive_tasks.append(receive_task) + + # Create connection options + initial_state = None + if self.current_prompt: + from decart.types import Prompt + initial_state = ModelState( + prompt=Prompt(text=self.current_prompt, enrich=True) + ) + + options = RealtimeConnectOptions( + model=self.model, + on_remote_stream=on_remote_stream, + initial_state=initial_state, + ) + + # Connect to Decart realtime API + logger.info("Connecting to Decart realtime API...") + self.realtime_client = await RealtimeClient.connect( + base_url=self.client.base_url, + api_key=self.client.api_key, + local_track=self.local_track, + options=options, + ) + logger.info("Connected to Decart realtime API") + + # Keep connection alive + while not self.shutdown_event.is_set(): + await asyncio.sleep(1.0) + + except Exception as e: + logger.error(f"Error in async worker: {e}", exc_info=True) + self.connection_established = False + + async def _receive_remote_frames(self, remote_track: MediaStreamTrack): + """Receive processed frames from Decart and queue them.""" + logger.info("Starting to receive remote frames from Decart") + frame_count = 0 + try: + while not self.shutdown_event.is_set(): + try: + # Receive frame from Decart + frame = await asyncio.wait_for( + remote_track.recv(), timeout=1.0 + ) + # Convert VideoFrame to numpy array + frame_np = frame.to_ndarray(format="rgb24") + frame_count += 1 + if frame_count % 30 == 0: + logger.debug( + f"Received {frame_count} frames from Decart" + ) + # Put in output queue + try: + self.output_frame_queue.put_nowait(frame_np) + except queue.Full: + # Drop oldest frame if queue is full + try: + self.output_frame_queue.get_nowait() + except queue.Empty: + pass + self.output_frame_queue.put_nowait(frame_np) + except asyncio.TimeoutError: + logger.debug("Timeout waiting for remote frame") + continue + except Exception as e: + logger.error(f"Error receiving remote frame: {e}") + await asyncio.sleep(0.1) + except Exception as e: + logger.error(f"Error in receive_remote_frames: {e}", exc_info=True) + + def prepare(self, **kwargs) -> Requirements: + """Return input requirements for video mode.""" + # Process one frame at a time for realtime + return Requirements(input_size=1) + + def __call__( + self, + **kwargs, + ) -> torch.Tensor: + """ + Process video frames through Decart API. + + Args: + **kwargs: Pipeline parameters including: + - video: Input video frames (list of tensors or tensor) + - prompts: Text prompt for style transformation + - transition: Optional transition dict with target_prompts + (takes precedence over prompts if provided) + + Returns: + Processed frames as tensor in THWC format [0, 1] range + """ + input_video = kwargs.get("video") + prompts = kwargs.get("prompts") + transition = kwargs.get("transition") + + if input_video is None: + raise ValueError( + "Input video cannot be None for DecartApiPipeline" + ) + + # Convert input to list of tensors if needed + if isinstance(input_video, list): + tensor_frames = input_video + else: + # Assume it's a BCTHW tensor, convert to list + # Rearrange to THWC and split into frames + if len(input_video.shape) == 5: # BCTHW + input_video = rearrange(input_video, "B C T H W -> B T C H W") + input_video = input_video.squeeze(0) # Remove batch dim -> T C H W + input_video = rearrange(input_video, "T C H W -> T H W C") + # Convert to list of (1, H, W, C) tensors + tensor_frames = [ + input_video[i].unsqueeze(0) + for i in range(input_video.shape[0]) + ] + + # Extract prompt text - simple: check transition first, then prompts + # But only update if it's actually different from current + prompt_text = None + + # Get prompt from transition.target_prompts if available + if transition is not None: + target_prompts = transition.get("target_prompts") + if target_prompts and len(target_prompts) > 0: + first_prompt = target_prompts[0] + extracted = first_prompt.get("text", "") + # Only use if different from current (prevents duplicate updates) + if extracted != self.current_prompt: + prompt_text = extracted + logger.info( + f"Using prompt from transition: '{prompt_text}' " + f"(current: '{self.current_prompt}')" + ) + + + # Update prompt if we have a new one + if prompt_text: + logger.info( + f"Prompt change detected: '{self.current_prompt}' -> " + f"'{prompt_text}'" + ) + self._update_prompt(prompt_text) + + # Process frames through Decart API + # Decart's realtime API uses WebRTC streams with async handling. + processed_frames = [] + for frame_tensor in tensor_frames: + # Convert tensor to numpy array (H, W, C) in [0, 255] range + # frame_tensor is (1, H, W, C) from the list + frame_np = frame_tensor.squeeze(0).cpu().numpy() # (H, W, C) + + # Ensure it's 3D (H, W, C) + if frame_np.ndim == 2: + # If grayscale, convert to RGB + frame_np = np.stack([frame_np] * 3, axis=-1) + elif frame_np.ndim != 3: + raise ValueError(f"Unexpected frame shape: {frame_np.shape}") + + # Normalize to [0, 255] if needed + if frame_np.max() <= 1.0: + frame_np = (frame_np * 255.0).astype(np.uint8) + else: + frame_np = frame_np.astype(np.uint8) + + # Resize if needed to match configured output resolution + # Use config resolution to preserve aspect ratio (e.g., square input -> square output) + needs_resize = ( + frame_np.shape[0] != self.height + or frame_np.shape[1] != self.width + ) + if needs_resize: + img = Image.fromarray(frame_np) + img = img.resize( + (self.width, self.height), + Image.Resampling.LANCZOS + ) + frame_np = np.array(img) + # Ensure it's still 3D after resize + if frame_np.ndim == 2: + frame_np = np.stack([frame_np] * 3, axis=-1) + + # Send frame to Decart API via WebRTC + processed_frame_np = self._process_frame_with_decart( + frame_np, prompt_text + ) + + # Convert processed frame back to tensor + # processed_frame_np is (H, W, C) + processed_frame = ( + torch.from_numpy(processed_frame_np).float() / 255.0 + ) + # processed_frame is now (H, W, C) - ensure it's 3D + if processed_frame.ndim != 3: + raise ValueError( + f"Processed frame must be 3D (H, W, C), " + f"got {processed_frame.shape}" + ) + processed_frames.append(processed_frame) + + # Stack frames and return in THWC format + # Always return (T, H, W, C) format where T is the number of frames + if len(processed_frames) == 1: + # For single frame, ensure we have time dimension + output = processed_frames[0].unsqueeze(0) # (1, H, W, C) + else: + output = torch.stack(processed_frames, dim=0) # (T, H, W, C) + + return output + + def _process_frame_with_decart( + self, frame_np: np.ndarray, prompt_text: str | None + ) -> np.ndarray: + """Process a frame through Decart API or pass through.""" + if REALTIME_AVAILABLE and self.connection_established: + # Note: Prompt updates are now handled in __call__ before + # processing frames, so we don't need to check here + + # Send frame to input queue for WebRTC stream + try: + self.input_frame_queue.put_nowait(frame_np) + except queue.Full: + # Drop oldest frame if queue is full + try: + self.input_frame_queue.get_nowait() + except queue.Empty: + pass + self.input_frame_queue.put_nowait(frame_np) + + # Get processed frame from output queue + try: + processed_frame_np = self.output_frame_queue.get(timeout=2.0) + # Resize output frame to match configured resolution + # Decart might return frames at model resolution, so ensure output matches config + if (processed_frame_np.shape[0] != self.height or + processed_frame_np.shape[1] != self.width): + img = Image.fromarray(processed_frame_np) + img = img.resize( + (self.width, self.height), + Image.Resampling.LANCZOS + ) + processed_frame_np = np.array(img) + except queue.Empty: + # Timeout - use input frame as fallback + logger.warning( + "Timeout waiting for processed frame from Decart, " + "using input frame" + ) + processed_frame_np = frame_np + else: + # Fallback: pass through if realtime not available + if not REALTIME_AVAILABLE: + logger.debug( + "Realtime API not available, passing through" + ) + elif not self.connection_established: + logger.debug( + "Connection not established yet, passing through" + ) + processed_frame_np = frame_np + + return processed_frame_np + + def _update_prompt(self, prompt_text: str): + """Update prompt in Decart API.""" + if not self.realtime_client: + logger.warning( + "Cannot update prompt: realtime_client is not initialized" + ) + return + if not self.async_loop: + logger.warning( + "Cannot update prompt: async_loop is not available" + ) + return + + try: + logger.info(f"Calling set_prompt with: '{prompt_text}'") + # set_prompt expects a string, not a Prompt object + # Based on the error, it calls .strip() on the prompt + future = asyncio.run_coroutine_threadsafe( + self.realtime_client.set_prompt(prompt_text), + self.async_loop, + ) + # Increase timeout to 5 seconds as API calls may take longer + future.result(timeout=5.0) + self.current_prompt = prompt_text + logger.info(f"Successfully updated prompt to: '{prompt_text}'") + except asyncio.TimeoutError: + logger.error( + f"Timeout updating prompt to '{prompt_text}' " + "(API call took > 5 seconds)" + ) + except Exception as e: + logger.error( + f"Failed to update prompt to '{prompt_text}': {e}", + exc_info=True + ) + + def __del__(self): + """Cleanup on pipeline destruction.""" + # Signal shutdown + if hasattr(self, 'shutdown_event'): + self.shutdown_event.set() + + # Cleanup async connection + if hasattr(self, 'async_loop') and self.async_loop: + if hasattr(self, 'realtime_client') and self.realtime_client: + try: + # Schedule disconnect in async loop + asyncio.run_coroutine_threadsafe( + self._disconnect_async(), self.async_loop + ) + except Exception as e: + logger.warning(f"Error disconnecting: {e}") + + # Cleanup Decart client + if hasattr(self, 'client') and self.client: + try: + # Client cleanup will happen when it goes out of scope + pass + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + + async def _disconnect_async(self): + """Async cleanup method.""" + if hasattr(self, 'realtime_client') and self.realtime_client: + try: + # Disconnect realtime client + # Note: RealtimeClient might have a disconnect method + if hasattr(self.realtime_client, 'disconnect'): + await self.realtime_client.disconnect() + except Exception as e: + logger.warning(f"Error disconnecting realtime client: {e}") diff --git a/src/scope/core/pipelines/decart_api/test.py b/src/scope/core/pipelines/decart_api/test.py new file mode 100644 index 000000000..834710ba3 --- /dev/null +++ b/src/scope/core/pipelines/decart_api/test.py @@ -0,0 +1,189 @@ +import os +import time +from pathlib import Path + +import numpy as np +import torch +from diffusers.utils import export_to_video +from einops import rearrange +from omegaconf import OmegaConf + +try: + from ..video import load_video + VIDEO_LOADING_AVAILABLE = True +except Exception as e: + # Try alternative video loading with imageio + try: + import imageio + VIDEO_LOADING_AVAILABLE = True + USE_IMAGEIO = True + except ImportError: + print(f"Warning: Video loading not available: {e}") + print("Will skip video loading test") + VIDEO_LOADING_AVAILABLE = False + USE_IMAGEIO = False +else: + USE_IMAGEIO = False + +from .pipeline import DecartApiPipeline + +# Check for API key +api_key = os.getenv("DECART_API_KEY") +if not api_key: + raise ValueError( + "DECART_API_KEY environment variable is required. " + "Set it before running this test." + ) + +# Create config - using same resolution as streamdiffusionv2 test +config = OmegaConf.create( + { + "height": 480, + "width": 832, + "seed": 42, + } +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# Initialize pipeline +print("Initializing DecartApiPipeline...") +pipeline = DecartApiPipeline( + config, + device=device, + dtype=torch.bfloat16, +) +print("Pipeline initialized successfully!") + +if not VIDEO_LOADING_AVAILABLE: + print("\n=== Skipping video loading test (video loading not available) ===") + print("Please install FFmpeg and torchcodec to test with video input.") + exit(0) + +# Load the same input video as streamdiffusionv2 test +print("\n=== Loading input video ===") +video_path = Path(__file__).parent.parent / "streamdiffusionv2" / "assets" / "original.mp4" +if not video_path.exists(): + raise FileNotFoundError( + f"Input video not found at {video_path}. " + "Please ensure the streamdiffusionv2/assets/original.mp4 file exists." + ) + +if USE_IMAGEIO: + # Use imageio as fallback + import imageio + from PIL import Image + + print("Using imageio to load video...") + reader = imageio.get_reader(str(video_path)) + frames = [] + for frame in reader: + # Resize frame + img = Image.fromarray(frame) + img = img.resize((config.width, config.height), Image.Resampling.LANCZOS) + frame_resized = torch.from_numpy(np.array(img)).float() + frames.append(frame_resized) + reader.close() + + # Stack frames: T H W C -> C T H W + input_video_cthw = torch.stack(frames, dim=0) # T H W C + input_video_cthw = rearrange(input_video_cthw, "T H W C -> C T H W") + + # Convert to BCTHW and ensure [0, 255] range + input_video_bcthw = rearrange(input_video_cthw, "C T H W -> 1 C T H W") + input_video_bcthw = input_video_bcthw.clamp(0, 255) +else: + # Load video - load_video returns CTHW format, normalized to [-1, 1] + # We need to convert to [0, 255] range and BCTHW format for the pipeline + input_video_cthw = load_video( + str(video_path), + resize_hw=(config.height, config.width), + normalize=True, # Returns [-1, 1] range + ) + + # Convert from CTHW to BCTHW and denormalize to [0, 255] + input_video_bcthw = rearrange(input_video_cthw, "C T H W -> 1 C T H W") + # Denormalize from [-1, 1] to [0, 255] + input_video_bcthw = ((input_video_bcthw + 1.0) / 2.0 * 255.0).clamp(0, 255) + +_, _, num_input_frames, _, _ = input_video_bcthw.shape +print(f"Input video loaded: {num_input_frames} frames, shape={input_video_bcthw.shape}") + +# Process video in chunks (decart_api processes one frame at a time) +chunk_size = 1 # Decart API processes one frame at a time +prompts = [{"text": "a bear is walking on the grass", "weight": 100}] + +outputs = [] +latency_measures = [] +fps_measures = [] +total_input_frames = 0 +total_output_frames = 0 + +print(f"\n=== Processing {num_input_frames} frames in chunks of {chunk_size} ===") +for start_idx in range(0, num_input_frames, chunk_size): + end_idx = min(start_idx + chunk_size, num_input_frames) + chunk = input_video_bcthw[:, :, start_idx:end_idx] + + start = time.time() + # Process chunk through pipeline + output = pipeline(video=chunk, prompts=prompts) + latency = time.time() - start + + num_output_frames, _, _, _ = output.shape + fps = num_output_frames / latency if latency > 0 else 0 + + input_frames_in_chunk = end_idx - start_idx + total_input_frames += input_frames_in_chunk + total_output_frames += num_output_frames + + print( + f"Chunk [{start_idx}:{end_idx}]: " + f"Input {input_frames_in_chunk} frames -> " + f"Output {num_output_frames} frames, " + f"latency={latency:.3f}s, " + f"fps={fps:.2f}" + ) + + latency_measures.append(latency) + fps_measures.append(fps) + outputs.append(output.detach().cpu()) + +# Concatenate all outputs +output_video = torch.concat(outputs) +print(f"\n=== Frame Count Comparison ===") +print(f"Total input frames: {total_input_frames}") +print(f"Total output frames: {total_output_frames}") +print(f"Frame ratio (output/input): {total_output_frames / total_input_frames:.3f}") + +if abs(total_output_frames - total_input_frames) <= 1: + print("✓ Output frame count matches input frame count!") +else: + print( + f"⚠ Frame count difference: {abs(total_output_frames - total_input_frames)} frames" + ) + +print(f"\nOutput video shape: {output_video.shape}") + +# Export to video +output_path = Path(__file__).parent / "output.mp4" +output_video_np = output_video.contiguous().numpy() +export_to_video(output_video_np, output_path, fps=16) +print(f"Output video saved to: {output_path}") + +# Print statistics +print("\n=== Performance Statistics ===") +if latency_measures: + print( + f"Latency - Avg: {sum(latency_measures) / len(latency_measures):.3f}s, " + f"Max: {max(latency_measures):.3f}s, " + f"Min: {min(latency_measures):.3f}s" + ) +if fps_measures: + print( + f"FPS - Avg: {sum(fps_measures) / len(fps_measures):.2f}, " + f"Max: {max(fps_measures):.2f}, " + f"Min: {min(fps_measures):.2f}" + ) + +print("\n=== Test completed successfully! ===") diff --git a/src/scope/core/pipelines/registry.py b/src/scope/core/pipelines/registry.py index 493804be4..bffd6270f 100644 --- a/src/scope/core/pipelines/registry.py +++ b/src/scope/core/pipelines/registry.py @@ -68,6 +68,7 @@ def list_pipelines(cls) -> list[str]: def _register_pipelines(): """Register all built-in pipelines.""" # Import lazily to avoid circular imports and heavy dependencies + from .decart_api.pipeline import DecartApiPipeline from .krea_realtime_video.pipeline import KreaRealtimeVideoPipeline from .longlive.pipeline import LongLivePipeline from .passthrough.pipeline import PassthroughPipeline @@ -81,6 +82,7 @@ def _register_pipelines(): StreamDiffusionV2Pipeline, PassthroughPipeline, RewardForcingPipeline, + DecartApiPipeline, ]: config_class = pipeline_class.get_config_class() PipelineRegistry.register(config_class.pipeline_id, pipeline_class) diff --git a/src/scope/core/pipelines/schema.py b/src/scope/core/pipelines/schema.py index db11474e9..4f1736ad8 100644 --- a/src/scope/core/pipelines/schema.py +++ b/src/scope/core/pipelines/schema.py @@ -418,6 +418,32 @@ class PassthroughConfig(BasePipelineConfig): ) +class DecartApiConfig(BasePipelineConfig): + """Configuration for Decart API pipeline. + + Decart API only supports video mode - it processes video frames through Decart's realtime API. + """ + + pipeline_id: ClassVar[str] = "decart-api" + pipeline_name: ClassVar[str] = "Decart API" + pipeline_description: ClassVar[str] = ( + "Real-time video restyling using Decart's Mirage LSD API" + ) + + # Mode support - video only + supported_modes: ClassVar[list[InputMode]] = ["video"] + default_mode: ClassVar[InputMode] = "video" + + # Decart API defaults - model requires specific resolution + # Mirage LSD model typically uses 512x512 or similar + height: int = Field(default=512, ge=1, description="Output height in pixels") + width: int = Field(default=512, ge=1, description="Output width in pixels") + input_size: int | None = Field( + default=1, + description="Expected input video frame count (realtime processes one frame at a time)", + ) + + # Registry of pipeline config classes PIPELINE_CONFIGS: dict[str, type[BasePipelineConfig]] = { "longlive": LongLiveConfig, @@ -425,6 +451,7 @@ class PassthroughConfig(BasePipelineConfig): "krea-realtime-video": KreaRealtimeVideoConfig, "reward-forcing": RewardForcingConfig, "passthrough": PassthroughConfig, + "decart-api": DecartApiConfig, } diff --git a/src/scope/server/pipeline_manager.py b/src/scope/server/pipeline_manager.py index 884bd43b5..1ebd8c0c7 100644 --- a/src/scope/server/pipeline_manager.py +++ b/src/scope/server/pipeline_manager.py @@ -491,6 +491,35 @@ def _load_pipeline_implementation( logger.info("RewardForcing pipeline initialized") return pipeline + elif pipeline_id == "decart-api": + from scope.core.pipelines import DecartApiPipeline + from scope.core.pipelines.schema import DecartApiConfig + + # Create config with defaults from DecartApiConfig + config_dict = { + "height": DecartApiConfig.model_fields["height"].default, + "width": DecartApiConfig.model_fields["width"].default, + "seed": DecartApiConfig.model_fields["base_seed"].default, + } + config = OmegaConf.create(config_dict) + + # Apply load parameters (resolution, seed) to config + self._apply_load_params( + config, + load_params, + default_height=512, + default_width=512, + default_seed=42, + ) + + pipeline = DecartApiPipeline( + config, + device=get_device(), + dtype=torch.bfloat16, + ) + logger.info("DecartApi pipeline initialized") + return pipeline + else: raise ValueError(f"Invalid pipeline ID: {pipeline_id}") diff --git a/src/scope/server/schema.py b/src/scope/server/schema.py index 96946cbf2..78a0a6ce9 100644 --- a/src/scope/server/schema.py +++ b/src/scope/server/schema.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from scope.core.pipelines.schema import ( + DecartApiConfig, KreaRealtimeVideoConfig, LongLiveConfig, StreamDiffusionV2Config, @@ -341,6 +342,31 @@ class KreaRealtimeVideoLoadParams(LoRAEnabledLoadParams): ) +class DecartApiLoadParams(PipelineLoadParams): + """Load parameters for Decart API pipeline. + + Defaults are derived from DecartApiConfig to ensure consistency. + """ + + height: int = Field( + default=DecartApiConfig.model_fields["height"].default, + description="Target video height", + ge=64, + le=2048, + ) + width: int = Field( + default=DecartApiConfig.model_fields["width"].default, + description="Target video width", + ge=64, + le=2048, + ) + seed: int = Field( + default=DecartApiConfig.model_fields["base_seed"].default, + description="Random seed for generation (not used by Decart API but kept for consistency)", + ge=0, + ) + + class PipelineLoadRequest(BaseModel): """Pipeline load request schema.""" @@ -352,6 +378,7 @@ class PipelineLoadRequest(BaseModel): | PassthroughLoadParams | LongLiveLoadParams | KreaRealtimeVideoLoadParams + | DecartApiLoadParams | None ) = Field(default=None, description="Pipeline-specific load parameters") diff --git a/uv.lock b/uv.lock index 66f1f3709..a1873e87e 100644 --- a/uv.lock +++ b/uv.lock @@ -636,6 +636,7 @@ dependencies = [ { name = "lmdb" }, { name = "omegaconf" }, { name = "peft" }, + { name = "pillow" }, { name = "safetensors" }, { name = "sageattention", version = "2.2.0", source = { url = "https://github.com/daydreamlive/SageAttention/releases/download/v2.2.0-linux/sageattention-2.2.0-cp310-cp310-linux_x86_64.whl" }, marker = "sys_platform == 'linux'" }, { name = "sageattention", version = "2.2.0+cu128torch2.8.0.post3", source = { url = "https://github.com/woct0rdho/SageAttention/releases/download/v2.2.0-windows.post3/sageattention-2.2.0+cu128torch2.8.0.post3-cp39-abi3-win_amd64.whl" }, marker = "sys_platform == 'win32'" }, @@ -680,6 +681,7 @@ requires-dist = [ { name = "lmdb", specifier = ">=1.7.3" }, { name = "omegaconf", specifier = ">=2.3.0" }, { name = "peft", specifier = ">=0.17.1" }, + { name = "pillow", specifier = ">=10.0.0" }, { name = "safetensors", specifier = ">=0.6.2" }, { name = "sageattention", marker = "sys_platform == 'linux'", url = "https://github.com/daydreamlive/SageAttention/releases/download/v2.2.0-linux/sageattention-2.2.0-cp310-cp310-linux_x86_64.whl" }, { name = "sageattention", marker = "sys_platform == 'win32'", url = "https://github.com/woct0rdho/SageAttention/releases/download/v2.2.0-windows.post3/sageattention-2.2.0+cu128torch2.8.0.post3-cp39-abi3-win_amd64.whl" }, From 3241c9b99561e1bd9fd9695268aeeee5c137acf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Leszko?= Date: Fri, 12 Dec 2025 12:28:10 +0100 Subject: [PATCH 2/2] Update dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rafał Leszko --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 79d885c6f..68c8c6550 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "omegaconf>=2.3.0", "accelerate>=1.1.1", "flash-attn==2.8.3; sys_platform == 'linux' or sys_platform == 'win32'", - "sageattention==2.2.0; sys_platform == 'linux' or sys_platform == 'win32'", + "sageattention==2.2.0; (sys_platform == 'linux' and python_version == '3.10') or sys_platform == 'win32'", "safetensors>=0.6.2", "huggingface_hub>=0.25.0", "peft>=0.17.1", @@ -54,6 +54,8 @@ dependencies = [ "triton==3.4.0; sys_platform == 'linux'", "triton-windows==3.4.0.post21; sys_platform == 'win32'", "pillow>=10.0.0", + "decart>=0.0.8", + "tenacity>=9.0.0", ] [project.scripts]