@@ -73,17 +73,21 @@ def foo(x, y):
7373
7474######################################################################
7575# Alternatively, we can decorate the function.
76+ t1 = torch .randn (10 , 10 )
77+ t2 = torch .randn (10 , 10 )
7678
7779@torch .compile
7880def opt_foo2 (x , y ):
7981 a = torch .sin (x )
8082 b = torch .cos (y )
8183 return a + b
82- print (opt_foo2 (torch . randn ( 10 , 10 ), torch . randn ( 10 , 10 ) ))
84+ print (opt_foo2 (t1 , t2 ))
8385
8486######################################################################
8587# We can also optimize ``torch.nn.Module`` instances.
8688
89+ t = torch .randn (10 , 100 )
90+
8791class MyModule (torch .nn .Module ):
8892 def __init__ (self ):
8993 super ().__init__ ()
@@ -94,7 +98,101 @@ def forward(self, x):
9498
9599mod = MyModule ()
96100opt_mod = torch .compile (mod )
97- print (opt_mod (torch .randn (10 , 100 )))
101+ print (opt_mod (t ))
102+
103+ ######################################################################
104+ # torch.compile and Nested Calls
105+ # ------------------------------
106+ # Nested function calls within the decorated function will also be compiled.
107+
108+ def nested_function (x ):
109+ return torch .sin (x )
110+
111+ @torch .compile
112+ def outer_function (x , y ):
113+ a = nested_function (x )
114+ b = torch .cos (y )
115+ return a + b
116+
117+ print (outer_function (t1 , t2 ))
118+
119+ ######################################################################
120+ # In the same fashion, when compiling a module all sub-modules and methods
121+ # within it, that are not in a skip list, are also compiled.
122+
123+ class OuterModule (torch .nn .Module ):
124+ def __init__ (self ):
125+ super ().__init__ ()
126+ self .inner_module = MyModule ()
127+ self .outer_lin = torch .nn .Linear (10 , 2 )
128+
129+ def forward (self , x ):
130+ x = self .inner_module (x )
131+ return torch .nn .functional .relu (self .outer_lin (x ))
132+
133+ outer_mod = OuterModule ()
134+ opt_outer_mod = torch .compile (outer_mod )
135+ print (opt_outer_mod (t ))
136+
137+ ######################################################################
138+ # We can also disable some functions from being compiled by using
139+ # ``torch.compiler.disable``. Suppose you want to disable the tracing on just
140+ # the ``complex_function`` function, but want to continue the tracing back in
141+ # ``complex_conjugate``. In this case, you can use
142+ # ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is
143+ # ``recursive=True``.
144+
145+ def complex_conjugate (z ):
146+ return torch .conj (z )
147+
148+ @torch .compiler .disable (recursive = False )
149+ def complex_function (real , imag ):
150+ # Assuming this function cause problems in the compilation
151+ z = torch .complex (real , imag )
152+ return complex_conjugate (z )
153+
154+ def outer_function ():
155+ real = torch .tensor ([2 , 3 ], dtype = torch .float32 )
156+ imag = torch .tensor ([4 , 5 ], dtype = torch .float32 )
157+ z = complex_function (real , imag )
158+ return torch .abs (z )
159+
160+ # Try to compile the outer_function
161+ try :
162+ opt_outer_function = torch .compile (outer_function )
163+ print (opt_outer_function ())
164+ except Exception as e :
165+ print ("Compilation of outer_function failed:" , e )
166+
167+ ######################################################################
168+ # Best Practices and Recommendations
169+ # ----------------------------------
170+ #
171+ # Behavior of ``torch.compile`` with Nested Modules and Function Calls
172+ #
173+ # When you use ``torch.compile``, the compiler will try to recursively compile
174+ # every function call inside the target function or module inside the target
175+ # function or module that is not in a skip list (such as built-ins, some functions in
176+ # the torch.* namespace).
177+ #
178+ # **Best Practices:**
179+ #
180+ # 1. **Top-Level Compilation:** One approach is to compile at the highest level
181+ # possible (i.e., when the top-level module is initialized/called) and
182+ # selectively disable compilation when encountering excessive graph breaks or
183+ # errors. If there are still many compile issues, compile individual
184+ # subcomponents instead.
185+ #
186+ # 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
187+ # before integrating them into larger models to isolate potential issues.
188+ #
189+ # 3. **Disable Compilation Selectively:** If certain functions or sub-modules
190+ # cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
191+ # managers to recursively exclude them from compilation.
192+ #
193+ # 4. **Compile Leaf Functions First:** In complex models with multiple nested
194+ # functions and modules, start by compiling the leaf functions or modules first.
195+ # For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
98196
99197######################################################################
100198# Demonstrating Speedups
0 commit comments