@@ -200,7 +200,6 @@ def refit_from_dict(self, refit_weights, is_fp16):
200200 trt_datatype = trt .DataType .HALF
201201
202202 # trt.Weight and trt.TensorLocation
203- refit_weights [trt_weight_name ] = refit_weights [trt_weight_name ].cpu ()
204203 trt_wt_tensor = trt .Weights (
205204 trt_datatype ,
206205 refit_weights [trt_weight_name ].data_ptr (),
@@ -213,15 +212,16 @@ def refit_from_dict(self, refit_weights, is_fp16):
213212 )
214213
215214 # apply refit
216- # refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
217- refitter .set_named_weights (trt_weight_name , trt_wt_tensor )
215+ refitter .set_named_weights (trt_weight_name , trt_wt_tensor , trt_wt_location )
218216 refitted_weights .add (trt_weight_name )
219217
220218 assert set (refitted_weights ) == set (refit_weights .keys ())
221219 if not refitter .refit_cuda_engine ():
222220 print ("Error: failed to refit new weights." )
223221 exit (0 )
224222
223+ print (f"[I] Total refitted weights { len (refitted_weights )} ." )
224+
225225 def build (
226226 self ,
227227 onnx_path ,
@@ -240,14 +240,18 @@ def build(
240240 for _p , i_profile in zip (p , input_profile ):
241241 for name , dims in i_profile .items ():
242242 assert len (dims ) == 3
243- _p .add (name , min = dims [0 ], opt = dims [1 ], max = dims [2 ])
243+ _p .add (namFe , min = dims [0 ], opt = dims [1 ], max = dims [2 ])
244244
245245 config_kwargs = {}
246246 if not enable_all_tactics :
247247 config_kwargs ["tactic_sources" ] = []
248248
249249 network = network_from_onnx_path (
250- onnx_path , flags = [trt .OnnxParserFlag .NATIVE_INSTANCENORM ]
250+ onnx_path ,
251+ flags = [
252+ trt .OnnxParserFlag .NATIVE_INSTANCENORM ,
253+ trt .NetworkDefinitionCreationFlag .STRONGLY_TYPED ,
254+ ],
251255 )
252256 if update_output_names :
253257 print (f"Updating network outputs to { update_output_names } " )
@@ -257,7 +261,6 @@ def build(
257261 config = builder .create_builder_config ()
258262 config .progress_monitor = TQDMProgressMonitor ()
259263
260- config .set_flag (trt .BuilderFlag .STRICT_TYPES )
261264 config .set_flag (trt .BuilderFlag .FP16 ) if fp16 else None
262265 config .set_flag (trt .BuilderFlag .REFIT ) if enable_refit else None
263266
@@ -305,53 +308,52 @@ def load(self):
305308 print (f"Loading TensorRT engine: { self .engine_path } " )
306309 self .engine = engine_from_bytes (bytes_from_path (self .engine_path ))
307310
308- def activate (self , reuse_device_memory = None ):
311+ def activate (self , reuse_device_memory = False ):
309312 if reuse_device_memory :
310313 self .context = self .engine .create_execution_context_without_device_memory ()
311- # self.context.device_memory = reuse_device_memory
312314 else :
313315 self .context = self .engine .create_execution_context ()
314316
315317 def allocate_buffers (self , shape_dict = None , device = "cuda" , additional_shapes = None ):
316318 nvtx .range_push ("allocate_buffers" )
317- for idx in range (self .engine .num_io_tensors ):
318- binding = self .engine [idx ]
319- if shape_dict and binding in shape_dict :
320- shape = shape_dict [binding ].shape
321- elif additional_shapes and binding in additional_shapes :
322- shape = additional_shapes [binding ]
319+ for binding in range (self .engine .num_io_tensors ):
320+ name = self .engine .get_tensor_name (binding )
321+
322+ if shape_dict and name in shape_dict :
323+ shape = shape_dict [name ].shape
324+ elif additional_shapes and name in additional_shapes :
325+ shape = additional_shapes [name ]
323326 else :
324- shape = self .context .get_binding_shape (idx )
325- dtype = trt .nptype (self .engine .get_binding_dtype (binding ))
326- if self .engine .binding_is_input (binding ):
327- self .context .set_binding_shape (idx , shape )
327+ shape = self .context .get_tensor_shape (name )
328+
329+ dtype = trt .nptype (self .engine .get_tensor_dtype (name ))
330+ if self .engine .get_tensor_mode (name ) == trt .TensorIOMode .INPUT :
331+ self .context .set_input_shape (name , shape )
328332 tensor = torch .zeros (
329333 tuple (shape ), dtype = numpy_to_torch_dtype_dict [dtype ]
330334 ).to (device = device )
331- self .tensors [binding ] = tensor
335+ self .tensors [name ] = tensor
332336 nvtx .range_pop ()
333337
334338 def infer (self , feed_dict , stream , use_cuda_graph = False ):
335- nvtx . range_push ( "set_tensors" )
339+
336340 for name , buf in feed_dict .items ():
337341 self .tensors [name ].copy_ (buf )
338342
339343 for name , tensor in self .tensors .items ():
340344 self .context .set_tensor_address (name , tensor .data_ptr ())
341- nvtx .range_pop ()
342- nvtx .range_push ("execute" )
345+
343346 noerror = self .context .execute_async_v3 (stream )
344347 if not noerror :
345- raise ValueError ("ERROR: inference failed." )
346- nvtx . range_pop ()
348+ raise ValueError (f "ERROR: inference failed." )
349+
347350 return self .tensors
348351
349352 def __str__ (self ):
350353 out = ""
351354 for opt_profile in range (self .engine .num_optimization_profiles ):
352- for binding_idx in range (self .engine .num_bindings ):
353- name = self .engine .get_binding_name ( binding_idx )
355+ for binding in range (self .engine .num_io_tensors ):
356+ name = self .engine .get_tensor_name ( binding )
354357 shape = self .engine .get_profile_shape (opt_profile , name )
355358 out += f"\t { name } = { shape } \n "
356359 return out
357-
0 commit comments