diff --git a/src/streamdiffusion/acceleration/tensorrt/models/controlnet_models.py b/src/streamdiffusion/acceleration/tensorrt/models/controlnet_models.py index 6d887fd4..17640f3a 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/controlnet_models.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/controlnet_models.py @@ -240,6 +240,53 @@ def get_output_names(self): "down_block_04", "down_block_05", "down_block_06", "down_block_07", "down_block_08", "mid_block"] + def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]: + """Get dynamic axes configuration for variable input shapes""" + return { + "sample": {0: "B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "B"}, + "timestep": {0: "B"}, + "controlnet_cond": {0: "B", 2: "H_ctrl", 3: "W_ctrl"}, + "text_embeds": {0: "B"}, + "time_ids": {0: "B"}, + **{f"down_block_{i:02d}": {0: "B", 2: "H", 3: "W"} for i in range(9)}, + "mid_block": {0: "B", 2: "H", 3: "W"} + } + + def get_input_profile(self, batch_size, image_height, image_width, + static_batch, static_shape): + """Override to provide SDXL-specific input profiles including text_embeds and time_ids""" + # Get base profiles from parent class + profile = super().get_input_profile(batch_size, image_height, image_width, + static_batch, static_shape) + + # Add SDXL-specific input profiles with dynamic batch dimension + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + + # conditioning_scale is a scalar (empty shape) + profile["conditioning_scale"] = [ + (), # min + (), # opt + (), # max + ] + + # text_embeds has shape (batch, 1280) + profile["text_embeds"] = [ + (min_batch, 1280), # min + (batch_size, 1280), # opt + (max_batch, 1280), # max + ] + + # time_ids has shape (batch, 6) + profile["time_ids"] = [ + (min_batch, 6), # min + (batch_size, 6), # opt + (max_batch, 6), # max + ] + + return profile + def create_controlnet_model(model_type: str = "sd15", unet=None, model_path: str = "", diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py index 90d44a8f..9b904881 100644 --- a/src/streamdiffusion/modules/controlnet_module.py +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -98,7 +98,7 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st preproc = None if cfg.preprocessor: from streamdiffusion.preprocessing.processors import get_preprocessor - preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream, normalization_context='controlnet') + preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream, normalization_context='controlnet', params=cfg.preprocessor_params) # Apply provided parameters to the preprocessor instance if cfg.preprocessor_params: params = cfg.preprocessor_params or {} diff --git a/src/streamdiffusion/preprocessing/processors/__init__.py b/src/streamdiffusion/preprocessing/processors/__init__.py index 3e0e36e9..19b66106 100644 --- a/src/streamdiffusion/preprocessing/processors/__init__.py +++ b/src/streamdiffusion/preprocessing/processors/__init__.py @@ -111,7 +111,7 @@ def get_preprocessor_class(name: str) -> type: return _preprocessor_registry[name] -def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: str = 'controlnet') -> BasePreprocessor: +def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: str = 'controlnet', params: Any = None) -> BasePreprocessor: """ Get a preprocessor by name @@ -135,9 +135,9 @@ def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: if hasattr(processor_class, 'requires_sync_processing') and processor_class.requires_sync_processing: if pipeline_ref is None: raise ValueError(f"Processor '{name}' requires a pipeline_ref") - return processor_class(pipeline_ref=pipeline_ref, normalization_context=normalization_context, _registry_name=name) + return processor_class(pipeline_ref=pipeline_ref, normalization_context=normalization_context, _registry_name=name, **(params or {})) else: - return processor_class(normalization_context=normalization_context, _registry_name=name) + return processor_class(normalization_context=normalization_context, _registry_name=name, **(params or {})) def register_preprocessor(name: str, preprocessor_class): diff --git a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py index e893bb4a..8932976e 100644 --- a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py @@ -66,15 +66,36 @@ def activate(self): self.context = self.engine.create_execution_context() self._cuda_stream = torch.cuda.current_stream().cuda_stream - def allocate_buffers(self, device="cuda"): - """Allocate input/output buffers""" + def allocate_buffers(self, device="cuda", input_shape=None): + """ + Allocate input/output buffers + + Args: + device: Device to allocate tensors on + input_shape: Shape for input tensors (B, C, H, W). Required for engines with dynamic shapes. + """ for idx in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(idx) shape = self.context.get_tensor_shape(name) dtype = trt.nptype(self.engine.get_tensor_dtype(name)) if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + # For dynamic shapes, use provided input_shape + if input_shape is not None and any(dim == -1 for dim in shape): + shape = input_shape self.context.set_input_shape(name, shape) + # Update shape after setting it + shape = self.context.get_tensor_shape(name) + else: + # For output tensors, get shape after input shapes are set + shape = self.context.get_tensor_shape(name) + + # Verify shape has no dynamic dimensions + if any(dim == -1 for dim in shape): + raise RuntimeError( + f"Tensor '{name}' still has dynamic dimensions {shape} after setting input shapes. " + f"Please provide input_shape parameter to allocate_buffers()." + ) tensor = torch.empty( tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] @@ -85,6 +106,37 @@ def infer(self, feed_dict, stream=None): """Run inference with optional stream parameter""" if stream is None: stream = self._cuda_stream + + # Check if we need to update tensor shapes for dynamic dimensions + need_realloc = False + for name, buf in feed_dict.items(): + if name in self.tensors: + if self.tensors[name].shape != buf.shape: + need_realloc = True + break + + # Reallocate buffers if input shape changed + if need_realloc: + # Update input shapes + for name, buf in feed_dict.items(): + # Check if this tensor is an input tensor + try: + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self.context.set_input_shape(name, buf.shape) + except: + # Tensor name might not be in engine, skip + pass + + # Reallocate all tensors with new shapes + for idx in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(idx) + shape = self.context.get_tensor_shape(name) + dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + + tensor = torch.empty( + tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] + ).to(device=self.tensors[name].device) + self.tensors[name] = tensor # Copy input data to tensors for name, buf in feed_dict.items(): @@ -104,47 +156,52 @@ def infer(self, feed_dict, stream=None): class TemporalNetTensorRTPreprocessor(PipelineAwareProcessor): """ - TensorRT-accelerated TemporalNet preprocessor for temporal consistency using optical flow. + TensorRT-accelerated TemporalNet preprocessor for temporal consistency using optical flow visualization. - This preprocessor uses TensorRT to accelerate RAFT optical flow computation, providing - significant speedup over the standard PyTorch implementation. + This preprocessor uses TensorRT to accelerate RAFT optical flow computation and creates a 6-channel + control tensor by concatenating the previous input frame (RGB) with a colorized optical flow + visualization (RGB) computed between the previous and current input frames. + + Output: [prev_input_RGB, flow_RGB(prev_input → current_input)] """ @classmethod def get_preprocessor_metadata(cls): return { "display_name": "TemporalNet TensorRT", - "description": "TensorRT-accelerated optical flow computation for temporal consistency in video generation.", + "description": "TensorRT-accelerated optical flow visualization for temporal consistency. Outputs [prev_input_RGB, flow_RGB].", "parameters": { + "engine_path": { + "type": "str", + "default": None, + "description": "Path to pre-built TensorRT engine file. Use compile_raft_tensorrt.py to build one." + }, "flow_strength": { "type": "float", "default": 1.0, "range": [0.0, 2.0], "step": 0.1, - "description": "Strength of optical flow warping (1.0 = normal, higher = more warping)" + "description": "Strength multiplier for optical flow visualization (1.0 = normal, higher = more pronounced flow)" + }, + "height": { + "type": "int", + "default": 512, + "range": [256, 1024], + "step": 64, + "description": "Height for optical flow computation (must be within engine's height range)" }, - "detect_resolution": { + "width": { "type": "int", "default": 512, "range": [256, 1024], "step": 64, - "description": "Resolution for optical flow computation (affects quality vs speed)" + "description": "Width for optical flow computation (must be within engine's width range)" }, "output_format": { "type": "str", "default": "concat", "options": ["concat", "warped_only"], - "description": "Output format: 'concat' for 6-channel (current+warped), 'warped_only' for 3-channel warped frame" - }, - "enable_tensorrt": { - "type": "bool", - "default": True, - "description": "Use TensorRT acceleration for optical flow computation" - }, - "force_rebuild": { - "type": "bool", - "default": False, - "description": "Force rebuild TensorRT engine even if it exists" + "description": "Output format: 'concat' for 6-channel (prev_input+flow_RGB), 'warped_only' for 3-channel flow RGB only" } }, "use_cases": ["High-performance video generation", "Real-time temporal consistency", "GPU-optimized motion control"] @@ -152,26 +209,26 @@ def get_preprocessor_metadata(cls): def __init__(self, pipeline_ref: Any, - image_resolution: int = 512, + engine_path: str = None, + height: int = 512, + width: int = 512, flow_strength: float = 1.0, - detect_resolution: int = 512, output_format: str = "concat", - enable_tensorrt: bool = True, - force_rebuild: bool = False, **kwargs): """ Initialize TensorRT TemporalNet preprocessor Args: pipeline_ref: Reference to the StreamDiffusion pipeline instance (required) - image_resolution: Output image resolution - flow_strength: Strength of optical flow warping - detect_resolution: Resolution for optical flow computation - output_format: "concat" for 6-channel TemporalNetV2, "warped_only" for 3-channel - enable_tensorrt: Use TensorRT acceleration - force_rebuild: Force rebuild TensorRT engine + engine_path: Path to pre-built TensorRT engine file (required). + Build one using: python -m streamdiffusion.tools.compile_raft_tensorrt + height: Height for optical flow computation (must be within engine's height range) + width: Width for optical flow computation (must be within engine's width range) + flow_strength: Strength multiplier for optical flow visualization + output_format: "concat" for 6-channel [prev_input+flow_RGB], "warped_only" for 3-channel flow RGB only **kwargs: Additional parameters passed to BasePreprocessor """ + if not TORCHVISION_AVAILABLE: raise ImportError( "torchvision is required for TemporalNet preprocessing. " @@ -181,31 +238,42 @@ def __init__(self, if not TENSORRT_AVAILABLE: raise ImportError( "TensorRT and polygraphy are required for TensorRT acceleration. " - "Install them with: pip install tensorrt polygraphy" + "Install them with: python -m streamdiffusion.tools.install-tensorrt" + ) + if engine_path is None: + raise ValueError( + "engine_path is required for TemporalNetTensorRTPreprocessor. " + "Build a TensorRT engine using:\n" + " python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution 512x512 --max_resolution 1024x1024 --output_dir ./models/temporal_net\n" + "Then pass the engine path to this preprocessor." ) super().__init__( pipeline_ref=pipeline_ref, - image_resolution=image_resolution, + height=height, + width=width, + engine_path=engine_path, flow_strength=flow_strength, - detect_resolution=detect_resolution, output_format=output_format, - enable_tensorrt=enable_tensorrt, - force_rebuild=force_rebuild, **kwargs ) self.flow_strength = max(0.0, min(2.0, flow_strength)) - self.detect_resolution = detect_resolution - self.enable_tensorrt = enable_tensorrt and TENSORRT_AVAILABLE - self.force_rebuild = force_rebuild + self.height = height + self.width = width self._first_frame = True - # Model paths - self.models_dir = Path("models") / "temporal_net" - self.models_dir.mkdir(parents=True, exist_ok=True) - self.onnx_path = self.models_dir / "raft_small.onnx" - self.engine_path = self.models_dir / f"raft_small_{trt.__version__ if TENSORRT_AVAILABLE else 'notrt'}_{detect_resolution}.trt" + # Store previous input frame for flow computation + self.prev_input = None + + # Engine path + self.engine_path = Path(engine_path) + if not self.engine_path.exists(): + raise FileNotFoundError( + f"TensorRT engine not found at: {self.engine_path}\n" + f"Build one using:\n" + f" python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution {height}x{width} --max_resolution {height}x{width} --output_dir {self.engine_path.parent}" + ) # Model state self.trt_engine = None @@ -214,151 +282,31 @@ def __init__(self, self._grid_cache = {} self._tensor_cache = {} - # Initialize TensorRT engine - self._ensure_model_ready() - - def _ensure_model_ready(self): - """Ensure TensorRT engine is ready""" - if not self.enable_tensorrt: - raise RuntimeError("TemporalNetTensorRTPreprocessor requires TensorRT acceleration. Use the standard TemporalNetPreprocessor for PyTorch fallback.") - self._setup_tensorrt() - - def _load_raft_for_export(self): - """Load RAFT model temporarily for ONNX export only""" - logger.info("_load_raft_for_export: Loading RAFT Small model for ONNX export") - raft_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False) - raft_model = raft_model.to(device=self.device) - raft_model.eval() - return raft_model - - def _setup_tensorrt(self): - """Setup TensorRT engine""" - # Export to ONNX first if needed - if not self.onnx_path.exists() or self.force_rebuild: - self._export_to_onnx() - - # Build/load TensorRT engine + # Load TensorRT engine self._load_tensorrt_engine() - def _export_to_onnx(self): - """Export RAFT model to ONNX format""" - logger.info(f"_export_to_onnx: Exporting RAFT model to ONNX: {self.onnx_path}") - - # Load PyTorch model temporarily for export - raft_model = self._load_raft_for_export() - - # Create dummy inputs for export - dummy_frame1 = torch.randn(1, 3, self.detect_resolution, self.detect_resolution).to(self.device) - dummy_frame2 = torch.randn(1, 3, self.detect_resolution, self.detect_resolution).to(self.device) - - # Apply RAFT preprocessing if available - weights = Raft_Small_Weights.DEFAULT - if hasattr(weights, 'transforms') and weights.transforms is not None: - transforms = weights.transforms() - dummy_frame1, dummy_frame2 = transforms(dummy_frame1, dummy_frame2) - - dynamic_axes = { - "frame1": {0: "batch_size"}, - "frame2": {0: "batch_size"}, - "flow": {0: "batch_size"}, - } - - with torch.no_grad(): - torch.onnx.export( - raft_model, - (dummy_frame1, dummy_frame2), - str(self.onnx_path), - verbose=False, - input_names=['frame1', 'frame2'], - output_names=['flow'], - opset_version=17, - export_params=True, - dynamic_axes=dynamic_axes, - ) - - # Clean up the temporary model - del raft_model - torch.cuda.empty_cache() - - logger.info(f"_export_to_onnx: Successfully exported ONNX model to {self.onnx_path}") - def _load_tensorrt_engine(self): - """Load or build TensorRT engine""" - if self.engine_path.exists() and not self.force_rebuild: - logger.info(f"_load_tensorrt_engine: Loading existing TensorRT engine: {self.engine_path}") - self._load_existing_engine() - else: - logger.info("_load_tensorrt_engine: Building new TensorRT engine") - self._build_tensorrt_engine() - - def _load_existing_engine(self): - """Load existing TensorRT engine""" + """Load pre-built TensorRT engine""" + logger.info(f"_load_tensorrt_engine: Loading TensorRT engine: {self.engine_path}") try: self.trt_engine = TensorRTEngine(str(self.engine_path)) self.trt_engine.load() self.trt_engine.activate() - self.trt_engine.allocate_buffers(device=self.device) - logger.info(f"_load_existing_engine: TensorRT engine loaded successfully from {self.engine_path}") - except Exception as e: - logger.error(f"_load_existing_engine: Failed to load TensorRT engine: {e}") - self.trt_engine = None - raise RuntimeError(f"Failed to load TensorRT engine: {e}") - - def _build_tensorrt_engine(self): - """Build TensorRT engine from ONNX model""" - if not self.onnx_path.exists(): - logger.error("TemporalNetTensorRTPreprocessor._build_tensorrt_engine: ONNX model not found") - return - - logger.info("_build_tensorrt_engine: Building TensorRT engine... this may take several minutes") - - try: - # Create builder and network - builder = trt.Builder(trt.Logger(trt.Logger.WARNING)) - network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) - parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) - - # Parse ONNX model - with open(self.onnx_path, 'rb') as model: - if not parser.parse(model.read()): - logger.error("_build_tensorrt_engine: Failed to parse ONNX model") - for error in range(parser.num_errors): - logger.error(f"_build_tensorrt_engine: {parser.get_error(error)}") - return - # Configure builder - config = builder.create_builder_config() - config.set_flag(trt.BuilderFlag.FP16) # Enable FP16 for better performance + # For dynamic shapes, provide the input shape based on image dimensions + input_shape = (1, 3, self.height, self.width) + self.trt_engine.allocate_buffers(device=self.device, input_shape=input_shape) - # Set optimization profile for dynamic shapes - profile = builder.create_optimization_profile() - min_shape = (1, 3, self.detect_resolution, self.detect_resolution) - opt_shape = (1, 3, self.detect_resolution, self.detect_resolution) - max_shape = (1, 3, self.detect_resolution, self.detect_resolution) - - profile.set_shape("frame1", min_shape, opt_shape, max_shape) - profile.set_shape("frame2", min_shape, opt_shape, max_shape) - config.add_optimization_profile(profile) - - # Build engine - engine = builder.build_serialized_network(network, config) - - if engine is None: - logger.error("_build_tensorrt_engine: Failed to build TensorRT engine") - return - - # Save engine - with open(self.engine_path, 'wb') as f: - f.write(engine) - - # Load the built engine - self._load_existing_engine() - logger.info(f"_build_tensorrt_engine: Successfully built and saved TensorRT engine: {self.engine_path}") - + logger.info(f"_load_tensorrt_engine: TensorRT engine loaded successfully from {self.engine_path}") + logger.info(f"_load_tensorrt_engine: Using resolution: {self.height}x{self.width}") except Exception as e: - logger.error(f"_build_tensorrt_engine: Failed to build TensorRT engine: {e}") + logger.error(f"_load_tensorrt_engine: Failed to load TensorRT engine: {e}") self.trt_engine = None - raise RuntimeError(f"Failed to build TensorRT engine: {e}") + raise RuntimeError( + f"Failed to load TensorRT engine from {self.engine_path}: {e}\n" + f"Make sure the engine was built with a resolution range that includes {self.height}x{self.width}.\n" + f"For example: python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution 512x512 --max_resolution 1024x1024" + ) @@ -379,49 +327,38 @@ def _process_core(self, image: Image.Image) -> Image.Image: def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ - Process using TensorRT-accelerated optical flow warping (GPU-optimized path) + Process using TensorRT-accelerated optical flow computation (GPU-optimized path) Args: tensor: Current input tensor Returns: - Warped previous frame tensor for temporal guidance + Concatenated tensor: [prev_input_RGB, flow_RGB] for temporal guidance """ - # Check if we have a pipeline reference and previous output - if (self.pipeline_ref is not None and - hasattr(self.pipeline_ref, 'prev_image_result') and - self.pipeline_ref.prev_image_result is not None and - not self._first_frame): - - prev_output = self.pipeline_ref.prev_image_result - - # Convert from VAE output format [-1, 1] to [0, 1] - prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1) - - # Normalize input tensor - input_tensor = tensor - if input_tensor.max() > 1.0: - input_tensor = input_tensor / 255.0 - - # Ensure consistent format - if prev_output.dim() == 4 and prev_output.shape[0] == 1: - prev_output = prev_output[0] - if input_tensor.dim() == 4 and input_tensor.shape[0] == 1: - input_tensor = input_tensor[0] - + # Normalize input tensor + input_tensor = tensor + if input_tensor.max() > 1.0: + input_tensor = input_tensor / 255.0 + + # Ensure consistent format + if input_tensor.dim() == 4 and input_tensor.shape[0] == 1: + input_tensor = input_tensor[0] + + # Check if we have a previous input frame + if self.prev_input is not None and not self._first_frame: try: - # Compute optical flow and warp on GPU using TensorRT - warped_tensor = self._compute_and_warp_tensor(input_tensor, prev_output) + # Compute optical flow between prev_input -> current_input + flow_rgb_tensor = self._compute_flow_to_rgb_tensor(self.prev_input, input_tensor) # Check output format output_format = self.params.get('output_format', 'concat') if output_format == "concat": - # Concatenate current frame + warped frame for TemporalNet2 (6 channels) - result_tensor = self._concatenate_frames_tensor(input_tensor, warped_tensor) + # Concatenate prev_input + flow_RGB for TemporalNet2 (6 channels) + result_tensor = self._concatenate_frames_tensor(self.prev_input, flow_rgb_tensor) else: - # Return only warped frame (3 channels) - result_tensor = warped_tensor + # Return only flow RGB (3 channels) + result_tensor = flow_rgb_tensor # Ensure correct output format if result_tensor.dim() == 3: @@ -432,19 +369,19 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: logger.error(f"_process_tensor_core: TensorRT optical flow failed: {e}") output_format = self.params.get('output_format', 'concat') if output_format == "concat": - # Create 6-channel fallback by concatenating current frame with itself - result_tensor = self._concatenate_frames_tensor(input_tensor, input_tensor) + # Create 6-channel fallback by concatenating prev_input with itself + result_tensor = self._concatenate_frames_tensor(self.prev_input, self.prev_input) if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) else: - # Create 6-channel fallback by concatenating current frame with itself - result_tensor = self._concatenate_frames_tensor(input_tensor, input_tensor) + # Fallback: return prev_input as 3-channel + result_tensor = self.prev_input if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) else: - # First frame or no previous output available + # First frame or no previous input available self._first_frame = False if tensor.dim() == 3: tensor = tensor.unsqueeze(0) @@ -452,85 +389,96 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: # Handle 6-channel output for first frame output_format = self.params.get('output_format', 'concat') if output_format == "concat": - # For first frame, duplicate current frame to create 6-channel output + # For first frame, concatenate current frame with zeros (no flow) if tensor.dim() == 4 and tensor.shape[0] == 1: current_tensor = tensor[0] else: current_tensor = tensor - result_tensor = self._concatenate_frames_tensor(current_tensor, current_tensor) + + # Create zero tensor for flow (same shape as current_tensor) + zero_flow = torch.zeros_like(current_tensor, device=self.device, dtype=current_tensor.dtype) + + result_tensor = self._concatenate_frames_tensor(current_tensor, zero_flow) if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) else: - # Create 6-channel fallback by concatenating current frame with itself + # Return zeros as 3-channel (no flow for first frame) if tensor.dim() == 4 and tensor.shape[0] == 1: current_tensor = tensor[0] else: current_tensor = tensor - result_tensor = self._concatenate_frames_tensor(current_tensor, current_tensor) + result_tensor = torch.zeros_like(current_tensor, device=self.device, dtype=current_tensor.dtype) if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) + # Store current input as previous for next frame + self.prev_input = input_tensor.clone() + return result - def _compute_and_warp_tensor(self, current_tensor: torch.Tensor, prev_tensor: torch.Tensor) -> torch.Tensor: + def _compute_flow_to_rgb_tensor(self, prev_input_tensor: torch.Tensor, current_input_tensor: torch.Tensor) -> torch.Tensor: """ - Compute optical flow using TensorRT and warp previous tensor + Compute optical flow between prev_input -> current_input and convert to RGB visualization Args: - current_tensor: Current input frame tensor (CHW format, [0,1]) on GPU - prev_tensor: Previous pipeline output tensor (CHW format, [0,1]) on GPU + prev_input_tensor: Previous input frame tensor (CHW format, [0,1]) on GPU + current_input_tensor: Current input frame tensor (CHW format, [0,1]) on GPU Returns: - Warped previous frame tensor on GPU + Flow visualization as RGB tensor (CHW format, [0,1]) on GPU """ target_width, target_height = self.get_target_dimensions() # Convert to float32 for TensorRT processing - current_tensor = current_tensor.to(device=self.device, dtype=torch.float32) - prev_tensor = prev_tensor.to(device=self.device, dtype=torch.float32) + prev_tensor = prev_input_tensor.to(device=self.device, dtype=torch.float32) + current_tensor = current_input_tensor.to(device=self.device, dtype=torch.float32) # Resize for flow computation if needed (keep on GPU) - if current_tensor.shape[-1] != self.detect_resolution or current_tensor.shape[-2] != self.detect_resolution: - current_resized = F.interpolate( - current_tensor.unsqueeze(0), - size=(self.detect_resolution, self.detect_resolution), - mode='bilinear', - align_corners=False - ).squeeze(0) + if current_tensor.shape[-1] != self.width or current_tensor.shape[-2] != self.height: prev_resized = F.interpolate( prev_tensor.unsqueeze(0), - size=(self.detect_resolution, self.detect_resolution), + size=(self.height, self.width), mode='bilinear', align_corners=False ).squeeze(0) + current_resized = F.interpolate( + current_tensor.unsqueeze(0), + size=(self.height, self.width), + mode='bilinear', + align_corners=False + ).squeeze(0) else: - current_resized = current_tensor prev_resized = prev_tensor + current_resized = current_tensor - # Compute optical flow using TensorRT - flow = self._compute_optical_flow_tensorrt(current_resized, prev_resized) + # Compute optical flow using TensorRT: prev_input -> current_input + flow = self._compute_optical_flow_tensorrt(prev_resized, current_resized) # Apply flow strength scaling (GPU operation) flow_strength = self.params.get('flow_strength', 1.0) if flow_strength != 1.0: flow = flow * flow_strength - # Warp previous frame using flow (GPU operation) - warped_frame = self._warp_frame_tensor(prev_resized, flow) + # Convert flow to RGB visualization using torchvision's flow_to_image + # flow_to_image expects (2, H, W) and returns (3, H, W) in range [0, 255] + flow_rgb = flow_to_image(flow) # Returns uint8 tensor [0, 255] + + # Convert to float [0, 1] range + flow_rgb = flow_rgb.float() / 255.0 # Resize back to target resolution if needed (keep on GPU) - if warped_frame.shape[-1] != target_width or warped_frame.shape[-2] != target_height: - warped_frame = F.interpolate( - warped_frame.unsqueeze(0), + if flow_rgb.shape[-1] != target_width or flow_rgb.shape[-2] != target_height: + flow_rgb = F.interpolate( + flow_rgb.unsqueeze(0), size=(target_height, target_width), mode='bilinear', align_corners=False ).squeeze(0) # Convert to processor's dtype only at the very end - result = warped_frame.to(dtype=self.dtype) + result = flow_rgb.to(dtype=self.dtype) return result @@ -671,6 +619,7 @@ def reset(self): Reset the preprocessor state (useful for new sequences) """ self._first_frame = True + self.prev_input = None # Clear caches to free memory self._grid_cache.clear() self._tensor_cache.clear() diff --git a/src/streamdiffusion/tools/compile_raft_tensorrt.py b/src/streamdiffusion/tools/compile_raft_tensorrt.py new file mode 100644 index 00000000..8dec3a76 --- /dev/null +++ b/src/streamdiffusion/tools/compile_raft_tensorrt.py @@ -0,0 +1,301 @@ +import torch +import logging +from pathlib import Path +from typing import Optional +import fire + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +try: + import tensorrt as trt + TENSORRT_AVAILABLE = True +except ImportError: + TENSORRT_AVAILABLE = False + logger.error("TensorRT not available. Please install it first.") + +try: + from torchvision.models.optical_flow import raft_small, Raft_Small_Weights + TORCHVISION_AVAILABLE = True +except ImportError: + TORCHVISION_AVAILABLE = False + logger.error("torchvision not available. Please install it first.") + + +def export_raft_to_onnx( + onnx_path: Path, + min_height: int = 512, + min_width: int = 512, + max_height: int = 512, + max_width: int = 512, + device: str = "cuda" +) -> bool: + """ + Export RAFT model to ONNX format + + Args: + onnx_path: Path to save the ONNX model + min_height: Minimum input height for the model + min_width: Minimum input width for the model + max_height: Maximum input height for the model + max_width: Maximum input width for the model + device: Device to use for export + + Returns: + True if successful, False otherwise + """ + if not TORCHVISION_AVAILABLE: + logger.error("torchvision is required but not installed") + return False + + logger.info(f"Exporting RAFT model to ONNX: {onnx_path}") + logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}") + + try: + # Load RAFT model + logger.info("Loading RAFT Small model...") + raft_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=True) + raft_model = raft_model.to(device=device) + raft_model.eval() + + # Create dummy inputs using max resolution for export + dummy_frame1 = torch.randn(1, 3, max_height, max_width).to(device) + dummy_frame2 = torch.randn(1, 3, max_height, max_width).to(device) + + # Apply RAFT preprocessing if available + weights = Raft_Small_Weights.DEFAULT + if hasattr(weights, 'transforms') and weights.transforms is not None: + transforms = weights.transforms() + dummy_frame1, dummy_frame2 = transforms(dummy_frame1, dummy_frame2) + + # Make batch, height, and width dimensions dynamic + dynamic_axes = { + "frame1": {0: "batch_size", 2: "height", 3: "width"}, + "frame2": {0: "batch_size", 2: "height", 3: "width"}, + "flow": {0: "batch_size", 2: "height", 3: "width"}, + } + + logger.info("Exporting to ONNX...") + with torch.no_grad(): + torch.onnx.export( + raft_model, + (dummy_frame1, dummy_frame2), + str(onnx_path), + verbose=False, + input_names=['frame1', 'frame2'], + output_names=['flow'], + opset_version=17, + export_params=True, + dynamic_axes=dynamic_axes, + ) + + del raft_model + torch.cuda.empty_cache() + + logger.info(f"Successfully exported ONNX model to {onnx_path}") + return True + + except Exception as e: + logger.error(f"Failed to export ONNX model: {e}") + import traceback + traceback.print_exc() + return False + + +def build_tensorrt_engine( + onnx_path: Path, + engine_path: Path, + min_height: int = 512, + min_width: int = 512, + max_height: int = 512, + max_width: int = 512, + fp16: bool = True, + workspace_size_gb: int = 4 +) -> bool: + """ + Build TensorRT engine from ONNX model + + Args: + onnx_path: Path to the ONNX model + engine_path: Path to save the TensorRT engine + min_height: Minimum input height for optimization + min_width: Minimum input width for optimization + max_height: Maximum input height for optimization + max_width: Maximum input width for optimization + fp16: Enable FP16 precision mode + workspace_size_gb: Maximum workspace size in GB + + Returns: + True if successful, False otherwise + """ + if not TENSORRT_AVAILABLE: + logger.error("TensorRT is required but not installed") + return False + + if not onnx_path.exists(): + logger.error(f"ONNX model not found: {onnx_path}") + return False + + logger.info(f"Building TensorRT engine from ONNX model: {onnx_path}") + logger.info(f"Output path: {engine_path}") + logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}") + logger.info(f"FP16 mode: {fp16}") + logger.info("This may take several minutes...") + + try: + builder = trt.Builder(trt.Logger(trt.Logger.INFO)) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) + + logger.info("Parsing ONNX model...") + with open(onnx_path, 'rb') as model: + if not parser.parse(model.read()): + logger.error("Failed to parse ONNX model") + for error in range(parser.num_errors): + logger.error(f"Parser error: {parser.get_error(error)}") + return False + + logger.info("Configuring TensorRT builder...") + config = builder.create_builder_config() + + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size_gb * (1 << 30)) + + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + logger.info("FP16 mode enabled") + + # Calculate optimal resolution (middle point) + opt_height = (min_height + max_height) // 2 + opt_width = (min_width + max_width) // 2 + + profile = builder.create_optimization_profile() + min_shape = (1, 3, min_height, min_width) + opt_shape = (1, 3, opt_height, opt_width) + max_shape = (1, 3, max_height, max_width) + + profile.set_shape("frame1", min_shape, opt_shape, max_shape) + profile.set_shape("frame2", min_shape, opt_shape, max_shape) + config.add_optimization_profile(profile) + + logger.info("Building TensorRT engine... (this will take a while)") + engine = builder.build_serialized_network(network, config) + + if engine is None: + logger.error("Failed to build TensorRT engine") + return False + + logger.info(f"Saving engine to {engine_path}") + engine_path.parent.mkdir(parents=True, exist_ok=True) + with open(engine_path, 'wb') as f: + f.write(engine) + + logger.info(f"Successfully built and saved TensorRT engine: {engine_path}") + logger.info(f"Engine size: {engine_path.stat().st_size / (1024*1024):.2f} MB") + + # Delete ONNX file after successful engine creation + try: + if onnx_path.exists(): + onnx_path.unlink() + logger.info(f"Deleted ONNX file: {onnx_path}") + except Exception as e: + logger.warning(f"Failed to delete ONNX file: {e}") + + return True + + except Exception as e: + logger.error(f"Failed to build TensorRT engine: {e}") + import traceback + traceback.print_exc() + return False + + +def compile_raft( + min_resolution: str = "512x512", + max_resolution: str = "512x512", + output_dir: str = "./models/temporal_net", + device: str = "cuda", + fp16: bool = True, + workspace_size_gb: int = 4, + force_rebuild: bool = False +): + """ + Main function to compile RAFT model to TensorRT engine + + Args: + min_resolution: Minimum input resolution as "HxW" (e.g., "512x512") (default: "512x512") + max_resolution: Maximum input resolution as "HxW" (e.g., "1024x1024") (default: "512x512") + output_dir: Directory to save the models (default: ./models/temporal_net) + device: Device to use for export (default: cuda) + fp16: Enable FP16 precision mode (default: True) + workspace_size_gb: Maximum workspace size in GB (default: 4) + force_rebuild: Force rebuild even if engine exists (default: False) + """ + if not TENSORRT_AVAILABLE: + logger.error("TensorRT is not available. Please install it first using:") + logger.error(" python -m streamdiffusion.tools.install-tensorrt") + return + + if not TORCHVISION_AVAILABLE: + logger.error("torchvision is not available. Please install it first using:") + logger.error(" pip install torchvision") + return + + # Parse resolution strings + try: + min_height, min_width = map(int, min_resolution.split('x')) + except: + logger.error(f"Invalid min_resolution format: {min_resolution}. Expected format: HxW (e.g., 512x512)") + return + + try: + max_height, max_width = map(int, max_resolution.split('x')) + except: + logger.error(f"Invalid max_resolution format: {max_resolution}. Expected format: HxW (e.g., 1024x1024)") + return + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Add resolution suffix to filenames + onnx_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.onnx" + engine_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.engine" + + logger.info("="*80) + logger.info("RAFT TensorRT Compilation") + logger.info("="*80) + logger.info(f"Output directory: {output_path.absolute()}") + logger.info(f"Resolution range: {min_resolution} - {max_resolution}") + logger.info(f"ONNX path: {onnx_path}") + logger.info(f"Engine path: {engine_path}") + logger.info("="*80) + + if engine_path.exists() and not force_rebuild: + logger.info(f"TensorRT engine already exists: {engine_path}") + logger.info("Use --force_rebuild to rebuild it") + return + + if not onnx_path.exists() or force_rebuild: + logger.info("\n[Step 1/2] Exporting RAFT to ONNX...") + if not export_raft_to_onnx(onnx_path, min_height, min_width, max_height, max_width, device): + logger.error("Failed to export ONNX model") + return + else: + logger.info(f"\n[Step 1/2] ONNX model already exists: {onnx_path}") + + logger.info("\n[Step 2/2] Building TensorRT engine...") + if not build_tensorrt_engine(onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb): + logger.error("Failed to build TensorRT engine") + return + + logger.info("\n" + "="*80) + logger.info("✓ Compilation completed successfully!") + logger.info("="*80) + logger.info(f"Engine path: {engine_path.absolute()}") + logger.info("\nYou can now use this engine in TemporalNetTensorRTPreprocessor:") + logger.info(f' engine_path="{engine_path.absolute()}"') + logger.info("="*80) + + +if __name__ == "__main__": + fire.Fire(compile_raft) +