You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
are preferred for jax-finufft as this ensures the CUDA extensions are compiled
133
+
with the same Toolkit version as the CUDA runtime. However, in theory, this is not required
134
+
as long as both JAX and jax-finufft use CUDA with the same major version.
135
+
131
136
#### Notes on CUDA versions
132
137
While jax-finufft may build with a wide range of CUDA
133
138
versions, the resulting binaries may not be compatible with JAX (resulting in
134
139
odd runtime errors, like failed cuDNN or cuBLAS initialization). For the greatest
135
140
chance of success, we recommend building with the same version as JAX was built with.
136
141
To discover that, one can look at the requirements in [JAX's `build` directory](https://github.com/jax-ml/jax/tree/main/build)
137
-
(be sure to select the git tag for your version of JAX). Similarly, we encourage installing
142
+
(be sure to select the git tag for your version of JAX). Similarly, when installing from PyPI, we encourage using
138
143
`jax[cuda12-local]` so JAX and jax-finufft use the same CUDA libraries.
139
144
140
145
Depending on how challenging the installation is, users might want to run jax-finufft in a container. The [`.devcontainer`](./.devcontainer) directory is a good starting point for this.
@@ -146,14 +151,14 @@ There are several important CMake variables that control aspects of the jax-finu
146
151
-**`CMAKE_CUDA_ARCHITECTURES`**[default `native`]: the target GPU architecture. `native` means the GPU arch of the build system.
147
152
-**`FINUFFT_ARCH_FLAGS`**[default `-march=native`]: the target CPU architecture. The default is the native CPU arch of the build system.
148
153
149
-
Each of these can be set as `-Ccmake.define.NAME=VALUE` arguments to `pip install`. For example,
154
+
Each of these can be set as `-Ccmake.define.NAME=VALUE` arguments to `pip install` or `uv pip install`. For example,
150
155
to build with GPU support from the repo root, run:
Use multiple `-C` arguments to set multiple variables. The `-C` argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, etc).
161
+
Use multiple `-C` arguments to set multiple variables. The `-C` argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, `pip install`, `uv pip install`, `uv sync`, etc).
157
162
158
163
Build options can also be set with the `CMAKE_ARGS` environment variable. For example:
159
164
@@ -168,7 +173,7 @@ By default, jax-finufft will build for the GPU of the build machine. If you need
168
173
a different compute capability, such as 8.0 for Ampere, set `CMAKE_CUDA_ARCHITECTURES` as a CMake define:
@@ -237,25 +248,27 @@ transforms). If you're already familiar with the [Python
237
248
interface](https://finufft.readthedocs.io/en/latest/python.html) to FINUFFT,
238
249
_please note that the function signatures here are different_!
239
250
240
-
For example, here's how you can do a 1-dimensional type 1 transform (CPU or GPU):
251
+
For example, here's how you can do a 1-dimensional type 1 transform:
241
252
242
253
```python
243
254
import numpy as np
255
+
244
256
from jax_finufft import nufft1
245
257
246
258
M =100000
247
259
N =200000
248
260
249
-
x =2* np.pi * np.random.uniform(size=M)
250
-
c = np.random.standard_normal(size=M) +1j* np.random.standard_normal(size=M)
261
+
rng = np.random.default_rng(123)
262
+
x =2* np.pi * rng.random(M)
263
+
c = rng.standard_normal(M) +1j* rng.standard_normal(M)
251
264
f = nufft1(N, c, x, eps=1e-6, iflag=1)
252
265
```
253
266
254
-
Noting that the `eps` and `iflag` are optional, and that (for good reason, I
267
+
Noting that the `eps` and `iflag` are optional, and that (for good reason, we
255
268
promise!) the order of the positional arguments is reversed from the `finufft`
256
269
Python package.
257
270
258
-
The syntax for a 2-, or 3-dimensional transform (CPU or GPU) is:
271
+
The syntax for a 2-, or 3-dimensional transform is:
259
272
260
273
```python
261
274
f = nufft1((Nx, Ny), c, x, y) # 2D
@@ -282,12 +295,48 @@ f = nufft3(c, x, y, z, s, t, u) # 3D
282
295
All of these functions support batching using `vmap`, and forward and reverse
283
296
mode differentiation.
284
297
298
+
### Stacked Transforms and Broadcasting
299
+
300
+
A "stacked", or "vectorized", finufft transform is one where the same non-uniform points are reused for multiple sets of source strengths. In the JAX interface, this is achieved by broadcasting. In the following example, only one finufft plan is created and one `setpts` call made, with a stack of 32 source strengths:
301
+
302
+
```python
303
+
import numpy as np
304
+
305
+
from jax_finufft import nufft1
306
+
307
+
M =100000
308
+
N =200000
309
+
S =32
310
+
311
+
rng = np.random.default_rng(123)
312
+
x =2* np.pi * rng.random(M)
313
+
c = rng.standard_normal((S, M)) +1j* rng.standard_normal((S, M))
314
+
f = nufft1(N, c, x)
315
+
```
316
+
317
+
To verify that a stacked transform is being used, see [Inspecting the finufft calls](#inspecting-the-finufft-calls).
318
+
319
+
Note that the broadcasting occurs because an implicit axis of length 1 is inserted in the second-to-last dimension of `x`. Currently, this is the only style of broadcasting that is supported when the strengths and points have unequal numbers of non-core dimensions. For other styles of broadcasting, insert axes of length 1 into the inputs. Any broadcast axes (even non-consecutive ones) are grouped and stacked in the transform.
320
+
321
+
Matched, but not broadcast, axes will be executed as separate transforms, each with their own `setpts` calls (but a single shared plan). In the following example (which continues from the previous), 1 plan is created and 4 `setpts` and 4 `execute` calls are made, each executing a stack of 32 transforms:
322
+
323
+
```python
324
+
P =4
325
+
326
+
x =2* np.pi * rng.random((P, 1, M))
327
+
c = rng.standard_normal((P, S, M)) +1j* rng.standard_normal((P, S, M))
328
+
f = nufft1(N, c, x)
329
+
```
330
+
331
+
285
332
## Selecting a platform
286
333
If you compiled jax-finufft with GPU support, you can force it to use a particular
287
334
backend by setting the environment variable `JAX_PLATFORMS=cpu` or `JAX_PLATFORMS=cuda`.
288
335
289
336
## Advanced usage
290
337
338
+
### Options
339
+
291
340
The tuning parameters for the library can be set using the `opts` parameter to
292
341
`nufft1`, `nufft2`, and `nufft3`. For example, to explicitly set the CPU [up-sampling
293
342
factor](https://finufft.readthedocs.io/en/latest/opts.html) that FINUFFT should
@@ -301,7 +350,7 @@ nufft1(N, c, x, opts=opts)
301
350
```
302
351
303
352
The corresponding option for the GPU is `gpu_upsampfac`. In fact, all options
304
-
for the GPU are prefixed with `gpu_`.
353
+
for the GPU are prefixed with `gpu_`, with the exception of `modeord`.
305
354
306
355
One complication here is that the [vector-Jacobian
Evidently, we are creating a single plan with 32 transforms, and finufft has chosen to
412
+
batch them into two sets of 16. `setpts` is only called once, as is `execute`, as we
413
+
would expect for a stacked transform.
414
+
415
+
## Notes on the Implementation of the Gradients
416
+
The NUFFT gradients are implemented as [Jacobian-vector products](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html#jacobian-vector-products-jvps-aka-forward-mode-autodiff) (JVP, i.e. forward-mode autodiff), with associated transpose rules that implement the vector-Jacobian product (VJP, reverse mode). These are found in [`ops.py`](./src/jax_finufft/ops.py), in the `jvp` and `transpose` functions.
417
+
418
+
The JVP of a D-dimensional type 1 or 2 NUFFT requires D transforms of the same type in D dimensions (considering just the gradients with respect to the non-uniform locations). Each transform is weighted by the frequencies (as a overall scaling for type 1, and at the Fourier strength level for type 2). These transforms are fully stacked, and finufft plans are reused where possible.
419
+
420
+
Furthermore, the JAX `jvp` evaluates the function in addition to its JVP, so 1 more transform is necessary. This transform is not stacked with the JVP transforms. Likewise, 1 more is needed when the gradient with respect to the source or Fourier strengths is requested. However, this transform is stacked with the JVP.
421
+
422
+
In reverse mode, the VJP of a type 1 NUFFT requires type 2 transforms, and type 2 requires type 1. In either case, the function evaluation returned under JAX's `vjp` still requires an NUFFT of the original type (which cannot be stacked with the VJP transforms, as they are of a different type).
423
+
424
+
For type 3, the JVP requires `2*D` type 3 transforms of dimension D to evaluate the gradients with respect to both the source and target locations. The strengths of each transform are weighted by the source or target locations. The source and target transforms are stacked separately. As with type 1 and 2, the strengths gradient transform is stacked with the source locations and the function evaluation transform is not stacked.
425
+
426
+
The VJP of a type 3 NUFFT also uses type 3 NUFFTs, but with the source and target points swapped.
427
+
428
+
In all of the above, whenever a user requests [stacked transforms via broadcasting](#stacked-transforms-and-broadcasting), this does not introduce new plans or finufft calls—the stacks simply get deeper. New sets of non-uniform points necessarily introduce new `setpts` and new executions, but not new plans.
429
+
430
+
To see all of the stacking behavior in action, take a look at [Inspecting the finufft calls](#inspecting-the-finufft-calls).
337
431
338
432
## Similar libraries
339
433
340
434
-[finufft](https://finufft.readthedocs.io/en/latest/python.html): The
341
435
"official" Python bindings to FINUFFT. A good choice if you're not already
342
436
using JAX and if you don't need to differentiate through your transform.
- A list of other finufft binding libraries (e.g. for Julia, TensorFlow, PyTorch) is maintained at https://finufft.readthedocs.io/en/latest/users.html#other-wrappers-to-cu-finufft
0 commit comments