Skip to content

Commit 6ebdf4c

Browse files
committed
eval: Refactor to optimize argmax
1 parent 1624575 commit 6ebdf4c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def evaluate(model, testloader, num_classes: int, device):
6868

6969
# Segmentation map 만들기
7070
masks_pred = F.log_softmax(masks_pred, dim=1)
71-
masks_pred = torch.argmax(masks_pred, dim=1, keepdim=True)
71+
masks_pred = torch.argmax(masks_pred, dim=1)
7272

7373
# 혼동행렬 업데이트
7474
metrics.update_matrix(masks, masks_pred)
@@ -79,7 +79,7 @@ def evaluate(model, testloader, num_classes: int, device):
7979
# 평균 validation loss 계산
8080
val_loss /= len(testloader.dataset)
8181

82-
# 추론 시간과 fps를 계산 (추론 시간: ms)
82+
# 추론 시간과 fps를 계산 (추론 시간 단위: sec)
8383
inference_time /= len(testloader.dataset)
8484
fps = 1 / inference_time
8585

0 commit comments

Comments
 (0)