@@ -221,8 +221,16 @@ def __init__(
221221 self .use_output_allocator_outputs = False
222222 self .device = torch .cuda .current_device ()
223223 self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
224+ self .requires_unique_output = False
224225 if self .serialized_engine is not None and not self .settings .lazy_engine_init :
225226 self .setup_engine ()
227+ self .is_shape_inference_io = [
228+ self .engine .is_shape_inference_io (input_name )
229+ for input_name in self .input_names
230+ ]
231+
232+ def set_requires_unique_output (self , requires_unique_output : bool ) -> None :
233+ self .requires_unique_output = requires_unique_output
226234
227235 def get_streamable_device_memory_budget (self ) -> Any :
228236 return self .engine .streamable_weights_size
@@ -269,10 +277,10 @@ def setup_engine(self) -> None:
269277 # otherwise, use the caller stream and disable stream synchronization
270278 self ._caller_stream = torch .cuda .current_stream ()
271279 if self ._caller_stream == torch .cuda .default_stream ():
272- self ._engine_stream = torch .cuda .Stream ()
280+ self ._engine_stream : torch . cuda . Stream = torch .cuda .Stream ()
273281 self .sync_stream = True
274282 else :
275- self ._engine_stream = self ._caller_stream
283+ self ._engine_stream : torch . cuda . Stream = self ._caller_stream
276284 self .sync_stream = False
277285
278286 self .initialized = True
@@ -396,7 +404,7 @@ def setup_input_tensors(
396404
397405 # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
398406 # as per TensorRT requirements
399- if self .engine . is_shape_inference_io ( input_name ) :
407+ if self .is_shape_inference_io [ i ] :
400408 # Shape tensor inputs are casted to int64 explicitly
401409 # Currently Torch CPU pointers are not working; numpy pointers are used instead
402410 # to refer to underlying memory
@@ -500,7 +508,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
500508 raise ValueError (
501509 "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
502510 )
503- if self .output_tensors is None :
511+ if self .output_tensors is None or self . requires_unique_output :
504512 self .output_tensors = self .create_output_tensors ()
505513 outputs = self .output_tensors
506514
0 commit comments