@@ -416,15 +416,34 @@ directly into the graph-theoretic representations supported by JAX.
416416
417417Random number generation in JAX differs significantly from the patterns found in NumPy or MATLAB.
418418
419- At first you might find the syntax rather verbose.
420419
421- But the syntax and semantics are necessary to maintain the functional programming style we just discussed.
422420
423- Moreover, full control of random state is essential for parallel programming,
424- such as when we want to run independent experiments along multiple threads.
421+ ### NumPy / MATLAB Approach
425422
423+ In NumPy / MATLAB, generation works by maintaining hidden global state.
424+
425+ ``` {code-cell} ipython3
426+ np.random.seed(42)
427+ print(np.random.randn(2))
428+ ```
429+
430+ Each time we call a random function, the hidden state is updated:
431+
432+ ``` {code-cell} ipython3
433+ print(np.random.randn(2))
434+ ```
435+
436+ This function is * not pure* because:
437+
438+ * It's non-deterministic: same inputs, different outputs
439+ * It has side effects: it modifies the global random number generator state
440+
441+ Dangerous under parallelization --- must carefully control what happens in each
442+ thread!
443+
444+
445+ ### JAX
426446
427- ### Random number generation
428447
429448In JAX, the state of the random number generator is controlled explicitly.
430449
@@ -547,105 +566,30 @@ def gen_random_matrices(key, n=2, k=3):
547566 key, subkey = jax.random.split(key)
548567 A = jax.random.uniform(subkey, (n, n))
549568 matrices.append(A)
550- print(A)
551569 return matrices
552570```
553571
554572``` {code-cell} ipython3
555573seed = 42
556574key = jax.random.key(seed)
557- matrices = gen_random_matrices(key)
558- ```
559-
560- We can also use ` fold_in ` when iterating in a loop:
561-
562- ``` {code-cell} ipython3
563- def gen_random_matrices(key, n=2, k=3):
564- matrices = []
565- for i in range(k):
566- step_key = jax.random.fold_in(key, i)
567- A = jax.random.uniform(step_key, (n, n))
568- matrices.append(A)
569- print(A)
570- return matrices
571- ```
572-
573- ``` {code-cell} ipython3
574- key = jax.random.key(seed)
575- matrices = gen_random_matrices(key)
576- ```
577-
578-
579- ### Why explicit random state?
580-
581- Why does JAX require this somewhat verbose approach to random number generation?
582-
583- One reason is to maintain pure functions.
584-
585- Let's see how random number generation relates to pure functions by comparing NumPy and JAX.
586-
587- #### NumPy's approach
588-
589- In NumPy's legacy random number generation API (which mimics MATLAB), generation
590- works by maintaining hidden global state.
591-
592- Each time we call a random function, this state is updated:
593-
594- ``` {code-cell} ipython3
595- np.random.seed(42)
596- print(np.random.randn()) # Updates state of random number generator
597- print(np.random.randn()) # Updates state of random number generator
575+ gen_random_matrices(key)
598576```
599577
600- Each call returns a different value, even though we're calling the same function with the same inputs (no arguments).
601-
602- This function is * not pure* because:
603-
604- * It's non-deterministic: same inputs (none, in this case) give different outputs
605- * It has side effects: it modifies the global random number generator state
606-
607-
608- #### JAX's approach
609-
610- As we saw above, JAX takes a different approach, making randomness explicit through keys.
611-
612- For example,
613-
614- ``` {code-cell} ipython3
615- def random_sum_jax(key):
616- key1, key2 = jax.random.split(key)
617- x = jax.random.normal(key1)
618- y = jax.random.normal(key2)
619- return x + y
620- ```
621-
622- With the same key, we always get the same result:
623-
624- ``` {code-cell} ipython3
625- key = jax.random.key(42)
626- random_sum_jax(key)
627- ```
628-
629- ``` {code-cell} ipython3
630- random_sum_jax(key)
631- ```
578+ This function is * pure*
632579
633- To get new draws we need to supply a new key.
580+ * Deterministic: same inputs, same output
581+ * No side effects: no hidden state is modified
634582
635- The function ` random_sum_jax ` is pure because:
636583
637- * It's deterministic: same key always produces same output
638- * No side effects: no hidden state is modified
584+ ### Benefits
639585
640586The explicitness of JAX brings significant benefits:
641587
642588* Reproducibility: Easy to reproduce results by reusing keys
643- * Parallelization: Each thread can have its own key without conflicts
644- * Debugging: No hidden state makes code easier to reason about
589+ * Parallelization: Control what happens on separate threads
590+ * Debugging: No hidden state makes code easier to test
645591* JIT compatibility: The compiler can optimize pure functions more aggressively
646592
647- The last point is expanded on in the next section.
648-
649593
650594## JIT Compilation
651595
@@ -655,17 +599,20 @@ efficient machine code that varies with both task size and hardware.
655599We saw the power of JAX's JIT compiler combined with parallel hardware when we
656600{ref}` above <jax_speed> ` , when we applied ` cos ` to a large array.
657601
658- Let's try the same thing with a more complex function:
602+ Here we study JIT compilation for more complex functions
603+
604+
605+ ### With NumPy
606+
607+ We'll try first with NumPy, using
659608
660609``` {code-cell}
661610def f(x):
662611 y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
663612 return y
664613```
665614
666- ### With NumPy
667-
668- We'll try first with NumPy
615+ Let's run with large ` x `
669616
670617``` {code-cell}
671618n = 50_000_000
@@ -678,11 +625,20 @@ with qe.Timer():
678625 y = f(x)
679626```
680627
628+ ** Eager** execution model
681629
630+ * Each operation is executed immediately as it is encountered, materializing its
631+ result before the next operation begins.
682632
683- ### With JAX
633+ Disadvantages
684634
685- Now let's try again with JAX.
635+ * Minimal parallelization
636+ * Heavy memory footprint --- produces many intermediate arrays
637+ * Lots of memory read/write
638+
639+
640+
641+ ### With JAX
686642
687643As a first pass, we replace ` np ` with ` jnp ` throughout:
688644
@@ -716,14 +672,15 @@ with qe.Timer():
716672The outcome is similar to the ` cos ` example --- JAX is faster, especially on the
717673second run after JIT compilation.
718674
719- However, with JAX, we have another trick up our sleeve --- we can JIT-compile
720- the entire function, not just individual operations.
675+ But we are still using eager execution --- lots of memory and read/write
721676
722677
723678### Compiling the Whole Function
724679
725- The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array
726- operations into a single optimized kernel.
680+ Fortunately, with JAX, we have another trick up our sleeve --- we can JIT-compile
681+ the entire function, not just individual operations.
682+
683+ The compiler fuses all array operations into a single optimized kernel
727684
728685Let's try this with the function ` f ` :
729686
@@ -747,11 +704,11 @@ with qe.Timer():
747704 jax.block_until_ready(y);
748705```
749706
750- The runtime has improved again --- now because we fused all the operations,
751- allowing the compiler to optimize more aggressively.
707+ The runtime has improved again --- now because we fused all the operations
752708
753- For example, the compiler can eliminate multiple calls to the hardware
754- accelerator and the creation of a number of intermediate arrays.
709+ * Aggressive optimization based on entire computational sequence
710+ * Eliminates multiple calls to the hardware accelerator
711+ * No creation of intermediate arrays
755712
756713Incidentally, a more common syntax when targeting a function for the JIT
757714compiler is
@@ -777,16 +734,12 @@ subsequent calls with the same input shapes and types reuse the cached
777734compiled code and run at full speed.
778735
779736
780-
781737### Compiling non-pure functions
782738
783- Now that we've seen how powerful JIT compilation can be, it's important to
784- understand its relationship with pure functions.
785-
786739While JAX will not usually throw errors when compiling impure functions,
787- execution becomes unpredictable.
740+ execution becomes unpredictable!
788741
789- Here's an illustration of this fact, using global variables :
742+ Here's an illustration of this fact:
790743
791744``` {code-cell} ipython3
792745a = 1 # global
@@ -871,16 +824,13 @@ for row in X:
871824However, Python loops are slow and cannot be efficiently compiled or
872825parallelized by JAX.
873826
874- Using ` vmap ` keeps the computation on the accelerator and composes with other
875- JAX transformations like ` jit ` and ` grad ` :
827+ With ` vmap ` , we can avoid loops and keep the computation on the accelerator:
876828
877829``` {code-cell} ipython3
878- batch_mm_diff = jax.vmap(mm_diff)
879- batch_mm_diff(X)
830+ batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
831+ batch_mm_diff(X) # Apply to each row of X
880832```
881833
882- The function ` mm_diff ` was written for a single array, and ` vmap ` automatically
883- lifted it to operate row-wise over a matrix --- no loops, no reshaping.
884834
885835### Combining transformations
886836
0 commit comments