Skip to content

Commit d08a73d

Browse files
jstacclaude
andauthored
Misc changes to jax lectures (#533)
* misc * misc * Fix call to add_tax_pure in jax_intro pure function example Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 450bafe commit d08a73d

File tree

2 files changed

+100
-74
lines changed

2 files changed

+100
-74
lines changed

lectures/jax_intro.md

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -351,19 +351,20 @@ In particular, pure functions will always return the same result if invoked with
351351

352352

353353

354-
### Examples
354+
### Examples -- Pure and Impure
355355

356-
Here's an example of a *non-pure* function
356+
Here's an example of a *impure* function
357357

358358
```{code-cell} ipython3
359359
tax_rate = 0.1
360-
prices = [10.0, 20.0]
361360
362361
def add_tax(prices):
363362
for i, price in enumerate(prices):
364363
prices[i] = price * (1 + tax_rate)
365-
print('Post-tax prices: ', prices)
366-
return prices
364+
365+
prices = [10.0, 20.0]
366+
add_tax(prices)
367+
prices
367368
```
368369

369370
This function fails to be pure because
@@ -375,15 +376,22 @@ This function fails to be pure because
375376
Here's a *pure* version
376377

377378
```{code-cell} ipython3
378-
tax_rate = 0.1
379-
prices = (10.0, 20.0)
380379
381380
def add_tax_pure(prices, tax_rate):
382381
new_prices = [price * (1 + tax_rate) for price in prices]
383382
return new_prices
383+
384+
tax_rate = 0.1
385+
prices = (10.0, 20.0)
386+
after_tax_prices = add_tax_pure(prices, tax_rate)
387+
after_tax_prices
384388
```
385389

386-
This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state.
390+
This is pure because
391+
392+
* all dependencies explicit through function arguments
393+
* and doesn't modify any external state
394+
387395

388396
### Why Functional Programming?
389397

@@ -438,8 +446,8 @@ This function is *not pure* because:
438446
* It's non-deterministic: same inputs, different outputs
439447
* It has side effects: it modifies the global random number generator state
440448

441-
Dangerous under parallelization --- must carefully control what happens in each
442-
thread!
449+
This is dangerous under parallelization --- must carefully control what happens in each
450+
thread.
443451

444452

445453
### JAX
@@ -560,7 +568,11 @@ sense when we get to parallel programming.
560568
The function below produces `k` (quasi-) independent random `n x n` matrices using `split`.
561569

562570
```{code-cell} ipython3
563-
def gen_random_matrices(key, n=2, k=3):
571+
def gen_random_matrices(
572+
key, # JAX key for random numbers
573+
n=2, # Matrices will be n x n
574+
k=3 # Number of matrices to generate
575+
):
564576
matrices = []
565577
for _ in range(k):
566578
key, subkey = jax.random.split(key)
@@ -583,7 +595,7 @@ This function is *pure*
583595

584596
### Benefits
585597

586-
The explicitness of JAX brings significant benefits:
598+
As mentioned above, this explicitness is valuable:
587599

588600
* Reproducibility: Easy to reproduce results by reusing keys
589601
* Parallelization: Control what happens on separate threads
@@ -672,8 +684,14 @@ with qe.Timer():
672684
The outcome is similar to the `cos` example --- JAX is faster, especially on the
673685
second run after JIT compilation.
674686

675-
But we are still using eager execution --- lots of memory and read/write
687+
This is because the individual array operations are parallelized on the GPU
676688

689+
But we are still using eager execution
690+
691+
* lots of memory due to intermediate arrays
692+
* lots of memory read/writes
693+
694+
Also, many separate kernels launched on the GPU
677695

678696
### Compiling the Whole Function
679697

@@ -708,7 +726,8 @@ The runtime has improved again --- now because we fused all the operations
708726

709727
* Aggressive optimization based on entire computational sequence
710728
* Eliminates multiple calls to the hardware accelerator
711-
* No creation of intermediate arrays
729+
730+
The memory footprint is also much lower --- no creation of intermediate arrays
712731

713732
Incidentally, a more common syntax when targeting a function for the JIT
714733
compiler is

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,36 @@ for x in grid:
135135

136136
Let's switch to NumPy and use a larger grid
137137

138-
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such
139-
that `f(x, y)` generates all evaluations on the product grid.
138+
```{code-cell} ipython3
139+
grid = np.linspace(-3, 3, 3_000) # Large grid
140+
```
141+
142+
As a first pass of vectorization we might try something like this
143+
144+
```{code-cell} ipython3
145+
# Large grid
146+
z = np.max(f(grid, grid)) # This is wrong!
147+
```
148+
149+
The problem here is that `f(grid, grid)` doesn't obey the nested loop.
150+
151+
In terms of the figure above, it only computes the values of `f` along the
152+
diagonal.
153+
154+
To trick NumPy into calculating `f(x,y)` on every `x,y` pair, we need to use `np.meshgrid`.
155+
156+
Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y`
157+
such that `f(x, y)` generates all evaluations on the product grid.
140158

141159

142160
```{code-cell} ipython3
143161
# Large grid
144162
grid = np.linspace(-3, 3, 3_000)
145163
146-
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
164+
x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
147165
148166
with qe.Timer():
149-
z_max_numpy = np.max(f(x, y))
167+
z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
150168
```
151169

152170
In the vectorized version, all the looping takes place in compiled code.
@@ -159,11 +177,30 @@ The output should be close to one:
159177
print(f"NumPy result: {z_max_numpy:.6f}")
160178
```
161179

180+
### Memory Issues
181+
182+
So we have the right solution in reasonable time --- but memory usage is huge.
183+
184+
While the flat arrays are low-memory
185+
186+
```{code-cell} ipython3
187+
grid.nbytes
188+
```
189+
190+
the mesh grids are two-dimensional and hence very memory intensive
191+
192+
```{code-cell} ipython3
193+
x_mesh.nbytes + y_mesh.nbytes
194+
```
195+
196+
Moreover, NumPy's eager execution creates many intermediate arrays of the same size!
197+
198+
This kind of memory usage can be a big problem in actual research calculations.
162199

163200

164201
### A Comparison with Numba
165202

166-
Now let's see if we can achieve better performance using Numba with a simple loop.
203+
Let's see if we can achieve better performance using Numba with a simple loop.
167204

168205
```{code-cell} ipython3
169206
@numba.jit
@@ -194,15 +231,13 @@ with qe.Timer():
194231
compute_max_numba(grid)
195232
```
196233

197-
Depending on your machine, the Numba version might be either slower or faster than NumPy.
234+
Notice how we are using almost no memory --- we just need the one-dimensional `grid`
198235

199-
In most cases we find that Numba is slightly better.
236+
Moreover, execution speed is good.
200237

201-
On the one hand, NumPy combines efficient arithmetic with some
202-
multithreading, which provides an advantage.
238+
On most machines, the Numba version will be somewhat faster than NumPy.
203239

204-
On the other hand, the Numba routine uses much less memory, since we are only
205-
working with a single one-dimensional grid.
240+
The reason is efficient machine code plus less memory read-write.
206241

207242

208243
### Parallelized Numba
@@ -301,27 +336,11 @@ The compilation overhead is a one-time cost that pays off when the function is c
301336

302337
### JAX plus vmap
303338

304-
There is one problem with both the NumPy code and the JAX code above:
305-
306-
While the flat arrays are low-memory
307-
308-
```{code-cell} ipython3
309-
grid.nbytes
310-
```
311-
312-
the mesh grids are memory intensive
313-
314-
```{code-cell} ipython3
315-
x_mesh.nbytes + y_mesh.nbytes
316-
```
339+
Because we used `jax.jit` above, we avoided creating many intermediate arrays.
317340

318-
This extra memory usage can be a big problem in actual research calculations.
341+
But we still create the big arrays `z_max`, `x_mesh`, and `y_mesh`.
319342

320-
Fortunately, JAX admits a different approach
321-
using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).
322-
323-
The idea of `vmap` is to break vectorization into stages, transforming a
324-
function that operates on single values into one that operates on arrays.
343+
Fortunately, we can avoid this by using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html).
325344

326345
Here's how we can apply it to our problem.
327346

@@ -330,13 +349,13 @@ Here's how we can apply it to our problem.
330349
@jax.jit
331350
def compute_max_vmap(grid):
332351
# Construct a function that takes the max over all x for given y
333-
f_vec_x_max = lambda y: jnp.max(f(grid, y))
352+
compute_column_max = lambda y: jnp.max(f(grid, y))
334353
# Vectorize the function so we can call on all y simultaneously
335-
f_vec_max = jax.vmap(f_vec_x_max)
336-
# Compute the max across x at every y
337-
maxes = f_vec_max(grid)
338-
# Compute the max of the maxes and return
339-
return jnp.max(maxes)
354+
vectorized_compute_column_max = jax.vmap(compute_column_max)
355+
# Compute the column max at every row
356+
column_maxes = vectorized_compute_column_max(grid)
357+
# Compute the max of the column maxes and return
358+
return jnp.max(column_maxes)
340359
```
341360

342361
Note that we never create
@@ -345,6 +364,8 @@ Note that we never create
345364
* the two-dimensional grid `y_mesh` or
346365
* the two-dimensional array `f(x,y)`
347366

367+
Like Numba, we just use the flat array `grid`.
368+
348369
And because everything is under a single `@jax.jit`, the compiler can fuse
349370
all operations into one optimized kernel.
350371

@@ -378,18 +399,14 @@ In our view, JAX is the winner for vectorized operations.
378399
It dominates NumPy both in terms of speed (via JIT-compilation and
379400
parallelization) and memory efficiency (via vmap).
380401

381-
Moreover, the `vmap` approach can sometimes lead to significantly clearer code.
382-
383-
While Numba is impressive, the beauty of JAX is that, with fully vectorized
384-
operations, we can run exactly the same code on machines with hardware
385-
accelerators and reap all the benefits without extra effort.
402+
It also dominates Numba when run on the GPU.
386403

387-
Moreover, JAX already knows how to effectively parallelize many common array
388-
operations, which is key to fast execution.
389-
390-
For most cases encountered in economics, econometrics, and finance, it is
404+
```{note}
405+
Numba can support GPU programming through `numba.cuda` but then we need to
406+
parallelize by hand. For most cases encountered in economics, econometrics, and finance, it is
391407
far better to hand over to the JAX compiler for efficient parallelization than to
392-
try to hand code these routines ourselves.
408+
try to hand-code these routines ourselves.
409+
```
393410

394411

395412
## Sequential operations
@@ -554,8 +571,6 @@ The JAX versions, on the other hand, require either `lax.fori_loop` or
554571
While JAX's `at[t].set` syntax does allow element-wise updates, the overall code
555572
remains harder to read than the Numba equivalent.
556573

557-
For this type of sequential operation, Numba is the clear winner in terms of
558-
code clarity and ease of implementation.
559574

560575

561576
## Overall recommendations
@@ -573,25 +588,17 @@ than traditional meshgrid-based vectorization.
573588
In addition, JAX functions are automatically differentiable, as we explore in
574589
{doc}`autodiff`.
575590

576-
For **sequential operations**, Numba has clear advantages.
591+
For **sequential operations**, Numba has nicer syntax.
577592

578593
The code is natural and readable --- just a Python loop with a decorator ---
579594
and performance is excellent.
580595

581596
JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but
582597
the syntax is less intuitive.
583598

584-
```{note}
585-
One important advantage of `lax.fori_loop` and `lax.scan` is that they
586-
support automatic differentiation through the loop, which Numba cannot do.
587-
If you need to differentiate through a sequential computation (e.g., computing
588-
sensitivities of a trajectory to model parameters), JAX is the better choice
589-
despite the less natural syntax.
590-
```
599+
On the other hand, the JAX versions support automatic differentiation.
591600

592-
In practice, many problems involve a mix of both patterns.
601+
That might be of interest if, say, we want to compute sensitivities of a
602+
trajectory to model parameters
593603

594-
A good rule of thumb: default to JAX for new projects, especially when
595-
hardware acceleration or differentiability might be useful, and reach for Numba
596-
when you have a tight sequential loop that needs to be fast and readable.
597604

0 commit comments

Comments
 (0)