You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Reorganize parallel programming lectures and improve content flow
Major restructuring of parallelization-related content across lectures to
improve pedagogical flow and consolidate related material.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
Copy file name to clipboardExpand all lines: lectures/jax_intro.md
+1-154Lines changed: 1 addition & 154 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -513,167 +513,14 @@ plt.show()
513
513
514
514
We defer further exploration of automatic differentiation with JAX until {doc}`jax:autodiff`.
515
515
516
-
## Writing vectorized code
517
-
518
-
Writing fast JAX code requires shifting repetitive tasks from loops to array processing operations, so that the JAX compiler can easily understand the whole operation and generate more efficient machine code.
519
-
520
-
This procedure is called **vectorization** or **array programming**, and will be
521
-
familiar to anyone who has used NumPy or MATLAB.
522
-
523
-
In most ways, vectorization is the same in JAX as it is in NumPy.
524
-
525
-
But there are also some differences, which we highlight here.
526
-
527
-
As a running example, consider the function
528
-
529
-
$$
530
-
f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2}
531
-
$$
532
-
533
-
Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points and then plot it.
534
-
535
-
To clarify, here is the slow `for` loop version.
536
-
537
-
```{code-cell} ipython3
538
-
@jax.jit
539
-
def f(x, y):
540
-
return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)
541
-
542
-
n = 80
543
-
x = jnp.linspace(-2, 2, n)
544
-
y = x
545
-
546
-
z_loops = np.empty((n, n))
547
-
```
548
-
549
-
```{code-cell} ipython3
550
-
with qe.Timer():
551
-
for i in range(n):
552
-
for j in range(n):
553
-
z_loops[i, j] = f(x[i], y[j])
554
-
```
555
-
556
-
Even for this very small grid, the run time is extremely slow.
557
-
558
-
(Notice that we used a NumPy array for `z_loops` because we wanted to write to it.)
559
-
560
-
+++
561
-
562
-
OK, so how can we do the same operation in vectorized form?
563
-
564
-
If you are new to vectorization, you might guess that we can simply write
565
-
566
-
```{code-cell} ipython3
567
-
z_bad = f(x, y)
568
-
```
569
-
570
-
But this gives us the wrong result because JAX doesn't understand the nested for loop.
571
-
572
-
```{code-cell} ipython3
573
-
z_bad.shape
574
-
```
575
-
576
-
Here is what we actually wanted:
577
-
578
-
```{code-cell} ipython3
579
-
z_loops.shape
580
-
```
581
-
582
-
To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation designed for this purpose:
583
-
584
-
```{code-cell} ipython3
585
-
x_mesh, y_mesh = jnp.meshgrid(x, y)
586
-
```
587
-
588
-
Now we get what we want and the execution time is very fast.
589
-
590
-
```{code-cell} ipython3
591
-
with qe.Timer():
592
-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
593
-
```
594
-
595
-
Let's run again to eliminate compile time.
596
-
597
-
```{code-cell} ipython3
598
-
with qe.Timer():
599
-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
600
-
```
601
-
602
-
Let's confirm that we got the right answer.
603
-
604
-
```{code-cell} ipython3
605
-
jnp.allclose(z_mesh, z_loops)
606
-
```
607
-
608
-
Now we can set up a serious grid and run the same calculation (on the larger grid) in a short amount of time.
609
-
610
-
```{code-cell} ipython3
611
-
n = 6000
612
-
x = jnp.linspace(-2, 2, n)
613
-
y = x
614
-
x_mesh, y_mesh = jnp.meshgrid(x, y)
615
-
```
616
-
617
-
```{code-cell} ipython3
618
-
with qe.Timer():
619
-
z_mesh = f(x_mesh, y_mesh).block_until_ready()
620
-
```
621
-
622
-
But there is one problem here: the mesh grids use a lot of memory.
623
-
624
-
```{code-cell} ipython3
625
-
x_mesh.nbytes + y_mesh.nbytes
626
-
```
627
-
628
-
By comparison, the flat array `x` is just
629
-
630
-
```{code-cell} ipython3
631
-
x.nbytes # and y is just a pointer to x
632
-
```
633
-
634
-
This extra memory usage can be a big problem in actual research calculations.
635
-
636
-
So let's try a different approach using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html)
637
-
638
-
+++
639
-
640
-
First we vectorize `f` in `y`.
641
-
642
-
```{code-cell} ipython3
643
-
f_vec_y = jax.vmap(f, in_axes=(None, 0))
644
-
```
645
-
646
-
In the line above, `(None, 0)` indicates that we are vectorizing in the second argument, which is `y`.
647
-
648
-
Next, we vectorize in the first argument, which is `x`.
649
-
650
-
```{code-cell} ipython3
651
-
f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
652
-
```
653
-
654
-
With this construction, we can now call the function $f$ on flat (low memory) arrays.
655
-
656
-
```{code-cell} ipython3
657
-
with qe.Timer():
658
-
z_vmap = f_vec(x, y).block_until_ready()
659
-
```
660
-
661
-
The execution time is essentially the same as the mesh operation but we are using much less memory.
662
-
663
-
And we produce the correct answer:
664
-
665
-
```{code-cell} ipython3
666
-
jnp.allclose(z_vmap, z_mesh)
667
-
```
668
-
669
516
## Exercises
670
517
671
518
672
519
```{exercise-start}
673
520
:label: jax_intro_ex2
674
521
```
675
522
676
-
In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option.
523
+
In the Exercise section of {doc}`a lecture on Numba <numba>`, we used Monte Carlo to price a European call option.
677
524
678
525
The code was accelerated by Numba-based multithreading.
But this is not a restriction for scientific libraries like NumPy and Numba.
386
+
387
+
Functions imported from these libraries and JIT-compiled code run in low level
388
+
execution environments where Python's legacy restrictions don't apply.
389
+
390
+
#### Advantages and Disadvantages
391
+
392
+
Multithreading is more lightweight because most system and memory resources
393
+
are shared by the threads.
394
+
395
+
In addition, the fact that multiple threads all access a shared pool of memory
396
+
is extremely convenient for numerical programming.
397
+
398
+
On the other hand, multiprocessing is more flexible and can be distributed
399
+
across clusters.
400
+
401
+
For the great majority of what we do in these lectures, multithreading will
402
+
suffice.
403
+
404
+
### Hardware Accelerators
405
+
406
+
While CPUs with multiple cores have become standard for parallel computing, a more dramatic shift has occurred with the rise of specialized hardware accelerators.
407
+
408
+
These accelerators are designed specifically for the kinds of highly parallel computations that arise in scientific computing, machine learning, and data science.
409
+
410
+
#### GPUs and TPUs
411
+
412
+
The two most important types of hardware accelerators are
413
+
414
+
***GPUs** (Graphics Processing Units) and
415
+
***TPUs** (Tensor Processing Units).
416
+
417
+
GPUs were originally designed for rendering graphics, which requires performing the same operation on many pixels simultaneously.
418
+
419
+
Scientists and engineers realized that this same architecture --- many simple processors working in parallel --- is ideal for scientific computing tasks such as
420
+
421
+
* matrix operations,
422
+
* numerical simulation,
423
+
* solving partial differential equations and
424
+
* training machine learning models.
425
+
426
+
TPUs are a more recent development, designed by Google specifically for machine learning workloads.
427
+
428
+
Like GPUs, TPUs excel at performing massive numbers of matrix operations in parallel.
429
+
430
+
#### Why GPUs Matter for Scientific Computing
431
+
432
+
The performance gains from using GPUs can be dramatic.
433
+
434
+
A modern GPU can contain thousands of small processing cores, compared to the 8-64 cores typically found in CPUs.
435
+
436
+
When a problem can be expressed as many independent operations on arrays of data, GPUs can be orders of magnitude faster than CPUs.
437
+
438
+
This is particularly relevant for scientific computing because many algorithms in
439
+
440
+
* linear algebra,
441
+
* optimization,
442
+
* Monte Carlo simulation and
443
+
* numerical methods for differential equations
444
+
445
+
naturally map onto the parallel architecture of GPUs.
446
+
447
+
#### Single GPUs vs GPU Servers
448
+
449
+
There are two common ways to access GPU resources:
450
+
451
+
**Single GPU Systems**
452
+
453
+
Many workstations and laptops now come with capable GPUs, or can be equipped with them.
0 commit comments