From 7a91e7cc9bce196e54e51f4dd089740fc480d27d Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Mon, 2 Mar 2026 14:38:31 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Debugger Converge Phase 1: Introduce debug handle --- backends/qualcomm/_passes/__init__.py | 2 + backends/qualcomm/_passes/qnn_pass_manager.py | 5 ++ .../qualcomm/_passes/resolve_debug_handle.py | 47 +++++++++++++++ backends/qualcomm/_passes/utils.py | 4 ++ backends/qualcomm/qnn_preprocess.py | 21 +++++-- backends/qualcomm/tests/test_passes.py | 59 ++++++++++++++++++- 6 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 backends/qualcomm/_passes/resolve_debug_handle.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 83e4e9bad37..414d4aa7965 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -48,6 +48,7 @@ from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_inf_values import ReplaceInfValues +from .resolve_debug_handle import ResolveDebugHandle from .seq_mse import SeqMSE from .tag_quant_io import TagQuantIO @@ -96,6 +97,7 @@ RemoveRedundancy, ReplaceArangeArgs, ReplaceInfValues, + ResolveDebugHandle, SeqMSE, TagQuantIO, ] diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 5f4168c1770..a0b6d5e0781 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -53,6 +53,7 @@ RemoveRedundancy, ReplaceArangeArgs, ReplaceInfValues, + ResolveDebugHandle, TagQuantIO, ) from executorch.backends.qualcomm._passes.utils import ( @@ -107,6 +108,7 @@ def get_capture_program_passes(): (Remove0DTensor, True), (RemoveRedundancy, True), (TagQuantIO, False), + (ResolveDebugHandle, True), ] passes = OrderedDict() @@ -175,6 +177,9 @@ def get_to_edge_transform_passes( if "edge_program" in kwargs: kwargs["edge_program"] = exported_program self.add_pass(p(**kwargs)) + assert isinstance( + self.passes[-1], ResolveDebugHandle + ), "Please ensure ResolveDebugHandle is the last executed edge pass." return self.passes def transform_for_to_edge_pipeline( diff --git a/backends/qualcomm/_passes/resolve_debug_handle.py b/backends/qualcomm/_passes/resolve_debug_handle.py new file mode 100644 index 00000000000..332c71ad0fe --- /dev/null +++ b/backends/qualcomm/_passes/resolve_debug_handle.py @@ -0,0 +1,47 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import operator + +import torch +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY +from executorch.exir.pass_base import ExportPass, PassResult + + +class ResolveDebugHandle(ExportPass): + """ + Caution: This pass is executed as the last of the edge_passes. + For any passes executed during qnn_preprocess, users will need to handle debug_handle ID themselves. + + Description: During passes transformation, some passes might be copying some node's meta when creating a new node, + which means multiple nodes might be sharing the same debug_handle ID while it shouldn't for QNN's scenario. + We want each call function to have its own debug_handle so we can compare all decomposed node's accuracy. + This is critical as Intermediate Debugger uses debug handle as key. + debug_handle ID must be resolved so each op gets its own set of debug_handle ID and intermediate output. + """ + + def __init__(self): + super(ResolveDebugHandle, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + handle_counter = 1 + visited = set() + # TODO: Migrate to bfs tracing if torch.cond is introduced to QNN. + for node in graph_module.graph.nodes: + # Assume node is traversed in topological order, adding a check here to be safe. + # For ops like topk, getitem node shares same handle as topk node. This should align with original ExecuTorch behavior. + if node.target == operator.getitem: + source_node = node.args[0] + assert ( + source_node.name in visited + ), "Graph is not traversed in topological order, unexpected behavior." + node.meta[DEBUG_HANDLE_KEY] = source_node.meta[DEBUG_HANDLE_KEY] + elif node.op == "call_function": + node.meta[DEBUG_HANDLE_KEY] = handle_counter + handle_counter += 1 + visited.add(node.name) + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 72749a29544..de5070dc812 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -79,6 +79,7 @@ def get_passes_dependency_for_capture_program(): RecomposePixelUnshuffle, RecomposeRmsNorm, RemoveRedundancy, + ResolveDebugHandle, TagQuantIO, ) @@ -110,6 +111,9 @@ def get_passes_dependency_for_capture_program(): RecomposePixelUnshuffle: [RemoveRedundancy], RecomposeRmsNorm: [RemoveRedundancy], TagQuantIO: [LayoutTransform], + ResolveDebugHandle: [ + TagQuantIO + ], # IMPORTANT: Please always ensure ResolveDebugHandle is the last executed pass. } diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index c0351b01ed6..f423288640c 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -28,6 +28,8 @@ CompileSpec, PreprocessResult, ) +from executorch.exir.backend.utils import DelegateMappingBuilder +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY from torch.export.exported_program import ExportedProgram DEFAULT_DEBUG_HANDLE = 65535 @@ -138,7 +140,7 @@ def preprocess( ) @staticmethod - def preprocess_multimethod( + def preprocess_multimethod( # noqa: C901 edge_programs: Dict[str, List[ExportedProgram]], compile_specs: Dict[str, List[List[CompileSpec]]], ) -> PreprocessResult: @@ -161,8 +163,9 @@ def preprocess_multimethod( qnn_manager = get_current_qnn_manager( option.backend_options.backend_type, compile_spec ) + debug_handle_builder = DelegateMappingBuilder(generated_identifiers=False) for i in range(num_sub_graphs): - # e.g. 2 methods (x, y) with 3 partitions + # e.g. 2 methods (x, y) with 3 subgraphs(partitions) # > context_binary_0: [x.subgraph_0, y.subgraph_0] # > context_binary_1: [x.subgraph_1, y.subgraph_1] # > context_binary_2: [x.subgraph_2, y.subgraph_2] @@ -176,6 +179,13 @@ def preprocess_multimethod( option.op_package_options.op_package_infos, option.use_mha2sha, ) + if qnn_manager.IsTensorDump(): + for node in programs[i].graph.nodes: + if handle_id := node.meta.get(DEBUG_HANDLE_KEY): + debug_handle_builder.insert_delegate_mapping_entry( + handles=handle_id, + identifier=node.name, + ) if isinstance(py_op_wrappers, bytes): ctx_binary_list.append(py_op_wrappers) else: @@ -204,13 +214,16 @@ def preprocess_multimethod( all_processed_results[key].append( PreprocessResult( processed_bytes=bytes(qnn_context_binary), - debug_handle_map={}, + debug_handle_map=debug_handle_builder.get_delegate_mapping(), ) ) elif len(ctx_binary_list) == len(edge_programs.values()): for i, key in enumerate(edge_programs.keys()): all_processed_results[key].append( - PreprocessResult(processed_bytes=ctx_binary_list[i]) + PreprocessResult( + processed_bytes=ctx_binary_list[i], + debug_handle_map=debug_handle_builder.get_delegate_mapping(), + ) ) else: raise RuntimeError("Hybrid compilation is not supported") diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 46c25c66dc3..89120980613 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -7,8 +7,15 @@ InsertReshapeForReduceOps, RemoveRedundancy, ) - +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset +from executorch.backends.qualcomm.tests.models import TopKandIndex +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + to_edge_transform_and_lower_to_qnn, +) from executorch.exir import to_edge +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY from executorch.exir.dialects._ops import ops as exir_ops @@ -150,6 +157,56 @@ def test_mha_to_sha(self): f"Output {i} mismatch: got {out}, expected {ref}", ) + def test_resolve_debug_handle(self): + name_handle_map = { + "aten_topk_default": 1, + "getitem": 1, + "getitem_1": 1, + "aten_view_copy_default": 2, + "aten_index_tensor": 3, + "aten_add_tensor": 4, + } + module = TopKandIndex() # noqa: F405 + sample_input = (torch.randn(3, 10),) + + backend_options = generate_htp_compiler_spec(use_fp16=False) + compiler_spec = generate_qnn_executorch_compiler_spec( + soc_model=QcomChipset.SM8650, # Random soc_model + backend_options=backend_options, + dump_intermediate_outputs=True, + ) + + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + compiler_spec, + generate_etrecord=True, + ) + exec_prog_mgr = edge_prog_mgr.to_executorch() + etrecord = exec_prog_mgr.get_etrecord() + debug_handle_size = len(etrecord._debug_handle_map["forward"][0]) + self.assertEqual( + len(name_handle_map), + debug_handle_size, + f"Number of handles does not match, expecting: {len(name_handle_map)}, but get: {debug_handle_size}", + ) + after_edge_pass_ep = etrecord.graph_map["edge_after_transform/forward"] + + for node in after_edge_pass_ep.graph.nodes: + if node.name in name_handle_map: + expected_handle = name_handle_map.pop(node.name) + node_handle = node.meta[DEBUG_HANDLE_KEY] + self.assertEqual( + expected_handle, + node_handle, + f"{node.name} is expecting a handle id {expected_handle}, but got {node_handle}.", + ) + self.assertEqual( + len(name_handle_map), + 0, + f"Following nodes did not find a match in the graph: {name_handle_map.keys()}", + ) + if __name__ == "__main__": unittest.main()