You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.
510
+
* We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism.
511
+
512
+
Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place.
513
+
514
+
Let's time it with the same parameters:
515
+
516
+
```{code-cell} ipython3
517
+
with qe.Timer():
518
+
# First run
519
+
x_jax = qm_jax_fori(0.1, n)
520
+
# Hold interpreter
521
+
x_jax.block_until_ready()
522
+
```
523
+
524
+
Let's run it again to eliminate compilation overhead:
525
+
526
+
```{code-cell} ipython3
527
+
with qe.Timer():
528
+
# Second run
529
+
x_jax = qm_jax_fori(0.1, n)
530
+
# Hold interpreter
531
+
x_jax.block_until_ready()
532
+
```
533
+
534
+
JAX is also quite efficient for this sequential operation.
535
+
536
+
537
+
There's another way we can implement the loop that uses `lax.scan`.
538
+
539
+
This alternative is arguably more in line with JAX's functional approach ---
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
507
555
508
-
```{note}
509
-
We specify `device=cpu` in the `jax.jit` decorator because this computation
510
-
consists of many small sequential operations, leaving little opportunity for the
511
-
GPU to exploit parallelism. As a result, kernel-launch overhead tends to
512
-
dominate on the GPU, making the CPU a better
513
-
fit.
514
-
```
515
-
516
556
Let's time it with the same parameters:
517
557
518
558
```{code-cell} ipython3
519
559
with qe.Timer():
520
560
# First run
521
-
x_jax = qm_jax(0.1, n)
561
+
x_jax = qm_jax_scan(0.1, n)
522
562
# Hold interpreter
523
563
x_jax.block_until_ready()
524
564
```
@@ -528,13 +568,11 @@ Let's run it again to eliminate compilation overhead:
528
568
```{code-cell} ipython3
529
569
with qe.Timer():
530
570
# Second run
531
-
x_jax = qm_jax(0.1, n)
571
+
x_jax = qm_jax_scan(0.1, n)
532
572
# Hold interpreter
533
573
x_jax.block_until_ready()
534
574
```
535
575
536
-
JAX is also quite efficient for this sequential operation.
537
-
538
576
Both JAX and Numba deliver strong performance after compilation.
539
577
540
578
@@ -547,9 +585,11 @@ array and fill it element by element using a standard Python loop.
547
585
548
586
This is exactly how most programmers think about the algorithm.
549
587
550
-
The JAX version, on the other hand, requires using `lax.scan`, which is significantly less intuitive.
588
+
The JAX versions, on the other hand, require either `lax.fori_loop` or
589
+
`lax.scan`, both of which are less intuitive than a standard Python loop.
551
590
552
-
Additionally, JAX's immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.
591
+
While JAX's `at[t].set` syntax does allow element-wise updates, the overall code
592
+
remains harder to read than the Numba equivalent.
553
593
554
594
For this type of sequential operation, Numba is the clear winner in terms of
555
595
code clarity and ease of implementation.
@@ -575,12 +615,12 @@ For **sequential operations**, Numba has clear advantages.
575
615
The code is natural and readable --- just a Python loop with a decorator ---
576
616
and performance is excellent.
577
617
578
-
JAX can handle sequential problems via `lax.scan`, but the syntax is less
579
-
intuitive.
618
+
JAX can handle sequential problems via `lax.fori_loop` or `lax.scan`, but
619
+
the syntax is less intuitive.
580
620
581
621
```{note}
582
-
One important advantage of `lax.scan` is that it supports automatic
583
-
differentiation through the loop, which Numba cannot do.
622
+
One important advantage of `lax.fori_loop` and `lax.scan` is that they
623
+
support automatic differentiation through the loop, which Numba cannot do.
584
624
If you need to differentiate through a sequential computation (e.g., computing
585
625
sensitivities of a trajectory to model parameters), JAX is the better choice
0 commit comments