Skip to content

Commit 919b0e2

Browse files
committed
models: Remove hard coding backbone weights path
1 parent 0a71fd7 commit 919b0e2

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

models/backbone.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def make_initial_conv(self, in_channels: int, out_channels: int):
4747
)
4848

4949

50-
def load_backbone(num_classes: int, pretrained=False):
50+
def load_backbone(num_classes: int, pretrained_weights: str = None):
5151
model = Backbone(num_classes)
52-
if pretrained:
53-
if os.path.exists('weights/Backbone_val_best.pth'):
54-
model.load_state_dict(torch.load('weights/Backbone_val_best.pth'))
52+
if pretrained_weights is not None:
53+
if os.path.exists(pretrained_weights):
54+
model.load_state_dict(torch.load(pretrained_weights))
5555
else:
5656
print('FileNotFound: pretrained_weights (Backbone)')
5757
return model

models/proposed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99

1010
class Proposed(nn.Module):
11-
def __init__(self, num_classes: int):
11+
def __init__(self, num_classes: int, backbone_pretrained_weights: str = None):
1212
super(Proposed, self).__init__()
1313
# Backbone
14-
backbone = models.backbone.load_backbone(num_classes, pretrained=True)
14+
backbone = models.backbone.load_backbone(num_classes, backbone_pretrained_weights)
1515
self.initial_conv = backbone.initial_conv
1616
self.encode1 = backbone.layer1 # 64
1717
self.encode2 = backbone.layer2 # 128, 1/2

utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3333
elif config['model'] == 'Backbone':
3434
model = models.backbone.Backbone(config['dataset']['num_classes'])
3535
elif config['model'] == 'Proposed':
36-
model = models.proposed.Proposed(config['dataset']['num_classes'])
36+
model = models.proposed.Proposed(config['dataset']['num_classes'], config['Backbone']['pretrained_weights'])
3737
else:
3838
raise NameError('Wrong model name.')
3939

0 commit comments

Comments
 (0)