Skip to content

Commit e9ecb55

Browse files
committed
proposed: Remove BN layers on initial conv
1 parent e47b7b4 commit e9ecb55

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

models/proposed.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(self, num_classes: int):
8686
super(Proposed, self).__init__()
8787
resnet34 = torchvision.models.resnet34(pretrained=True)
8888

89-
self.initial_conv = self.double_conv(3, 64)
89+
self.initial_conv = self.double_conv(3, 64, batch_norm=False)
9090
self.encode1 = resnet34.layer1 # 64
9191
self.encode2 = resnet34.layer2 # 128
9292
self.encode3 = resnet34.layer3 # 256
@@ -104,15 +104,23 @@ def __init__(self, num_classes: int):
104104

105105
self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
106106

107-
def double_conv(self, in_channels: int, out_channels: int):
108-
return nn.Sequential(
109-
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
110-
nn.BatchNorm2d(out_channels),
111-
nn.ReLU(inplace=True),
112-
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
113-
nn.BatchNorm2d(out_channels),
114-
nn.ReLU(inplace=True)
115-
)
107+
def double_conv(self, in_channels: int, out_channels: int, batch_norm=True):
108+
if batch_norm:
109+
return nn.Sequential(
110+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
111+
nn.BatchNorm2d(out_channels),
112+
nn.ReLU(inplace=True),
113+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
114+
nn.BatchNorm2d(out_channels),
115+
nn.ReLU(inplace=True)
116+
)
117+
else:
118+
return nn.Sequential(
119+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
120+
nn.ReLU(inplace=True),
121+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
122+
nn.ReLU(inplace=True)
123+
)
116124

117125
def make_layer(self, in_channels, out_channels, num_blocks):
118126
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)]

0 commit comments

Comments
 (0)