Skip to content

Commit 7509bad

Browse files
committed
utils: Implement get_optimizer()
1 parent 35a6656 commit 7509bad

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Backbone:
2626
optimizer:
2727
name: Adam
2828
lr: 0.001
29+
weight_decay: 0.00001
2930
scheduler:
3031
name: ReduceLROnPlateau
3132
min_lr: 0.0001
@@ -38,6 +39,7 @@ Proposed:
3839
optimizer:
3940
name: Adam
4041
lr: 0.001
42+
weight_decay: 0.00001
4143
scheduler:
4244
name: ReduceLROnPlateau
4345
min_lr: 0.0001

utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import models.unet
1515

1616

17-
# 설정 불러오기
1817
def load_config():
1918
with open('config.yaml') as f:
2019
config = yaml.load(f, Loader=yaml.FullLoader)
@@ -25,7 +24,6 @@ def load_config():
2524
return config
2625

2726

28-
# 모델 불러오기
2927
def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3028
assert isinstance(pretrained, bool)
3129

@@ -36,7 +34,7 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3634
elif config['model'] == 'Proposed':
3735
model = models.proposed.Proposed(config[config['model']]['num_classes'])
3836
else:
39-
raise NameError('Wrong model_name.')
37+
raise NameError('Wrong model name.')
4038

4139
if pretrained:
4240
if os.path.exists(config[config['model']]['pretrained_weights']):
@@ -46,6 +44,20 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
4644
return model
4745

4846

47+
def get_optimizer(config: dict, model: torch.nn.Module):
48+
cfg_optim: dict = config[config['model']]['optimizer']
49+
50+
if cfg_optim['name'] == 'SGD':
51+
optimizer = torch.optim.SGD(model.parameters(), lr=cfg_optim['lr'],
52+
momentum=cfg_optim['momentum'], weight_decay=cfg_optim['weight_decay'])
53+
elif cfg_optim['name'] == 'Adam':
54+
optimizer = torch.optim.Adam(model.parameters(), lr=cfg_optim['lr'], weight_decay=cfg_optim['weight_decay'])
55+
else:
56+
raise NameError('Wrong optimizer name.')
57+
58+
return optimizer
59+
60+
4961
class Cityscapes:
5062
def __init__(self, config: dict):
5163
self.config = config

0 commit comments

Comments
 (0)