Skip to content

Commit 444bd45

Browse files
committed
unet: Refactor to Improve readability
1 parent aec199c commit 444bd45

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

models/unet.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ class UNet(nn.Module):
99
def __init__(self, num_classes: int):
1010
super(UNet, self).__init__()
1111

12+
# Encoder
1213
self.encode1 = self.double_conv(3, 64)
1314
self.encode2 = self.double_conv(64, 128)
1415
self.encode3 = self.double_conv(128, 256)
1516
self.encode4 = self.double_conv(256, 512)
1617
self.encode_end = self.double_conv(512, 1024)
1718

19+
# Decoder
1820
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
1921
self.decode4 = self.double_conv(1024, 512)
2022

@@ -27,16 +29,9 @@ def __init__(self, num_classes: int):
2729
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
2830
self.decode1 = self.double_conv(128, 64)
2931

32+
# Classifier
3033
self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
3134

32-
def double_conv(self, in_channels: int, out_channels: int):
33-
return nn.Sequential(
34-
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
35-
nn.ReLU(inplace=True),
36-
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
37-
nn.ReLU(inplace=True)
38-
)
39-
4035
def forward(self, x):
4136
# Encoder
4237
encode1 = self.encode1(x)
@@ -55,6 +50,14 @@ def forward(self, x):
5550
out = self.classifier(out)
5651
return out
5752

53+
def double_conv(self, in_channels: int, out_channels: int):
54+
return nn.Sequential(
55+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
56+
nn.ReLU(inplace=True),
57+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
58+
nn.ReLU(inplace=True)
59+
)
60+
5861

5962
if __name__ == '__main__':
6063
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

0 commit comments

Comments
 (0)