Skip to content

bayesianempirimancer/jax-noprop

Repository files navigation

JAX/Flax Flow Model Comparison

Python 3.8+ JAX Flax License: MIT

A JAX/Flax implementation for comparing Flow Matching, Diffusion, and Continuous-Time diffusion models. These three approaches solve the same continuous-time problem using different objective functions, each resulting in networks that predict different quantities.

Overview

This repository implements three variants of continuous-time generative models that address the same fundamental problem through different parameterizations:

Three Approaches to Continuous-Time Generative Modeling

  1. Flow Matching (FM): Predicts the denoising flow directly

    • Network learns: dz/dt = f(z, x, t) where f directly predicts the flow field
    • Objective: Match the flow field that connects noise to data
    • Most robust approach - direct flow prediction tends to be stable across different data distributions
  2. Diffusion (DF): Predicts noise at each time step

    • Network learns: noise_prediction = f(z, x, t) where f predicts the noise component
    • Objective: Predict noise to be removed, reparameterized to avoid singularities
    • Uses noise prediction for denoising trajectory
  3. Continuous-Time (CT): Predicts the target at each time step

    • Network learns: target_prediction = f(z, x, t) where f predicts the clean target
    • Objective: Predict the target value at each time point with SNR-weighted loss
    • Uses target prediction to guide the denoising process

Key Insight

All three methods solve the same continuous-time problem but parameterize it differently:

  • Flow Matching: Direct flow field prediction
  • Diffusion: Noise prediction → denoising trajectory
  • CT: Target prediction → denoising trajectory

Each parameterization leads to different training dynamics and performance characteristics.

Installation

git clone https://github.com/yourusername/jax-noprop.git
cd jax-noprop
pip install -e .

Requirements:

  • Python 3.8+
  • JAX 0.4.0+ (with CUDA support for GPU acceleration)
  • Flax 0.7.0+
  • PyYAML 6.0+ (for YAML config file support)
  • See requirements.txt for full dependency list

Note: For GPU support, install JAX with CUDA following the official JAX installation guide.

Quick Start

Training Scripts

This repository provides three main training scripts for different tasks:

  1. train.py - Regression/Classification (x → y)
  2. train_gen.py - Conditional/Unconditional Generation (y → x or x generation)
  3. train_seq.py - Sequence Generation (for sequence data)

Two Moons Dataset Example

See the Two Moons Example README for a complete walkthrough. Quick start:

# Generate the dataset
python examples/two_moons/generate_two_moons.py

# Regression: Predict labels from coordinates
python -m src.flow_models.train \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching

# Conditional Generation: Generate coordinates from labels
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching

# Unconditional Generation: Generate coordinates without labels
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching \
    --unconditional

Output Structure

After training, results are saved to artifacts/{model_type}_{task}/{YYYYMMDD_HHMM}/:

  • Regression: artifacts/{model_type}_reg/{timestamp}/
  • Conditional Generation: artifacts/{model_type}_gen/{timestamp}/
  • Unconditional Generation: artifacts/{model_type}_uncond_gen/{timestamp}/

Each directory contains:

  • history.pkl or training_results.pkl - Training history with all loss components
  • params.pkl or model_params.pkl - Trained model parameters
  • config.yaml - Configuration used for training (human-readable, hierarchical)
  • loss_trends.png - Loss trends plot showing flow_loss, recon_loss, reg_loss, vae_loss over epochs
  • data_visualization.png - Data visualization plots
  • trajectories.png - Sample trajectory visualizations
  • trajectory_diagnostics.png - Trajectory diagnostic plots
  • Generation-specific plots (conditional_generation.png, unconditional_generation.png, latent_trajectories.png)

Usage Examples

Configuration Files

All training scripts support YAML configuration files:

# Use a YAML config file (recommended)
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching

# Use a custom config class
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --config_class examples.two_moons.config.Config \
    --data_path data/two_moons.pkl \
    --model_type flow_matching

Command-line arguments override values in config files.

Unconditional Generation

# Generate samples without conditioning
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching \
    --unconditional \
    --num_epochs 100

Custom Architecture

# Override config file values via command line
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching \
    --latent_dim 8 \
    --encoder_model_type linear \
    --decoder_model_type identity \
    --decoder_type linear

Noise Schedule Selection

# Override noise schedule (for diffusion/CT models)
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type diffusion \
    --noise_schedule cosine

# Available schedules: linear, cosine, sigmoid, exponential, cauchy, laplace, logistic, quadratic, polynomial

Training with Dropout Schedule

# Use dropout for first 80 epochs, then disable it
python -m src.flow_models.train_gen \
    --config_file examples/two_moons/config.yaml \
    --data_path data/two_moons.pkl \
    --model_type flow_matching \
    --num_epochs 100 \
    --dropout_epochs 80

Model Comparison

When to Use Each Approach

Flow Matching (Recommended for most cases)

  • Most robust - Stable across different data distributions
  • ✅ Simple objective - direct flow prediction
  • ✅ Fast training and inference
  • ✅ Works well as baseline for comparison

Diffusion

  • Good for noise-focused applications
  • Uses noise prediction parameterization
  • Requires careful noise schedule tuning

Continuous-Time

  • Good when target prediction is natural
  • Uses SNR-weighted loss for training stability
  • Can learn optimal noise schedules

Performance Characteristics

  • Flow Matching: Fastest inference, most stable training
  • Diffusion: Good for applications where noise structure matters
  • CT: Can adapt to data with learnable schedules

Python API

Training a Model

from src.flow_models.trainer_gen import GenerationTrainer
from src.flow_models.config import Config

# Load configuration from YAML or create default
config = Config.load_yaml('examples/two_moons/config.yaml')
# Or create a default config
# config = Config()

# Create trainer
trainer = GenerationTrainer(
    config=config,
    learning_rate=1e-3,
    optimizer_name='adam',
    seed=42,
    unconditional=False  # Set to True for unconditional generation
)

# Initialize and train
trainer.initialize(x_sample, y_sample)  # x_sample can be None for unconditional generation
history = trainer.train(
    x_data=x_train,  # None for unconditional generation
    y_data=y_train,
    num_epochs=100,
    batch_size=256,
    validation_data=(x_val, y_val)  # Optional validation data
)

Generating Samples

# Conditional generation (supports batched inputs)
x_gen = trainer.conditional_generate(
    cond_y=conditions,  # conditional inputs [batch_size, input_dim] or [input_dim]
    num_steps=20,
    prng_key=key  # Optional: if provided, samples z_0 from normal; otherwise z_0=0
)

# Unconditional generation (supports batch processing)
x_gen = trainer.unconditional_generate(
    batch_shape=(100,),  # number of samples - can be any batch shape tuple
    num_steps=20,
    prng_key=key  # Required: for sampling z_0 from normal distribution
)

Batch Processing: Both predict() and sample() methods support batch processing:

  • predict() automatically handles batched conditional inputs
  • sample() accepts batch_shape tuple (e.g., (100,) for 100 samples, (10, 5) for 10x5 grid)
  • All samples in a batch are generated efficiently in parallel

Command-Line Arguments

Core Arguments

  • --config_file: Path to YAML config file (recommended)
  • --config_class: Optional custom config class (e.g., examples.two_moons.config.Config)
  • --model_type: flow_matching, diffusion, or ct
  • --data_path: Path to data file (required for most tasks)
  • --num_epochs: Number of training epochs (default: 50)
  • --batch_size: Batch size (default: 256)
  • --learning_rate: Learning rate (default: 1e-3)
  • --optimizer: adam, sgd, or adagrad (default: adam)

Shape Arguments

  • --input_shape or --input_dim: Input shape/dimension
  • --output_shape or --output_dim: Output shape/dimension
  • --latent_shape or --latent_dim: Latent shape/dimension

Note: If no config file is provided, you must specify these shapes.

Architecture Arguments

  • --encoder_model_type: identity, linear, mlp, mlp_normal, resnet, resnet_normal
  • --decoder_model_type: identity, mlp, resnet
  • --decoder_type: linear, softmax, none
  • --crn_type: CRN type (e.g., vanilla, geometric, potential)
  • --network_type: Network backbone (e.g., mlp, bilinear, convex)
  • --hidden_dims: Hidden layer dimensions (space-separated integers)

Noise Schedule Arguments

  • --noise_schedule: linear, cosine, sigmoid, exponential, cauchy, laplace, logistic, quadratic, polynomial, monotonic_nn, learnable, network
  • --noise_schedule_learnable: Make noise schedule learnable

Training Arguments

  • --dropout_epochs: Number of epochs to use dropout (default: all epochs)
  • --recon_weight: Reconstruction loss weight (default: 0.0)
  • --reg_weight: Regularization loss weight (default: 0.0)
  • --vae_weight: VAE encoder-decoder reconstruction loss weight (default: 1.0 for generation, 0.0 for regression)
  • --use_snr_weight: Apply SNR weighting (default: True for diffusion/CT, False for flow matching)

Generation Arguments

  • --unconditional: Train for unconditional generation (only for train_gen.py)

See python -m src.flow_models.train --help, python -m src.flow_models.train_gen --help, or python -m src.flow_models.train_seq --help for full lists of options.

Project Structure

jax-noprop/
├── src/
│   ├── flow_models/
│   │   ├── fm.py              # Flow Matching implementation
│   │   ├── df.py              # Diffusion implementation
│   │   ├── ct.py              # Continuous-Time implementation
│   │   ├── config.py          # Unified Config class
│   │   ├── train.py           # Regression/classification training CLI
│   │   ├── train_gen.py       # Generation training CLI
│   │   ├── train_seq.py       # Sequence training CLI
│   │   ├── trainer.py         # Regression trainer
│   │   ├── trainer_gen.py     # Generation trainer
│   │   ├── trainer_seq.py      # Sequence trainer
│   │   └── training_utils.py  # Shared training utilities
│   ├── configs/
│   │   └── base_config.py     # BaseConfig class with YAML support
│   ├── embeddings/
│   │   └── noise_schedules.py # Noise schedule implementations
│   └── vae/                   # Encoder/decoder architectures
├── examples/
│   └── two_moons/             # Two moons dataset example
│       ├── config.py          # Example config class
│       ├── config.yaml        # Example YAML config
│       ├── generate_two_moons.py
│       └── README.md
├── data/                      # Dataset files
├── artifacts/                 # Training outputs
│   └── {model_type}_{task}/   # Organized by model and task
│       └── {YYYYMMDD_HHMM}/   # Timestamped runs
└── README.md

Recent Features

VAE Loss Support

All three models (Flow Matching, Diffusion, CT) now support VAE encoder-decoder reconstruction loss:

  • VAE Loss: Measures reconstruction quality by encoding targets to latent space and decoding back
  • Configurable Weight: Control via vae_weight parameter (default: 1.0 for generation, 0.0 for regression)
  • Loss Tracking: VAE loss is tracked separately in training history and loss trends plots

Batch Processing

Efficient batch processing for sample generation:

  • predict(): Handles batched conditional inputs automatically
  • sample(): Accepts flexible batch_shape tuples for parallel generation
  • Performance: All samples in a batch are generated in parallel, significantly faster than sequential generation

Enhanced Loss Tracking

Comprehensive loss component tracking:

  • Flow loss, reconstruction loss, regularization loss, VAE loss tracked separately
  • Loss trends plots show all components over training epochs
  • Support for sequence metrics (MSE, percent variance explained) in sequence training

Technical Details

Noise Schedules

All models support multiple noise schedules:

  • Linear, Cosine, Sigmoid: Standard fixed schedules
  • Exponential, Cauchy, Laplace: Distribution-based schedules
  • Quadratic, Polynomial: Power-based schedules
  • Learnable: Neural network-based adaptive schedule

Schedules are parameterized to avoid singularities at boundaries.

Training Objectives

All three models support multiple loss components that can be weighted independently:

Flow Matching:

Total Loss = Flow Loss + recon_weight * Recon Loss + reg_weight * Reg Loss + vae_weight * VAE Loss + KL(z_0) Loss

Flow Loss = E[||dz/dt(z_t, x, t) - (target - z_0)||²]
Recon Loss = E[||y_pred - y||²]  (optional, weighted by recon_weight)
Reg Loss = E[||dz/dt||²]  (optional, weighted by reg_weight)
VAE Loss = E[||y - decode(encode(y))||²]  (optional, weighted by vae_weight)
KL(z_0) Loss = KL divergence for initial latent state (for some models)

Direct flow field matching with optional reconstruction, regularization, and VAE losses.

Diffusion:

Total Loss = Flow Loss + recon_weight * Recon Loss + reg_weight * Reg Loss + vae_weight * VAE Loss + KL(z_0) Loss

Flow Loss = E[SNR'(t) * ||noise_prediction - actual_noise||²]
Recon Loss = E[SNR'(t) * ||y_pred - y||²]  (optional, weighted by recon_weight)
Reg Loss = E[SNR'(t) * ||dz/dt||²]  (optional, weighted by reg_weight)
VAE Loss = E[||y - decode(encode(y))||²]  (optional, weighted by vae_weight)
KL(z_0) Loss = KL divergence for initial latent state

SNR-weighted noise prediction with optional additional losses.

Continuous-Time:

Total Loss = Flow Loss + recon_weight * Recon Loss + reg_weight * Reg Loss + vae_weight * VAE Loss + KL(z_0) Loss

Flow Loss = E[SNR'(t) * ||target_prediction - target||²]
Recon Loss = E[SNR'(t) * ||y_pred - y||²]  (optional, weighted by recon_weight)
Reg Loss = E[SNR'(t) * ||dz/dt||²]  (optional, weighted by reg_weight)
VAE Loss = E[||y - decode(encode(y))||²]  (optional, weighted by vae_weight)
KL(z_0) Loss = KL divergence for initial latent state

SNR-weighted target prediction with optional additional losses.

Loss Weight Configuration:

  • recon_weight: Weight for reconstruction loss (default: 0.0)
  • reg_weight: Weight for regularization loss (default: 0.0)
  • vae_weight: Weight for VAE encoder-decoder reconstruction loss (default: 1.0 for generation tasks, 0.0 for regression)

Configuration System

The repository uses a unified configuration system:

  • YAML Config Files: Human-readable configuration files (recommended)
  • Python Config Classes: Custom config classes that extend BaseConfig
  • Command-Line Overrides: All config values can be overridden via command-line arguments

The unified Config class in src/flow_models/config.py works for all three model types (Flow Matching, Diffusion, CT) and all tasks (regression, generation, sequences).

Contributing

Contributions welcome! Please open an issue or submit a pull request.

License

MIT License

Citation

If you use this code in your research, please cite:

@inproceedings{Li2025NoProp,
  title={{NoProp: Training Neural Networks without Full Back-propagation or Full Forward-propagation}},
  author={Qinyu Li and Yee Whye Teh and Razvan Pascanu},
  booktitle={Conference on Lifelong Learning Agents (CoLLAs)},
  year={2025},
  url={https://arxiv.org/abs/2503.24322}
}

About

A Jax Implementation of Diffusion and Flow Models using simulation free training protocols

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •