|
7 | 7 | class ConvolutionLayer(nn.Module): |
8 | 8 | """ |
9 | 9 | A compose layer with the following components: |
10 | | - convolution -> (batch_norm) -> activation -> (dropout) |
| 10 | + convolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> activation -> (dropout) |
11 | 11 | batch norm and dropout are optional |
12 | 12 | """ |
13 | 13 | def __init__(self, in_channels, out_channels, kernel_size, dim = 3, |
14 | | - stride = 1, padding = 0, dilation =1, groups = 1, bias = True, |
15 | | - batch_norm = True, acti_func = None): |
| 14 | + stride = 1, padding = 0, dilation = 1, conv_group = 1, bias = True, |
| 15 | + norm_type = 'batch_norm', norm_group = 1, acti_func = None): |
16 | 16 | super(ConvolutionLayer, self).__init__() |
17 | 17 | self.n_in_chns = in_channels |
18 | 18 | self.n_out_chns = out_channels |
19 | | - self.batch_norm = batch_norm |
| 19 | + self.norm_type = norm_type |
| 20 | + self.norm_group = norm_group |
20 | 21 | self.acti_func = acti_func |
21 | 22 |
|
22 | 23 | assert(dim == 2 or dim == 3) |
23 | 24 | if(dim == 2): |
24 | 25 | self.conv = nn.Conv2d(in_channels, out_channels, |
25 | | - kernel_size, stride, padding, dilation, groups, bias) |
26 | | - if(self.batch_norm): |
| 26 | + kernel_size, stride, padding, dilation, conv_group, bias) |
| 27 | + if(self.norm_type == 'batch_norm'): |
27 | 28 | self.bn = nn.modules.BatchNorm2d(out_channels) |
| 29 | + elif(self.norm_type == 'group_norm'): |
| 30 | + self.bn = nn.GroupNorm(self.norm_group, out_channels) |
| 31 | + else: |
| 32 | + raise ValueError("unsupported normalization method {0:}".format(norm_type)) |
28 | 33 | else: |
29 | 34 | self.conv = nn.Conv3d(in_channels, out_channels, |
30 | | - kernel_size, stride, padding, dilation, groups, bias) |
31 | | - if(self.batch_norm): |
| 35 | + kernel_size, stride, padding, dilation, conv_group, bias) |
| 36 | + if(self.norm_type == 'batch_norm'): |
32 | 37 | self.bn = nn.modules.BatchNorm3d(out_channels) |
| 38 | + elif(self.norm_type == 'group_norm'): |
| 39 | + self.bn = nn.GroupNorm(self.norm_group, out_channels) |
| 40 | + else: |
| 41 | + raise ValueError("unsupported normalization method {0:}".format(norm_type)) |
33 | 42 |
|
34 | 43 | def forward(self, x): |
35 | 44 | f = self.conv(x) |
36 | | - if(self.batch_norm): |
| 45 | + if(self.norm_type is not None): |
| 46 | + f = self.bn(f) |
| 47 | + if(self.acti_func is not None): |
| 48 | + f = self.acti_func(f) |
| 49 | + return f |
| 50 | + |
| 51 | +class DepthSeperableConvolutionLayer(nn.Module): |
| 52 | + """ |
| 53 | + A compose layer with the following components: |
| 54 | + convolution -> (batch_norm) -> activation -> (dropout) |
| 55 | + batch norm and dropout are optional |
| 56 | + """ |
| 57 | + def __init__(self, in_channels, out_channels, kernel_size, dim = 3, |
| 58 | + stride = 1, padding = 0, dilation =1, conv_group = 1, bias = True, |
| 59 | + norm_type = 'batch_norm', norm_group = 1, acti_func = None): |
| 60 | + super(DepthSeperableConvolutionLayer, self).__init__() |
| 61 | + self.n_in_chns = in_channels |
| 62 | + self.n_out_chns = out_channels |
| 63 | + self.norm_type = norm_type |
| 64 | + self.norm_group = norm_group |
| 65 | + self.acti_func = acti_func |
| 66 | + |
| 67 | + assert(dim == 2 or dim == 3) |
| 68 | + if(dim == 2): |
| 69 | + self.conv = nn.Conv2d(in_channels, in_channels, |
| 70 | + kernel_size, stride, padding, dilation, groups = in_channels, bias = bias) |
| 71 | + self.conv1x1 = nn.Conv2d(in_channels, out_channels, |
| 72 | + kernel_size = 1, stride = stride, padding = 0, dilation = dilation, groups = conv_group, bias = bias) |
| 73 | + if(self.norm_type == 'batch_norm'): |
| 74 | + self.bn = nn.modules.BatchNorm2d(out_channels) |
| 75 | + elif(self.norm_type == 'group_norm'): |
| 76 | + self.bn = nn.GroupNorm(self.norm_group, out_channels) |
| 77 | + else: |
| 78 | + raise ValueError("unsupported normalization method {0:}".format(norm_type)) |
| 79 | + else: |
| 80 | + self.conv = nn.Conv3d(in_channels, in_channels, |
| 81 | + kernel_size, stride, padding, dilation, groups = in_channels, bias = bias) |
| 82 | + self.conv1x1 = nn.Conv3d(in_channels, out_channels, |
| 83 | + kernel_size = 1, stride = 0, padding = 0, dilation = 0, groups = conv_group, bias = bias) |
| 84 | + if(self.norm_type == 'batch_norm'): |
| 85 | + self.bn = nn.modules.BatchNorm3d(out_channels) |
| 86 | + elif(self.norm_type == 'group_norm'): |
| 87 | + self.bn = nn.GroupNorm(self.norm_group, out_channels) |
| 88 | + else: |
| 89 | + raise ValueError("unsupported normalization method {0:}".format(norm_type)) |
| 90 | + |
| 91 | + def forward(self, x): |
| 92 | + f = self.conv(x) |
| 93 | + f = self.conv1x1(f) |
| 94 | + if(self.norm_type is not None): |
37 | 95 | f = self.bn(f) |
38 | 96 | if(self.acti_func is not None): |
39 | 97 | f = self.acti_func(f) |
|
0 commit comments