@@ -64,7 +64,7 @@ def __init__(self, num_classes: int):
6464 super (Proposed , self ).__init__ ()
6565 resnet34 = torchvision .models .resnet34 (pretrained = True )
6666
67- self .initial_conv = self .double_conv (3 , 64 , batch_norm = False )
67+ self .initial_conv = self .double_conv (3 , 64 )
6868 self .encode1 = resnet34 .layer1 # 64
6969 self .encode2 = resnet34 .layer2 # 128
7070 self .encode3 = resnet34 .layer3 # 256
@@ -82,23 +82,15 @@ def __init__(self, num_classes: int):
8282
8383 self .classifier = nn .Conv2d (64 , num_classes , kernel_size = 1 )
8484
85- def double_conv (self , in_channels : int , out_channels : int , batch_norm = True ):
86- if batch_norm :
87- return nn .Sequential (
88- nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
89- nn .BatchNorm2d (out_channels ),
90- nn .ReLU (inplace = True ),
91- nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
92- nn .BatchNorm2d (out_channels ),
93- nn .ReLU (inplace = True )
94- )
95- else :
96- return nn .Sequential (
97- nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
98- nn .ReLU (inplace = True ),
99- nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
100- nn .ReLU (inplace = True )
101- )
85+ def double_conv (self , in_channels : int , out_channels : int ):
86+ return nn .Sequential (
87+ nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
88+ nn .BatchNorm2d (out_channels ),
89+ nn .ReLU (inplace = True ),
90+ nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
91+ nn .BatchNorm2d (out_channels ),
92+ nn .ReLU (inplace = True )
93+ )
10294
10395 def forward (self , x ):
10496 # Encoder
0 commit comments