Skip to content

Conversation

@zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Oct 28, 2025

Description

Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, we want to create similar PyTorch native system to use with Torch-TensorRT that recovers some of this behavior.

Fixes #3869

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Oct 28, 2025
@meta-cla meta-cla bot added the cla signed label Oct 28, 2025
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 28, 2025
@github-actions github-actions bot requested a review from apbose October 28, 2025 05:16
@zewenli98 zewenli98 removed the request for review from apbose October 28, 2025 05:16
@github-actions github-actions bot removed the component: conversion Issues re: Conversion stage label Oct 29, 2025
"targets_to_exclude": targets_to_exclude,
"data_max": data_max,
"max_depth_of_reduction": max_depth_of_reduction,
"intermediate_node_outputs": intermediate_node_outputs,
Copy link
Collaborator Author

@zewenli98 zewenli98 Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intermediate_node_outputs is used as calibration data for autocast ruleset. Since the compilation settings will be printed out in the terminal, consider if better to save to a file and just pass in filename.

Comment on lines +437 to +444
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
low_precision_type: Optional[
Union[torch.dtype, dtype]
] = _defaults.LOW_PRECISION_TYPE,
nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE,
targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE,
data_max: float = _defaults.DATA_MAX,
max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, these args should be added to other compile functions in this file.

]:
# GEMM: A (M, K) @ B (K, N) = C (M, N)
self.reduction_depth = input_0_dims[-1]
# TODO: Add more reduction ops here
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should any more reduction targets be added?

Comment on lines -374 to -377
assert (
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

Copy link
Collaborator Author

@zewenli98 zewenli98 Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This precision check was removed because after autocasting, if the first layer runs in fp16 but the original input is fp32, input_dtypes will become fp16 but contiguous_inputs is still fp32.

Similarly, other runtimes also removed the check.

Comment on lines +33 to +46
# nodes = list(gm.graph.nodes)
# # insert enter autocast node in the beginning of the graph
# with gm.graph.inserting_before(nodes[0]):
# enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True))
# enter_autocast_node.meta.update(getattr(nodes[0], "meta", {}))

# # insert exit autocast node before the return node, assuming the return node is the last node
# with gm.graph.inserting_before(nodes[-1]):
# exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,))
# exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {}))

# gm = clean_up_graph_after_modifications(gm)
# gm, new_signature = replace_autocast_with_hop_pass(gm, None)
# logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using Pytorch autocast to wrap the whole model, pytorch will control the precision of each node per the doc and I didn't find a way to customize based on our ruleset.

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines -110 to -113
auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not necessary now ?

Comment on lines +65 to +73
# model = MyModule().cuda().eval()
# inputs = (torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),
# torch.randn((8, 8), device="cuda"),)

# model = AutocastExample().cuda().eval()
# inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"),
# torch.randn((1,), dtype=torch.float16, device="cuda"),)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

Comment on lines +7 to +49
class MyModule(torch.nn.Module):
def forward(self, a_float32, b_float32, c_float32, d_float32):
with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())

# No manual casts are required when re-entering the autocast-enabled region.
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
g_float16 = torch.mm(d_float32, f_float32)
return g_float16


class AutocastExample(nn.Module):
def __init__(self):
super(AutocastExample, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1
)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(16 * 8 * 8, 10)

def forward(self, x, y):
out = self.pool1(self.relu1(self.conv1(x))) # fp16
x = self.pool2(self.relu2(self.conv2(out))) # fp16
x = self.flatten(x)
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
x = self.fc1(x) # fp32
with torch.autocast(x.device.type, enabled=False):
x = torch.sub(x.half(), y) # fp16
out2 = torch.add(x, x) # fp16
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
out2 = torch.log(out2) # fp32
return x, out, out2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like these modules are not being used. Consider removing them

"""Check if a node should be skipped based on the rule.
Args:
node: The ONNX node to check.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX node - > torch.fx.Node

"""

@abc.abstractmethod
def _check_inner(self, node):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider modifying the name ?



class DepthOfReductionRule(NodeRuleBase):
"""Rule for keeping nodes with high depth of reduction in high precision."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add more detailed information about this class with an example in docstring ?

Comment on lines +33 to +46
# nodes = list(gm.graph.nodes)
# # insert enter autocast node in the beginning of the graph
# with gm.graph.inserting_before(nodes[0]):
# enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True))
# enter_autocast_node.meta.update(getattr(nodes[0], "meta", {}))

# # insert exit autocast node before the return node, assuming the return node is the last node
# with gm.graph.inserting_before(nodes[-1]):
# exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,))
# exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {}))

# gm = clean_up_graph_after_modifications(gm)
# gm, new_signature = replace_autocast_with_hop_pass(gm, None)
# logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider removing commented code

Comment on lines +55 to +59
nodes_to_exclude = settings.nodes_to_exclude
targets_to_exclude = settings.targets_to_exclude
data_max = settings.data_max
max_depth_of_reduction = settings.max_depth_of_reduction
reference_data: dict[str, torch.Tensor] = settings.intermediate_node_outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think of tying these settings together like eg: settings.autocast_settings.nodes_to_exclude ?

):
continue

def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this function outside the loop ?

Comment on lines +660 to +686
class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Dump intermediate outputs of each node"""

def run_node(self, n: torch.fx.Node) -> Any:
if (
n.op == "call_function"
and n.target != torch.ops.higher_order.wrap_with_autocast
):
out = super().run_node(n)
if not isinstance(out, torch.Tensor):
raise ValueError(
f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}."
)
intermediate_node_outputs[n.name] = out
return out
return super().run_node(n)

def _materialize(x: Input | torch.Tensor) -> torch.Tensor:
"""Materialize an Input object to a tensor"""
if isinstance(x, Input):
return x.torch_tensor
return x

with torch.no_grad():
mat_args = tuple(_materialize(a) for a in arg_inputs)
mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()}
DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a general useful utility. Consider moving this to utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants