Skip to content

Commit 8c16008

Browse files
committed
proposed: 모델 수정
1. 인코더 깊이를 낮춤
1 parent 1755fed commit 8c16008

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

models/proposed.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,18 @@ def __init__(self, num_classes: int):
8686
super(Proposed, self).__init__()
8787
resnet34 = torchvision.models.resnet34(pretrained=True)
8888

89-
self.encode1 = self.double_conv(3, 64)
90-
self.encode2 = resnet34.layer1 # 64
91-
self.encode3 = resnet34.layer2 # 128
92-
self.encode4 = resnet34.layer3 # 256
93-
self.encode_end = resnet34.layer4 # 512
94-
self.aspp = ASPP(512, 1024)
95-
96-
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
97-
self.decode4 = self.double_conv(512 + 256, 512)
89+
self.encode0 = self.double_conv(3, 64)
90+
self.encode1 = resnet34.layer1 # 64
91+
self.encode2 = resnet34.layer2 # 128
92+
self.encode3 = resnet34.layer3 # 256
93+
self.encode4 = resnet34.layer4 # 512
94+
self.encode_end = ASPP(512, 512)
9895

9996
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
100-
self.decode3 = self.double_conv(256 + 128, 256)
97+
self.decode3 = self.double_conv(512, 256)
10198

10299
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
103-
self.decode2 = self.double_conv(128 + 64, 128)
100+
self.decode2 = self.double_conv(256, 128)
104101

105102
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
106103
self.decode1 = self.double_conv(128, 64)
@@ -125,16 +122,14 @@ def make_layer(self, in_channels, out_channels, num_blocks):
125122

126123
def forward(self, x):
127124
# Encoder
128-
encode1 = self.encode1(x)
129-
encode2 = self.encode2(F.max_pool2d(encode1, 2))
125+
encode1 = self.encode1(self.encode0(x))
126+
encode2 = self.encode2(encode1)
130127
encode3 = self.encode3(encode2)
131128
encode4 = self.encode4(encode3)
132129
encode_end = self.encode_end(encode4)
133-
encode_end = self.aspp(encode_end)
134130

135131
# Decoder
136-
out = self.decode4(torch.cat([self.upconv4(encode_end), encode4], dim=1))
137-
out = self.decode3(torch.cat([self.upconv3(out), encode3], dim=1))
132+
out = self.decode3(torch.cat([self.upconv3(encode_end), encode3], dim=1))
138133
out = self.decode2(torch.cat([self.upconv2(out), encode2], dim=1))
139134
out = self.decode1(torch.cat([self.upconv1(out), encode1], dim=1))
140135

0 commit comments

Comments
 (0)