Skip to content

Commit c85f5fe

Browse files
committed
demo: Implement colormap of label
1 parent 9281c97 commit c85f5fe

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

demo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
import matplotlib.colors
34
import matplotlib.pyplot as plt
45
import torch.nn.functional as F
56
import torch.utils.data
@@ -26,6 +27,9 @@
2627
image_name = image_path.replace('\\', '/').split('/')[-1]
2728
image_names.append(image_name)
2829

30+
# label colormap 설정
31+
cmap = matplotlib.colors.ListedColormap(utils.utils.get_cityscapes_label_colormap(short=True))
32+
2933
# 예측 결과 저장
3034
step = 0
3135
result_dir = os.path.join('demo', model_name.lower())
@@ -48,6 +52,6 @@
4852
# 1 배치단위 처리
4953
assert masks.shape[0] == masks_pred.shape[0]
5054
for i in range(masks.shape[0]):
51-
plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu())
52-
plt.imsave(os.path.join(groundtruth_dir, image_names[step]), masks[i])
55+
plt.imsave(os.path.join(result_dir, image_names[step]), masks_pred[i].cpu(), cmap=cmap)
56+
plt.imsave(os.path.join(groundtruth_dir, image_names[step]), masks[i], cmap=cmap)
5357
step += 1

utils/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import configparser
22
import os
33

4-
import matplotlib.colors
54
import matplotlib.pyplot as plt
65
import numpy as np
76
import torch
@@ -52,8 +51,8 @@ def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torc
5251

5352
# Cityscapes 데이터셋 라벨 색상 불러오기
5453
def get_cityscapes_label_colormap(short=False):
55-
colormap = np.zeros((20, 3), dtype=np.uint8)
5654
if not short:
55+
colormap = np.zeros((20, 3), dtype=np.uint8)
5756
colormap[0] = [0, 0, 0]
5857
colormap[1] = [128, 64, 128]
5958
colormap[2] = [244, 35, 232]
@@ -75,6 +74,7 @@ def get_cityscapes_label_colormap(short=False):
7574
colormap[18] = [0, 0, 230]
7675
colormap[19] = [119, 11, 32]
7776
else:
77+
colormap = np.zeros((8, 3), dtype=np.uint8)
7878
colormap[0] = [0, 0, 0]
7979
colormap[1] = [128, 64, 128]
8080
colormap[2] = [70, 70, 70]
@@ -84,7 +84,7 @@ def get_cityscapes_label_colormap(short=False):
8484
colormap[6] = [220, 20, 60]
8585
colormap[7] = [0, 0, 142]
8686

87-
matplotlib.colors.ListedColormap(colormap)
87+
return np.divide(colormap, 255).tolist()
8888

8989

9090
# Cityscapes 데이터셋 설정

0 commit comments

Comments
 (0)