Skip to content

Commit f367d7f

Browse files
committed
utils: Edit config dict
1. Add dataset_root, model_name. 2. Change interface of function
1 parent 52933d9 commit f367d7f

File tree

5 files changed

+21
-16
lines changed

5 files changed

+21
-16
lines changed

demo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111

1212
if __name__ == '__main__':
1313
# 0. Load config
14-
model_name, config = utils.utils.load_config()
15-
print('Activated model: {}'.format(model_name))
14+
config = utils.utils.load_config()
15+
print('Activated model: {}'.format(config['model_name']))
1616

1717
# 1. Dataset
1818
_, _, testset, testloader = utils.utils.init_cityscapes_dataset(config)
1919

2020
# 2. Model
2121
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22-
model = utils.utils.get_model(model_name, config['num_classes'], config['pretrained_weights']).to(device)
22+
model = utils.utils.get_model(config['model_name'], config['num_classes'], config['pretrained_weights']).to(device)
2323

2424
# 이미지 이름 불러오기
2525
image_names = []
@@ -32,7 +32,7 @@
3232

3333
# 예측 결과 저장
3434
step = 0
35-
result_dir = os.path.join('demo', model_name.lower())
35+
result_dir = os.path.join('demo', config['model_name'].lower())
3636
groundtruth_dir = os.path.join('demo', 'groundtruth')
3737
os.makedirs(result_dir, exist_ok=True)
3838
os.makedirs(groundtruth_dir, exist_ok=True)

eval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,22 @@ def evaluate(model, testloader, num_classes: int, device):
8787

8888
if __name__ == '__main__':
8989
# 0. Load config
90-
model_name, config = utils.utils.load_config()
91-
print('Activated model: {}'.format(model_name))
90+
config = utils.utils.load_config()
91+
print('Activated model: {}'.format(config['model_name']))
9292

9393
# 1. Dataset
9494
_, _, testset, testloader = utils.utils.init_cityscapes_dataset(config)
9595

9696
# 2. Model
9797
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98-
model = utils.utils.get_model(model_name, config['num_classes'], config['pretrained_weights']).to(device)
98+
model = utils.utils.get_model(config['model_name'], config['num_classes'], config['pretrained_weights']).to(device)
9999

100100
# 모델 평가
101101
val_loss, iou, miou, fps = evaluate(model, testloader, config['num_classes'], device)
102102

103103
# 평가 결과를 csv 파일로 저장
104104
os.makedirs('result', exist_ok=True)
105-
with open(os.path.join('result', '{}.csv'.format(model_name)), mode='w') as f:
105+
with open(os.path.join('result', '{}.csv'.format(config['model_name'])), mode='w') as f:
106106
writer = csv.writer(f, delimiter=',', lineterminator='\n')
107107

108108
writer.writerow(['Class Number', 'Class Name', 'IoU'])

models/models.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
[dataset]
2+
root = ../../data/cityscapes
3+
14
[model]
25
activate_model = UNet
36

train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,23 @@
1111

1212
if __name__ == '__main__':
1313
# 0. Load config
14-
model_name, config = utils.utils.load_config()
15-
print('Activated model: {}'.format(model_name))
14+
config = utils.utils.load_config()
15+
print('Activated model: {}'.format(config['model_name']))
1616

1717
# 1. Dataset
1818
trainset, trainloader, testset, testloader = utils.utils.init_cityscapes_dataset(config)
1919

2020
# 2. Model
2121
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22-
model = utils.utils.get_model(model_name, config['num_classes']).to(device)
22+
model = utils.utils.get_model(config['model_name'], config['num_classes']).to(device)
2323

2424
# 3. Loss function, optimizer, lr scheduler
2525
criterion = nn.CrossEntropyLoss()
2626
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
2727
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=0.0001)
2828

2929
# 4. Tensorboard
30-
writer = torch.utils.tensorboard.SummaryWriter(os.path.join('runs', model_name))
30+
writer = torch.utils.tensorboard.SummaryWriter(os.path.join('runs', config['model_name']))
3131
writer.add_graph(model, trainloader.__iter__().__next__()[0].to(device))
3232

3333
# 5. Train and evaluate
@@ -69,6 +69,6 @@
6969
if miou > prev_miou:
7070
os.makedirs('weights', exist_ok=True)
7171
torch.save(model.state_dict(),
72-
os.path.join('weights', '{}_best.pth'.format(model_name)))
72+
os.path.join('weights', '{}_best.pth'.format(config['model_name'])))
7373
prev_miou = miou
7474
writer.close()

utils/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def load_config():
2020

2121
model_name = parser['model']['activate_model']
2222
config = {
23+
'dataset_root': parser['dataset']['root'],
24+
'model_name': model_name,
2325
'batch_size': parser.getint(model_name, 'batch_size'),
2426
'epoch': parser.getint(model_name, 'epoch'),
2527
'image_size': (int(parser[model_name]['image_size'].split('x')[1]),
@@ -29,7 +31,7 @@ def load_config():
2931
'num_workers': parser.getint(model_name, 'num_workers'),
3032
'pretrained_weights': parser[model_name]['pretrained_weights'],
3133
}
32-
return model_name, config
34+
return config
3335

3436

3537
# 모델 불러오기
@@ -98,7 +100,7 @@ def init_cityscapes_dataset(config: dict):
98100
torchvision.transforms.Resize(config['image_size'], interpolation=0),
99101
torchvision.transforms.ToTensor(),
100102
])
101-
trainset = utils.datasets.Cityscapes(root='../../data/cityscapes',
103+
trainset = utils.datasets.Cityscapes(root=config['dataset_root'],
102104
split='train',
103105
mode='fine',
104106
target_type='semantic',
@@ -109,7 +111,7 @@ def init_cityscapes_dataset(config: dict):
109111
shuffle=True,
110112
num_workers=config['num_workers'],
111113
pin_memory=True)
112-
testset = utils.datasets.Cityscapes(root='../../data/cityscapes',
114+
testset = utils.datasets.Cityscapes(root=config['dataset_root'],
113115
split='val',
114116
mode='fine',
115117
target_type='semantic',

0 commit comments

Comments
 (0)