@@ -86,21 +86,18 @@ def __init__(self, num_classes: int):
8686 super (Proposed , self ).__init__ ()
8787 resnet34 = torchvision .models .resnet34 (pretrained = True )
8888
89- self .encode1 = self .double_conv (3 , 64 )
90- self .encode2 = resnet34 .layer1 # 64
91- self .encode3 = resnet34 .layer2 # 128
92- self .encode4 = resnet34 .layer3 # 256
93- self .encode_end = resnet34 .layer4 # 512
94- self .aspp = ASPP (512 , 1024 )
95-
96- self .upconv4 = nn .ConvTranspose2d (1024 , 512 , kernel_size = 2 , stride = 2 )
97- self .decode4 = self .double_conv (512 + 256 , 512 )
89+ self .encode0 = self .double_conv (3 , 64 )
90+ self .encode1 = resnet34 .layer1 # 64
91+ self .encode2 = resnet34 .layer2 # 128
92+ self .encode3 = resnet34 .layer3 # 256
93+ self .encode4 = resnet34 .layer4 # 512
94+ self .encode_end = ASPP (512 , 512 )
9895
9996 self .upconv3 = nn .ConvTranspose2d (512 , 256 , kernel_size = 2 , stride = 2 )
100- self .decode3 = self .double_conv (256 + 128 , 256 )
97+ self .decode3 = self .double_conv (512 , 256 )
10198
10299 self .upconv2 = nn .ConvTranspose2d (256 , 128 , kernel_size = 2 , stride = 2 )
103- self .decode2 = self .double_conv (128 + 64 , 128 )
100+ self .decode2 = self .double_conv (256 , 128 )
104101
105102 self .upconv1 = nn .ConvTranspose2d (128 , 64 , kernel_size = 2 , stride = 2 )
106103 self .decode1 = self .double_conv (128 , 64 )
@@ -125,16 +122,14 @@ def make_layer(self, in_channels, out_channels, num_blocks):
125122
126123 def forward (self , x ):
127124 # Encoder
128- encode1 = self .encode1 (x )
129- encode2 = self .encode2 (F . max_pool2d ( encode1 , 2 ) )
125+ encode1 = self .encode1 (self . encode0 ( x ) )
126+ encode2 = self .encode2 (encode1 )
130127 encode3 = self .encode3 (encode2 )
131128 encode4 = self .encode4 (encode3 )
132129 encode_end = self .encode_end (encode4 )
133- encode_end = self .aspp (encode_end )
134130
135131 # Decoder
136- out = self .decode4 (torch .cat ([self .upconv4 (encode_end ), encode4 ], dim = 1 ))
137- out = self .decode3 (torch .cat ([self .upconv3 (out ), encode3 ], dim = 1 ))
132+ out = self .decode3 (torch .cat ([self .upconv3 (encode_end ), encode3 ], dim = 1 ))
138133 out = self .decode2 (torch .cat ([self .upconv2 (out ), encode2 ], dim = 1 ))
139134 out = self .decode1 (torch .cat ([self .upconv1 (out ), encode1 ], dim = 1 ))
140135
0 commit comments