Skip to content

Commit f940a63

Browse files
committed
misc
1 parent a897eb8 commit f940a63

File tree

1 file changed

+107
-92
lines changed

1 file changed

+107
-92
lines changed

lectures/numba.md

Lines changed: 107 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,25 @@ import matplotlib.pyplot as plt
4242
```
4343

4444

45-
4645
## Overview
4746

4847
In an {doc}`earlier lecture <need_for_speed>` we discussed vectorization,
4948
which can improve execution speed by sending array processing operations in batch to efficient low-level code.
5049

5150
However, as {ref}`discussed in that lecture <numba-p_c_vectorization>`,
52-
traditional vectorization schemes, such as those found in MATLAB and NumPy, have weaknesses.
51+
traditional vectorization schemes have weaknesses:
5352

5453
* Highly memory-intensive for compound array operations
55-
* Ineffective or impossible for some algorithms.
54+
* Ineffective or impossible for some algorithms
5655

5756
One way to circumvent these problems is by using [Numba](https://numba.pydata.org/), a
58-
**just in time (JIT) compiler** for Python that is oriented towards numerical work.
57+
**just in time (JIT) compiler** for Python.
5958

60-
Numba compiles functions to native machine code instructions during runtime.
59+
Numba compiles functions to native machine code instructions at runtime.
6160

6261
When it succeeds, the result is performance comparable to compiled C or Fortran.
6362

64-
In addition, Numba can do other useful tricks, such as {ref}`multithreading` or
65-
interfacing with GPUs (through `numba.cuda`).
66-
67-
Numba's JIT compiler is in many ways similar to the JIT compiler in Julia
68-
69-
The main difference is that it is less ambitious, attempting to compile a smaller subset of the language.
70-
71-
Although this might sound like a deficiency, it is in some ways an advantage.
72-
73-
Numba is lean, easy to use, and very good at what it does.
63+
In addition, Numba can do useful tricks such as {ref}`multithreading`.
7464

7565
This lecture introduces the core ideas.
7666

@@ -80,6 +70,16 @@ This lecture introduces the core ideas.
8070
```{index} single: Python; Numba
8171
```
8272

73+
{note}
74+
```
75+
Some readers might be curious about the relationship between Numba and [Julia](https://julialang.org/),
76+
which contains its own JIT compiler. While the two compilers are similar in
77+
many ways, Numba is less ambitious, attempting only to compile a small subset of
78+
the Python language. Although this might sound like a deficiency, it is also a
79+
strength: the more restrictive nature of Numba makes it easy to use well and
80+
good at what it does.
81+
```
82+
8383

8484
(quad_map_eg)=
8585
### An Example
@@ -93,16 +93,14 @@ $$
9393
x_{t+1} = \alpha x_t (1 - x_t)
9494
$$
9595

96-
In what follows we set
96+
In what follows we set $\alpha = 4$.
9797

98-
```{code-cell} ipython3
99-
α = 4.0
100-
```
98+
#### Base Version
10199

102100
Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis
103101

104102
```{code-cell} ipython3
105-
def qm(x0, n):
103+
def qm(x0, n, α=4.0):
106104
x = np.empty(n+1)
107105
x[0] = x0
108106
for t in range(n):
@@ -117,103 +115,119 @@ ax.set_ylabel('$x_{t}$', fontsize = 12)
117115
plt.show()
118116
```
119117

120-
To speed the function `qm` up using Numba, our first step is
118+
Let's see how long this takes to run for large $n$
121119

122120
```{code-cell} ipython3
123-
from numba import jit
121+
n = 10_000_000
122+
123+
with qe.Timer() as timer1:
124+
# Time Python base version
125+
x = qm(0.1, int(n))
124126
125-
qm_numba = jit(qm)
126127
```
127128

128-
The function `qm_numba` is a version of `qm` that is "targeted" for
129-
JIT-compilation.
130129

131-
We will explain what this means momentarily.
130+
#### Acceleration via Numba
131+
132+
To speed the function `qm` up using Numba, we first import the `jit` function
132133

133-
Let's time and compare identical function calls across these two versions, starting with the original function `qm`:
134134

135135
```{code-cell} ipython3
136-
n = 10_000_000
136+
from numba import jit
137+
```
137138

138-
with qe.Timer() as timer1:
139-
qm(0.1, int(n))
140-
time1 = timer1.elapsed
139+
Now we apply it to `qm`, producing a new function:
140+
141+
```{code-cell} ipython3
142+
qm_numba = jit(qm)
141143
```
142144

143-
Now let's try qm_numba
145+
The function `qm_numba` is a version of `qm` that is "targeted" for
146+
JIT-compilation.
147+
148+
We will explain what this means momentarily.
149+
150+
Let's time this new version:
144151

145152
```{code-cell} ipython3
146153
with qe.Timer() as timer2:
147-
qm_numba(0.1, int(n))
148-
time2 = timer2.elapsed
154+
# Time jitted version
155+
x = qm_numba(0.1, int(n))
149156
```
150157

151-
This is already a very large speed gain.
158+
This is a large speed gain.
152159

153-
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:
160+
In fact, the next time and all subsequent times it runs even faster as the
161+
function has been compiled and is in memory:
154162

155163
(qm_numba_result)=
156164

157165
```{code-cell} ipython3
158166
with qe.Timer() as timer3:
159-
qm_numba(0.1, int(n))
160-
time3 = timer3.elapsed
167+
# Second run
168+
x = qm_numba(0.1, int(n))
161169
```
162170

171+
Here's the speed gain
172+
163173
```{code-cell} ipython3
164-
time1 / time3 # Calculate speed gain
174+
timer1.elapsed / timer3.elapsed
165175
```
166176

177+
This is a big boost for a small modification to our original code.
178+
179+
Let's discuss how this works.
167180

168181
### How and When it Works
169182

170-
Numba attempts to generate fast machine code using the infrastructure provided by the [LLVM Project](https://llvm.org/).
183+
Numba attempts to generate fast machine code using the infrastructure provided
184+
by the [LLVM Project](https://llvm.org/).
171185

172186
It does this by inferring type information on the fly.
173187

174188
(See our {doc}`earlier lecture <need_for_speed>` on scientific computing for a discussion of types.)
175189

176190
The basic idea is this:
177191

178-
* Python is very flexible and hence we could call the function qm with many
179-
types.
192+
* Python is very flexible and hence we could call the function qm with many types.
180193
* e.g., `x0` could be a NumPy array or a list, `n` could be an integer or a float, etc.
181194
* This makes it very difficult to generate efficient machine code *ahead of time* (i.e., before runtime).
182195
* However, when we do actually *call* the function, say by running `qm(0.5, 10)`,
183-
the types of `x0` and `n` become clear.
196+
the types of `x0`, `α` and `n` are determined.
184197
* Moreover, the types of *other variables* in `qm` *can be inferred once the input types are known*.
185198
* So the strategy of Numba and other JIT compilers is to *wait until the function is called*, and then compile.
186199

187200
That is called "just-in-time" compilation.
188201

189-
Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9,
190-
20)`, compilation only takes place on the first call.
202+
Note that, if you make the call `qm_numba(0.5, 10)` and then follow it with `qm_numba(0.9, 20)`, compilation only takes place on the first call.
191203

192204
This is because compiled code is cached and reused as required.
193205

194-
This is why, in the code above, `time3` is smaller than `time2`.
206+
This is why, in the code above, the second run of `qm_numba` is faster.
195207

196208
```{admonition} Remark
197-
In practice, rather than writing `qm_numba = jit(qm)`, we use *decorator* syntax and put `@jit` before the function definition. This is equivalent to adding `qm = jit(qm)` after the definition. We use this syntax throughout the rest of the lecture. (See {doc}`python_advanced_features` for more on decorators.)
209+
In practice, rather than writing `qm_numba = jit(qm)`, we typically use
210+
*decorator* syntax and put `@jit` before the function definition. This is
211+
equivalent to adding `qm = jit(qm)` after the definition.
198212
```
199213

200214

201-
## Type Inference
215+
## Sharp Bits
202216

203-
Successful type inference is a key part of JIT compilation.
217+
Numba is relatively easy to use but not always seamless.
204218

205-
As you can imagine, inferring types is easier for simple Python objects (e.g.,
206-
simple scalar data types such as floats and integers).
219+
Let's review some of the issues users run into.
207220

208-
Numba also plays well with NumPy arrays, which have well-defined types.
221+
### Typing
209222

210-
In an ideal setting, Numba can infer all necessary type information.
223+
Successful type inference is the key to JIT compilation.
211224

212-
This allows it to generate efficient native machine code, without having to call the Python runtime environment.
225+
In an ideal setting, Numba can infer all necessary type information.
213226

214-
When Numba cannot infer all type information, it will raise an error.
227+
When Numba *cannot* infer all type information, it will raise an error.
215228

216-
For example, in the setting below, Numba is unable to determine the type of the function `g` when compiling `iterate`
229+
For example, in the setting below, Numba is unable to determine the type of the
230+
function `g` when compiling `iterate`
217231

218232
```{code-cell} ipython3
219233
@jit
@@ -234,7 +248,7 @@ except Exception as e:
234248
print(e)
235249
```
236250

237-
We can fix this easily by compiling `g`.
251+
In the present case, we can fix this easily by compiling `g`.
238252

239253
```{code-cell} ipython3
240254
@jit
@@ -244,28 +258,16 @@ def g(x):
244258
iterate(g, 0.5, 100)
245259
```
246260

261+
In other cases, such as when we want to use functions from external libaries
262+
such as `SciPy`, there might not be any easy workaround.
247263

248-
## Dangers and Limitations
249-
250-
Let's add some cautionary notes.
251-
252-
### Limitations
253264

254-
As we've seen, Numba needs to infer type information on
255-
all variables to generate fast machine-level instructions.
265+
### Global Variables
256266

257-
For large routines or those using external libraries, this process can easily fail.
267+
Another thing to be careful about when using Numba is handling of global
268+
variables.
258269

259-
Hence, it's best to focus on speeding up small, time-critical snippets of code.
260-
261-
This will give you much better performance than blanketing your Python programs with `@jit` statements.
262-
263-
264-
### A Gotcha: Global Variables
265-
266-
Here's another thing to be careful about when using Numba.
267-
268-
Consider the following example
270+
For example, consider the following code
269271

270272
```{code-cell} ipython3
271273
a = 1
@@ -284,9 +286,10 @@ print(add_a(10))
284286
```
285287

286288
Notice that changing the global had no effect on the value returned by the
287-
function.
289+
function 😱.
288290

289-
When Numba compiles machine code for functions, it treats global variables as constants to ensure type stability.
291+
When Numba compiles machine code for functions, it treats global variables as
292+
constants to ensure type stability.
290293

291294
To avoid this, pass values as function arguments rather than relying on globals.
292295

@@ -320,15 +323,11 @@ Here's the code:
320323

321324
```{code-cell} ipython3
322325
@jit
323-
def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0):
324-
"""
325-
Updates household wealth.
326-
"""
327-
326+
def update(w, r=0.1, s=0.3, v1=0.1, v2=1.0):
327+
" Updates household wealth. "
328328
# Draw shocks
329329
R = np.exp(v1 * np.random.randn()) * (1 + r)
330330
y = np.exp(v2 * np.random.randn())
331-
332331
# Update wealth
333332
w = R * s * w + y
334333
return w
@@ -343,7 +342,7 @@ T = 100
343342
w = np.empty(T)
344343
w[0] = 5
345344
for t in range(T-1):
346-
w[t+1] = h(w[t])
345+
w[t+1] = update(w[t])
347346
348347
ax.plot(w)
349348
ax.set_xlabel('$t$', fontsize=12)
@@ -365,21 +364,30 @@ Here's the code:
365364
```{code-cell} ipython3
366365
@jit
367366
def compute_long_run_median(w0=1, T=1000, num_reps=50_000):
368-
369367
obs = np.empty(num_reps)
368+
# For each household
370369
for i in range(num_reps):
370+
# Set the initial condition and run forward in time
371371
w = w0
372372
for t in range(T):
373-
w = h(w)
373+
w = update(w)
374+
# Record the final value
374375
obs[i] = w
375-
376+
# Take the median of all final values
376377
return np.median(obs)
377378
```
378379

379380
Let's see how fast this runs:
380381

381382
```{code-cell} ipython3
382383
with qe.Timer():
384+
# Warm up
385+
compute_long_run_median()
386+
```
387+
388+
```{code-cell} ipython3
389+
with qe.Timer():
390+
# Second run
383391
compute_long_run_median()
384392
```
385393

@@ -391,22 +399,29 @@ To do so, we add the `parallel=True` flag and change `range` to `prange`:
391399
from numba import prange
392400
393401
@jit(parallel=True)
394-
def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000):
395-
402+
def compute_long_run_median_parallel(
403+
w0=1, T=1000, num_reps=50_000
404+
):
396405
obs = np.empty(num_reps)
397-
for i in prange(num_reps):
406+
for i in prange(num_reps): # Parallelize over households
398407
w = w0
399408
for t in range(T):
400-
w = h(w)
409+
w = update(w)
401410
obs[i] = w
402-
403411
return np.median(obs)
404412
```
405413

406414
Let's look at the timing:
407415

408416
```{code-cell} ipython3
409417
with qe.Timer():
418+
# Warm up
419+
compute_long_run_median_parallel()
420+
```
421+
422+
```{code-cell} ipython3
423+
with qe.Timer():
424+
# Second run
410425
compute_long_run_median_parallel()
411426
```
412427

0 commit comments

Comments
 (0)