Skip to content

Commit 5dc6ff7

Browse files
authored
updating readme for v1.2 (#188)
* readme: refresh installation instructions, etc * readme: add notes on gradient implementation, stacked transforms
1 parent 199ba1e commit 5dc6ff7

File tree

1 file changed

+155
-62
lines changed

1 file changed

+155
-62
lines changed

README.md

Lines changed: 155 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
[![GitHub Tests](https://github.com/flatironinstitute/jax-finufft/actions/workflows/tests.yml/badge.svg)](https://github.com/flatironinstitute/jax-finufft/actions/workflows/tests.yml)
44
[![Jenkins Tests](https://jenkins.flatironinstitute.org/buildStatus/icon?job=jax-finufft%2Fmain&subject=Jenkins%20Tests)](https://jenkins.flatironinstitute.org/job/jax-finufft/job/main/)
55

6-
This package provides a [JAX](https://github.com/google/jax) interface to (a
7-
subset of) the [Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT)
6+
This package provides a [JAX](https://github.com/google/jax) interface to the [Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT)
87
library](https://github.com/flatironinstitute/finufft). Take a look at the
98
[FINUFFT docs](https://finufft.readthedocs.io) for all the necessary
109
definitions, conventions, and more information about the algorithms and their
@@ -23,44 +22,57 @@ are supported in 1, 2, and 3 dimensions on the CPU and GPU.
2322
All of these functions support forward, reverse, and higher-order differentiation,
2423
as well as batching using `vmap`.
2524

25+
The [FINUFFT plan interface](https://finufft.readthedocs.io/en/latest/c.html#guru-plan-interface)
26+
is not directly exposed, although within a given jax-finufft call, plans are reused where possible,
27+
and transforms sharing the same non-uniform points are stacked/vectorized. All of the tuning options
28+
one can set in the plan interface are available through the `opts` argument of the jax-finufft API
29+
(see [Advanced Usage](#advanced-usage)).
30+
2631
## Installation
2732

28-
The easiest ways to install jax-finufft is to install a pre-compiled binary from
29-
PyPI or conda-forge, but if you need GPU support or want to get tuned
30-
performance, you'll want to follow the instructions to install from source as
31-
described below.
33+
The easiest way to install jax-finufft is from a pre-compiled binary on
34+
PyPI or conda-forge. Only CPU binaries currently are available on PyPI, while
35+
conda-forge has both CPU and GPU binaries. If you want GPU support without using
36+
conda, you can install jax-finufft from source as detailed below. This is also
37+
useful when you want to build finufft optimized for your hardware.
38+
39+
Currently only `jax<0.8` is supported.
3240

3341
### Install binary from PyPI
3442

3543
> [!NOTE]
3644
> Only the CPU-enabled build of jax-finufft is available as a binary wheel on
3745
> PyPI. For a GPU-enabled build, you'll need to build from source as described
38-
> below.
46+
> below or use conda-forge.
3947
4048
To install a binary wheel from [PyPI](https://pypi.org/project/jax-finufft/)
41-
using pip, run the following commands:
49+
using [uv](https://docs.astral.sh/uv/), run the following command in a venv:
4250

4351
```bash
44-
python -m pip install "jax[cpu]"
45-
python -m pip install jax-finufft
52+
uv pip install jax-finufft
4653
```
4754

48-
If this fails, you may need to use a conda-forge binary, or install from source.
55+
To install with `pip` instead of `uv`, simply drop `uv` from that command.
4956

5057
### Install binary from conda-forge
51-
52-
> [!NOTE]
53-
> Only the CPU-enabled build of jax-finufft is available as a binary from
54-
> conda-forge. For a GPU-enabled build, you'll need to build from source as
55-
> described below.
56-
57-
To install using [mamba](https://github.com/mamba-org/mamba) (or
58+
To install a CPU build using [mamba](https://github.com/mamba-org/mamba) (or
5859
[conda](https://docs.conda.io)), run:
5960

6061
```bash
6162
mamba install -c conda-forge jax-finufft
6263
```
6364

65+
To install a GPU-enabled build, run:
66+
67+
```bash
68+
mamba install -c conda-forge 'jax-finufft=*=cuda*'
69+
```
70+
71+
Make note of the installed package version, like `conda-forge/linux-64::jax-finufft-1.1.0-cuda129py312h8ad7275_1`.
72+
The `cuda129` substring indicates the package was built for CUDA 12.9. Your
73+
NVIDIA driver will need to support this version of CUDA. Only one CUDA
74+
build per major CUDA version is provided at present.
75+
6476
### Install from source
6577

6678
#### Dependencies
@@ -91,50 +103,43 @@ mamba activate jax-finufft
91103
<details>
92104
<summary>Install GPU dependencies with mamba or conda</summary>
93105

94-
For a GPU build, while the CUDA libraries and compiler are nominally available
95-
through conda, our experience trying to install them this way suggests that the
96-
"traditional" way of obtaining the [CUDA
97-
Toolkit](https://developer.nvidia.com/cuda-downloads) directly from NVIDIA may
98-
work best (see [related advice for
99-
Horovod](https://horovod.readthedocs.io/en/stable/conda_include.html)). After
100-
installing the CUDA Toolkit, one can set up the rest of the dependencies with:
101-
102106
```bash
103-
mamba create -n gpu-jax-finufft -c conda-forge python numpy scipy fftw 'gxx<12'
107+
mamba create -n gpu-jax-finufft -c conda-forge python fftw cxx-compiler jax 'jaxlib=*=*cuda*'
104108
mamba activate gpu-jax-finufft
109+
mamba install cuda libcufft-static -c nvidia
105110
export CMAKE_PREFIX_PATH=$CONDA_PREFIX:$CMAKE_PREFIX_PATH
106-
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
107111
```
108-
109-
Other ways of installing JAX are given on the JAX website; the ["local CUDA"
110-
install
111-
methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder)
112-
are preferred for jax-finufft as this ensures the CUDA extensions are compiled
113-
with the same Toolkit version as the CUDA runtime. However, this is not required
114-
as long as both JAX and jax-finufft use CUDA with the same major version.
115112
</details>
116113

117114
<details>
118115
<summary>Install GPU dependencies using Flatiron module system</summary>
119116

120117
```bash
121-
ml modules/2.3 \
118+
ml modules/2.4 \
122119
gcc \
123-
python/3.11 \
120+
python \
121+
uv \
124122
fftw \
125-
cuda/12
123+
cuda/12.8 \
124+
cudnn/9
126125

127-
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
126+
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=80;90;120 -DJAX_FINUFFT_USE_CUDA=ON"
128127
```
129128
</details>
130129

130+
Other ways of installing JAX are given on the JAX website; the
131+
["local CUDA" install methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder)
132+
are preferred for jax-finufft as this ensures the CUDA extensions are compiled
133+
with the same Toolkit version as the CUDA runtime. However, in theory, this is not required
134+
as long as both JAX and jax-finufft use CUDA with the same major version.
135+
131136
#### Notes on CUDA versions
132137
While jax-finufft may build with a wide range of CUDA
133138
versions, the resulting binaries may not be compatible with JAX (resulting in
134139
odd runtime errors, like failed cuDNN or cuBLAS initialization). For the greatest
135140
chance of success, we recommend building with the same version as JAX was built with.
136141
To discover that, one can look at the requirements in [JAX's `build` directory](https://github.com/jax-ml/jax/tree/main/build)
137-
(be sure to select the git tag for your version of JAX). Similarly, we encourage installing
142+
(be sure to select the git tag for your version of JAX). Similarly, when installing from PyPI, we encourage using
138143
`jax[cuda12-local]` so JAX and jax-finufft use the same CUDA libraries.
139144

140145
Depending on how challenging the installation is, users might want to run jax-finufft in a container. The [`.devcontainer`](./.devcontainer) directory is a good starting point for this.
@@ -146,14 +151,14 @@ There are several important CMake variables that control aspects of the jax-finu
146151
- **`CMAKE_CUDA_ARCHITECTURES`** [default `native`]: the target GPU architecture. `native` means the GPU arch of the build system.
147152
- **`FINUFFT_ARCH_FLAGS`** [default `-march=native`]: the target CPU architecture. The default is the native CPU arch of the build system.
148153

149-
Each of these can be set as `-Ccmake.define.NAME=VALUE` arguments to `pip install`. For example,
154+
Each of these can be set as `-Ccmake.define.NAME=VALUE` arguments to `pip install` or `uv pip install`. For example,
150155
to build with GPU support from the repo root, run:
151156

152157
```bash
153-
pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON .
158+
uv pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON .
154159
```
155160

156-
Use multiple `-C` arguments to set multiple variables. The `-C` argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, etc).
161+
Use multiple `-C` arguments to set multiple variables. The `-C` argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, `pip install`, `uv pip install`, `uv sync`, etc).
157162

158163
Build options can also be set with the `CMAKE_ARGS` environment variable. For example:
159164

@@ -168,7 +173,7 @@ By default, jax-finufft will build for the GPU of the build machine. If you need
168173
a different compute capability, such as 8.0 for Ampere, set `CMAKE_CUDA_ARCHITECTURES` as a CMake define:
169174

170175
```bash
171-
pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON -Ccmake.define.CMAKE_CUDA_ARCHITECTURES=80 .
176+
uv pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON -Ccmake.define.CMAKE_CUDA_ARCHITECTURES=80 .
172177
```
173178

174179
`CMAKE_CUDA_ARCHITECTURES` also takes a semicolon-separated list.
@@ -184,10 +189,10 @@ The values are also listed on the [NVIDIA website](https://developer.nvidia.com/
184189
In some cases, you may also need the following at runtime:
185190

186191
```bash
187-
export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
192+
export LD_LIBRARY_PATH="$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
188193
```
189194

190-
If `CUDA_PATH` isn't set, you'll need to replace it with the path to your CUDA
195+
If `CUDA_HOME` isn't set, you'll need to replace it with the path to your CUDA
191196
installation in the above line, often something like `/usr/local/cuda`.
192197

193198
#### Install source from PyPI
@@ -196,7 +201,7 @@ The source code for all released versions of jax-finufft are available on PyPI,
196201
and this can be installed using:
197202

198203
```bash
199-
python -m pip install --no-binary jax-finufft
204+
uv pip install jax-finufft --no-binary jax-finufft
200205
```
201206

202207
#### Install source from GitHub
@@ -214,10 +219,16 @@ cd jax-finufft
214219
> you can run `git submodule update --init --recursive` in your local copy to
215220
> checkout the submodule after the initial clone.
216221
217-
After cloning the repository, you can install the local copy using:
222+
After cloning the repository, you can install the local copy using the uv ["project interface"](https://docs.astral.sh/uv/guides/projects/):
223+
224+
```bash
225+
uv sync
226+
```
227+
228+
or using the pip interface:
218229

219230
```bash
220-
python -m pip install -e .
231+
uv pip install -e .
221232
```
222233

223234
where the `-e` flag optionally runs an "editable" install.
@@ -226,7 +237,7 @@ As yet another alternative, the latest development version from GitHub can be
226237
installed directly (i.e. without cloning first) with
227238

228239
```bash
229-
python -m pip install git+https://github.com/flatironinstitute/jax-finufft.git
240+
uv pip install git+https://github.com/flatironinstitute/jax-finufft.git
230241
```
231242

232243
## Usage
@@ -237,25 +248,27 @@ transforms). If you're already familiar with the [Python
237248
interface](https://finufft.readthedocs.io/en/latest/python.html) to FINUFFT,
238249
_please note that the function signatures here are different_!
239250

240-
For example, here's how you can do a 1-dimensional type 1 transform (CPU or GPU):
251+
For example, here's how you can do a 1-dimensional type 1 transform:
241252

242253
```python
243254
import numpy as np
255+
244256
from jax_finufft import nufft1
245257

246258
M = 100000
247259
N = 200000
248260

249-
x = 2 * np.pi * np.random.uniform(size=M)
250-
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
261+
rng = np.random.default_rng(123)
262+
x = 2 * np.pi * rng.random(M)
263+
c = rng.standard_normal(M) + 1j * rng.standard_normal(M)
251264
f = nufft1(N, c, x, eps=1e-6, iflag=1)
252265
```
253266

254-
Noting that the `eps` and `iflag` are optional, and that (for good reason, I
267+
Noting that the `eps` and `iflag` are optional, and that (for good reason, we
255268
promise!) the order of the positional arguments is reversed from the `finufft`
256269
Python package.
257270

258-
The syntax for a 2-, or 3-dimensional transform (CPU or GPU) is:
271+
The syntax for a 2-, or 3-dimensional transform is:
259272

260273
```python
261274
f = nufft1((Nx, Ny), c, x, y) # 2D
@@ -282,12 +295,48 @@ f = nufft3(c, x, y, z, s, t, u) # 3D
282295
All of these functions support batching using `vmap`, and forward and reverse
283296
mode differentiation.
284297

298+
### Stacked Transforms and Broadcasting
299+
300+
A "stacked", or "vectorized", finufft transform is one where the same non-uniform points are reused for multiple sets of source strengths. In the JAX interface, this is achieved by broadcasting. In the following example, only one finufft plan is created and one `setpts` call made, with a stack of 32 source strengths:
301+
302+
```python
303+
import numpy as np
304+
305+
from jax_finufft import nufft1
306+
307+
M = 100000
308+
N = 200000
309+
S = 32
310+
311+
rng = np.random.default_rng(123)
312+
x = 2 * np.pi * rng.random(M)
313+
c = rng.standard_normal((S, M)) + 1j * rng.standard_normal((S, M))
314+
f = nufft1(N, c, x)
315+
```
316+
317+
To verify that a stacked transform is being used, see [Inspecting the finufft calls](#inspecting-the-finufft-calls).
318+
319+
Note that the broadcasting occurs because an implicit axis of length 1 is inserted in the second-to-last dimension of `x`. Currently, this is the only style of broadcasting that is supported when the strengths and points have unequal numbers of non-core dimensions. For other styles of broadcasting, insert axes of length 1 into the inputs. Any broadcast axes (even non-consecutive ones) are grouped and stacked in the transform.
320+
321+
Matched, but not broadcast, axes will be executed as separate transforms, each with their own `setpts` calls (but a single shared plan). In the following example (which continues from the previous), 1 plan is created and 4 `setpts` and 4 `execute` calls are made, each executing a stack of 32 transforms:
322+
323+
```python
324+
P = 4
325+
326+
x = 2 * np.pi * rng.random((P, 1, M))
327+
c = rng.standard_normal((P, S, M)) + 1j * rng.standard_normal((P, S, M))
328+
f = nufft1(N, c, x)
329+
```
330+
331+
285332
## Selecting a platform
286333
If you compiled jax-finufft with GPU support, you can force it to use a particular
287334
backend by setting the environment variable `JAX_PLATFORMS=cpu` or `JAX_PLATFORMS=cuda`.
288335

289336
## Advanced usage
290337

338+
### Options
339+
291340
The tuning parameters for the library can be set using the `opts` parameter to
292341
`nufft1`, `nufft2`, and `nufft3`. For example, to explicitly set the CPU [up-sampling
293342
factor](https://finufft.readthedocs.io/en/latest/opts.html) that FINUFFT should
@@ -301,7 +350,7 @@ nufft1(N, c, x, opts=opts)
301350
```
302351

303352
The corresponding option for the GPU is `gpu_upsampfac`. In fact, all options
304-
for the GPU are prefixed with `gpu_`.
353+
for the GPU are prefixed with `gpu_`, with the exception of `modeord`.
305354

306355
One complication here is that the [vector-Jacobian
307356
product](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff)
@@ -330,18 +379,62 @@ opts = options.NestedOpts(
330379
)
331380
```
332381

333-
See [the FINUFFT docs](https://finufft.readthedocs.io/en/latest/opts.html) for
334-
descriptions of all the CPU tuning parameters. The corresponding GPU parameters
335-
are currently only listed in source code form in
336-
[`cufinufft_opts.h`](https://github.com/flatironinstitute/finufft/blob/master/include/cufinufft_opts.h).
382+
For descriptions of the options, see these pages in the FINUFFT docs:
383+
- CPU: https://finufft.readthedocs.io/en/latest/opts.html
384+
- GPU: https://finufft.readthedocs.io/en/latest/c_gpu.html#options-for-gpu-code
385+
386+
### Inspecting the finufft calls
387+
When evaluating a single NUFFT, it's fairly obvious that jax-finufft will execute one
388+
finufft transform under the hood. However, when evaluating a stacked NUFFT, or taking
389+
the gradients of a NUFFT, the sequence of calls may be less obvious. One way to inspect
390+
exactly what finufft calls are being made is to enable finufft's debug output by
391+
passing `opts=Opts(debug=True)` or `opts=Opts(gpu_debug=True)`.
392+
393+
For example, taking the [Stacked Transforms](#stacked-transforms-and-broadcasting) example and enabling
394+
debug output, we see the following:
395+
396+
```python-repl
397+
>>> f = nufft1(N, c, x, eps=1e-6, iflag=1, opts=Opts(debug=True))
398+
[FINUFFT_PLAN_T] new plan: FINUFFT version 2.4.1 .................
399+
[FINUFFT_PLAN_T] 1d1: (ms,mt,mu)=(200000,1,1) (nf1,nf2,nf3)=(400000,1,1)
400+
ntrans=32 nthr=16 batchSize=16 spread_thread=2
401+
[FINUFFT_PLAN_T] kernel fser (ns=7): 0.000765 s
402+
[FINUFFT_PLAN_T] fwBatch 0.05GB alloc: 0.00703 s
403+
[FINUFFT_PLAN_T] FFT plan (mode 64, nthr=16): 0.00892 s
404+
[setpts] sort (didSort=1): 0.00327 s
405+
[execute] start ntrans=32 (2 batches, bsize=16)...
406+
[execute] done. tot spread: 0.0236 s
407+
tot FFT: 0.0164 s
408+
tot deconvolve: 0.00191 s
409+
```
410+
411+
Evidently, we are creating a single plan with 32 transforms, and finufft has chosen to
412+
batch them into two sets of 16. `setpts` is only called once, as is `execute`, as we
413+
would expect for a stacked transform.
414+
415+
## Notes on the Implementation of the Gradients
416+
The NUFFT gradients are implemented as [Jacobian-vector products](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html#jacobian-vector-products-jvps-aka-forward-mode-autodiff) (JVP, i.e. forward-mode autodiff), with associated transpose rules that implement the vector-Jacobian product (VJP, reverse mode). These are found in [`ops.py`](./src/jax_finufft/ops.py), in the `jvp` and `transpose` functions.
417+
418+
The JVP of a D-dimensional type 1 or 2 NUFFT requires D transforms of the same type in D dimensions (considering just the gradients with respect to the non-uniform locations). Each transform is weighted by the frequencies (as a overall scaling for type 1, and at the Fourier strength level for type 2). These transforms are fully stacked, and finufft plans are reused where possible.
419+
420+
Furthermore, the JAX `jvp` evaluates the function in addition to its JVP, so 1 more transform is necessary. This transform is not stacked with the JVP transforms. Likewise, 1 more is needed when the gradient with respect to the source or Fourier strengths is requested. However, this transform is stacked with the JVP.
421+
422+
In reverse mode, the VJP of a type 1 NUFFT requires type 2 transforms, and type 2 requires type 1. In either case, the function evaluation returned under JAX's `vjp` still requires an NUFFT of the original type (which cannot be stacked with the VJP transforms, as they are of a different type).
423+
424+
For type 3, the JVP requires `2*D` type 3 transforms of dimension D to evaluate the gradients with respect to both the source and target locations. The strengths of each transform are weighted by the source or target locations. The source and target transforms are stacked separately. As with type 1 and 2, the strengths gradient transform is stacked with the source locations and the function evaluation transform is not stacked.
425+
426+
The VJP of a type 3 NUFFT also uses type 3 NUFFTs, but with the source and target points swapped.
427+
428+
In all of the above, whenever a user requests [stacked transforms via broadcasting](#stacked-transforms-and-broadcasting), this does not introduce new plans or finufft calls—the stacks simply get deeper. New sets of non-uniform points necessarily introduce new `setpts` and new executions, but not new plans.
429+
430+
To see all of the stacking behavior in action, take a look at [Inspecting the finufft calls](#inspecting-the-finufft-calls).
337431

338432
## Similar libraries
339433

340434
- [finufft](https://finufft.readthedocs.io/en/latest/python.html): The
341435
"official" Python bindings to FINUFFT. A good choice if you're not already
342436
using JAX and if you don't need to differentiate through your transform.
343-
- [mrphys/tensorflow-nufft](https://github.com/mrphys/tensorflow-nufft):
344-
TensorFlow bindings for FINUFFT and cuFINUFFT.
437+
- A list of other finufft binding libraries (e.g. for Julia, TensorFlow, PyTorch) is maintained at https://finufft.readthedocs.io/en/latest/users.html#other-wrappers-to-cu-finufft
345438

346439
## License & attribution
347440

0 commit comments

Comments
 (0)