Skip to content

Commit 8d73de3

Browse files
jstacclaude
andauthored
Improve jax_intro lecture: consolidate imports, clarify explanations (#527)
Move imports to top, restructure functional programming discussion, add inline comments to timing blocks, and streamline examples. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent be6eeae commit 8d73de3

File tree

1 file changed

+68
-70
lines changed

1 file changed

+68
-70
lines changed

lectures/jax_intro.md

Lines changed: 68 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ kernelspec:
1515

1616
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
1717

18+
```{include} _admonition/gpu.md
19+
```
20+
1821
JAX is a high-performance scientific computing library that provides
1922

2023
* a [NumPy](https://en.wikipedia.org/wiki/NumPy)-like interface that can automatically parallelize across CPUs and GPUs,
@@ -33,9 +36,19 @@ In addition to what's in Anaconda, this lecture will need the following librarie
3336
!pip install jax quantecon
3437
```
3538

36-
```{include} _admonition/gpu.md
39+
We'll use the following imports
40+
41+
```{code-cell} ipython3
42+
import jax
43+
import jax.numpy as jnp
44+
import matplotlib.pyplot as plt
45+
import numpy as np
46+
import quantecon as qe
3747
```
3848

49+
Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface.
50+
51+
3952
## JAX as a NumPy Replacement
4053

4154
One of the attractive features of JAX is that, whenever possible, its array
@@ -47,17 +60,6 @@ Let's look at the similarities and differences between JAX and NumPy.
4760

4861
### Similarities
4962

50-
We'll use the following imports
51-
52-
```{code-cell} ipython3
53-
import jax
54-
import jax.numpy as jnp
55-
import matplotlib.pyplot as plt
56-
import numpy as np
57-
import quantecon as qe
58-
```
59-
60-
Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface.
6163

6264
Here are some standard array operations using `jnp`:
6365

@@ -73,10 +75,6 @@ print(a)
7375
print(jnp.sum(a))
7476
```
7577

76-
```{code-cell} ipython3
77-
print(jnp.mean(a))
78-
```
79-
8078
```{code-cell} ipython3
8179
print(jnp.dot(a, a))
8280
```
@@ -91,30 +89,12 @@ a
9189
type(a)
9290
```
9391

94-
Even scalar-valued maps on arrays return JAX arrays.
92+
Even scalar-valued maps on arrays return JAX arrays rather than scalars!
9593

9694
```{code-cell} ipython3
9795
jnp.sum(a)
9896
```
9997

100-
Operations on higher dimensional arrays are also similar to NumPy:
101-
102-
```{code-cell} ipython3
103-
A = jnp.ones((2, 2))
104-
B = jnp.identity(2)
105-
A @ B
106-
```
107-
108-
JAX's array interface also provides the `linalg` subpackage:
109-
110-
```{code-cell} ipython3
111-
jnp.linalg.inv(B) # Inverse of identity is identity
112-
```
113-
114-
```{code-cell} ipython3
115-
eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
116-
eigvals
117-
```
11898

11999

120100
### Differences
@@ -137,13 +117,15 @@ Let's try with NumPy
137117

138118
```{code-cell}
139119
with qe.Timer():
120+
# First NumPy timing
140121
y = np.cos(x)
141122
```
142123

143124
And one more time.
144125

145126
```{code-cell}
146127
with qe.Timer():
128+
# Second NumPy timing
147129
y = np.cos(x)
148130
```
149131

@@ -165,7 +147,9 @@ Let's time the same procedure.
165147

166148
```{code-cell}
167149
with qe.Timer():
150+
# First run
168151
y = jnp.cos(x)
152+
# Hold the interpreter until the array operation finishes
169153
jax.block_until_ready(y);
170154
```
171155

@@ -184,7 +168,9 @@ And let's time it again.
184168

185169
```{code-cell}
186170
with qe.Timer():
171+
# Second run
187172
y = jnp.cos(x)
173+
# Hold interpreter
188174
jax.block_until_ready(y);
189175
```
190176

@@ -212,14 +198,18 @@ x = jnp.linspace(0, 10, n + 1)
212198

213199
```{code-cell}
214200
with qe.Timer():
201+
# First run
215202
y = jnp.cos(x)
203+
# Hold interpreter
216204
jax.block_until_ready(y);
217205
```
218206

219207

220208
```{code-cell}
221209
with qe.Timer():
210+
# Second run
222211
y = jnp.cos(x)
212+
# Hold interpreter
223213
jax.block_until_ready(y);
224214
```
225215

@@ -294,7 +284,7 @@ functional programming style, which we discuss below.
294284

295285
#### A workaround
296286

297-
We note that JAX does provide a version of in-place array modification
287+
We note that JAX does provide an alternative to in-place array modification
298288
using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
299289

300290
```{code-cell} ipython3
@@ -387,11 +377,26 @@ This pure version makes all dependencies explicit through function arguments, an
387377

388378
### Why Functional Programming?
389379

390-
JAX represents functions as computational graphs, which are then compiled or transformed (e.g., differentiated)
380+
At QuantEcon we love pure functions because they
381+
382+
* Help testing: each function can operate in isolation
383+
* Promote deterministic behavior and hence reproducibility
384+
* Prevent bugs that arise from mutating shared state
385+
386+
The JAX compiler loves pure functions and functional programming because
387+
388+
* Data dependencies are explicit, which helps with optimizing complex computations
389+
* Pure functions are easier to differentiate (autodiff)
390+
* Pure functions are easier to parallelize and optimize (don't depend on shared mutable state)
391+
392+
Another way to think of this is as follows:
393+
394+
JAX represents functions as computational graphs, which are then compiled or
395+
transformed (e.g., differentiated)
391396

392397
These computational graphs describe how a given set of inputs is transformed into an output.
393398

394-
They are pure by construction.
399+
JAX's computational graphs are pure by construction.
395400

396401
JAX uses a functional programming style so that user-built functions map
397402
directly into the graph-theoretic representations supported by JAX.
@@ -520,8 +525,8 @@ plt.tight_layout()
520525
plt.show()
521526
```
522527

523-
This syntax will seem unusual for a NumPy or Matlab user --- but will make a lot
524-
of sense when we progress to parallel programming.
528+
This syntax will seem unusual for a NumPy or Matlab user --- but will make more
529+
sense when we get to parallel programming.
525530

526531
The function below produces `k` (quasi-) independent random `n x n` matrices using `split`.
527532

@@ -664,6 +669,7 @@ x = np.linspace(0, 10, n)
664669

665670
```{code-cell}
666671
with qe.Timer():
672+
# Time NumPy code
667673
y = f(x)
668674
```
669675

@@ -679,27 +685,31 @@ As a first pass, we replace `np` with `jnp` throughout:
679685
def f(x):
680686
y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
681687
return y
682-
```
683688
684-
Now let's time it.
685689
686-
```{code-cell}
687690
x = jnp.linspace(0, 10, n)
688691
```
689692

693+
Now let's time it.
694+
690695
```{code-cell}
691696
with qe.Timer():
697+
# First call
692698
y = f(x)
699+
# Hold interpreter
693700
jax.block_until_ready(y);
694701
```
695702

696703
```{code-cell}
697704
with qe.Timer():
705+
# Second call
698706
y = f(x)
707+
# Hold interpreter
699708
jax.block_until_ready(y);
700709
```
701710

702-
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.
711+
The outcome is similar to the `cos` example --- JAX is faster, especially on the
712+
second run after JIT compilation.
703713

704714
However, with JAX, we have another trick up our sleeve --- we can JIT-compile
705715
the *entire* function, not just individual operations.
@@ -718,13 +728,17 @@ f_jax = jax.jit(f)
718728

719729
```{code-cell}
720730
with qe.Timer():
731+
# First run
721732
y = f_jax(x)
733+
# Hold interpreter
722734
jax.block_until_ready(y);
723735
```
724736

725737
```{code-cell}
726738
with qe.Timer():
739+
# Second run
727740
y = f_jax(x)
741+
# Hold interpreter
728742
jax.block_until_ready(y);
729743
```
730744

@@ -734,7 +748,6 @@ allowing the compiler to optimize more aggressively.
734748
For example, the compiler can eliminate multiple calls to the hardware
735749
accelerator and the creation of a number of intermediate arrays.
736750

737-
738751
Incidentally, a more common syntax when targeting a function for the JIT
739752
compiler is
740753

@@ -811,21 +824,6 @@ f(x)
811824
Moral of the story: write pure functions when using JAX!
812825

813826

814-
### Summary
815-
816-
Now we can see why both developers and compilers benefit from pure functions.
817-
818-
We love pure functions because they
819-
820-
* Help testing: each function can operate in isolation
821-
* Promote deterministic behavior and hence reproducibility
822-
* Prevent bugs that arise from mutating shared state
823-
824-
The compiler loves pure functions and functional programming because
825-
826-
* Data dependencies are explicit, which helps with optimizing complex computations
827-
* Pure functions are easier to differentiate (autodiff)
828-
* Pure functions are easier to parallelize and optimize (don't depend on shared mutable state)
829827

830828

831829
## Vectorization with `vmap`
@@ -838,18 +836,18 @@ This avoids the need to manually write vectorized code or use explicit loops.
838836

839837
### A simple example
840838

841-
Suppose we have a function that computes summary statistics for a single array:
839+
Suppose we have a function that computes the difference between mean and median for an array of numbers.
842840

843841
```{code-cell} ipython3
844-
def summary(x):
845-
return jnp.mean(x), jnp.median(x)
842+
def mm_diff(x):
843+
return jnp.mean(x) - jnp.median(x)
846844
```
847845

848846
We can apply it to a single vector:
849847

850848
```{code-cell} ipython3
851849
x = jnp.array([1.0, 2.0, 5.0])
852-
summary(x)
850+
mm_diff(x)
853851
```
854852

855853
Now suppose we have a matrix and want to compute these statistics for each row.
@@ -862,7 +860,7 @@ X = jnp.array([[1.0, 2.0, 5.0],
862860
[1.0, 8.0, 9.0]])
863861
864862
for row in X:
865-
print(summary(row))
863+
print(mm_diff(row))
866864
```
867865

868866
However, Python loops are slow and cannot be efficiently compiled or
@@ -872,11 +870,11 @@ Using `vmap` keeps the computation on the accelerator and composes with other
872870
JAX transformations like `jit` and `grad`:
873871

874872
```{code-cell} ipython3
875-
batch_summary = jax.vmap(summary)
876-
batch_summary(X)
873+
batch_mm_diff = jax.vmap(mm_diff)
874+
batch_mm_diff(X)
877875
```
878876

879-
The function `summary` was written for a single array, and `vmap` automatically
877+
The function `mm_diff` was written for a single array, and `vmap` automatically
880878
lifted it to operate row-wise over a matrix --- no loops, no reshaping.
881879

882880
### Combining transformations
@@ -886,8 +884,8 @@ One of JAX's strengths is that transformations compose naturally.
886884
For example, we can JIT-compile a vectorized function:
887885

888886
```{code-cell} ipython3
889-
fast_batch_summary = jax.jit(jax.vmap(summary))
890-
fast_batch_summary(X)
887+
fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
888+
fast_batch_mm_diff(X)
891889
```
892890

893891
This composition of `jit`, `vmap`, and (as we'll see next) `grad` is central to

0 commit comments

Comments
 (0)