From 4e449f519c64bf0768daad3b82f7c16bd9f813cc Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Wed, 11 Feb 2026 13:39:24 +0100 Subject: [PATCH] Arm backend: Format docs for remaining files in backends/arm/test/ Change-Id: I6d07cc2f917a719bf945e47315ef383f93ac9450 Signed-off-by: Sebastian Larsson --- .lintrunner.toml | 1 - backends/arm/test/common.py | 42 ++++++----- backends/arm/test/conftest.py | 21 +++--- backends/arm/test/misc/test_dim_order.py | 19 ++--- .../test/misc/test_extract_io_params_tosa.py | 6 +- .../test/misc/test_non_persistent_buffers.py | 16 ++--- ...test_partition_decomposed_quantized_ops.py | 13 +++- .../arm/test/misc/test_qat_training_loop.py | 10 ++- .../arm/test/misc/test_quant_custom_meta.py | 14 ++-- backends/arm/test/misc/test_shared_qspecs.py | 56 ++++++++++----- backends/arm/test/misc/test_tosa_spec.py | 4 +- .../stable_diffusion_module_test_configs.py | 3 +- .../arm/test/models/test_nn_functional.py | 1 - backends/arm/test/models/test_nn_modules.py | 1 - .../arm/test/models/test_torch_functions.py | 1 - backends/arm/test/ops/test_div_tensor_mode.py | 4 +- backends/arm/test/ops/test_max_pool1d.py | 9 ++- backends/arm/test/ops/test_scalars.py | 27 ++++--- backends/arm/test/ops/test_sdpa.py | 18 +++-- ...st_decompose_int16_activation_conv_pass.py | 48 +++++++------ .../passes/test_decompose_softmax_pass.py | 5 +- .../quantizer/test_partial_quantization.py | 4 +- .../test/quantizer/test_preserve_kwargs.py | 4 +- backends/arm/test/runner_utils.py | 38 ++++++---- .../arm/test/test_memory_allocator_log.py | 6 +- .../arm/test/tester/analyze_output_utils.py | 6 +- backends/arm/test/tester/arm_tester.py | 47 +++++++----- backends/arm/test/tester/test_pipeline.py | 71 +++++++++++-------- 28 files changed, 287 insertions(+), 208 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 84d2305ec15..2d279b16d64 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -511,7 +511,6 @@ include_patterns = ['backends/arm/**/*.py'] exclude_patterns = [ 'third-party/**', '**/third-party/**', - 'backends/arm/test/**', ] command = [ 'python','-m','lintrunner_adapters','run','docformatter_linter','--config=pyproject.toml','--','@{{PATHSFILE}}' diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index c68d75bede0..15354c5005a 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -27,8 +27,8 @@ def get_time_formatted_path(path: str, log_prefix: str) -> str: - """ - Returns the log path with the current time appended to it. Used for debugging. + """Returns the log path with the current time appended to it. Used for + debugging. Args: path: The path to the folder where the log file will be stored. @@ -36,6 +36,7 @@ def get_time_formatted_path(path: str, log_prefix: str) -> str: Example output: './my_log_folder/test_INT_artifact_28-Nov-14:14:38.log' + """ return str( Path(path) / f"{log_prefix}_{datetime.now().strftime('%d-%b-%H:%M:%S')}.log" @@ -43,8 +44,7 @@ def get_time_formatted_path(path: str, log_prefix: str) -> str: def maybe_get_tosa_collate_path() -> str | None: - """ - Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the + """Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the path to the where to store the current tests if it is set. """ tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH") @@ -161,8 +161,8 @@ def get_vgf_compile_spec( custom_path: Optional[str] = None, tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, ) -> VgfCompileSpec: - """Get the ArmCompileSpec for the default VGF tests, to modify - the compile spec before calling .build() to finalize it. + """Get the ArmCompileSpec for the default VGF tests, to modify the compile + spec before calling .build() to finalize it. """ if not custom_path: @@ -198,7 +198,9 @@ def get_vgf_compile_spec( raises=FileNotFoundError, reason="Did not find Corstone-300 FVP or executor_runner on path", ) -"""Xfails a test if Corsone300 FVP is not installed, or if the executor runner is not built""" +"""Xfails a test if Corsone300 FVP is not installed, or if the executor runner +is not built. +""" XfailIfNoCorstone320 = pytest.mark.xfail( condition=not ( @@ -207,21 +209,23 @@ def get_vgf_compile_spec( raises=FileNotFoundError, reason="Did not find Corstone-320 FVP or executor_runner on path", ) -"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built""" +"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner +is not built. +""" SkipIfNoModelConverter = pytest.mark.skipif( # type: ignore[call-arg] condition=not (model_converter_installed()), raises=FileNotFoundError, reason="Did not find model-converter on path", ) -"""Skips a test if model-converter is not installed""" +"""Skips a test if model-converter is not installed.""" XfailfNoVKMLEmulationLayer = pytest.mark.xfail( condition=not (vkml_emulation_layer_installed()), raises=TypeError, reason="VKML environment is not set properly or executor_runner path is misused", ) -"""Xfails a test if VKML Emulation Layer is not installed""" +"""Xfails a test if VKML Emulation Layer is not installed.""" xfail_type = str | tuple[str, type[Exception]] @@ -238,12 +242,14 @@ def parametrize( strict: bool = True, flakies: dict[str, int] | None = None, ) -> Decorator: - """ - Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality - - test_data is expected as a dict of (id, test_data) pairs - - alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail. - Failure_reason can be str, type[Exception], or tuple[str, type[Exception]]. - Strings set the reason for failure, the exception type sets expected error. + """Custom version of pytest.mark.parametrize with some syntatic sugar and + added xfail functionality. + + - test_data is expected as a dict of (id, test_data) pairs + - alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail. + Failure_reason can be str, type[Exception], or tuple[str, type[Exception]]. + Strings set the reason for failure, the exception type sets expected error. + """ if xfails is None: xfails = {} @@ -253,7 +259,9 @@ def parametrize( flakies = {} def decorator_func(func: Callable[_P, _R]) -> Callable[_P, _R]: - """Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function""" + """Test data is transformed from a dict of (id, data) pairs to a list of + pytest params to work with the native pytests parametrize function. + """ pytest_testsuite = [] for id, test_parameters in test_data.items(): if id in flakies: diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index b52091b833a..351d6de7a09 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -64,10 +64,10 @@ def pytest_sessionfinish(session, exitstatus): @pytest.fixture(autouse=True) def set_random_seed(): - """ - Control random numbers in Arm test suite. Default behavior is to use a fixed - seed (0), which ensures reproducible tests. Use the env variable ARM_TEST_SEED - to set a custom seed, or set it to RANDOM for random seed behavior. + """Control random numbers in Arm test suite. Default behavior is to use a + fixed seed (0), which ensures reproducible tests. Use the env variable + ARM_TEST_SEED to set a custom seed, or set it to RANDOM for random seed + behavior. Examples: As default use fixed seed (0) for reproducible tests @@ -76,6 +76,7 @@ def set_random_seed(): ARM_TEST_SEED=RANDOM pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k Rerun with a specific seed ARM_TEST_SEED=3478246 pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k + """ import torch @@ -100,12 +101,12 @@ def set_random_seed(): def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: - """ - Returns whether an option is successfully enabled, i.e. if the flag was + """Returns whether an option is successfully enabled, i.e. if the flag was given to pytest and the necessary requirements are available. - The optional parameter 'fail_if_not_enabled' makes the function raise - a RuntimeError instead of returning False. + The optional parameter 'fail_if_not_enabled' makes the function raise a + RuntimeError instead of returning False. + """ if hasattr(pytest, "_test_options") and option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] @@ -118,11 +119,11 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: def get_option(option: str) -> Any | None: - """ - Returns the value of an pytest option if it is set, otherwise None. + """Returns the value of an pytest option if it is set, otherwise None. Args: option (str): The option to check for. + """ if option in pytest._test_options: # type: ignore[attr-defined] return pytest._test_options[option] # type: ignore[attr-defined] diff --git a/backends/arm/test/misc/test_dim_order.py b/backends/arm/test/misc/test_dim_order.py index 14e12461652..4c36fcd9e89 100644 --- a/backends/arm/test/misc/test_dim_order.py +++ b/backends/arm/test/misc/test_dim_order.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -21,9 +21,8 @@ class ChannelsLastInput(torch.nn.Module): - """ - Test a complex case with (channels last, channels first) input, - and (channels first, channels last) output. + """Test a complex case with (channels last, channels first) input, and + (channels first, channels last) output. """ inputs: input_t1 = ( @@ -39,9 +38,7 @@ def forward(self, x, y): class ChannelsFirstOutput(torch.nn.Module): - """ - Test coverting to channels_first inside the delegate. - """ + """Test coverting to channels_first inside the delegate.""" inputs: input_t1 = ( torch.arange(1, 25, dtype=torch.float32) @@ -55,9 +52,7 @@ def forward(self, x): class ChannelsLastOutput(torch.nn.Module): - """ - Test changing of dim_order inside the delegate. - """ + """Test changing of dim_order inside the delegate.""" inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),) @@ -68,8 +63,8 @@ def forward(self, x): class ChannelsLastInsidePartition(torch.nn.Module): - """ - Test dim_order changes inside the partiton, but no dim_order changes at input/output. + """Test dim_order changes inside the partiton, but no dim_order changes at + input/output. """ inputs: input_t1 = (torch.randn((1, 2, 3, 3)),) diff --git a/backends/arm/test/misc/test_extract_io_params_tosa.py b/backends/arm/test/misc/test_extract_io_params_tosa.py index 229970b2be0..cd1a6e37d43 100644 --- a/backends/arm/test/misc/test_extract_io_params_tosa.py +++ b/backends/arm/test/misc/test_extract_io_params_tosa.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -45,9 +45,7 @@ def test_roundtrip_extracts_io_params_tosa_INT( quantizer_cls, partitioner_cls, ): - """ - Validates that IO quantization parameters round-trip for both flows. - """ + """Validates that IO quantization parameters round-trip for both flows.""" example_inputs = ( torch.ones(1, 5), torch.full((1, 5), 2.0), diff --git a/backends/arm/test/misc/test_non_persistent_buffers.py b/backends/arm/test/misc/test_non_persistent_buffers.py index 374eb0a57d6..50648e5867d 100644 --- a/backends/arm/test/misc/test_non_persistent_buffers.py +++ b/backends/arm/test/misc/test_non_persistent_buffers.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,9 +14,7 @@ class NonPersistentBuffer(nn.Module): - """ - Min code version registering a non-persistent input buffer. - """ + """Min code version registering a non-persistent input buffer.""" def __init__(self): super().__init__() @@ -33,17 +31,15 @@ def forward(self, x): @parametrize("test_data", test_input) def test_non_persistent_buffer_tosa_FP(test_data: input_t): - """ - Test validates Arm backend handling of non-persistent buffers - and ensures that there are no asserts or errors when they are used. + """Test validates Arm backend handling of non-persistent buffers and ensures + that there are no asserts or errors when they are used. """ TosaPipelineFP[input_t](NonPersistentBuffer(), test_data, "").run() @parametrize("test_data", test_input) def test_non_persistent_buffer_tosa_INT(test_data: input_t): - """ - Test validates Arm backend handling of non-persistent buffers - and ensures that there are no asserts or errors when they are used. + """Test validates Arm backend handling of non-persistent buffers and ensures + that there are no asserts or errors when they are used. """ TosaPipelineINT[input_t](NonPersistentBuffer(), test_data, "").run() diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index 0514ad5e280..974f3dd3349 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,7 +48,11 @@ class SoftplusModule(torch.nn.Module): - """Module containing an addition followed by a Softplus. Softplus is currently not supported by TosaBackend.""" + """Module containing an addition followed by a Softplus. + + Softplus is currently not supported by TosaBackend. + + """ def __init__(self): super().__init__() @@ -59,8 +63,11 @@ def forward(self, x: torch.Tensor): class LinearResidualModule(torch.nn.Module): - """Module containing a residual and a linear layer followed by GELU and a Dropout. + """Module containing a residual and a linear layer followed by GELU and a + Dropout. + GELU is currently not supported by TosaBackend nor TosaQuantizer. + """ def __init__( diff --git a/backends/arm/test/misc/test_qat_training_loop.py b/backends/arm/test/misc/test_qat_training_loop.py index 425849b8564..6a45ac09489 100644 --- a/backends/arm/test/misc/test_qat_training_loop.py +++ b/backends/arm/test/misc/test_qat_training_loop.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,8 +46,12 @@ def evaluate_model(model, inputs, expected_outputs): def test_qat_training_loop_tosa_INT(): """Test the QAT training loop with a simple MLP model. - This function creates a simple MLP model, prepares it for QAT, runs a training loop, - and evaluates the quantized model to make sure everything works as expected.""" + + This function creates a simple MLP model, prepares it for QAT, runs a + training loop, and evaluates the quantized model to make sure everything + works as expected. + + """ model = MLP() logger.info("Starting training loop test") diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py index 59156c2fd57..cd9964f4511 100644 --- a/backends/arm/test/misc/test_quant_custom_meta.py +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,11 +19,13 @@ def forward(self, x, y): @pytest.mark.parametrize("fp_extension", [True, False]) def test_qdq_squeezed_fp_op_tosa_INT_FP(fp_extension: bool): - """Test that a float operation surrounded by quantize-dequantize pairs - is correctly handled by the partitioner and the TOSA backend. + """Test that a float operation surrounded by quantize-dequantize pairs is + correctly handled by the partitioner and the TOSA backend. + Pattern: q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q |_____unquantized_____| + """ aten_op = "torch.ops.aten.add.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" @@ -67,11 +69,13 @@ def forward(self, x, y): @pytest.mark.parametrize("fp_extension", [True, False]) def test_quantized_to_float_transition_tosa_INT_FP(fp_extension: bool): - """Test that a model executing quantized ops followed by float ops - is correctly handled by the partitioner and the TOSA backend. + """Test that a model executing quantized ops followed by float ops is + correctly handled by the partitioner and the TOSA backend. + Pattern: q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv |___unquantized___| + """ aten_op = "torch.ops.aten.add.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" diff --git a/backends/arm/test/misc/test_shared_qspecs.py b/backends/arm/test/misc/test_shared_qspecs.py index 7a27727e4a6..8324209f5b2 100644 --- a/backends/arm/test/misc/test_shared_qspecs.py +++ b/backends/arm/test/misc/test_shared_qspecs.py @@ -34,8 +34,8 @@ def forward(self, x, y): def _get_quantizer() -> TOSAQuantizer: - """ - Returns a TOSAQuantizer configured for int8 quantization with SubOp unquantized. + """Returns a TOSAQuantizer configured for int8 quantization with SubOp + unquantized. """ quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) quantizer.set_global(_QUANT_CONFIG_INT8) @@ -114,7 +114,9 @@ def forward(self, x): class SharedQspecInputForkNonShared(torch.nn.Module): - """Shared qspec cluster with an input fork with both inputs as non-shared qspecs.""" + """Shared qspec cluster with an input fork with both inputs as non-shared + qspecs. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 4}, @@ -137,7 +139,9 @@ def forward(self, x, y): class SharedQspecInputForkShared(torch.nn.Module): - """Shared qspec cluster with an input fork with both inputs as shared qspecs.""" + """Shared qspec cluster with an input fork with both inputs as shared + qspecs. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 5}, @@ -162,7 +166,9 @@ def forward(self, x, y): class SharedQspecInputForkXShared(torch.nn.Module): - """Shared qspec cluster with an input fork with left input as shared qspec.""" + """Shared qspec cluster with an input fork with left input as shared + qspec. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 4}, @@ -186,7 +192,9 @@ def forward(self, x, y): class SharedQspecInputForkYShared(torch.nn.Module): - """Shared qspec cluster with an input fork with right input as shared qspec.""" + """Shared qspec cluster with an input fork with right input as shared + qspec. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 5}, @@ -210,7 +218,9 @@ def forward(self, x, y): class SharedQspecInputForkXConstant(torch.nn.Module): - """Shared qspec cluster with an input fork with left input as global constant.""" + """Shared qspec cluster with an input fork with left input as global + constant. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 2}, @@ -233,7 +243,9 @@ def forward(self, x): class SharedQspecInputForkYConstant(torch.nn.Module): - """Shared qspec cluster with an input fork with left input as local constant.""" + """Shared qspec cluster with an input fork with left input as local + constant. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 2}, @@ -255,7 +267,9 @@ def forward(self, x): class SharedQspecOutputForkNonShared(torch.nn.Module): - """Shared qspec cluster with an output fork with both outputs as non-shared qspecs.""" + """Shared qspec cluster with an output fork with both outputs as non-shared + qspecs. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 3}, @@ -282,7 +296,9 @@ def forward(self, x): class SharedQspecOutputForkShared(torch.nn.Module): - """Shared qspec cluster with an output fork with both outputs as shared qspecs.""" + """Shared qspec cluster with an output fork with both outputs as shared + qspecs. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 4}, @@ -307,7 +323,9 @@ def forward(self, x): class SharedQspecManyForks(torch.nn.Module): - """Shared qspec cluster with a number of forks to test more complex structures.""" + """Shared qspec cluster with a number of forks to test more complex + structures. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 6}, @@ -334,7 +352,9 @@ def forward(self, x): class SharedQspecSurroundedQuantizedOp(torch.nn.Module): - """An annotated int8 surrounded by a shared qspec cluster forcing input/output qparams to be equal.""" + """An annotated int8 surrounded by a shared qspec cluster forcing + input/output qparams to be equal. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 4}, @@ -360,7 +380,7 @@ def forward(self, x): class SharedQspecSurroundedQuantizedOpConstant(torch.nn.Module): - """ """ + """""" qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 5}, @@ -389,7 +409,7 @@ def forward(self, x): class SharedQspecSub(torch.nn.Module): - """A shared qspec node with float input""" + """A shared qspec node with float input.""" qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 2}, @@ -416,7 +436,9 @@ def forward(self, x, y): class SharedQspecCompetingQspecs(torch.nn.Module): - """A shared qspec node with per-channel/per-tensor annotated nodes as inputs""" + """A shared qspec node with per-channel/per-tensor annotated nodes as + inputs. + """ qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 3}, @@ -456,7 +478,7 @@ def forward(self, x): class SharedQspecNoQspecs(torch.nn.Module): - """A shared qspec node with float input/outputs""" + """A shared qspec node with float input/outputs.""" qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 2}, @@ -500,7 +522,7 @@ def forward(self, x): class MixedMaximumInt8Int16(torch.nn.Module): - """A shared qspec node with int16/int8 inputs""" + """A shared qspec node with int16/int8 inputs.""" qspecs = { "quantized_decomposed.quantize_per_tensor.default": {None: 6}, diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index dafe1169394..7170c7ed4d1 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -71,7 +71,7 @@ class TestTosaSpecification(unittest.TestCase): - """Tests the TOSA specification class""" + """Tests the TOSA specification class.""" @parameterized.expand(test_valid_strings) # type: ignore[misc] def test_version_string_no_target(self, version_string: str, expected_type): @@ -125,7 +125,7 @@ def test_supports_new_1_1_extensions_no_target(self): class TestTosaSpecMapping(unittest.TestCase): - """Tests the TosaSpecMapping class""" + """Tests the TosaSpecMapping class.""" def test_mapping_no_target(self): mapping = TosaSpecMapping() diff --git a/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py b/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py index 86e945311c7..89f6e7c13ed 100644 --- a/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py +++ b/backends/arm/test/models/stable_diffusion/stable_diffusion_module_test_configs.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,7 +18,6 @@ from transformers import CLIPTextConfig, T5Config - """ This file defines test configs used to initialize Stable Diffusion module tests. Module tests in the same directory will import these configs. diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index ff159e7f8ea..14201ed54e4 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -2,7 +2,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Tests 10 popular torch.nn.functional not tested in other ways or training related. diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 733e1fc8986..c3aaa61799b 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -2,7 +2,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Tests 10 popular nn modules not tested in other ways or training-related. - Embedding diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 03271ce1246..0ca8d3ac091 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -2,7 +2,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Tests 10 popular torch ops, not tested in other ways, training related or requiring randomness. diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index 88fe151c69f..d9d058fccc6 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -19,7 +19,9 @@ class DivTensorModeFloat(torch.nn.Module): - """torch.div(x, y, rounding_mode=mode) with mode in {None, "floor", "trunc"}.""" + """torch.div(x, y, rounding_mode=mode) with mode in {None, "floor", + "trunc"}. + """ aten_ops = ["aten.div.Tensor_mode"] aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default"] diff --git a/backends/arm/test/ops/test_max_pool1d.py b/backends/arm/test/ops/test_max_pool1d.py index 4d75cd8529a..4c7e9006555 100644 --- a/backends/arm/test/ops/test_max_pool1d.py +++ b/backends/arm/test/ops/test_max_pool1d.py @@ -4,15 +4,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -""" -Tests for the max_pool1d operation. +"""Tests for the max_pool1d operation. In PyTorch, max_pool1d may be decomposed internally into a sequence of operations (e.g., unsqueeze -> max_pool2d_with_indices -> getitem -> squeeze), but this test focuses on ensuring that the max_pool1d aten op is correctly -lowered/quantized and delegated to the expected edge dialect op on the -Arm backend (U55/U85). +lowered/quantized and delegated to the expected edge dialect op on the Arm +backend (U55/U85). + """ from typing import Callable, Tuple diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 9854772f849..d1fbefad745 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -2,20 +2,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - - -from typing import Tuple - -import pytest - -import torch - -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - TosaPipelineFP, - TosaPipelineINT, -) - """Summary of non-working cases. FP: @@ -32,6 +18,19 @@ Sub or inplace-sub with an integer input. """ + +from typing import Tuple + +import pytest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + input_t1 = Tuple[torch.Tensor, torch.scalar_tensor] # Input x, Input y diff --git a/backends/arm/test/ops/test_sdpa.py b/backends/arm/test/ops/test_sdpa.py index 4ed68daca04..0ffab82edbc 100644 --- a/backends/arm/test/ops/test_sdpa.py +++ b/backends/arm/test/ops/test_sdpa.py @@ -106,9 +106,13 @@ def test_sdpa_vgf_quant(test_case: test_case_t): @common.parametrize("test_case", test_suite) def test_sdpa_u55_INT(test_case: test_case_t): - """Verify SDPA compiles on U55. _safe_softmax from SDPA is skipped by - DecomposeSoftmaxPass (skip_safe_softmax=True for U55) and runs on CPU, - avoiding REDUCE_MAX which fails Vela compilation.""" + """Verify SDPA compiles on U55. + + _safe_softmax from SDPA is skipped by DecomposeSoftmaxPass + (skip_safe_softmax=True for U55) and runs on CPU, avoiding REDUCE_MAX which + fails Vela compilation. + + """ model, test_input = test_case() pipeline = EthosU55PipelineINT[input_t](model, test_input, [], []) pipeline.pop_stage("check.quant_nodes") @@ -120,8 +124,12 @@ def test_sdpa_u55_INT(test_case: test_case_t): @common.parametrize("test_case", test_suite) @common.XfailIfNoCorstone320 def test_sdpa_u85_INT(test_case: test_case_t): - """Verify SDPA compiles on U85. _safe_softmax is decomposed with stable - softmax (including amax/REDUCE_MAX) which is supported on U85.""" + """Verify SDPA compiles on U85. + + _safe_softmax is decomposed with stable softmax (including amax/REDUCE_MAX) + which is supported on U85. + + """ model, test_input = test_case() pipeline = EthosU85PipelineINT[input_t](model, test_input, [], []) pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/passes/test_decompose_int16_activation_conv_pass.py b/backends/arm/test/passes/test_decompose_int16_activation_conv_pass.py index 834bc7f7663..dd3f742cf84 100644 --- a/backends/arm/test/passes/test_decompose_int16_activation_conv_pass.py +++ b/backends/arm/test/passes/test_decompose_int16_activation_conv_pass.py @@ -1,16 +1,15 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -""" -Tests for DecomposeConvWithInt16ActivationPass. +"""Tests for DecomposeConvWithInt16ActivationPass. This pass decomposes convolution with int16 activation and bias into: - A convolution without bias - A rescale to int32 - An add with the reshaped bias - A rescale back to the output dtype + """ from typing import Tuple @@ -107,9 +106,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def test_decompose_int16_conv_pass_fp32_no_decomposition() -> None: - """ - Test that DecomposeConvWithInt16ActivationPass does NOT decompose +def test_decompose_conv_with_int16_activation_no_target_fp32_no_decomposition() -> None: + """Test that DecomposeConvWithInt16ActivationPass does NOT decompose convolution when using FP32 (no quantization). """ module = Conv2dWithBias() @@ -137,11 +135,12 @@ def test_decompose_int16_conv_pass_fp32_no_decomposition() -> None: exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" -def test_conv2d_int16_e2e_tosa_single_conv() -> None: - """ - End-to-end test for conv2d with INT16 quantization using TOSA pipeline. - This validates the full lowering path including the decomposition pass - for a single convolution with bias. +def test_decompose_conv_with_int16_activation_tosa_INT_single_conv() -> None: + """End-to-end test for conv2d with INT16 quantization using TOSA pipeline. + + This validates the full lowering path including the decomposition pass for a + single convolution with bias. + """ module = Conv2dWithBias() pipeline = TosaPipelineINT[input_t]( @@ -154,11 +153,12 @@ def test_conv2d_int16_e2e_tosa_single_conv() -> None: pipeline.run() -def test_conv2d_int16_e2e_tosa_multiple_convs() -> None: - """ - End-to-end test for conv2d with INT16 quantization using TOSA pipeline. - This validates the full lowering path including the decomposition pass - for multiple convolutions with bias. +def test_decompose_conv_with_int16_activation_tosa_INT_multiple_convs() -> None: + """End-to-end test for conv2d with INT16 quantization using TOSA pipeline. + + This validates the full lowering path including the decomposition pass for + multiple convolutions with bias. + """ module = Conv2dMultipleConvs() pipeline = TosaPipelineINT[input_t]( @@ -171,10 +171,11 @@ def test_conv2d_int16_e2e_tosa_multiple_convs() -> None: pipeline.run() -def test_conv2d_int16_e2e_tosa_without_bias() -> None: - """ - End-to-end test for conv2d without bias with INT16 quantization. +def test_decompose_conv_with_int16_activation_tosa_INT_without_bias() -> None: + """End-to-end test for conv2d without bias with INT16 quantization. + This validates that convolutions without bias don't get decomposed. + """ module = Conv2dWithoutBias() pipeline = TosaPipelineINT[input_t]( @@ -187,10 +188,11 @@ def test_conv2d_int16_e2e_tosa_without_bias() -> None: pipeline.run() -def test_conv2d_int8_e2e_tosa() -> None: - """ - End-to-end test for conv2d with INT8 quantization using TOSA pipeline. +def test_decompose_conv_with_int16_activation_tosa_INT_int8() -> None: + """End-to-end test for conv2d with INT8 quantization using TOSA pipeline. + This validates that INT8 activations don't trigger the decomposition. + """ module = Conv2dWithBias() pipeline = TosaPipelineINT[input_t]( diff --git a/backends/arm/test/passes/test_decompose_softmax_pass.py b/backends/arm/test/passes/test_decompose_softmax_pass.py index 8e6b52d78bc..f208f07b6fd 100644 --- a/backends/arm/test/passes/test_decompose_softmax_pass.py +++ b/backends/arm/test/passes/test_decompose_softmax_pass.py @@ -107,8 +107,9 @@ def __init__(self, **kwargs): def test_decompose_softmax_tosa_FP_skip_safe_softmax(): - """Verify skip_safe_softmax=True still decomposes regular softmax - using the stable algorithm (with amax and sub).""" + """Verify skip_safe_softmax=True still decomposes regular softmax using the + stable algorithm (with amax and sub). + """ module = Softmax() pipeline = PassPipeline[input_t]( module, diff --git a/backends/arm/test/quantizer/test_partial_quantization.py b/backends/arm/test/quantizer/test_partial_quantization.py index 5fe53d375ce..f8ba1d8d8d5 100644 --- a/backends/arm/test/quantizer/test_partial_quantization.py +++ b/backends/arm/test/quantizer/test_partial_quantization.py @@ -158,8 +158,8 @@ def test_disallow_tfa_for_two_skipped_modules_no_target(): def test_disallow_tfa_with_global_none_and_one_quantized_module_no_target(): """Ensure that with a global None quantization config, only the linear - module (with its own quantization config) is quantized, and that the - other nodes have `disallow_tfa` set. + module (with its own quantization config) is quantized, and that the other + nodes have `disallow_tfa` set. """ graph_after_quant_stage = _run_quantization_pipeline( diff --git a/backends/arm/test/quantizer/test_preserve_kwargs.py b/backends/arm/test/quantizer/test_preserve_kwargs.py index 8bfa912720f..f2ba835a481 100644 --- a/backends/arm/test/quantizer/test_preserve_kwargs.py +++ b/backends/arm/test/quantizer/test_preserve_kwargs.py @@ -17,7 +17,9 @@ class FullLike(torch.nn.Module): - """Since full_like is replaced with full, we only need to test on reference model, not FVP.""" + """Since full_like is replaced with full, we only need to test on reference + model, not FVP. + """ test_parameters = { "full_like_int_val": lambda: (torch.randn(2, 2, 2, 2) * 50, 3), diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 733783f2a09..273c35de56d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -91,13 +91,13 @@ def __init__( def get_input_names(program: ExportedProgram) -> list[str]: - """ - Get a list[str] with the names of the inputs to this model. + """Get a list[str] with the names of the inputs to this model. Args: program (ExportedProgram): The program to get input names from. Returns: A list of strings with the names of the model input. + """ return [spec.arg.name for spec in program.graph_signature.input_specs] @@ -105,12 +105,14 @@ def get_input_names(program: ExportedProgram) -> list[str]: def get_input_quantization_params( program: ExportedProgram, ) -> list[QuantizationParams]: - """ - Get input QuantizationParams in a program, maximum one per input to the program. + """Get input QuantizationParams in a program, maximum one per input to the + program. + Args: program (ExportedProgram): The program to get input quantization parameters from. Returns: list[QuantizationParams]: The found quantization parameters. + """ quant_params = [] @@ -142,8 +144,8 @@ def get_input_quantization_params( def get_output_quantization_params( output_node: Node, ) -> dict[Node, QuantizationParams | None]: - """ - Get output QuantizationParams from a program. + """Get output QuantizationParams from a program. + Args: output_nodes (list(Node)): A list of output nodes to get output quantization parameters from. Returns: @@ -151,6 +153,7 @@ def get_output_quantization_params( If no quantization parameters were found, the entry is None. Raises: RuntimeError if no output quantization parameters are found. + """ quant_params: dict[Node, QuantizationParams | None] = {} for node in output_node.args[0]: # type: ignore[union-attr] @@ -207,7 +210,9 @@ def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: class TosaReferenceModelDispatch(TorchFunctionMode): - """A context manager for executing call_delegate nodes using the reference model""" + """A context manager for executing call_delegate nodes using the reference + model. + """ def __init__(self): self.ran_tosa_dispatch = False @@ -376,6 +381,7 @@ def run_corstone( timeout: int = 120, # s ) -> list[torch.Tensor]: """Executes an inference of the exported_program on FVP. + Returns a list of tensors with the output. Args: `executorch_program_manager`: The executorch program to run. @@ -392,6 +398,7 @@ def run_corstone( Relies on the output tensors from the exported program to figure out the shape and dtype of the buffer that was output from the FVP. + """ exported_program = executorch_program_manager.exported_program() intermediate_path = Path(intermediate_path) @@ -552,7 +559,8 @@ def save_bytes( input_name: str, quant_param: Optional[QuantizationParams] = None, ) -> str: - """Serializes and saves 'data' in byte format, possibly quantizing it before. + """Serializes and saves 'data' in byte format, possibly quantizing it + before. Parameters: path: the directory where to save the data. @@ -561,6 +569,7 @@ def save_bytes( quant_param: the parameters to use for quantization. Returns: the full file path of the output. + """ data_np = prep_data_for_save(data, input_name, quant_param) file_path = os.path.join(path, input_name + ".bin") @@ -572,11 +581,11 @@ def save_bytes( def _run_cmd(cmd: List[str], check=True) -> subprocess.CompletedProcess[bytes]: - """ - Run a command and check for errors. + """Run a command and check for errors. Args: cmd (List[str]): The command to run as a list. + """ try: result = subprocess.run( # nosec B603 - cmd constructed from trusted inputs @@ -599,6 +608,7 @@ def _run_flatc(args: List[str]) -> None: If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable. Otherwise, expects the `flatc` tool to be available on the system path. + """ flatc_resource = _resources.files(arm_test_package).joinpath(_FLATC_RESOURCE_NAME) if flatc_resource.is_file(): @@ -623,9 +633,11 @@ def _run_flatc(args: List[str]) -> None: def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: - """ - This function is used to dump the TOSA flatbuffer to a human readable - format, using flatc. It is used for debugging purposes. + """This function is used to dump the TOSA flatbuffer to a human readable + format, using flatc. + + It is used for debugging purposes. + """ tmp = tempfile.mkdtemp() diff --git a/backends/arm/test/test_memory_allocator_log.py b/backends/arm/test/test_memory_allocator_log.py index 3853b60b7f6..87b32391806 100644 --- a/backends/arm/test/test_memory_allocator_log.py +++ b/backends/arm/test/test_memory_allocator_log.py @@ -1,15 +1,15 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -Check log files for memory metrics and compare them against thresholds. +"""Check log files for memory metrics and compare them against thresholds. Usage example: python3 test_memory_allocator_log.py \ --log path/to/log.txt \ --require "Total SRAM used" "<= 310 KiB" \ --require "method_allocator_input" "<= 4 B" + """ import argparse diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 6ba08fd4785..67b6a2cf19f 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -301,8 +301,8 @@ def dump_error_output( rtol: float = 1e-03, qtol: float = 0, ) -> None: - """ - Prints Quantization info and error tolerances, and saves the differing tensors to disc. + """Prints Quantization info and error tolerances, and saves the differing + tensors to disc. """ # Capture assertion error and print more info banner = "=" * 40 + "TOSA debug info" + "=" * 40 @@ -336,7 +336,7 @@ def dump_error_output( if __name__ == "__main__": - """This is expected to produce the example output of print_diff""" + """This is expected to produce the example output of print_diff.""" torch.manual_seed(0) a = torch.rand(3, 3, 2, 2) * 0.01 b = a.clone().detach() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 2336ccc6233..cbd682eefae 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -482,10 +482,10 @@ def run_method_and_compare_outputs( error_callbacks: Optional[Sequence[Callable[..., None]]] = None, run_eager_mode: bool = False, ): - """ - Compares the run_artifact output of 'stage' with the output of a reference stage. - If the model is quantized, the reference stage is the Quantize stage output. - Otherwise, the reference stage is the initial pytorch module. + """Compares the run_artifact output of 'stage' with the output of a + reference stage. If the model is quantized, the reference stage is the + Quantize stage output. Otherwise, the reference stage is the initial + pytorch module. Asserts that the outputs are equal (within tolerances). Returns self to allow the function to be run in a test chain. @@ -495,6 +495,7 @@ def run_method_and_compare_outputs( The default is the latest run stage. inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. + """ # backward-compatible ordering (accept inputs as the first positional argument) @@ -730,8 +731,7 @@ def check_quantization_annotation( input_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, output_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, ): - """ - Check the quantization annotations in the graph of a quantized model. + """Check the quantization annotations in the graph of a quantized model. Args: quantization_annotations: A dictionary mapping operator names to a dictionary of @@ -743,6 +743,7 @@ def check_quantization_annotation( If None, the check is skipped. Returns self for daisy-chaining. + """ if not self.is_quantized(): raise RuntimeError( @@ -783,13 +784,13 @@ def dump_operator_distribution( print_table: bool = True, include_dtypes: bool = True, ): - """Dump the distribution of operators in the current stage. - In the partition stage, additional information is included such as the number of - delegates and the distribution of TOSA operators. - Set parameter print_table to False to dump in a parseable format. - + """Dump the distribution of operators in the current stage. In the + partition stage, additional information is included such as the number + of delegates and the distribution of TOSA operators. Set parameter + print_table to False to dump in a parseable format. Returns self for daisy-chaining. + """ line = "#" * 10 to_print = f"\n{line} {self.cur} Operator Distribution {line}\n" @@ -868,10 +869,12 @@ def dump_operator_distribution( def dump_dtype_distribution( self, path_to_dump: Optional[str] = None, print_table: bool = True ): - """Dump a the distributions of dtypes of nodes and placeholders in the current stage. - Set parameter print_table to False to dump in a parseable format. + """Dump a the distributions of dtypes of nodes and placeholders in the + current stage. Set parameter print_table to False to dump in a parseable + format. Returns self for daisy-chaining. + """ line = "#" * 10 @@ -915,11 +918,12 @@ def run_transform_for_annotation_pipeline( """Run transform_for_annotation_pipeline on exported program to ensure passes do not break the initial model before quantization. - There are caveats to this however. As we register buffers to the graph modules - the resulting exported graph can fail. Use this only to compare numerical correctness - in eager mode. + There are caveats to this however. As we register buffers to the graph + modules the resulting exported graph can fail. Use this only to compare + numerical correctness in eager mode. Returns exported program with passes applied. + """ if stage is None: @@ -1038,7 +1042,10 @@ def _get_dtype_distribution( graph: Graph, tosa_spec: TosaSpecification ) -> tuple[Counter[str], Counter[str]]: """Counts the occurences of placeholder and call_function dtypes in a graph. - The result is a tuple of Counters (placeholder_distribution, call_function_distribution) + + The result is a tuple of Counters (placeholder_distribution, + call_function_distribution) + """ placeholder_dtypes: list[str] = [] call_function_dtypes: list[str] = [] @@ -1054,7 +1061,9 @@ def _get_dtype_distribution( def _get_operator_distribution(graph: Graph) -> List[Tuple[str, int]]: """Counts the occurences of operator names in a graph. + The result is a sorted list [('operator name':'number of nodes')] + """ return sorted( Counter( @@ -1069,7 +1078,9 @@ def _get_operator_distribution(graph: Graph) -> List[Tuple[str, int]]: def _get_operator_dtype_distribution(graph: Graph) -> List[Tuple[Tuple[str, str], int]]: """Counts the occurences of operator names and dtype pairs in a graph. + The result is a sorted list[(('operator name','dtype'),'number of nodes')] + """ target_dtype_pairs = [] for node in graph.nodes: @@ -1104,7 +1115,9 @@ def _get_tosa_operator_distribution( ) -> list[Tuple[str, int]] | list[Tuple[Tuple[str, str], int]]: """Counts the occurences of operator names of all lowered modules containing a TOSA flatbuffer. + The result is a string with the operator distribution or an error message. + """ id = 0 unknown_dtype_str = "UNKNOWN" diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 5420875bbda..4e060919738 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -51,7 +51,11 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound=Tuple[Any, ...]) -""" Generic type used for test data in the pipeline. Depends on which type the operator expects.""" +"""Generic type used for test data in the pipeline. + +Depends on which type the operator expects. + +""" def _has_quantizable_inputs(test_data: T) -> bool: @@ -87,10 +91,10 @@ def update(self, *args, **kwargs): class BasePipeline(Generic[T]): - """ - The BasePipeline defines a list of stages to be applied to a torch.nn.module for lowering it - in the Arm backend. To be inherited and adjusted for particular targets. Importantly, the - pipeline list can be modified before running the pipeline to support various pipeline extensions + """The BasePipeline defines a list of stages to be applied to a + torch.nn.module for lowering it in the Arm backend. To be inherited and + adjusted for particular targets. Importantly, the pipeline list can be + modified before running the pipeline to support various pipeline extensions and debugging usecases. Attributes: @@ -104,6 +108,7 @@ class BasePipeline(Generic[T]): tester.to_edge_transform_and_lower() or tester.to_edge().check(exir_ops).partition() + """ @staticmethod @@ -157,14 +162,15 @@ def __init__( self.add_stage(self.tester.to_executorch) def add_stage(self, func: Callable, *args, **kwargs): - """ - Adds a stage defined by a function with args and kwargs. By default appends to the pipeline. - For stages which may be added multiple times to a pipeline, s.a. checks and debug stages, - a suffix is appended with a dot to make sure every id is unique, e.g. check becomes check.0 + """Adds a stage defined by a function with args and kwargs. By default + appends to the pipeline. For stages which may be added multiple times to + a pipeline, s.a. checks and debug stages, a suffix is appended with a + dot to make sure every id is unique, e.g. check becomes check.0. Special kwargs: pos : specifies position in pipeline to add stage at. suffix : specifies a custom suffix to identify non unique stages, instead of a number. + """ pipeline_length = len(self._stages) @@ -238,7 +244,7 @@ def quantizer(self) -> TOSAQuantizer: ) def pop_stage(self, identifier: int | str): - """Removes and returns the stage at postion pos""" + """Removes and returns the stage at postion pos.""" if isinstance(identifier, int): stage = self._stages.pop(identifier) elif isinstance(identifier, str): @@ -369,8 +375,8 @@ def run(self): class TosaPipelineINT(TOSAPipeline, Generic[T]): - """ - Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model. + """Lowers a graph to INT TOSA spec (with quantization) and tests it with the + TOSA reference model. Attributes: module: The module which the pipeline is applied to. @@ -395,6 +401,7 @@ class TosaPipelineINT(TOSAPipeline, Generic[T]): tosa_version: TOSA version string to target. tosa_extensions: Optional list of TOSA extensions. epsilon: Epsilon used in quantization configuration. + """ def __init__( @@ -515,8 +522,8 @@ def __init__( class TosaPipelineFP(TOSAPipeline, Generic[T]): - """ - Lowers a graph to FP TOSA spec and tests it with the TOSA reference model. + """Lowers a graph to FP TOSA spec and tests it with the TOSA reference + model. Attributes: module: The module which the pipeline is applied to. @@ -532,6 +539,7 @@ class TosaPipelineFP(TOSAPipeline, Generic[T]): options. use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module. custom_path : Path to dump intermediate artifacts such as tosa and pte to. + """ def __init__( @@ -679,8 +687,8 @@ def __init__( class EthosU55PipelineINT(EthosUPipelineINTBase, Generic[T]): - """ - Lowers a graph to u55 INT TOSA spec and tests it on the Corstone300 FVP, if run_on_fvp is true. + """Lowers a graph to u55 INT TOSA spec and tests it on the Corstone300 FVP, + if run_on_fvp is true. Attributes: module: The module which the pipeline is applied to. @@ -692,6 +700,7 @@ class EthosU55PipelineINT(EthosUPipelineINTBase, Generic[T]): run_on_fvp: Set to true to test the pte file on a fvp simulator. use_edge_to_transform_and_lower: Selects between two possible ways of lowering the module. custom_path : Path to dump intermediate artifacts such as tosa and pte to. + """ def __init__( @@ -735,8 +744,8 @@ def __init__( class EthosU85PipelineINT(EthosUPipelineINTBase, Generic[T]): - """ - Lowers a graph to u85 INT TOSA spec and tests it on the Corstone320 FVP, if run_on_fvp is true. + """Lowers a graph to u85 INT TOSA spec and tests it on the Corstone320 FVP, + if run_on_fvp is true. Attributes: module: The module which the pipeline is applied to. @@ -748,6 +757,7 @@ class EthosU85PipelineINT(EthosUPipelineINTBase, Generic[T]): run_on_fvp: Set to true to test the pte file on a fvp simulator. use_edge_to_transform_and_lower: Selects between two possible ways of lowering the module. custom_path : Path to dump intermediate artifacts such as tosa and pte to. + """ def __init__( @@ -791,8 +801,8 @@ def __init__( class PassPipeline(TOSAPipeline, Generic[T]): - """ - Runs single passes directly on an edge_program and checks operators before/after. + """Runs single passes directly on an edge_program and checks operators + before/after. Attributes: module: The module which the pipeline is applied to. @@ -811,6 +821,7 @@ class PassPipeline(TOSAPipeline, Generic[T]): Passes are run in order pass_list -> pass_functions -> passes_with_exported_program. See arm_tester.RunPasses() for more information. + """ def __init__( @@ -889,8 +900,8 @@ def run(self): class TransformAnnotationPassPipeline(TOSAPipeline, Generic[T]): - """ - Runs transform_for_annotation_pipeline passes directly on an exported program and checks output. + """Runs transform_for_annotation_pipeline passes directly on an exported + program and checks output. Attributes: module: The module which the pipeline is applied to. @@ -945,9 +956,8 @@ def __init__( class QuantizationPipeline(TOSAPipeline, Generic[T]): - """ - Runs quantization and checks that appropriate nodes are annotated with an expected - quantization-spec. + """Runs quantization and checks that appropriate nodes are annotated with an + expected quantization-spec. Attributes: module: The module which the pipeline is applied to. @@ -1002,9 +1012,8 @@ def __init__( class OpNotSupportedPipeline(TOSAPipeline, Generic[T]): - """ - Runs the partitioner on a module and checks that ops are not delegated to test - SupportedTOSAOperatorChecks. + """Runs the partitioner on a module and checks that ops are not delegated to + test SupportedTOSAOperatorChecks. Attributes: module: The module which the pipeline is applied to. @@ -1014,6 +1023,7 @@ class OpNotSupportedPipeline(TOSAPipeline, Generic[T]): non_delegated_ops : Exir ops expected not to be delegated. n_expected_delegates : Number of delegate calls (0 in the usual case). custom_path : Path to dump intermediate artifacts such as tosa and pte to. + """ def __init__( @@ -1069,8 +1079,8 @@ def __init__( class VgfPipeline(BasePipeline, Generic[T]): - """ - Lowers a graph based on TOSA spec (with or without quantization) and converts TOSA to VFG. + """Lowers a graph based on TOSA spec (with or without quantization) and + converts TOSA to VFG. Attributes: module: The module which the pipeline is applied to. @@ -1089,6 +1099,7 @@ class VgfPipeline(BasePipeline, Generic[T]): use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module. custom_path : Path to dump intermediate artifacts such as tosa and pte to. + """ def __init__(