We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 069b357 commit 1624575Copy full SHA for 1624575
demo.py
@@ -39,10 +39,11 @@
39
with torch.no_grad():
40
masks_pred = model(images)
41
masks_pred = F.log_softmax(masks_pred, dim=1)
42
- masks_pred = torch.argmax(masks_pred, dim=1, keepdim=True)
+ masks_pred = torch.argmax(masks_pred, dim=1)
43
44
# 1 배치단위 처리
45
+ assert masks.shape[0] == masks_pred.shape[0]
46
for i in range(masks.shape[0]):
- plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu().squeeze())
47
+ plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu())
48
plt.imsave(os.path.join(groundtruth_dir, image_names[step]), masks[i].squeeze())
49
step += 1
0 commit comments