Skip to content

TanqiuJiang/RAPID

Repository files navigation

RAPID: Retrieval Augmented Training of Differentially Private Diffusion Models

This is the offical implemation of RAPID: Retrieval Augmented Training of Differentially Private Diffusion Models.
The code is based off a public implementation of Latent Diffusion Models, available here and a public implementation of Differentially Private Latent Diffusion Models available here.

Environment setup:

conda env create -f environment.yaml
conda activate RAPID

Model Training

Training the Autoencoder

python main.py --base <AE config file path> -t --gpus 1

Training the Diffusion Models

python main.py --base <DM config file path> -t --gpus 1

Private Model Finetuning

python main.py --base <Finetune config file path> -t --gpus 0,  --accelerator gpu

Feature Extractor Training

python train_feature_extractor.py --config <DM config file path> --ckpt <checkpoint path> --output <network output path> --epoch 50

Sampling

Conditional Sampling

python conditional_sampling.py --config <DM config file path> --private_config <DM config file path> --ckpt <checkpoint path> \
 --private_ckpt <checkpoint path> --netpath <path to the feature extractor> --output <network output path> 

Unconditional Sampling

python unconditional_sampling.py --config <DM config file path> --private_config <DM config file path> --ckpt <checkpoint path> \
 --private_ckpt <checkpoint path> --netpath <path to the feature extractor> --output <network output path> 

FID Evaluation

python FID_test.py --sample_path <path to generated samples> --train_stats_path <path to generated statistics on the reference set>

Diversity Evaluation

python Diversity_test.py --sample_path <path to generated samples> --data_config <config file path>

Downstream Classification Accuracy

For MNIST, to compute the downstream performance on a regular CNN, the command is:

CNN_downstream.py --sample_path <path to generated samples> --epoch 10

Acknowledgement

We built and tested our project on top of Differentially Private Latent Diffusion Models and Differentially Private Latent Diffusion Models. Many thanks to the authors who make their work publicly accessable!

About

The official implementation of the paper "RAPID: Retrieval Augmented Training of Differentially Private Diffusion Models"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors