Skip to content

Commit 50e2080

Browse files
committed
models: Implement load_backbone function
1 parent 4cd4e78 commit 50e2080

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

models/backbone.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import torch.nn as nn
24
import torch.nn.functional as F
35
import torch.utils.tensorboard
@@ -51,6 +53,16 @@ def make_initial_conv(self, in_channels: int, out_channels: int):
5153
)
5254

5355

56+
def load_backbone(num_classes: int, pretrained=False):
57+
model = Backbone(num_classes)
58+
if pretrained:
59+
if os.path.exists('weights/Backbone_best.pth'):
60+
model.load_state_dict(torch.load('weights/Backbone_best.pth'))
61+
else:
62+
print('FileNotFound: pretrained_weights (Backbone)')
63+
return model
64+
65+
5466
if __name__ == '__main__':
5567
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5668
model = Backbone(8).to(device)

models/proposed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.utils.tensorboard
55
import torchsummary
66

7-
import utils
7+
import models.backbone
88

99

1010
# ASPP(Atrous Spatial Pyramid Pooling) Module
@@ -59,7 +59,7 @@ class Proposed(nn.Module):
5959
def __init__(self, num_classes: int):
6060
super(Proposed, self).__init__()
6161
# Backbone
62-
backbone = utils.get_model('Backbone', num_classes, 'weights/Backbone_best.pth')
62+
backbone = models.backbone.load_backbone(num_classes, pretrained=True)
6363
self.initial_conv = backbone.initial_conv
6464
self.encode1 = backbone.layer1 # 64
6565
self.encode2 = backbone.layer2 # 128, 1/2

0 commit comments

Comments
 (0)