Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 111 additions & 63 deletions backends/cortex_m/quantizer/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict

from dataclasses import dataclass
from typing import Iterator, List, Optional

from executorch.backends.arm.quantizer.quantization_annotator import _is_large_scalar

from executorch.backends.cortex_m.quantizer.pattern_checkers import PatternCheck
from executorch.backends.cortex_m.quantizer.quantization_configs import (
CortexMQuantizationConfig,
QuantizationConfig,
)
from torch._ops import OpOverload
from torch.fx import Node
Expand All @@ -30,73 +32,57 @@ class PatternMatcher:
Attributes:
support_dict: A dictionary mapping patterns (tuples of operator overloads) to
PatternCheck instances that validate the patterns.
filter_fn: A global filter applied over all nodes to exclude them from matching.
support_dict_name: An optional name for the support dict, used for logging.
"""

Q_PATTERN_MATCHED_KEY = "quantizer_matched"
REJECT_PREVIOUSLY_ANNOTATED = "Tried annotating already quantized node."
REJECT_FILTERED_OUT = "Node filtered out by global filter."
REJECT_UNSUPPORTED_PATTERN = (
"Tried annotating unsupported configuration of operators"
)
REJECT_UNSUPPORTED_QCONFIG = "Tried annotating unsupported quantization config"
REJECT_LARGE_SCALAR = "Tried annotating a large constant scalar value that is not supported for quantization."

def __init__(
self,
support_dict: dict[tuple[OpOverload, ...], PatternCheck],
filter_fn=lambda node: False,
support_dict_name: str | None = None,
):
self.support_dict = support_dict
self.filter_fn = filter_fn
self.support_dict_name = support_dict_name

self.patterns_by_first = defaultdict(list)
for p in sorted(support_dict.keys(), key=len, reverse=True):
self.patterns_by_first[p[0]].append(p)
self.max_pattern_len = max(
(len(pattern) for pattern in support_dict.keys()), default=0
)

def check_node(
self, node: Optional[Node], target: OpOverload
) -> tuple[bool, Optional[str]]:
"""
Return true if the node is a valid match for the given target.
"""
if node is None:
return False, None
if not node.target == target:
return False, None
if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False):
return False, self.REJECT_PREVIOUSLY_ANNOTATED
if self.filter_fn(node):
return False, self.REJECT_FILTERED_OUT

return True, None

def check_pattern(
def _validate_match(
self,
node: Optional[Node],
pattern: List[OpOverload],
match: List[Node],
quantization_config: CortexMQuantizationConfig,
) -> Optional[PatternMatchResult]:
"""
Returns a PatternMatchResult when the pattern structurally matches, with
status indicating accept/reject. Returns None if there is no match.
"""
match: List[Node] = []

for pattern_target in pattern:
node_ok, rejection_reason = self.check_node(node, pattern_target)
if not node_ok:
if rejection_reason is None:
return None
return PatternMatchResult([node], False, rejection_reason)
# Reject match if it contains a node that has already been matched as part of another pattern.
if any(node.meta.get(self.Q_PATTERN_MATCHED_KEY, False) for node in match):
return PatternMatchResult(match, False, self.REJECT_PREVIOUSLY_ANNOTATED)

match.append(node)
node = list(node.users)[0] if len(node.users) > 0 else None
# Reject match if it contains a node that has an input which is too large to be quantized
if any(_is_large_scalar(node, node.graph.owning_module) for node in match):
return PatternMatchResult(match, False, self.REJECT_LARGE_SCALAR)

if all(node.op in ("placeholder", "output") for node in match):
# Accept matches of length 1 that are just placeholders or outputs
for node in match:
node.meta[self.Q_PATTERN_MATCHED_KEY] = True
return PatternMatchResult(match, True)

key = tuple([n.target for n in match])
pattern_checker = self.support_dict.get(key, None)
if pattern_checker:

if pattern_checker is not None:
pattern_ok = pattern_checker.check_pattern(match)
if not pattern_ok:
return PatternMatchResult(match, False, self.REJECT_UNSUPPORTED_PATTERN)
Expand All @@ -107,8 +93,68 @@ def check_pattern(
if not qconfig_ok:
return PatternMatchResult(match, False, self.REJECT_UNSUPPORTED_QCONFIG)

for node in match:
node.meta[self.Q_PATTERN_MATCHED_KEY] = True
return PatternMatchResult(match, True)

def _get_match(self, node_queue: List[Node]) -> List[Node]:
"""
Returns the longest pattern match starting at the front of the queue.
"""
if node_queue[0].op in ("placeholder", "output"):
return [node_queue[0]]

pattern_key = tuple(n.target for n in node_queue)
while pattern_key:
if pattern_key in self.support_dict:
return node_queue[: len(pattern_key)]
else:
pattern_key = pattern_key[:-1]

return []

def _get_matches(
self, node_queue: List[Node], quantization_config: QuantizationConfig
) -> List[PatternMatchResult]:
"""
Returns the longest accepted match starting at the first node of the queue as well as longer rejected matches.
"""
matches = []
accepted = False
max_match_length = len(node_queue)

while max_match_length > 0 and not accepted:
match = self._get_match(node_queue[:max_match_length])
max_match_length = (
len(match) - 1
) # Look for shorter matches in the next iter if no accepted match found

if match:
validated_match = self._validate_match(match, quantization_config)
accepted = validated_match.accepted
matches.append(validated_match)

return matches

def _dequeue_and_get_matches(
self, node_queue: List[Node], quantization_config: QuantizationConfig
) -> List[PatternMatchResult]:
"""
Dequeues the longest accepted match starting at the first node of the queue, and returns all potential matches that were checked (rejected ones). If no match is found, simply dequeues the first node and returns an empty list.
"""
potential_matches = self._get_matches(node_queue, quantization_config)
accepted_matches = [m for m in potential_matches if m.accepted]
assert (
len(accepted_matches) <= 1
), "_get_matches should only accept the longest possible match, but multiple accepted matches were found."

if len(accepted_matches) == 0:
node_queue.pop(0)
else:
del node_queue[: len(accepted_matches[0].pattern)]

return potential_matches

def find_pattern_matches(
self, nodes: Iterator[Node], quantization_config: CortexMQuantizationConfig
) -> Iterator[PatternMatchResult]:
Expand All @@ -122,28 +168,30 @@ def find_pattern_matches(
already been matched.
"""

for node in nodes:
if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False):
yield PatternMatchResult(
[node], False, self.REJECT_PREVIOUSLY_ANNOTATED
) # Reject already matched nodes
continue
if node.op == "placeholder" or node.op == "output":
node.meta[self.Q_PATTERN_MATCHED_KEY] = True
yield PatternMatchResult(
[node], True
) # Always accept placeholders and outputs

for pattern in self.patterns_by_first.get(node.target, []):
match_or_none = self.check_pattern(node, pattern, quantization_config)
if match_or_none is None:
continue # No match, try next pattern
if match_or_none.accepted:
for _ in range(len(match_or_none.pattern) - 1):
next(nodes) # Fast-forward iterator to skip matched nodes
for matched_node in match_or_none.pattern:
matched_node.meta[self.Q_PATTERN_MATCHED_KEY] = True
yield match_or_none # Accepted pattern found, break to skip checking remaining patterns
break
else:
yield match_or_none # Rejected pattern found, keep searching
node = next(nodes, None)
node_queue = []
while node is not None:
potential_matches = []
node_queue.append(node)
next_node = next(nodes, None)
node_users = list(node.users)

# If there is a fork or gap in the nodes iterator, empty the queue
if (len(node_users) != 1) or (node_users[0] != next_node):
while node_queue:
new_matches = self._dequeue_and_get_matches(
node_queue, quantization_config
)
potential_matches.extend(new_matches)

# When que reach the max length, search for match starting at the front of the queue
elif len(node_queue) >= self.max_pattern_len:
potential_matches = self._dequeue_and_get_matches(
node_queue, quantization_config
)

# Report all pattern matches, also rejected ones for debugging purposes
for match in potential_matches:
yield match

node = next_node
Loading
Loading