Skip to content

Commit 11e7d82

Browse files
jstacclaude
andauthored
Add lax.fori_loop example and improve jax_intro clarity (#530)
- Add lax.fori_loop example to sequential operations section, alongside existing lax.scan version - Update summaries and recommendations to reference both approaches - Improve jax_intro: reorganize intro text, fix block_until_ready style, add Size Experiment subheading, expand immutability explanation, add cross-reference labels - Remove unused import Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 95378b8 commit 11e7d82

File tree

2 files changed

+102
-52
lines changed

2 files changed

+102
-52
lines changed

lectures/jax_intro.md

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,21 @@ import numpy as np
4646
import quantecon as qe
4747
```
4848

49-
Notice that we import `jax.numpy as jnp`, which provides a NumPy-like interface.
5049

5150

5251
## JAX as a NumPy Replacement
5352

54-
One of the attractive features of JAX is that, whenever possible, its array
55-
processing operations conform to the NumPy API.
56-
57-
This means that, in many cases, we can use JAX as a drop-in NumPy replacement.
58-
5953
Let's look at the similarities and differences between JAX and NumPy.
6054

6155
### Similarities
6256

57+
Above we import `jax.numpy as jnp`, which provides a NumPy-like interface to
58+
array operations.
59+
60+
One of the attractive features of JAX is that, whenever possible, this interface
61+
conform to the NumPy API.
62+
63+
As a result, we can often use JAX as a drop-in NumPy replacement.
6364

6465
Here are some standard array operations using `jnp`:
6566

@@ -79,7 +80,7 @@ print(jnp.sum(a))
7980
print(jnp.dot(a, a))
8081
```
8182

82-
However, the array object `a` is not a NumPy array:
83+
It should be remembered, however, that the array object `a` is not a NumPy array:
8384

8485
```{code-cell} ipython3
8586
a
@@ -104,11 +105,13 @@ Let's now look at some differences between JAX and NumPy array operations.
104105
(jax_speed)=
105106
#### Speed!
106107

107-
Let's say we want to evaluate the cosine function at many points.
108+
One major difference is that JAX is faster --- and sometimes much faster.
109+
110+
To illustrate, suppose that we want to evaluate the cosine function at many points.
108111

109112
```{code-cell}
110113
n = 50_000_000
111-
x = np.linspace(0, 10, n)
114+
x = np.linspace(0, 10, n) # NumPy array
112115
```
113116

114117
##### With NumPy
@@ -150,28 +153,24 @@ with qe.Timer():
150153
# First run
151154
y = jnp.cos(x)
152155
# Hold the interpreter until the array operation finishes
153-
jax.block_until_ready(y);
156+
y.block_until_ready()
154157
```
155158

156159
```{note}
157-
Here, in order to measure actual speed, we use the `block_until_ready` method
158-
to hold the interpreter until the results of the computation are returned.
159-
160-
This is necessary because JAX uses asynchronous dispatch, which
160+
Above, the `block_until_ready` method
161+
holds the interpreter until the results of the computation are returned.
162+
This is necessary for timing execution because JAX uses asynchronous dispatch, which
161163
allows the Python interpreter to run ahead of numerical computations.
162-
163-
For non-timed code, you can drop the line containing `block_until_ready`.
164164
```
165165

166-
And let's time it again.
167-
166+
Now let's time it again.
168167

169168
```{code-cell}
170169
with qe.Timer():
171170
# Second run
172171
y = jnp.cos(x)
173172
# Hold interpreter
174-
jax.block_until_ready(y);
173+
y.block_until_ready()
175174
```
176175

177176
On a GPU, this code runs much faster than its NumPy equivalent.
@@ -190,7 +189,11 @@ being used (as well as the data type).
190189
The size matters for generating optimized code because efficient parallelization
191190
requires matching the size of the task to the available hardware.
192191

193-
We can verify the claim that JAX specializes on array size by changing the input size and watching the runtimes.
192+
193+
#### Size Experiment
194+
195+
We can verify the claim that JAX specializes on array size by changing the input
196+
size and watching the runtimes.
194197

195198
```{code-cell}
196199
x = jnp.linspace(0, 10, n + 1)
@@ -201,7 +204,7 @@ with qe.Timer():
201204
# First run
202205
y = jnp.cos(x)
203206
# Hold interpreter
204-
jax.block_until_ready(y);
207+
y.block_until_ready()
205208
```
206209

207210

@@ -210,7 +213,7 @@ with qe.Timer():
210213
# Second run
211214
y = jnp.cos(x)
212215
# Hold interpreter
213-
jax.block_until_ready(y);
216+
y.block_until_ready()
214217
```
215218

216219
The run time increases and then falls again (this will be more obvious on the GPU).
@@ -263,7 +266,8 @@ a[0] = 1
263266
a
264267
```
265268

266-
In JAX this fails!
269+
In JAX this fails 😱.
270+
267271

268272
```{code-cell} ipython3
269273
a = jnp.linspace(0, 1, 3)
@@ -278,14 +282,19 @@ except Exception as e:
278282
279283
```
280284

281-
The designers of JAX chose to make arrays immutable because JAX uses a
282-
functional programming style, which we discuss below.
285+
The designers of JAX chose to make arrays immutable because
286+
287+
1. JAX uses a *functional programming style* and
288+
2. functional programming typically avoids mutable data
289+
290+
We discuss these ideas {ref}`below <jax_func>`.
283291

284292

285-
#### A workaround
293+
(jax_at_workaround)=
294+
#### A Workaround
286295

287-
We note that JAX does provide an alternative to in-place array modification
288-
using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
296+
JAX does provide a direct alternative to in-place array modification
297+
via the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
289298

290299
```{code-cell} ipython3
291300
a = jnp.linspace(0, 1, 3)
@@ -308,6 +317,7 @@ Hence, for the most part, we try to avoid this syntax.
308317
(Although it can in fact be efficient inside JIT-compiled functions -- but let's put this aside for now.)
309318

310319

320+
(jax_func)=
311321
## Functional Programming
312322

313323
From JAX's documentation:

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ tags: [hide-output]
5454
We will use the following imports.
5555

5656
```{code-cell} ipython3
57-
import random
5857
from functools import partial
5958
6059
import numpy as np
@@ -483,18 +482,67 @@ Notice that the second run is significantly faster after JIT compilation complet
483482

484483
Numba's compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one.
485484

485+
486486
### JAX Version
487487

488-
Now let's create a JAX version using `lax.scan`:
488+
Now let's create a JAX version using `at[t].set` style syntax, which, as
489+
{ref}`discussed in the JAX lecture <jax_at_workaround>`, provides a workaround for immutable arrays.
489490

490-
(We'll hold `n` static because it affects array size and hence JAX wants to
491-
specialize on its value in the compiled code.)
491+
We'll apply a `lax.fori_loop`, which is a version of a for loop that can be compiled by XLA.
492492

493493
```{code-cell} ipython3
494494
cpu = jax.devices("cpu")[0]
495495
496-
@partial(jax.jit, static_argnames=('n',), device=cpu)
497-
def qm_jax(x0, n, α=4.0):
496+
@partial(jax.jit, static_argnames=("n",), device=cpu)
497+
def qm_jax_fori(x0, n, α=4.0):
498+
499+
x = jnp.empty(n + 1).at[0].set(x0)
500+
501+
def update(t, x):
502+
return x.at[t + 1].set(α * x[t] * (1 - x[t]))
503+
504+
x = lax.fori_loop(0, n, update, x)
505+
return x
506+
507+
```
508+
509+
* We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.
510+
* We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism.
511+
512+
Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place.
513+
514+
Let's time it with the same parameters:
515+
516+
```{code-cell} ipython3
517+
with qe.Timer():
518+
# First run
519+
x_jax = qm_jax_fori(0.1, n)
520+
# Hold interpreter
521+
x_jax.block_until_ready()
522+
```
523+
524+
Let's run it again to eliminate compilation overhead:
525+
526+
```{code-cell} ipython3
527+
with qe.Timer():
528+
# Second run
529+
x_jax = qm_jax_fori(0.1, n)
530+
# Hold interpreter
531+
x_jax.block_until_ready()
532+
```
533+
534+
JAX is also quite efficient for this sequential operation.
535+
536+
537+
There's another way we can implement the loop that uses `lax.scan`.
538+
539+
This alternative is arguably more in line with JAX's functional approach ---
540+
although the syntax is difficult to remember.
541+
542+
543+
```{code-cell} ipython3
544+
@partial(jax.jit, static_argnames=("n",), device=cpu)
545+
def qm_jax_scan(x0, n, α=4.0):
498546
def update(x, t):
499547
x_new = α * x * (1 - x)
500548
return x_new, x_new
@@ -505,20 +553,12 @@ def qm_jax(x0, n, α=4.0):
505553

506554
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
507555

508-
```{note}
509-
We specify `device=cpu` in the `jax.jit` decorator because this computation
510-
consists of many small sequential operations, leaving little opportunity for the
511-
GPU to exploit parallelism. As a result, kernel-launch overhead tends to
512-
dominate on the GPU, making the CPU a better
513-
fit.
514-
```
515-
516556
Let's time it with the same parameters:
517557

518558
```{code-cell} ipython3
519559
with qe.Timer():
520560
# First run
521-
x_jax = qm_jax(0.1, n)
561+
x_jax = qm_jax_scan(0.1, n)
522562
# Hold interpreter
523563
x_jax.block_until_ready()
524564
```
@@ -528,13 +568,11 @@ Let's run it again to eliminate compilation overhead:
528568
```{code-cell} ipython3
529569
with qe.Timer():
530570
# Second run
531-
x_jax = qm_jax(0.1, n)
571+
x_jax = qm_jax_scan(0.1, n)
532572
# Hold interpreter
533573
x_jax.block_until_ready()
534574
```
535575

536-
JAX is also quite efficient for this sequential operation.
537-
538576
Both JAX and Numba deliver strong performance after compilation.
539577

540578

@@ -547,9 +585,11 @@ array and fill it element by element using a standard Python loop.
547585

548586
This is exactly how most programmers think about the algorithm.
549587

550-
The JAX version, on the other hand, requires using `lax.scan`, which is significantly less intuitive.
588+
The JAX versions, on the other hand, require either `lax.fori_loop` or
589+
`lax.scan`, both of which are less intuitive than a standard Python loop.
551590

552-
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.
591+
While JAX's `at[t].set` syntax does allow element-wise updates, the overall code
592+
remains harder to read than the Numba equivalent.
553593

554594
For this type of sequential operation, Numba is the clear winner in terms of
555595
code clarity and ease of implementation.
@@ -575,12 +615,12 @@ For **sequential operations**, Numba has clear advantages.
575615
The code is natural and readable --- just a Python loop with a decorator ---
576616
and performance is excellent.
577617

578-
JAX can handle sequential problems via `lax.scan`, but the syntax is less
579-
intuitive.
618+
JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but
619+
the syntax is less intuitive.
580620

581621
```{note}
582-
One important advantage of `lax.scan` is that it supports automatic
583-
differentiation through the loop, which Numba cannot do.
622+
One important advantage of `lax.fori_loop` and `lax.scan` is that they
623+
support automatic differentiation through the loop, which Numba cannot do.
584624
If you need to differentiate through a sequential computation (e.g., computing
585625
sensitivities of a trajectory to model parameters), JAX is the better choice
586626
despite the less natural syntax.

0 commit comments

Comments
 (0)