Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ void setup_input_tensors(
TORCHTRT_CHECK(
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());

auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
Comment on lines -110 to -113
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not necessary now ?

"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());

auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
Expand Down
103 changes: 103 additions & 0 deletions examples/dynamo/autocast_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn as nn
import torch_tensorrt
import torchvision


class MyModule(torch.nn.Module):
def forward(self, a_float32, b_float32, c_float32, d_float32):
with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())

# No manual casts are required when re-entering the autocast-enabled region.
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
g_float16 = torch.mm(d_float32, f_float32)
return g_float16


class AutocastExample(nn.Module):
def __init__(self):
super(AutocastExample, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1
)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(16 * 8 * 8, 10)

def forward(self, x, y):
out = self.pool1(self.relu1(self.conv1(x))) # fp16
x = self.pool2(self.relu2(self.conv2(out))) # fp16
x = self.flatten(x)
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
x = self.fc1(x) # fp32
with torch.autocast(x.device.type, enabled=False):
x = torch.sub(x.half(), y) # fp16
out2 = torch.add(x, x) # fp16
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
out2 = torch.log(out2) # fp32
return x, out, out2
Comment on lines +7 to +49
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like these modules are not being used. Consider removing them



class MyResNet18Wrapper(torch.nn.Module):
def __init__(self, num_classes=1000, pretrained=True):
super(MyResNet18Wrapper, self).__init__()
self.resnet = torchvision.models.resnet18(
num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None
)

def forward(self, x):
x = self.resnet(x)
return x


if __name__ == "__main__":
# model = MyModule().cuda().eval()
# inputs = (torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),)

# model = AutocastExample().cuda().eval()
# inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
# torch.randn((1,), dtype=torch.float16, device="cuda"),)
Comment on lines +65 to +73
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this


model = MyResNet18Wrapper().cuda().eval()
inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),)

ep = torch.export.export(model, inputs)

with torch_tensorrt.dynamo.Debugger(
"graphs",
logging_dir=".",
engine_builder_monitor=False,
):
trt_mod = torch_tensorrt.compile(
ep.module(),
arg_inputs=inputs,
min_block_size=1,
use_python_runtime=True,
##### weak typing #####
# use_explicit_typing=False,
# enabled_precisions={torch.float16},
##### strong typing + autocast #####
use_explicit_typing=True,
enable_autocast=True,
low_precision_type=torch.float16,
# nodes_to_exclude={"^conv2d$"},
targets_to_exclude={},
data_max=512,
max_depth_of_reduction=None,
)

trt_out = trt_mod(*inputs)
72 changes: 71 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def cross_compile_for_windows(
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
Expand Down Expand Up @@ -434,6 +434,14 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
low_precision_type: Optional[
Union[torch.dtype, dtype]
] = _defaults.LOW_PRECISION_TYPE,
nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE,
targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE,
data_max: float = _defaults.DATA_MAX,
max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION,
Comment on lines +437 to +444
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, these args should be added to other compile functions in this file.

**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -511,6 +519,12 @@ def compile(
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -584,6 +598,10 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if enable_autocast:
use_explicit_typing = True
logger.debug("Autocast is enabled, setting use_explicit_typing to True.")

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions
Expand All @@ -593,6 +611,19 @@ def compile(
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
)

if low_precision_type is not None:
if not isinstance(low_precision_type, (torch.dtype, dtype)):
raise ValueError(
f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}"
)
if low_precision_type not in {
torch.float16,
torch.bfloat16,
} and low_precision_type not in {dtype.f16, dtype.bf16}:
raise ValueError(
f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}"
)

if use_fp32_acc:
logger.debug(
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
Expand Down Expand Up @@ -622,6 +653,38 @@ def compile(
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# save intermediate outputs of each node for Autocast
intermediate_node_outputs = {}
if not use_explicit_typing:

class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Dump intermediate outputs of each node"""

def run_node(self, n: torch.fx.Node) -> Any:
if (
n.op == "call_function"
and n.target != torch.ops.higher_order.wrap_with_autocast
):
out = super().run_node(n)
if not isinstance(out, torch.Tensor):
raise ValueError(
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
)
intermediate_node_outputs[n.name] = out
return out
return super().run_node(n)

def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
"""Materialize an Input object to a tensor"""
if isinstance(x, Input):
return x.torch_tensor
return x

with torch.no_grad():
mat_args = tuple(_materialize(a) for a in arg_inputs)
mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()}
DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs)
Comment on lines +660 to +686
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a general useful utility. Consider moving this to utils.py


# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
Expand Down Expand Up @@ -680,6 +743,13 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"enable_autocast": enable_autocast,
"low_precision_type": low_precision_type,
"nodes_to_exclude": nodes_to_exclude,
"targets_to_exclude": targets_to_exclude,
"data_max": data_max,
"max_depth_of_reduction": max_depth_of_reduction,
"intermediate_node_outputs": intermediate_node_outputs,
Copy link
Collaborator Author

@zewenli98 zewenli98 Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intermediate_node_outputs is used as calibration data for autocast ruleset. Since the compilation settings will be printed out in the terminal, consider if better to save to a file and just pass in filename.

}

settings = CompilationSettings(**compilation_options)
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
ENABLE_AUTOCAST = False
LOW_PRECISION_TYPE = None
NODES_TO_EXCLUDE = set[str]()
TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]()
DATA_MAX = 512
MAX_DEPTH_OF_REDUCTION = None

if platform.system() == "Linux":
import pwd
Expand Down
26 changes: 26 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from dataclasses import dataclass, field
from typing import Any, Collection, Optional, Set, Tuple, Union

import torch
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
DATA_MAX,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_AUTOCAST,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
Expand All @@ -21,8 +24,11 @@
IMMUTABLE_WEIGHTS,
L2_LIMIT_FOR_TILING,
LAZY_ENGINE_INIT,
LOW_PRECISION_TYPE,
MAX_AUX_STREAMS,
MAX_DEPTH_OF_REDUCTION,
MIN_BLOCK_SIZE,
NODES_TO_EXCLUDE,
NUM_AVG_TIMING_ITERS,
OFFLOAD_MODULE_TO_CPU,
OPTIMIZATION_LEVEL,
Expand All @@ -32,6 +38,7 @@
REUSE_CACHED_ENGINES,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TARGETS_TO_EXCLUDE,
TILING_OPTIMIZATION_LEVEL,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
Expand Down Expand Up @@ -97,6 +104,13 @@ class CompilationSettings:
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is [].
targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None.
intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -140,6 +154,17 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
enable_autocast: bool = ENABLE_AUTOCAST
low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE
nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE)
targets_to_exclude: Collection[Target] = field(
default_factory=lambda: TARGETS_TO_EXCLUDE
)
data_max: float = DATA_MAX
max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION
intermediate_node_outputs: dict[str, torch.Tensor] = field(
default_factory=lambda: {}
)

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand All @@ -157,6 +182,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)


# If any of the following setting is changed, the engine should be rebuilt.
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
"enabled_precisions",
"max_aux_streams",
Expand Down
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast

pre_lowering_pass_list = [
remove_detach,
rule_based_autocast,
remove_assert_nodes, # rule_based_autocast might insert assert nodes
]

post_lowering_pass_list = [
remove_input_alias_fixing_clones,
Expand All @@ -27,10 +34,6 @@
complex_graph_detection,
]

pre_lowering_pass_list = [
remove_detach,
]

if not is_tegra_platform():
from .fuse_distributed_ops import fuse_distributed_ops

Expand Down
Loading
Loading