Skip to content

Commit e5fa4e9

Browse files
committed
utils: Implement get_scheduler()
1 parent 7509bad commit e5fa4e9

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ UNet:
1616
weight_decay: 0.00001
1717
scheduler:
1818
name: ReduceLROnPlateau
19+
patience: 5
1920
min_lr: 0.0001
2021
pretrained_weights: weights/UNet_best.pth
2122

@@ -29,6 +30,7 @@ Backbone:
2930
weight_decay: 0.00001
3031
scheduler:
3132
name: ReduceLROnPlateau
33+
patience: 5
3234
min_lr: 0.0001
3335
pretrained_weights: weights/Backbone_best.pth
3436

@@ -42,5 +44,6 @@ Proposed:
4244
weight_decay: 0.00001
4345
scheduler:
4446
name: ReduceLROnPlateau
47+
patience: 5
4548
min_lr: 0.0001
4649
pretrained_weights: weights/Proposed_best.pth

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
# 3. Loss function, optimizer, lr scheduler
2525
criterion = nn.CrossEntropyLoss()
26-
optimizer = torch.optim.Adam(model.parameters(), lr=config[config['model']]['lr'])
27-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=0.0001)
26+
optimizer = utils.get_optimizer(config, model)
27+
scheduler = utils.get_scheduler(config, optimizer)
2828

2929
# 4. Tensorboard
3030
writer = torch.utils.tensorboard.SummaryWriter(os.path.join('runs', config['model']))

utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
4444
return model
4545

4646

47-
def get_optimizer(config: dict, model: torch.nn.Module):
47+
def get_optimizer(config: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
4848
cfg_optim: dict = config[config['model']]['optimizer']
4949

5050
if cfg_optim['name'] == 'SGD':
@@ -58,6 +58,18 @@ def get_optimizer(config: dict, model: torch.nn.Module):
5858
return optimizer
5959

6060

61+
def get_scheduler(config: dict, optimizer: torch.optim.Optimizer):
62+
cfg_scheduler: dict = config[config['model']]['scheduler']
63+
64+
if cfg_scheduler['name'] == 'ReduceLROnPlateau':
65+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=cfg_scheduler['patience'],
66+
min_lr=cfg_scheduler['min_lr'])
67+
else:
68+
raise NameError('Wrong scheduler name.')
69+
70+
return scheduler
71+
72+
6173
class Cityscapes:
6274
def __init__(self, config: dict):
6375
self.config = config

0 commit comments

Comments
 (0)