Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_inf_values import ReplaceInfValues
from .resolve_debug_handle import ResolveDebugHandle
from .seq_mse import SeqMSE
from .tag_quant_io import TagQuantIO

Expand Down Expand Up @@ -96,6 +97,7 @@
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceInfValues,
ResolveDebugHandle,
SeqMSE,
TagQuantIO,
]
5 changes: 5 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceInfValues,
ResolveDebugHandle,
TagQuantIO,
)
from executorch.backends.qualcomm._passes.utils import (
Expand Down Expand Up @@ -107,6 +108,7 @@ def get_capture_program_passes():
(Remove0DTensor, True),
(RemoveRedundancy, True),
(TagQuantIO, False),
(ResolveDebugHandle, True),
]

passes = OrderedDict()
Expand Down Expand Up @@ -175,6 +177,9 @@ def get_to_edge_transform_passes(
if "edge_program" in kwargs:
kwargs["edge_program"] = exported_program
self.add_pass(p(**kwargs))
assert isinstance(
self.passes[-1], ResolveDebugHandle
), "Please ensure ResolveDebugHandle is the last executed edge pass."
return self.passes

def transform_for_to_edge_pipeline(
Expand Down
47 changes: 47 additions & 0 deletions backends/qualcomm/_passes/resolve_debug_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator

import torch
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
from executorch.exir.pass_base import ExportPass, PassResult


class ResolveDebugHandle(ExportPass):
"""
Caution: This pass is executed as the last of the edge_passes.
For any passes executed during qnn_preprocess, users will need to handle debug_handle ID themselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add a comment here, highlighting that conditional branch like torch.cond is not supported right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Gasoonjia,
Thanks for reviewing the PR.
I do have a comment below in line 31 for mentioning about conditional branch.
Do you think the comment looks good to you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @winskuo-quic, it would be great to highlight that at the beginning of the class. User will first and sometimes only read the class-level comment and it will be great if we can warn them potential issue here.

Description: During passes transformation, some passes might be copying some node's meta when creating a new node,
which means multiple nodes might be sharing the same debug_handle ID while it shouldn't for QNN's scenario.
We want each call function to have its own debug_handle so we can compare all decomposed node's accuracy.
This is critical as Intermediate Debugger uses debug handle as key.
debug_handle ID must be resolved so each op gets its own set of debug_handle ID and intermediate output.
"""

def __init__(self):
super(ResolveDebugHandle, self).__init__()

def call(self, graph_module: torch.fx.GraphModule):
handle_counter = 1
visited = set()
# TODO: Migrate to bfs tracing if torch.cond is introduced to QNN.
for node in graph_module.graph.nodes:
# Assume node is traversed in topological order, adding a check here to be safe.
# For ops like topk, getitem node shares same handle as topk node. This should align with original ExecuTorch behavior.
if node.target == operator.getitem:
source_node = node.args[0]
assert (
source_node.name in visited
), "Graph is not traversed in topological order, unexpected behavior."
node.meta[DEBUG_HANDLE_KEY] = source_node.meta[DEBUG_HANDLE_KEY]
elif node.op == "call_function":
node.meta[DEBUG_HANDLE_KEY] = handle_counter
handle_counter += 1
visited.add(node.name)

graph_module.recompile()
return PassResult(graph_module, True)
4 changes: 4 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def get_passes_dependency_for_capture_program():
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ResolveDebugHandle,
TagQuantIO,
)

Expand Down Expand Up @@ -110,6 +111,9 @@ def get_passes_dependency_for_capture_program():
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
TagQuantIO: [LayoutTransform],
ResolveDebugHandle: [
TagQuantIO
], # IMPORTANT: Please always ensure ResolveDebugHandle is the last executed pass.
}


Expand Down
21 changes: 17 additions & 4 deletions backends/qualcomm/qnn_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
CompileSpec,
PreprocessResult,
)
from executorch.exir.backend.utils import DelegateMappingBuilder
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
from torch.export.exported_program import ExportedProgram

DEFAULT_DEBUG_HANDLE = 65535
Expand Down Expand Up @@ -138,7 +140,7 @@ def preprocess(
)

@staticmethod
def preprocess_multimethod(
def preprocess_multimethod( # noqa: C901
edge_programs: Dict[str, List[ExportedProgram]],
compile_specs: Dict[str, List[List[CompileSpec]]],
) -> PreprocessResult:
Expand All @@ -161,8 +163,9 @@ def preprocess_multimethod(
qnn_manager = get_current_qnn_manager(
option.backend_options.backend_type, compile_spec
)
debug_handle_builder = DelegateMappingBuilder(generated_identifiers=False)
for i in range(num_sub_graphs):
# e.g. 2 methods (x, y) with 3 partitions
# e.g. 2 methods (x, y) with 3 subgraphs(partitions)
# > context_binary_0: [x.subgraph_0, y.subgraph_0]
# > context_binary_1: [x.subgraph_1, y.subgraph_1]
# > context_binary_2: [x.subgraph_2, y.subgraph_2]
Expand All @@ -176,6 +179,13 @@ def preprocess_multimethod(
option.op_package_options.op_package_infos,
option.use_mha2sha,
)
if qnn_manager.IsTensorDump():
for node in programs[i].graph.nodes:
if handle_id := node.meta.get(DEBUG_HANDLE_KEY):
debug_handle_builder.insert_delegate_mapping_entry(
handles=handle_id,
identifier=node.name,
)
if isinstance(py_op_wrappers, bytes):
ctx_binary_list.append(py_op_wrappers)
else:
Expand Down Expand Up @@ -204,13 +214,16 @@ def preprocess_multimethod(
all_processed_results[key].append(
PreprocessResult(
processed_bytes=bytes(qnn_context_binary),
debug_handle_map={},
debug_handle_map=debug_handle_builder.get_delegate_mapping(),
)
)
elif len(ctx_binary_list) == len(edge_programs.values()):
for i, key in enumerate(edge_programs.keys()):
all_processed_results[key].append(
PreprocessResult(processed_bytes=ctx_binary_list[i])
PreprocessResult(
processed_bytes=ctx_binary_list[i],
debug_handle_map=debug_handle_builder.get_delegate_mapping(),
)
)
else:
raise RuntimeError("Hybrid compilation is not supported")
Expand Down
59 changes: 58 additions & 1 deletion backends/qualcomm/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
InsertReshapeForReduceOps,
RemoveRedundancy,
)

from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
from executorch.backends.qualcomm.tests.models import TopKandIndex
from executorch.backends.qualcomm.utils.utils import (
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
to_edge_transform_and_lower_to_qnn,
)
from executorch.exir import to_edge
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -150,6 +157,56 @@ def test_mha_to_sha(self):
f"Output {i} mismatch: got {out}, expected {ref}",
)

def test_resolve_debug_handle(self):
name_handle_map = {
"aten_topk_default": 1,
"getitem": 1,
"getitem_1": 1,
"aten_view_copy_default": 2,
"aten_index_tensor": 3,
"aten_add_tensor": 4,
}
module = TopKandIndex() # noqa: F405
sample_input = (torch.randn(3, 10),)

backend_options = generate_htp_compiler_spec(use_fp16=False)
compiler_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650, # Random soc_model
backend_options=backend_options,
dump_intermediate_outputs=True,
)

edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
module,
sample_input,
compiler_spec,
generate_etrecord=True,
)
exec_prog_mgr = edge_prog_mgr.to_executorch()
etrecord = exec_prog_mgr.get_etrecord()
debug_handle_size = len(etrecord._debug_handle_map["forward"][0])
self.assertEqual(
len(name_handle_map),
debug_handle_size,
f"Number of handles does not match, expecting: {len(name_handle_map)}, but get: {debug_handle_size}",
)
after_edge_pass_ep = etrecord.graph_map["edge_after_transform/forward"]

for node in after_edge_pass_ep.graph.nodes:
if node.name in name_handle_map:
expected_handle = name_handle_map.pop(node.name)
node_handle = node.meta[DEBUG_HANDLE_KEY]
self.assertEqual(
expected_handle,
node_handle,
f"{node.name} is expecting a handle id {expected_handle}, but got {node_handle}.",
)
self.assertEqual(
len(name_handle_map),
0,
f"Following nodes did not find a match in the graph: {name_handle_map.keys()}",
)


if __name__ == "__main__":
unittest.main()
Loading