Skip to content

Commit 9281c97

Browse files
committed
Proposed: Add BN to initial conv
1 parent c176533 commit 9281c97

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

models/proposed.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, num_classes: int):
6464
super(Proposed, self).__init__()
6565
resnet34 = torchvision.models.resnet34(pretrained=True)
6666

67-
self.initial_conv = self.double_conv(3, 64, batch_norm=False)
67+
self.initial_conv = self.double_conv(3, 64)
6868
self.encode1 = resnet34.layer1 # 64
6969
self.encode2 = resnet34.layer2 # 128
7070
self.encode3 = resnet34.layer3 # 256
@@ -82,23 +82,15 @@ def __init__(self, num_classes: int):
8282

8383
self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
8484

85-
def double_conv(self, in_channels: int, out_channels: int, batch_norm=True):
86-
if batch_norm:
87-
return nn.Sequential(
88-
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
89-
nn.BatchNorm2d(out_channels),
90-
nn.ReLU(inplace=True),
91-
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
92-
nn.BatchNorm2d(out_channels),
93-
nn.ReLU(inplace=True)
94-
)
95-
else:
96-
return nn.Sequential(
97-
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
98-
nn.ReLU(inplace=True),
99-
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
100-
nn.ReLU(inplace=True)
101-
)
85+
def double_conv(self, in_channels: int, out_channels: int):
86+
return nn.Sequential(
87+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
88+
nn.BatchNorm2d(out_channels),
89+
nn.ReLU(inplace=True),
90+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
91+
nn.BatchNorm2d(out_channels),
92+
nn.ReLU(inplace=True)
93+
)
10294

10395
def forward(self, x):
10496
# Encoder

0 commit comments

Comments
 (0)