Skip to content

Commit 4436683

Browse files
committed
utils: Add alert to get_model function
1 parent b1d0a1c commit 4436683

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

utils/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def get_model(model_name: str, num_classes: int, pretrained: str = None) -> torc
3939
else:
4040
raise NameError('Wrong model_name.')
4141

42-
if pretrained is not None and os.path.exists(pretrained):
43-
model.load_state_dict(torch.load(pretrained))
42+
if pretrained is not None:
43+
if os.path.exists(pretrained):
44+
model.load_state_dict(torch.load(pretrained))
45+
else:
46+
print('FileNotFound: pretrained_weights')
4447
return model
4548

4649

0 commit comments

Comments
 (0)