Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions UNETR/BTCV/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Model Overview

This repository contains the code for UNETR: Transformers for 3D Medical Image Segmentation [1]. UNETR is the first 3D segmentation network that uses a pure vision transformer as its encoder without relying on CNNs for feature extraction.
The code presents a volumetric (3D) multi-organ segmentation application using the BTCV challenge dataset.
![image](https://lh3.googleusercontent.com/pw/AM-JKLU2eTW17rYtCmiZP3WWC-U1HCPOHwLe6pxOfJXwv2W-00aHfsNy7jeGV1dwUq0PXFOtkqasQ2Vyhcu6xkKsPzy3wx7O6yGOTJ7ZzA01S6LSh8szbjNLfpbuGgMe6ClpiS61KGvqu71xXFnNcyvJNFjN=w1448-h496-no?authuser=0)

### Installing Dependencies

Dependencies can be installed using:
``` bash

```bash
pip install -r requirements.txt
```

### Training

A UNETR network with standard hyper-parameters for the task of multi-organ semantic segmentation (BTCV dataset) can be defined as follows:

``` bash
```bash
model = UNETR(
in_channels=1,
out_channels=14,
Expand All @@ -22,20 +25,21 @@ model = UNETR(
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed='perceptron',
pos_embed='learnable',
norm_name='instance',
conv_block=True,
res_block=True,
dropout_rate=0.0)
```

The above UNETR model is used for CT images (1-channel input) and for 14-class segmentation outputs. The network expects
resampled input images with size ```(96, 96, 96)``` which will be converted into non-overlapping patches of size ```(16, 16, 16)```.
resampled input images with size `(96, 96, 96)` which will be converted into non-overlapping patches of size `(16, 16, 16)`.
The position embedding is performed using a perceptron layer. The ViT encoder follows standard hyper-parameters as introduced in [2].
The decoder uses convolutional and residual blocks as well as instance normalization. More details can be found in [1].

Using the default values for hyper-parameters, the following command can be used to initiate training using PyTorch native AMP package:
``` bash

```bash
python main.py
--feature_size=32
--batch_size=1
Expand All @@ -48,28 +52,30 @@ python main.py
--data_dir=/dataset/dataset0/
```

Note that you need to provide the location of your dataset directory by using ```--data_dir```.
Note that you need to provide the location of your dataset directory by using `--data_dir`.

To initiate distributed multi-gpu training, ```--distributed``` needs to be added to the training command.
To initiate distributed multi-gpu training, `--distributed` needs to be added to the training command.

To disable AMP, ```--noamp``` needs to be added to the training command.
To disable AMP, `--noamp` needs to be added to the training command.

If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. ```--optim_lr```)
according to the number of GPUs. For instance, ```--optim_lr=4e-4``` is recommended for training with 4 GPUs.
If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. `--optim_lr`)
according to the number of GPUs. For instance, `--optim_lr=4e-4` is recommended for training with 4 GPUs.

### Finetuning

We provide state-of-the-art pre-trained checkpoints and TorchScript models of UNETR using BTCV dataset.

For using the pre-trained checkpoint, please download the weights from the following directory:

https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth

Once downloaded, please place the checkpoint in the following directory or use ```--pretrained_dir``` to provide the address of where the model is placed:
Once downloaded, please place the checkpoint in the following directory or use `--pretrained_dir` to provide the address of where the model is placed:

```./pretrained_models```
`./pretrained_models`

The following command initiates finetuning using the pretrained checkpoint:
``` bash

```bash
python main.py
--batch_size=1
--logdir=unetr_pretrained
Expand All @@ -88,12 +94,13 @@ For using the pre-trained TorchScript model, please download the model from the

https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pt

Once downloaded, please place the TorchScript model in the following directory or use ```--pretrained_dir``` to provide the address of where the model is placed:
Once downloaded, please place the TorchScript model in the following directory or use `--pretrained_dir` to provide the address of where the model is placed:

```./pretrained_models```
`./pretrained_models`

The following command initiates finetuning using the TorchScript model:
``` bash

```bash
python main.py
--batch_size=1
--logdir=unetr_pretrained
Expand All @@ -108,39 +115,53 @@ python main.py
--pretrained_model_name='UNETR_model_best_acc.pt'
--resume_jit
```
Note that finetuning from the provided TorchScript model does not support AMP.

Note that finetuning from the provided TorchScript model does not support AMP.

### Testing

You can use the state-of-the-art pre-trained TorchScript model or checkpoint of UNETR to test it on your own data.

Once the pretrained weights are downloaded, using the links above, please place the TorchScript model in the following directory or
use ```--pretrained_dir``` to provide the address of where the model is placed:
use `--pretrained_dir` to provide the address of where the model is placed:

`./pretrained_models`

```./pretrained_models```
The following command runs inference(validation or predict mask) using the provided checkpoint:

The following command runs inference using the provided checkpoint:
``` bash
```bash
python test.py
--mode='validation'
--infer_overlap=0.5
--data_dir=/dataset/dataset0/
--pretrained_dir='./pretrained_models/'
--saved_checkpoint=ckpt
```

Note that ```--infer_overlap``` determines the overlap between the sliding window patches. A higher value typically results in more accurate segmentation outputs but with the cost of longer inference time.
```bash
python test.py
--mode='predict'
--infer_overlap=0.5
--pretrained_dir='./pretrained_models/'
--saved_checkpoint=ckpt
```

Note that `--infer_overlap` determines the overlap between the sliding window patches. A higher value typically results in more accurate segmentation outputs but with the cost of longer inference time.

If you would like to use the pretrained TorchScript model, ```--saved_checkpoint=torchscript``` should be used.
If you would like to use the pretrained TorchScript model, `--saved_checkpoint=torchscript` should be used.

### Tutorial

A tutorial for the task of multi-organ segmentation using BTCV dataset can be found in the following:

https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb

Additionally, a tutorial which leverages PyTorch Lightning can be found in the following:

https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb

## Dataset

![image](https://lh3.googleusercontent.com/pw/AM-JKLX0svvlMdcrchGAgiWWNkg40lgXYjSHsAAuRc5Frakmz2pWzSzf87JQCRgYpqFR0qAjJWPzMQLc_mmvzNjfF9QWl_1OHZ8j4c9qrbR6zQaDJWaCLArRFh0uPvk97qAa11HtYbD6HpJ-wwTCUsaPcYvM=w1724-h522-no?authuser=0)

The training data is from the [BTCV challenge dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217752).
Expand All @@ -152,14 +173,14 @@ Under Institutional Review Board (IRB) supervision, 50 abdomen CT scans of were
- Modality: CT
- Size: 30 3D volumes (24 Training + 6 Testing)


We provide the json file that is used to train our models in the following link:

https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_btcv_dataset_0.json

Once the json file is downloaded, please place it in the same folder as the dataset.

## Citation

If you find this repository useful, please consider citing UNETR paper:

```
Expand All @@ -173,6 +194,7 @@ If you find this repository useful, please consider citing UNETR paper:
```

## References

[1] Hatamizadeh, Ali, et al. "UNETR: Transformers for 3D Medical Image Segmentation", 2021. https://arxiv.org/abs/2103.10504.

[2] Dosovitskiy, Alexey, et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Expand Down
3 changes: 3 additions & 0 deletions UNETR/BTCV/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
NIFTI_DATA_ROOT = 'data/images' # nifti image directory
NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory
PREDICT_DATA_ROOT = 'data/predict' # predict image directory
126 changes: 126 additions & 0 deletions UNETR/BTCV/dataset/customDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
from torch.utils.data import DataLoader
from monai.data import Dataset
import monai.transforms as transforms
import torch

from config import NIFTI_DATA_ROOT, NIFTI_LABEL_ROOT, PREDICT_DATA_ROOT

def _get_collate_fn(isTrain:bool):
def collate_fn(batch):
'''collate function'''
images = []
labels = []
if isTrain:
for p in batch: # [ {"image": (C, H, W ,D), "label": (C, H, W ,D)} , ...]
for i in range(len(p)): # list, RandCropByPosNegLabeld will produce multiple samples
images.append(p[i]['image'])
labels.append(p[i]['label'])
else:
for p in batch:
images.append(p['image'])
labels.append(p['label'])

images = torch.stack(images, dim=0)
labels = torch.stack(labels, dim=0)
# keep images float and labels long for loss functions
return [images.float(), labels.long()]

return collate_fn

def getDatasetLoader(args):
exts = (".nii", ".nii.gz")
img_names = {f for f in os.listdir(NIFTI_DATA_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_DATA_ROOT, f))}
lbl_names = {f for f in os.listdir(NIFTI_LABEL_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_LABEL_ROOT, f))}
common = sorted(img_names & lbl_names)
if not common:
raise RuntimeError(f"No matching image/label pairs found in {NIFTI_DATA_ROOT} and {NIFTI_LABEL_ROOT}")
dataDicts = [{"image": os.path.join(NIFTI_DATA_ROOT, f), "label": os.path.join(NIFTI_LABEL_ROOT, f)} for f in common]

trainDicts, valDicts = _splitList(dataDicts)

train_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"]),
transforms.EnsureChannelFirstd(keys=["image", "label"]),
transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
transforms.Spacingd(
keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
),
transforms.ScaleIntensityRanged(
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
),
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
transforms.RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(args.roi_x, args.roi_y, args.roi_z),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0),
transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1),
transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2),
transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3),
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob),
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob),
transforms.ToTensord(keys=["image", "label"]),
]
)

val_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"]),
transforms.EnsureChannelFirstd(keys=["image", "label"]),
transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
transforms.Spacingd(
keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
),
transforms.ScaleIntensityRanged(
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
),
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
transforms.ToTensord(keys=["image", "label"]),
]
)

trainDataset = Dataset(data=trainDicts, transform=train_transform)
valDataset = Dataset(data=valDicts, transform=val_transform)
trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True))
valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False))
loader = [trainLoader, valLoader]

return loader

def _splitList(l, trainRatio:float = 0.8):
totalNum = len(l)
splitIdx = int(totalNum * trainRatio)

return l[:splitIdx], l[splitIdx :]

def getPredictLoader(args):
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]

Comment on lines +105 to +107
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Filter predict files to NIfTI and validate presence.

Unfiltered os.listdir may include non-NIfTI files (.DS_Store, JSON, etc.) and will break LoadImaged.

-def getPredictLoader(args):
-    dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
-    dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
+def getPredictLoader(args):
+    exts = (".nii", ".nii.gz")
+    if not os.path.isdir(PREDICT_DATA_ROOT):
+        raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}")
+    files = sorted(
+        f for f in os.listdir(PREDICT_DATA_ROOT)
+        if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f))
+    )
+    if not files:
+        raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}")
+    dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
def getPredictLoader(args):
exts = (".nii", ".nii.gz")
if not os.path.isdir(PREDICT_DATA_ROOT):
raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}")
files = sorted(
f for f in os.listdir(PREDICT_DATA_ROOT)
if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f))
)
if not files:
raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}")
dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files]
🤖 Prompt for AI Agents
In UNETR/BTCV/dataset/customDataset.py around lines 105 to 107, the code builds
dataName from os.listdir which can include non-NIfTI files and will break
LoadImaged; change it to only include files with NIfTI extensions (e.g., .nii,
.nii.gz) using a filter (or glob) and ensure each entry is a regular file, build
dataDicts from those paths, and add a validation step that logs/raises an error
if no valid NIfTI files are found.

preTransform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.EnsureChannelFirstd(keys=["image"]),
transforms.Orientationd(keys=["image"], axcodes="RAS"),
transforms.Spacingd(
keys=["image"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear")
),
transforms.ScaleIntensityRanged(
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
),
transforms.CropForegroundd(keys=["image"], source_key="image", allow_smaller=True),
transforms.EnsureTyped(keys=["image"], track_meta=True),
]
)
valDataset = Dataset(data=dataDicts, transform=preTransform)
valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers)

return valLoader, preTransform
Loading