Skip to content

Commit 15ac8f9

Browse files
IREE inference_auxiliary
1 parent 4d09e21 commit 15ac8f9

File tree

8 files changed

+264
-166
lines changed

8 files changed

+264
-166
lines changed

src/inference/inference_iree.py

Lines changed: 8 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import argparse
2-
import os
32
import sys
43
import traceback
5-
import tempfile
64
from pathlib import Path
75

86
import postprocessing_data as pp
@@ -11,14 +9,8 @@
119
from io_model_wrapper import IREEModelWrapper
1210
from reporter.report_writer import ReportWriter
1311
from transformer import IREETransformer
12+
from iree_auxiliary import (load_model, create_dict_for_transformer, prepare_output, validate_cli_args)
1413

15-
import numpy as np
16-
17-
sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('model_converters',
18-
'iree_converter',
19-
'iree_auxiliary')))
20-
from compiler import IREECompiler # noqa: E402
21-
from converter import IREEConverter # noqa: E402
2214

2315
sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
2416
from logger_conf import configure_logger # noqa: E402
@@ -32,20 +24,13 @@
3224
sys.exit(1)
3325

3426

35-
def validate_cli_args(args):
36-
if args.model:
37-
pass
38-
else:
39-
pass
40-
41-
4227
def cli_argument_parser():
4328
parser = argparse.ArgumentParser()
4429
parser.add_argument('-f', '--source_framework',
4530
help='Source model framework (required for automatic conversion to MLIR)',
4631
type=str,
4732
choices=['onnx', 'pytorch'],
48-
dest='source_framework')
33+
dest='source_framework')
4934
parser.add_argument('-m', '--model',
5035
help='Path to source framework model (.onnx, .pt),'
5136
'to file with compiled model (.vmfb)'
@@ -181,96 +166,6 @@ def cli_argument_parser():
181166
return args
182167

183168

184-
def convert_model_to_mlir(model_path, model_weights, torch_module, model_name, onnx_opset_version, source_framework, input_shape, output_mlir):
185-
dictionary = {
186-
'source_framework': source_framework,
187-
'model_name': model_name,
188-
'model_path': model_path,
189-
'model_weights': model_weights,
190-
'torch_module': torch_module,
191-
'onnx_opset_version': onnx_opset_version,
192-
'input_shape': input_shape,
193-
'output_mlir': output_mlir
194-
}
195-
converter = IREEConverter.get_converter(dictionary)
196-
converter.convert_to_mlir()
197-
return
198-
199-
200-
def compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args):
201-
try:
202-
log.info('Starting model compilation')
203-
return IREECompiler.compile(mlir_path, target_backend, opt_level, extra_compile_args)
204-
except Exception as e:
205-
log.error(f'Failed to compile MLIR: {e}')
206-
raise
207-
208-
209-
def load_model_buffer(model_path, target_backend, opt_level, extra_compile_args):
210-
if not os.path.exists(model_path):
211-
raise FileNotFoundError(f'Model file not found: {model_path}')
212-
213-
file_type = model_path.split('.')[-1]
214-
215-
if file_type == 'mlir':
216-
if target_backend is None:
217-
raise ValueError('target_backend is required for MLIR compilation')
218-
vmfb_buffer = compile_mlir(model_path, target_backend, opt_level, extra_compile_args)
219-
elif file_type == 'vmfb':
220-
with open(model_path, 'rb') as f:
221-
vmfb_buffer = f.read()
222-
else:
223-
raise ValueError(f'The file type {file_type} is not supported. Supported types: .mlir, .vmfb')
224-
225-
log.info(f'Successfully loaded model buffer from {model_path}')
226-
return vmfb_buffer
227-
228-
229-
def create_iree_context_from_buffer(vmfb_buffer):
230-
try:
231-
config = ireert.Config('local-task')
232-
vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, vmfb_buffer)
233-
context = ireert.SystemContext(config=config)
234-
context.add_vm_module(vm_module)
235-
236-
log.info('Successfully created IREE context from buffer')
237-
return context
238-
239-
except Exception as e:
240-
log.error(f'Failed to create IREE context: {e}')
241-
raise
242-
243-
244-
def load_model(model_path, model_weights, torch_module, model_name, onnx_opset_version,
245-
source_framework, input_shape, target_backend, opt_level, extra_compile_args):
246-
is_tmp_mlir = False
247-
if model_path is None or model_path.split('.')[-1] not in ['vmfb', 'mlir']:
248-
with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.mlir') as temp:
249-
output_mlir = temp.name
250-
convert_model_to_mlir(model_path,
251-
model_weights,
252-
torch_module,
253-
model_name,
254-
onnx_opset_version,
255-
source_framework,
256-
input_shape,
257-
output_mlir)
258-
model_path = output_mlir
259-
is_tmp_mlir = True
260-
261-
vmfb_buffer = load_model_buffer(
262-
model_path,
263-
target_backend=target_backend,
264-
opt_level=opt_level,
265-
extra_compile_args=extra_compile_args
266-
)
267-
268-
if is_tmp_mlir:
269-
os.remove(model_path)
270-
271-
return create_iree_context_from_buffer(vmfb_buffer)
272-
273-
274169
def get_inference_function(model_context, function_name):
275170
try:
276171
main_module = model_context.modules.module
@@ -293,7 +188,7 @@ def inference_iree(inference_func, number_iter, get_slice, test_duration):
293188
time_infer.append(exec_time)
294189
else:
295190
time_infer = loop_inference(number_iter, test_duration)(
296-
inference_iteration
191+
inference_iteration,
297192
)(inference_func, get_slice)['time_infer']
298193

299194
log.info('Inference completed')
@@ -311,7 +206,7 @@ def infer_slice(inference_func, slice_input):
311206
config = ireert.Config('local-task')
312207
device = config.device
313208

314-
input_buffers = list()
209+
input_buffers = ()
315210
for input_ in slice_input:
316211
input_buffers.append(ireert.asdevicearray(device, input_))
317212

@@ -323,50 +218,6 @@ def infer_slice(inference_func, slice_input):
323218
return result
324219

325220

326-
def prepare_output(result, task):
327-
if task == 'feedforward':
328-
return {}
329-
elif task == 'classification':
330-
if hasattr(result, 'to_host'):
331-
result = result.to_host()
332-
333-
# Extract tensor from dict if needed
334-
if isinstance(result, dict):
335-
result_key = next(iter(result))
336-
logits = result[result_key]
337-
output_key = result_key
338-
else:
339-
logits = np.array(result)
340-
output_key = 'output'
341-
342-
# Ensure correct shape (batch_size, num_classes)
343-
if logits.ndim == 1:
344-
logits = logits.reshape(1, -1)
345-
elif logits.ndim > 2:
346-
logits = logits.reshape(logits.shape[0], -1)
347-
348-
# Apply softmax
349-
max_logits = np.max(logits, axis=-1, keepdims=True)
350-
exp_logits = np.exp(logits - max_logits)
351-
probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
352-
353-
return {output_key: probabilities}
354-
else:
355-
raise ValueError(f'Unsupported task {task}')
356-
357-
358-
def create_dict_for_transformer(args):
359-
return {
360-
'channel_swap': getattr(args, 'channel_swap'),
361-
'mean': getattr(args, 'mean'),
362-
'std': getattr(args, 'std'),
363-
'norm': getattr(args, 'norm'),
364-
'layout': getattr(args, 'layout'),
365-
'input_shape': getattr(args, 'input_shape'),
366-
'batch_size': getattr(args, 'batch_size'),
367-
}
368-
369-
370221
def main():
371222
args = cli_argument_parser()
372223

@@ -380,7 +231,7 @@ def main():
380231
report_writer.update_configuration_setup(
381232
batch_size=args.batch_size,
382233
iterations_num=args.number_iter,
383-
target_device=args.target_backend
234+
target_device=args.target_backend,
384235
)
385236

386237
log.info('Loading model')
@@ -394,7 +245,7 @@ def main():
394245
input_shape=args.input_shape,
395246
target_backend=args.target_backend,
396247
opt_level=args.opt_level,
397-
extra_compile_args=args.extra_compile_args
248+
extra_compile_args=args.extra_compile_args,
398249
)
399250
inference_func = get_inference_function(model_context, args.function_name)
400251

@@ -406,13 +257,13 @@ def main():
406257
inference_func,
407258
args.number_iter,
408259
io.get_slice_input_iree,
409-
args.time
260+
args.time,
410261
)
411262

412263
log.info('Computing performance metrics')
413264
inference_result = pp.calculate_performance_metrics_sync_mode(
414265
args.batch_size,
415-
inference_time
266+
inference_time,
416267
)
417268

418269
report_writer.update_execution_results(**inference_result)

src/inference/io_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_slice_input(self, *args, **kwargs):
187187
return slice_input
188188

189189
def get_slice_input_iree(self, *args, **kwargs):
190-
slice_input = list()
190+
slice_input = ()
191191
for key in self._transformed_input:
192192
data_gen = self._transformed_input[key]
193193
slice_data = [copy.deepcopy(next(data_gen)) for _ in range(self._batch_size)]

0 commit comments

Comments
 (0)