Skip to content

Commit 04acd25

Browse files
committed
utils: Add function
1 parent 6c5de35 commit 04acd25

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

utils/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import configparser
22
import os
33

4+
import matplotlib.colors
45
import matplotlib.pyplot as plt
6+
import numpy as np
57
import torch
68
import torch.utils.data
79
import torchvision
@@ -31,6 +33,7 @@ def load_config():
3133
return model_name, config
3234

3335

36+
# 모델 불러오기
3437
def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torch.nn.Module:
3538
if model_name == 'UNet':
3639
model = models.unet.UNet(num_classes)
@@ -47,6 +50,43 @@ def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torc
4750
return model
4851

4952

53+
# Cityscapes 데이터셋 라벨 색상 불러오기
54+
def get_cityscapes_label_colormap(short=False):
55+
colormap = np.zeros((20, 3), dtype=np.uint8)
56+
if not short:
57+
colormap[0] = [0, 0, 0]
58+
colormap[1] = [128, 64, 128]
59+
colormap[2] = [244, 35, 232]
60+
colormap[3] = [70, 70, 70]
61+
colormap[4] = [102, 102, 156]
62+
colormap[5] = [190, 153, 153]
63+
colormap[6] = [153, 153, 153]
64+
colormap[7] = [250, 170, 30]
65+
colormap[8] = [220, 220, 0]
66+
colormap[9] = [107, 142, 35]
67+
colormap[10] = [152, 251, 152]
68+
colormap[11] = [70, 130, 180]
69+
colormap[12] = [220, 20, 60]
70+
colormap[13] = [255, 0, 0]
71+
colormap[14] = [0, 0, 142]
72+
colormap[15] = [0, 0, 70]
73+
colormap[16] = [0, 60, 100]
74+
colormap[17] = [0, 80, 100]
75+
colormap[18] = [0, 0, 230]
76+
colormap[19] = [119, 11, 32]
77+
else:
78+
colormap[0] = [0, 0, 0]
79+
colormap[1] = [128, 64, 128]
80+
colormap[2] = [70, 70, 70]
81+
colormap[3] = [250, 170, 30]
82+
colormap[4] = [107, 142, 35]
83+
colormap[5] = [70, 130, 180]
84+
colormap[6] = [220, 20, 60]
85+
colormap[7] = [0, 0, 142]
86+
87+
matplotlib.colors.ListedColormap(colormap)
88+
89+
5090
# Cityscapes 데이터셋 설정
5191
def init_cityscapes_dataset(config: dict):
5292
transform = torchvision.transforms.Compose([

0 commit comments

Comments
 (0)