@@ -66,7 +66,7 @@ def display(img):
6666######################################################################
6767# ``crop`` is not handled effectively out-of-the-box by
6868# ``torch.compile``: ``torch.compile`` induces a
69- # `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
69+ # `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
7070# on functions it is unable to handle and graph breaks are bad for performance.
7171# The following code demonstrates this by raising an error
7272# (``torch.compile`` with ``fullgraph=True`` raises an error if a
@@ -85,9 +85,9 @@ def f(img):
8585#
8686# 1. wrap the function into a PyTorch custom operator.
8787# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
88- # Given the metadata (e.g. shapes)
89- # of the input Tensors, this function says how to compute the metadata
90- # of the output Tensor(s ).
88+ # Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
89+ # this function should return dummy Tensors of your choice with the correct
90+ # Tensor metadata (shape/strides/``dtype``/device ).
9191
9292
9393from typing import Sequence
@@ -130,6 +130,11 @@ def f(img):
130130# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
131131# has led to) silent incorrectness when composed with ``torch.compile``.
132132#
133+ # If you don't need training support, there is no need to use
134+ # ``torch.library.register_autograd``.
135+ # If you end up training with a ``custom_op`` that doesn't have an autograd
136+ # registration, we'll raise an error message.
137+ #
133138# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
134139# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
135140# custom operator:
@@ -203,7 +208,7 @@ def setup_context(ctx, inputs, output):
203208######################################################################
204209# Mutable Python Custom operators
205210# -------------------------------
206- # You can also wrap a Python function that mutates its inputs into a custom
211+ # You can also wrap a Python function that mutates its inputs into a custom
207212# operator.
208213# Functions that mutate inputs are common because that is how many low-level
209214# kernels are written; for example, a kernel that computes ``sin`` may take in
0 commit comments