diff --git a/lectures/_static/lecture_specific/need_for_speed/dgx.png b/lectures/_static/lecture_specific/need_for_speed/dgx.png new file mode 100644 index 00000000..fd9d633e Binary files /dev/null and b/lectures/_static/lecture_specific/need_for_speed/dgx.png differ diff --git a/lectures/_static/lecture_specific/need_for_speed/geforce.png b/lectures/_static/lecture_specific/need_for_speed/geforce.png new file mode 100644 index 00000000..1cc62bec Binary files /dev/null and b/lectures/_static/lecture_specific/need_for_speed/geforce.png differ diff --git a/lectures/_toc.yml b/lectures/_toc.yml index 60348f49..97c429c4 100644 --- a/lectures/_toc.yml +++ b/lectures/_toc.yml @@ -23,8 +23,8 @@ parts: numbered: true chapters: - file: numba - - file: parallelization - file: jax_intro + - file: numpy_vs_numba_vs_jax - caption: Working with Data numbered: true chapters: diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 531200e4..0d890d8f 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -11,7 +11,7 @@ kernelspec: name: python3 --- -# An Introduction to JAX +# JAX In addition to what's in Anaconda, this lecture will need the following libraries: @@ -26,35 +26,41 @@ This lecture provides a short introduction to [Google JAX](https://github.com/ja Here we are focused on using JAX on the CPU, rather than on accelerators such as GPUs or TPUs. -This means we will only see a small amount of the possible benefits from using -JAX. +This means we will only see a small amount of the possible benefits from using JAX. -At the same time, JAX computing on the CPU is a good place to start, since the -JAX just-in-time compiler seamlessly handles transitions across different -hardware platforms. +However, JAX seamlessly handles transitions across different hardware platforms. -(In other words, if you do want to shift to using GPUs, you will almost never -need to modify your code.) +As a result, if you run this code on a machine with a GPU and a GPU-aware +version of JAX installed, your code will be automatically accelerated and you +will receive the full benefits. -For a discusson of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html). +For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html). ## JAX as a NumPy Replacement -One way to use JAX is as a plug-in NumPy replacement. Let's look at the -similarities and differences. +One of the attractive features of JAX is that, whenever possible, it conforms to +the NumPy API for array operations. -### Similarities +This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement. + +Let's look at the similarities and differences between JAX and NumPy. +### Similarities -The following import is standard, replacing `import numpy as np`: +We'll use the following imports ```{code-cell} ipython3 import jax -import jax.numpy as jnp import quantecon as qe ``` +In addition, we replace `import numpy as np` with + +```{code-cell} ipython3 +import jax.numpy as jnp +``` + Now we can use `jnp` in place of `np` for the usual array operations: ```{code-cell} ipython3 @@ -101,28 +107,30 @@ B = jnp.identity(2) A @ B ``` -```{code-cell} ipython3 -from jax.numpy import linalg -``` +JAX's array interface also provides the `linalg` subpackage: ```{code-cell} ipython3 -linalg.inv(B) # Inverse of identity is identity +jnp.linalg.inv(B) # Inverse of identity is identity ``` ```{code-cell} ipython3 -linalg.eigh(B) # Computes eigenvalues and eigenvectors +jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors ``` + ### Differences +Let's now look at some differences between JAX and NumPy array operations. + +#### Precision -One difference between NumPy and JAX is that JAX uses 32 bit floats by default. +One difference between NumPy and JAX is that JAX uses 32 bit floats by default. This is because JAX is often used for GPU computing, and most GPU computations use 32 bit floats. Using 32 bit floats can lead to significant speed gains with small loss of precision. -However, for some calculations precision matters. +However, for some calculations precision matters. In these cases 64 bit floats can be enforced via the command @@ -136,7 +144,9 @@ Let's check this works: jnp.ones(3) ``` -As a NumPy replacement, a more significant difference is that arrays are treated as **immutable**. +#### Immutability + +As a NumPy replacement, a more significant difference is that arrays are treated as **immutable**. For example, with NumPy we can write @@ -170,29 +180,28 @@ In line with immutability, JAX does not support inplace operations: ```{code-cell} ipython3 a = np.array((2, 1)) -a.sort() +a.sort() # Unlike NumPy, does not mutate a a ``` ```{code-cell} ipython3 a = jnp.array((2, 1)) -a_new = a.sort() +a_new = a.sort() # Instead, the sort method returns a new sorted array a, a_new ``` The designers of JAX chose to make arrays immutable because JAX uses a -functional programming style. More on this below. +*functional programming style*. + +This design choice has important implications, which we explore next! -However, JAX provides a functionally pure equivalent of in-place array modification +#### A workaround + +We note that JAX does provide a version of in-place array modification using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) -id(a) -``` - -```{code-cell} ipython3 -a ``` Applying `at[0].set(1)` returns a new copy of `a` with the first element set to 1 @@ -202,470 +211,581 @@ a = a.at[0].set(1) a ``` -Inspecting the identifier of `a` shows that it has been reassigned +Obviously, there are downsides to using `at`: -```{code-cell} ipython3 -id(a) -``` +* The syntax is cumbersome and +* we want to avoid creating fresh arrays in memory every time we change a single value! -## Random Numbers +Hence, for the most part, we try to avoid this syntax. -Random numbers are also a bit different in JAX, relative to NumPy. Typically, in JAX, the state of the random number generator needs to be controlled explicitly. +(Although it can in fact be efficient inside JIT-compiled functions -- but let's put this aside for now.) -```{code-cell} ipython3 -import jax.random as random -``` -First we produce a key, which seeds the random number generator. +## Functional Programming + +From JAX's documentation: + +*When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has "una anima di pura programmazione funzionale".* + +In other words, JAX assumes a functional programming style. + +### Pure functions + +The major implication is that JAX functions should be pure. + +**Pure functions** have the following characteristics: + +1. *Deterministic* +2. *No side effects* + +**Deterministic** means + +* Same input $\implies$ same output +* Outputs do not depend on global state + +In particular, pure functions will always return the same result if invoked with the same inputs. + +**No side effects** means that the function + +* Won't change global state +* Won't modify data passed to the function (immutable data) + + + +### Examples + +Here's an example of a *non-pure* function ```{code-cell} ipython3 -key = random.PRNGKey(1) +tax_rate = 0.1 +prices = [10.0, 20.0] + +def add_tax(prices): + for i, price in enumerate(prices): + prices[i] = price * (1 + tax_rate) + print('Post-tax prices: ', prices) + return prices ``` +This function fails to be pure because + +* side effects --- it modifies the global variable `prices` +* non-deterministic --- a change to the global variable `tax_rate` will modify + function outputs, even with the same input array `prices`. + +Here's a *pure* version + ```{code-cell} ipython3 -type(key) +tax_rate = 0.1 +prices = (10.0, 20.0) + +def add_tax_pure(prices, tax_rate): + new_prices = [price * (1 + tax_rate) for price in prices] + return new_prices ``` +This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state. + +Now that we understand what pure functions are, let's explore how JAX's approach to random numbers maintains this purity. + + +## Random numbers + +Random numbers are rather different in JAX, compared to what you find in NumPy +or Matlab. + +At first you might find the syntax rather verbose. + +But you will soon realize that the syntax and semantics are necessary in order +to maintain the functional programming style we just discussed. + +Moreover, full control of random state +essential for parallel programming, such as when we want to run independent experiments along multiple threads. + + +### Random number generation + +In JAX, the state of the random number generator is controlled explicitly. + +First we produce a key, which seeds the random number generator. + ```{code-cell} ipython3 -print(key) +seed = 1234 +key = jax.random.PRNGKey(seed) ``` Now we can use the key to generate some random numbers: ```{code-cell} ipython3 -x = random.normal(key, (3, 3)) +x = jax.random.normal(key, (3, 3)) x ``` If we use the same key again, we initialize at the same seed, so the random numbers are the same: ```{code-cell} ipython3 -random.normal(key, (3, 3)) +jax.random.normal(key, (3, 3)) ``` -To produce a (quasi-) independent draw, best practice is to "split" the existing key: +To produce a (quasi-) independent draw, one option is to "split" the existing key: ```{code-cell} ipython3 -key, subkey = random.split(key) +key, subkey = jax.random.split(key) ``` ```{code-cell} ipython3 -random.normal(key, (3, 3)) +jax.random.normal(key, (3, 3)) ``` ```{code-cell} ipython3 -random.normal(subkey, (3, 3)) +jax.random.normal(subkey, (3, 3)) ``` -The function below produces `k` (quasi-) independent random `n x n` matrices using this procedure. +This syntax will seem unusual for a NumPy or Matlab user --- but will make a lot +of sense when we progress to parallel programming. + +The function below produces `k` (quasi-) independent random `n x n` matrices using `split`. ```{code-cell} ipython3 -def gen_random_matrices(key, n, k): +def gen_random_matrices(key, n=2, k=3): matrices = [] for _ in range(k): - key, subkey = random.split(key) - matrices.append(random.uniform(subkey, (n, n))) + key, subkey = jax.random.split(key) + A = jax.random.uniform(subkey, (n, n)) + matrices.append(A) + print(A) return matrices ``` ```{code-cell} ipython3 -matrices = gen_random_matrices(key, 2, 2) -for A in matrices: - print(A) +seed = 42 +key = jax.random.PRNGKey(seed) +matrices = gen_random_matrices(key) ``` -To get a one-dimensional array of normal random draws, we can either use `(len, )` for the shape, as in +We can also use `fold_in` when iterating in a loop: ```{code-cell} ipython3 -random.normal(key, (5, )) +def gen_random_matrices(key, n=2, k=3): + matrices = [] + for i in range(k): + step_key = jax.random.fold_in(key, i) + A = jax.random.uniform(step_key, (n, n)) + matrices.append(A) + print(A) + return matrices ``` -or simply use `5` as the shape argument: - ```{code-cell} ipython3 -random.normal(key, 5) +key = jax.random.PRNGKey(seed) +matrices = gen_random_matrices(key) ``` -## JIT compilation -The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear -algebra operations into a single optimized kernel. +### Why explicit random state? -### A first example +Why does JAX require this somewhat verbose approach to random number generation? -To see the JIT compiler in action, consider the following function. +One reason is to maintain pure functions. -```{code-cell} ipython3 -def f(x): - a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5 - return jnp.sum(a) -``` +Let's see how random number generation relates to pure functions by comparing NumPy and JAX. -Let's build an array to call the function on. +#### NumPy's approach -```{code-cell} ipython3 -n = 50_000_000 -x = jnp.ones(n) -``` +In NumPy, random number generation works by maintaining hidden global state. -How long does the function take to execute? +Each time we call a random function, this state is updated: ```{code-cell} ipython3 -with qe.Timer(): - f(x).block_until_ready() +np.random.seed(42) +print(np.random.randn()) # Updates state of random number generator +print(np.random.randn()) # Updates state of random number generator ``` -```{note} -Here, in order to measure actual speed, we use the `block_until_ready()` method -to hold the interpreter until the results of the computation are returned. -This is necessary because JAX uses asynchronous dispatch, which -allows the Python interpreter to run ahead of numerical computations. +Each call returns a different value, even though we're calling the same function with the same inputs (no arguments). -``` +This function is *not pure* because: -If we run it a second time it becomes faster again: +* It's non-deterministic: same inputs (none, in this case) give different outputs +* It has side effects: it modifies the global random number generator state -```{code-cell} ipython3 -with qe.Timer(): - f(x).block_until_ready() -``` -This is because the built in functions like `jnp.cos` are JIT compiled and the -first run includes compile time. +#### JAX's approach -Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of -just providing pre-compiled versions, like NumPy? +As we saw above, JAX takes a different approach, making randomness explicit through keys. -The reason is that the JIT compiler can specialize on the *size* of the array -being used, which is helpful for parallelization. +For example, -For example, in running the code above, the JIT compiler produced a version of `jnp.cos` that is -specialized to floating point arrays of size `n = 50_000_000`. +```{code-cell} ipython3 +def random_sum_jax(key): + key1, key2 = jax.random.split(key) + x = jax.random.normal(key1) + y = jax.random.normal(key2) + return x + y +``` -We can check this by calling `f` with a new array of different size. +With the same key, we always get the same result: ```{code-cell} ipython3 -m = 50_000_001 -y = jnp.ones(m) +key = jax.random.PRNGKey(42) +random_sum_jax(key) ``` ```{code-cell} ipython3 -with qe.Timer(): - f(y).block_until_ready() +random_sum_jax(key) ``` -Notice that the execution time increases, because now new versions of -the built-ins like `jnp.cos` are being compiled, specialized to the new array -size. +To get new draws we need to supply a new key. -If we run again, the code is dispatched to the correct compiled version and we -get faster execution. +The function `random_sum_jax` is pure because: -```{code-cell} ipython3 -with qe.Timer(): - f(y).block_until_ready() -``` +* It's deterministic: same key always produces same output +* No side effects: no hidden state is modified -The compiled versions for the previous array size are still available in memory -too, and the following call is dispatched to the correct compiled code. +The explicitness of JAX brings significant benefits: -```{code-cell} ipython3 -with qe.Timer(): - f(x).block_until_ready() -``` +* Reproducibility: Easy to reproduce results by reusing keys +* Parallelization: Each thread can have its own key without conflicts +* Debugging: No hidden state makes code easier to reason about +* JIT compatibility: The compiler can optimize pure functions more aggressively -### Compiling the outer function +The last point is expanded on in the next section. -We can do even better if we manually JIT-compile the outer function. -```{code-cell} ipython3 -f_jit = jax.jit(f) # target for JIT compilation +## JIT compilation + +The JAX just-in-time (JIT) compiler accelerates execution by generating +efficient machine code that varies with both task size and hardware. + +### A simple example + +Let's say we want to evaluate the cosine function at many points. + +```{code-cell} +n = 50_000_000 +x = np.linspace(0, 10, n) ``` -Let's run once to compile it: +#### With NumPy -```{code-cell} ipython3 -f_jit(x) +Let's try with NumPy + +```{code-cell} +with qe.Timer(): + y = np.cos(x) ``` -And now let's time it. +And one more time. -```{code-cell} ipython3 +```{code-cell} with qe.Timer(): - f_jit(x).block_until_ready() + y = np.cos(x) ``` -Note the speed gain. +Here NumPy uses a pre-built binary file, compiled from carefully written +low-level code, for applying cosine to an array of floats. -This is because the array operations are fused and no intermediate arrays are created. +This binary file ships with NumPy. +#### With JAX -Incidentally, a more common syntax when targetting a function for the JIT -compiler is +Now let's try with JAX. -```{code-cell} ipython3 -@jax.jit -def f(x): - a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5 - return jnp.sum(a) +```{code-cell} +x = jnp.linspace(0, 10, n) ``` -## Functional Programming +Let's time the same procedure. -From JAX's documentation: +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` -*When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.* +```{note} +Here, in order to measure actual speed, we use the `block_until_ready` method +to hold the interpreter until the results of the computation are returned. -In other words, JAX assumes a functional programming style. +This is necessary because JAX uses asynchronous dispatch, which +allows the Python interpreter to run ahead of numerical computations. -The major implication is that JAX functions should be pure. - -A pure function will always return the same result if invoked with the same inputs. +For non-timed code, you can drop the line containing `block_until_ready`. +``` -In particular, a pure function has -* no dependence on global variables and -* no side effects +And let's time it again. -JAX will not usually throw errors when compiling impure functions but execution becomes unpredictable. +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` -Here's an illustration of this fact, using global variables: +If you are running this on a GPU the code will run much faster than its NumPy +equivalent, which ran on the CPU. -```{code-cell} ipython3 -a = 1 # global +Even if you are running on a machine with many CPUs, the second JAX run should +be substantially faster with JAX. -@jax.jit -def f(x): - return a + x -``` +Also, typically, the second run is faster than the first. -```{code-cell} ipython3 -x = jnp.ones(2) -``` +(This might not be noticable on the CPU but it should definitely be noticable on +the GPU.) -```{code-cell} ipython3 -f(x) -``` +This is because even built in functions like `jnp.cos` are JIT-compiled --- and the +first run includes compile time. -In the code above, the global value `a=1` is fused into the jitted function. +Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of +just providing pre-compiled versions, like NumPy? -Even if we change `a`, the output of `f` will not be affected --- as long as the same compiled version is called. +The reason is that the JIT compiler wants to specialize on the *size* of the array +being used (as well as the data type). -```{code-cell} ipython3 -a = 42 -``` +The size matters for generating optimized code because efficient parallelization +requires matching the size of the task to the available hardware. -```{code-cell} ipython3 -f(x) -``` +That's why JAX waits to see the size of the array before compiling --- which +requires a JIT-compiled approach instead of supplying precompiled binaries. -Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of `a` takes effect: -```{code-cell} ipython3 -x = jnp.ones(3) +#### Changing array sizes + +Here we change the input size and watch the runtimes. + +```{code-cell} +x = jnp.linspace(0, 10, n + 1) ``` -```{code-cell} ipython3 -f(x) +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); ``` -Moral of the story: write pure functions when using JAX! +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` -## Gradients +Typically, the run time increases and then falls again (this will be more obvious on the GPU). -JAX can use automatic differentiation to compute gradients. +This is because the JIT compiler specializes on array size to exploit +parallelization --- and hence generates fresh compiled code when the array size +changes. -This can be extremely useful for optimization and solving nonlinear systems. -We will see significant applications later in this lecture series. +### Evaluating a more complicated function -For now, here's a very simple illustration involving the function +Let's try the same thing with a more complex function. -```{code-cell} ipython3 +```{code-cell} def f(x): - return (x**2) / 2 + y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2 + return y ``` -Let's take the derivative: +#### With NumPy -```{code-cell} ipython3 -f_prime = jax.grad(f) +We'll try first with NumPy + +```{code-cell} +n = 50_000_000 +x = np.linspace(0, 10, n) ``` -```{code-cell} ipython3 -f_prime(10.0) +```{code-cell} +with qe.Timer(): + y = f(x) ``` -Let's plot the function and derivative, noting that $f'(x) = x$. -```{code-cell} ipython3 -import matplotlib.pyplot as plt -fig, ax = plt.subplots() -x_grid = jnp.linspace(-4, 4, 200) -ax.plot(x_grid, f(x_grid), label="$f$") -ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") -ax.legend(loc='upper center') -plt.show() -``` +#### With JAX -We defer further exploration of automatic differentiation with JAX until {doc}`jax:autodiff`. +Now let's try again with JAX. -## Writing vectorized code +As a first pass, we replace `np` with `jnp` throughout: -Writing fast JAX code requires shifting repetitive tasks from loops to array processing operations, so that the JAX compiler can easily understand the whole operation and generate more efficient machine code. +```{code-cell} +def f(x): + y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2 + return y +``` -This procedure is called **vectorization** or **array programming**, and will be -familiar to anyone who has used NumPy or MATLAB. +Now let's time it. + +```{code-cell} +x = jnp.linspace(0, 10, n) +``` -In most ways, vectorization is the same in JAX as it is in NumPy. +```{code-cell} +with qe.Timer(): + y = f(x) + jax.block_until_ready(y); +``` -But there are also some differences, which we highlight here. +```{code-cell} +with qe.Timer(): + y = f(x) + jax.block_until_ready(y); +``` -As a running example, consider the function +The outcome is similar to the `cos` example --- JAX is faster, especially if you +use a GPU and especially on the second run. -$$ - f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2} -$$ +Moreover, with JAX, we have another trick up our sleeve: -Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points and then plot it. -To clarify, here is the slow `for` loop version. +### Compiling the Whole Function -```{code-cell} ipython3 -@jax.jit -def f(x, y): - return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2) +The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear +algebra operations into a single optimized kernel. -n = 80 -x = jnp.linspace(-2, 2, n) -y = x +Let's try this with the function `f`: -z_loops = np.empty((n, n)) +```{code-cell} +f_jax = jax.jit(f) ``` -```{code-cell} ipython3 +```{code-cell} with qe.Timer(): - for i in range(n): - for j in range(n): - z_loops[i, j] = f(x[i], y[j]) + y = f_jax(x) + jax.block_until_ready(y); ``` -Even for this very small grid, the run time is extremely slow. +```{code-cell} +with qe.Timer(): + y = f_jax(x) + jax.block_until_ready(y); +``` -(Notice that we used a NumPy array for `z_loops` because we wanted to write to it.) +The runtime has improved again --- now because we fused all the operations, +allowing the compiler to optimize more aggressively. -+++ +For example, the compiler can eliminate multiple calls to the hardware +accelerator and the creation of a number of intermediate arrays. -OK, so how can we do the same operation in vectorized form? -If you are new to vectorization, you might guess that we can simply write +Incidentally, a more common syntax when targeting a function for the JIT +compiler is ```{code-cell} ipython3 -z_bad = f(x, y) +@jax.jit +def f(x): + pass # put function body here ``` -But this gives us the wrong result because JAX doesn't understand the nested for loop. +### Compiling non-pure functions -```{code-cell} ipython3 -z_bad.shape -``` +Now that we've seen how powerful JIT compilation can be, it's important to understand its relationship with pure functions. -Here is what we actually wanted: +While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable. + +Here's an illustration of this fact, using global variables: ```{code-cell} ipython3 -z_loops.shape -``` +a = 1 # global -To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation designed for this purpose: +@jax.jit +def f(x): + return a + x +``` ```{code-cell} ipython3 -x_mesh, y_mesh = jnp.meshgrid(x, y) +x = jnp.ones(2) ``` -Now we get what we want and the execution time is very fast. - ```{code-cell} ipython3 -with qe.Timer(): - z_mesh = f(x_mesh, y_mesh).block_until_ready() +f(x) ``` -Let's run again to eliminate compile time. +In the code above, the global value `a=1` is fused into the jitted function. + +Even if we change `a`, the output of `f` will not be affected --- as long as the same compiled version is called. ```{code-cell} ipython3 -with qe.Timer(): - z_mesh = f(x_mesh, y_mesh).block_until_ready() +a = 42 ``` -Let's confirm that we got the right answer. - ```{code-cell} ipython3 -jnp.allclose(z_mesh, z_loops) +f(x) ``` -Now we can set up a serious grid and run the same calculation (on the larger grid) in a short amount of time. +Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of `a` takes effect: ```{code-cell} ipython3 -n = 6000 -x = jnp.linspace(-2, 2, n) -y = x -x_mesh, y_mesh = jnp.meshgrid(x, y) +x = jnp.ones(3) ``` ```{code-cell} ipython3 -with qe.Timer(): - z_mesh = f(x_mesh, y_mesh).block_until_ready() +f(x) ``` -But there is one problem here: the mesh grids use a lot of memory. +Moral of the story: write pure functions when using JAX! -```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes -``` -By comparison, the flat array `x` is just +### Summary -```{code-cell} ipython3 -x.nbytes # and y is just a pointer to x -``` +Now we can see why both developers and compilers benefit from pure functions. -This extra memory usage can be a big problem in actual research calculations. +We love pure functions because they -So let's try a different approach using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) +* Help testing: each function can operate in isolation +* Promote deterministic behavior and hence reproducibility +* Prevent bugs that arise from mutating shared state -+++ +The compiler loves pure functions and functional programming because -First we vectorize `f` in `y`. +* Data dependencies are explicit, which helps with optimizing complex computations +* Pure functions are easier to differentiate (autodiff) +* Pure functions are easier to parallelize and optimize (don't depend on shared mutable state) -```{code-cell} ipython3 -f_vec_y = jax.vmap(f, in_axes=(None, 0)) -``` -In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`. +## Gradients -Next, we vectorize in the first argument, which is `x`. +JAX can use automatic differentiation to compute gradients. + +This can be extremely useful for optimization and solving nonlinear systems. + +We will see significant applications later in this lecture series. + +For now, here's a very simple illustration involving the function ```{code-cell} ipython3 -f_vec = jax.vmap(f_vec_y, in_axes=(0, None)) +def f(x): + return (x**2) / 2 ``` -With this construction, we can now call the function $f$ on flat (low memory) arrays. +Let's take the derivative: ```{code-cell} ipython3 -with qe.Timer(): - z_vmap = f_vec(x, y).block_until_ready() +f_prime = jax.grad(f) ``` -The execution time is essentially the same as the mesh operation but we are using much less memory. +```{code-cell} ipython3 +f_prime(10.0) +``` -And we produce the correct answer: +Let's plot the function and derivative, noting that $f'(x) = x$. ```{code-cell} ipython3 -jnp.allclose(z_vmap, z_mesh) +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +x_grid = jnp.linspace(-4, 4, 200) +ax.plot(x_grid, f(x_grid), label="$f$") +ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") +ax.legend(loc='upper center') +plt.show() ``` +We defer further exploration of automatic differentiation with JAX until {doc}`jax:autodiff`. + + ## Exercises @@ -673,7 +793,8 @@ jnp.allclose(z_vmap, z_mesh) :label: jax_intro_ex2 ``` -In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option. +In the Exercise section of {doc}`our lecture on Numba `, we used Monte +Carlo to price a European call option. The code was accelerated by Numba-based multithreading. @@ -717,7 +838,7 @@ def compute_call_price_jax(β=β, s = s + μ + jnp.exp(h) * Z[0, :] h = ρ * h + ν * Z[1, :] expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0)) - + return β**n * expectation ``` diff --git a/lectures/need_for_speed.md b/lectures/need_for_speed.md index dc9d7274..ef4f23ad 100644 --- a/lectures/need_for_speed.md +++ b/lectures/need_for_speed.md @@ -27,25 +27,26 @@ premature optimization is the root of all evil." -- Donald Knuth ## Overview -Python is popular for scientific computing due to factors such as +It's probably safe to say that Python is the most popular language for scientific computing. + +This is due to * the accessible and expressive nature of the language itself, * the huge range of high quality scientific libraries, * the fact that the language and libraries are open source, -* the popular [Anaconda Python distribution](https://www.anaconda.com/download), which simplifies installation and management of scientific libraries, and -* the key role that Python plays in data science, machine learning and artificial intelligence. +* the central role that Python plays in data science, machine learning and AI. -In previous lectures, we looked at some scientific Python libraries such as NumPy and Matplotlib. +In previous lectures, we used some scientific Python libraries, including NumPy and Matplotlib. However, our main focus was the core Python language, rather than the libraries. Now we turn to the scientific libraries and give them our full attention. -We'll also discuss the following topics: +In this introductory lecture, we'll discuss the following topics: -* What are the relative strengths and weaknesses of Python for scientific work? -* What are the main elements of the scientific Python ecosystem? -* How is the situation changing over time? +1. What are the main elements of the scientific Python ecosystem? +1. How do they fit together? +1. How is the situation changing over time? In addition to what's in Anaconda, this lecture will need @@ -56,27 +57,36 @@ tags: [hide-output] !pip install quantecon ``` +Let's start with some imports: + +```{code-cell} ipython +import numpy as np +import quantecon as qe +import matplotlib.pyplot as plt +import random +``` + +## Major Scientific Libraries -## Scientific Libraries +Let's briefly review Python's scientific libraries. -Let's briefly review Python's scientific libraries, starting with why we need them. -### The Role of Scientific Libraries +### Why do we need them? One reason we use scientific libraries is because they implement routines we want to use. * numerical integration, interpolation, linear algebra, root finding, etc. -For example, it's almost always better to use an existing routine for root finding than to write a new one from scratch. +For example, it's usually better to use an existing routine for root finding than to write a new one from scratch. (For standard algorithms, efficiency is maximized if the community can coordinate on a common set of implementations, written by experts and tuned by -users to be as fast and robust as possible.) +users to be as fast and robust as possible!) But this is not the only reason that we use Python's scientific libraries. -Another is that pure Python, while flexible and elegant, is not fast. +Another is that pure Python is not fast. So we need libraries that are designed to accelerate execution of Python code. @@ -85,7 +95,8 @@ They do this using two strategies: 1. using compilers that convert Python-like statements into fast machine code for individual threads of logic and 2. parallelizing tasks across multiple "workers" (e.g., CPUs, individual threads inside GPUs). -There are several Python libraries that can do this extremely well. +We will discuss these ideas extensively in this and the remaining lectures from +this series. ### Python's Scientific Ecosystem @@ -97,7 +108,7 @@ At QuantEcon, the scientific libraries we use most often are * [Matplotlib](https://matplotlib.org/) * [JAX](https://github.com/jax-ml/jax) * [Pandas](https://pandas.pydata.org/) -* [Numba](https://numba.pydata.org/) and +* [Numba](https://numba.pydata.org/) Here's how they fit together: @@ -112,53 +123,42 @@ Here's how they fit together: * Pandas provides types and functions for manipulating data. * Numba provides a just-in-time compiler that plays well with NumPy and helps accelerate Python code. +We will discuss all of these libraries extensively in this lecture series. -## The Need for Speed -Let's discuss execution speed and how scientific libraries can help us accelerate code. +## Pure Python is slow -Higher-level languages like Python are optimized for humans. - -This means that the programmer can leave many details to the runtime environment - -* specifying variable types -* memory allocation/deallocation, etc. +As mentioned above, one major attraction of the scientific libraries is greater execution speeds. -On one hand, compared to low-level languages, high-level languages are typically faster to write, less error-prone and easier to debug. +We will discuss how scientific libraries can help us accelerate code. -On the other hand, high-level languages are harder to optimize --- that is, to turn into fast machine code --- than languages like C or Fortran. +For this topic, it will be helpful if we understand what's driving slow execution speeds. -Indeed, the standard implementation of Python (called CPython) cannot match the speed of compiled languages such as C or Fortran. -Does that mean that we should just switch to C or Fortran for everything? +### High vs low level code -The answer is: No! +Higher-level languages like Python are optimized for humans. -There are three reasons why: - -First, for any given program, relatively few lines are ever going to be time-critical. - -Hence it is far more efficient to write most of our code in a high productivity language like Python. +This means that the programmer can leave many details to the runtime environment -Second, even for those lines of code that *are* time-critical, we can now achieve the same speed as C or Fortran using Python's scientific libraries. +* specifying variable types +* memory allocation/deallocation +* etc. -Third, in the last few years, accelerating code has become essentially -synonymous with parallelizing execution, and this task is best left to -specialized compilers. +In addition, pure Python is run by an [interpreter](https://en.wikipedia.org/wiki/Interpreter_(computing)), which executes code statement-by-statement. -Certain Python libraries have outstanding capabilities for parallelizing -scientific code -- we'll discuss this more as we go along. +This makes Python flexible, interactive, easy to write, easy to read, and relatively easy to debug. +On the other hand, the standard implementation of Python (called CPython) cannot +match the speed of compiled languages such as C or Fortran. -### Where are the Bottlenecks? -Before we do so, let's try to understand why plain vanilla Python is slower than C or Fortran. +### Where are the bottlenecks? -This will, in turn, help us figure out how to speed things up. +Why is this the case? -In reading the following, remember that the Python interpreter executes code line-by-line. -#### Dynamic Typing +#### Dynamic typing ```{index} single: Dynamic Typing ``` @@ -194,10 +194,13 @@ type of the objects on which it acts) As a result, when executing `a + b`, Python must first check the type of the objects and then call the correct operation. -This involves substantial overheads. +This involves a nontrivial overhead. + +If we repeatedly execute this expression in a tight loop, the nontrivial +overhead becomes a large overhead. -#### Static Types +#### Static types ```{index} single: Static Types ``` @@ -224,7 +227,13 @@ int main(void) { The variables `i` and `sum` are explicitly declared to be integers. -Hence, the meaning of addition here is completely unambiguous. +Moreover, when we make a statement such as `int i`, we are making a promise to the compiler +that `i` will *always* be an integer, throughout execution of the program. + +As such, the meaning of addition in the expression `sum + i` is completely unambiguous. + +There is no need for type-checking and hence no overhead. + ### Data Access @@ -241,7 +250,7 @@ Such an array is stored in a single contiguous block of memory * In modern computers, memory addresses are allocated to each byte (one byte = 8 bits). * For example, a 64 bit integer is stored in 8 bytes of memory. -* An array of $n$ such integers occupies $8n$ **consecutive** memory slots. +* An array of $n$ such integers occupies $8n$ *consecutive* memory slots. Moreover, the compiler is made aware of the data type by the programmer. @@ -266,21 +275,55 @@ This is a considerable drag on speed. In fact, it's generally true that memory traffic is a major culprit when it comes to slow execution. -Let's look at some ways around these problems. +### Summary + +Does the discussion above mean that we should just switch to C or Fortran for everything? + +The answer is: Definitely not! + +For any given program, relatively few lines are ever going to be time-critical. + +Hence it is far more efficient to write most of our code in a high productivity language like Python. + +Moreover, even for those lines of code that *are* time-critical, we can now +equal or outpace binaries compiled from C or Fortran by using Python's scientific libraries. + +On that note, we emphasize that, in the last few years, accelerating code has become essentially +synonymous with parallelization. + +This task is best left to specialized compilers! + +Certain Python libraries have outstanding capabilities for parallelizing scientific code -- we'll discuss this more as we go along. + + + + +## Accelerating Python + +In this section we look at three related techniques for accelerating Python +code. + +Here we'll focus on the fundamental ideas. +Later we'll look at specific libraries and how they implement these ideas. -## {index}`Vectorization ` + + +### {index}`Vectorization ` ```{index} single: Python; Vectorization ``` -One method for avoiding memory traffic and type checking is [array programming](https://en.wikipedia.org/wiki/Array_programming). +One method for avoiding memory traffic and type checking is [array +programming](https://en.wikipedia.org/wiki/Array_programming). -Economists usually refer to array programming as ``vectorization.'' +Many economists usually refer to array programming as "vectorization." -(In computer science, this term has [a slightly different meaning](https://en.wikipedia.org/wiki/Automatic_vectorization).) +```{note} +In computer science, this term has [a slightly different meaning](https://en.wikipedia.org/wiki/Automatic_vectorization). +``` The key idea is to send array processing operations in batch to pre-compiled and efficient native machine code. @@ -291,17 +334,64 @@ For example, when working in a high level language, the operation of inverting a large matrix can be subcontracted to efficient machine code that is pre-compiled for this purpose and supplied to users as part of a package. -This idea dates back to MATLAB, which uses vectorization extensively. +The core benefits are + +1. type-checking is paid *per array*, rather than per element, and +1. arrays containing elements with the same data type are efficient in terms of + memory access. + +The idea of vectorization dates back to MATLAB, which uses vectorization extensively. ```{figure} /_static/lecture_specific/need_for_speed/matlab.png ``` -Vectorization can greatly accelerate many numerical computations, as we will see -in later lectures. + + +### Vectorization vs for pure Python loops + +Let's try a quick speed comparison to illustrate how vectorization can +accelerate code. + +Here's some non-vectorized code, which uses a native Python loop to generate, +square and then sum a large number of random variables: + +```{code-cell} python3 +n = 1_000_000 +``` + +```{code-cell} python3 +with qe.Timer(): + y = 0 # Will accumulate and store sum + for i in range(n): + x = random.uniform(0, 1) + y += x**2 +``` + +The following vectorized code uses NumPy, which we'll soon investigate in depth, +to achieve the same thing. + +```{code-cell} ipython +with qe.Timer(): + x = np.random.uniform(0, 1, n) + y = np.sum(x**2) +``` + +As you can see, the second code block runs much faster. + +It breaks the loop down into three basic operations + +1. draw `n` uniforms +1. square them +1. sum them + +These are sent as batch operators to optimized machine code. + + + (numba-p_c_vectorization)= -## Beyond Vectorization +### JIT compilers At best, vectorization yields fast, simple code. @@ -315,8 +405,170 @@ producing the final calculation. Another issue is that not all algorithms can be vectorized. Because of these issues, most high performance computing is moving away from -traditional vectorization and towards the use of [just-in-time compilers](https://en.wikipedia.org/wiki/Just-in-time_compilation). +traditional vectorization and towards the use of [just-in-time +compilers](https://en.wikipedia.org/wiki/Just-in-time_compilation). + +In later lectures in this series, we will learn about how modern Python +libraries exploit just-in-time compilers to generate fast, efficient, +parallelized machine code. + + + + +## Parallelization + +The growth of CPU clock speed (i.e., the speed at which a single chain of logic +can be run) has slowed dramatically in recent years. + +Chip designers and computer programmers have responded to the slowdown by +seeking a different path to fast execution: parallelization. + +Hardware makers have increased the number of cores (physical CPUs) embedded in each machine. + +For programmers, the challenge has been to exploit these multiple CPUs by +running many processes in parallel (i.e., simultaneously). + +This is particularly important in scientific programming, which requires handling + +* large amounts of data and +* CPU intensive simulations and other calculations. + +Below we discuss parallelization for scientific computing, with a focus on + +1. the best tools for parallelization in Python and +1. how these tools can be applied to quantitative economic problems. + + +### Parallelization on CPUs + +Let's review the two main kinds of CPU-based parallelization commonly used in +scientific computing and discuss their pros and cons. + + +#### Multiprocessing + +Multiprocessing means concurrent execution of multiple processes using more than one processor. + +In this context, a **process** is a chain of instructions (i.e., a program). + +Multiprocessing can be carried out on one machine with multiple CPUs or on a +collection of machines connected by a network. + +In the latter case, the collection of machines is usually called a +**cluster**. + +With multiprocessing, each process has its own memory space, although the +physical memory chip might be shared. + +#### Multithreading + +Multithreading is similar to multiprocessing, except that, during execution, the threads all share the same memory space. + +Native Python struggles to implement multithreading due to some [legacy design +features](https://wiki.python.org/moin/GlobalInterpreterLock). + +But this is not a restriction for scientific libraries like NumPy and Numba. + +Functions imported from these libraries and JIT-compiled code run in low level +execution environments where Python's legacy restrictions don't apply. + +#### Advantages and Disadvantages + +Multithreading is more lightweight because most system and memory resources +are shared by the threads. + +In addition, the fact that multiple threads all access a shared pool of memory +is extremely convenient for numerical programming. + +On the other hand, multiprocessing is more flexible and can be distributed +across clusters. + +For the great majority of what we do in these lectures, multithreading will +suffice. + + +### Hardware Accelerators + +While CPUs with multiple cores have become standard for parallel computing, a +more dramatic shift has occurred with the rise of specialized hardware +accelerators. + +These accelerators are designed specifically for the kinds of highly parallel +computations that arise in scientific computing, machine learning, and data +science. + +#### GPUs and TPUs + +The two most important types of hardware accelerators are + +* **GPUs** (Graphics Processing Units) and +* **TPUs** (Tensor Processing Units). + +GPUs were originally designed for rendering graphics, which requires performing +the same operation on many pixels simultaneously. + +```{figure} /_static/lecture_specific/need_for_speed/geforce.png +:scale: 40 +``` + +Scientists and engineers realized that this same architecture --- many simple +processors working in parallel --- is ideal for scientific computing tasks + +TPUs are a more recent development, designed by Google specifically for machine learning workloads. + +Like GPUs, TPUs excel at performing massive numbers of matrix operations in parallel. + + +#### Why TPUs/GPUs Matter + +The performance gains from using hardware accelerators can be dramatic. + +For example, a modern GPU can contain thousands of small processing cores, +compared to the 8-64 cores typically found in CPUs. + +When a problem can be expressed as many independent operations on arrays of +data, GPUs can be orders of magnitude faster than CPUs. + +This is particularly relevant for scientific computing because many algorithms +naturally map onto the parallel architecture of GPUs. + + +### Single GPUs vs GPU Servers + +There are two common ways to access GPU resources: + +#### Single GPU Systems + +Many workstations and laptops now come with capable GPUs, or can be equipped with them. + +A single modern GPU can dramatically accelerate many scientific computing tasks. + +For individual researchers and small projects, a single GPU is often sufficient. + +Modern Python libraries like JAX, discussed extensively in this lecture series, +automatically detect and use available GPUs with minimal code changes. + + +#### Multi-GPU Servers + +For larger-scale problems, servers containing multiple GPUs (often 4-8 GPUs per server) are increasingly common. + +```{figure} /_static/lecture_specific/need_for_speed/dgx.png +:scale: 40 +``` + + +With appropriate software, computations can be distributed across multiple GPUs, either within a single server or across multiple servers. + +This enables researchers to tackle problems that would be infeasible on a single GPU or CPU. + + +### Summary + +GPU computing is becoming far more accessible, particularly from within Python. + +Some Python scientific libraries, like JAX, now support GPU acceleration with minimal changes to existing code. -In later lectures in this series, we will learn about how modern Python libraries exploit -just-in-time compilers to generate fast, efficient, parallelized machine code. +We will explore GPU computing in more detail in later lectures, applying it to a +range of economic applications. diff --git a/lectures/numba.md b/lectures/numba.md index d21cd745..b6a3cccb 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -67,9 +67,7 @@ The key idea is to compile functions to native machine code instructions on the When it succeeds, the compiled code is extremely fast. -Numba is specifically designed for numerical work and can also do other tricks such as [multithreading](https://en.wikipedia.org/wiki/Multithreading_%28computer_architecture%29). - -Numba will be a key part of our lectures --- especially those lectures involving dynamic programming. +Beyond speed gains from compilation, Numba is specifically designed for numerical work and can also do other tricks such as {ref}`multithreading`. This lecture introduces the main ideas. @@ -238,15 +236,13 @@ with qe.Timer(precision=4): Numba also provides several arguments for decorators to accelerate computation and cache functions -- see [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html). -In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization. - ## Type Inference Successful type inference is a key part of JIT compilation. As you can imagine, inferring types is easier for simple Python objects (e.g., simple scalar data types such as floats and integers). -Numba also plays well with NumPy arrays. +Numba also plays well with NumPy arrays, which have well-defined types. In an ideal setting, Numba can infer all necessary type information. @@ -299,7 +295,7 @@ As mentioned above, at present Numba can only compile a subset of Python. However, that subset is ever expanding. -For example, Numba is now quite effective at compiling classes. +Notably, Numba is now quite effective at compiling classes. If a class is successfully compiled, then its methods act as JIT-compiled functions. @@ -405,97 +401,195 @@ ax.legend() plt.show() ``` -## Alternatives to Numba +## Dangers and Limitations + +Let's review the above and add some cautionary notes. + +### Limitations + +As we've seen, Numba needs to infer type information on +all variables to generate fast machine-level instructions. + +For simple routines, Numba infers types very well. + +For larger ones, or for routines using external libraries, it can easily fail. + +Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code. + +This will give you much better performance than blanketing your Python programs with `@njit` statements. + +### A Gotcha: Global Variables + +Here's another thing to be careful about when using Numba. + +Consider the following example + +```{code-cell} ipython3 +a = 1 + +@jit +def add_a(x): + return a + x + +print(add_a(10)) +``` + +```{code-cell} ipython3 +a = 2 -```{index} single: Python; Cython +print(add_a(10)) ``` -There are additional options for accelerating Python loops. +Notice that changing the global had no effect on the value returned by the +function. -Here we quickly review them. +When Numba compiles machine code for functions, it treats global variables as constants to ensure type stability. -However, we do so only for interest and completeness. +(multithreading)= +## Multithreaded Loops in Numba -If you prefer, you can safely skip this section. +In addition to JIT compilation, Numba provides powerful support for parallel computing on CPUs. -### Cython +By distributing computations across multiple CPU cores, we can achieve significant speed gains for many numerical algorithms. -Like {doc}`Numba `, [Cython](https://cython.org/) provides an approach to generating fast compiled code that can be used from Python. +The key tool for parallelization in Numba is the `prange` function, which tells Numba to execute loop iterations in parallel across available CPU cores. -As was the case with Numba, a key problem is the fact that Python is dynamically typed. +This approach to multithreading works well for a wide range of problems in scientific computing and quantitative economics. -As you'll recall, Numba solves this problem (where possible) by inferring type. +To illustrate, let's look first at a simple, single-threaded (i.e., non-parallelized) piece of code. -Cython's approach is different --- programmers add type definitions directly to their "Python" code. +The code simulates updating the wealth $w_t$ of a household via the rule -As such, the Cython language can be thought of as Python with type definitions. +$$ +w_{t+1} = R_{t+1} s w_t + y_{t+1} +$$ + +Here + +* $R$ is the gross rate of return on assets +* $s$ is the savings rate of the household and +* $y$ is labor income. -In addition to a language specification, Cython is also a language translator, transforming Cython code into optimized C and C++ code. +We model both $R$ and $y$ as independent draws from a lognormal +distribution. -Cython also takes care of building language extensions --- the wrapper code that interfaces between the resulting compiled code and Python. +Here's the code: + +```{code-cell} ipython3 +from numpy.random import randn +from numba import njit -While Cython has certain advantages, we generally find it both slower and more -cumbersome than Numba. +@njit +def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0): + """ + Updates household wealth. + """ -### Interfacing with Fortran via F2Py + # Draw shocks + R = np.exp(v1 * randn()) * (1 + r) + y = np.exp(v2 * randn()) -```{index} single: Python; Interfacing with Fortran + # Update wealth + w = R * s * w + y + return w ``` -If you are comfortable writing Fortran you will find it very easy to create -extension modules from Fortran code using [F2Py](https://numpy.org/doc/stable/f2py/). +Let's have a look at how wealth evolves under this rule. -F2Py is a Fortran-to-Python interface generator that is particularly simple to -use. +```{code-cell} ipython3 +fig, ax = plt.subplots() -Robert Johansson provides a [nice introduction](https://nbviewer.org/github/jrjohansson/scientific-python-lectures/blob/master/Lecture-6A-Fortran-and-C.ipynb) -to F2Py, among other things. +T = 100 +w = np.empty(T) +w[0] = 5 +for t in range(T-1): + w[t+1] = h(w[t]) -Recently, [a Jupyter cell magic for Fortran](https://nbviewer.org/github/mgaitan/fortran_magic/blob/master/documentation.ipynb) has been developed --- you might want to give it a try. +ax.plot(w) +ax.set_xlabel('$t$', fontsize=12) +ax.set_ylabel('$w_{t}$', fontsize=12) +plt.show() +``` -## Summary and Comments +Now let's suppose that we have a large population of households and we want to +know what median wealth will be. -Let's review the above and add some cautionary notes. +This is not easy to solve with pencil and paper, so we will use simulation +instead. -### Limitations +In particular, we will simulate a large number of households and then +calculate median wealth for this group. -As we've seen, Numba needs to infer type information on -all variables to generate fast machine-level instructions. +Suppose we are interested in the long-run average of this median over time. -For simple routines, Numba infers types very well. +It turns out that, for the specification that we've chosen above, we can +calculate this by taking a one-period snapshot of what has happened to median +wealth of the group at the end of a long simulation. -For larger ones, or for routines using external libraries, it can easily fail. +Moreover, provided the simulation period is long enough, initial conditions +don't matter. -Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code. +* This is due to something called ergodicity, which we will discuss [later on](https://python.quantecon.org/finite_markov.html#id15). -This will give you much better performance than blanketing your Python programs with `@njit` statements. +So, in summary, we are going to simulate 50,000 households by -### A Gotcha: Global Variables +1. arbitrarily setting initial wealth to 1 and +1. simulating forward in time for 1,000 periods. -Here's another thing to be careful about when using Numba. +Then we'll calculate median wealth at the end period. -Consider the following example +Here's the code: ```{code-cell} ipython3 -a = 1 +@njit +def compute_long_run_median(w0=1, T=1000, num_reps=50_000): -@jit -def add_a(x): - return a + x + obs = np.empty(num_reps) + for i in range(num_reps): + w = w0 + for t in range(T): + w = h(w) + obs[i] = w -print(add_a(10)) + return np.median(obs) ``` +Let's see how fast this runs: + ```{code-cell} ipython3 -a = 2 +with qe.Timer(): + compute_long_run_median() +``` -print(add_a(10)) +To speed this up, we're going to parallelize it via multithreading. + +To do so, we add the `parallel=True` flag and change `range` to `prange`: + +```{code-cell} ipython3 +from numba import prange + +@njit(parallel=True) +def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000): + + obs = np.empty(num_reps) + for i in prange(num_reps): + w = w0 + for t in range(T): + w = h(w) + obs[i] = w + + return np.median(obs) ``` -Notice that changing the global had no effect on the value returned by the -function. +Let's look at the timing: + +```{code-cell} ipython3 +with qe.Timer(): + compute_long_run_median_parallel() +``` + +The speed-up is significant. -When Numba compiles machine code for functions, it treats global variables as constants to ensure type stability. ## Exercises @@ -670,3 +764,189 @@ This is a nice speed improvement for one line of code! ```{solution-end} ``` + +```{exercise} +:label: numba_ex3 + +In {ref}`an earlier exercise `, we used Numba to accelerate an +effort to compute the constant $\pi$ by Monte Carlo. + +Now try adding parallelization and see if you get further speed gains. + +You should not expect huge gains here because, while there are many +independent tasks (draw point and test if in circle), each one has low +execution time. + +Generally speaking, parallelization is less effective when the individual +tasks to be parallelized are very small relative to total execution time. + +This is due to overheads associated with spreading all of these small tasks across multiple CPUs. + +Nevertheless, with suitable hardware, it is possible to get nontrivial speed gains in this exercise. + +For the size of the Monte Carlo simulation, use something substantial, such as +`n = 100_000_000`. +``` + +```{solution-start} numba_ex3 +:class: dropdown +``` + +Here is one solution: + +```{code-cell} ipython3 +from random import uniform + +@njit(parallel=True) +def calculate_pi(n=1_000_000): + count = 0 + for i in prange(n): + u, v = uniform(0, 1), uniform(0, 1) + d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2) + if d < 0.5: + count += 1 + + area_estimate = count / n + return area_estimate * 4 # dividing by radius**2 +``` + +Now let's see how fast it runs: + +```{code-cell} ipython3 +with qe.Timer(): + calculate_pi() +``` + +```{code-cell} ipython3 +with qe.Timer(): + calculate_pi() +``` + +By switching parallelization on and off (selecting `True` or +`False` in the `@njit` annotation), we can test the speed gain that +multithreading provides on top of JIT compilation. + +On our workstation, we find that parallelization increases execution speed by +a factor of 2 or 3. + +(If you are executing locally, you will get different numbers, depending mainly +on the number of CPUs on your machine.) + +```{solution-end} +``` + + +```{exercise} +:label: numba_ex4 + +In {doc}`our lecture on SciPy`, we discussed pricing a call option in a +setting where the underlying stock price had a simple and well-known +distribution. + +Here we discuss a more realistic setting. + +We recall that the price of the option obeys + +$$ +P = \beta^n \mathbb E \max\{ S_n - K, 0 \} +$$ + +where + +1. $\beta$ is a discount factor, +2. $n$ is the expiry date, +2. $K$ is the strike price and +3. $\{S_t\}$ is the price of the underlying asset at each time $t$. + +Suppose that `n, β, K = 20, 0.99, 100`. + +Assume that the stock price obeys + +$$ +\ln \frac{S_{t+1}}{S_t} = \mu + \sigma_t \xi_{t+1} +$$ + +where + +$$ + \sigma_t = \exp(h_t), + \quad + h_{t+1} = \rho h_t + \nu \eta_{t+1} +$$ + +Here $\{\xi_t\}$ and $\{\eta_t\}$ are IID and standard normal. + +(This is a **stochastic volatility** model, where the volatility $\sigma_t$ +varies over time.) + +Use the defaults `μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0`. + +(Here `S0` is $S_0$ and `h0` is $h_0$.) + +By generating $M$ paths $s_0, \ldots, s_n$, compute the Monte Carlo estimate + +$$ + \hat P_M + := \beta^n \mathbb E \max\{ S_n - K, 0 \} + \approx + \frac{1}{M} \sum_{m=1}^M \max \{S_n^m - K, 0 \} +$$ + + +of the price, applying Numba and parallelization. + +``` + + +```{solution-start} numba_ex4 +:class: dropdown +``` + + +With $s_t := \ln S_t$, the price dynamics become + +$$ +s_{t+1} = s_t + \mu + \exp(h_t) \xi_{t+1} +$$ + +Using this fact, the solution can be written as follows. + + +```{code-cell} ipython3 +from numpy.random import randn +M = 10_000_000 + +n, β, K = 20, 0.99, 100 +μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0 + +@njit(parallel=True) +def compute_call_price_parallel(β=β, + μ=μ, + S0=S0, + h0=h0, + K=K, + n=n, + ρ=ρ, + ν=ν, + M=M): + current_sum = 0.0 + # For each sample path + for m in prange(M): + s = np.log(S0) + h = h0 + # Simulate forward in time + for t in range(n): + s = s + μ + np.exp(h) * randn() + h = ρ * h + ν * randn() + # And add the value max{S_n - K, 0} to current_sum + current_sum += np.maximum(np.exp(s) - K, 0) + + return β**n * current_sum / M +``` + +Try swapping between `parallel=True` and `parallel=False` and noting the run time. + +If you are on a machine with many CPUs, the difference should be significant. + +```{solution-end} +``` diff --git a/lectures/numpy.md b/lectures/numpy.md index d707ead5..ecb6a294 100644 --- a/lectures/numpy.md +++ b/lectures/numpy.md @@ -78,6 +78,8 @@ called a [numpy.ndarray](https://numpy.org/doc/stable/reference/arrays.ndarray.h NumPy arrays power a very large proportion of the scientific Python ecosystem. +### Basics + To create a NumPy array containing only zeros we use [np.zeros](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html#numpy.zeros) ```{code-cell} python3 @@ -128,32 +130,33 @@ Consider the following assignment z = np.zeros(10) ``` -Here `z` is a *flat* array with no dimension --- neither row nor column vector. - -The dimension is recorded in the `shape` attribute, which is a tuple +Here `z` is a **flat** array --- neither row nor column vector. ```{code-cell} python3 z.shape ``` -Here the shape tuple has only one element, which is the length of the array (tuples with one element end with a comma). +Here the shape tuple has only one element, which is the length of the array +(tuples with one element end with a comma). -To give it dimension, we can change the `shape` attribute +To give it an additional dimension, we can change the `shape` attribute ```{code-cell} python3 -z.shape = (10, 1) +z.shape = (10, 1) # Convert flat array to column vector (two-dimensional) z ``` ```{code-cell} python3 -z = np.zeros(4) -z.shape = (2, 2) +z = np.zeros(4) # Flat array +z.shape = (2, 2) # Two-dimensional array z ``` -In the last case, to make the 2 by 2 array, we could also pass a tuple to the `zeros()` function, as +In the last case, to make the 2x2 array, we could also pass a tuple to the `zeros()` function, as in `z = np.zeros((2, 2))`. + + (creating_arrays)= ### Creating Arrays @@ -212,17 +215,9 @@ z See also `np.asarray`, which performs a similar function, but does not make a distinct copy of data already in a NumPy array. -```{code-cell} python3 -na = np.linspace(10, 20, 2) -na is np.asarray(na) # Does not copy NumPy arrays -``` +To read in the array data from a text file containing numeric data use `np.loadtxt` ---see [the documentation](https://numpy.org/doc/stable/reference/routines.io.html) for details. -```{code-cell} python3 -na is np.array(na) # Does make a new copy --- perhaps unnecessarily -``` -To read in the array data from a text file containing numeric data use `np.loadtxt` -or `np.genfromtxt`---see [the documentation](https://numpy.org/doc/stable/reference/routines.io.html) for details. ### Array Indexing @@ -265,8 +260,6 @@ z[0, 1] And so on. -Note that indices are still zero-based, to maintain compatibility with Python sequences. - Columns and rows can be extracted as follows ```{code-cell} python3 @@ -374,7 +367,8 @@ a.T # Equivalent to a.transpose() Another method worth knowing is `searchsorted()`. -If `z` is a nondecreasing array, then `z.searchsorted(a)` returns the index of the first element of `z` that is `>= a` +If `z` is a nondecreasing array, then `z.searchsorted(a)` returns the index of +the first element of `z` that is `>= a` ```{code-cell} python3 z = np.linspace(2, 4, 5) @@ -385,20 +379,6 @@ z z.searchsorted(2.2) ``` -Many of the methods discussed above have equivalent functions in the NumPy namespace - -```{code-cell} python3 -a = np.array((4, 3, 2, 1)) -``` - -```{code-cell} python3 -np.sum(a) -``` - -```{code-cell} python3 -np.mean(a) -``` - ## Arithmetic Operations @@ -457,8 +437,7 @@ In particular, `A * B` is *not* the matrix product, it is an element-wise produc ```{index} single: NumPy; Matrix Multiplication ``` -With Anaconda's scientific Python package based around Python 3.5 and above, -one can use the `@` symbol for matrix multiplication, as follows: +We use the `@` symbol for matrix multiplication, as follows: ```{code-cell} python3 A = np.ones((2, 2)) @@ -466,22 +445,8 @@ B = np.ones((2, 2)) A @ B ``` -(For older versions of Python and NumPy you need to use the [np.dot](https://numpy.org/doc/stable/reference/generated/numpy.dot.html) function) - -We can also use `@` to take the inner product of two flat arrays - -```{code-cell} python3 -A = np.array((1, 2)) -B = np.array((10, 20)) -A @ B -``` - -In fact, we can use `@` when one element is a Python list or tuple - -```{code-cell} python3 -A = np.array(((1, 2), (3, 4))) -A -``` +The syntax works with flat arrays --- NumPy makes an educated guess of what you +want: ```{code-cell} python3 A @ (0, 1) @@ -489,6 +454,8 @@ A @ (0, 1) Since we are post-multiplying, the tuple is treated as a column vector. + + (broadcasting)= ## Broadcasting @@ -713,19 +680,6 @@ ax.text(10.5, 7.0, '=', size=12, ha='center', va='center'); ``` -The previous broadcasting operation is equivalent to the following `for` loop - -```{code-cell} python3 - -row, column = a.shape -result = np.empty((3, 3)) -for i in range(row): - for j in range(column): - result[i, j] = a[i, j] + b[i,0] - -result -``` - In some cases, both operands will be expanded. When we have `a -> (3,)` and `b -> (3, 1)`, `a` will be expanded to `a -> (3, 3)`, and `b` will be expanded to `b -> (3, 3)`. @@ -872,71 +826,12 @@ To help us, we can use the following list of rules: - If `a -> (2, 2, 2)` and `b -> (1, 2, 2)`, then broadcasting will expand the first dimension of `b` so that `b -> (2, 2, 2)`; - If `a -> (3, 2, 2)` and `b -> (1, 1, 2)`, then broadcasting will expand `b` on all dimensions with shape 1 so that `b -> (3, 2, 2)`. -Here are code examples for broadcasting higher dimensional arrays - -```{code-cell} python3 -# a -> (2, 2, 2) and b -> (1, 2, 2) - -a = np.array( - [[[1, 2], - [2, 3]], - - [[2, 3], - [3, 4]]]) -print(f'the shape of array a is {a.shape}') - -b = np.array( - [[1,7], - [7,1]]) -print(f'the shape of array b is {b.shape}') - -a + b -``` - -```{code-cell} python3 -# a -> (3, 2, 2) and b -> (2,) - -a = np.array( - [[[1, 2], - [3, 4]], - - [[4, 5], - [6, 7]], - - [[7, 8], - [9, 10]]]) -print(f'the shape of array a is {a.shape}') - -b = np.array([3, 6]) -print(f'the shape of array b is {b.shape}') - -a + b -``` - * *Step 3:* After Step 1 and 2, if the two arrays still do not match, a `ValueError` will be raised. For example, suppose `a -> (2, 2, 3)` and `b -> (2, 2)` - By *Step 1*, `b` will be expanded to `b -> (1, 2, 2)`; - By *Step 2*, `b` will be expanded to `b -> (2, 2, 2)`; - We can see that they do not match each other after the first two steps. Thus, a `ValueError` will be raised -```{code-cell} python3 ---- -tags: [raises-exception] ---- -a = np.array( - [[[1, 2, 3], - [2, 3, 4]], - - [[2, 3, 4], - [3, 4, 5]]]) -print(f'the shape of array a is {a.shape}') - -b = np.array( - [[1,7], - [7,1]]) -print(f'the shape of array b is {b.shape}') -a + b -``` ## Mutability and Copying Arrays @@ -944,9 +839,17 @@ NumPy arrays are mutable data types, like Python lists. In other words, their contents can be altered (mutated) in memory after initialization. -We already saw examples above. +This is convenient but, when combined with Python's naming and reference model, +can lead to mistakes by NumPy beginners. + +In this section we review some key issues. -Here's another example: + +### Mutability + +We already saw examples of multability above. + +Here's another example of mutation of a NumPy array ```{code-cell} python3 a = np.array([42, 44]) @@ -1013,11 +916,15 @@ a Note that the change to `b` has not affected `a`. -## Additional Functionality -Let's look at some other useful things we can do with NumPy. -### Vectorized Functions + +## Additional Features + +Let's look at some other useful features of NumPy. + + +### Universal Functions ```{index} single: NumPy; Vectorized Functions ``` @@ -1038,9 +945,9 @@ for i in range(n): y[i] = np.sin(z[i]) ``` -Because they act element-wise on arrays, these functions are called *vectorized functions*. +Because they act element-wise on arrays, these functions are sometimes called **vectorized functions**. -In NumPy-speak, they are also called *ufuncs*, which stands for "universal functions". +In NumPy-speak, they are also called **ufuncs**, or **universal functions**. As we saw above, the usual arithmetic operations (`+`, `*`, etc.) also work element-wise, and combining these with the ufuncs gives a very large set of fast element-wise functions. @@ -1082,6 +989,9 @@ f(x) # Passing the same vector x as in the previous example However, this approach doesn't always obtain the same speed as a more carefully crafted vectorized function. +(Later we'll see that JAX has a powerful version of `np.vectorize` that can and usually does generate highly efficient code.) + + ### Comparisons ```{index} single: NumPy; Comparisons @@ -1172,140 +1082,41 @@ We'll cover the SciPy versions in more detail {doc}`soon `. For a comprehensive list of what's available in NumPy see [this documentation](https://numpy.org/doc/stable/reference/routines.html). -## Speed Comparisons +### Implicit Multithreading -```{index} single: Vectorization; Operations on Arrays -``` +[Previously](need_for_speed) we discussed the concept of parallelization via multithreading. -We mentioned in an {doc}`previous lecture ` that NumPy-based vectorization can -accelerate scientific applications. +NumPy tries to implement multithreading in much of its compiled code. -In this section we try some speed comparisons to illustrate this fact. +Let's look at an example to see this in action. -### Vectorization vs Loops +The next piece of code computes the eigenvalues of a large number of randomly +generated matrices. -Let's begin with some non-vectorized code, which uses a native Python loop to generate, -square and then sum a large number of random variables: - -```{code-cell} python3 -n = 1_000_000 -``` +It takes a few seconds to run. ```{code-cell} python3 -with qe.Timer(): - y = 0 # Will accumulate and store sum - for i in range(n): - x = random.uniform(0, 1) - y += x**2 -``` - -The following vectorized code achieves the same thing. - -```{code-cell} ipython -with qe.Timer(): - x = np.random.uniform(0, 1, n) - y = np.sum(x**2) -``` - -As you can see, the second code block runs much faster. Why? - -The second code block breaks the loop down into three basic operations - -1. draw `n` uniforms -1. square them -1. sum them - -These are sent as batch operators to optimized machine code. - -Apart from minor overheads associated with sending data back and forth, the result is C or Fortran-like speed. - -When we run batch operations on arrays like this, we say that the code is *vectorized*. - -The next section illustrates this point. - -(ufuncs)= -### Universal Functions - -```{index} single: NumPy; Universal Functions -``` - -As discussed above, many functions provided by NumPy are universal functions (ufuncs). - -By exploiting ufuncs, many operations can be vectorized, leading to faster -execution. - -For example, consider the problem of maximizing a function $f$ of two -variables $(x,y)$ over the square $[-a, a] \times [-a, a]$. - -For $f$ and $a$ let's choose - -$$ -f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2} -\quad \text{and} \quad -a = 3 -$$ - -Here's a plot of $f$ - -```{code-cell} ipython - -def f(x, y): - return np.cos(x**2 + y**2) / (1 + x**2 + y**2) - -xgrid = np.linspace(-3, 3, 50) -ygrid = xgrid -x, y = np.meshgrid(xgrid, ygrid) - -fig = plt.figure(figsize=(10, 8)) -ax = fig.add_subplot(111, projection='3d') -ax.plot_surface(x, - y, - f(x, y), - rstride=2, cstride=2, - cmap=cm.jet, - alpha=0.7, - linewidth=0.25) -ax.set_zlim(-0.5, 1.0) -ax.set_xlabel('$x$', fontsize=14) -ax.set_ylabel('$y$', fontsize=14) -plt.show() +n = 20 +m = 1000 +for i in range(n): + X = np.random.randn(m, m) + λ = np.linalg.eigvals(X) ``` -To maximize it, we're going to use a naive grid search: - -1. Evaluate $f$ for all $(x,y)$ in a grid on the square. -1. Return the maximum of observed values. - -The grid will be +Now, let's look at the output of the htop system monitor on our machine while +this code is running: -```{code-cell} python3 -grid = np.linspace(-3, 3, 1000) +```{figure} /_static/lecture_specific/parallelization/htop_parallel_npmat.png +:scale: 80 ``` -Here's a non-vectorized version that uses Python loops. - -```{code-cell} python3 -with qe.Timer(): - m = -np.inf +We can see that 4 of the 8 CPUs are running at full speed. - for x in grid: - for y in grid: - z = f(x, y) - if z > m: - m = z -``` +This is because NumPy's `eigvals` routine neatly splits up the tasks and +distributes them to different threads. -And here's a vectorized version - -```{code-cell} python3 -with qe.Timer(): - x, y = np.meshgrid(grid, grid) - np.max(f(x, y)) -``` -In the vectorized version, all the looping takes place in compiled code. -As you can see, the second version is *much* faster. ## Exercises diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md new file mode 100644 index 00000000..55de222b --- /dev/null +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -0,0 +1,525 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +(parallel)= +```{raw} jupyter + +``` + +# NumPy vs Numba vs JAX + +In the preceding lectures, we've discussed three core libraries for scientific +and numerical computing: + +* [NumPy](numpy) +* [Numba](numba) +* [JAX](jax_intro) + +Which one should we use in any given situation? + +This lecture addresses that question, at least partially, by discussing some use cases. + +Before getting started, we note that the first two are a natural pair: NumPy and +Numba play well together. + +JAX, on the other hand, stands alone. + +When considering each approach, we will consider not just efficiency and memory +footprint but also clarity and ease of use. + +In addition to what's in Anaconda, this lecture will need the following libraries: + +```{code-cell} ipython3 +--- +tags: [hide-output] +--- +!pip install quantecon jax +``` + +We will use the following imports. + +```{code-cell} ipython3 +import random +import numpy as np +import quantecon as qe +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d.axes3d import Axes3D +from matplotlib import cm +import jax +import jax.numpy as jnp +``` + +## Vectorized operations + +Some operations can be perfectly vectorized --- all loops are easily eliminated +and numerical operations are reduced to calculations on arrays. + +In this case, which approach is best? + +### Problem Statement + +Consider the problem of maximizing a function $f$ of two variables $(x,y)$ over +the square $[-a, a] \times [-a, a]$. + +For $f$ and $a$ let's choose + +$$ +f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2} +\quad \text{and} \quad +a = 3 +$$ + +Here's a plot of $f$ + +```{code-cell} ipython3 + +def f(x, y): + return np.cos(x**2 + y**2) / (1 + x**2 + y**2) + +xgrid = np.linspace(-3, 3, 50) +ygrid = xgrid +x, y = np.meshgrid(xgrid, ygrid) + +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(x, + y, + f(x, y), + rstride=2, cstride=2, + cmap=cm.jet, + alpha=0.7, + linewidth=0.25) +ax.set_zlim(-0.5, 1.0) +ax.set_xlabel('$x$', fontsize=14) +ax.set_ylabel('$y$', fontsize=14) +plt.show() +``` + +For the sake of this exercise, we're going to use brute force for the +maximization. + +1. Evaluate $f$ for all $(x,y)$ in a grid on the square. +1. Return the maximum of observed values. + +Just to illustrate the idea, here's a non-vectorized version that uses Python loops. + +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 50) +m = -np.inf +for x in grid: + for y in grid: + z = f(x, y) + if z > m: + m = z +``` + + +### NumPy vectorization + +If we switch to NumPy-style vectorization we can use a much larger grid and the +code executes relatively quickly. + +Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such +that `f(x, y)` generates all evaluations on the product grid. + +(This strategy dates back to Matlab.) + +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 3_000) +x, y = np.meshgrid(grid, grid) + +with qe.Timer(precision=8): + np.max(f(x, y)) +``` + +In the vectorized version, all the looping takes place in compiled code. + +Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs. + +```{note} +If you have a system monitor such as htop (Linux/Mac) or perfmon +(Windows), then try running this and then observing the load on your CPUs. + +(You will probably need to bump up the grid size to see large effects.) + +The output typically shows that the operation is successfully distributed across multiple threads. +``` + +(The parallelization cannot be highly efficient because the binary is compiled +before it sees the size of the arrays `x` and `y`.) + + +### A Comparison with Numba + +Now let's see if we can achieve better performance using Numba with a simple loop. + +```{code-cell} ipython3 +import numba + +@numba.jit +def compute_max_numba(grid): + m = -np.inf + for x in grid: + for y in grid: + z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) + if z > m: + m = z + return m + +grid = np.linspace(-3, 3, 3_000) + +with qe.Timer(precision=8): + compute_max_numba(grid) +``` + +```{code-cell} ipython3 +with qe.Timer(precision=8): + compute_max_numba(grid) +``` + + +Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy. + +On one hand, NumPy combines efficient arithmetic (like Numba) with some multithreading (unlike this Numba code), which provides an advantage. + +On the other hand, the Numba routine uses much less memory, since we are only +working with a single one-dimensional grid. + + +### Parallelized Numba + +Now let's try parallelization with Numba using `prange`: + +First we parallelize just the outer loop. + +```{code-cell} ipython3 +@numba.jit(parallel=True) +def compute_max_numba_parallel(grid): + n = len(grid) + m = -np.inf + for i in numba.prange(n): + for j in range(n): + x = grid[i] + y = grid[j] + z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) + if z > m: + m = z + return m + +with qe.Timer(precision=8): + compute_max_numba_parallel(grid) +``` + + +```{code-cell} ipython3 +with qe.Timer(precision=8): + compute_max_numba_parallel(grid) +``` + +Next we parallelize both loops. + +```{code-cell} ipython3 +@numba.jit(parallel=True) +def compute_max_numba_parallel_nested(grid): + n = len(grid) + m = -np.inf + for i in numba.prange(n): + for j in numba.prange(n): + x = grid[i] + y = grid[j] + z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) + if z > m: + m = z + return m + +with qe.Timer(precision=8): + compute_max_numba_parallel_nested(grid) +``` + +```{code-cell} ipython3 +with qe.Timer(precision=8): + compute_max_numba_parallel_nested(grid) +``` + + +Depending on your machine, you might or might not see large benefits from parallelization here. + +If you have a small number of cores, the overhead of thread management and synchronization can +overwhelm the benefits of parallel execution. + +For more powerful machines and larger grid sizes, parallelization can generate +large speed gains. + + + +### Vectorized code with JAX + +In most ways, vectorization is the same in JAX as it is in NumPy. + +But there are also some differences, which we highlight here. + +Let's start with the function. + + +```{code-cell} ipython3 +@jax.jit +def f(x, y): + return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2) + +``` + +As with NumPy, to get the right shape and the correct nested `for` loop +calculation, we can use a `meshgrid` operation designed for this purpose: + +```{code-cell} ipython3 +grid = jnp.linspace(-3, 3, 3_000) +x_mesh, y_mesh = np.meshgrid(grid, grid) + +with qe.Timer(precision=8): + z_mesh = f(x_mesh, y_mesh).block_until_ready() +``` + +Let's run again to eliminate compile time. + +```{code-cell} ipython3 +with qe.Timer(precision=8): + z_mesh = f(x_mesh, y_mesh).block_until_ready() +``` + +Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU. + +The compilation overhead is a one-time cost that pays off when the function is called repeatedly. + + +### JAX plus vmap + +There is one problem with both the NumPy code and the JAX code: + +While the flat arrays are low-memory + +```{code-cell} ipython3 +grid.nbytes +``` + +the mesh grids are memory intensive + +```{code-cell} ipython3 +x_mesh.nbytes + y_mesh.nbytes +``` + +This extra memory usage can be a big problem in actual research calculations. + +Fortunately, JAX admits a different approach +using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html). + +#### Version 1 + +Here's one way we can apply `vmap`. + +```{code-cell} ipython3 +# Set up f to compute f(x, y) at every x for any given y +f_vec_x = lambda y: f(grid, y) +# Vectorize this operation over all y +f_vec = jax.vmap(f_vec_x) +# Compute result at all y +z_vmap = f_vec(grid) +``` + +Let's see the timing: + +```{code-cell} ipython3 +with qe.Timer(precision=8): + z_vmap = f_vec(grid) + z_vmap.block_until_ready() +``` + +Let's check we got the right result: + + +```{code-cell} ipython3 +jnp.allclose(z_mesh, z_vmap) +``` + +The execution time is similar to as the mesh operation but we are using much +less memory. + +In addition, `vmap` allows us to break vectorization up into stages, which is +often easier to comprehend than the traditional approach. + +This will become more obvious when we tackle larger problems. + + +#### Version 2 + +Here's a more generic approach to using `vmap` that we often use in the lectures. + +First we vectorize in `y`. + +```{code-cell} ipython3 +f_vec_y = jax.vmap(f, in_axes=(None, 0)) +``` + +In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`. + +Next, we vectorize in the first argument, which is `x`. + +```{code-cell} ipython3 +f_vec = jax.vmap(f_vec_y, in_axes=(0, None)) +``` + +With this construction, we can now call $f$ directly on flat (low memory) arrays. + +```{code-cell} ipython3 +x, y = grid, grid +with qe.Timer(precision=8): + z_vmap = f_vec(x, y).block_until_ready() +``` + +Let's run it again to eliminate compilation time: + +```{code-cell} ipython3 +with qe.Timer(precision=8): + z_vmap = f_vec(x, y).block_until_ready() +``` + +Let's check we got the right result: + + +```{code-cell} ipython3 +jnp.allclose(z_mesh, z_vmap) +``` + + + +### Summary + +In our view, JAX is the winner for vectorized operations. + +It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap). + +Moreover, the `vmap` approach can sometimes lead to significantly clearer code. + +While Numba is impressive, the beauty of JAX is that, with fully vectorized +operations, we can run exactly the +same code on machines with hardware accelerators and reap all the benefits +without paying extra cost. + + +## Sequential operations + +Some operations are inherently sequential -- and hence difficult or impossible +to vectorize. + +In this case NumPy is a poor option and we are left with the choice of Numba or +JAX. + +To compare these choices, we will revisit the problem of iterating on the +quadratic map that we saw in our {doc}`Numba lecture `. + + +### Numba Version + +Here's the Numba version. + +```{code-cell} ipython3 +@numba.jit +def qm(x0, n, α=4.0): + x = np.empty(n+1) + x[0] = x0 + for t in range(n): + x[t+1] = α * x[t] * (1 - x[t]) + return x +``` + +Let's generate a time series of length 10,000,000 and time the execution: + +```{code-cell} ipython3 +n = 10_000_000 + +with qe.Timer(precision=8): + x = qm(0.1, n) +``` + +Let's run it again to eliminate compilation time: + +```{code-cell} ipython3 +with qe.Timer(precision=8): + x = qm(0.1, n) +``` + +Numba handles this sequential operation very efficiently. + +Notice that the second run is significantly faster after JIT compilation completes. + +Numba's compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one. + +### JAX Version + +Now let's create a JAX version using `lax.scan`: + +(We'll hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.) + +```{code-cell} ipython3 +from jax import lax +from functools import partial + +@partial(jax.jit, static_argnums=(1,)) +def qm_jax(x0, n, α=4.0): + def update(x, t): + x_new = α * x * (1 - x) + return x_new, x_new + + _, x = lax.scan(update, x0, jnp.arange(n)) + return jnp.concatenate([jnp.array([x0]), x]) +``` + +This code is not easy to read but, in essence, `lax.scan` repeatedly calls `qm_jax` and accumulates the returns `x_new` into an array. + +Let's time it with the same parameters: + +```{code-cell} ipython3 +with qe.Timer(precision=8): + x_jax = qm_jax(0.1, n).block_until_ready() +``` + +Let's run it again to eliminate compilation overhead: + +```{code-cell} ipython3 +with qe.Timer(precision=8): + x_jax = qm_jax(0.1, n).block_until_ready() +``` + +JAX is also efficient for this sequential operation. + +Both JAX and Numba deliver strong performance after compilation, with Numba +typically (but not always) offering slightly better speeds on purely sequential +operations. + +### Summary + +While both Numba and JAX deliver strong performance for sequential operations, +there are significant differences in code readability and ease of use. + +The Numba version is straightforward and natural to read: we simply allocate an +array and fill it element by element using a standard Python loop. + +This is exactly how most programmers think about the algorithm. + +The JAX version, on the other hand, requires using `lax.scan`, which is significantly less intuitive. + +Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba. + +For this type of sequential operation, Numba is the clear winner in terms of +code clarity and ease of implementation, as well as high performance. diff --git a/lectures/parallelization.md b/lectures/parallelization.md deleted file mode 100644 index d0e46967..00000000 --- a/lectures/parallelization.md +++ /dev/null @@ -1,603 +0,0 @@ ---- -jupytext: - text_representation: - extension: .md - format_name: myst -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -(parallel)= -```{raw} jupyter - -``` - -# Parallelization - -In addition to what's in Anaconda, this lecture will need the following libraries: - -```{code-cell} ipython ---- -tags: [hide-output] ---- -!pip install quantecon -``` - -## Overview - -The growth of CPU clock speed (i.e., the speed at which a single chain of logic can -be run) has slowed dramatically in recent years. - -This is unlikely to change in the near future, due to inherent physical -limitations on the construction of chips and circuit boards. - -Chip designers and computer programmers have responded to the slowdown by -seeking a different path to fast execution: parallelization. - -Hardware makers have increased the number of cores (physical CPUs) embedded in each machine. - -For programmers, the challenge has been to exploit these multiple CPUs by running many processes in parallel (i.e., simultaneously). - -This is particularly important in scientific programming, which requires handling - -* large amounts of data and -* CPU intensive simulations and other calculations. - -In this lecture we discuss parallelization for scientific computing, with a focus on - -1. the best tools for parallelization in Python and -1. how these tools can be applied to quantitative economic problems. - -Let's start with some imports: - -```{code-cell} ipython -import numpy as np -import quantecon as qe -import matplotlib.pyplot as plt -``` - -## Types of Parallelization - -Large textbooks have been written on different approaches to parallelization but we will keep a tight focus on what's most useful to us. - -We will briefly review the two main kinds of parallelization commonly used in -scientific computing and discuss their pros and cons. - -### Multiprocessing - -Multiprocessing means concurrent execution of multiple processes using more than one processor. - -In this context, a **process** is a chain of instructions (i.e., a program). - -Multiprocessing can be carried out on one machine with multiple CPUs or on a -collection of machines connected by a network. - -In the latter case, the collection of machines is usually called a -**cluster**. - -With multiprocessing, each process has its own memory space, although the -physical memory chip might be shared. - -### Multithreading - -Multithreading is similar to multiprocessing, except that, during execution, the threads all share the same memory space. - -Native Python struggles to implement multithreading due to some [legacy design -features](https://wiki.python.org/moin/GlobalInterpreterLock). - -But this is not a restriction for scientific libraries like NumPy and Numba. - -Functions imported from these libraries and JIT-compiled code run in low level -execution environments where Python's legacy restrictions don't apply. - -### Advantages and Disadvantages - -Multithreading is more lightweight because most system and memory resources -are shared by the threads. - -In addition, the fact that multiple threads all access a shared pool of memory -is extremely convenient for numerical programming. - -On the other hand, multiprocessing is more flexible and can be distributed -across clusters. - -For the great majority of what we do in these lectures, multithreading will -suffice. - -## Implicit Multithreading in NumPy - -Actually, you have already been using multithreading in your Python code, -although you might not have realized it. - -(We are, as usual, assuming that you are running the latest version of -Anaconda Python.) - -This is because NumPy cleverly implements multithreading in a lot of its -compiled code. - -Let's look at some examples to see this in action. - -### A Matrix Operation - -The next piece of code computes the eigenvalues of a large number of randomly -generated matrices. - -It takes a few seconds to run. - -```{code-cell} python3 -n = 20 -m = 1000 -for i in range(n): - X = np.random.randn(m, m) - λ = np.linalg.eigvals(X) -``` - -Now, let's look at the output of the htop system monitor on our machine while -this code is running: - -```{figure} /_static/lecture_specific/parallelization/htop_parallel_npmat.png -:scale: 80 -``` - -We can see that 4 of the 8 CPUs are running at full speed. - -This is because NumPy's `eigvals` routine neatly splits up the tasks and -distributes them to different threads. - -### A Multithreaded Ufunc - -Over the last few years, NumPy has managed to push this kind of multithreading -out to more and more operations. - -For example, let's return to a maximization problem {ref}`discussed previously `: - -```{code-cell} python3 -def f(x, y): - return np.cos(x**2 + y**2) / (1 + x**2 + y**2) - -grid = np.linspace(-3, 3, 5000) -x, y = np.meshgrid(grid, grid) -``` - -```{code-cell} ipython3 -with qe.Timer(): - np.max(f(x, y)) -``` - -If you have a system monitor such as htop (Linux/Mac) or perfmon -(Windows), then try running this and then observing the load on your CPUs. - -(You will probably need to bump up the grid size to see large effects.) - -At least on our machine, the output shows that the operation is successfully -distributed across multiple threads. - -This is one of the reasons why the vectorized code above is fast. - -### A Comparison with Numba - -To get some basis for comparison for the last example, let's try the same -thing with Numba. - -In fact there is an easy way to do this, since Numba can also be used to -create custom {ref}`ufuncs ` with the [@vectorize](https://numba.pydata.org/numba-doc/dev/user/vectorize.html) decorator. - -```{code-cell} python3 -from numba import vectorize - -@vectorize -def f_vec(x, y): - return np.cos(x**2 + y**2) / (1 + x**2 + y**2) - -np.max(f_vec(x, y)) # Run once to compile -``` - -```{code-cell} ipython3 -with qe.Timer(): - np.max(f_vec(x, y)) -``` - -At least on our machine, the difference in the speed between the -Numba version and the vectorized NumPy version shown above is not large. - -But there's quite a bit going on here so let's try to break down what is -happening. - -Both Numba and NumPy use efficient machine code that's specialized to these -floating point operations. - -However, the code NumPy uses is, in some ways, less efficient. - -The reason is that, in NumPy, the operation `np.cos(x**2 + y**2) / (1 + -x**2 + y**2)` generates several intermediate arrays. - -For example, a new array is created when `x**2` is calculated. - -The same is true when `y**2` is calculated, and then `x**2 + y**2` and so on. - -Numba avoids creating all these intermediate arrays by compiling one -function that is specialized to the entire operation. - -But if this is true, then why isn't the Numba code faster? - -The reason is that NumPy makes up for its disadvantages with implicit -multithreading, as we've just discussed. - -### Multithreading a Numba Ufunc - -Can we get both of these advantages at once? - -In other words, can we pair - -* the efficiency of Numba's highly specialized JIT compiled function and -* the speed gains from parallelization obtained by NumPy's implicit - multithreading? - -It turns out that we can, by adding some type information plus `target='parallel'`. - -```{code-cell} python3 -@vectorize('float64(float64, float64)', target='parallel') -def f_vec(x, y): - return np.cos(x**2 + y**2) / (1 + x**2 + y**2) - -np.max(f_vec(x, y)) # Run once to compile -``` - -```{code-cell} ipython3 -with qe.Timer(): - np.max(f_vec(x, y)) -``` - -Now our code runs significantly faster than the NumPy version. - -## Multithreaded Loops in Numba - -We just saw one approach to parallelization in Numba, using the `parallel` -flag in `@vectorize`. - -This is neat but, it turns out, not well suited to many problems we consider. - -Fortunately, Numba provides another approach to multithreading that will work -for us almost everywhere parallelization is possible. - -To illustrate, let's look first at a simple, single-threaded (i.e., non-parallelized) piece of code. - -The code simulates updating the wealth $w_t$ of a household via the rule - -$$ -w_{t+1} = R_{t+1} s w_t + y_{t+1} -$$ - -Here - -* $R$ is the gross rate of return on assets -* $s$ is the savings rate of the household and -* $y$ is labor income. - -We model both $R$ and $y$ as independent draws from a lognormal -distribution. - -Here's the code: - -```{code-cell} ipython -from numpy.random import randn -from numba import njit - -@njit -def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0): - """ - Updates household wealth. - """ - - # Draw shocks - R = np.exp(v1 * randn()) * (1 + r) - y = np.exp(v2 * randn()) - - # Update wealth - w = R * s * w + y - return w -``` - -Let's have a look at how wealth evolves under this rule. - -```{code-cell} ipython -fig, ax = plt.subplots() - -T = 100 -w = np.empty(T) -w[0] = 5 -for t in range(T-1): - w[t+1] = h(w[t]) - -ax.plot(w) -ax.set_xlabel('$t$', fontsize=12) -ax.set_ylabel('$w_{t}$', fontsize=12) -plt.show() -``` - -Now let's suppose that we have a large population of households and we want to -know what median wealth will be. - -This is not easy to solve with pencil and paper, so we will use simulation -instead. - -In particular, we will simulate a large number of households and then -calculate median wealth for this group. - -Suppose we are interested in the long-run average of this median over time. - -It turns out that, for the specification that we've chosen above, we can -calculate this by taking a one-period snapshot of what has happened to median -wealth of the group at the end of a long simulation. - -Moreover, provided the simulation period is long enough, initial conditions -don't matter. - -* This is due to something called ergodicity, which we will discuss [later on](https://python.quantecon.org/finite_markov.html#id15). - -So, in summary, we are going to simulate 50,000 households by - -1. arbitrarily setting initial wealth to 1 and -1. simulating forward in time for 1,000 periods. - -Then we'll calculate median wealth at the end period. - -Here's the code: - -```{code-cell} ipython -@njit -def compute_long_run_median(w0=1, T=1000, num_reps=50_000): - - obs = np.empty(num_reps) - for i in range(num_reps): - w = w0 - for t in range(T): - w = h(w) - obs[i] = w - - return np.median(obs) -``` - -Let's see how fast this runs: - -```{code-cell} ipython -with qe.Timer(): - compute_long_run_median() -``` - -To speed this up, we're going to parallelize it via multithreading. - -To do so, we add the `parallel=True` flag and change `range` to `prange`: - -```{code-cell} ipython -from numba import prange - -@njit(parallel=True) -def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000): - - obs = np.empty(num_reps) - for i in prange(num_reps): - w = w0 - for t in range(T): - w = h(w) - obs[i] = w - - return np.median(obs) -``` - -Let's look at the timing: - -```{code-cell} ipython -with qe.Timer(): - compute_long_run_median_parallel() -``` - -The speed-up is significant. - -### A Warning - -Parallelization works well in the outer loop of the last example because the individual tasks inside the loop are independent of each other. - -If this independence fails then parallelization is often problematic. - -For example, each step inside the inner loop depends on the last step, so -independence fails, and this is why we use ordinary `range` instead of `prange`. - -When you see us using `prange` in later lectures, it is because the -independence of tasks holds true. - -When you see us using ordinary `range` in a jitted function, it is either because the speed gain from parallelization is small or because independence fails. - -## Exercises - -```{exercise} -:label: parallel_ex1 - -In {ref}`an earlier exercise `, we used Numba to accelerate an -effort to compute the constant $\pi$ by Monte Carlo. - -Now try adding parallelization and see if you get further speed gains. - -You should not expect huge gains here because, while there are many -independent tasks (draw point and test if in circle), each one has low -execution time. - -Generally speaking, parallelization is less effective when the individual -tasks to be parallelized are very small relative to total execution time. - -This is due to overheads associated with spreading all of these small tasks across multiple CPUs. - -Nevertheless, with suitable hardware, it is possible to get nontrivial speed gains in this exercise. - -For the size of the Monte Carlo simulation, use something substantial, such as -`n = 100_000_000`. -``` - -```{solution-start} parallel_ex1 -:class: dropdown -``` - -Here is one solution: - -```{code-cell} python3 -from random import uniform - -@njit(parallel=True) -def calculate_pi(n=1_000_000): - count = 0 - for i in prange(n): - u, v = uniform(0, 1), uniform(0, 1) - d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2) - if d < 0.5: - count += 1 - - area_estimate = count / n - return area_estimate * 4 # dividing by radius**2 -``` - -Now let's see how fast it runs: - -```{code-cell} ipython3 -with qe.Timer(): - calculate_pi() -``` - -```{code-cell} ipython3 -with qe.Timer(): - calculate_pi() -``` - -By switching parallelization on and off (selecting `True` or -`False` in the `@njit` annotation), we can test the speed gain that -multithreading provides on top of JIT compilation. - -On our workstation, we find that parallelization increases execution speed by -a factor of 2 or 3. - -(If you are executing locally, you will get different numbers, depending mainly -on the number of CPUs on your machine.) - -```{solution-end} -``` - - -```{exercise} -:label: parallel_ex2 - -In {doc}`our lecture on SciPy`, we discussed pricing a call option in a -setting where the underlying stock price had a simple and well-known -distribution. - -Here we discuss a more realistic setting. - -We recall that the price of the option obeys - -$$ -P = \beta^n \mathbb E \max\{ S_n - K, 0 \} -$$ - -where - -1. $\beta$ is a discount factor, -2. $n$ is the expiry date, -2. $K$ is the strike price and -3. $\{S_t\}$ is the price of the underlying asset at each time $t$. - -Suppose that `n, β, K = 20, 0.99, 100`. - -Assume that the stock price obeys - -$$ -\ln \frac{S_{t+1}}{S_t} = \mu + \sigma_t \xi_{t+1} -$$ - -where - -$$ - \sigma_t = \exp(h_t), - \quad - h_{t+1} = \rho h_t + \nu \eta_{t+1} -$$ - -Here $\{\xi_t\}$ and $\{\eta_t\}$ are IID and standard normal. - -(This is a **stochastic volatility** model, where the volatility $\sigma_t$ -varies over time.) - -Use the defaults `μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0`. - -(Here `S0` is $S_0$ and `h0` is $h_0$.) - -By generating $M$ paths $s_0, \ldots, s_n$, compute the Monte Carlo estimate - -$$ - \hat P_M - := \beta^n \mathbb E \max\{ S_n - K, 0 \} - \approx - \frac{1}{M} \sum_{m=1}^M \max \{S_n^m - K, 0 \} -$$ - - -of the price, applying Numba and parallelization. - -``` - - -```{solution-start} parallel_ex2 -:class: dropdown -``` - - -With $s_t := \ln S_t$, the price dynamics become - -$$ -s_{t+1} = s_t + \mu + \exp(h_t) \xi_{t+1} -$$ - -Using this fact, the solution can be written as follows. - - -```{code-cell} ipython3 -from numpy.random import randn -M = 10_000_000 - -n, β, K = 20, 0.99, 100 -μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0 - -@njit(parallel=True) -def compute_call_price_parallel(β=β, - μ=μ, - S0=S0, - h0=h0, - K=K, - n=n, - ρ=ρ, - ν=ν, - M=M): - current_sum = 0.0 - # For each sample path - for m in prange(M): - s = np.log(S0) - h = h0 - # Simulate forward in time - for t in range(n): - s = s + μ + np.exp(h) * randn() - h = ρ * h + ν * randn() - # And add the value max{S_n - K, 0} to current_sum - current_sum += np.maximum(np.exp(s) - K, 0) - - return β**n * current_sum / M -``` - -Try swapping between `parallel=True` and `parallel=False` and noting the run time. - -If you are on a machine with many CPUs, the difference should be significant. - -```{solution-end} -```