@@ -34,7 +34,7 @@ def infer_module_output_dtypes(
3434 """
3535 outputs = [node for node in module .graph .nodes if node .op == "output" ]
3636 outputs = outputs [0 ].args
37- return get_output_dtypes (outputs , truncate_double ) # type: ignore[no-any-return]
37+ return get_output_dtypes (outputs , truncate_double )
3838
3939
4040def interpret_module_to_result (
@@ -70,6 +70,29 @@ def interpret_module_to_result(
7070 )
7171
7272 interpreter_result = interpreter .run ()
73+ # Delete the frozen parameters from the module to release CPU memory
74+ del interpreter
75+ for attr in dir (module ):
76+ if attr .startswith ("_frozen_param" ):
77+ delattr (module , attr )
78+ release_memory ()
79+ logger .debug (
80+ f"CPU memory usage after clearing frozen parameters and building memory in conversion: { get_cpu_memory_usage ()} MB"
81+ )
82+
83+ serialized_engine = interpreter_result .engine .serialize ()
84+ with io .BytesIO () as engine_bytes :
85+ engine_bytes .write (serialized_engine )
86+ serialized_engine = engine_bytes .getvalue ()
87+
88+ interpreter_result = TRTInterpreterResult (
89+ engine = serialized_engine ,
90+ input_names = interpreter_result .input_names ,
91+ output_names = interpreter_result .output_names ,
92+ weight_name_map = interpreter_result .weight_name_map ,
93+ requires_output_allocator = interpreter_result .requires_output_allocator ,
94+ )
95+
7396 return interpreter_result
7497
7598
@@ -108,22 +131,8 @@ def convert_module(
108131 "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
109132 )
110133
111- # Delete the frozen parameters from the module to release CPU memory
112- for attr in dir (module ):
113- if attr .startswith ("_frozen_param" ):
114- delattr (module , attr )
115- release_memory ()
116- logger .debug (
117- f"CPU memory usage after clearing frozen parameters and building memory in conversion: { get_cpu_memory_usage ()} MB"
118- )
119-
120- serialized_engine = interpreter_result .engine .serialize ()
121- with io .BytesIO () as engine_bytes :
122- engine_bytes .write (serialized_engine )
123- serialized_engine = engine_bytes .getvalue ()
124- breakpoint ()
125134 return rt_cls (
126- serialized_engine = serialized_engine ,
135+ serialized_engine = interpreter_result . engine ,
127136 input_binding_names = list (interpreter_result .input_names ),
128137 output_binding_names = list (interpreter_result .output_names ),
129138 name = name ,
0 commit comments