@@ -35,7 +35,7 @@ def __init__(self, in_channels: int, out_channels: int):
3535 nn .BatchNorm2d (out_channels ),
3636 nn .ReLU (inplace = True )
3737 )
38- # 5번 branch = AdaptiveAvgPool2d → 1x1 convolution → BatchNorm → ReLu
38+ # 5번 branch = Global Average Pooling → 1x1 convolution → BatchNorm → ReLu
3939 self .branch5 = nn .Sequential (
4040 nn .AdaptiveAvgPool2d (1 ),
4141 nn .Conv2d (in_channels , out_channels , kernel_size = 1 ),
@@ -62,15 +62,18 @@ def forward(self, x):
6262class Proposed (nn .Module ):
6363 def __init__ (self , num_classes : int ):
6464 super (Proposed , self ).__init__ ()
65+ # Backbone
6566 resnet34 = torchvision .models .resnet34 (pretrained = True )
66-
6767 self .initial_conv = self .double_conv (3 , 64 )
6868 self .encode1 = resnet34 .layer1 # 64
69- self .encode2 = resnet34 .layer2 # 128
70- self .encode3 = resnet34 .layer3 # 256
71- self .encode4 = resnet34 .layer4 # 512
69+ self .encode2 = resnet34 .layer2 # 128, 1/2
70+ self .encode3 = resnet34 .layer3 # 256, 1/4
71+ self .encode4 = resnet34 .layer4 # 512, 1/8
72+
73+ # ASPP
7274 self .aspp = ASPP (512 , 512 )
7375
76+ # Decoder
7477 self .upconv3 = nn .ConvTranspose2d (512 , 256 , kernel_size = 2 , stride = 2 )
7578 self .decode3 = self .double_conv (512 , 256 )
7679
@@ -80,18 +83,9 @@ def __init__(self, num_classes: int):
8083 self .upconv1 = nn .ConvTranspose2d (128 , 64 , kernel_size = 2 , stride = 2 )
8184 self .decode1 = self .double_conv (128 , 64 )
8285
86+ # Classifier
8387 self .classifier = nn .Conv2d (64 , num_classes , kernel_size = 1 )
8488
85- def double_conv (self , in_channels : int , out_channels : int ):
86- return nn .Sequential (
87- nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
88- nn .BatchNorm2d (out_channels ),
89- nn .ReLU (inplace = True ),
90- nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
91- nn .BatchNorm2d (out_channels ),
92- nn .ReLU (inplace = True )
93- )
94-
9589 def forward (self , x ):
9690 # Encoder
9791 encode1 = self .encode1 (self .initial_conv (x ))
@@ -108,6 +102,22 @@ def forward(self, x):
108102 out = self .classifier (out )
109103 return out
110104
105+ def double_conv (self , in_channels : int , out_channels : int ):
106+ return nn .Sequential (
107+ nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
108+ nn .BatchNorm2d (out_channels ),
109+ nn .ReLU (inplace = True ),
110+ nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
111+ nn .BatchNorm2d (out_channels ),
112+ nn .ReLU (inplace = True )
113+ )
114+
115+ def make_channel_adjuster (self , in_channels : int , out_channels : int ):
116+ return nn .Sequential (
117+ nn .Conv2d (in_channels , out_channels , kernel_size = 1 ),
118+ nn .Conv2d (out_channels , out_channels , kernel_size = 1 )
119+ )
120+
111121
112122if __name__ == '__main__' :
113123 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
0 commit comments