Skip to content

Commit 9a9cb8b

Browse files
jstacclaude
andcommitted
Improve and simplify jax_intro and numpy_vs_numba_vs_jax lectures
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 762d0fb commit 9a9cb8b

File tree

2 files changed

+109
-196
lines changed

2 files changed

+109
-196
lines changed

lectures/jax_intro.md

Lines changed: 63 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,34 @@ directly into the graph-theoretic representations supported by JAX.
416416

417417
Random number generation in JAX differs significantly from the patterns found in NumPy or MATLAB.
418418

419-
At first you might find the syntax rather verbose.
420419

421-
But the syntax and semantics are necessary to maintain the functional programming style we just discussed.
422420

423-
Moreover, full control of random state is essential for parallel programming,
424-
such as when we want to run independent experiments along multiple threads.
421+
### NumPy / MATLAB Approach
425422

423+
In NumPy / MATLAB, generation works by maintaining hidden global state.
424+
425+
```{code-cell} ipython3
426+
np.random.seed(42)
427+
print(np.random.randn(2))
428+
```
429+
430+
Each time we call a random function, the hidden state is updated:
431+
432+
```{code-cell} ipython3
433+
print(np.random.randn(2))
434+
```
435+
436+
This function is *not pure* because:
437+
438+
* It's non-deterministic: same inputs, different outputs
439+
* It has side effects: it modifies the global random number generator state
440+
441+
Dangerous under parallelization --- must carefully control what happens in each
442+
thread!
443+
444+
445+
### JAX
426446

427-
### Random number generation
428447

429448
In JAX, the state of the random number generator is controlled explicitly.
430449

@@ -547,105 +566,30 @@ def gen_random_matrices(key, n=2, k=3):
547566
key, subkey = jax.random.split(key)
548567
A = jax.random.uniform(subkey, (n, n))
549568
matrices.append(A)
550-
print(A)
551569
return matrices
552570
```
553571

554572
```{code-cell} ipython3
555573
seed = 42
556574
key = jax.random.key(seed)
557-
matrices = gen_random_matrices(key)
558-
```
559-
560-
We can also use `fold_in` when iterating in a loop:
561-
562-
```{code-cell} ipython3
563-
def gen_random_matrices(key, n=2, k=3):
564-
matrices = []
565-
for i in range(k):
566-
step_key = jax.random.fold_in(key, i)
567-
A = jax.random.uniform(step_key, (n, n))
568-
matrices.append(A)
569-
print(A)
570-
return matrices
571-
```
572-
573-
```{code-cell} ipython3
574-
key = jax.random.key(seed)
575-
matrices = gen_random_matrices(key)
576-
```
577-
578-
579-
### Why explicit random state?
580-
581-
Why does JAX require this somewhat verbose approach to random number generation?
582-
583-
One reason is to maintain pure functions.
584-
585-
Let's see how random number generation relates to pure functions by comparing NumPy and JAX.
586-
587-
#### NumPy's approach
588-
589-
In NumPy's legacy random number generation API (which mimics MATLAB), generation
590-
works by maintaining hidden global state.
591-
592-
Each time we call a random function, this state is updated:
593-
594-
```{code-cell} ipython3
595-
np.random.seed(42)
596-
print(np.random.randn()) # Updates state of random number generator
597-
print(np.random.randn()) # Updates state of random number generator
575+
gen_random_matrices(key)
598576
```
599577

600-
Each call returns a different value, even though we're calling the same function with the same inputs (no arguments).
601-
602-
This function is *not pure* because:
603-
604-
* It's non-deterministic: same inputs (none, in this case) give different outputs
605-
* It has side effects: it modifies the global random number generator state
606-
607-
608-
#### JAX's approach
609-
610-
As we saw above, JAX takes a different approach, making randomness explicit through keys.
611-
612-
For example,
613-
614-
```{code-cell} ipython3
615-
def random_sum_jax(key):
616-
key1, key2 = jax.random.split(key)
617-
x = jax.random.normal(key1)
618-
y = jax.random.normal(key2)
619-
return x + y
620-
```
621-
622-
With the same key, we always get the same result:
623-
624-
```{code-cell} ipython3
625-
key = jax.random.key(42)
626-
random_sum_jax(key)
627-
```
628-
629-
```{code-cell} ipython3
630-
random_sum_jax(key)
631-
```
578+
This function is *pure*
632579

633-
To get new draws we need to supply a new key.
580+
* Deterministic: same inputs, same output
581+
* No side effects: no hidden state is modified
634582

635-
The function `random_sum_jax` is pure because:
636583

637-
* It's deterministic: same key always produces same output
638-
* No side effects: no hidden state is modified
584+
### Benefits
639585

640586
The explicitness of JAX brings significant benefits:
641587

642588
* Reproducibility: Easy to reproduce results by reusing keys
643-
* Parallelization: Each thread can have its own key without conflicts
644-
* Debugging: No hidden state makes code easier to reason about
589+
* Parallelization: Control what happens on separate threads
590+
* Debugging: No hidden state makes code easier to test
645591
* JIT compatibility: The compiler can optimize pure functions more aggressively
646592

647-
The last point is expanded on in the next section.
648-
649593

650594
## JIT Compilation
651595

@@ -655,17 +599,20 @@ efficient machine code that varies with both task size and hardware.
655599
We saw the power of JAX's JIT compiler combined with parallel hardware when we
656600
{ref}`above <jax_speed>`, when we applied `cos` to a large array.
657601

658-
Let's try the same thing with a more complex function:
602+
Here we study JIT compilation for more complex functions
603+
604+
605+
### With NumPy
606+
607+
We'll try first with NumPy, using
659608

660609
```{code-cell}
661610
def f(x):
662611
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
663612
return y
664613
```
665614

666-
### With NumPy
667-
668-
We'll try first with NumPy
615+
Let's run with large `x`
669616

670617
```{code-cell}
671618
n = 50_000_000
@@ -678,11 +625,20 @@ with qe.Timer():
678625
y = f(x)
679626
```
680627

628+
**Eager** execution model
681629

630+
* Each operation is executed immediately as it is encountered, materializing its
631+
result before the next operation begins.
682632

683-
### With JAX
633+
Disadvantages
684634

685-
Now let's try again with JAX.
635+
* Minimal parallelization
636+
* Heavy memory footprint --- produces many intermediate arrays
637+
* Lots of memory read/write
638+
639+
640+
641+
### With JAX
686642

687643
As a first pass, we replace `np` with `jnp` throughout:
688644

@@ -716,14 +672,15 @@ with qe.Timer():
716672
The outcome is similar to the `cos` example --- JAX is faster, especially on the
717673
second run after JIT compilation.
718674

719-
However, with JAX, we have another trick up our sleeve --- we can JIT-compile
720-
the entire function, not just individual operations.
675+
But we are still using eager execution --- lots of memory and read/write
721676

722677

723678
### Compiling the Whole Function
724679

725-
The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array
726-
operations into a single optimized kernel.
680+
Fortunately, with JAX, we have another trick up our sleeve --- we can JIT-compile
681+
the entire function, not just individual operations.
682+
683+
The compiler fuses all array operations into a single optimized kernel
727684

728685
Let's try this with the function `f`:
729686

@@ -747,11 +704,11 @@ with qe.Timer():
747704
jax.block_until_ready(y);
748705
```
749706

750-
The runtime has improved again --- now because we fused all the operations,
751-
allowing the compiler to optimize more aggressively.
707+
The runtime has improved again --- now because we fused all the operations
752708

753-
For example, the compiler can eliminate multiple calls to the hardware
754-
accelerator and the creation of a number of intermediate arrays.
709+
* Aggressive optimization based on entire computational sequence
710+
* Eliminates multiple calls to the hardware accelerator
711+
* No creation of intermediate arrays
755712

756713
Incidentally, a more common syntax when targeting a function for the JIT
757714
compiler is
@@ -777,16 +734,12 @@ subsequent calls with the same input shapes and types reuse the cached
777734
compiled code and run at full speed.
778735

779736

780-
781737
### Compiling non-pure functions
782738

783-
Now that we've seen how powerful JIT compilation can be, it's important to
784-
understand its relationship with pure functions.
785-
786739
While JAX will not usually throw errors when compiling impure functions,
787-
execution becomes unpredictable.
740+
execution becomes unpredictable!
788741

789-
Here's an illustration of this fact, using global variables:
742+
Here's an illustration of this fact:
790743

791744
```{code-cell} ipython3
792745
a = 1 # global
@@ -871,16 +824,13 @@ for row in X:
871824
However, Python loops are slow and cannot be efficiently compiled or
872825
parallelized by JAX.
873826

874-
Using `vmap` keeps the computation on the accelerator and composes with other
875-
JAX transformations like `jit` and `grad`:
827+
With `vmap`, we can avoid loops and keep the computation on the accelerator:
876828

877829
```{code-cell} ipython3
878-
batch_mm_diff = jax.vmap(mm_diff)
879-
batch_mm_diff(X)
830+
batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
831+
batch_mm_diff(X) # Apply to each row of X
880832
```
881833

882-
The function `mm_diff` was written for a single array, and `vmap` automatically
883-
lifted it to operate row-wise over a matrix --- no loops, no reshaping.
884834

885835
### Combining transformations
886836

0 commit comments

Comments
 (0)