22
33import io
44import logging
5- from typing import Any , List , Optional , Sequence
5+ from typing import Any , List , NamedTuple , Optional , Sequence
66
77import torch
88from torch_tensorrt ._enums import dtype
99from torch_tensorrt ._features import ENABLED_FEATURES
1010from torch_tensorrt ._Input import Input
1111from torch_tensorrt .dynamo ._engine_cache import BaseEngineCache
1212from torch_tensorrt .dynamo ._settings import CompilationSettings
13- from torch_tensorrt .dynamo .conversion ._TRTInterpreter import (
14- TRTInterpreter ,
15- TRTInterpreterResult ,
16- )
13+ from torch_tensorrt .dynamo .conversion ._TRTInterpreter import TRTInterpreter
1714from torch_tensorrt .dynamo .runtime import PythonTorchTensorRTModule , TorchTensorRTModule
1815from torch_tensorrt .dynamo .utils import (
1916 get_cpu_memory_usage ,
2421logger = logging .getLogger (__name__ )
2522
2623
24+ class SerializedInterpreterResult (NamedTuple ):
25+ serialized_engine : bytes
26+ input_names : Sequence [str ]
27+ output_names : Sequence [str ]
28+ weight_name_map : Optional [dict [Any , Any ]]
29+ requires_output_allocator : bool
30+
31+
2732def infer_module_output_dtypes (
2833 module : torch .fx .GraphModule ,
2934 truncate_double : bool = False ,
@@ -34,7 +39,7 @@ def infer_module_output_dtypes(
3439 """
3540 outputs = [node for node in module .graph .nodes if node .op == "output" ]
3641 outputs = outputs [0 ].args
37- return get_output_dtypes (outputs , truncate_double )
42+ return get_output_dtypes (outputs , truncate_double ) # type: ignore
3843
3944
4045def interpret_module_to_result (
@@ -44,7 +49,7 @@ def interpret_module_to_result(
4449 arg_inputs : Optional [Sequence [Input ]] = None ,
4550 kwarg_inputs : Optional [dict [str , Any ]] = None ,
4651 engine_cache : Optional [BaseEngineCache ] = None ,
47- ) -> TRTInterpreterResult :
52+ ) -> SerializedInterpreterResult :
4853 """Interpret an FX module to a TRTInterpreterResult
4954 Args:
5055 module: FX GraphModule to interpret
@@ -84,16 +89,18 @@ def interpret_module_to_result(
8489 with io .BytesIO () as engine_bytes :
8590 engine_bytes .write (serialized_engine )
8691 serialized_engine = engine_bytes .getvalue ()
87-
88- interpreter_result = TRTInterpreterResult (
89- engine = serialized_engine ,
92+ logger .debug (
93+ f"CPU memory usage after serializing engine: { get_cpu_memory_usage ()} MB"
94+ )
95+ serialized_interpreter_result = SerializedInterpreterResult (
96+ serialized_engine = serialized_engine ,
9097 input_names = interpreter_result .input_names ,
9198 output_names = interpreter_result .output_names ,
9299 weight_name_map = interpreter_result .weight_name_map ,
93100 requires_output_allocator = interpreter_result .requires_output_allocator ,
94101 )
95102
96- return interpreter_result
103+ return serialized_interpreter_result
97104
98105
99106def convert_module (
@@ -132,7 +139,7 @@ def convert_module(
132139 )
133140
134141 return rt_cls (
135- serialized_engine = interpreter_result .engine ,
142+ serialized_engine = interpreter_result .serialized_engine ,
136143 input_binding_names = list (interpreter_result .input_names ),
137144 output_binding_names = list (interpreter_result .output_names ),
138145 name = name ,
0 commit comments