@@ -135,18 +135,36 @@ for x in grid:
135135
136136Let's switch to NumPy and use a larger grid
137137
138- Here we use ` np.meshgrid ` to create two-dimensional input grids ` x ` and ` y ` such
139- that ` f(x, y) ` generates all evaluations on the product grid.
138+ ``` {code-cell} ipython3
139+ grid = np.linspace(-3, 3, 3_000) # Large grid
140+ ```
141+
142+ As a first pass of vectorization we might try something like this
143+
144+ ``` {code-cell} ipython3
145+ # Large grid
146+ z = np.max(f(grid, grid)) # This is wrong!
147+ ```
148+
149+ The problem here is that ` f(grid, grid) ` doesn't obey the nested loop.
150+
151+ In terms of the figure above, it only computes the values of ` f ` along the
152+ diagonal.
153+
154+ To trick NumPy into calculating ` f(x,y) ` on every ` x,y ` pair, we need to use ` np.meshgrid ` .
155+
156+ Here we use ` np.meshgrid ` to create two-dimensional input grids ` x ` and ` y `
157+ such that ` f(x, y) ` generates all evaluations on the product grid.
140158
141159
142160``` {code-cell} ipython3
143161# Large grid
144162grid = np.linspace(-3, 3, 3_000)
145163
146- x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
164+ x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
147165
148166with qe.Timer():
149- z_max_numpy = np.max(f(x, y))
167+ z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
150168```
151169
152170In the vectorized version, all the looping takes place in compiled code.
@@ -159,11 +177,30 @@ The output should be close to one:
159177print(f"NumPy result: {z_max_numpy:.6f}")
160178```
161179
180+ ### Memory Issues
181+
182+ So we have the right solution in reasonable time --- but memory usage is huge.
183+
184+ While the flat arrays are low-memory
185+
186+ ``` {code-cell} ipython3
187+ grid.nbytes
188+ ```
189+
190+ the mesh grids are two-dimensional and hence very memory intensive
191+
192+ ``` {code-cell} ipython3
193+ x_mesh.nbytes + y_mesh.nbytes
194+ ```
195+
196+ Moreover, NumPy's eager execution creates many intermediate arrays of the same size!
197+
198+ This kind of memory usage can be a big problem in actual research calculations.
162199
163200
164201### A Comparison with Numba
165202
166- Now let 's see if we can achieve better performance using Numba with a simple loop.
203+ Let 's see if we can achieve better performance using Numba with a simple loop.
167204
168205``` {code-cell} ipython3
169206@numba.jit
@@ -194,15 +231,13 @@ with qe.Timer():
194231 compute_max_numba(grid)
195232```
196233
197- Depending on your machine, the Numba version might be either slower or faster than NumPy.
234+ Notice how we are using almost no memory --- we just need the one-dimensional ` grid `
198235
199- In most cases we find that Numba is slightly better .
236+ Moreover, execution speed is good .
200237
201- On the one hand, NumPy combines efficient arithmetic with some
202- multithreading, which provides an advantage.
238+ On most machines, the Numba version will be somewhat faster than NumPy.
203239
204- On the other hand, the Numba routine uses much less memory, since we are only
205- working with a single one-dimensional grid.
240+ The reason is efficient machine code plus less memory read-write.
206241
207242
208243### Parallelized Numba
@@ -301,27 +336,11 @@ The compilation overhead is a one-time cost that pays off when the function is c
301336
302337### JAX plus vmap
303338
304- There is one problem with both the NumPy code and the JAX code above:
305-
306- While the flat arrays are low-memory
307-
308- ``` {code-cell} ipython3
309- grid.nbytes
310- ```
311-
312- the mesh grids are memory intensive
313-
314- ``` {code-cell} ipython3
315- x_mesh.nbytes + y_mesh.nbytes
316- ```
339+ Because we used ` jax.jit ` above, we avoided creating many intermediate arrays.
317340
318- This extra memory usage can be a big problem in actual research calculations .
341+ But we still create the big arrays ` z_max ` , ` x_mesh ` , and ` y_mesh ` .
319342
320- Fortunately, JAX admits a different approach
321- using [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) .
322-
323- The idea of ` vmap ` is to break vectorization into stages, transforming a
324- function that operates on single values into one that operates on arrays.
343+ Fortunately, we can avoid this by using [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) .
325344
326345Here's how we can apply it to our problem.
327346
@@ -330,13 +349,13 @@ Here's how we can apply it to our problem.
330349@jax.jit
331350def compute_max_vmap(grid):
332351 # Construct a function that takes the max over all x for given y
333- f_vec_x_max = lambda y: jnp.max(f(grid, y))
352+ compute_column_max = lambda y: jnp.max(f(grid, y))
334353 # Vectorize the function so we can call on all y simultaneously
335- f_vec_max = jax.vmap(f_vec_x_max )
336- # Compute the max across x at every y
337- maxes = f_vec_max (grid)
338- # Compute the max of the maxes and return
339- return jnp.max(maxes )
354+ vectorized_compute_column_max = jax.vmap(compute_column_max )
355+ # Compute the column max at every row
356+ column_maxes = vectorized_compute_column_max (grid)
357+ # Compute the max of the column maxes and return
358+ return jnp.max(column_maxes )
340359```
341360
342361Note that we never create
@@ -345,6 +364,8 @@ Note that we never create
345364* the two-dimensional grid ` y_mesh ` or
346365* the two-dimensional array ` f(x,y) `
347366
367+ Like Numba, we just use the flat array ` grid ` .
368+
348369And because everything is under a single ` @jax.jit ` , the compiler can fuse
349370all operations into one optimized kernel.
350371
@@ -378,18 +399,14 @@ In our view, JAX is the winner for vectorized operations.
378399It dominates NumPy both in terms of speed (via JIT-compilation and
379400parallelization) and memory efficiency (via vmap).
380401
381- Moreover, the ` vmap ` approach can sometimes lead to significantly clearer code.
382-
383- While Numba is impressive, the beauty of JAX is that, with fully vectorized
384- operations, we can run exactly the same code on machines with hardware
385- accelerators and reap all the benefits without extra effort.
402+ It also dominates Numba when run on the GPU.
386403
387- Moreover, JAX already knows how to effectively parallelize many common array
388- operations, which is key to fast execution.
389-
390- For most cases encountered in economics, econometrics, and finance, it is
404+ ``` {note}
405+ Numba can support GPU programming through `numba.cuda` but then we need to
406+ parallelize by hand. For most cases encountered in economics, econometrics, and finance, it is
391407far better to hand over to the JAX compiler for efficient parallelization than to
392- try to hand code these routines ourselves.
408+ try to hand-code these routines ourselves.
409+ ```
393410
394411
395412## Sequential operations
@@ -554,8 +571,6 @@ The JAX versions, on the other hand, require either `lax.fori_loop` or
554571While JAX's ` at[t].set ` syntax does allow element-wise updates, the overall code
555572remains harder to read than the Numba equivalent.
556573
557- For this type of sequential operation, Numba is the clear winner in terms of
558- code clarity and ease of implementation.
559574
560575
561576## Overall recommendations
@@ -573,25 +588,17 @@ than traditional meshgrid-based vectorization.
573588In addition, JAX functions are automatically differentiable, as we explore in
574589{doc}` autodiff ` .
575590
576- For ** sequential operations** , Numba has clear advantages .
591+ For ** sequential operations** , Numba has nicer syntax .
577592
578593The code is natural and readable --- just a Python loop with a decorator ---
579594and performance is excellent.
580595
581596JAX can handle sequential problems via ` lax.fori_loop ` or ` lax.scan ` , but
582597the syntax is less intuitive.
583598
584- ``` {note}
585- One important advantage of `lax.fori_loop` and `lax.scan` is that they
586- support automatic differentiation through the loop, which Numba cannot do.
587- If you need to differentiate through a sequential computation (e.g., computing
588- sensitivities of a trajectory to model parameters), JAX is the better choice
589- despite the less natural syntax.
590- ```
599+ On the other hand, the JAX versions support automatic differentiation.
591600
592- In practice, many problems involve a mix of both patterns.
601+ That might be of interest if, say, we want to compute sensitivities of a
602+ trajectory to model parameters
593603
594- A good rule of thumb: default to JAX for new projects, especially when
595- hardware acceleration or differentiability might be useful, and reach for Numba
596- when you have a tight sequential loop that needs to be fast and readable.
597604
0 commit comments