Skip to content

Commit 39edf98

Browse files
committed
proposed: Refactor to improve readability
1 parent 444bd45 commit 39edf98

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

models/proposed.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, in_channels: int, out_channels: int):
3535
nn.BatchNorm2d(out_channels),
3636
nn.ReLU(inplace=True)
3737
)
38-
# 5번 branch = AdaptiveAvgPool2d → 1x1 convolution → BatchNorm → ReLu
38+
# 5번 branch = Global Average Pooling → 1x1 convolution → BatchNorm → ReLu
3939
self.branch5 = nn.Sequential(
4040
nn.AdaptiveAvgPool2d(1),
4141
nn.Conv2d(in_channels, out_channels, kernel_size=1),
@@ -62,15 +62,18 @@ def forward(self, x):
6262
class Proposed(nn.Module):
6363
def __init__(self, num_classes: int):
6464
super(Proposed, self).__init__()
65+
# Backbone
6566
resnet34 = torchvision.models.resnet34(pretrained=True)
66-
6767
self.initial_conv = self.double_conv(3, 64)
6868
self.encode1 = resnet34.layer1 # 64
69-
self.encode2 = resnet34.layer2 # 128
70-
self.encode3 = resnet34.layer3 # 256
71-
self.encode4 = resnet34.layer4 # 512
69+
self.encode2 = resnet34.layer2 # 128, 1/2
70+
self.encode3 = resnet34.layer3 # 256, 1/4
71+
self.encode4 = resnet34.layer4 # 512, 1/8
72+
73+
# ASPP
7274
self.aspp = ASPP(512, 512)
7375

76+
# Decoder
7477
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
7578
self.decode3 = self.double_conv(512, 256)
7679

@@ -80,18 +83,9 @@ def __init__(self, num_classes: int):
8083
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
8184
self.decode1 = self.double_conv(128, 64)
8285

86+
# Classifier
8387
self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
8488

85-
def double_conv(self, in_channels: int, out_channels: int):
86-
return nn.Sequential(
87-
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
88-
nn.BatchNorm2d(out_channels),
89-
nn.ReLU(inplace=True),
90-
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
91-
nn.BatchNorm2d(out_channels),
92-
nn.ReLU(inplace=True)
93-
)
94-
9589
def forward(self, x):
9690
# Encoder
9791
encode1 = self.encode1(self.initial_conv(x))
@@ -108,6 +102,22 @@ def forward(self, x):
108102
out = self.classifier(out)
109103
return out
110104

105+
def double_conv(self, in_channels: int, out_channels: int):
106+
return nn.Sequential(
107+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
108+
nn.BatchNorm2d(out_channels),
109+
nn.ReLU(inplace=True),
110+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
111+
nn.BatchNorm2d(out_channels),
112+
nn.ReLU(inplace=True)
113+
)
114+
115+
def make_channel_adjuster(self, in_channels: int, out_channels: int):
116+
return nn.Sequential(
117+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
118+
nn.Conv2d(out_channels, out_channels, kernel_size=1)
119+
)
120+
111121

112122
if __name__ == '__main__':
113123
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

0 commit comments

Comments
 (0)