-
Notifications
You must be signed in to change notification settings - Fork 368
feat: Autocast #3878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Autocast #3878
Conversation
| "targets_to_exclude": targets_to_exclude, | ||
| "data_max": data_max, | ||
| "max_depth_of_reduction": max_depth_of_reduction, | ||
| "intermediate_node_outputs": intermediate_node_outputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intermediate_node_outputs is used as calibration data for autocast ruleset. Since the compilation settings will be printed out in the terminal, consider if better to save to a file and just pass in filename.
| enable_autocast: bool = _defaults.ENABLE_AUTOCAST, | ||
| low_precision_type: Optional[ | ||
| Union[torch.dtype, dtype] | ||
| ] = _defaults.LOW_PRECISION_TYPE, | ||
| nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE, | ||
| targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE, | ||
| data_max: float = _defaults.DATA_MAX, | ||
| max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before merging, these args should be added to other compile functions in this file.
| ]: | ||
| # GEMM: A (M, K) @ B (K, N) = C (M, N) | ||
| self.reduction_depth = input_0_dims[-1] | ||
| # TODO: Add more reduction ops here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should any more reduction targets be added?
| assert ( | ||
| contiguous_inputs[i].dtype == self.input_dtypes[i] | ||
| ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This precision check was removed because after autocasting, if the first layer runs in fp16 but the original input is fp32, input_dtypes will become fp16 but contiguous_inputs is still fp32.
Similarly, other runtimes also removed the check.
| # nodes = list(gm.graph.nodes) | ||
| # # insert enter autocast node in the beginning of the graph | ||
| # with gm.graph.inserting_before(nodes[0]): | ||
| # enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True)) | ||
| # enter_autocast_node.meta.update(getattr(nodes[0], "meta", {})) | ||
|
|
||
| # # insert exit autocast node before the return node, assuming the return node is the last node | ||
| # with gm.graph.inserting_before(nodes[-1]): | ||
| # exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,)) | ||
| # exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {})) | ||
|
|
||
| # gm = clean_up_graph_after_modifications(gm) | ||
| # gm, new_signature = replace_autocast_with_hop_pass(gm, None) | ||
| # logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If using Pytorch autocast to wrap the whole model, pytorch will control the precision of each node per the doc and I didn't find a way to customize based on our ruleset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also update the documentation at https://github.com/pytorch/TensorRT/blob/main/docsrc/user_guide/mixed_precision.rst
| auto expected_type = | ||
| util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); | ||
| TORCHTRT_CHECK( | ||
| inputs[i].dtype() == expected_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not necessary now ?
| # model = MyModule().cuda().eval() | ||
| # inputs = (torch.randn((8, 8), device="cuda"), | ||
| # torch.randn((8, 8), device="cuda"), | ||
| # torch.randn((8, 8), device="cuda"), | ||
| # torch.randn((8, 8), device="cuda"),) | ||
|
|
||
| # model = AutocastExample().cuda().eval() | ||
| # inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), | ||
| # torch.randn((1,), dtype=torch.float16, device="cuda"),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
| class MyModule(torch.nn.Module): | ||
| def forward(self, a_float32, b_float32, c_float32, d_float32): | ||
| with torch.autocast(device_type="cuda"): | ||
| e_float16 = torch.mm(a_float32, b_float32) | ||
| with torch.autocast(device_type="cuda", enabled=False): | ||
| # Calls e_float16.float() to ensure float32 execution | ||
| # (necessary because e_float16 was created in an autocasted region) | ||
| f_float32 = torch.mm(c_float32, e_float16.float()) | ||
|
|
||
| # No manual casts are required when re-entering the autocast-enabled region. | ||
| # torch.mm again runs in float16 and produces float16 output, regardless of input types. | ||
| g_float16 = torch.mm(d_float32, f_float32) | ||
| return g_float16 | ||
|
|
||
|
|
||
| class AutocastExample(nn.Module): | ||
| def __init__(self): | ||
| super(AutocastExample, self).__init__() | ||
| self.conv1 = nn.Conv2d( | ||
| in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu1 = nn.ReLU() | ||
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.conv2 = nn.Conv2d( | ||
| in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu2 = nn.ReLU() | ||
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.flatten = nn.Flatten() | ||
| self.fc1 = nn.Linear(16 * 8 * 8, 10) | ||
|
|
||
| def forward(self, x, y): | ||
| out = self.pool1(self.relu1(self.conv1(x))) # fp16 | ||
| x = self.pool2(self.relu2(self.conv2(out))) # fp16 | ||
| x = self.flatten(x) | ||
| with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): | ||
| x = self.fc1(x) # fp32 | ||
| with torch.autocast(x.device.type, enabled=False): | ||
| x = torch.sub(x.half(), y) # fp16 | ||
| out2 = torch.add(x, x) # fp16 | ||
| with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): | ||
| out2 = torch.log(out2) # fp32 | ||
| return x, out, out2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like these modules are not being used. Consider removing them
| """Check if a node should be skipped based on the rule. | ||
| Args: | ||
| node: The ONNX node to check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX node - > torch.fx.Node
| """ | ||
|
|
||
| @abc.abstractmethod | ||
| def _check_inner(self, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider modifying the name ?
|
|
||
|
|
||
| class DepthOfReductionRule(NodeRuleBase): | ||
| """Rule for keeping nodes with high depth of reduction in high precision.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add more detailed information about this class with an example in docstring ?
| # nodes = list(gm.graph.nodes) | ||
| # # insert enter autocast node in the beginning of the graph | ||
| # with gm.graph.inserting_before(nodes[0]): | ||
| # enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True)) | ||
| # enter_autocast_node.meta.update(getattr(nodes[0], "meta", {})) | ||
|
|
||
| # # insert exit autocast node before the return node, assuming the return node is the last node | ||
| # with gm.graph.inserting_before(nodes[-1]): | ||
| # exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,)) | ||
| # exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {})) | ||
|
|
||
| # gm = clean_up_graph_after_modifications(gm) | ||
| # gm, new_signature = replace_autocast_with_hop_pass(gm, None) | ||
| # logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider removing commented code
| nodes_to_exclude = settings.nodes_to_exclude | ||
| targets_to_exclude = settings.targets_to_exclude | ||
| data_max = settings.data_max | ||
| max_depth_of_reduction = settings.max_depth_of_reduction | ||
| reference_data: dict[str, torch.Tensor] = settings.intermediate_node_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think of tying these settings together like eg: settings.autocast_settings.nodes_to_exclude ?
| ): | ||
| continue | ||
|
|
||
| def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this function outside the loop ?
| class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] | ||
| """Dump intermediate outputs of each node""" | ||
|
|
||
| def run_node(self, n: torch.fx.Node) -> Any: | ||
| if ( | ||
| n.op == "call_function" | ||
| and n.target != torch.ops.higher_order.wrap_with_autocast | ||
| ): | ||
| out = super().run_node(n) | ||
| if not isinstance(out, torch.Tensor): | ||
| raise ValueError( | ||
| f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." | ||
| ) | ||
| intermediate_node_outputs[n.name] = out | ||
| return out | ||
| return super().run_node(n) | ||
|
|
||
| def _materialize(x: Input | torch.Tensor) -> torch.Tensor: | ||
| """Materialize an Input object to a tensor""" | ||
| if isinstance(x, Input): | ||
| return x.torch_tensor | ||
| return x | ||
|
|
||
| with torch.no_grad(): | ||
| mat_args = tuple(_materialize(a) for a in arg_inputs) | ||
| mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()} | ||
| DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a general useful utility. Consider moving this to utils.py
Description
Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, we want to create similar PyTorch native system to use with Torch-TensorRT that recovers some of this behavior.
Fixes #3869
Type of change
Checklist: