@@ -9,12 +9,14 @@ class UNet(nn.Module):
99 def __init__ (self , num_classes : int ):
1010 super (UNet , self ).__init__ ()
1111
12+ # Encoder
1213 self .encode1 = self .double_conv (3 , 64 )
1314 self .encode2 = self .double_conv (64 , 128 )
1415 self .encode3 = self .double_conv (128 , 256 )
1516 self .encode4 = self .double_conv (256 , 512 )
1617 self .encode_end = self .double_conv (512 , 1024 )
1718
19+ # Decoder
1820 self .upconv4 = nn .ConvTranspose2d (1024 , 512 , kernel_size = 2 , stride = 2 )
1921 self .decode4 = self .double_conv (1024 , 512 )
2022
@@ -27,16 +29,9 @@ def __init__(self, num_classes: int):
2729 self .upconv1 = nn .ConvTranspose2d (128 , 64 , kernel_size = 2 , stride = 2 )
2830 self .decode1 = self .double_conv (128 , 64 )
2931
32+ # Classifier
3033 self .classifier = nn .Conv2d (64 , num_classes , kernel_size = 1 )
3134
32- def double_conv (self , in_channels : int , out_channels : int ):
33- return nn .Sequential (
34- nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
35- nn .ReLU (inplace = True ),
36- nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
37- nn .ReLU (inplace = True )
38- )
39-
4035 def forward (self , x ):
4136 # Encoder
4237 encode1 = self .encode1 (x )
@@ -55,6 +50,14 @@ def forward(self, x):
5550 out = self .classifier (out )
5651 return out
5752
53+ def double_conv (self , in_channels : int , out_channels : int ):
54+ return nn .Sequential (
55+ nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
56+ nn .ReLU (inplace = True ),
57+ nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 ),
58+ nn .ReLU (inplace = True )
59+ )
60+
5861
5962if __name__ == '__main__' :
6063 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
0 commit comments