From b5b33e1bedc92d0ddf6c1fb52af760f1257896dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=97=E4=BD=91=20=E6=9E=97?= Date: Tue, 2 Sep 2025 17:08:33 +0800 Subject: [PATCH 1/6] 1. Add customDatasetLoader 2. Add predict function in test.py --- UNETR/BTCV/config.py | 3 + UNETR/BTCV/dataset/customDataset.py | 120 ++++++++++++++++++++++++++++ UNETR/BTCV/main.py | 45 ++++++----- UNETR/BTCV/networks/unetr.py | 4 +- UNETR/BTCV/requirements.txt | 12 +-- UNETR/BTCV/test.py | 103 ++++++++++++++++-------- UNETR/BTCV/trainer.py | 9 ++- UNETR/BTCV/utils/data_utils.py | 11 ++- 8 files changed, 236 insertions(+), 71 deletions(-) create mode 100644 UNETR/BTCV/config.py create mode 100644 UNETR/BTCV/dataset/customDataset.py diff --git a/UNETR/BTCV/config.py b/UNETR/BTCV/config.py new file mode 100644 index 00000000..6477737e --- /dev/null +++ b/UNETR/BTCV/config.py @@ -0,0 +1,3 @@ +NIFTI_DATA_ROOT = 'data/images' # nifti image directory +NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory +PREDICT_DATA_ROOT = 'data/predict' # predict image directory \ No newline at end of file diff --git a/UNETR/BTCV/dataset/customDataset.py b/UNETR/BTCV/dataset/customDataset.py new file mode 100644 index 00000000..bfcf7a1f --- /dev/null +++ b/UNETR/BTCV/dataset/customDataset.py @@ -0,0 +1,120 @@ +import os +from torch.utils.data import DataLoader +from monai.data import Dataset +import monai.transforms as transforms +import torch + +from config import NIFTI_DATA_ROOT, NIFTI_LABEL_ROOT, PREDICT_DATA_ROOT + +def _get_collate_fn(isTrain:bool): + def collate_fn(batch): + '''collate function''' + images = [] + labels = [] + if isTrain: + for p in batch: # [ {"image": (C, H, W ,D), "label": (C, H, W ,D)} , ...] + for i in range(len(p)): # list, RandCropByPosNegLabeld will produce multiple samples + images.append(p[i]['image']) + labels.append(p[i]['label']) + else: + for p in batch: + images.append(p['image']) + labels.append(p['label']) + + images = torch.stack(images, dim=0) + labels = torch.stack(labels, dim=0) + + return [torch.Tensor(images), torch.Tensor(labels)] + + return collate_fn + +def getDatasetLoader(args): + dataName = [d for d in os.listdir(NIFTI_LABEL_ROOT)] + dataDicts = [{"image": f"{os.path.join(NIFTI_DATA_ROOT, d)}", "label": f"{os.path.join(NIFTI_LABEL_ROOT, d)}"} for d in dataName] + trainDicts, valDicts = _splitList(dataDicts) + + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.EnsureChannelFirstd(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd( + keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") + ), + transforms.ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), + transforms.RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_z), + pos=1, + neg=1, + num_samples=4, + image_key="image", + image_threshold=0, + ), + transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0), + transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1), + transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2), + transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3), + transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.EnsureChannelFirstd(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd( + keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") + ), + transforms.ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + trainDataset = Dataset(data=trainDicts, transform=train_transform) + valDataset = Dataset(data=valDicts, transform=val_transform) + trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True)) + valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False)) + loader = [trainLoader, valLoader] + + return loader + +def _splitList(l, trainRatio:float = 0.8): + totalNum = len(l) + splitIdx = int(totalNum * trainRatio) + + return l[:splitIdx], l[splitIdx :] + +def getPredictLoader(args): + dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)] + dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName] + + preTransform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.Orientationd(keys=["image"], axcodes="RAS"), + transforms.Spacingd( + keys=["image"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear") + ), + transforms.ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + transforms.CropForegroundd(keys=["image"], source_key="image", allow_smaller=True), + # transforms.ToTensord(keys=["image"],track_meta=True), # This transformation will transform MetaTensor to Tensor + ] + ) + valDataset = Dataset(data=dataDicts, transform=preTransform) + valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers) + + return valLoader, preTransform diff --git a/UNETR/BTCV/main.py b/UNETR/BTCV/main.py index e31b1991..3b5b80ba 100644 --- a/UNETR/BTCV/main.py +++ b/UNETR/BTCV/main.py @@ -12,17 +12,12 @@ import argparse import os from functools import partial - import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.parallel import torch.utils.data.distributed -from networks.unetr import UNETR -from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from trainer import run_training -from utils.data_utils import get_loader from monai.inferers import sliding_window_inference from monai.losses import DiceCELoss, DiceLoss @@ -30,19 +25,26 @@ from monai.transforms import Activations, AsDiscrete, Compose from monai.utils.enums import MetricReduction +from networks.unetr import UNETR +from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from trainer import run_training +from utils.data_utils import get_loader +from dataset.customDataset import getDatasetLoader + parser = argparse.ArgumentParser(description="UNETR segmentation pipeline") parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint") parser.add_argument("--logdir", default="test", type=str, help="directory to save the tensorboard logs") parser.add_argument( "--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory" ) -parser.add_argument("--data_dir", default="/dataset/dataset0/", type=str, help="dataset directory") +parser.add_argument("--btcv", action="store_true", help="Use BTCV dataset") +parser.add_argument("--data_dir", default="./dataset/dataset0/", type=str, help="dataset directory") parser.add_argument("--json_list", default="dataset_0.json", type=str, help="dataset json file") parser.add_argument( "--pretrained_model_name", default="UNETR_model_best_acc.pth", type=str, help="pretrained model name" ) -parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training") -parser.add_argument("--max_epochs", default=5000, type=int, help="max number of training epochs") +parser.add_argument("--save_checkpoint", action="store_true", default=True, help="save checkpoint during training") +parser.add_argument("--max_epochs", default=100, type=int, help="max number of training epochs") parser.add_argument("--batch_size", default=1, type=int, help="number of batch size") parser.add_argument("--sw_batch_size", default=1, type=int, help="number of sliding window batch size") parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate") @@ -50,7 +52,7 @@ parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight") parser.add_argument("--momentum", default=0.99, type=float, help="momentum") parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") -parser.add_argument("--val_every", default=100, type=int, help="validation frequency") +parser.add_argument("--val_every", default=10, type=int, help="validation frequency") parser.add_argument("--distributed", action="store_true", help="start distributed training") parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") @@ -58,7 +60,7 @@ parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") parser.add_argument("--workers", default=8, type=int, help="number of workers") parser.add_argument("--model_name", default="unetr", type=str, help="model name") -parser.add_argument("--pos_embed", default="perceptron", type=str, help="type of position embedding") +parser.add_argument("--pos_embed", default="learnable", type=str, help="type of position embedding") parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder") parser.add_argument("--num_heads", default=12, type=int, help="number of attention heads in ViT encoder") parser.add_argument("--mlp_dim", default=3072, type=int, help="mlp dimention in ViT encoder") @@ -73,12 +75,12 @@ parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") -parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") -parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") -parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") -parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction") -parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction") -parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction") +parser.add_argument("--space_x", default=1.0, type=float, help="spacing in x direction") +parser.add_argument("--space_y", default=1.0, type=float, help="spacing in y direction") +parser.add_argument("--space_z", default=1.0, type=float, help="spacing in z direction") +parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction") +parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction") +parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction") parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability") parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") @@ -102,10 +104,9 @@ def main(): print("Found total gpus", args.ngpus_per_node) args.world_size = args.ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) - else: + else: main_worker(gpu=0, args=args) - def main_worker(gpu, args): if args.distributed: torch.multiprocessing.set_start_method("fork", force=True) @@ -119,7 +120,8 @@ def main_worker(gpu, args): torch.cuda.set_device(args.gpu) torch.backends.cudnn.benchmark = True args.test_mode = False - loader = get_loader(args) + loader = get_loader(args) if args.btcv else getDatasetLoader(args) + print(args.rank, " gpu", args.gpu) if args.rank == 0: print("Batch size is:", args.batch_size, "epochs", args.max_epochs) @@ -157,8 +159,8 @@ def main_worker(gpu, args): dice_loss = DiceCELoss( to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr ) - post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels) - post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels) + post_label = AsDiscrete(to_onehot=args.out_channels) + post_pred = AsDiscrete(argmax=True, to_onehot=args.out_channels) dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True) model_inferer = partial( sliding_window_inference, @@ -235,6 +237,5 @@ def main_worker(gpu, args): ) return accuracy - if __name__ == "__main__": main() diff --git a/UNETR/BTCV/networks/unetr.py b/UNETR/BTCV/networks/unetr.py index 5557c412..3cec0469 100644 --- a/UNETR/BTCV/networks/unetr.py +++ b/UNETR/BTCV/networks/unetr.py @@ -73,7 +73,7 @@ def __init__( if hidden_size % num_heads != 0: raise AssertionError("hidden size should be divisible by num_heads.") - if pos_embed not in ["conv", "perceptron"]: + if pos_embed not in ['sincos', 'learnable', 'none']: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") self.num_layers = 12 @@ -93,7 +93,7 @@ def __init__( mlp_dim=mlp_dim, num_layers=self.num_layers, num_heads=num_heads, - pos_embed=pos_embed, + pos_embed_type=pos_embed, classification=self.classification, dropout_rate=dropout_rate, ) diff --git a/UNETR/BTCV/requirements.txt b/UNETR/BTCV/requirements.txt index 32677ba2..285af421 100644 --- a/UNETR/BTCV/requirements.txt +++ b/UNETR/BTCV/requirements.txt @@ -1,6 +1,6 @@ -torch==1.9.1 -monai==0.7.0 -nibabel==3.1.1 -tqdm==4.59.0 -einops==0.3.0 -tensorboardX==2.1 +monai==1.5.0 +numpy==2.3.2 +opencv_python +simpleitk==2.5.2 +tensorboardx==2.6.4 +torch \ No newline at end of file diff --git a/UNETR/BTCV/test.py b/UNETR/BTCV/test.py index 4528cd6c..7c5f8877 100644 --- a/UNETR/BTCV/test.py +++ b/UNETR/BTCV/test.py @@ -11,23 +11,28 @@ import argparse import os +from monai.transforms import Compose, Invertd, SaveImaged, AsDiscreted +from monai.inferers import sliding_window_inference +from monai.data.meta_tensor import MetaTensor import numpy as np import torch + from networks.unetr import UNETR from trainer import dice -from utils.data_utils import get_loader - -from monai.inferers import sliding_window_inference +from dataset.customDataset import getDatasetLoader, getPredictLoader parser = argparse.ArgumentParser(description="UNETR segmentation pipeline") + +parser.add_argument( + "--mode",choices=['predict', 'validation'], default="validation", type=str, help="mode for predict or validation" +) + parser.add_argument( "--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory" ) -parser.add_argument("--data_dir", default="/dataset/dataset0/", type=str, help="dataset directory") -parser.add_argument("--json_list", default="dataset_0.json", type=str, help="dataset json file") parser.add_argument( - "--pretrained_model_name", default="UNETR_model_best_acc.pth", type=str, help="pretrained model name" + "--pretrained_model_name", default="model_final.pt", type=str, help="pretrained model name" ) parser.add_argument( "--saved_checkpoint", default="ckpt", type=str, help="Supports torchscript or ckpt pretrained checkpoint type" @@ -37,7 +42,9 @@ parser.add_argument("--feature_size", default=16, type=int, help="feature size dimention") parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap") parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") -parser.add_argument("--out_channels", default=14, type=int, help="number of output channels") +parser.add_argument("--out_channels", default=2, type=int, help="number of output channels") +parser.add_argument("--batch_size", default=1, type=int, help="number of batch size") +parser.add_argument("--sw_batch_size", default=1, type=int, help="number of sliding window batch size") parser.add_argument("--num_heads", default=12, type=int, help="number of attention heads in ViT encoder") parser.add_argument("--res_block", action="store_true", help="use residual blocks") parser.add_argument("--conv_block", action="store_true", help="use conv blocks") @@ -45,12 +52,12 @@ parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") -parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") -parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") -parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") -parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction") -parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction") -parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction") +parser.add_argument("--space_x", default=1.0, type=float, help="spacing in x direction") +parser.add_argument("--space_y", default=1.0, type=float, help="spacing in y direction") +parser.add_argument("--space_z", default=1.0, type=float, help="spacing in z direction") +parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction") +parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction") +parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction") parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") parser.add_argument("--distributed", action="store_true", help="start distributed training") parser.add_argument("--workers", default=8, type=int, help="number of workers") @@ -58,14 +65,18 @@ parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") -parser.add_argument("--pos_embed", default="perceptron", type=str, help="type of position embedding") +parser.add_argument("--pos_embed", default="learnable", type=str, help="type of position embedding") parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder") +def inference(inputs,model, args): + out = sliding_window_inference(inputs, (args.roi_x, args.roi_y, args.roi_z), args.sw_batch_size, model, overlap=args.infer_overlap) + prob = torch.softmax(out, 1).cpu().numpy() + predict_label = np.argmax(prob, axis=1).astype(np.uint8) + return predict_label def main(): args = parser.parse_args() args.test_mode = True - val_loader = get_loader(args) pretrained_dir = args.pretrained_dir model_name = args.pretrained_model_name device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -87,30 +98,56 @@ def main(): res_block=True, dropout_rate=args.dropout_rate, ) - model_dict = torch.load(pretrained_pth) - model.load_state_dict(model_dict) + model_dict = torch.load(pretrained_pth, weights_only=False) + model.load_state_dict(model_dict['state_dict']) model.eval() model.to(device) with torch.no_grad(): dice_list_case = [] - for i, batch in enumerate(val_loader): - val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) - img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] - print("Inference on case {}".format(img_name)) - val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model, overlap=args.infer_overlap) - val_outputs = torch.softmax(val_outputs, 1).cpu().numpy() - val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8) - val_labels = val_labels.cpu().numpy()[:, 0, :, :, :] - dice_list_sub = [] - for i in range(1, 14): - organ_Dice = dice(val_outputs[0] == i, val_labels[0] == i) - dice_list_sub.append(organ_Dice) - mean_dice = np.mean(dice_list_sub) - print("Mean Organ Dice: {}".format(mean_dice)) - dice_list_case.append(mean_dice) - print("Overall Mean Dice: {}".format(np.mean(dice_list_case))) + if args.mode == 'validation': + loader = getDatasetLoader(args)[1] + for batch, label in loader: + val_inputs, val_labels = (batch.cuda(), label.cuda()) + val_outputs = inference(val_inputs, model, args) + val_labels = val_labels.cpu().numpy()[:, 0, :, :, :] + dice_list_sub = [] + for i in range(1, args.out_channels): + every_Dice = dice(val_outputs[0] == i, val_labels[0] == i) + dice_list_sub.append(every_Dice) + mean_dice = np.mean(dice_list_sub) + print("Mean Dice: {}".format(mean_dice)) + dice_list_case.append(mean_dice) + print("Overall Mean Dice: {}".format(np.mean(dice_list_case))) + elif args.mode == 'predict': # ref: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py + loader, preTransform = getPredictLoader(args) + postTransforms = Compose([ + Invertd( + keys="pred", + transform=preTransform, + orig_keys="image", # invert from history of image tag + nearest_interp=False, + to_tensor=True, + ), + AsDiscreted(keys="pred", threshold=0.5), + SaveImaged( + keys="pred", + output_dir="./output", + output_postfix="seg", + resample=False, + print_log=True, + ) + ]) + + for d in loader: + + input_data = d['image'].cuda() # (b, c, h, w, d) + predict_raw = inference(input_data, model, args) # shape: (B, H, W, D) + predict_tensor = torch.from_numpy(predict_raw.astype(np.float32)) # shape: (B, H, W, D) + + d["pred"] = MetaTensor(predict_tensor, meta=d["image"].meta) + postTransforms(d) if __name__ == "__main__": main() diff --git a/UNETR/BTCV/trainer.py b/UNETR/BTCV/trainer.py index ecc1540f..edb4d802 100644 --- a/UNETR/BTCV/trainer.py +++ b/UNETR/BTCV/trainer.py @@ -96,10 +96,11 @@ def val_epoch(model, loader, epoch, acc_func, args, model_inferer=None, post_lab start_time = time.time() with torch.no_grad(): for idx, batch_data in enumerate(loader): - if isinstance(batch_data, list): - data, target = batch_data - else: - data, target = batch_data["image"], batch_data["label"] + # if isinstance(batch_data, list): + # data, target = batch_data + # else: + # data, target = batch_data["image"], batch_data["label"] + data, target = batch_data data, target = data.cuda(args.rank), target.cuda(args.rank) with autocast(enabled=args.amp): if model_inferer is not None: diff --git a/UNETR/BTCV/utils/data_utils.py b/UNETR/BTCV/utils/data_utils.py index bcdd844e..b2a29731 100755 --- a/UNETR/BTCV/utils/data_utils.py +++ b/UNETR/BTCV/utils/data_utils.py @@ -11,13 +11,16 @@ import math import os - import numpy as np import torch +from pathlib import Path +import SimpleITK as sitk from monai import data, transforms from monai.data import load_decathlon_datalist +WORKROOT = Path(__file__).parent.parent +JPG_EXT = '.jpg' class Sampler(torch.utils.data.Sampler): def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True): @@ -72,7 +75,7 @@ def get_loader(args): train_transform = transforms.Compose( [ transforms.LoadImaged(keys=["image", "label"]), - transforms.AddChanneld(keys=["image", "label"]), + transforms.EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"), transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), transforms.Spacingd( keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") @@ -103,7 +106,7 @@ def get_loader(args): val_transform = transforms.Compose( [ transforms.LoadImaged(keys=["image", "label"]), - transforms.AddChanneld(keys=["image", "label"]), + transforms.EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"), transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), transforms.Spacingd( keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") @@ -162,4 +165,4 @@ def get_loader(args): ) loader = [train_loader, val_loader] - return loader + return loader \ No newline at end of file From 832618ef2a697e7e05d33f69860c9d5a88575a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=97=E4=BD=91=20=E6=9E=97?= Date: Tue, 2 Sep 2025 17:12:37 +0800 Subject: [PATCH 2/6] Update README.md --- UNETR/BTCV/README.md | 68 +++++++++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/UNETR/BTCV/README.md b/UNETR/BTCV/README.md index 8ff6f412..ccd11c58 100644 --- a/UNETR/BTCV/README.md +++ b/UNETR/BTCV/README.md @@ -1,11 +1,14 @@ # Model Overview + This repository contains the code for UNETR: Transformers for 3D Medical Image Segmentation [1]. UNETR is the first 3D segmentation network that uses a pure vision transformer as its encoder without relying on CNNs for feature extraction. The code presents a volumetric (3D) multi-organ segmentation application using the BTCV challenge dataset. ![image](https://lh3.googleusercontent.com/pw/AM-JKLU2eTW17rYtCmiZP3WWC-U1HCPOHwLe6pxOfJXwv2W-00aHfsNy7jeGV1dwUq0PXFOtkqasQ2Vyhcu6xkKsPzy3wx7O6yGOTJ7ZzA01S6LSh8szbjNLfpbuGgMe6ClpiS61KGvqu71xXFnNcyvJNFjN=w1448-h496-no?authuser=0) ### Installing Dependencies + Dependencies can be installed using: -``` bash + +```bash pip install -r requirements.txt ``` @@ -13,7 +16,7 @@ pip install -r requirements.txt A UNETR network with standard hyper-parameters for the task of multi-organ semantic segmentation (BTCV dataset) can be defined as follows: -``` bash +```bash model = UNETR( in_channels=1, out_channels=14, @@ -30,12 +33,13 @@ model = UNETR( ``` The above UNETR model is used for CT images (1-channel input) and for 14-class segmentation outputs. The network expects -resampled input images with size ```(96, 96, 96)``` which will be converted into non-overlapping patches of size ```(16, 16, 16)```. +resampled input images with size `(96, 96, 96)` which will be converted into non-overlapping patches of size `(16, 16, 16)`. The position embedding is performed using a perceptron layer. The ViT encoder follows standard hyper-parameters as introduced in [2]. The decoder uses convolutional and residual blocks as well as instance normalization. More details can be found in [1]. Using the default values for hyper-parameters, the following command can be used to initiate training using PyTorch native AMP package: -``` bash + +```bash python main.py --feature_size=32 --batch_size=1 @@ -48,28 +52,30 @@ python main.py --data_dir=/dataset/dataset0/ ``` -Note that you need to provide the location of your dataset directory by using ```--data_dir```. +Note that you need to provide the location of your dataset directory by using `--data_dir`. -To initiate distributed multi-gpu training, ```--distributed``` needs to be added to the training command. +To initiate distributed multi-gpu training, `--distributed` needs to be added to the training command. -To disable AMP, ```--noamp``` needs to be added to the training command. +To disable AMP, `--noamp` needs to be added to the training command. -If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. ```--optim_lr```) -according to the number of GPUs. For instance, ```--optim_lr=4e-4``` is recommended for training with 4 GPUs. +If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. `--optim_lr`) +according to the number of GPUs. For instance, `--optim_lr=4e-4` is recommended for training with 4 GPUs. ### Finetuning + We provide state-of-the-art pre-trained checkpoints and TorchScript models of UNETR using BTCV dataset. For using the pre-trained checkpoint, please download the weights from the following directory: https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pth -Once downloaded, please place the checkpoint in the following directory or use ```--pretrained_dir``` to provide the address of where the model is placed: +Once downloaded, please place the checkpoint in the following directory or use `--pretrained_dir` to provide the address of where the model is placed: -```./pretrained_models``` +`./pretrained_models` The following command initiates finetuning using the pretrained checkpoint: -``` bash + +```bash python main.py --batch_size=1 --logdir=unetr_pretrained @@ -88,12 +94,13 @@ For using the pre-trained TorchScript model, please download the model from the https://developer.download.nvidia.com/assets/Clara/monai/research/UNETR_model_best_acc.pt -Once downloaded, please place the TorchScript model in the following directory or use ```--pretrained_dir``` to provide the address of where the model is placed: +Once downloaded, please place the TorchScript model in the following directory or use `--pretrained_dir` to provide the address of where the model is placed: -```./pretrained_models``` +`./pretrained_models` The following command initiates finetuning using the TorchScript model: -``` bash + +```bash python main.py --batch_size=1 --logdir=unetr_pretrained @@ -108,31 +115,43 @@ python main.py --pretrained_model_name='UNETR_model_best_acc.pt' --resume_jit ``` -Note that finetuning from the provided TorchScript model does not support AMP. +Note that finetuning from the provided TorchScript model does not support AMP. ### Testing + You can use the state-of-the-art pre-trained TorchScript model or checkpoint of UNETR to test it on your own data. Once the pretrained weights are downloaded, using the links above, please place the TorchScript model in the following directory or -use ```--pretrained_dir``` to provide the address of where the model is placed: +use `--pretrained_dir` to provide the address of where the model is placed: + +`./pretrained_models` -```./pretrained_models``` +The following command runs inference(validation or predict mask) using the provided checkpoint: -The following command runs inference using the provided checkpoint: -``` bash +```bash python test.py +--mode='validation' --infer_overlap=0.5 --data_dir=/dataset/dataset0/ --pretrained_dir='./pretrained_models/' --saved_checkpoint=ckpt ``` -Note that ```--infer_overlap``` determines the overlap between the sliding window patches. A higher value typically results in more accurate segmentation outputs but with the cost of longer inference time. +```bash +python test.py +--mode='predict' +--infer_overlap=0.5 +--pretrained_dir='./pretrained_models/' +--saved_checkpoint=ckpt +``` + +Note that `--infer_overlap` determines the overlap between the sliding window patches. A higher value typically results in more accurate segmentation outputs but with the cost of longer inference time. -If you would like to use the pretrained TorchScript model, ```--saved_checkpoint=torchscript``` should be used. +If you would like to use the pretrained TorchScript model, `--saved_checkpoint=torchscript` should be used. ### Tutorial + A tutorial for the task of multi-organ segmentation using BTCV dataset can be found in the following: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb @@ -140,7 +159,9 @@ https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_ Additionally, a tutorial which leverages PyTorch Lightning can be found in the following: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb + ## Dataset + ![image](https://lh3.googleusercontent.com/pw/AM-JKLX0svvlMdcrchGAgiWWNkg40lgXYjSHsAAuRc5Frakmz2pWzSzf87JQCRgYpqFR0qAjJWPzMQLc_mmvzNjfF9QWl_1OHZ8j4c9qrbR6zQaDJWaCLArRFh0uPvk97qAa11HtYbD6HpJ-wwTCUsaPcYvM=w1724-h522-no?authuser=0) The training data is from the [BTCV challenge dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217752). @@ -152,7 +173,6 @@ Under Institutional Review Board (IRB) supervision, 50 abdomen CT scans of were - Modality: CT - Size: 30 3D volumes (24 Training + 6 Testing) - We provide the json file that is used to train our models in the following link: https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_btcv_dataset_0.json @@ -160,6 +180,7 @@ https://developer.download.nvidia.com/assets/Clara/monai/tutorials/swin_unetr_bt Once the json file is downloaded, please place it in the same folder as the dataset. ## Citation + If you find this repository useful, please consider citing UNETR paper: ``` @@ -173,6 +194,7 @@ If you find this repository useful, please consider citing UNETR paper: ``` ## References + [1] Hatamizadeh, Ali, et al. "UNETR: Transformers for 3D Medical Image Segmentation", 2021. https://arxiv.org/abs/2103.10504. [2] Dosovitskiy, Alexey, et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale From f0b3dc46228ae86932e49f3c79eb03ca1317f526 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Sep 2025 09:21:48 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- UNETR/BTCV/config.py | 2 +- UNETR/BTCV/dataset/customDataset.py | 6 +++--- UNETR/BTCV/main.py | 4 ++-- UNETR/BTCV/requirements.txt | 2 +- UNETR/BTCV/test.py | 4 ++-- UNETR/BTCV/utils/data_utils.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/UNETR/BTCV/config.py b/UNETR/BTCV/config.py index 6477737e..f34875f4 100644 --- a/UNETR/BTCV/config.py +++ b/UNETR/BTCV/config.py @@ -1,3 +1,3 @@ NIFTI_DATA_ROOT = 'data/images' # nifti image directory NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory -PREDICT_DATA_ROOT = 'data/predict' # predict image directory \ No newline at end of file +PREDICT_DATA_ROOT = 'data/predict' # predict image directory diff --git a/UNETR/BTCV/dataset/customDataset.py b/UNETR/BTCV/dataset/customDataset.py index bfcf7a1f..b5e00e08 100644 --- a/UNETR/BTCV/dataset/customDataset.py +++ b/UNETR/BTCV/dataset/customDataset.py @@ -23,9 +23,9 @@ def collate_fn(batch): images = torch.stack(images, dim=0) labels = torch.stack(labels, dim=0) - + return [torch.Tensor(images), torch.Tensor(labels)] - + return collate_fn def getDatasetLoader(args): @@ -80,7 +80,7 @@ def getDatasetLoader(args): transforms.ToTensord(keys=["image", "label"]), ] ) - + trainDataset = Dataset(data=trainDicts, transform=train_transform) valDataset = Dataset(data=valDicts, transform=val_transform) trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True)) diff --git a/UNETR/BTCV/main.py b/UNETR/BTCV/main.py index 3b5b80ba..0804ccac 100644 --- a/UNETR/BTCV/main.py +++ b/UNETR/BTCV/main.py @@ -104,7 +104,7 @@ def main(): print("Found total gpus", args.ngpus_per_node) args.world_size = args.ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) - else: + else: main_worker(gpu=0, args=args) def main_worker(gpu, args): @@ -121,7 +121,7 @@ def main_worker(gpu, args): torch.backends.cudnn.benchmark = True args.test_mode = False loader = get_loader(args) if args.btcv else getDatasetLoader(args) - + print(args.rank, " gpu", args.gpu) if args.rank == 0: print("Batch size is:", args.batch_size, "epochs", args.max_epochs) diff --git a/UNETR/BTCV/requirements.txt b/UNETR/BTCV/requirements.txt index 285af421..1b72f298 100644 --- a/UNETR/BTCV/requirements.txt +++ b/UNETR/BTCV/requirements.txt @@ -3,4 +3,4 @@ numpy==2.3.2 opencv_python simpleitk==2.5.2 tensorboardx==2.6.4 -torch \ No newline at end of file +torch diff --git a/UNETR/BTCV/test.py b/UNETR/BTCV/test.py index 7c5f8877..b9defef0 100644 --- a/UNETR/BTCV/test.py +++ b/UNETR/BTCV/test.py @@ -138,9 +138,9 @@ def main(): print_log=True, ) ]) - + for d in loader: - + input_data = d['image'].cuda() # (b, c, h, w, d) predict_raw = inference(input_data, model, args) # shape: (B, H, W, D) predict_tensor = torch.from_numpy(predict_raw.astype(np.float32)) # shape: (B, H, W, D) diff --git a/UNETR/BTCV/utils/data_utils.py b/UNETR/BTCV/utils/data_utils.py index b2a29731..c2079f0b 100755 --- a/UNETR/BTCV/utils/data_utils.py +++ b/UNETR/BTCV/utils/data_utils.py @@ -165,4 +165,4 @@ def get_loader(args): ) loader = [train_loader, val_loader] - return loader \ No newline at end of file + return loader From 9a75007002b1689f52633670d3c29b4f8debb064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=97=E4=BD=91=20=E6=9E=97?= Date: Wed, 3 Sep 2025 09:39:50 +0800 Subject: [PATCH 4/6] fix: Resolve issue #420 in the workflow of predict - `dataset/customDataset.py` - Ensure data type of image/label match float/int required by CE/one-hot - Add transforms.EnsureTyped in preTransform pipeline - `networks/unetr.py`, `README.md` - Change default param of pos_embe to `learnable` - `requirements.txt` - Assign the range of torch version and add nibabel version - `test.py` - Refactor the code - `trainer.py` - Handle unpack data process by loader --- UNETR/BTCV/README.md | 2 +- UNETR/BTCV/dataset/customDataset.py | 6 +++--- UNETR/BTCV/networks/unetr.py | 5 ++++- UNETR/BTCV/requirements.txt | 3 ++- UNETR/BTCV/test.py | 9 ++++++--- UNETR/BTCV/trainer.py | 9 ++++----- 6 files changed, 20 insertions(+), 14 deletions(-) diff --git a/UNETR/BTCV/README.md b/UNETR/BTCV/README.md index ccd11c58..e125b504 100644 --- a/UNETR/BTCV/README.md +++ b/UNETR/BTCV/README.md @@ -25,7 +25,7 @@ model = UNETR( hidden_size=768, mlp_dim=3072, num_heads=12, - pos_embed='perceptron', + pos_embed='learnable', norm_name='instance', conv_block=True, res_block=True, diff --git a/UNETR/BTCV/dataset/customDataset.py b/UNETR/BTCV/dataset/customDataset.py index bfcf7a1f..f5a89e65 100644 --- a/UNETR/BTCV/dataset/customDataset.py +++ b/UNETR/BTCV/dataset/customDataset.py @@ -23,8 +23,8 @@ def collate_fn(batch): images = torch.stack(images, dim=0) labels = torch.stack(labels, dim=0) - - return [torch.Tensor(images), torch.Tensor(labels)] + # keep images float and labels long for loss functions + return [images.float(), labels.long()] return collate_fn @@ -111,7 +111,7 @@ def getPredictLoader(args): keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), transforms.CropForegroundd(keys=["image"], source_key="image", allow_smaller=True), - # transforms.ToTensord(keys=["image"],track_meta=True), # This transformation will transform MetaTensor to Tensor + transforms.EnsureTyped(keys=["image"], track_meta=True), ] ) valDataset = Dataset(data=dataDicts, transform=preTransform) diff --git a/UNETR/BTCV/networks/unetr.py b/UNETR/BTCV/networks/unetr.py index 3cec0469..fc1fa77c 100644 --- a/UNETR/BTCV/networks/unetr.py +++ b/UNETR/BTCV/networks/unetr.py @@ -34,7 +34,7 @@ def __init__( hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, - pos_embed: str = "perceptron", + pos_embed: str = "learnable", norm_name: Union[Tuple, str] = "instance", conv_block: bool = False, res_block: bool = True, @@ -73,6 +73,9 @@ def __init__( if hidden_size % num_heads != 0: raise AssertionError("hidden size should be divisible by num_heads.") + # Backward-compat aliases + if pos_embed == "perceptron": + pos_embed = "learnable" if pos_embed not in ['sincos', 'learnable', 'none']: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") diff --git a/UNETR/BTCV/requirements.txt b/UNETR/BTCV/requirements.txt index 285af421..19e403f5 100644 --- a/UNETR/BTCV/requirements.txt +++ b/UNETR/BTCV/requirements.txt @@ -3,4 +3,5 @@ numpy==2.3.2 opencv_python simpleitk==2.5.2 tensorboardx==2.6.4 -torch \ No newline at end of file +torch>=2.3,<2.7 +nibabel>=5.0 \ No newline at end of file diff --git a/UNETR/BTCV/test.py b/UNETR/BTCV/test.py index 7c5f8877..27481909 100644 --- a/UNETR/BTCV/test.py +++ b/UNETR/BTCV/test.py @@ -135,18 +135,21 @@ def main(): output_dir="./output", output_postfix="seg", resample=False, + output_dtype=np.uint8, print_log=True, ) ]) for d in loader: - input_data = d['image'].cuda() # (b, c, h, w, d) + # shape: (b, c, h, w, d) + input_data = (d["image"] if torch.is_tensor(d["image"]) else torch.as_tensor(d["image"])).to(device) predict_raw = inference(input_data, model, args) # shape: (B, H, W, D) predict_tensor = torch.from_numpy(predict_raw.astype(np.float32)) # shape: (B, H, W, D) - d["pred"] = MetaTensor(predict_tensor, meta=d["image"].meta) - + meta = getattr(d["image"], "meta", None) + d["pred"] = MetaTensor(predict_tensor, meta=meta) if meta is not None else predict_tensor + postTransforms(d) if __name__ == "__main__": diff --git a/UNETR/BTCV/trainer.py b/UNETR/BTCV/trainer.py index edb4d802..b12eb396 100644 --- a/UNETR/BTCV/trainer.py +++ b/UNETR/BTCV/trainer.py @@ -96,11 +96,10 @@ def val_epoch(model, loader, epoch, acc_func, args, model_inferer=None, post_lab start_time = time.time() with torch.no_grad(): for idx, batch_data in enumerate(loader): - # if isinstance(batch_data, list): - # data, target = batch_data - # else: - # data, target = batch_data["image"], batch_data["label"] - data, target = batch_data + if isinstance(batch_data, (list, tuple)): + data, target = batch_data + else: + data, target = batch_data["image"], batch_data["label"] data, target = data.cuda(args.rank), target.cuda(args.rank) with autocast(enabled=args.amp): if model_inferer is not None: From 4b366dfe0a0ea8c93232a51c854ebb5d2a25d78c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 02:06:41 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- UNETR/BTCV/dataset/customDataset.py | 2 +- UNETR/BTCV/requirements.txt | 1 - UNETR/BTCV/test.py | 6 +++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/UNETR/BTCV/dataset/customDataset.py b/UNETR/BTCV/dataset/customDataset.py index 73dc68ff..c87375e7 100644 --- a/UNETR/BTCV/dataset/customDataset.py +++ b/UNETR/BTCV/dataset/customDataset.py @@ -25,7 +25,7 @@ def collate_fn(batch): labels = torch.stack(labels, dim=0) # keep images float and labels long for loss functions return [images.float(), labels.long()] - + return collate_fn def getDatasetLoader(args): diff --git a/UNETR/BTCV/requirements.txt b/UNETR/BTCV/requirements.txt index 7140a6e4..041a5aeb 100644 --- a/UNETR/BTCV/requirements.txt +++ b/UNETR/BTCV/requirements.txt @@ -5,4 +5,3 @@ simpleitk==2.5.2 tensorboardx==2.6.4 torch>=2.3,<2.7 nibabel>=5.0 - diff --git a/UNETR/BTCV/test.py b/UNETR/BTCV/test.py index f09faff6..a21ca00e 100644 --- a/UNETR/BTCV/test.py +++ b/UNETR/BTCV/test.py @@ -141,8 +141,8 @@ def main(): ]) for d in loader: - - # shape: (b, c, h, w, d) + + # shape: (b, c, h, w, d) input_data = (d["image"] if torch.is_tensor(d["image"]) else torch.as_tensor(d["image"])).to(device) predict_raw = inference(input_data, model, args) # shape: (B, H, W, D) @@ -150,7 +150,7 @@ def main(): meta = getattr(d["image"], "meta", None) d["pred"] = MetaTensor(predict_tensor, meta=meta) if meta is not None else predict_tensor - + postTransforms(d) if __name__ == "__main__": From 30bc5b123db225e20849e077df9aed2fe9517697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=97=E4=BD=91=20=E6=9E=97?= Date: Wed, 3 Sep 2025 13:54:47 +0800 Subject: [PATCH 6/6] Make pairing robust: filter by extension and intersect filenames across image/label roots --- UNETR/BTCV/dataset/customDataset.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/UNETR/BTCV/dataset/customDataset.py b/UNETR/BTCV/dataset/customDataset.py index c87375e7..964ffa33 100644 --- a/UNETR/BTCV/dataset/customDataset.py +++ b/UNETR/BTCV/dataset/customDataset.py @@ -29,8 +29,14 @@ def collate_fn(batch): return collate_fn def getDatasetLoader(args): - dataName = [d for d in os.listdir(NIFTI_LABEL_ROOT)] - dataDicts = [{"image": f"{os.path.join(NIFTI_DATA_ROOT, d)}", "label": f"{os.path.join(NIFTI_LABEL_ROOT, d)}"} for d in dataName] + exts = (".nii", ".nii.gz") + img_names = {f for f in os.listdir(NIFTI_DATA_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_DATA_ROOT, f))} + lbl_names = {f for f in os.listdir(NIFTI_LABEL_ROOT) if f.endswith(exts) and os.path.isfile(os.path.join(NIFTI_LABEL_ROOT, f))} + common = sorted(img_names & lbl_names) + if not common: + raise RuntimeError(f"No matching image/label pairs found in {NIFTI_DATA_ROOT} and {NIFTI_LABEL_ROOT}") + dataDicts = [{"image": os.path.join(NIFTI_DATA_ROOT, f), "label": os.path.join(NIFTI_LABEL_ROOT, f)} for f in common] + trainDicts, valDicts = _splitList(dataDicts) train_transform = transforms.Compose(