Skip to content

Commit 7cda001

Browse files
committed
train: Save weights of last epoch
1 parent 434bfce commit 7cda001

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@
6565
writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
6666
scheduler.step(val_loss)
6767

68+
# 가장 마지막 epoch의 모델을 저장
69+
os.makedirs('weights', exist_ok=True)
70+
torch.save(model.state_dict(), os.path.join('weights', '{}_last.pth'.format(config['model_name'])))
71+
6872
# Best mIoU를 가진 모델을 저장
6973
if miou > prev_miou:
70-
os.makedirs('weights', exist_ok=True)
71-
torch.save(model.state_dict(),
72-
os.path.join('weights', '{}_best.pth'.format(config['model_name'])))
74+
torch.save(model.state_dict(), os.path.join('weights', '{}_best.pth'.format(config['model_name'])))
7375
prev_miou = miou
7476
writer.close()

0 commit comments

Comments
 (0)