This repo contains a CellViT-based diffusion model for medical image segmentation, plus training, sampling, evaluation and demo scripts.
- Clone the repo
git clone git@github.com:caesarchen000/cellvit_diffusion.git
cd cellvit_diffusion- Create and activate a Conda env (recommended)
conda create -n cellvit_diffusion python=3.10 -y
conda activate cellvit_diffusion- Install Python dependencies
Most code lives in MedSegDiff and CellViT. Install their requirements:
cd MedSegDiff
pip install -r requirements.txt
# If you have additional libs you know you need:
pip install opencv-python scipy scikit-learn visdomAdjust versions as needed for your system / CUDA.
Training expects a nested dataset under train/ (and optionally test/) at repo root:
train/
<ID>/
images/
<ID>.png # input image
masks/
*.png # one or more instance masks for this ID
test/
<ID>/
images/
<ID>.png
masks/
*.png # optional, used for evaluation
- Masks are binary instance masks; the training pipeline merges all instance masks for an ID into a single foreground/background mask.
The main training entrypoint lives in MedSegDiff:
cd MedSegDiff
conda activate cellvit_diffusion # if not already
python scripts/segmentation_train.py \
--data_name CUSTOM \
--data_dir ../train \
--model_arch cellvit \
--image_size 256 \
--in_ch 4 \
--batch_size 4 \
--lr 5e-5 \
--diffusion_steps 6000 \
--log_interval 100 \
--save_interval 5000 \
--out_dir ./results/cellvit_diffusion_merged_masks \
--gpu_dev 0 \
--cellvit_embed_dim 384 \
--cellvit_depth 12 \
--cellvit_heads 6Notes:
--data_dirmust point to thetrain/folder described above.- Checkpoints are written into
results/...assavedmodel*.ptand EMA checkpoints asemasavedmodel_*.pt. - If you want to resume from a checkpoint:
python scripts/segmentation_train.py \
...same args as above... \
--resume_checkpoint ./results/cellvit_diffusion_merged_masks/savedmodel155000.ptThere is also a helper shell script scripts/train_cellvit_diffusion.sh you can inspect and modify.
To generate segmentation masks from a trained model:
cd MedSegDiff
conda activate cellvit_diffusion
bash scripts/sample_cellvit_diffusion_fast.shThis script:
- Automatically finds the latest EMA checkpoint in
results/cellvit_diffusion_merged_masks. - Runs a relatively fast sampling configuration (few diffusion steps, few ensembles) and writes predictions into something like:
results/cellvit_diffusion/samples_fast_debug/
For higher-quality but slower sampling, use scripts/sample_cellvit_diffusion.sh and adjust:
DIFFUSION_STEPSNUM_ENSEMBLETHRESHOLD
Both sampling scripts call scripts/segmentation_sample.py, which supports:
--num_samples--diffusion_steps--num_ensemble--post_process true|false--threshold 0.3(for binarization)
To quantitatively evaluate predictions against merged ground-truth masks, use:
cd MedSegDiff
conda activate cellvit_diffusion
python scripts/evaluate_iou.py \
--pred_dir ./results/cellvit_diffusion/samples_fast_debug \
--data_dir ../train \
--threshold 0.5This script:
- Matches each prediction
<ID>_output*.jpginpred_dirto the corresponding ID indata_dir. - Merges all instance masks for that ID.
- Computes per-sample and mean IoU, Dice, and Average Precision (AP).
To generate demo images showing input, merged ground truth, and prediction side-by-side:
cd MedSegDiff
conda activate cellvit_diffusion
python scripts/make_demo_triptych.py \
--pred_dir ./results/cellvit_diffusion/samples_fast_debug \
--data_dir ../train \
--out_dir ./results/cellvit_diffusion/demo_triptychEach output image in demo_triptych/ will look like:
[ input RGB | merged GT mask | predicted mask ]
- Large folders
train/,test/, andresults/are git-ignored so they are not pushed to GitHub.- Each user should create their own
train/andtest/folders locally.
- Each user should create their own
- If you encounter CUDA OOM, reduce
--batch_sizeor image size. - For multi-GPU or DDP, you can extend the training script to launch with
torch.distributed.launchortorchrun.
For any questions about configuration or extending the model, check the scripts in MedSegDiff/scripts and the training utilities in MedSegDiff/guided_diffusion.
