Skip to content

Commit 32766d8

Browse files
committed
utils: Implement adaptive code to get_cityscapes_colormap()
1 parent 4f7aad9 commit 32766d8

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
image_names.append(image_name)
3030

3131
# label colormap 설정
32-
cmap = matplotlib.colors.ListedColormap(dataset.get_cityscapes_colormap(short=True))
32+
cmap = matplotlib.colors.ListedColormap(dataset.get_cityscapes_colormap())
3333

3434
# 예측 결과 저장
3535
step = 0

utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def load_config():
2626

2727
def get_model(config: dict, pretrained=False) -> torch.nn.Module:
2828
assert isinstance(pretrained, bool)
29+
assert config['dataset']['num_classes'] == 20 or config['dataset']['num_classes'] == 8
2930

3031
if config['model'] == 'UNet':
3132
model = models.unet.UNet(config['dataset']['num_classes'])
@@ -81,6 +82,7 @@ def __init__(self, config: dict):
8182

8283
self.class_names_short = ['unlabeled', 'flat', 'construction', 'object',
8384
'nature', 'sky', 'human', 'vehicle']
85+
self.num_classes = self.config['dataset']['num_classes']
8486

8587
self.transform = torchvision.transforms.Compose([
8688
torchvision.transforms.Resize(self.config['dataset']['image_size']),
@@ -124,9 +126,8 @@ def set_cityscapes(self):
124126
return trainset, trainloader, testset, testloader
125127

126128
# Cityscapes 데이터셋 라벨 색상 불러오기
127-
def get_cityscapes_colormap(self, short=False):
128-
assert isinstance(short, bool)
129-
if not short:
129+
def get_cityscapes_colormap(self):
130+
if self.num_classes == 20:
130131
colormap = np.zeros((20, 3), dtype=np.uint8)
131132
colormap[0] = [0, 0, 0]
132133
colormap[1] = [128, 64, 128]
@@ -148,7 +149,7 @@ def get_cityscapes_colormap(self, short=False):
148149
colormap[17] = [0, 80, 100]
149150
colormap[18] = [0, 0, 230]
150151
colormap[19] = [119, 11, 32]
151-
else:
152+
elif self.num_classes == 8:
152153
colormap = np.zeros((8, 3), dtype=np.uint8)
153154
colormap[0] = [0, 0, 0]
154155
colormap[1] = [128, 64, 128]
@@ -158,6 +159,8 @@ def get_cityscapes_colormap(self, short=False):
158159
colormap[5] = [70, 130, 180]
159160
colormap[6] = [220, 20, 60]
160161
colormap[7] = [0, 0, 142]
162+
else:
163+
raise ValueError('Wrong num_classes.')
161164

162165
return np.divide(colormap, 255).tolist()
163166

0 commit comments

Comments
 (0)