@@ -15,6 +15,9 @@ kernelspec:
1515
1616This lecture provides a short introduction to [ Google JAX] ( https://github.com/jax-ml/jax ) .
1717
18+ ``` {include} _admonition/gpu.md
19+ ```
20+
1821JAX is a high-performance scientific computing library that provides
1922
2023* a [ NumPy] ( https://en.wikipedia.org/wiki/NumPy ) -like interface that can automatically parallelize across CPUs and GPUs,
@@ -33,9 +36,19 @@ In addition to what's in Anaconda, this lecture will need the following librarie
3336!pip install jax quantecon
3437```
3538
36- ``` {include} _admonition/gpu.md
39+ We'll use the following imports
40+
41+ ``` {code-cell} ipython3
42+ import jax
43+ import jax.numpy as jnp
44+ import matplotlib.pyplot as plt
45+ import numpy as np
46+ import quantecon as qe
3747```
3848
49+ Notice that we import ` jax.numpy as jnp ` , which provides a NumPy-like interface.
50+
51+
3952## JAX as a NumPy Replacement
4053
4154One of the attractive features of JAX is that, whenever possible, its array
@@ -47,17 +60,6 @@ Let's look at the similarities and differences between JAX and NumPy.
4760
4861### Similarities
4962
50- We'll use the following imports
51-
52- ``` {code-cell} ipython3
53- import jax
54- import jax.numpy as jnp
55- import matplotlib.pyplot as plt
56- import numpy as np
57- import quantecon as qe
58- ```
59-
60- Notice that we import ` jax.numpy as jnp ` , which provides a NumPy-like interface.
6163
6264Here are some standard array operations using ` jnp ` :
6365
@@ -73,10 +75,6 @@ print(a)
7375print(jnp.sum(a))
7476```
7577
76- ``` {code-cell} ipython3
77- print(jnp.mean(a))
78- ```
79-
8078``` {code-cell} ipython3
8179print(jnp.dot(a, a))
8280```
9189type(a)
9290```
9391
94- Even scalar-valued maps on arrays return JAX arrays.
92+ Even scalar-valued maps on arrays return JAX arrays rather than scalars!
9593
9694``` {code-cell} ipython3
9795jnp.sum(a)
9896```
9997
100- Operations on higher dimensional arrays are also similar to NumPy:
101-
102- ``` {code-cell} ipython3
103- A = jnp.ones((2, 2))
104- B = jnp.identity(2)
105- A @ B
106- ```
107-
108- JAX's array interface also provides the ` linalg ` subpackage:
109-
110- ``` {code-cell} ipython3
111- jnp.linalg.inv(B) # Inverse of identity is identity
112- ```
113-
114- ``` {code-cell} ipython3
115- eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
116- eigvals
117- ```
11898
11999
120100### Differences
@@ -137,13 +117,15 @@ Let's try with NumPy
137117
138118``` {code-cell}
139119with qe.Timer():
120+ # First NumPy timing
140121 y = np.cos(x)
141122```
142123
143124And one more time.
144125
145126``` {code-cell}
146127with qe.Timer():
128+ # Second NumPy timing
147129 y = np.cos(x)
148130```
149131
@@ -165,7 +147,9 @@ Let's time the same procedure.
165147
166148``` {code-cell}
167149with qe.Timer():
150+ # First run
168151 y = jnp.cos(x)
152+ # Hold the interpreter until the array operation finishes
169153 jax.block_until_ready(y);
170154```
171155
@@ -184,7 +168,9 @@ And let's time it again.
184168
185169``` {code-cell}
186170with qe.Timer():
171+ # Second run
187172 y = jnp.cos(x)
173+ # Hold interpreter
188174 jax.block_until_ready(y);
189175```
190176
@@ -212,14 +198,18 @@ x = jnp.linspace(0, 10, n + 1)
212198
213199``` {code-cell}
214200with qe.Timer():
201+ # First run
215202 y = jnp.cos(x)
203+ # Hold interpreter
216204 jax.block_until_ready(y);
217205```
218206
219207
220208``` {code-cell}
221209with qe.Timer():
210+ # Second run
222211 y = jnp.cos(x)
212+ # Hold interpreter
223213 jax.block_until_ready(y);
224214```
225215
@@ -294,7 +284,7 @@ functional programming style, which we discuss below.
294284
295285#### A workaround
296286
297- We note that JAX does provide a version of in-place array modification
287+ We note that JAX does provide an alternative to in-place array modification
298288using the [ ` at ` method] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) .
299289
300290``` {code-cell} ipython3
@@ -387,11 +377,26 @@ This pure version makes all dependencies explicit through function arguments, an
387377
388378### Why Functional Programming?
389379
390- JAX represents functions as computational graphs, which are then compiled or transformed (e.g., differentiated)
380+ At QuantEcon we love pure functions because they
381+
382+ * Help testing: each function can operate in isolation
383+ * Promote deterministic behavior and hence reproducibility
384+ * Prevent bugs that arise from mutating shared state
385+
386+ The JAX compiler loves pure functions and functional programming because
387+
388+ * Data dependencies are explicit, which helps with optimizing complex computations
389+ * Pure functions are easier to differentiate (autodiff)
390+ * Pure functions are easier to parallelize and optimize (don't depend on shared mutable state)
391+
392+ Another way to think of this is as follows:
393+
394+ JAX represents functions as computational graphs, which are then compiled or
395+ transformed (e.g., differentiated)
391396
392397These computational graphs describe how a given set of inputs is transformed into an output.
393398
394- They are pure by construction.
399+ JAX's computational graphs are pure by construction.
395400
396401JAX uses a functional programming style so that user-built functions map
397402directly into the graph-theoretic representations supported by JAX.
@@ -520,8 +525,8 @@ plt.tight_layout()
520525plt.show()
521526```
522527
523- This syntax will seem unusual for a NumPy or Matlab user --- but will make a lot
524- of sense when we progress to parallel programming.
528+ This syntax will seem unusual for a NumPy or Matlab user --- but will make more
529+ sense when we get to parallel programming.
525530
526531The function below produces ` k ` (quasi-) independent random ` n x n ` matrices using ` split ` .
527532
@@ -664,6 +669,7 @@ x = np.linspace(0, 10, n)
664669
665670``` {code-cell}
666671with qe.Timer():
672+ # Time NumPy code
667673 y = f(x)
668674```
669675
@@ -679,27 +685,31 @@ As a first pass, we replace `np` with `jnp` throughout:
679685def f(x):
680686 y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
681687 return y
682- ```
683688
684- Now let's time it.
685689
686- ``` {code-cell}
687690x = jnp.linspace(0, 10, n)
688691```
689692
693+ Now let's time it.
694+
690695``` {code-cell}
691696with qe.Timer():
697+ # First call
692698 y = f(x)
699+ # Hold interpreter
693700 jax.block_until_ready(y);
694701```
695702
696703``` {code-cell}
697704with qe.Timer():
705+ # Second call
698706 y = f(x)
707+ # Hold interpreter
699708 jax.block_until_ready(y);
700709```
701710
702- The outcome is similar to the ` cos ` example --- JAX is faster, especially on the second run after JIT compilation.
711+ The outcome is similar to the ` cos ` example --- JAX is faster, especially on the
712+ second run after JIT compilation.
703713
704714However, with JAX, we have another trick up our sleeve --- we can JIT-compile
705715the * entire* function, not just individual operations.
@@ -718,13 +728,17 @@ f_jax = jax.jit(f)
718728
719729``` {code-cell}
720730with qe.Timer():
731+ # First run
721732 y = f_jax(x)
733+ # Hold interpreter
722734 jax.block_until_ready(y);
723735```
724736
725737``` {code-cell}
726738with qe.Timer():
739+ # Second run
727740 y = f_jax(x)
741+ # Hold interpreter
728742 jax.block_until_ready(y);
729743```
730744
@@ -734,7 +748,6 @@ allowing the compiler to optimize more aggressively.
734748For example, the compiler can eliminate multiple calls to the hardware
735749accelerator and the creation of a number of intermediate arrays.
736750
737-
738751Incidentally, a more common syntax when targeting a function for the JIT
739752compiler is
740753
@@ -811,21 +824,6 @@ f(x)
811824Moral of the story: write pure functions when using JAX!
812825
813826
814- ### Summary
815-
816- Now we can see why both developers and compilers benefit from pure functions.
817-
818- We love pure functions because they
819-
820- * Help testing: each function can operate in isolation
821- * Promote deterministic behavior and hence reproducibility
822- * Prevent bugs that arise from mutating shared state
823-
824- The compiler loves pure functions and functional programming because
825-
826- * Data dependencies are explicit, which helps with optimizing complex computations
827- * Pure functions are easier to differentiate (autodiff)
828- * Pure functions are easier to parallelize and optimize (don't depend on shared mutable state)
829827
830828
831829## Vectorization with ` vmap `
@@ -838,18 +836,18 @@ This avoids the need to manually write vectorized code or use explicit loops.
838836
839837### A simple example
840838
841- Suppose we have a function that computes summary statistics for a single array:
839+ Suppose we have a function that computes the difference between mean and median for an array of numbers.
842840
843841``` {code-cell} ipython3
844- def summary (x):
845- return jnp.mean(x), jnp.median(x)
842+ def mm_diff (x):
843+ return jnp.mean(x) - jnp.median(x)
846844```
847845
848846We can apply it to a single vector:
849847
850848``` {code-cell} ipython3
851849x = jnp.array([1.0, 2.0, 5.0])
852- summary (x)
850+ mm_diff (x)
853851```
854852
855853Now suppose we have a matrix and want to compute these statistics for each row.
@@ -862,7 +860,7 @@ X = jnp.array([[1.0, 2.0, 5.0],
862860 [1.0, 8.0, 9.0]])
863861
864862for row in X:
865- print(summary (row))
863+ print(mm_diff (row))
866864```
867865
868866However, Python loops are slow and cannot be efficiently compiled or
@@ -872,11 +870,11 @@ Using `vmap` keeps the computation on the accelerator and composes with other
872870JAX transformations like ` jit ` and ` grad ` :
873871
874872``` {code-cell} ipython3
875- batch_summary = jax.vmap(summary )
876- batch_summary (X)
873+ batch_mm_diff = jax.vmap(mm_diff )
874+ batch_mm_diff (X)
877875```
878876
879- The function ` summary ` was written for a single array, and ` vmap ` automatically
877+ The function ` mm_diff ` was written for a single array, and ` vmap ` automatically
880878lifted it to operate row-wise over a matrix --- no loops, no reshaping.
881879
882880### Combining transformations
@@ -886,8 +884,8 @@ One of JAX's strengths is that transformations compose naturally.
886884For example, we can JIT-compile a vectorized function:
887885
888886``` {code-cell} ipython3
889- fast_batch_summary = jax.jit(jax.vmap(summary ))
890- fast_batch_summary (X)
887+ fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff ))
888+ fast_batch_mm_diff (X)
891889```
892890
893891This composition of ` jit ` , ` vmap ` , and (as we'll see next) ` grad ` is central to
0 commit comments