A JAX/Flax implementation for comparing Flow Matching, Diffusion, and Continuous-Time diffusion models. These three approaches solve the same continuous-time problem using different objective functions, each resulting in networks that predict different quantities.
This repository implements three variants of continuous-time generative models that address the same fundamental problem through different parameterizations:
-
Flow Matching (FM): Predicts the denoising flow directly
- Network learns:
dz/dt = f(z, x, t)wherefdirectly predicts the flow field - Objective: Match the flow field that connects noise to data
- Most robust approach - direct flow prediction tends to be stable across different data distributions
- Network learns:
-
Diffusion (DF): Predicts noise at each time step
- Network learns:
noise_prediction = f(z, x, t)wherefpredicts the noise component - Objective: Predict noise to be removed, reparameterized to avoid singularities
- Uses noise prediction for denoising trajectory
- Network learns:
-
Continuous-Time (CT): Predicts the target at each time step
- Network learns:
target_prediction = f(z, x, t)wherefpredicts the clean target - Objective: Predict the target value at each time point with SNR-weighted loss
- Uses target prediction to guide the denoising process
- Network learns:
All three methods solve the same continuous-time problem but parameterize it differently:
- Flow Matching: Direct flow field prediction
- Diffusion: Noise prediction → denoising trajectory
- CT: Target prediction → denoising trajectory
Each parameterization leads to different training dynamics and performance characteristics.
git clone https://github.com/yourusername/jax-noprop.git
cd jax-noprop
pip install -e .Requirements:
- Python 3.8+
- JAX 0.4.0+ (with CUDA support for GPU acceleration)
- Flax 0.7.0+
- PyYAML 6.0+ (for YAML config file support)
- See
requirements.txtfor full dependency list
Note: For GPU support, install JAX with CUDA following the official JAX installation guide.
This repository provides three main training scripts for different tasks:
train.py- Regression/Classification (x → y)train_gen.py- Conditional/Unconditional Generation (y → x or x generation)train_seq.py- Sequence Generation (for sequence data)
See the Two Moons Example README for a complete walkthrough. Quick start:
# Generate the dataset
python examples/two_moons/generate_two_moons.py
# Regression: Predict labels from coordinates
python -m src.flow_models.train \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching
# Conditional Generation: Generate coordinates from labels
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching
# Unconditional Generation: Generate coordinates without labels
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching \
--unconditionalAfter training, results are saved to artifacts/{model_type}_{task}/{YYYYMMDD_HHMM}/:
- Regression:
artifacts/{model_type}_reg/{timestamp}/ - Conditional Generation:
artifacts/{model_type}_gen/{timestamp}/ - Unconditional Generation:
artifacts/{model_type}_uncond_gen/{timestamp}/
Each directory contains:
history.pklortraining_results.pkl- Training history with all loss componentsparams.pklormodel_params.pkl- Trained model parametersconfig.yaml- Configuration used for training (human-readable, hierarchical)loss_trends.png- Loss trends plot showing flow_loss, recon_loss, reg_loss, vae_loss over epochsdata_visualization.png- Data visualization plotstrajectories.png- Sample trajectory visualizationstrajectory_diagnostics.png- Trajectory diagnostic plots- Generation-specific plots (conditional_generation.png, unconditional_generation.png, latent_trajectories.png)
All training scripts support YAML configuration files:
# Use a YAML config file (recommended)
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching
# Use a custom config class
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--config_class examples.two_moons.config.Config \
--data_path data/two_moons.pkl \
--model_type flow_matchingCommand-line arguments override values in config files.
# Generate samples without conditioning
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching \
--unconditional \
--num_epochs 100# Override config file values via command line
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching \
--latent_dim 8 \
--encoder_model_type linear \
--decoder_model_type identity \
--decoder_type linear# Override noise schedule (for diffusion/CT models)
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type diffusion \
--noise_schedule cosine
# Available schedules: linear, cosine, sigmoid, exponential, cauchy, laplace, logistic, quadratic, polynomial# Use dropout for first 80 epochs, then disable it
python -m src.flow_models.train_gen \
--config_file examples/two_moons/config.yaml \
--data_path data/two_moons.pkl \
--model_type flow_matching \
--num_epochs 100 \
--dropout_epochs 80Flow Matching (Recommended for most cases)
- ✅ Most robust - Stable across different data distributions
- ✅ Simple objective - direct flow prediction
- ✅ Fast training and inference
- ✅ Works well as baseline for comparison
Diffusion
- Good for noise-focused applications
- Uses noise prediction parameterization
- Requires careful noise schedule tuning
Continuous-Time
- Good when target prediction is natural
- Uses SNR-weighted loss for training stability
- Can learn optimal noise schedules
- Flow Matching: Fastest inference, most stable training
- Diffusion: Good for applications where noise structure matters
- CT: Can adapt to data with learnable schedules
from src.flow_models.trainer_gen import GenerationTrainer
from src.flow_models.config import Config
# Load configuration from YAML or create default
config = Config.load_yaml('examples/two_moons/config.yaml')
# Or create a default config
# config = Config()
# Create trainer
trainer = GenerationTrainer(
config=config,
learning_rate=1e-3,
optimizer_name='adam',
seed=42,
unconditional=False # Set to True for unconditional generation
)
# Initialize and train
trainer.initialize(x_sample, y_sample) # x_sample can be None for unconditional generation
history = trainer.train(
x_data=x_train, # None for unconditional generation
y_data=y_train,
num_epochs=100,
batch_size=256,
validation_data=(x_val, y_val) # Optional validation data
)# Conditional generation (supports batched inputs)
x_gen = trainer.conditional_generate(
cond_y=conditions, # conditional inputs [batch_size, input_dim] or [input_dim]
num_steps=20,
prng_key=key # Optional: if provided, samples z_0 from normal; otherwise z_0=0
)
# Unconditional generation (supports batch processing)
x_gen = trainer.unconditional_generate(
batch_shape=(100,), # number of samples - can be any batch shape tuple
num_steps=20,
prng_key=key # Required: for sampling z_0 from normal distribution
)Batch Processing: Both predict() and sample() methods support batch processing:
predict()automatically handles batched conditional inputssample()acceptsbatch_shapetuple (e.g.,(100,)for 100 samples,(10, 5)for 10x5 grid)- All samples in a batch are generated efficiently in parallel
--config_file: Path to YAML config file (recommended)--config_class: Optional custom config class (e.g.,examples.two_moons.config.Config)--model_type:flow_matching,diffusion, orct--data_path: Path to data file (required for most tasks)--num_epochs: Number of training epochs (default: 50)--batch_size: Batch size (default: 256)--learning_rate: Learning rate (default: 1e-3)--optimizer:adam,sgd, oradagrad(default:adam)
--input_shapeor--input_dim: Input shape/dimension--output_shapeor--output_dim: Output shape/dimension--latent_shapeor--latent_dim: Latent shape/dimension
Note: If no config file is provided, you must specify these shapes.
--encoder_model_type:identity,linear,mlp,mlp_normal,resnet,resnet_normal--decoder_model_type:identity,mlp,resnet--decoder_type:linear,softmax,none--crn_type: CRN type (e.g.,vanilla,geometric,potential)--network_type: Network backbone (e.g.,mlp,bilinear,convex)--hidden_dims: Hidden layer dimensions (space-separated integers)
--noise_schedule:linear,cosine,sigmoid,exponential,cauchy,laplace,logistic,quadratic,polynomial,monotonic_nn,learnable,network--noise_schedule_learnable: Make noise schedule learnable
--dropout_epochs: Number of epochs to use dropout (default: all epochs)--recon_weight: Reconstruction loss weight (default: 0.0)--reg_weight: Regularization loss weight (default: 0.0)--vae_weight: VAE encoder-decoder reconstruction loss weight (default: 1.0 for generation, 0.0 for regression)--use_snr_weight: Apply SNR weighting (default: True for diffusion/CT, False for flow matching)
--unconditional: Train for unconditional generation (only fortrain_gen.py)
See python -m src.flow_models.train --help, python -m src.flow_models.train_gen --help, or python -m src.flow_models.train_seq --help for full lists of options.
jax-noprop/
├── src/
│ ├── flow_models/
│ │ ├── fm.py # Flow Matching implementation
│ │ ├── df.py # Diffusion implementation
│ │ ├── ct.py # Continuous-Time implementation
│ │ ├── config.py # Unified Config class
│ │ ├── train.py # Regression/classification training CLI
│ │ ├── train_gen.py # Generation training CLI
│ │ ├── train_seq.py # Sequence training CLI
│ │ ├── trainer.py # Regression trainer
│ │ ├── trainer_gen.py # Generation trainer
│ │ ├── trainer_seq.py # Sequence trainer
│ │ └── training_utils.py # Shared training utilities
│ ├── configs/
│ │ └── base_config.py # BaseConfig class with YAML support
│ ├── embeddings/
│ │ └── noise_schedules.py # Noise schedule implementations
│ └── vae/ # Encoder/decoder architectures
├── examples/
│ └── two_moons/ # Two moons dataset example
│ ├── config.py # Example config class
│ ├── config.yaml # Example YAML config
│ ├── generate_two_moons.py
│ └── README.md
├── data/ # Dataset files
├── artifacts/ # Training outputs
│ └── {model_type}_{task}/ # Organized by model and task
│ └── {YYYYMMDD_HHMM}/ # Timestamped runs
└── README.md
All three models (Flow Matching, Diffusion, CT) now support VAE encoder-decoder reconstruction loss:
- VAE Loss: Measures reconstruction quality by encoding targets to latent space and decoding back
- Configurable Weight: Control via
vae_weightparameter (default: 1.0 for generation, 0.0 for regression) - Loss Tracking: VAE loss is tracked separately in training history and loss trends plots
Efficient batch processing for sample generation:
predict(): Handles batched conditional inputs automaticallysample(): Accepts flexiblebatch_shapetuples for parallel generation- Performance: All samples in a batch are generated in parallel, significantly faster than sequential generation
Comprehensive loss component tracking:
- Flow loss, reconstruction loss, regularization loss, VAE loss tracked separately
- Loss trends plots show all components over training epochs
- Support for sequence metrics (MSE, percent variance explained) in sequence training
All models support multiple noise schedules:
- Linear, Cosine, Sigmoid: Standard fixed schedules
- Exponential, Cauchy, Laplace: Distribution-based schedules
- Quadratic, Polynomial: Power-based schedules
- Learnable: Neural network-based adaptive schedule
Schedules are parameterized to avoid singularities at boundaries.
All three models support multiple loss components that can be weighted independently:
Flow Matching:
Total Loss = Flow Loss + recon_weight * Recon Loss + reg_weight * Reg Loss + vae_weight * VAE Loss + KL(z_0) Loss
Flow Loss = E[||dz/dt(z_t, x, t) - (target - z_0)||²]
Recon Loss = E[||y_pred - y||²] (optional, weighted by recon_weight)
Reg Loss = E[||dz/dt||²] (optional, weighted by reg_weight)
VAE Loss = E[||y - decode(encode(y))||²] (optional, weighted by vae_weight)
KL(z_0) Loss = KL divergence for initial latent state (for some models)
Direct flow field matching with optional reconstruction, regularization, and VAE losses.
Diffusion:
Total Loss = Flow Loss + recon_weight * Recon Loss + reg_weight * Reg Loss + vae_weight * VAE Loss + KL(z_0) Loss
Flow Loss = E[SNR'(t) * ||noise_prediction - actual_noise||²]
Recon Loss = E[SNR'(t) * ||y_pred - y||²] (optional, weighted by recon_weight)
Reg Loss = E[SNR'(t) * ||dz/dt||²] (optional, weighted by reg_weight)
VAE Loss = E[||y - decode(encode(y))||²] (optional, weighted by vae_weight)
KL(z_0) Loss = KL divergence for initial latent state
SNR-weighted noise prediction with optional additional losses.
Continuous-Time:
Total Loss = Flow Loss + recon_weight * Recon Loss + reg_weight * Reg Loss + vae_weight * VAE Loss + KL(z_0) Loss
Flow Loss = E[SNR'(t) * ||target_prediction - target||²]
Recon Loss = E[SNR'(t) * ||y_pred - y||²] (optional, weighted by recon_weight)
Reg Loss = E[SNR'(t) * ||dz/dt||²] (optional, weighted by reg_weight)
VAE Loss = E[||y - decode(encode(y))||²] (optional, weighted by vae_weight)
KL(z_0) Loss = KL divergence for initial latent state
SNR-weighted target prediction with optional additional losses.
Loss Weight Configuration:
recon_weight: Weight for reconstruction loss (default: 0.0)reg_weight: Weight for regularization loss (default: 0.0)vae_weight: Weight for VAE encoder-decoder reconstruction loss (default: 1.0 for generation tasks, 0.0 for regression)
The repository uses a unified configuration system:
- YAML Config Files: Human-readable configuration files (recommended)
- Python Config Classes: Custom config classes that extend
BaseConfig - Command-Line Overrides: All config values can be overridden via command-line arguments
The unified Config class in src/flow_models/config.py works for all three model types (Flow Matching, Diffusion, CT) and all tasks (regression, generation, sequences).
Contributions welcome! Please open an issue or submit a pull request.
MIT License
If you use this code in your research, please cite:
@inproceedings{Li2025NoProp,
title={{NoProp: Training Neural Networks without Full Back-propagation or Full Forward-propagation}},
author={Qinyu Li and Yee Whye Teh and Razvan Pascanu},
booktitle={Conference on Lifelong Learning Agents (CoLLAs)},
year={2025},
url={https://arxiv.org/abs/2503.24322}
}