diff --git a/backends/cortex_m/quantizer/pattern_matcher.py b/backends/cortex_m/quantizer/pattern_matcher.py index d75c249765f..7ed848d5e83 100644 --- a/backends/cortex_m/quantizer/pattern_matcher.py +++ b/backends/cortex_m/quantizer/pattern_matcher.py @@ -3,14 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections import defaultdict from dataclasses import dataclass from typing import Iterator, List, Optional +from executorch.backends.arm.quantizer.quantization_annotator import _is_large_scalar + from executorch.backends.cortex_m.quantizer.pattern_checkers import PatternCheck from executorch.backends.cortex_m.quantizer.quantization_configs import ( CortexMQuantizationConfig, + QuantizationConfig, ) from torch._ops import OpOverload from torch.fx import Node @@ -30,73 +32,57 @@ class PatternMatcher: Attributes: support_dict: A dictionary mapping patterns (tuples of operator overloads) to PatternCheck instances that validate the patterns. - filter_fn: A global filter applied over all nodes to exclude them from matching. + support_dict_name: An optional name for the support dict, used for logging. """ Q_PATTERN_MATCHED_KEY = "quantizer_matched" REJECT_PREVIOUSLY_ANNOTATED = "Tried annotating already quantized node." - REJECT_FILTERED_OUT = "Node filtered out by global filter." REJECT_UNSUPPORTED_PATTERN = ( "Tried annotating unsupported configuration of operators" ) REJECT_UNSUPPORTED_QCONFIG = "Tried annotating unsupported quantization config" + REJECT_LARGE_SCALAR = "Tried annotating a large constant scalar value that is not supported for quantization." def __init__( self, support_dict: dict[tuple[OpOverload, ...], PatternCheck], - filter_fn=lambda node: False, support_dict_name: str | None = None, ): self.support_dict = support_dict - self.filter_fn = filter_fn self.support_dict_name = support_dict_name - self.patterns_by_first = defaultdict(list) - for p in sorted(support_dict.keys(), key=len, reverse=True): - self.patterns_by_first[p[0]].append(p) + self.max_pattern_len = max( + (len(pattern) for pattern in support_dict.keys()), default=0 + ) - def check_node( - self, node: Optional[Node], target: OpOverload - ) -> tuple[bool, Optional[str]]: - """ - Return true if the node is a valid match for the given target. - """ - if node is None: - return False, None - if not node.target == target: - return False, None - if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False): - return False, self.REJECT_PREVIOUSLY_ANNOTATED - if self.filter_fn(node): - return False, self.REJECT_FILTERED_OUT - - return True, None - - def check_pattern( + def _validate_match( self, - node: Optional[Node], - pattern: List[OpOverload], + match: List[Node], quantization_config: CortexMQuantizationConfig, ) -> Optional[PatternMatchResult]: """ Returns a PatternMatchResult when the pattern structurally matches, with status indicating accept/reject. Returns None if there is no match. """ - match: List[Node] = [] - for pattern_target in pattern: - node_ok, rejection_reason = self.check_node(node, pattern_target) - if not node_ok: - if rejection_reason is None: - return None - return PatternMatchResult([node], False, rejection_reason) + # Reject match if it contains a node that has already been matched as part of another pattern. + if any(node.meta.get(self.Q_PATTERN_MATCHED_KEY, False) for node in match): + return PatternMatchResult(match, False, self.REJECT_PREVIOUSLY_ANNOTATED) - match.append(node) - node = list(node.users)[0] if len(node.users) > 0 else None + # Reject match if it contains a node that has an input which is too large to be quantized + if any(_is_large_scalar(node, node.graph.owning_module) for node in match): + return PatternMatchResult(match, False, self.REJECT_LARGE_SCALAR) + + if all(node.op in ("placeholder", "output") for node in match): + # Accept matches of length 1 that are just placeholders or outputs + for node in match: + node.meta[self.Q_PATTERN_MATCHED_KEY] = True + return PatternMatchResult(match, True) key = tuple([n.target for n in match]) pattern_checker = self.support_dict.get(key, None) - if pattern_checker: + + if pattern_checker is not None: pattern_ok = pattern_checker.check_pattern(match) if not pattern_ok: return PatternMatchResult(match, False, self.REJECT_UNSUPPORTED_PATTERN) @@ -107,8 +93,68 @@ def check_pattern( if not qconfig_ok: return PatternMatchResult(match, False, self.REJECT_UNSUPPORTED_QCONFIG) + for node in match: + node.meta[self.Q_PATTERN_MATCHED_KEY] = True return PatternMatchResult(match, True) + def _get_match(self, node_queue: List[Node]) -> List[Node]: + """ + Returns the longest pattern match starting at the front of the queue. + """ + if node_queue[0].op in ("placeholder", "output"): + return [node_queue[0]] + + pattern_key = tuple(n.target for n in node_queue) + while pattern_key: + if pattern_key in self.support_dict: + return node_queue[: len(pattern_key)] + else: + pattern_key = pattern_key[:-1] + + return [] + + def _get_matches( + self, node_queue: List[Node], quantization_config: QuantizationConfig + ) -> List[PatternMatchResult]: + """ + Returns the longest accepted match starting at the first node of the queue as well as longer rejected matches. + """ + matches = [] + accepted = False + max_match_length = len(node_queue) + + while max_match_length > 0 and not accepted: + match = self._get_match(node_queue[:max_match_length]) + max_match_length = ( + len(match) - 1 + ) # Look for shorter matches in the next iter if no accepted match found + + if match: + validated_match = self._validate_match(match, quantization_config) + accepted = validated_match.accepted + matches.append(validated_match) + + return matches + + def _dequeue_and_get_matches( + self, node_queue: List[Node], quantization_config: QuantizationConfig + ) -> List[PatternMatchResult]: + """ + Dequeues the longest accepted match starting at the first node of the queue, and returns all potential matches that were checked (rejected ones). If no match is found, simply dequeues the first node and returns an empty list. + """ + potential_matches = self._get_matches(node_queue, quantization_config) + accepted_matches = [m for m in potential_matches if m.accepted] + assert ( + len(accepted_matches) <= 1 + ), "_get_matches should only accept the longest possible match, but multiple accepted matches were found." + + if len(accepted_matches) == 0: + node_queue.pop(0) + else: + del node_queue[: len(accepted_matches[0].pattern)] + + return potential_matches + def find_pattern_matches( self, nodes: Iterator[Node], quantization_config: CortexMQuantizationConfig ) -> Iterator[PatternMatchResult]: @@ -122,28 +168,30 @@ def find_pattern_matches( already been matched. """ - for node in nodes: - if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False): - yield PatternMatchResult( - [node], False, self.REJECT_PREVIOUSLY_ANNOTATED - ) # Reject already matched nodes - continue - if node.op == "placeholder" or node.op == "output": - node.meta[self.Q_PATTERN_MATCHED_KEY] = True - yield PatternMatchResult( - [node], True - ) # Always accept placeholders and outputs - - for pattern in self.patterns_by_first.get(node.target, []): - match_or_none = self.check_pattern(node, pattern, quantization_config) - if match_or_none is None: - continue # No match, try next pattern - if match_or_none.accepted: - for _ in range(len(match_or_none.pattern) - 1): - next(nodes) # Fast-forward iterator to skip matched nodes - for matched_node in match_or_none.pattern: - matched_node.meta[self.Q_PATTERN_MATCHED_KEY] = True - yield match_or_none # Accepted pattern found, break to skip checking remaining patterns - break - else: - yield match_or_none # Rejected pattern found, keep searching + node = next(nodes, None) + node_queue = [] + while node is not None: + potential_matches = [] + node_queue.append(node) + next_node = next(nodes, None) + node_users = list(node.users) + + # If there is a fork or gap in the nodes iterator, empty the queue + if (len(node_users) != 1) or (node_users[0] != next_node): + while node_queue: + new_matches = self._dequeue_and_get_matches( + node_queue, quantization_config + ) + potential_matches.extend(new_matches) + + # When que reach the max length, search for match starting at the front of the queue + elif len(node_queue) >= self.max_pattern_len: + potential_matches = self._dequeue_and_get_matches( + node_queue, quantization_config + ) + + # Report all pattern matches, also rejected ones for debugging purposes + for match in potential_matches: + yield match + + node = next_node diff --git a/backends/cortex_m/test/misc/test_pattern_matcher.py b/backends/cortex_m/test/misc/test_pattern_matcher.py index a7567ba43c7..cc563f68c0f 100644 --- a/backends/cortex_m/test/misc/test_pattern_matcher.py +++ b/backends/cortex_m/test/misc/test_pattern_matcher.py @@ -20,6 +20,15 @@ def _export_two_op_graph_module(): return export(_TwoOpModule(), (torch.ones(2, 2),)).graph_module +class _FullModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return torch.full(x.shape, 1e20, dtype=x.dtype, device=x.device) + + +def _export_full_graph_module(): + return export(_FullModule(), (torch.ones(2, 2),)).graph_module + + class _AlwaysPassCheck(PatternCheck): @classmethod def check_pattern(cls, pattern): @@ -40,10 +49,14 @@ def check_quantization_config(cls, pattern, quantization_config): return False +def _node_iter_for_targets(graph_module, targets): + return node_finders.NodeTargetNodeFinder(targets).find_nodes(graph_module) + + def _node_iter(graph_module): - return node_finders.NodeTargetNodeFinder( - [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default] - ).find_nodes(graph_module) + return _node_iter_for_targets( + graph_module, [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default] + ) def _dummy_qconfig(): @@ -51,6 +64,7 @@ def _dummy_qconfig(): def test_matches_linear_chain_pattern(): + """Test basic pattern match functionality.""" graph_module = _export_two_op_graph_module() support = { (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): _AlwaysPassCheck, @@ -71,6 +85,7 @@ def test_matches_linear_chain_pattern(): def test_prefers_longest_available_pattern(): + """Test that when multiple patterns match, the longest pattern is preferred.""" graph_module = _export_two_op_graph_module() support = { (torch.ops.aten.add.Tensor,): _AlwaysPassCheck, @@ -91,12 +106,13 @@ def test_prefers_longest_available_pattern(): ] -def test_filter_fn_blocks_match(): +def test_pattern_checker_can_reject_match(): + """Test basic pattern rejection capability""" graph_module = _export_two_op_graph_module() - support = {(torch.ops.aten.add.Tensor,): _AlwaysPassCheck} - matcher = PatternMatcher( - support, filter_fn=lambda node: node.target == torch.ops.aten.add.Tensor - ) + support = { + (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): _AlwaysFailCheck, + } + matcher = PatternMatcher(support) matches = list( matcher.find_pattern_matches(_node_iter(graph_module), _dummy_qconfig()) @@ -104,13 +120,15 @@ def test_filter_fn_blocks_match(): assert len(matches) == 1 assert not matches[0].accepted - assert matches[0].message == PatternMatcher.REJECT_FILTERED_OUT + assert matches[0].message == PatternMatcher.REJECT_UNSUPPORTED_PATTERN -def test_pattern_checker_can_reject_match(): +def test_rejects_longer_match_then_accepts_shorter_match(): + """Test that a shorter match is accepted if a longer match is rejected by the pattern checker and both are reported.""" graph_module = _export_two_op_graph_module() support = { (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): _AlwaysFailCheck, + (torch.ops.aten.add.Tensor,): _AlwaysPassCheck, } matcher = PatternMatcher(support) @@ -118,6 +136,119 @@ def test_pattern_checker_can_reject_match(): matcher.find_pattern_matches(_node_iter(graph_module), _dummy_qconfig()) ) - assert len(matches) == 1 + assert len(matches) == 2 assert not matches[0].accepted assert matches[0].message == PatternMatcher.REJECT_UNSUPPORTED_PATTERN + assert matches[1].accepted + assert [n.target for n in matches[1].pattern] == [torch.ops.aten.add.Tensor] + + +def _get_node_by_target(graph_module, target): + return next(n for n in graph_module.graph.nodes if n.target == target) + + +def _get_output_node(graph_module): + return next(n for n in graph_module.graph.nodes if n.op == "output") + + +def test_missing_second_node_matches_first_node_pattern(): + """Test that a pattern going outside the selected nodes is non matched.""" + graph_module = _export_two_op_graph_module() + support = { + (torch.ops.aten.add.Tensor,): _AlwaysPassCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): _AlwaysPassCheck, + } + matcher = PatternMatcher(support) + + add_node = _get_node_by_target(graph_module, torch.ops.aten.add.Tensor) + + matches = list(matcher.find_pattern_matches(iter([add_node]), _dummy_qconfig())) + + assert len(matches) == 1 + assert matches[0].accepted + assert [n.target for n in matches[0].pattern] == [torch.ops.aten.add.Tensor] + assert add_node.meta[PatternMatcher.Q_PATTERN_MATCHED_KEY] + + +def test_missing_second_node_with_below_node_matches_first_node_pattern(): + """Similar to test_missing_second_node_matches_first_node_pattern but with an additional node below the matched node to ensure that the presence of additional nodes does not interfere with matching.""" + graph_module = _export_two_op_graph_module() + support = { + (torch.ops.aten.add.Tensor,): _AlwaysPassCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): _AlwaysPassCheck, + } + matcher = PatternMatcher(support) + + add_node = _get_node_by_target(graph_module, torch.ops.aten.add.Tensor) + relu_node = _get_node_by_target(graph_module, torch.ops.aten.relu.default) + output_node = _get_output_node(graph_module) + + matches = list( + matcher.find_pattern_matches(iter([add_node, output_node]), _dummy_qconfig()) + ) + + assert len(matches) == 2 + assert matches[0].accepted + assert [n.target for n in matches[0].pattern] == [torch.ops.aten.add.Tensor] + assert add_node.meta[PatternMatcher.Q_PATTERN_MATCHED_KEY] + assert not relu_node.meta.get(PatternMatcher.Q_PATTERN_MATCHED_KEY, False) + assert matches[1].accepted + assert matches[1].pattern == [output_node] + assert output_node.meta[PatternMatcher.Q_PATTERN_MATCHED_KEY] + + +def test_rejects_large_scalar_match(): + """Tests that patterns with large scalar constants are rejected regardless of pattern checker.""" + graph_module = _export_full_graph_module() + support = { + (torch.ops.aten.full.default,): _AlwaysPassCheck, + } + matcher = PatternMatcher(support) + + matches = list( + matcher.find_pattern_matches( + _node_iter_for_targets(graph_module, [torch.ops.aten.full.default]), + _dummy_qconfig(), + ) + ) + + assert len(matches) == 1 + assert not matches[0].accepted + assert matches[0].message == PatternMatcher.REJECT_LARGE_SCALAR + + +def test_accept_none_nodechecker(): + """Tests that patterns with None as the pattern checker are accepted.""" + graph_module = _export_two_op_graph_module() + support = { + (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): None, + } + matcher = PatternMatcher(support) + + matches = list( + matcher.find_pattern_matches(_node_iter(graph_module), _dummy_qconfig()) + ) + + assert len(matches) == 1 + assert matches[0].accepted + + +def test_reject_reported_once(): + """Tests that the pattern matcher reports a rejected pattern only once.""" + graph_module = _export_two_op_graph_module() + support = { + (torch.ops.aten.add.Tensor,): _AlwaysFailCheck, + ( + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + torch.ops.aten.mul.Tensor, + ): _AlwaysFailCheck, + } + matcher = PatternMatcher(support) + + matches = list( + matcher.find_pattern_matches(_node_iter(graph_module), _dummy_qconfig()) + ) + + assert len(matches) == 1 + assert not matches[0].accepted