We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 02846c2 commit eded42eCopy full SHA for eded42e
train.py
@@ -33,6 +33,7 @@
33
# 5. Train and evaluate
34
log_loss = tqdm.tqdm(total=0, position=2, bar_format='{desc}', leave=False)
35
prev_miou = 0.0
36
+ prev_val_loss = 0.0
37
for epoch in tqdm.tqdm(range(config['epoch']), desc='Epoch'):
38
model.train()
39
@@ -73,4 +74,9 @@
73
74
if miou > prev_miou:
75
torch.save(model.state_dict(), os.path.join('weights', '{}_best.pth'.format(config['model_name'])))
76
prev_miou = miou
77
+
78
+ # Best val_loss를 가진 모델을 저장
79
+ if val_loss > prev_val_loss:
80
+ torch.save(model.state_dict(), os.path.join('weights', '{}_val_best.pth'.format(config['model_name'])))
81
+ prev_val_loss = val_loss
82
writer.close()
0 commit comments