@@ -26,6 +26,7 @@ def load_config():
2626
2727def get_model (config : dict , pretrained = False ) -> torch .nn .Module :
2828 assert isinstance (pretrained , bool )
29+ assert config ['dataset' ]['num_classes' ] == 20 or config ['dataset' ]['num_classes' ] == 8
2930
3031 if config ['model' ] == 'UNet' :
3132 model = models .unet .UNet (config ['dataset' ]['num_classes' ])
@@ -81,6 +82,7 @@ def __init__(self, config: dict):
8182
8283 self .class_names_short = ['unlabeled' , 'flat' , 'construction' , 'object' ,
8384 'nature' , 'sky' , 'human' , 'vehicle' ]
85+ self .num_classes = self .config ['dataset' ]['num_classes' ]
8486
8587 self .transform = torchvision .transforms .Compose ([
8688 torchvision .transforms .Resize (self .config ['dataset' ]['image_size' ]),
@@ -124,9 +126,8 @@ def set_cityscapes(self):
124126 return trainset , trainloader , testset , testloader
125127
126128 # Cityscapes 데이터셋 라벨 색상 불러오기
127- def get_cityscapes_colormap (self , short = False ):
128- assert isinstance (short , bool )
129- if not short :
129+ def get_cityscapes_colormap (self ):
130+ if self .num_classes == 20 :
130131 colormap = np .zeros ((20 , 3 ), dtype = np .uint8 )
131132 colormap [0 ] = [0 , 0 , 0 ]
132133 colormap [1 ] = [128 , 64 , 128 ]
@@ -148,7 +149,7 @@ def get_cityscapes_colormap(self, short=False):
148149 colormap [17 ] = [0 , 80 , 100 ]
149150 colormap [18 ] = [0 , 0 , 230 ]
150151 colormap [19 ] = [119 , 11 , 32 ]
151- else :
152+ elif self . num_classes == 8 :
152153 colormap = np .zeros ((8 , 3 ), dtype = np .uint8 )
153154 colormap [0 ] = [0 , 0 , 0 ]
154155 colormap [1 ] = [128 , 64 , 128 ]
@@ -158,6 +159,8 @@ def get_cityscapes_colormap(self, short=False):
158159 colormap [5 ] = [70 , 130 , 180 ]
159160 colormap [6 ] = [220 , 20 , 60 ]
160161 colormap [7 ] = [0 , 0 , 142 ]
162+ else :
163+ raise ValueError ('Wrong num_classes.' )
161164
162165 return np .divide (colormap , 255 ).tolist ()
163166
0 commit comments