Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ include_patterns = ['backends/arm/**/*.py']
exclude_patterns = [
'third-party/**',
'**/third-party/**',
'backends/arm/test/**',
]
command = [
'python','-m','lintrunner_adapters','run','docformatter_linter','--config=pyproject.toml','--','@{{PATHSFILE}}'
Expand Down
42 changes: 25 additions & 17 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,24 @@


def get_time_formatted_path(path: str, log_prefix: str) -> str:
"""
Returns the log path with the current time appended to it. Used for debugging.
"""Returns the log path with the current time appended to it. Used for
debugging.

Args:
path: The path to the folder where the log file will be stored.
log_prefix: The name of the test.

Example output:
'./my_log_folder/test_INT_artifact_28-Nov-14:14:38.log'

"""
return str(
Path(path) / f"{log_prefix}_{datetime.now().strftime('%d-%b-%H:%M:%S')}.log"
)


def maybe_get_tosa_collate_path() -> str | None:
"""
Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
"""Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
path to the where to store the current tests if it is set.
"""
tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
Expand Down Expand Up @@ -161,8 +161,8 @@ def get_vgf_compile_spec(
custom_path: Optional[str] = None,
tosa_debug_mode: VgfCompileSpec.DebugMode | None = None,
) -> VgfCompileSpec:
"""Get the ArmCompileSpec for the default VGF tests, to modify
the compile spec before calling .build() to finalize it.
"""Get the ArmCompileSpec for the default VGF tests, to modify the compile
spec before calling .build() to finalize it.
"""

if not custom_path:
Expand Down Expand Up @@ -198,7 +198,9 @@ def get_vgf_compile_spec(
raises=FileNotFoundError,
reason="Did not find Corstone-300 FVP or executor_runner on path",
)
"""Xfails a test if Corsone300 FVP is not installed, or if the executor runner is not built"""
"""Xfails a test if Corsone300 FVP is not installed, or if the executor runner
is not built.
"""

XfailIfNoCorstone320 = pytest.mark.xfail(
condition=not (
Expand All @@ -207,21 +209,23 @@ def get_vgf_compile_spec(
raises=FileNotFoundError,
reason="Did not find Corstone-320 FVP or executor_runner on path",
)
"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built"""
"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner
is not built.
"""

SkipIfNoModelConverter = pytest.mark.skipif( # type: ignore[call-arg]
condition=not (model_converter_installed()),
raises=FileNotFoundError,
reason="Did not find model-converter on path",
)
"""Skips a test if model-converter is not installed"""
"""Skips a test if model-converter is not installed."""

XfailfNoVKMLEmulationLayer = pytest.mark.xfail(
condition=not (vkml_emulation_layer_installed()),
raises=TypeError,
reason="VKML environment is not set properly or executor_runner path is misused",
)
"""Xfails a test if VKML Emulation Layer is not installed"""
"""Xfails a test if VKML Emulation Layer is not installed."""

xfail_type = str | tuple[str, type[Exception]]

Expand All @@ -238,12 +242,14 @@ def parametrize(
strict: bool = True,
flakies: dict[str, int] | None = None,
) -> Decorator:
"""
Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
- test_data is expected as a dict of (id, test_data) pairs
- alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail.
Failure_reason can be str, type[Exception], or tuple[str, type[Exception]].
Strings set the reason for failure, the exception type sets expected error.
"""Custom version of pytest.mark.parametrize with some syntatic sugar and
added xfail functionality.

- test_data is expected as a dict of (id, test_data) pairs
- alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail.
Failure_reason can be str, type[Exception], or tuple[str, type[Exception]].
Strings set the reason for failure, the exception type sets expected error.

"""
if xfails is None:
xfails = {}
Expand All @@ -253,7 +259,9 @@ def parametrize(
flakies = {}

def decorator_func(func: Callable[_P, _R]) -> Callable[_P, _R]:
"""Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function"""
"""Test data is transformed from a dict of (id, data) pairs to a list of
pytest params to work with the native pytests parametrize function.
"""
pytest_testsuite = []
for id, test_parameters in test_data.items():
if id in flakies:
Expand Down
21 changes: 11 additions & 10 deletions backends/arm/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def pytest_sessionfinish(session, exitstatus):

@pytest.fixture(autouse=True)
def set_random_seed():
"""
Control random numbers in Arm test suite. Default behavior is to use a fixed
seed (0), which ensures reproducible tests. Use the env variable ARM_TEST_SEED
to set a custom seed, or set it to RANDOM for random seed behavior.
"""Control random numbers in Arm test suite. Default behavior is to use a
fixed seed (0), which ensures reproducible tests. Use the env variable
ARM_TEST_SEED to set a custom seed, or set it to RANDOM for random seed
behavior.

Examples:
As default use fixed seed (0) for reproducible tests
Expand All @@ -76,6 +76,7 @@ def set_random_seed():
ARM_TEST_SEED=RANDOM pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k <TESTCASE>
Rerun with a specific seed
ARM_TEST_SEED=3478246 pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k <TESTCASE>

"""
import torch

Expand All @@ -100,12 +101,12 @@ def set_random_seed():


def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
"""
Returns whether an option is successfully enabled, i.e. if the flag was
"""Returns whether an option is successfully enabled, i.e. if the flag was
given to pytest and the necessary requirements are available.

The optional parameter 'fail_if_not_enabled' makes the function raise
a RuntimeError instead of returning False.
The optional parameter 'fail_if_not_enabled' makes the function raise a
RuntimeError instead of returning False.

"""

if hasattr(pytest, "_test_options") and option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined]
Expand All @@ -118,11 +119,11 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:


def get_option(option: str) -> Any | None:
"""
Returns the value of an pytest option if it is set, otherwise None.
"""Returns the value of an pytest option if it is set, otherwise None.

Args:
option (str): The option to check for.

"""
if option in pytest._test_options: # type: ignore[attr-defined]
return pytest._test_options[option] # type: ignore[attr-defined]
Expand Down
19 changes: 7 additions & 12 deletions backends/arm/test/misc/test_dim_order.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -21,9 +21,8 @@


class ChannelsLastInput(torch.nn.Module):
"""
Test a complex case with (channels last, channels first) input,
and (channels first, channels last) output.
"""Test a complex case with (channels last, channels first) input, and
(channels first, channels last) output.
"""

inputs: input_t1 = (
Expand All @@ -39,9 +38,7 @@ def forward(self, x, y):


class ChannelsFirstOutput(torch.nn.Module):
"""
Test coverting to channels_first inside the delegate.
"""
"""Test coverting to channels_first inside the delegate."""

inputs: input_t1 = (
torch.arange(1, 25, dtype=torch.float32)
Expand All @@ -55,9 +52,7 @@ def forward(self, x):


class ChannelsLastOutput(torch.nn.Module):
"""
Test changing of dim_order inside the delegate.
"""
"""Test changing of dim_order inside the delegate."""

inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),)

Expand All @@ -68,8 +63,8 @@ def forward(self, x):


class ChannelsLastInsidePartition(torch.nn.Module):
"""
Test dim_order changes inside the partiton, but no dim_order changes at input/output.
"""Test dim_order changes inside the partiton, but no dim_order changes at
input/output.
"""

inputs: input_t1 = (torch.randn((1, 2, 3, 3)),)
Expand Down
6 changes: 2 additions & 4 deletions backends/arm/test/misc/test_extract_io_params_tosa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -45,9 +45,7 @@ def test_roundtrip_extracts_io_params_tosa_INT(
quantizer_cls,
partitioner_cls,
):
"""
Validates that IO quantization parameters round-trip for both flows.
"""
"""Validates that IO quantization parameters round-trip for both flows."""
example_inputs = (
torch.ones(1, 5),
torch.full((1, 5), 2.0),
Expand Down
16 changes: 6 additions & 10 deletions backends/arm/test/misc/test_non_persistent_buffers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,9 +14,7 @@


class NonPersistentBuffer(nn.Module):
"""
Min code version registering a non-persistent input buffer.
"""
"""Min code version registering a non-persistent input buffer."""

def __init__(self):
super().__init__()
Expand All @@ -33,17 +31,15 @@ def forward(self, x):

@parametrize("test_data", test_input)
def test_non_persistent_buffer_tosa_FP(test_data: input_t):
"""
Test validates Arm backend handling of non-persistent buffers
and ensures that there are no asserts or errors when they are used.
"""Test validates Arm backend handling of non-persistent buffers and ensures
that there are no asserts or errors when they are used.
"""
TosaPipelineFP[input_t](NonPersistentBuffer(), test_data, "").run()


@parametrize("test_data", test_input)
def test_non_persistent_buffer_tosa_INT(test_data: input_t):
"""
Test validates Arm backend handling of non-persistent buffers
and ensures that there are no asserts or errors when they are used.
"""Test validates Arm backend handling of non-persistent buffers and ensures
that there are no asserts or errors when they are used.
"""
TosaPipelineINT[input_t](NonPersistentBuffer(), test_data, "").run()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -48,7 +48,11 @@


class SoftplusModule(torch.nn.Module):
"""Module containing an addition followed by a Softplus. Softplus is currently not supported by TosaBackend."""
"""Module containing an addition followed by a Softplus.

Softplus is currently not supported by TosaBackend.

"""

def __init__(self):
super().__init__()
Expand All @@ -59,8 +63,11 @@ def forward(self, x: torch.Tensor):


class LinearResidualModule(torch.nn.Module):
"""Module containing a residual and a linear layer followed by GELU and a Dropout.
"""Module containing a residual and a linear layer followed by GELU and a
Dropout.

GELU is currently not supported by TosaBackend nor TosaQuantizer.

"""

def __init__(
Expand Down
10 changes: 7 additions & 3 deletions backends/arm/test/misc/test_qat_training_loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -46,8 +46,12 @@ def evaluate_model(model, inputs, expected_outputs):

def test_qat_training_loop_tosa_INT():
"""Test the QAT training loop with a simple MLP model.
This function creates a simple MLP model, prepares it for QAT, runs a training loop,
and evaluates the quantized model to make sure everything works as expected."""

This function creates a simple MLP model, prepares it for QAT, runs a
training loop, and evaluates the quantized model to make sure everything
works as expected.

"""

model = MLP()
logger.info("Starting training loop test")
Expand Down
14 changes: 9 additions & 5 deletions backends/arm/test/misc/test_quant_custom_meta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -19,11 +19,13 @@ def forward(self, x, y):

@pytest.mark.parametrize("fp_extension", [True, False])
def test_qdq_squeezed_fp_op_tosa_INT_FP(fp_extension: bool):
"""Test that a float operation surrounded by quantize-dequantize pairs
is correctly handled by the partitioner and the TOSA backend.
"""Test that a float operation surrounded by quantize-dequantize pairs is
correctly handled by the partitioner and the TOSA backend.

Pattern:
q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q
|_____unquantized_____|

"""
aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
Expand Down Expand Up @@ -67,11 +69,13 @@ def forward(self, x, y):

@pytest.mark.parametrize("fp_extension", [True, False])
def test_quantized_to_float_transition_tosa_INT_FP(fp_extension: bool):
"""Test that a model executing quantized ops followed by float ops
is correctly handled by the partitioner and the TOSA backend.
"""Test that a model executing quantized ops followed by float ops is
correctly handled by the partitioner and the TOSA backend.

Pattern:
q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv
|___unquantized___|

"""
aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
Expand Down
Loading
Loading