-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
How to run onnx full test suite.
First we need update onnx to latest version
pip install --upgrade onnx
Run this python script: python -m pytest backend_full_test.py
# filename: backend_full_test.py
import collections
import unittest
import jax
from jaxonnxruntime.backend import Backend as JaxBackend
import onnx.backend.test
from onnx.backend.test.loader import load_model_tests
# Some node tests require jax_enable_x64=True.
# E.g. argmax, bitshift.
jax.config.update("jax_enable_x64", True)
# This is a pytest magic variable to load extra plugins
pytest_plugins = ("onnx.backend.test.report",)
class Runner(onnx.backend.test.runner.Runner):
def __init__(self, backend, parent_module=None) -> None:
self.backend = backend
self._parent_module = parent_module
self._include_patterns = set()
self._exclude_patterns = set()
self._xfail_patterns = set()
self._test_items = collections.defaultdict(dict)
for rt in load_model_tests(kind="real"):
self._add_model_test(rt, "Real")
for rt in load_model_tests(kind="simple"):
self._add_model_test(rt, "Simple")
for rt in load_model_tests(kind='node'):
self._add_model_test(rt, 'Node')
for ct in load_model_tests(kind='pytorch-converted'):
self._add_model_test(ct, 'PyTorchConverted')
for ot in load_model_tests(kind='pytorch-operator'):
self._add_model_test(ot, 'PyTorchOperator')
backend_test = Runner(JaxBackend, __name__)
# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.enable_report().test_cases)
if __name__ == "__main__":
unittest.main()
You will see thousands tests result. Currently
==================================== 2116 failed, 518 passed, 1 warning in 375.93s (0:06:15) ====================================
You can use test filter option -k to debug the exact test case.
For example: python -m unittest backend_full_test.py -k test_resnet50_cpu
Metadata
Metadata
Assignees
Labels
No labels