Skip to content

Commit 4119c20

Browse files
jstacclaude
andcommitted
Reorganize parallel programming lectures and improve content flow
Major restructuring of parallelization-related content across lectures to improve pedagogical flow and consolidate related material. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 5b5e342 commit 4119c20

File tree

9 files changed

+864
-809
lines changed

9 files changed

+864
-809
lines changed
432 KB
Loading
1.6 MB
Loading

lectures/_toc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ parts:
2323
numbered: true
2424
chapters:
2525
- file: numba
26-
- file: parallelization
26+
- file: numpy_vs_numba_vs_jax
2727
- file: jax_intro
2828
- caption: Working with Data
2929
numbered: true

lectures/jax_intro.md

Lines changed: 1 addition & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -513,167 +513,14 @@ plt.show()
513513

514514
We defer further exploration of automatic differentiation with JAX until {doc}`jax:autodiff`.
515515

516-
## Writing vectorized code
517-
518-
Writing fast JAX code requires shifting repetitive tasks from loops to array processing operations, so that the JAX compiler can easily understand the whole operation and generate more efficient machine code.
519-
520-
This procedure is called **vectorization** or **array programming**, and will be
521-
familiar to anyone who has used NumPy or MATLAB.
522-
523-
In most ways, vectorization is the same in JAX as it is in NumPy.
524-
525-
But there are also some differences, which we highlight here.
526-
527-
As a running example, consider the function
528-
529-
$$
530-
f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2}
531-
$$
532-
533-
Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points and then plot it.
534-
535-
To clarify, here is the slow `for` loop version.
536-
537-
```{code-cell} ipython3
538-
@jax.jit
539-
def f(x, y):
540-
return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)
541-
542-
n = 80
543-
x = jnp.linspace(-2, 2, n)
544-
y = x
545-
546-
z_loops = np.empty((n, n))
547-
```
548-
549-
```{code-cell} ipython3
550-
with qe.Timer():
551-
for i in range(n):
552-
for j in range(n):
553-
z_loops[i, j] = f(x[i], y[j])
554-
```
555-
556-
Even for this very small grid, the run time is extremely slow.
557-
558-
(Notice that we used a NumPy array for `z_loops` because we wanted to write to it.)
559-
560-
+++
561-
562-
OK, so how can we do the same operation in vectorized form?
563-
564-
If you are new to vectorization, you might guess that we can simply write
565-
566-
```{code-cell} ipython3
567-
z_bad = f(x, y)
568-
```
569-
570-
But this gives us the wrong result because JAX doesn't understand the nested for loop.
571-
572-
```{code-cell} ipython3
573-
z_bad.shape
574-
```
575-
576-
Here is what we actually wanted:
577-
578-
```{code-cell} ipython3
579-
z_loops.shape
580-
```
581-
582-
To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation designed for this purpose:
583-
584-
```{code-cell} ipython3
585-
x_mesh, y_mesh = jnp.meshgrid(x, y)
586-
```
587-
588-
Now we get what we want and the execution time is very fast.
589-
590-
```{code-cell} ipython3
591-
with qe.Timer():
592-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
593-
```
594-
595-
Let's run again to eliminate compile time.
596-
597-
```{code-cell} ipython3
598-
with qe.Timer():
599-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
600-
```
601-
602-
Let's confirm that we got the right answer.
603-
604-
```{code-cell} ipython3
605-
jnp.allclose(z_mesh, z_loops)
606-
```
607-
608-
Now we can set up a serious grid and run the same calculation (on the larger grid) in a short amount of time.
609-
610-
```{code-cell} ipython3
611-
n = 6000
612-
x = jnp.linspace(-2, 2, n)
613-
y = x
614-
x_mesh, y_mesh = jnp.meshgrid(x, y)
615-
```
616-
617-
```{code-cell} ipython3
618-
with qe.Timer():
619-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
620-
```
621-
622-
But there is one problem here: the mesh grids use a lot of memory.
623-
624-
```{code-cell} ipython3
625-
x_mesh.nbytes + y_mesh.nbytes
626-
```
627-
628-
By comparison, the flat array `x` is just
629-
630-
```{code-cell} ipython3
631-
x.nbytes # and y is just a pointer to x
632-
```
633-
634-
This extra memory usage can be a big problem in actual research calculations.
635-
636-
So let's try a different approach using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html)
637-
638-
+++
639-
640-
First we vectorize `f` in `y`.
641-
642-
```{code-cell} ipython3
643-
f_vec_y = jax.vmap(f, in_axes=(None, 0))
644-
```
645-
646-
In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`.
647-
648-
Next, we vectorize in the first argument, which is `x`.
649-
650-
```{code-cell} ipython3
651-
f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
652-
```
653-
654-
With this construction, we can now call the function $f$ on flat (low memory) arrays.
655-
656-
```{code-cell} ipython3
657-
with qe.Timer():
658-
z_vmap = f_vec(x, y).block_until_ready()
659-
```
660-
661-
The execution time is essentially the same as the mesh operation but we are using much less memory.
662-
663-
And we produce the correct answer:
664-
665-
```{code-cell} ipython3
666-
jnp.allclose(z_vmap, z_mesh)
667-
```
668-
669516
## Exercises
670517

671518

672519
```{exercise-start}
673520
:label: jax_intro_ex2
674521
```
675522

676-
In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option.
523+
In the Exercise section of {doc}`a lecture on Numba <numba>`, we used Monte Carlo to price a European call option.
677524

678525
The code was accelerated by Numba-based multithreading.
679526

lectures/need_for_speed.md

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,171 @@ traditional vectorization and towards the use of [just-in-time compilers](https:
320320
In later lectures in this series, we will learn about how modern Python libraries exploit
321321
just-in-time compilers to generate fast, efficient, parallelized machine code.
322322

323+
## Parallelization
324+
325+
The growth of CPU clock speed (i.e., the speed at which a single chain of logic can
326+
be run) has slowed dramatically in recent years.
327+
328+
This is unlikely to change in the near future, due to inherent physical
329+
limitations on the construction of chips and circuit boards.
330+
331+
Chip designers and computer programmers have responded to the slowdown by
332+
seeking a different path to fast execution: parallelization.
333+
334+
Hardware makers have increased the number of cores (physical CPUs) embedded in each machine.
335+
336+
For programmers, the challenge has been to exploit these multiple CPUs by running many processes in parallel (i.e., simultaneously).
337+
338+
This is particularly important in scientific programming, which requires handling
339+
340+
* large amounts of data and
341+
* CPU intensive simulations and other calculations.
342+
343+
In this lecture we discuss parallelization for scientific computing, with a focus on
344+
345+
1. the best tools for parallelization in Python and
346+
1. how these tools can be applied to quantitative economic problems.
347+
348+
Let's start with some imports:
349+
350+
```{code-cell} ipython
351+
import numpy as np
352+
import quantecon as qe
353+
import matplotlib.pyplot as plt
354+
```
355+
356+
### Parallelization on CPUs
357+
358+
Large textbooks have been written on different approaches to parallelization but we will keep a tight focus on what's most useful to us.
359+
360+
We will briefly review the two main kinds of CPU-based parallelization commonly used in
361+
scientific computing and discuss their pros and cons.
362+
363+
#### Multiprocessing
364+
365+
Multiprocessing means concurrent execution of multiple processes using more than one processor.
366+
367+
In this context, a **process** is a chain of instructions (i.e., a program).
368+
369+
Multiprocessing can be carried out on one machine with multiple CPUs or on a
370+
collection of machines connected by a network.
371+
372+
In the latter case, the collection of machines is usually called a
373+
**cluster**.
374+
375+
With multiprocessing, each process has its own memory space, although the
376+
physical memory chip might be shared.
377+
378+
#### Multithreading
379+
380+
Multithreading is similar to multiprocessing, except that, during execution, the threads all share the same memory space.
381+
382+
Native Python struggles to implement multithreading due to some [legacy design
383+
features](https://wiki.python.org/moin/GlobalInterpreterLock).
384+
385+
But this is not a restriction for scientific libraries like NumPy and Numba.
386+
387+
Functions imported from these libraries and JIT-compiled code run in low level
388+
execution environments where Python's legacy restrictions don't apply.
389+
390+
#### Advantages and Disadvantages
391+
392+
Multithreading is more lightweight because most system and memory resources
393+
are shared by the threads.
394+
395+
In addition, the fact that multiple threads all access a shared pool of memory
396+
is extremely convenient for numerical programming.
397+
398+
On the other hand, multiprocessing is more flexible and can be distributed
399+
across clusters.
400+
401+
For the great majority of what we do in these lectures, multithreading will
402+
suffice.
403+
404+
### Hardware Accelerators
405+
406+
While CPUs with multiple cores have become standard for parallel computing, a more dramatic shift has occurred with the rise of specialized hardware accelerators.
407+
408+
These accelerators are designed specifically for the kinds of highly parallel computations that arise in scientific computing, machine learning, and data science.
409+
410+
#### GPUs and TPUs
411+
412+
The two most important types of hardware accelerators are
413+
414+
* **GPUs** (Graphics Processing Units) and
415+
* **TPUs** (Tensor Processing Units).
416+
417+
GPUs were originally designed for rendering graphics, which requires performing the same operation on many pixels simultaneously.
418+
419+
Scientists and engineers realized that this same architecture --- many simple processors working in parallel --- is ideal for scientific computing tasks such as
420+
421+
* matrix operations,
422+
* numerical simulation,
423+
* solving partial differential equations and
424+
* training machine learning models.
425+
426+
TPUs are a more recent development, designed by Google specifically for machine learning workloads.
427+
428+
Like GPUs, TPUs excel at performing massive numbers of matrix operations in parallel.
429+
430+
#### Why GPUs Matter for Scientific Computing
431+
432+
The performance gains from using GPUs can be dramatic.
433+
434+
A modern GPU can contain thousands of small processing cores, compared to the 8-64 cores typically found in CPUs.
435+
436+
When a problem can be expressed as many independent operations on arrays of data, GPUs can be orders of magnitude faster than CPUs.
437+
438+
This is particularly relevant for scientific computing because many algorithms in
439+
440+
* linear algebra,
441+
* optimization,
442+
* Monte Carlo simulation and
443+
* numerical methods for differential equations
444+
445+
naturally map onto the parallel architecture of GPUs.
446+
447+
#### Single GPUs vs GPU Servers
448+
449+
There are two common ways to access GPU resources:
450+
451+
**Single GPU Systems**
452+
453+
Many workstations and laptops now come with capable GPUs, or can be equipped with them.
454+
455+
```{figure} /_static/lecture_specific/need_for_speed/geforce.png
456+
:scale: 40
457+
```
458+
459+
A single modern GPU can dramatically accelerate many scientific computing tasks.
460+
461+
For individual researchers and small projects, a single GPU is often sufficient.
462+
463+
Python libraries like JAX, PyTorch, and TensorFlow can automatically detect and use available GPUs with minimal code changes.
464+
465+
**Multi-GPU Servers**
466+
467+
For larger-scale problems, servers containing multiple GPUs (often 4-8 GPUs per server) are increasingly common.
468+
469+
```{figure} /_static/lecture_specific/need_for_speed/dgx.png
470+
:scale: 23
471+
```
472+
473+
These can be located
474+
475+
* in local compute clusters,
476+
* in university or national lab computing facilities, or
477+
* in cloud computing platforms (AWS, Google Cloud, Azure, etc.).
478+
479+
With appropriate software, computations can be distributed across multiple GPUs, either within a single server or across multiple servers.
480+
481+
This enables researchers to tackle problems that would be infeasible on a single GPU or CPU.
482+
483+
#### GPU Programming in Python
484+
485+
The good news for Python users is that many scientific libraries now support GPU acceleration with minimal changes to existing code.
486+
487+
For example, JAX code that runs on CPUs can often run on GPUs simply by ensuring the data is placed on the GPU device.
488+
489+
We will explore GPU computing in more detail in later lectures, particularly when we discuss JAX.
490+

0 commit comments

Comments
 (0)