Skip to content

Commit eded42e

Browse files
committed
train: best val loss 모델도 저장
1 parent 02846c2 commit eded42e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# 5. Train and evaluate
3434
log_loss = tqdm.tqdm(total=0, position=2, bar_format='{desc}', leave=False)
3535
prev_miou = 0.0
36+
prev_val_loss = 0.0
3637
for epoch in tqdm.tqdm(range(config['epoch']), desc='Epoch'):
3738
model.train()
3839

@@ -73,4 +74,9 @@
7374
if miou > prev_miou:
7475
torch.save(model.state_dict(), os.path.join('weights', '{}_best.pth'.format(config['model_name'])))
7576
prev_miou = miou
77+
78+
# Best val_loss를 가진 모델을 저장
79+
if val_loss > prev_val_loss:
80+
torch.save(model.state_dict(), os.path.join('weights', '{}_val_best.pth'.format(config['model_name'])))
81+
prev_val_loss = val_loss
7682
writer.close()

0 commit comments

Comments
 (0)