diff --git a/DAE/BTCV_Finetune/utils/data_utils.py b/DAE/BTCV_Finetune/utils/data_utils.py index 6d3a1a37..84867b09 100644 --- a/DAE/BTCV_Finetune/utils/data_utils.py +++ b/DAE/BTCV_Finetune/utils/data_utils.py @@ -80,7 +80,7 @@ def get_loader(args): transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), - transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), transforms.RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", @@ -111,7 +111,7 @@ def get_loader(args): transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), - transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), transforms.ToTensord(keys=["image", "label"]), ] ) diff --git a/DAE/Pretrain_full_contrast/data/data_pretrain.py b/DAE/Pretrain_full_contrast/data/data_pretrain.py index 8a3a817c..4340515d 100644 --- a/DAE/Pretrain_full_contrast/data/data_pretrain.py +++ b/DAE/Pretrain_full_contrast/data/data_pretrain.py @@ -21,7 +21,6 @@ AsChannelFirstd, AsDiscrete, Compose, - CropForegroundd, LoadImaged, NormalizeIntensityd, Orientationd, diff --git a/SwinMM/WORD/models/swin_unetr.py b/SwinMM/WORD/models/swin_unetr.py index be03918b..83b51c87 100644 --- a/SwinMM/WORD/models/swin_unetr.py +++ b/SwinMM/WORD/models/swin_unetr.py @@ -38,7 +38,6 @@ def __init__( """ super().__init__( - img_size, *args, num_heads=num_heads, feature_size=feature_size, diff --git a/SwinUNETR/BRATS21/README.md b/SwinUNETR/BRATS21/README.md index a3614305..d3d9048a 100644 --- a/SwinUNETR/BRATS21/README.md +++ b/SwinUNETR/BRATS21/README.md @@ -106,8 +106,7 @@ Mean Dice refers to average Dice of WT, ET and TC tumor semantic classes. A Swin UNETR network with standard hyper-parameters for brain tumor semantic segmentation (BraTS dataset) is be defined as: ``` bash -model = SwinUNETR(img_size=(128,128,128), - in_channels=4, +model = SwinUNETR(in_channels=4, out_channels=3, feature_size=48, use_checkpoint=True, diff --git a/SwinUNETR/BRATS21/main.py b/SwinUNETR/BRATS21/main.py index 3b96b7ce..f4d51af9 100644 --- a/SwinUNETR/BRATS21/main.py +++ b/SwinUNETR/BRATS21/main.py @@ -127,7 +127,6 @@ def main_worker(gpu, args): pretrained_pth = os.path.join(pretrained_dir, model_name) model = SwinUNETR( - img_size=(args.roi_x, args.roi_y, args.roi_z), in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, diff --git a/SwinUNETR/BRATS21/test.py b/SwinUNETR/BRATS21/test.py index dcf4cf99..8f8f29ff 100644 --- a/SwinUNETR/BRATS21/test.py +++ b/SwinUNETR/BRATS21/test.py @@ -70,7 +70,6 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pretrained_pth = os.path.join(pretrained_dir, model_name) model = SwinUNETR( - img_size=128, in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, diff --git a/SwinUNETR/BRATS21/utils/data_utils.py b/SwinUNETR/BRATS21/utils/data_utils.py index 64865ce6..999df6a6 100755 --- a/SwinUNETR/BRATS21/utils/data_utils.py +++ b/SwinUNETR/BRATS21/utils/data_utils.py @@ -99,7 +99,7 @@ def get_loader(args): transforms.LoadImaged(keys=["image", "label"]), transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), transforms.CropForegroundd( - keys=["image", "label"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z] + keys=["image", "label"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z], allow_smaller=True ), transforms.RandSpatialCropd( keys=["image", "label"], roi_size=[args.roi_x, args.roi_y, args.roi_z], random_size=False diff --git a/SwinUNETR/BTCV/README.md b/SwinUNETR/BTCV/README.md index 087a914c..98e02611 100644 --- a/SwinUNETR/BTCV/README.md +++ b/SwinUNETR/BTCV/README.md @@ -84,8 +84,7 @@ Once the json file is downloaded, please place it in the same folder as the data A Swin UNETR network with standard hyper-parameters for multi-organ semantic segmentation (BTCV dataset) is be defined as: ``` bash -model = SwinUNETR(img_size=(96,96,96), - in_channels=1, +model = SwinUNETR(in_channels=1, out_channels=14, feature_size=48, use_checkpoint=True, diff --git a/SwinUNETR/BTCV/main.py b/SwinUNETR/BTCV/main.py index 7654d51d..5ef77243 100644 --- a/SwinUNETR/BTCV/main.py +++ b/SwinUNETR/BTCV/main.py @@ -127,7 +127,6 @@ def main_worker(gpu, args): pretrained_dir = args.pretrained_dir model = SwinUNETR( - img_size=(args.roi_x, args.roi_y, args.roi_z), in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, diff --git a/SwinUNETR/BTCV/test.py b/SwinUNETR/BTCV/test.py index a4c45167..ad570536 100644 --- a/SwinUNETR/BTCV/test.py +++ b/SwinUNETR/BTCV/test.py @@ -71,7 +71,6 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pretrained_pth = os.path.join(pretrained_dir, model_name) model = SwinUNETR( - img_size=96, in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, diff --git a/SwinUNETR/BTCV/utils/data_utils.py b/SwinUNETR/BTCV/utils/data_utils.py index 6df4c43e..e6d3fd32 100755 --- a/SwinUNETR/BTCV/utils/data_utils.py +++ b/SwinUNETR/BTCV/utils/data_utils.py @@ -80,7 +80,7 @@ def get_loader(args): transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), - transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), transforms.RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", @@ -111,7 +111,7 @@ def get_loader(args): transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), - transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), transforms.ToTensord(keys=["image", "label"]), ] ) diff --git a/SwinUNETR/Pretrain/utils/data_utils.py b/SwinUNETR/Pretrain/utils/data_utils.py index 0264c851..4d7095f2 100644 --- a/SwinUNETR/Pretrain/utils/data_utils.py +++ b/SwinUNETR/Pretrain/utils/data_utils.py @@ -78,7 +78,7 @@ def get_loader(args): keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]), - CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]), + CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z], allow_smaller=True), RandSpatialCropSamplesd( keys=["image"], roi_size=[args.roi_x, args.roi_y, args.roi_z], @@ -98,7 +98,7 @@ def get_loader(args): keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]), - CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]), + CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z], allow_smaller=True), RandSpatialCropSamplesd( keys=["image"], roi_size=[args.roi_x, args.roi_y, args.roi_z], diff --git a/UNETR/BTCV/utils/data_utils.py b/UNETR/BTCV/utils/data_utils.py index 026f88f4..bcdd844e 100755 --- a/UNETR/BTCV/utils/data_utils.py +++ b/UNETR/BTCV/utils/data_utils.py @@ -80,7 +80,7 @@ def get_loader(args): transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), - transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), transforms.RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", @@ -111,7 +111,7 @@ def get_loader(args): transforms.ScaleIntensityRanged( keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), - transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), transforms.ToTensord(keys=["image", "label"]), ] ) diff --git a/auto3dseg/algorithm_templates/dints/scripts/algo.py b/auto3dseg/algorithm_templates/dints/scripts/algo.py index f3afb50d..646e6ba3 100644 --- a/auto3dseg/algorithm_templates/dints/scripts/algo.py +++ b/auto3dseg/algorithm_templates/dints/scripts/algo.py @@ -158,6 +158,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs): "source_key": "@image_key", "start_coord_key": None, "end_coord_key": None, + "allow_smaller": True, }, ], } @@ -174,7 +175,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs): "b_max": 1.0, "clip": True, }, - {"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key"}, + {"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key", "allow_smaller": True}, ], } diff --git a/auto3dseg/algorithm_templates/swinunetr/configs/network.yaml b/auto3dseg/algorithm_templates/swinunetr/configs/network.yaml index 888862a2..a3ace8df 100644 --- a/auto3dseg/algorithm_templates/swinunetr/configs/network.yaml +++ b/auto3dseg/algorithm_templates/swinunetr/configs/network.yaml @@ -1,7 +1,6 @@ network: _target_: SwinUNETR feature_size: 48 - img_size: 96 in_channels: "@input_channels" out_channels: "@output_classes" spatial_dims: 3 diff --git a/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py b/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py index ec80eb5d..1512ff62 100644 --- a/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py +++ b/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py @@ -168,6 +168,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs): "source_key": "@image_key", "start_coord_key": None, "end_coord_key": None, + "allow_smaller": True, }, ], } @@ -183,7 +184,7 @@ def fill_template_config(self, data_stats_file, output_path, **kwargs): "b_max": 1.0, "clip": True, }, - {"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key"}, + {"_target_": "CropForegroundd", "keys": "@image_key", "source_key": "@image_key", "allow_smaller": True}, ], } mr_intensity_transform = {