diff --git a/dataset/default.py b/dataset/default.py index 28bb0bd..7258364 100644 --- a/dataset/default.py +++ b/dataset/default.py @@ -26,7 +26,7 @@ def __init__(self, root_dir: str): def get_data_files(self): nifti_file_names = os.listdir(self.root_dir) folder_names = [os.path.join( - self.root_dir, nifti_file_name) for nifti_file_name in nifti_file_names if nifti_file_names.endsiwth('.nii')] + self.root_dir, nifti_file_name) for nifti_file_name in nifti_file_names if nifti_file_names.endswith('.nii')] return folder_names def __len__(self): diff --git a/train/get_dataset.py b/train/get_dataset.py index 5da9338..a3c4058 100644 --- a/train/get_dataset.py +++ b/train/get_dataset.py @@ -45,4 +45,5 @@ def get_dataset(cfg): val_dataset = DEFAULTDataset( root_dir=cfg.dataset.root_dir) sampler = None + return train_dataset, val_dataset, sampler raise ValueError(f'{cfg.dataset.name} Dataset is not available') diff --git a/train/train_vqgan.py b/train/train_vqgan.py index d472321..3da2903 100644 --- a/train/train_vqgan.py +++ b/train/train_vqgan.py @@ -1,6 +1,8 @@ "Adapted from https://github.com/SongweiGe/TATS" import os +import sys +sys.path.append('/Users/dmnk/Documents/GitHub/segmentation_diffusion/medicaldiffusion') import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader