11import argparse
2- import os
32import sys
43import traceback
5- import tempfile
64from pathlib import Path
75
86import postprocessing_data as pp
119from io_model_wrapper import IREEModelWrapper
1210from reporter .report_writer import ReportWriter
1311from transformer import IREETransformer
12+ from iree_auxiliary import (load_model , create_dict_for_transformer , prepare_output , validate_cli_args )
1413
15- import numpy as np
16-
17- sys .path .append (str (Path (__file__ ).resolve ().parents [1 ].joinpath ('model_converters' ,
18- 'iree_converter' ,
19- 'iree_auxiliary' )))
20- from compiler import IREECompiler # noqa: E402
21- from converter import IREEConverter # noqa: E402
2214
2315sys .path .append (str (Path (__file__ ).resolve ().parents [1 ].joinpath ('utils' )))
2416from logger_conf import configure_logger # noqa: E402
3224 sys .exit (1 )
3325
3426
35- def validate_cli_args (args ):
36- if args .model :
37- pass
38- else :
39- pass
40-
41-
4227def cli_argument_parser ():
4328 parser = argparse .ArgumentParser ()
4429 parser .add_argument ('-f' , '--source_framework' ,
4530 help = 'Source model framework (required for automatic conversion to MLIR)' ,
4631 type = str ,
4732 choices = ['onnx' , 'pytorch' ],
48- dest = 'source_framework' )
33+ dest = 'source_framework' )
4934 parser .add_argument ('-m' , '--model' ,
5035 help = 'Path to source framework model (.onnx, .pt),'
5136 'to file with compiled model (.vmfb)'
@@ -181,96 +166,6 @@ def cli_argument_parser():
181166 return args
182167
183168
184- def convert_model_to_mlir (model_path , model_weights , torch_module , model_name , onnx_opset_version , source_framework , input_shape , output_mlir ):
185- dictionary = {
186- 'source_framework' : source_framework ,
187- 'model_name' : model_name ,
188- 'model_path' : model_path ,
189- 'model_weights' : model_weights ,
190- 'torch_module' : torch_module ,
191- 'onnx_opset_version' : onnx_opset_version ,
192- 'input_shape' : input_shape ,
193- 'output_mlir' : output_mlir
194- }
195- converter = IREEConverter .get_converter (dictionary )
196- converter .convert_to_mlir ()
197- return
198-
199-
200- def compile_mlir (mlir_path , target_backend , opt_level , extra_compile_args ):
201- try :
202- log .info ('Starting model compilation' )
203- return IREECompiler .compile (mlir_path , target_backend , opt_level , extra_compile_args )
204- except Exception as e :
205- log .error (f'Failed to compile MLIR: { e } ' )
206- raise
207-
208-
209- def load_model_buffer (model_path , target_backend , opt_level , extra_compile_args ):
210- if not os .path .exists (model_path ):
211- raise FileNotFoundError (f'Model file not found: { model_path } ' )
212-
213- file_type = model_path .split ('.' )[- 1 ]
214-
215- if file_type == 'mlir' :
216- if target_backend is None :
217- raise ValueError ('target_backend is required for MLIR compilation' )
218- vmfb_buffer = compile_mlir (model_path , target_backend , opt_level , extra_compile_args )
219- elif file_type == 'vmfb' :
220- with open (model_path , 'rb' ) as f :
221- vmfb_buffer = f .read ()
222- else :
223- raise ValueError (f'The file type { file_type } is not supported. Supported types: .mlir, .vmfb' )
224-
225- log .info (f'Successfully loaded model buffer from { model_path } ' )
226- return vmfb_buffer
227-
228-
229- def create_iree_context_from_buffer (vmfb_buffer ):
230- try :
231- config = ireert .Config ('local-task' )
232- vm_module = ireert .VmModule .from_flatbuffer (config .vm_instance , vmfb_buffer )
233- context = ireert .SystemContext (config = config )
234- context .add_vm_module (vm_module )
235-
236- log .info ('Successfully created IREE context from buffer' )
237- return context
238-
239- except Exception as e :
240- log .error (f'Failed to create IREE context: { e } ' )
241- raise
242-
243-
244- def load_model (model_path , model_weights , torch_module , model_name , onnx_opset_version ,
245- source_framework , input_shape , target_backend , opt_level , extra_compile_args ):
246- is_tmp_mlir = False
247- if model_path is None or model_path .split ('.' )[- 1 ] not in ['vmfb' , 'mlir' ]:
248- with tempfile .NamedTemporaryFile (mode = 'w+t' , delete = False , suffix = '.mlir' ) as temp :
249- output_mlir = temp .name
250- convert_model_to_mlir (model_path ,
251- model_weights ,
252- torch_module ,
253- model_name ,
254- onnx_opset_version ,
255- source_framework ,
256- input_shape ,
257- output_mlir )
258- model_path = output_mlir
259- is_tmp_mlir = True
260-
261- vmfb_buffer = load_model_buffer (
262- model_path ,
263- target_backend = target_backend ,
264- opt_level = opt_level ,
265- extra_compile_args = extra_compile_args
266- )
267-
268- if is_tmp_mlir :
269- os .remove (model_path )
270-
271- return create_iree_context_from_buffer (vmfb_buffer )
272-
273-
274169def get_inference_function (model_context , function_name ):
275170 try :
276171 main_module = model_context .modules .module
@@ -293,7 +188,7 @@ def inference_iree(inference_func, number_iter, get_slice, test_duration):
293188 time_infer .append (exec_time )
294189 else :
295190 time_infer = loop_inference (number_iter , test_duration )(
296- inference_iteration
191+ inference_iteration ,
297192 )(inference_func , get_slice )['time_infer' ]
298193
299194 log .info ('Inference completed' )
@@ -311,7 +206,7 @@ def infer_slice(inference_func, slice_input):
311206 config = ireert .Config ('local-task' )
312207 device = config .device
313208
314- input_buffers = list ()
209+ input_buffers = ()
315210 for input_ in slice_input :
316211 input_buffers .append (ireert .asdevicearray (device , input_ ))
317212
@@ -323,50 +218,6 @@ def infer_slice(inference_func, slice_input):
323218 return result
324219
325220
326- def prepare_output (result , task ):
327- if task == 'feedforward' :
328- return {}
329- elif task == 'classification' :
330- if hasattr (result , 'to_host' ):
331- result = result .to_host ()
332-
333- # Extract tensor from dict if needed
334- if isinstance (result , dict ):
335- result_key = next (iter (result ))
336- logits = result [result_key ]
337- output_key = result_key
338- else :
339- logits = np .array (result )
340- output_key = 'output'
341-
342- # Ensure correct shape (batch_size, num_classes)
343- if logits .ndim == 1 :
344- logits = logits .reshape (1 , - 1 )
345- elif logits .ndim > 2 :
346- logits = logits .reshape (logits .shape [0 ], - 1 )
347-
348- # Apply softmax
349- max_logits = np .max (logits , axis = - 1 , keepdims = True )
350- exp_logits = np .exp (logits - max_logits )
351- probabilities = exp_logits / np .sum (exp_logits , axis = - 1 , keepdims = True )
352-
353- return {output_key : probabilities }
354- else :
355- raise ValueError (f'Unsupported task { task } ' )
356-
357-
358- def create_dict_for_transformer (args ):
359- return {
360- 'channel_swap' : getattr (args , 'channel_swap' ),
361- 'mean' : getattr (args , 'mean' ),
362- 'std' : getattr (args , 'std' ),
363- 'norm' : getattr (args , 'norm' ),
364- 'layout' : getattr (args , 'layout' ),
365- 'input_shape' : getattr (args , 'input_shape' ),
366- 'batch_size' : getattr (args , 'batch_size' ),
367- }
368-
369-
370221def main ():
371222 args = cli_argument_parser ()
372223
@@ -380,7 +231,7 @@ def main():
380231 report_writer .update_configuration_setup (
381232 batch_size = args .batch_size ,
382233 iterations_num = args .number_iter ,
383- target_device = args .target_backend
234+ target_device = args .target_backend ,
384235 )
385236
386237 log .info ('Loading model' )
@@ -394,7 +245,7 @@ def main():
394245 input_shape = args .input_shape ,
395246 target_backend = args .target_backend ,
396247 opt_level = args .opt_level ,
397- extra_compile_args = args .extra_compile_args
248+ extra_compile_args = args .extra_compile_args ,
398249 )
399250 inference_func = get_inference_function (model_context , args .function_name )
400251
@@ -406,13 +257,13 @@ def main():
406257 inference_func ,
407258 args .number_iter ,
408259 io .get_slice_input_iree ,
409- args .time
260+ args .time ,
410261 )
411262
412263 log .info ('Computing performance metrics' )
413264 inference_result = pp .calculate_performance_metrics_sync_mode (
414265 args .batch_size ,
415- inference_time
266+ inference_time ,
416267 )
417268
418269 report_writer .update_execution_results (** inference_result )
0 commit comments