File tree Expand file tree Collapse file tree 3 files changed +18
-3
lines changed
Expand file tree Collapse file tree 3 files changed +18
-3
lines changed Original file line number Diff line number Diff line change 1616 weight_decay : 0.00001
1717 scheduler :
1818 name : ReduceLROnPlateau
19+ patience : 5
1920 min_lr : 0.0001
2021 pretrained_weights : weights/UNet_best.pth
2122
@@ -29,6 +30,7 @@ Backbone:
2930 weight_decay : 0.00001
3031 scheduler :
3132 name : ReduceLROnPlateau
33+ patience : 5
3234 min_lr : 0.0001
3335 pretrained_weights : weights/Backbone_best.pth
3436
@@ -42,5 +44,6 @@ Proposed:
4244 weight_decay : 0.00001
4345 scheduler :
4446 name : ReduceLROnPlateau
47+ patience : 5
4548 min_lr : 0.0001
4649 pretrained_weights : weights/Proposed_best.pth
Original file line number Diff line number Diff line change 2323
2424 # 3. Loss function, optimizer, lr scheduler
2525 criterion = nn .CrossEntropyLoss ()
26- optimizer = torch . optim . Adam ( model . parameters (), lr = config [ config [ ' model' ]][ 'lr' ] )
27- scheduler = torch . optim . lr_scheduler . ReduceLROnPlateau ( optimizer , min_lr = 0.0001 )
26+ optimizer = utils . get_optimizer ( config , model )
27+ scheduler = utils . get_scheduler ( config , optimizer )
2828
2929 # 4. Tensorboard
3030 writer = torch .utils .tensorboard .SummaryWriter (os .path .join ('runs' , config ['model' ]))
Original file line number Diff line number Diff line change @@ -44,7 +44,7 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
4444 return model
4545
4646
47- def get_optimizer (config : dict , model : torch .nn .Module ):
47+ def get_optimizer (config : dict , model : torch .nn .Module ) -> torch . optim . Optimizer :
4848 cfg_optim : dict = config [config ['model' ]]['optimizer' ]
4949
5050 if cfg_optim ['name' ] == 'SGD' :
@@ -58,6 +58,18 @@ def get_optimizer(config: dict, model: torch.nn.Module):
5858 return optimizer
5959
6060
61+ def get_scheduler (config : dict , optimizer : torch .optim .Optimizer ):
62+ cfg_scheduler : dict = config [config ['model' ]]['scheduler' ]
63+
64+ if cfg_scheduler ['name' ] == 'ReduceLROnPlateau' :
65+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , patience = cfg_scheduler ['patience' ],
66+ min_lr = cfg_scheduler ['min_lr' ])
67+ else :
68+ raise NameError ('Wrong scheduler name.' )
69+
70+ return scheduler
71+
72+
6173class Cityscapes :
6274 def __init__ (self , config : dict ):
6375 self .config = config
You can’t perform that action at this time.
0 commit comments