1414import models .unet
1515
1616
17- # 설정 불러오기
1817def load_config ():
1918 with open ('config.yaml' ) as f :
2019 config = yaml .load (f , Loader = yaml .FullLoader )
@@ -25,7 +24,6 @@ def load_config():
2524 return config
2625
2726
28- # 모델 불러오기
2927def get_model (config : dict , pretrained = False ) -> torch .nn .Module :
3028 assert isinstance (pretrained , bool )
3129
@@ -36,7 +34,7 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3634 elif config ['model' ] == 'Proposed' :
3735 model = models .proposed .Proposed (config [config ['model' ]]['num_classes' ])
3836 else :
39- raise NameError ('Wrong model_name .' )
37+ raise NameError ('Wrong model name .' )
4038
4139 if pretrained :
4240 if os .path .exists (config [config ['model' ]]['pretrained_weights' ]):
@@ -46,6 +44,20 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
4644 return model
4745
4846
47+ def get_optimizer (config : dict , model : torch .nn .Module ):
48+ cfg_optim : dict = config [config ['model' ]]['optimizer' ]
49+
50+ if cfg_optim ['name' ] == 'SGD' :
51+ optimizer = torch .optim .SGD (model .parameters (), lr = cfg_optim ['lr' ],
52+ momentum = cfg_optim ['momentum' ], weight_decay = cfg_optim ['weight_decay' ])
53+ elif cfg_optim ['name' ] == 'Adam' :
54+ optimizer = torch .optim .Adam (model .parameters (), lr = cfg_optim ['lr' ], weight_decay = cfg_optim ['weight_decay' ])
55+ else :
56+ raise NameError ('Wrong optimizer name.' )
57+
58+ return optimizer
59+
60+
4961class Cityscapes :
5062 def __init__ (self , config : dict ):
5163 self .config = config
0 commit comments