@@ -173,6 +173,8 @@ def __init__(
173173 self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
174174 self ._caller_stream : Optional [torch .cuda .Stream ] = None
175175 self ._engine_stream : Optional [torch .cuda .Stream ] = None
176+ self .output_tensors : Optional [List [torch .Tensor ]] = None
177+ self .sync_stream = True
176178
177179 # TODO: Make the below a Dictionary {shape: cudagraph}
178180 self .shape_key : Optional [str ] = None
@@ -263,12 +265,16 @@ def setup_engine(self) -> None:
263265 assert (
264266 self .target_platform == Platform .current_platform ()
265267 ), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
268+ # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
269+ # otherwise, use the caller stream and disable stream synchronization
266270 self ._caller_stream = torch .cuda .current_stream ()
267- if (
268- self ._engine_stream == torch .cuda .default_stream ()
269- or self ._engine_stream is None
270- ):
271+ if self ._caller_stream == torch .cuda .default_stream ():
271272 self ._engine_stream = torch .cuda .Stream ()
273+ self .sync_stream = True
274+ else :
275+ self ._engine_stream = self ._caller_stream
276+ self .sync_stream = False
277+
272278 self .initialized = True
273279 runtime = trt .Runtime (TRT_LOGGER )
274280 self .engine = runtime .deserialize_cuda_engine (self .serialized_engine )
@@ -489,15 +495,14 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
489495 if can_use_pre_allocated_outputs :
490496 outputs = self .pre_allocated_outputs
491497 else :
492- # self.output_shapes = [
493- # tuple(self.context.get_tensor_shape(output_name))
494- # for output_name in self.output_names
495- # ]
498+
496499 if DYNAMIC_DIM in self .output_shapes :
497500 raise ValueError (
498501 "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
499502 )
500- outputs = self .create_output_tensors ()
503+ if self .output_tensors is None :
504+ self .output_tensors = self .create_output_tensors ()
505+ outputs = self .output_tensors
501506
502507 for o , output_name in enumerate (self .output_names ):
503508 if need_cudagraphs_record :
@@ -520,37 +525,38 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
520525 else nullcontext ()
521526 ):
522527
523- self ._engine_stream .wait_stream (self ._caller_stream )
528+ if self .sync_stream :
529+ self ._engine_stream .wait_stream (self ._caller_stream )
524530
525- # with torch.cuda.stream(self._engine_stream):
526- # if self.cudagraphs_enabled:
527- # if need_cudagraphs_record:
528- # self.cudagraph = torch.cuda.CUDAGraph()
531+ if self .cudagraphs_enabled :
532+ if need_cudagraphs_record :
533+ self .cudagraph = torch .cuda .CUDAGraph ()
529534
530- # if self.profiling_enabled:
531- # self.cudagraph.enable_debug_mode()
535+ if self .profiling_enabled :
536+ self .cudagraph .enable_debug_mode ()
532537
533- # with torch.cuda.graph(
534- # self.cudagraph, stream=self._engine_stream
535- # ):
536- # self.context.execute_async_v3(
537- # self._engine_stream.cuda_stream
538- # )
538+ with torch .cuda .graph (
539+ self .cudagraph , stream = self ._engine_stream
540+ ):
541+ self .context .execute_async_v3 (
542+ self ._engine_stream .cuda_stream
543+ )
539544
540- # if self.profiling_enabled:
541- # import tempfile
545+ if self .profiling_enabled :
546+ import tempfile
542547
543- # with tempfile.TemporaryDirectory() as tmpdir:
544- # self.cudagraph.debug_dump(
545- # f"{tempdir }/{self.name}_cudagraph.dot"
546- # )
548+ with tempfile .TemporaryDirectory () as tmpdir :
549+ self .cudagraph .debug_dump (
550+ f"{ tmpdir } /{ self .name } _cudagraph.dot"
551+ )
547552
548- # self.cudagraph.replay() # type: ignore
553+ self .cudagraph .replay () # type: ignore
549554
550- # else:
551- self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
555+ else :
556+ self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
552557
553- self ._caller_stream .wait_stream (self ._engine_stream )
558+ if self .sync_stream :
559+ self ._caller_stream .wait_stream (self ._engine_stream )
554560
555561 if self .use_pre_allocated_outputs :
556562 self .pre_allocated_outputs = self .create_output_tensors ()
@@ -753,13 +759,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
753759 # Representation of input shapes to a given model
754760 # Shapes are concatenated as so:
755761 # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
756- tensor_inputs = []
757- for t in inputs :
758- if not isinstance (t , torch .Tensor ):
759- return True
760- tensor_inputs .append (t )
762+ if not all (isinstance (t , torch .Tensor ) for t in inputs ):
763+ return True
764+
761765 new_shape_key = "" .join (
762- str (tuple (t .shape )).replace (" " , "" ) for t in tensor_inputs
766+ str (tuple (t .shape )).replace (" " , "" )
767+ for t in inputs
768+ if isinstance (t , torch .Tensor )
763769 )
764770
765771 # If the new shape key differs from the existing one,
0 commit comments