diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..6ea4d01d03 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -107,12 +107,6 @@ void setup_input_tensors( TORCHTRT_CHECK( inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - auto expected_type = - util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - TORCHTRT_CHECK( - inputs[i].dtype() == expected_type, - "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); - auto dims = core::util::toDims(inputs[i].sizes()); auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py new file mode 100644 index 0000000000..9feeb6d751 --- /dev/null +++ b/examples/dynamo/autocast_example.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch_tensorrt +import torchvision + + +class MyModule(torch.nn.Module): + def forward(self, a_float32, b_float32, c_float32, d_float32): + with torch.autocast(device_type="cuda"): + e_float16 = torch.mm(a_float32, b_float32) + with torch.autocast(device_type="cuda", enabled=False): + # Calls e_float16.float() to ensure float32 execution + # (necessary because e_float16 was created in an autocasted region) + f_float32 = torch.mm(c_float32, e_float16.float()) + + # No manual casts are required when re-entering the autocast-enabled region. + # torch.mm again runs in float16 and produces float16 output, regardless of input types. + g_float16 = torch.mm(d_float32, f_float32) + return g_float16 + + +class AutocastExample(nn.Module): + def __init__(self): + super(AutocastExample, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x, y): + out = self.pool1(self.relu1(self.conv1(x))) # fp16 + x = self.pool2(self.relu2(self.conv2(out))) # fp16 + x = self.flatten(x) + with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): + x = self.fc1(x) # fp32 + with torch.autocast(x.device.type, enabled=False): + x = torch.sub(x.half(), y) # fp16 + out2 = torch.add(x, x) # fp16 + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + out2 = torch.log(out2) # fp32 + return x, out, out2 + + +class MyResNet18Wrapper(torch.nn.Module): + def __init__(self, num_classes=1000, pretrained=True): + super(MyResNet18Wrapper, self).__init__() + self.resnet = torchvision.models.resnet18( + num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None + ) + + def forward(self, x): + x = self.resnet(x) + return x + + +if __name__ == "__main__": + # model = MyModule().cuda().eval() + # inputs = (torch.randn((8, 8), device="cuda"), + # torch.randn((8, 8), device="cuda"), + # torch.randn((8, 8), device="cuda"), + # torch.randn((8, 8), device="cuda"),) + + # model = AutocastExample().cuda().eval() + # inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), + # torch.randn((1,), dtype=torch.float16, device="cuda"),) + + model = MyResNet18Wrapper().cuda().eval() + inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),) + + ep = torch.export.export(model, inputs) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_mod = torch_tensorrt.compile( + ep.module(), + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + ##### weak typing ##### + # use_explicit_typing=False, + # enabled_precisions={torch.float16}, + ##### strong typing + autocast ##### + use_explicit_typing=True, + enable_autocast=True, + low_precision_type=torch.float16, + # nodes_to_exclude={"^conv2d$"}, + targets_to_exclude={}, + data_max=512, + max_depth_of_reduction=None, + ) + + trt_out = trt_mod(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..511d215335 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -141,7 +141,7 @@ def cross_compile_for_windows( disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. - enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels + enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -434,6 +434,14 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + enable_autocast: bool = _defaults.ENABLE_AUTOCAST, + low_precision_type: Optional[ + Union[torch.dtype, dtype] + ] = _defaults.LOW_PRECISION_TYPE, + nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE, + targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE, + data_max: float = _defaults.DATA_MAX, + max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -511,6 +519,12 @@ def compile( l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. + low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -584,6 +598,10 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) + if enable_autocast: + use_explicit_typing = True + logger.debug("Autocast is enabled, setting use_explicit_typing to True.") + if use_explicit_typing: if len(enabled_precisions) != 1 or not any( x in enabled_precisions @@ -593,6 +611,19 @@ def compile( f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) + if low_precision_type is not None: + if not isinstance(low_precision_type, (torch.dtype, dtype)): + raise ValueError( + f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}" + ) + if low_precision_type not in { + torch.float16, + torch.bfloat16, + } and low_precision_type not in {dtype.f16, dtype.bf16}: + raise ValueError( + f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}" + ) + if use_fp32_acc: logger.debug( "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ @@ -622,6 +653,38 @@ def compile( if not isinstance(arg_inputs, collections.abc.Sequence): arg_inputs = [arg_inputs] # type: ignore + # save intermediate outputs of each node for Autocast + intermediate_node_outputs = {} + if not use_explicit_typing: + + class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] + """Dump intermediate outputs of each node""" + + def run_node(self, n: torch.fx.Node) -> Any: + if ( + n.op == "call_function" + and n.target != torch.ops.higher_order.wrap_with_autocast + ): + out = super().run_node(n) + if not isinstance(out, torch.Tensor): + raise ValueError( + f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." + ) + intermediate_node_outputs[n.name] = out + return out + return super().run_node(n) + + def _materialize(x: Input | torch.Tensor) -> torch.Tensor: + """Materialize an Input object to a tensor""" + if isinstance(x, Input): + return x.torch_tensor + return x + + with torch.no_grad(): + mat_args = tuple(_materialize(a) for a in arg_inputs) + mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()} + DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs) + # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) @@ -680,6 +743,13 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_autocast": enable_autocast, + "low_precision_type": low_precision_type, + "nodes_to_exclude": nodes_to_exclude, + "targets_to_exclude": targets_to_exclude, + "data_max": data_max, + "max_depth_of_reduction": max_depth_of_reduction, + "intermediate_node_outputs": intermediate_node_outputs, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..e69cda70c7 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,12 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +ENABLE_AUTOCAST = False +LOW_PRECISION_TYPE = None +NODES_TO_EXCLUDE = set[str]() +TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]() +DATA_MAX = 512 +MAX_DEPTH_OF_REDUCTION = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..e406bba615 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,17 +1,20 @@ from dataclasses import dataclass, field from typing import Any, Collection, Optional, Set, Tuple, Union +import torch from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + DATA_MAX, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, + ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLE_WEIGHT_STREAMING, @@ -21,8 +24,11 @@ IMMUTABLE_WEIGHTS, L2_LIMIT_FOR_TILING, LAZY_ENGINE_INIT, + LOW_PRECISION_TYPE, MAX_AUX_STREAMS, + MAX_DEPTH_OF_REDUCTION, MIN_BLOCK_SIZE, + NODES_TO_EXCLUDE, NUM_AVG_TIMING_ITERS, OFFLOAD_MODULE_TO_CPU, OPTIMIZATION_LEVEL, @@ -32,6 +38,7 @@ REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, + TARGETS_TO_EXCLUDE, TILING_OPTIMIZATION_LEVEL, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, @@ -97,6 +104,13 @@ class CompilationSettings: tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. + low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -140,6 +154,17 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + enable_autocast: bool = ENABLE_AUTOCAST + low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE + nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE) + targets_to_exclude: Collection[Target] = field( + default_factory=lambda: TARGETS_TO_EXCLUDE + ) + data_max: float = DATA_MAX + max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION + intermediate_node_outputs: dict[str, torch.Tensor] = field( + default_factory=lambda: {} + ) def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -157,6 +182,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) +# If any of the following setting is changed, the engine should be rebuilt. _SETTINGS_TO_BE_ENGINE_INVARIANT = ( "enabled_precisions", "max_aux_streams", diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index e5183668ae..1499e670bd 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -15,6 +15,13 @@ from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .rule_based_autocast import rule_based_autocast + +pre_lowering_pass_list = [ + remove_detach, + rule_based_autocast, + remove_assert_nodes, # rule_based_autocast might insert assert nodes +] post_lowering_pass_list = [ remove_input_alias_fixing_clones, @@ -27,10 +34,6 @@ complex_graph_detection, ] -pre_lowering_pass_list = [ - remove_detach, -] - if not is_tegra_platform(): from .fuse_distributed_ops import fuse_distributed_ops diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py new file mode 100644 index 0000000000..72bb376291 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -0,0 +1,300 @@ +# Borrowed from ModelOpt AutoCast's nodeclassifier.py, modified to fit Torch-TensorRT's needs. +import abc +import logging +import operator +import re +from typing import Collection, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class NodeRuleBase: + """Base class for node classification rules. + + This class defines the interface for rules that determine whether a node + should be kept in high precision or converted to low precision. + """ + + @abc.abstractmethod + def _check_inner(self, node): + """Implement this method to check if node conversion should be skipped based on rule criteria.""" + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes.""" + logger.info(f"Skipping node {node.name}: {self.__class__.__name__}") + + def check(self, node): + """Check if a node should be skipped based on the rule. + + Args: + node: The ONNX node to check. + + Returns: + bool: True if the node should be kept in high precision, False otherwise. + """ + result = self._check_inner(node) + if result: + self._log_skipped(node) + return True + return False + + +class DisabledNodeNameRegexRule(NodeRuleBase): + """Rule for keeping nodes with matching names in high precision.""" + + def __init__(self, disabled_node_name_regex): + """Initialize the rule. + + Args: + disabled_node_name_regex: List of regex patterns for node names to keep in high precision. + """ + self.disabled_node_name_regex = disabled_node_name_regex + + def _check_inner(self, node): + return any( + re.match(regex, node.name) for regex in self.disabled_node_name_regex + ) + + +class DisabledTargets(NodeRuleBase): + """Rule for keeping nodes with specific operation types in high precision.""" + + def __init__(self, targets_to_exclude): + """Initialize the rule. + + Args: + targets_to_exclude: List of operation types to keep in high precision. + """ + self.targets_to_exclude = targets_to_exclude + + def _check_inner(self, node): + return node.target in self.targets_to_exclude + + +class IORangeRule(NodeRuleBase): + """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" + + def __init__(self, data_max, reference_data): + """Initialize the rule. + + Args: + data_max: Maximum absolute value allowed for node I/O. + reference_data: Reference data for checking I/O ranges. + """ + self.data_max = data_max + self.reference_data = reference_data + self.output_data = None + + def _check_inner(self, node): + def is_io_out_of_range(node): + tensor_name = node.name + if tensor_name not in self.reference_data: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} not found in reference data. Skipping I/O range check." + ) + return False + ref_data = self.reference_data[tensor_name] + if ref_data.numel() == 0: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} has 0 elements. Skipping I/O range check." + ) + return False + logger.debug( + f"Node {node.name}: reference data: min={ref_data.min()}, max={ref_data.max()}" + ) + if torch.any(torch.abs(ref_data) > self.data_max): + self.output_data = ref_data + return True + + if self.reference_data: + for in_node in node.all_input_nodes: + if is_io_out_of_range(in_node): + return True + for out_node in list(node.users): + if is_io_out_of_range(out_node): + return True + return False + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with I/O range violations.""" + if self.output_data is not None: + logger.info( + f"Skipping node {node.name}: reference IO out of range: min={torch.min(self.output_data)}, " + f"max={torch.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]" + ) + else: + super()._log_skipped(node, **kwargs) + + +class DepthOfReductionRule(NodeRuleBase): + """Rule for keeping nodes with high depth of reduction in high precision.""" + + def __init__(self, max_depth_of_reduction, reference_data): + """Initialize the rule. + + Args: + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + reference_data: Reference data for checking I/O ranges. + """ + self.max_depth_of_reduction = max_depth_of_reduction + self.reference_data = reference_data + self.reduction_depth = 0 + + def _get_tensor_shape(self, tensor_name): + """Get tensor shape from reference data.""" + if tensor_name in self.reference_data: + return self.reference_data[tensor_name].shape + return None + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with depth of reduction violations.""" + if self.reduction_depth > 0: + logger.info( + f"Skipping node {node.name}: depth of reduction {self.reduction_depth} exceeds " + f"{self.max_depth_of_reduction}." + ) + else: + super()._log_skipped(node, **kwargs) + + def _check_inner(self, node): + # All reduction ops rely on shape of input[0] + input_0_dims = ( + self._get_tensor_shape(node.all_input_nodes[0].name) + if len(node.all_input_nodes) > 0 + else None + ) + if input_0_dims is None: + return False + self.reduction_depth = 0 + if node.target in [ + torch.ops.aten.scaled_dot_product_attention.default, + ]: + # Attention: input (batch_size, sequence_length, hidden_size) + # or (batch_size, kv_num_heads, total_sequence_length, head_size) + assert len(input_0_dims) == 3 or len(input_0_dims) == 4 + hidden_size = ( + input_0_dims[2] + if len(input_0_dims) == 3 + else input_0_dims[1] * input_0_dims[3] + ) + self.reduction_depth = hidden_size + elif node.target in [ + torch.ops.aten.convolution.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + ]: + # Conv: input (N x C x D1 x D2 ... x Dn) + # weight (out_channels, in_channels, kD1, kD2, ... kDn) + # Reduction depth = in_channels * kernel_volume + weight_shape = ( + self._get_tensor_shape(node.all_input_nodes[1].name) + if len(node.all_input_nodes) > 1 + else None + ) + if weight_shape is None: + return False + in_channels = weight_shape[1] + kernel_volume = torch.prod(weight_shape[2:]) + self.reduction_depth = in_channels * kernel_volume + elif node.target in [ + torch.ops.aten.matmul, + torch.ops.aten.matmul.default, + torch.ops.aten.dot.default, + torch.ops.aten.mm.default, + torch.ops.aten.mv.default, + torch.ops.aten.bmm.default, + ]: + # GEMM: A (M, K) @ B (K, N) = C (M, N) + self.reduction_depth = input_0_dims[-1] + # TODO: Add more reduction ops here + return self.reduction_depth > self.max_depth_of_reduction + + +class NodeClassifier: + """Main class for classifying nodes into high and low precision groups.""" + + def __init__( + self, + nodes, + nodes_to_exclude: Collection[str] | None = None, + targets_to_exclude: Collection[torch.fx.node.Target] | None = None, + custom_rule: NodeRuleBase | None = None, + data_max: float | None = 1000.0, + max_depth_of_reduction: int | None = None, + ): + """Initialize the node classifier. + + Args: + nodes: The nodes to classify. + nodes_to_exclude: Collection of regex patterns for node names to keep in high precision. + targets_to_exclude: Collection of targets to keep in high precision. + custom_rule: Optional custom classification rule. + data_max: Maximum absolute value allowed for node I/O. + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + """ + self.nodes = nodes + self.nodes_to_exclude = nodes_to_exclude + self.targets_to_exclude = targets_to_exclude + self.custom_rule = custom_rule + self.data_max = data_max + self.max_depth_of_reduction = max_depth_of_reduction + + def _gen_block_node_rules(self, reference_data): + """Generate list of rules for blocking nodes from precision conversion. + + Args: + reference_data: Reference data for checking I/O ranges. + + Returns: + list[NodeRuleBase]: List of rules to apply. + """ + block_node_rules: list[NodeRuleBase] = [] + if self.nodes_to_exclude: + block_node_rules.append(DisabledNodeNameRegexRule(self.nodes_to_exclude)) + if self.targets_to_exclude: + block_node_rules.append(DisabledTargets(self.targets_to_exclude)) + if reference_data: + block_node_rules.append(IORangeRule(self.data_max, reference_data)) + if self.max_depth_of_reduction is not None: + block_node_rules.append( + DepthOfReductionRule( + self.max_depth_of_reduction, + reference_data, + ) + ) + if self.custom_rule: + block_node_rules.append(self.custom_rule) + return block_node_rules + + def run( + self, ref_outputs_dict: Optional[dict[str, torch.Tensor]] = None + ) -> tuple[list[str], list[str]]: + """Run node classification. + + Args: + ref_outputs_dict: Optional tensors' reference data. + + Returns: + tuple: Lists of node names (low_precision_nodes, high_precision_nodes). + """ + block_node_rules = self._gen_block_node_rules(ref_outputs_dict) + low_precision_nodes = [] + high_precision_nodes = [] + for node in self.nodes: + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + # If any condition is met - node will be executed in high precision + if any(rule.check(node) for rule in block_node_rules): + high_precision_nodes.append(node.name) + else: + low_precision_nodes.append(node.name) + logger.debug(f"Low Precision Nodes: {low_precision_nodes}") + logger.debug(f"High Precision Nodes: {high_precision_nodes}") + return low_precision_nodes, high_precision_nodes diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py new file mode 100644 index 0000000000..6a824a6a90 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -0,0 +1,129 @@ +import logging +import operator +from typing import Any + +import torch +from torch._export.passes.replace_autocast_with_hop_pass import ( + replace_autocast_with_hop_pass, +) +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo._settings import CompilationSettings + +from .nodeclassifier import NodeClassifier +from .pass_utils import clean_up_graph_after_modifications + +logger = logging.getLogger(__name__) + + +def is_tensor_node(n: torch.fx.Node) -> bool: + val = n.meta.get("val", None) + if hasattr(val, "dtype"): + return True + return False + + +def rule_based_autocast( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Rule-based autocast""" + if not settings.enable_autocast: + logger.debug("Autocast is not enabled, skipping rule-based autocast.") + return gm + + # nodes = list(gm.graph.nodes) + # # insert enter autocast node in the beginning of the graph + # with gm.graph.inserting_before(nodes[0]): + # enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True)) + # enter_autocast_node.meta.update(getattr(nodes[0], "meta", {})) + + # # insert exit autocast node before the return node, assuming the return node is the last node + # with gm.graph.inserting_before(nodes[-1]): + # exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,)) + # exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {})) + + # gm = clean_up_graph_after_modifications(gm) + # gm, new_signature = replace_autocast_with_hop_pass(gm, None) + # logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph) + + # get config from settings + low_precision_type = settings.low_precision_type + if low_precision_type is None: + return gm + if isinstance(low_precision_type, dtype): + low_precision_type = low_precision_type.to(torch.dtype) + high_precision_type = torch.float32 + nodes_to_exclude = settings.nodes_to_exclude + targets_to_exclude = settings.targets_to_exclude + data_max = settings.data_max + max_depth_of_reduction = settings.max_depth_of_reduction + reference_data: dict[str, torch.Tensor] = settings.intermediate_node_outputs + + node_classifier = NodeClassifier( + gm.graph.nodes, + nodes_to_exclude=nodes_to_exclude, + targets_to_exclude=targets_to_exclude, + data_max=data_max, + max_depth_of_reduction=max_depth_of_reduction, + ) + low_precision_nodes, high_precision_nodes = node_classifier.run(reference_data) + + for node in list(gm.graph.nodes): + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + + def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any: + """Cast all tensor args to the given dtype + + Args: + arg: The argument to cast + dtype: The dtype to cast to + + Returns: + The casted argument + """ + if isinstance(arg, torch.fx.Node) and is_tensor_node(arg): + val = arg.meta.get("val", None) + with gm.graph.inserting_before(node): + cast = gm.graph.call_function( + torch.ops.aten.to.dtype, args=(arg, dtype) + ) + + if isinstance(val, torch.Tensor): + arg.meta["val"] = val.to(dtype) + cast.meta.update(arg.meta) + return cast + elif isinstance(arg, (tuple, list)): + return type(arg)( + _cast_all_tensor_args_to_dtype(a, dtype) for a in arg + ) + elif isinstance(arg, dict): + return { + k: _cast_all_tensor_args_to_dtype(v, dtype) + for k, v in arg.items() + } + else: + return arg + + if node.name in low_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node.args, low_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node.kwargs, low_precision_type + ) + elif node.name in high_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node.args, high_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node.kwargs, high_precision_type + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug("Graph after Autocast based on the rules:\n%s", gm.graph) + + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..24166eb895 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -154,10 +154,6 @@ def forward( + contiguous_inputs[i + 1 :] ) - assert ( - contiguous_inputs[i].dtype == inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..0eb5ebbbca 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -275,10 +275,6 @@ def setup_engine(self) -> None: len(self.input_names) + len(self.output_names) ) - self.input_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(input_name)) - for input_name in self.input_names - ] self.input_shapes = [ self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] @@ -371,10 +367,6 @@ def setup_input_tensors( + contiguous_inputs[i + 1 :] ) - assert ( - contiguous_inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory