|
7 | 7 | import models.backbone |
8 | 8 |
|
9 | 9 |
|
10 | | -# ASPP(Atrous Spatial Pyramid Pooling) Module |
11 | | -class ASPP(nn.Module): |
12 | | - def __init__(self, in_channels: int, out_channels: int): |
13 | | - super(ASPP, self).__init__() |
14 | | - |
15 | | - self.branch1 = nn.Sequential( |
16 | | - nn.Conv2d(in_channels, out_channels, kernel_size=1), |
17 | | - nn.BatchNorm2d(out_channels), |
18 | | - nn.ReLU(inplace=True) |
19 | | - ) |
20 | | - self.branch2 = nn.Sequential( |
21 | | - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=3, dilation=3), |
22 | | - nn.BatchNorm2d(out_channels), |
23 | | - nn.ReLU(inplace=True) |
24 | | - ) |
25 | | - self.branch3 = nn.Sequential( |
26 | | - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6), |
27 | | - nn.BatchNorm2d(out_channels), |
28 | | - nn.ReLU(inplace=True) |
29 | | - ) |
30 | | - self.branch4 = nn.Sequential( |
31 | | - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=9, dilation=9), |
32 | | - nn.BatchNorm2d(out_channels), |
33 | | - nn.ReLU(inplace=True) |
34 | | - ) |
35 | | - self.branch5 = nn.Sequential( |
36 | | - nn.AdaptiveAvgPool2d(1), |
37 | | - nn.Conv2d(in_channels, out_channels, kernel_size=1), |
38 | | - nn.BatchNorm2d(out_channels), |
39 | | - nn.ReLU(inplace=True) |
40 | | - ) |
41 | | - self.final_conv = nn.Sequential( |
42 | | - nn.Conv2d(out_channels * 5, out_channels, kernel_size=1), |
43 | | - nn.BatchNorm2d(out_channels), |
44 | | - nn.ReLU(inplace=True) |
45 | | - ) |
46 | | - |
47 | | - def forward(self, x): |
48 | | - branch1 = self.branch1(x) |
49 | | - branch2 = self.branch2(x) |
50 | | - branch3 = self.branch3(x) |
51 | | - branch4 = self.branch4(x) |
52 | | - branch5 = F.interpolate(self.branch5(x), size=(x.size()[2], x.size()[3]), mode="bilinear", align_corners=False) |
53 | | - |
54 | | - out = self.final_conv(torch.cat([branch1, branch2, branch3, branch4, branch5], dim=1)) |
55 | | - return out |
56 | | - |
57 | | - |
58 | 10 | class Proposed(nn.Module): |
59 | 11 | def __init__(self, num_classes: int): |
60 | 12 | super(Proposed, self).__init__() |
@@ -115,6 +67,54 @@ def make_channel_adjuster(self, in_channels: int, out_channels: int): |
115 | 67 | ) |
116 | 68 |
|
117 | 69 |
|
| 70 | +# ASPP(Atrous Spatial Pyramid Pooling) Module |
| 71 | +class ASPP(nn.Module): |
| 72 | + def __init__(self, in_channels: int, out_channels: int): |
| 73 | + super(ASPP, self).__init__() |
| 74 | + |
| 75 | + self.branch1 = nn.Sequential( |
| 76 | + nn.Conv2d(in_channels, out_channels, kernel_size=1), |
| 77 | + nn.BatchNorm2d(out_channels), |
| 78 | + nn.ReLU(inplace=True) |
| 79 | + ) |
| 80 | + self.branch2 = nn.Sequential( |
| 81 | + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=3, dilation=3), |
| 82 | + nn.BatchNorm2d(out_channels), |
| 83 | + nn.ReLU(inplace=True) |
| 84 | + ) |
| 85 | + self.branch3 = nn.Sequential( |
| 86 | + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6), |
| 87 | + nn.BatchNorm2d(out_channels), |
| 88 | + nn.ReLU(inplace=True) |
| 89 | + ) |
| 90 | + self.branch4 = nn.Sequential( |
| 91 | + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=9, dilation=9), |
| 92 | + nn.BatchNorm2d(out_channels), |
| 93 | + nn.ReLU(inplace=True) |
| 94 | + ) |
| 95 | + self.branch5 = nn.Sequential( |
| 96 | + nn.AdaptiveAvgPool2d(1), |
| 97 | + nn.Conv2d(in_channels, out_channels, kernel_size=1), |
| 98 | + nn.BatchNorm2d(out_channels), |
| 99 | + nn.ReLU(inplace=True) |
| 100 | + ) |
| 101 | + self.final_conv = nn.Sequential( |
| 102 | + nn.Conv2d(out_channels * 5, out_channels, kernel_size=1), |
| 103 | + nn.BatchNorm2d(out_channels), |
| 104 | + nn.ReLU(inplace=True) |
| 105 | + ) |
| 106 | + |
| 107 | + def forward(self, x): |
| 108 | + branch1 = self.branch1(x) |
| 109 | + branch2 = self.branch2(x) |
| 110 | + branch3 = self.branch3(x) |
| 111 | + branch4 = self.branch4(x) |
| 112 | + branch5 = F.interpolate(self.branch5(x), size=(x.size()[2], x.size()[3]), mode="bilinear", align_corners=False) |
| 113 | + |
| 114 | + out = self.final_conv(torch.cat([branch1, branch2, branch3, branch4, branch5], dim=1)) |
| 115 | + return out |
| 116 | + |
| 117 | + |
118 | 118 | if __name__ == '__main__': |
119 | 119 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
120 | 120 | model = Proposed(20).to(device) |
|
0 commit comments