Skip to content

Commit 308400c

Browse files
committed
utils: Fix assert statement
1 parent 5edcbb3 commit 308400c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def load_config():
3838

3939
# 모델 불러오기
4040
def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torch.nn.Module:
41-
assert isinstance(model_name, str) and isinstance(num_classes, int) and isinstance(pretrained, str)
41+
assert isinstance(model_name, str) and isinstance(num_classes, int)
4242

4343
if model_name == 'UNet':
4444
model = models.unet.UNet(num_classes)
@@ -50,6 +50,7 @@ def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torc
5050
raise NameError('Wrong model_name.')
5151

5252
if pretrained is not None:
53+
assert isinstance(pretrained, str)
5354
if os.path.exists(pretrained):
5455
model.load_state_dict(torch.load(pretrained))
5556
else:

0 commit comments

Comments
 (0)