@@ -20,6 +20,8 @@ def load_config():
2020
2121 model_name = parser ['model' ]['activate_model' ]
2222 config = {
23+ 'dataset_root' : parser ['dataset' ]['root' ],
24+ 'model_name' : model_name ,
2325 'batch_size' : parser .getint (model_name , 'batch_size' ),
2426 'epoch' : parser .getint (model_name , 'epoch' ),
2527 'image_size' : (int (parser [model_name ]['image_size' ].split ('x' )[1 ]),
@@ -29,7 +31,7 @@ def load_config():
2931 'num_workers' : parser .getint (model_name , 'num_workers' ),
3032 'pretrained_weights' : parser [model_name ]['pretrained_weights' ],
3133 }
32- return model_name , config
34+ return config
3335
3436
3537# 모델 불러오기
@@ -98,7 +100,7 @@ def init_cityscapes_dataset(config: dict):
98100 torchvision .transforms .Resize (config ['image_size' ], interpolation = 0 ),
99101 torchvision .transforms .ToTensor (),
100102 ])
101- trainset = utils .datasets .Cityscapes (root = '../../data/cityscapes' ,
103+ trainset = utils .datasets .Cityscapes (root = config [ 'dataset_root' ] ,
102104 split = 'train' ,
103105 mode = 'fine' ,
104106 target_type = 'semantic' ,
@@ -109,7 +111,7 @@ def init_cityscapes_dataset(config: dict):
109111 shuffle = True ,
110112 num_workers = config ['num_workers' ],
111113 pin_memory = True )
112- testset = utils .datasets .Cityscapes (root = '../../data/cityscapes' ,
114+ testset = utils .datasets .Cityscapes (root = config [ 'dataset_root' ] ,
113115 split = 'val' ,
114116 mode = 'fine' ,
115117 target_type = 'semantic' ,
0 commit comments