This notebook implements a simple 2D flow matching model using JAX, equinox and diffrax. It trains a neural network to learn the velocity field that transforms samples from a standard Gaussian distribution to a target checkerboard distribution.
This notebook implements flow matching on a 2D checkerboard dataset using an affine probability path
This notebook requires the following JAX ecosystem packages:
- JAX: Core library for high-performance numerical computing
- Equinox: Neural network library built on JAX
- Diffrax: Differential equation solver for JAX
- jaxtyping: Type annotations for JAX arrays
- Optax: Gradient processing and optimization library
Additional standard packages: numpy, matplotlib
This implementation can be run using Docker or Apptainer containers for reproducible results.
-
Install VSCode Dev Containers Extension
First, install the Dev Containers extension in VSCode.
-
Open the Repository in the Dev Container
Click the
Reopen in Containerbutton in the pop-up that appears once you open the repository in VSCode.Alternatively, open the command palette in VSCode by pressing
Shift+Alt+P(Windows/Linux) orShift+Cmd+P(Mac), and typeDev Containers: Reopen in Container.
-
Install VSCode Remote Tunnels Extension
First, install the Remote Tunnels extension in VSCode.
-
Launch container
To open a tunnel to connect your local VSCode to the container on the cluster:
apptainer run --nv --writable-tmpfs oras://ghcr.io/marvinsxtr/jax-flow-matching:latest-sif code tunnel
In VSCode press
Shift+Alt+P(Windows/Linux) orShift+Cmd+P(Mac), type "connect to tunnel", select GitHub and select your named node on the cluster. Your IDE is now connected to the cluster.
This implementation is based on concepts from the following sources:
