@@ -30,17 +30,17 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3030 assert isinstance (pretrained , bool )
3131
3232 if config ['model' ] == 'UNet' :
33- model = models .unet .UNet (config ['model' ]['num_classes' ])
33+ model = models .unet .UNet (config [config [ 'model' ] ]['num_classes' ])
3434 elif config ['model' ] == 'Proposed' :
35- model = models .proposed .Proposed (config ['model' ]['num_classes' ])
35+ model = models .proposed .Proposed (config [config [ 'model' ] ]['num_classes' ])
3636 elif config ['model' ] == 'Backbone' :
37- model = models .backbone .Backbone (config ['model' ]['num_classes' ])
37+ model = models .backbone .Backbone (config [config [ 'model' ] ]['num_classes' ])
3838 else :
3939 raise NameError ('Wrong model_name.' )
4040
4141 if pretrained :
42- if os .path .exists (config ['model' ]['pretrained_weights' ]):
43- model .load_state_dict (torch .load (config ['model' ]['pretrained_weights' ]))
42+ if os .path .exists (config [config [ 'model' ] ]['pretrained_weights' ]):
43+ model .load_state_dict (torch .load (config [config [ 'model' ] ]['pretrained_weights' ]))
4444 else :
4545 print ('FileNotFound: pretrained_weights (' + config ['model' ] + ')' )
4646 return model
@@ -77,7 +77,7 @@ def set_cityscapes(self):
7777 target_type = 'semantic' ,
7878 transforms = self .transforms )
7979 trainloader = torch .utils .data .DataLoader (trainset ,
80- batch_size = self .config ['model' ]['batch_size' ],
80+ batch_size = self .config [self . config [ 'model' ] ]['batch_size' ],
8181 shuffle = True ,
8282 num_workers = self .config ['dataset' ]['num_workers' ],
8383 pin_memory = True )
@@ -88,7 +88,7 @@ def set_cityscapes(self):
8888 transform = self .transform ,
8989 target_transform = self .target_transform )
9090 testloader = torch .utils .data .DataLoader (testset ,
91- batch_size = self .config ['model' ]['batch_size' ],
91+ batch_size = self .config [self . config [ 'model' ] ]['batch_size' ],
9292 shuffle = False ,
9393 num_workers = self .config ['dataset' ]['num_workers' ])
9494
0 commit comments