Skip to content

Commit 1624575

Browse files
committed
demo: Refactor to optimize argmax
1 parent 069b357 commit 1624575

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

demo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
with torch.no_grad():
4040
masks_pred = model(images)
4141
masks_pred = F.log_softmax(masks_pred, dim=1)
42-
masks_pred = torch.argmax(masks_pred, dim=1, keepdim=True)
42+
masks_pred = torch.argmax(masks_pred, dim=1)
4343

4444
# 1 배치단위 처리
45+
assert masks.shape[0] == masks_pred.shape[0]
4546
for i in range(masks.shape[0]):
46-
plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu().squeeze())
47+
plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu())
4748
plt.imsave(os.path.join(groundtruth_dir, image_names[step]), masks[i].squeeze())
4849
step += 1

0 commit comments

Comments
 (0)