Before installing the library, please make sure that you installed JAX for your given hardware.
pip install mixed-precision-for-JAXFor basic usage, this README should give you everything you need to know. For deeper insights, you can read the documentation (https://data-science-in-mechanical-engineering.github.io/mixed_precision_for_JAX/) and our paper (https://www.arxiv.org/pdf/2507.03312).
This repository offers a tool for training JAX models using mixed precision, called mpx. It builds upon JMP—another mixed precision library for JAX—but extends its capabilities. JMP does not support arbitrary PyTrees and is particularly incompatible with models developed using Equinox. mpx overcomes these limitations, by leveraging Equinox's flexibility to work with any PyTree.
This section summarizes the original Mixed Precision method from https://developer.nvidia.com/automatic-mixed-precision and https://arxiv.org/pdf/1710.03740. Mixed Precision training involves performing most of the computations in the forward and backward passes of a neural network using 16-bit floating-point numbers. This approach reduces GPU memory usage by roughly half compared to full precision training, allowing for larger batch sizes or the use of fewer TPUs/GPUs. Additionally, mixed precision can speed up training by decreasing memory access times and utilizing specialized half-precision tensor cores on modern hardware (if available).
One of the key factors when successfully applying Mixed Precision training is loss scaling. Due to the decreased resolution of float16, small gradients are cast to zero, decreasing training performance. The loss scaling scales the loss by a factor > 1, and as a result the gradients during gradient calculation. Afterwards, the gradients are cast to float32 and then divided by the factor to obtain the original gradient. A standard optimizer then uses the gradient to calculate the model update. The scaling can be chosen automatically with a simple heuristic. If the scaled gradients exceed the range of float16 (i.e., they are inf), we reduce the scaling and do not update the model. If the scaled gradients to not exceed the range of float16 for a longer time, we increase the scaling.
Mixed Precision Training hence has the following steps:
- Initialize the Model and Optimizer using Full Precision.
- Get a Batch from the dataloader.
- Cast the batch and model for half precision (e.g., float16 or bfloat16).
- Do the forward pass in halfprecision, except critical operations.
- Scale the loss.
- Calculate the gradient of the scaled loss with respect to the weights.
- Cast weights to float32 and divide by the scaling value.
- If gradients are infinite, decrease scaling, else, increase scaling if in every n-th epoch.
- If gradients are finit do optimizer update, continue with 2.
mpx provides important functions for steps 3--9. However, it does not provide a Keras/PyTorch Lightning/Kauldron-like functionality, where you just pass model, loss and optimizer and call run. This is done on purpose to not hurt the low-level approach of JAX and allow users to write their training pipeline like they prefer.
mpx provides a comprehensive set of tools for mixed precision training in JAX.
The main goal was to keep the library as flexible and as close to equinox as possible.
As a result, to update a training pipeline to work with mixed precision, one just have to:
- Update the gradient calculations from
eqx.filter_grad/filter_value_and_gradtompx.filter_grad/filter_value_and_grad. - Do the
optaxoptimizer call viampx.optimizer_update. Here are the key components:
set_half_precision_datatype(dtype): Configure whether to usefloat16orbfloat16for half precision traininghalf_precision_datatype(): Get the currently configured half precision data type
cast_to_half_precision(x: PyTree): Cast all JAX arrays in a PyTree to the configured half precision typecast_to_full_precision(x: PyTree): Cast all JAX arrays in a PyTree tofloat32cast_to_float16(x: PyTree): Cast all JAX arrays in a PyTree tofloat16cast_to_bfloat16(x: PyTree): Cast all JAX arrays in a PyTree tobfloat16cast_to_float32(x: PyTree): Cast all JAX arrays in a PyTree tofloat32
force_full_precision: A decorator that ensures a function performs all calculations infloat32. This is essential for maintaining numerical stability in some operations. Some critical operations in JAX, likejax.numpy.sum/mean, internally convert half precision to full precision. The same is true for common equinox layers likeequinox.nn.MultiheadAttentionthat also force critical parts to full precision. However, this function might be useful for other implementations that do not do this.
DynamicLossScaling: A class that manages dynamic loss scaling to prevent underflow in half precision training. It is syntactically equivalent tojmp.DynamicLossScaling, however it can scale arbitrary PyTrees.scale(x): Scale a value by the current loss scaling factorunscale(x): Remove the loss scaling factor from a valueadjust(grads_finite): Update the loss scaling factor based on gradient stability These functions are just for your information. They are internally used, however these might be interesting for non-standard implementations.
scaled(func, scaling): Decorator that applies loss scaling to a function's outputall_finite(tree): Check if all values in a PyTree are finite (not NaN or Inf)
mpx provides function decorators for gradient calculations that summarize steps 3--9 in one function call. They have the same meaning and syntax as the corresponding decorators of equinox. This means, for an existing training pipeline, one can replace the calls of equinox.filter_grad/filter_value_and_grad with mpx.filter_grad/filter_value_and_grad
-
filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False, use_mixed_precision=True): Transformation that computes the gradient of func with respect to its first argument using mixed precision with scaling, similar toequinox.filter_grad. The transformed function then works as follows:- If
use_mixed_precisionis True:- Casts all input arguments to half precision (float16/bfloat16)
- Scales the function's output by
scaling
- Computes gradients using
equinox.filter_grad - If
use_mixed_precisionis True:- Casts gradients back to full precision (float32)
- Checks if gradients are finite
- Updates
scalingbased on whether the gradients are inf or not. - Unscales the gradients by dividing with
scaling
- Returns a tuple containing:
- The updated
scalingobject - A boolean indicating if gradients are finite (needed for optimized step see below)
- The computed gradients
- Auxiliary values (if
has_aux=True)
- The updated
- If
-
filter_value_and_grad(func, scaling): Decorator that works likefilter_grad, except that it also returns the value.
The gradient transformations might return gradients that are infinite. In this case, the pipeline needs to skip the model update. For this, mpx provides the following function:
optimizer_update(model: PyTree, optimizer: optax.GradientTransformation, optimizer_state: PyTree, grads: PyTree, grads_finite: Bool): Apply optimizer updates only when gradients are finite. Works with arbitraryoptaxoptimizers.
The following provides a small example, training a vision transformer on Cifar100 presenting all the important features of mpx. For details, please visit examples/train_vit.py.
This example will not go into the details for the neural network part, but just the mpx relevant parts.
The example was tested on an RTX4070, the training crashes with a batch size of 256 without mixed precision. With mixed precision, the training runs, demonstrating that mixed precision training via mpx effectively reduces the memory used on the GPU. The training speed itself does not change dramatically as the RTX4070 does not have a higher throughput for half precision operations.
First install JAX for your hardware. Then, install all dependencies via
pip install -r examples/requirements.txtThen you can run the example via. ATTENTION: The script downloads Cifar100.
python -m examples.train_vitThe loss scaling has to be initialized during the instantiation of the datasets, models etc. Typically, the initial value is set to the maximum value of float16.
loss_scaling = mpx.DynamicLossScaling(loss_scaling=mpx.FLOAT16_MAX,
min_loss_scaling=jnp.ones((), dtype=jnp.float32),
period=2000)The loss_scaling object then must be passed to the training pipeline.
The most important part is the training step. mpx makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to eqx.filter_value_and_grad with mpx.filter_value_and_grad and afterwards call the optimizer via mpx.optimizer_update. Also, do not forget to return loss_scaling in your step function, because loss_scaling is updated.
@eqx.filter_jit
def make_step(model: eqx.Module,
optimizer: any,
optimizer_state: PyTree,
batch: dict,
batch_sharding: jax.sharding.NamedSharding,
replicated_sharding: jax.sharding.NamedSharding,
loss_scaling: mpx.DynamicLossScaling,
train_mixed_precicion: bool,
weight_regularization: Float,
key: PRNGKeyArray
) -> tuple[eqx.Module, PyTree, Float, PRNGKeyArray]:
batch = eqx.filter_shard(batch, batch_sharding)
model = eqx.filter_shard(model, replicated_sharding)
optimizer_state = eqx.filter_shard(optimizer_state, replicated_sharding)
if train_mixed_precicion:
# this is the critical part
(loss_value, _), loss_scaling, grads_finite, grads = mpx.filter_value_and_grad(batched_loss_acc_wrapper, scaling=loss_scaling, has_aux=True)(
model, batch, batch_sharding, replicated_sharding, key, weight_regularization)
model, optimizer_state = mpx.optimizer_update(model, optimizer, optimizer_state, grads,grads_finite)
else:
(loss_value, _), grads = eqx.filter_value_and_grad(batched_loss_acc_wrapper, has_aux=True)(
model, batch, batch_sharding, replicated_sharding, key)
# optimizer step
updates, optimizer_state = optimizer.update(
grads, optimizer_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
model = eqx.filter_shard(model, replicated_sharding)
optimizer_state = eqx.filter_shard(optimizer_state, replicated_sharding)
loss_scaling = eqx.filter_shard(loss_scaling, replicated_sharding)
# return loss_scaling as it is changed
return model, optimizer_state, loss_scaling, loss_valueThrough the transformation via mpx.filter_value_and_grad, we can write our loss function as we normally do when using JAX/Equinox.
To cite this repository, please cite our paper:
@ARTICLE{2025arXiv250703312G,
author = {{Gr{\"a}fe}, Alexander and {Trimpe}, Sebastian},
title = "{MPX: Mixed Precision Training for JAX}",
journal = {arXiv e-prints},
year = 2025,
doi = {10.48550/arXiv.2507.03312},
}
We want to thank Partick Kidger for providing equinox and google DeepMind for providing JMP, which was the base for this implementation.
The authors gratefully acknowledge the computing time provided to them at the NHR Center NHR4CES at RWTH Aachen University (project number p0021919). This is funded by the Federal Ministry of Education and Research, and the state governments participating on the basis of the resolutions of the GWK for national high performance computing at universities (www.nhr-verein.de/unsere-partner).