11import configparser
22import os
33
4+ import matplotlib .colors
45import matplotlib .pyplot as plt
6+ import numpy as np
57import torch
68import torch .utils .data
79import torchvision
@@ -31,6 +33,7 @@ def load_config():
3133 return model_name , config
3234
3335
36+ # 모델 불러오기
3437def get_model (model_name : str , num_classes : int , pretrained : str = None ) -> torch .nn .Module :
3538 if model_name == 'UNet' :
3639 model = models .unet .UNet (num_classes )
@@ -47,6 +50,43 @@ def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torc
4750 return model
4851
4952
53+ # Cityscapes 데이터셋 라벨 색상 불러오기
54+ def get_cityscapes_label_colormap (short = False ):
55+ colormap = np .zeros ((20 , 3 ), dtype = np .uint8 )
56+ if not short :
57+ colormap [0 ] = [0 , 0 , 0 ]
58+ colormap [1 ] = [128 , 64 , 128 ]
59+ colormap [2 ] = [244 , 35 , 232 ]
60+ colormap [3 ] = [70 , 70 , 70 ]
61+ colormap [4 ] = [102 , 102 , 156 ]
62+ colormap [5 ] = [190 , 153 , 153 ]
63+ colormap [6 ] = [153 , 153 , 153 ]
64+ colormap [7 ] = [250 , 170 , 30 ]
65+ colormap [8 ] = [220 , 220 , 0 ]
66+ colormap [9 ] = [107 , 142 , 35 ]
67+ colormap [10 ] = [152 , 251 , 152 ]
68+ colormap [11 ] = [70 , 130 , 180 ]
69+ colormap [12 ] = [220 , 20 , 60 ]
70+ colormap [13 ] = [255 , 0 , 0 ]
71+ colormap [14 ] = [0 , 0 , 142 ]
72+ colormap [15 ] = [0 , 0 , 70 ]
73+ colormap [16 ] = [0 , 60 , 100 ]
74+ colormap [17 ] = [0 , 80 , 100 ]
75+ colormap [18 ] = [0 , 0 , 230 ]
76+ colormap [19 ] = [119 , 11 , 32 ]
77+ else :
78+ colormap [0 ] = [0 , 0 , 0 ]
79+ colormap [1 ] = [128 , 64 , 128 ]
80+ colormap [2 ] = [70 , 70 , 70 ]
81+ colormap [3 ] = [250 , 170 , 30 ]
82+ colormap [4 ] = [107 , 142 , 35 ]
83+ colormap [5 ] = [70 , 130 , 180 ]
84+ colormap [6 ] = [220 , 20 , 60 ]
85+ colormap [7 ] = [0 , 0 , 142 ]
86+
87+ matplotlib .colors .ListedColormap (colormap )
88+
89+
5090# Cityscapes 데이터셋 설정
5191def init_cityscapes_dataset (config : dict ):
5292 transform = torchvision .transforms .Compose ([
0 commit comments