| 
1 |  | -from typing import Any, Callable, Dict  | 
 | 1 | +from dataclasses import dataclass, field  | 
 | 2 | +from typing import Any, Callable, Dict, Optional, Sequence, Union  | 
 | 3 | +from enum import Enum, auto  | 
2 | 4 | 
 
  | 
3 |  | -from torch.fx.node import Target  | 
 | 5 | +from torch.fx.node import Target, Node, _get_qualified_name  | 
4 | 6 | from torch_tensorrt.fx.converter_registry import CONVERTERS  | 
5 | 7 | 
 
  | 
6 |  | -DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS)  | 
 | 8 | + | 
 | 9 | +class ConverterPriority(Enum):  | 
 | 10 | +    """Enum to set a converter's priority in the registry"""  | 
 | 11 | + | 
 | 12 | +    STANDARD = auto()  | 
 | 13 | +    HIGH = auto()  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +@dataclass(frozen=True)  | 
 | 17 | +class ConverterSupport:  | 
 | 18 | +    """Class representing a converter implementation and support function  | 
 | 19 | +
  | 
 | 20 | +    Args:  | 
 | 21 | +        converter_implementation: Function which converts said node to a TRT equivalent  | 
 | 22 | +        capability_validator: Function which takes in a Node and returns a bool indicating  | 
 | 23 | +            whether that node can be supported by its companion converter. Note that  | 
 | 24 | +            this function must not modify the node or its graph  | 
 | 25 | +    """  | 
 | 26 | + | 
 | 27 | +    converter_implementation: Callable  | 
 | 28 | +    capability_validator: Callable[[Node], bool] = field(default=lambda node: True)  | 
 | 29 | + | 
 | 30 | + | 
 | 31 | +# Dictionary representing Dynamo aten-only converters  | 
 | 32 | +# Each converter maps to a sequence of at least one ConverterSupport object(s)  | 
 | 33 | +DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}  | 
7 | 34 | 
 
  | 
8 | 35 | 
 
  | 
9 | 36 | def dynamo_tensorrt_converter(  | 
10 | 37 |     key: Target,  | 
11 | 38 |     enabled: bool = True,  | 
 | 39 | +    capability_validator: Optional[Callable[[Node], bool]] = None,  | 
 | 40 | +    priority: ConverterPriority = ConverterPriority.STANDARD,  | 
12 | 41 | ) -> Callable[[Any], Any]:  | 
 | 42 | +    """Decorator for Dynamo TensorRT Converter  | 
 | 43 | +
  | 
 | 44 | +    Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry  | 
 | 45 | +
  | 
 | 46 | +    Args:  | 
 | 47 | +        key: Node target for which the converter is implemented for  | 
 | 48 | +            (for example, torch.ops.add.Tensor)  | 
 | 49 | +        enabled: Whether the converter should be enabled/cached or not  | 
 | 50 | +        capability_validator: Function which evaluates whether a node is valid for conversion  | 
 | 51 | +            by the decorated converter. See ConverterSupport for more details.  | 
 | 52 | +            Defaults to None, implying the capability_validator function is always true -  | 
 | 53 | +            this means all nodes of "key" kind can be supported by this converter  | 
 | 54 | +        priority: Converter's level of priority relative to other converters with the  | 
 | 55 | +            same target  | 
 | 56 | +    Returns:  | 
 | 57 | +        The converter being decorated  | 
 | 58 | +    """  | 
 | 59 | + | 
13 | 60 |     def register_converter(converter):  | 
14 |  | -        DYNAMO_CONVERTERS[key] = converter  | 
 | 61 | +        """Helper function to register the converter, then return it"""  | 
 | 62 | +        assert callable(converter), "Converter function must be callable"  | 
 | 63 | + | 
 | 64 | +        # If no capability_validator function is specified, use the default function - always return true  | 
 | 65 | +        if capability_validator is None:  | 
 | 66 | +            converter_support = ConverterSupport(converter_implementation=converter)  | 
 | 67 | +        else:  | 
 | 68 | +            assert callable(  | 
 | 69 | +                capability_validator  | 
 | 70 | +            ), "Argument checking function must be callable"  | 
 | 71 | +            converter_support = ConverterSupport(  | 
 | 72 | +                converter_implementation=converter,  | 
 | 73 | +                capability_validator=capability_validator,  | 
 | 74 | +            )  | 
 | 75 | + | 
 | 76 | +        # If a converter for this operator already exists, append the new converter to the list  | 
 | 77 | +        # Otherwise, start a new list  | 
 | 78 | +        if key in DYNAMO_ATEN_CONVERTERS:  | 
 | 79 | +            # High priority converters are inserted at the front of the list,  | 
 | 80 | +            # so they can be checked first by the registry  | 
 | 81 | +            if priority is ConverterPriority.HIGH:  | 
 | 82 | +                DYNAMO_ATEN_CONVERTERS[key].insert(0, converter_support)  | 
 | 83 | +            else:  | 
 | 84 | +                DYNAMO_ATEN_CONVERTERS[key].append(converter_support)  | 
 | 85 | +        else:  | 
 | 86 | +            DYNAMO_ATEN_CONVERTERS[key] = [converter_support]  | 
 | 87 | + | 
15 | 88 |         return converter  | 
16 | 89 | 
 
  | 
17 | 90 |     def disable_converter(converter):  | 
18 | 91 |         return converter  | 
19 | 92 | 
 
  | 
 | 93 | +    # Select whether to cache/enable the converter  | 
20 | 94 |     if enabled:  | 
21 | 95 |         return register_converter  | 
22 | 96 |     else:  | 
23 | 97 |         return disable_converter  | 
 | 98 | + | 
 | 99 | + | 
 | 100 | +class ConverterRegistry:  | 
 | 101 | +    """Registry for storing multiple converter dictionaries  | 
 | 102 | +
  | 
 | 103 | +    Capable of storing dictionaries with the following signature:  | 
 | 104 | +    Dict[Target, Union[Callable, Sequence[ConverterSupport]]]  | 
 | 105 | +
  | 
 | 106 | +    Also able to validate converter implementations against user-provided  | 
 | 107 | +    argument-checking functions  | 
 | 108 | +
  | 
 | 109 | +    Args:  | 
 | 110 | +        registries: List of dictionaries representing converter registries.  | 
 | 111 | +            The order of the provided dictionaries is the order in which they  | 
 | 112 | +            will be traversed. This is only significant when using non-validated  | 
 | 113 | +            methods.  | 
 | 114 | +    """  | 
 | 115 | + | 
 | 116 | +    def __init__(  | 
 | 117 | +        self,  | 
 | 118 | +        registries: Sequence[Dict[Target, Union[Callable, Sequence[ConverterSupport]]]],  | 
 | 119 | +        registry_names: Optional[Sequence[str]] = None,  | 
 | 120 | +    ):  | 
 | 121 | +        # Copy reference to each dictionary object into attribute list  | 
 | 122 | +        self.registries = [registry for registry in registries]  | 
 | 123 | + | 
 | 124 | +        if registry_names is not None:  | 
 | 125 | +            assert len(self.registries) == len(registry_names)  | 
 | 126 | +            self.registry_names = [name for name in registry_names]  | 
 | 127 | +        else:  | 
 | 128 | +            self.registry_names = [  | 
 | 129 | +                f"Registry {i + 1}" for i in range(len(self.registries))  | 
 | 130 | +            ]  | 
 | 131 | + | 
 | 132 | +        self.validate_invariants()  | 
 | 133 | + | 
 | 134 | +    def validate_invariants(self):  | 
 | 135 | +        """Validates the invariants required of the dictionaries in the registries  | 
 | 136 | +
  | 
 | 137 | +        Raises AssertionError if any invariants have been violated  | 
 | 138 | +        """  | 
 | 139 | +        # All registries must be dictionaries  | 
 | 140 | +        assert all(isinstance(elt, dict) for elt in self.registries)  | 
 | 141 | + | 
 | 142 | +        # Every dictionary in the registry must have one of two signatures:  | 
 | 143 | +        # Dict[Target, Callable] or Dict[Target, Sequence[ConverterSupport]]  | 
 | 144 | +        # Where, for the latter, the sequence must be non-empty  | 
 | 145 | +        for registry in self.registries:  | 
 | 146 | +            for converters in registry.values():  | 
 | 147 | +                if isinstance(converters, (list, tuple)):  | 
 | 148 | +                    assert (  | 
 | 149 | +                        all(isinstance(c, ConverterSupport) for c in converters)  | 
 | 150 | +                        and len(converters) > 0  | 
 | 151 | +                    )  | 
 | 152 | +                else:  | 
 | 153 | +                    assert callable(converters), "Converter function must be callable"  | 
 | 154 | + | 
 | 155 | +    def __getitem_without_validation__(self, key: Target):  | 
 | 156 | +        """Get the first-found converter in any registry  | 
 | 157 | +
  | 
 | 158 | +        Searches all registries in order and returns the first converter encountered  | 
 | 159 | +        """  | 
 | 160 | +        if isinstance(key, Node):  | 
 | 161 | +            raise KeyError(  | 
 | 162 | +                "Unvalidated accesses to the Converter registry can only be "  | 
 | 163 | +                + "made with node targets. Try accessing the registry with node.target"  | 
 | 164 | +            )  | 
 | 165 | + | 
 | 166 | +        self.validate_invariants()  | 
 | 167 | + | 
 | 168 | +        # Iterate over all registries and return the first converter found  | 
 | 169 | +        for registry in self.registries:  | 
 | 170 | +            if key in registry:  | 
 | 171 | +                converters = registry[key]  | 
 | 172 | + | 
 | 173 | +                if isinstance(converters, (list, tuple)):  | 
 | 174 | +                    return converters[0].converter_implementation  | 
 | 175 | +                else:  | 
 | 176 | +                    return converters  | 
 | 177 | + | 
 | 178 | +        raise KeyError(f"None of the converter registries have an entry for {key}")  | 
 | 179 | + | 
 | 180 | +    def __getitem__(self, node: Node):  | 
 | 181 | +        """Get the first-found validated converter in any registry  | 
 | 182 | +
  | 
 | 183 | +        Searches all registries in order and returns the first converter  | 
 | 184 | +        which passes validation on the input node  | 
 | 185 | +        """  | 
 | 186 | +        if not isinstance(node, Node):  | 
 | 187 | +            raise KeyError(  | 
 | 188 | +                "Validated accesses to the Converter registry can only be "  | 
 | 189 | +                + "made with node inputs. Try accessing the registry with a node "  | 
 | 190 | +                + "or use get_unvalidated to access without node validation."  | 
 | 191 | +            )  | 
 | 192 | + | 
 | 193 | +        self.validate_invariants()  | 
 | 194 | +        key = node.target  | 
 | 195 | + | 
 | 196 | +        # Iterate over all registries, validating the converter on the input node  | 
 | 197 | +        # If no capability_validator function is found, assume full coverage  | 
 | 198 | +        for registry in self.registries:  | 
 | 199 | +            if key in registry:  | 
 | 200 | +                converters = registry[key]  | 
 | 201 | + | 
 | 202 | +                if isinstance(converters, (list, tuple)):  | 
 | 203 | +                    for candidate in converters:  | 
 | 204 | +                        if candidate.capability_validator(node):  | 
 | 205 | +                            return candidate.converter_implementation  | 
 | 206 | +                else:  | 
 | 207 | +                    return converters  | 
 | 208 | + | 
 | 209 | +        raise KeyError(  | 
 | 210 | +            f"None of the converter registries have a validated entry for {key}, with node {node}"  | 
 | 211 | +        )  | 
 | 212 | + | 
 | 213 | +    def keys(self):  | 
 | 214 | +        """Get all unique targets across all dictionaries"""  | 
 | 215 | +        return self.unique_targets()  | 
 | 216 | + | 
 | 217 | +    def get_unvalidated(self, key: Target, value=None):  | 
 | 218 | +        """Get unvalidated converter for input target with a default return"""  | 
 | 219 | +        try:  | 
 | 220 | +            return self.__getitem_without_validation__(key)  | 
 | 221 | +        except KeyError:  | 
 | 222 | +            return value  | 
 | 223 | + | 
 | 224 | +    def get(self, node: Node, value=None):  | 
 | 225 | +        """Get validated converter for input node with a default return"""  | 
 | 226 | +        try:  | 
 | 227 | +            return self.__getitem__(node)  | 
 | 228 | +        except KeyError:  | 
 | 229 | +            return value  | 
 | 230 | + | 
 | 231 | +    def __contains__(self, key: Union[Target, Node]):  | 
 | 232 | +        """Check whether a converter for an input node or target exists"""  | 
 | 233 | +        try:  | 
 | 234 | +            # Attempt to access the item in the registry  | 
 | 235 | +            if isinstance(key, Node):  | 
 | 236 | +                self.__getitem__(key)  | 
 | 237 | +            else:  | 
 | 238 | +                self.__getitem_without_validation__(key)  | 
 | 239 | + | 
 | 240 | +            return True  | 
 | 241 | +        except KeyError:  | 
 | 242 | +            return False  | 
 | 243 | + | 
 | 244 | +    def get_all_converters_with_target(  | 
 | 245 | +        self, key: Target, return_registry_info: bool = False  | 
 | 246 | +    ):  | 
 | 247 | +        """Get all converters across all registries for the target  | 
 | 248 | +
  | 
 | 249 | +        Returns a list of all converterts having the specified target  | 
 | 250 | +        """  | 
 | 251 | +        self.validate_invariants()  | 
 | 252 | +        converters_with_target = []  | 
 | 253 | + | 
 | 254 | +        # Store count of number of registered converters per registry  | 
 | 255 | +        if return_registry_info:  | 
 | 256 | +            registry_data = {name: 0 for name in self.registry_names}  | 
 | 257 | + | 
 | 258 | +        for index, registry in enumerate(self.registries):  | 
 | 259 | +            if key in registry:  | 
 | 260 | +                converters = registry[key]  | 
 | 261 | + | 
 | 262 | +                if isinstance(converters, (list, tuple)):  | 
 | 263 | +                    converters_with_target.extend(  | 
 | 264 | +                        [c.converter_implementation for c in converters]  | 
 | 265 | +                    )  | 
 | 266 | +                    # Add converter count to registry name storage  | 
 | 267 | +                    if return_registry_info:  | 
 | 268 | +                        registry_data[self.registry_names[index]] += len(converters)  | 
 | 269 | +                else:  | 
 | 270 | +                    converters_with_target.append(converters)  | 
 | 271 | +                    # Add converter count to registry name storage  | 
 | 272 | +                    if return_registry_info:  | 
 | 273 | +                        registry_data[self.registry_names[index]] += 1  | 
 | 274 | + | 
 | 275 | +        if return_registry_info:  | 
 | 276 | +            return converters_with_target, registry_data  | 
 | 277 | +        else:  | 
 | 278 | +            return converters_with_target  | 
 | 279 | + | 
 | 280 | +    def __setitem__(self, key, value):  | 
 | 281 | +        raise AssertionError(  | 
 | 282 | +            f"Do not set registry members directly through the ConverterRegistry object. "  | 
 | 283 | +            + f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry."  | 
 | 284 | +        )  | 
 | 285 | + | 
 | 286 | +    def __delitem__(self, key):  | 
 | 287 | +        raise AssertionError(  | 
 | 288 | +            f"Do not delete registry members directly through the ConverterRegistry object. "  | 
 | 289 | +            + f"Attempted to delete {key} via direct del on ConverterRegistry."  | 
 | 290 | +        )  | 
 | 291 | + | 
 | 292 | +    def __len__(self):  | 
 | 293 | +        """Returns the sum of lengths of all registries stored"""  | 
 | 294 | +        return sum(len(registry) for registry in self.registries)  | 
 | 295 | + | 
 | 296 | +    def unique_targets(self):  | 
 | 297 | +        """Returns the set of unique converter targets stored across all registries"""  | 
 | 298 | +        return set.union(*[set(registry.keys()) for registry in self.registries])  | 
 | 299 | + | 
 | 300 | +    def qualified_name_or_str(self, target: Target) -> str:  | 
 | 301 | +        """Returns string representation of an FX Node target"""  | 
 | 302 | +        if isinstance(target, str):  | 
 | 303 | +            return target  | 
 | 304 | +        else:  | 
 | 305 | +            return _get_qualified_name(target)  | 
 | 306 | + | 
 | 307 | +    def display_all_available_converters(self) -> str:  | 
 | 308 | +        """Returns a string with all converters and their source, separated by newlines"""  | 
 | 309 | +        available_converters = "Available converters in ATen registries with counts:\n"  | 
 | 310 | + | 
 | 311 | +        for target in sorted(  | 
 | 312 | +            self.unique_targets(), key=lambda target: self.qualified_name_or_str(target)  | 
 | 313 | +        ):  | 
 | 314 | +            _, registry_data = self.get_all_converters_with_target(  | 
 | 315 | +                target, return_registry_info=True  | 
 | 316 | +            )  | 
 | 317 | +            available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n"  | 
 | 318 | + | 
 | 319 | +        return available_converters  | 
 | 320 | + | 
 | 321 | + | 
 | 322 | +# Initialize dynamo converter registry with the FX and Dynamo aten registries  | 
 | 323 | +# Note the Dynamo registry is listed first, for precedence  | 
 | 324 | +DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry(  | 
 | 325 | +    [DYNAMO_ATEN_CONVERTERS, CONVERTERS],  | 
 | 326 | +    ["Dynamo ATen Converters Registry", "FX ATen Converters Registry"],  | 
 | 327 | +)  | 
0 commit comments