Skip to content

Commit 11321c2

Browse files
committed
Fix config error
1 parent ef5f399 commit 11321c2

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def evaluate(model, testloader, num_classes: int, device):
9999
model.eval()
100100

101101
# 모델 평가
102-
val_loss, iou, miou, fps = evaluate(model, testloader, config['model']['num_classes'], device)
102+
val_loss, iou, miou, fps = evaluate(model, testloader, config[config['model']]['num_classes'], device)
103103

104104
# 평가 결과를 csv 파일로 저장
105105
os.makedirs('result', exist_ok=True)

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
# 3. Loss function, optimizer, lr scheduler
2525
criterion = nn.CrossEntropyLoss()
26-
optimizer = torch.optim.Adam(model.parameters(), lr=config['model']['lr'])
26+
optimizer = torch.optim.Adam(model.parameters(), lr=config[config['model']]['lr'])
2727
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=0.0001)
2828

2929
# 4. Tensorboard
@@ -34,7 +34,7 @@
3434
log_loss = tqdm.tqdm(total=0, position=2, bar_format='{desc}', leave=False)
3535
prev_miou = 0.0
3636
prev_val_loss = 0.0
37-
for epoch in tqdm.tqdm(range(config['model']['epoch']), desc='Epoch'):
37+
for epoch in tqdm.tqdm(range(config[config['model']]['epoch']), desc='Epoch'):
3838
model.train()
3939

4040
for batch_idx, (images, masks) in enumerate(tqdm.tqdm(trainloader, desc='Train', leave=False)):
@@ -58,7 +58,7 @@
5858
writer.add_scalar('Train Loss', loss.item(), len(trainloader) * epoch + batch_idx)
5959

6060
# 모델 평가
61-
val_loss, _, miou, _ = eval.evaluate(model, testloader, config['model']['num_classes'], device)
61+
val_loss, _, miou, _ = eval.evaluate(model, testloader, config[config['model']]['num_classes'], device)
6262
writer.add_scalar('Validation Loss', val_loss, epoch)
6363
writer.add_scalar('mIoU', miou, epoch)
6464

utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3030
assert isinstance(pretrained, bool)
3131

3232
if config['model'] == 'UNet':
33-
model = models.unet.UNet(config['model']['num_classes'])
33+
model = models.unet.UNet(config[config['model']]['num_classes'])
3434
elif config['model'] == 'Proposed':
35-
model = models.proposed.Proposed(config['model']['num_classes'])
35+
model = models.proposed.Proposed(config[config['model']]['num_classes'])
3636
elif config['model'] == 'Backbone':
37-
model = models.backbone.Backbone(config['model']['num_classes'])
37+
model = models.backbone.Backbone(config[config['model']]['num_classes'])
3838
else:
3939
raise NameError('Wrong model_name.')
4040

4141
if pretrained:
42-
if os.path.exists(config['model']['pretrained_weights']):
43-
model.load_state_dict(torch.load(config['model']['pretrained_weights']))
42+
if os.path.exists(config[config['model']]['pretrained_weights']):
43+
model.load_state_dict(torch.load(config[config['model']]['pretrained_weights']))
4444
else:
4545
print('FileNotFound: pretrained_weights (' + config['model'] + ')')
4646
return model
@@ -77,7 +77,7 @@ def set_cityscapes(self):
7777
target_type='semantic',
7878
transforms=self.transforms)
7979
trainloader = torch.utils.data.DataLoader(trainset,
80-
batch_size=self.config['model']['batch_size'],
80+
batch_size=self.config[self.config['model']]['batch_size'],
8181
shuffle=True,
8282
num_workers=self.config['dataset']['num_workers'],
8383
pin_memory=True)
@@ -88,7 +88,7 @@ def set_cityscapes(self):
8888
transform=self.transform,
8989
target_transform=self.target_transform)
9090
testloader = torch.utils.data.DataLoader(testset,
91-
batch_size=self.config['model']['batch_size'],
91+
batch_size=self.config[self.config['model']]['batch_size'],
9292
shuffle=False,
9393
num_workers=self.config['dataset']['num_workers'])
9494

0 commit comments

Comments
 (0)