Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/accuracy_checker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def execute(self, idx):
command_line = self.__fill_command_line()
if command_line == '':
self.__log.error('Command line is empty')
self.__log.info(f'Start accuracy check for {idx+1} test: {self._test.model.name}')
self.__log.info(f'Start accuracy check for {idx + 1} test: {self._test.model.name}')
self.__log.info(f'Command line is : {command_line}')
self._executor.set_target_framework(self._test.framework)
command_line = self._executor.prepare_command_line(self._test, command_line)
Expand Down
78 changes: 78 additions & 0 deletions src/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
1. ncnn.
1. PaddlePaddle.
1. Spektral.
1. IREE.

## Вывод глубоких моделей с использованием Inference Engine

Expand Down Expand Up @@ -1482,6 +1483,82 @@ python inference_ncnn.py --model <model_name> \
--batch_size <batch_size>
```

## Вывод глубоких моделей с использованием IREE

#### Аргументы командной строки

Название скрипта:

```bash
inference_iree.py
```

Обязательные аргументы:

- `-m / --model` - путь до vmfb-файла, содержащего скомпилированную модель.
- `-fn / --function_name` - название функции, которая будет вызвана IREE для исполнения модели.
- `-i / --input` - путь до изображения или директории с изображениями
(расширения файлов `.jpg`, `.png`, `.bmp` и т.д.).
- `-is / --input_shape` - размеры входного тензора сети в формате
BxCxWxH, B - размер пачки, C - количество каналов изображений,
W - ширина изображений, H - высота изображений.

Опциональные аргументы:

- `-b / --batch_size` - количество изображений, которые будут обработаны
за один проход сети. По умолчанию равно `1`.
- `-ni / --number_iter` - количество прямых проходов по сети.
По умолчанию выполняется один проход по сети.
- `--time` - время выполнения инференса в секундах. Этот параметр можно
задать вместо задать вместо параметра `-ni / --number_iter`. Если
одновременно указать и `-ni / --number_iter` и `--time`,
то будет учитываться тот параметр, при котором инферес работает дольше.
- `--layout` - формат входных тензоров. По умолчанию `NHWС`.
- `--channel_swap` - порядок перестановки цветовых каналов изображения.
Загрузка изображений осуществляется в формате BGR (порядок
соответствует `(0, 1, 2)`), а большинство нейронных сетей принимают
на вход изображения в формате RGB, поэтому по умолчанию порядок
`(2, 1, 0)`.
- `--norm` - флаг необходимости нормировки изображений.
Среднее и среднеквадратическое отклонение, которые принимаются
на вход указываются в следующих двух аргументах.
- `--mean` - среднее значение интенсивности, которое вычитается
из изображений в процессе нормировки. По умолчанию
данный параметр принимает значение `0 0 0`.
- `--std` - среднеквадратическое отклонение интенсивности, на которое
делится значение интенсивности каждого пикселя входного изображения
в процессе нормировки. По умолчанию данный параметр принимает значение `1 1 1`.
- `-t / --task` - название задачи. Текущая реализация поддерживает
решение задачи классификации (`classification`). По умолчанию принимает значение `feedforward`.
- `-nt / --number_top` - количество лучших результатов, выводимых
при решении задачи классификации. По умолчанию выводится `10` наилучших
результатов.
- `-l / --labels`- путь до файла в формате JSON с перечнем меток
при решении задачи. По умолчанию принимает значение
`image_net_labels.json`, что соответствует меткам набора данных
ImageNet.
- `-d / --device` - оборудование, на котором выполняется вывод сети.
Поддерживается вывод на CPU (значение параметра `CPU`). По умолчанию принимает значение `CPU`.
- `--raw_output` - работа скрипта без логов. По умолчанию не установлен.
- `--report_path` - путь до файла с отчетом в формате `.json`.


#### Примеры запуска

**Командная строка для решения задачи классификации изображений**

```bash
python3 inference_iree.py \
-t classification -i <path_to_image>/<image_name> \
-m <path_to_model>/<model_name>.vmfb \
--function_name main_graph \
--input_shape 1 3 224 224 \
--labels <path_to_labels>/image_net_synset.txt
```

Результат выполнения: набор наиболее вероятных классов, которым принадлежит
изображение.

<!-- LINKS -->
[execution_providers]: https://onnxruntime.ai/docs/execution-providers
[gluon_modelzoo]: https://cv.gluon.ai/model_zoo/index.html
Expand All @@ -1492,3 +1569,4 @@ python inference_ncnn.py --model <model_name> \
[dgl]: https://www.dgl.ai/pages/start.html
[ogb]: https://ogb.stanford.edu/
[tensorflow-gpu]: https://www.tensorflow.org/install/pip
[iree]: https://iree.dev
290 changes: 290 additions & 0 deletions src/inference/inference_iree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import argparse
import sys
import traceback
from pathlib import Path

import postprocessing_data as pp
from inference_tools.loop_tools import loop_inference, get_exec_time
from io_adapter import IOAdapter
from io_model_wrapper import IREEModelWrapper
from reporter.report_writer import ReportWriter
from transformer import IREETransformer
from iree_auxiliary import (load_model, create_dict_for_transformer, prepare_output, validate_cli_args)


sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
from logger_conf import configure_logger # noqa: E402

log = configure_logger()

try:
import iree.runtime as ireert # noqa: E402
except ImportError as e:
log.error(f'IREE import error: {e}')
sys.exit(1)


def cli_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--source_framework',
help='Source model framework (required for automatic conversion to MLIR)',
type=str,
choices=['onnx', 'pytorch'],
dest='source_framework')
parser.add_argument('-m', '--model',
help='Path to source framework model (.onnx, .pt),'
'to file with compiled model (.vmfb)'
'or MLIR (.mlir).',
type=str,
dest='model')
parser.add_argument('-w', '--weights',
help='Path to an .pth file with a trained weights.'
'Availiable when source_framework=pytorch ',
type=str,
dest='model_weights')
parser.add_argument('-tm', '--torch_module',
help='Torch module with model architecture.'
'Availiable when source_framework=pytorch',
type=str,
dest='torch_module')
parser.add_argument('-mn', '--model_name',
help='Model name.',
type=str,
dest='model_name')
parser.add_argument('--onnx_opset_version',
help='Path to an .onnx with a trained model.'
'Availiable when source_framework=onnx',
type=int,
dest='onnx_opset_version')
parser.add_argument('-fn', '--function_name',
help='IREE module function name to execute.',
required=True,
type=str,
dest='function_name')
parser.add_argument('-i', '--input',
help='Path to data.',
required=True,
type=str,
nargs='+',
dest='input')
parser.add_argument('-is', '--input_shape',
help='Input shape BxHxWxC, B is a batch size,'
'H is an input tensor height,'
'W is an input tensor width,'
'C is an input tensor number of channels.',
required=True,
type=int,
nargs=4,
dest='input_shape')
parser.add_argument('-b', '--batch_size',
help='Size of the processed pack.'
'Should be the same as B in input_shape argument.',
default=1,
type=int,
dest='batch_size')
parser.add_argument('-l', '--labels',
help='Labels mapping file.',
default=None,
type=str,
dest='labels')
parser.add_argument('-nt', '--number_top',
help='Number of top results.',
default=5,
type=int,
dest='number_top')
parser.add_argument('-t', '--task',
help='Task type. Default: feedforward.',
choices=['feedforward', 'classification'],
default='feedforward',
type=str,
dest='task')
parser.add_argument('-ni', '--number_iter',
help='Number of inference iterations.',
default=1,
type=int,
dest='number_iter')
parser.add_argument('--raw_output',
help='Raw output without logs.',
default=False,
type=bool,
dest='raw_output')
parser.add_argument('--time',
required=False,
default=0,
type=int,
dest='time',
help='Optional. Maximum test duration. 0 if no restrictions.')
parser.add_argument('--report_path',
type=Path,
default=Path(__file__).parent / 'iree_inference_report.json',
dest='report_path')
parser.add_argument('--layout',
help='Input layout.',
default='NHWC',
choices=['NHWC', 'NCHW'],
type=str,
dest='layout')
parser.add_argument('--norm',
help='Flag to normalize input images.',
action='store_true',
dest='norm')
parser.add_argument('--mean',
help='Mean values.',
default=[0, 0, 0],
type=float,
nargs=3,
dest='mean')
parser.add_argument('--std',
help='Standard deviation values.',
default=[1., 1., 1.],
type=float,
nargs=3,
dest='std')
parser.add_argument('--channel_swap',
help='Parameter of channel swap.',
default=[2, 1, 0],
type=int,
nargs=3,
dest='channel_swap')
parser.add_argument('-tb', '--target_backend',
help='Target backend, for example `llvm-cpu` for CPU.',
default='llvm-cpu',
type=str,
dest='target_backend')
parser.add_argument('--opt_level',
help='The optimization level of the compilation.',
type=int,
choices=[0, 1, 2, 3],
default=2)
parser.add_argument('--extra_compile_args',
help='The extra arguments for MLIR compilation.',
type=str,
nargs=argparse.REMAINDER,
default=[])
args = parser.parse_args()
validate_cli_args(args)
return args


def get_inference_function(model_context, function_name):
try:
main_module = model_context.modules.module
inference_func = main_module[function_name]
log.info(f'Using function {function_name} for inference')
return inference_func

except Exception as e:
log.error(f'Failed to get inference function: {e}')
raise


def inference_iree(inference_func, number_iter, get_slice, test_duration):
result = None
time_infer = []

if number_iter == 1:
slice_input = get_slice()
result, exec_time = infer_slice(inference_func, slice_input)
time_infer.append(exec_time)
else:
time_infer = loop_inference(number_iter, test_duration)(
inference_iteration,
)(inference_func, get_slice)['time_infer']

log.info('Inference completed')
return result, time_infer


def inference_iteration(inference_func, get_slice):
slice_input = get_slice()
_, exec_time = infer_slice(inference_func, slice_input)
return exec_time


@get_exec_time()
def infer_slice(inference_func, slice_input):
config = ireert.Config('local-task')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Такой же вопрос по 'local-task'.

device = config.device

input_buffers = ()
for input_ in slice_input:
input_buffers.append(ireert.asdevicearray(device, input_))

result = inference_func(*input_buffers)

if hasattr(result, 'to_host'):
result = result.to_host()

return result


def main():
args = cli_argument_parser()

try:
model_wrapper = IREEModelWrapper(args)
data_transformer = IREETransformer(create_dict_for_transformer(args))
io = IOAdapter.get_io_adapter(args, model_wrapper, data_transformer)

report_writer = ReportWriter()
report_writer.update_framework_info(name='IREE')
report_writer.update_configuration_setup(
batch_size=args.batch_size,
iterations_num=args.number_iter,
target_device=args.target_backend,
)

log.info('Loading model')
model_context = load_model(
model_path=args.model,
model_weights=args.model_weights,
torch_module=args.torch_module,
model_name=args.model_name,
onnx_opset_version=args.onnx_opset_version,
source_framework=args.source_framework,
input_shape=args.input_shape,
target_backend=args.target_backend,
opt_level=args.opt_level,
extra_compile_args=args.extra_compile_args,
)
inference_func = get_inference_function(model_context, args.function_name)

log.info(f'Preparing input data: {args.input}')
io.prepare_input(model_context, args.input)

log.info(f'Starting inference ({args.number_iter} iterations) on {args.target_backend}')
result, inference_time = inference_iree(
inference_func,
args.number_iter,
io.get_slice_input_iree,
args.time,
)

log.info('Computing performance metrics')
inference_result = pp.calculate_performance_metrics_sync_mode(
args.batch_size,
inference_time,
)

report_writer.update_execution_results(**inference_result)
report_writer.write_report(args.report_path)

if not args.raw_output:
if args.number_iter == 1:
try:
log.info('Converting output tensor to print results')
result = prepare_output(result, args.task)
log.info('Inference results')
io.process_output(result, log)
except Exception as ex:
log.warning(f'Error when printing inference results: {str(ex)}')

log.info(f'Performance results: {inference_result}')

except Exception:
log.error(traceback.format_exc())
sys.exit(1)


if __name__ == '__main__':
sys.exit(main() or 0)
Loading
Loading