hyperlax is a unified JAX-based framework for high-throughput reinforcement learning, designed to benchmark and optimize both classical and quantum machine learning models. It accelerates research by enabling massively parallel hyperparameter execution, transforming the traditional "one experiment, one process" paradigm into a vectorized "many experiments, one process" workflow.
By leveraging jax.vmap and jax.pmap across hyperparameter configurations, hyperlax allows for direct, fair, and efficient performance comparisons between different model families (e.g., MLP vs. PQC) on the same hardware, speeding up the research cycle.
We developed and tested with Python 3.10.
conda create --name hyperlax python=3.10 # or python3.10 -m venv .venv
conda activate hyperlax # or source .venv/bin/activateInstall the JAX version we develop and test according to your hardware. For CUDA:
pip install -U "jax[cuda12_pip]==0.4.28" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlNOTE: if you are getting "Segmentation Fault" with GPU setup, we recommend installing JAX with CUDA >=12.6 wheel. (See JAX docs).
For CPU-only:
pip install -U "jax[cpu]==0.4.28"Install hyperlax and its dependencies:
git clone git@github.com:dfki-ric-quantum/hyperlax-quantum.git
cd hyperlax-quantum
pip install -e '.' # or `'.[dev]'This setup is primarily for running large-scale experiments on a high-performance computing (HPC) cluster.
cd containers
singularity build hyperlax.sif singularity.defNote on Singularity Installation: Official installation guides may list outdated dependencies for modern Linux distributions (e.g., Ubuntu 22.04+). If you encounter issues (e.g., regarding =fuse=), we recommend consulting the SingularityCE GitHub Releases page for the latest packages and platform-specific instructions.
If you want to have pre-built singularity image used to develop and test, you can download from zenodo.org/records/17426400.
Run a quick, small-scale hyperparameter sweep for a classical PPO agent with an MLP policy on the Pendulum environment. This command will train 4 different hyperparameter configurations in a single vectorized run.
conda activate hyperlax
source envs/hyperlax_setup.sh
python hyperlax/cli.py sweep-hp-samples \
--algo-and-network-config ppo_mlp_no_network_search \
--env-config gymnax.pendulum \
--run-length-modifier quick \
--num-samples 4 \
--hparam-batch-size 4 \
--log-level INFO \
--output-root ./results/quickstart_classicalAfter the run completes, you'll find all results, logs, and metrics in the results/quickstart_classical/ppo_mlp/gymnax_pendulum directory.
NOTE: Use --run-length-modifier long option for longer/realistic training session.
Run a single experiment using a Parametrized Quantum Circuit (PQC) as the policy network for the SAC algorithm.
python hyperlax/cli.py run-single-hp \
--algo-and-network-config sac_drpqc \
--env-config gymnax.pendulum \
--run-length-modifier quick \
--log-level INFO \
--output-root ./results/quickstart_quantumThe main entry point is hyperlax/cli.py.
Runs one experiment using default hyperparameters.
python hyperlax/cli.py run-single-hp \
--algo-and-network-config sac_mlp \
--env-config gymnax.cartpole \
--log-level INFOGenerates samples and runs them in batches.
python hyperlax/cli.py sweep-hp-samples \
--algo-and-network-config ppo_mlp_no_network_search \
--env-config gymnax.pendulum \
--num-samples 64 \
--hparam-batch-size 16 \
--log-level INFO Use the --sequential True flag to run one-by-one for comparison.
Uses Optuna to go beyond random sampling for hyperparam search for finding the top performing hyperparameters.
python hyperlax/cli.py optuna-hp-search \
--algo-and-network-config ppo_mlp \
--env-config gymnax.pendulum \
--n-trials 100Executes a benchmark defined in hyperlax/configs/benchmark/, comparing multiple algorithms, environments, and sweep modes.
python hyperlax/cli.py run-benchmark \
--algos "ppo_mlp" "ppo_drpqc" \
--envs "gymnax.pendulum" "gymnax.cartpole" \
--num-samples-per-run 16Or use the pre-defined config:
python hyperlax/cli.py run-benchmark --base-config ppo_mlp_vs_drpqcPost-processes the output of a benchmark run to generate summary plots.
python hyperlax/cli.py plot-benchmark --results-dir-to-plot ./results_benchGenerates samples and saves them to a CSV file without running experiments.
python hyperlax/cli.py generate-hp-samples \
--algo-and-network-config ppo_mlp \
--num-samples 16 \
--output-file ./ppo_mlp_samples.csvhyperlax/configs/algo/benchmarked/01_static_qmc_sampling_S64 is the directory containing the algorithm configurations used for benchmarking. We sample 64 hyperparameter sets using the QMC method (no search). The benchmark data are stored under benchmark_results.
The hyperparameter distributions are chosen so that algorithms sharing the same parameters also share the same distributions. This provides an unbiased setup (as much as possible) and allows us to assess hyperparameter sensitivity.
We also ran experiments involving hyperparameter search (to be released soon).
Feel free to try and beat the current results!
Note that the full benchmark run takes about one month on a single GPU (e.g., A100). The slowest configuration is ppo-drpqc on reacher, which alone takes around 8 days, whereas the classical models finish in just a few hours. Therefore, you’ll need to strategize how to distribute the quantum models (especially ppo variants) across your cluster setup to optimize runtime.
python hyperlax/cli.py run-benchmark \
--algos benchmarked.01_static_qmc_sampling_S64.{dqn,ppo,sac}_{mlp,tmlp,drpqc} \
--envs gymnax.{cartpole,pendulum,reacher} brax.inverted_double_pendulum \
--num-samples-per-run 64 \
--sweep-modes "sequential" \ # HP configs include arch. choices such as n_layers and no batch support for those!
--run-length-modifier long \
--sampling-method "qmc_sobol" \
--log-level "INFO" \
--output-root "./results_benchmark_reproduce"- Crazy Long Runs
{ppo}_{drpqc}on inverted dp and reacher (due to increased obs dim.)
- Very Long Runs
{ppo}_{drpqc}on cartpole and pendulum
- Long Runs
{dqn,sac}_{drpqc}
- Medium Runs
{ppo,dqn,sac}_{tmlp}
- Light Runs
{ppo,dqn,sac}_{mlp}
If you are not performing architecture searches or sampling, you can use --sweep-modes "batched" to run hyperparameter batches in parallel. All algorithms are vectorized for efficient single-GPU utilization, enabling multiple configurations to run simultaneously.
See hyperlax/examples directory to find more usage examples.
Adding a new algorithm or network (classical or quantum) is:
- Create Configuration: Define dataclasses for your algorithm's config, network architectures, and hyperparameters in
hyperlax/configs/. - Implement Core Logic: Write the core algorithm update step with vectorized hyperparams (see how existing algorithm implementation achieves it) and loss functions in
hyperlax/algo/. For custom networks, add the Flax module inhyperlax/network/. - Implement the
AlgorithmInterface: Create asetup_my_algo.pyfile that provides the necessary functions (network builder, optimizer builder, etc.) and packages them into anAlgorithmInterfacedataclass. - Create a Recipe: Add a recipe file like
hyperlax/configs/algo/my_algo.pythat providesget_base_configandget_base_hyperparam_distributionsfunctions. - Run it!: Your new algorithm is now available via the CLI, e.g.,
--algo-and-network-config my_algo.
- Unified Benchmarking: Provide a single, consistent platform to fairly evaluate and compare the performance and data efficiency of classical, quantum, and tensor network models for reinforcement learning.
- Maximize Hardware Throughput: Minimize wall-clock time for research by fully utilizing available hardware, especially multiple GPUs on a cluster setup.
- Configuration as Code: Experiment configurations, including hyperparameter search spaces, are version-controllable, readable, and strongly-typed Python code.
- Vectorize Everything Possible: We aggressively apply
vmapnot just to environments but to distinct model architectures and training hyperparameters. - Immutable and Functional: We adhere to JAX's functional programming paradigm. State is explicitly passed and returned, and configurations are treated as immutable.
hyperlax extends ideas from prior JAX-based RL systems and quantum ML benchmarks:
- purejaxrl demonstrated fully JIT-compiled RL loops to keep environment rollouts on-device.
- Stoix introduced modular multi-device abstractions for distributed training.
- hyperlax takes these principles further by vectorizing across hyperparameter configurations, enabling batched, parallel experimentation in a single compiled computation.
- qml-benchmarks provided our baseline quantum model (i.e., Data-Reuploading Parameterized Quantum Circuit).
- gymnax and brax supplied fast, JAX-native environments crucial for large-scale, differentiable RL benchmarks.
Thanks to the broader open-source research community for advancing transparent, reproducible, and scalable machine learning tools.
pytest tests/- Vectorized/Batched hyperparameter computation is supported for algorithmic hyperparameters: both scalar (e.g., learning rates) and structural (e.g., rollout length) but not for function approximation related parameters (e.g., hidden dimensions). An experimental vectorized MLP implementation is available as a reference (see parametric_torso.py).
- In
dqn-drpqc_gymnax.acrobot, 9 out of 64 Acrobot samples trigger JAX’sXlaRuntimeError: INTERNAL: ptxas exited, indicating a possible synchronization issue in the multi-GPU setup. This issue was not further investigated. tmlpmodels are highly sensitive to specific learning rates; causin gradient explosions leading.- The current parameterized quantum circuit implementation, combined with the JAX-backed PennyLane version used, results in long JIT compilation times, likely due to non-JAX-compatible components in PennyLane. Interested researchers may explore the PennyLaneAI/catalyst project for potential JIT and execution improvements, though this has not been tested here.
See CONTRIBUTING for details.
@software{bolat_hyperlax_quantum_2025,
author = {Bolat, Ugur},
doi = {10.5281/zenodo.17426400},
month = {10},
title = {{Benchmarking Classical and Quantum Reinforcement Learning Algorithms with JAX}},
url = {https://github.com/dfki-ric-quantum/hyperlax-quantum},
version = {0.0.1},
year = {2025}
}Semantic versioning must be used, that is, the major version number will be incremented when the API changes in a backwards incompatible way, the minor version will be incremented when new functionality is added in a backwards compatible manner, and the patch version is incremented for bugfixes, documentation, etc.
Licensed under the BSD 3-clause license, see LICENSE for details.
This work was funded by the German Ministry of Economic Affairs and Climate Action (BMWK) and the German Aerospace Center (DLR) in project QuBER-KI (grants: 50RA2207A, 50RA2207B).
