Skip to content

Commit eaeb496

Browse files
committed
update conv/deconv
1 parent 9a0901d commit eaeb496

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

pymic/layer/convolution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
2828
self.bn = nn.modules.BatchNorm2d(out_channels)
2929
elif(self.norm_type == 'group_norm'):
3030
self.bn = nn.GroupNorm(self.norm_group, out_channels)
31-
else:
31+
elif(self.norm_type is not None):
3232
raise ValueError("unsupported normalization method {0:}".format(norm_type))
3333
else:
3434
self.conv = nn.Conv3d(in_channels, out_channels,
@@ -37,7 +37,7 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
3737
self.bn = nn.modules.BatchNorm3d(out_channels)
3838
elif(self.norm_type == 'group_norm'):
3939
self.bn = nn.GroupNorm(self.norm_group, out_channels)
40-
else:
40+
elif(self.norm_type is not None):
4141
raise ValueError("unsupported normalization method {0:}".format(norm_type))
4242

4343
def forward(self, x):
@@ -74,7 +74,7 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
7474
self.bn = nn.modules.BatchNorm2d(out_channels)
7575
elif(self.norm_type == 'group_norm'):
7676
self.bn = nn.GroupNorm(self.norm_group, out_channels)
77-
else:
77+
elif(self.norm_type is not None):
7878
raise ValueError("unsupported normalization method {0:}".format(norm_type))
7979
else:
8080
self.conv = nn.Conv3d(in_channels, in_channels,
@@ -85,7 +85,7 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
8585
self.bn = nn.modules.BatchNorm3d(out_channels)
8686
elif(self.norm_type == 'group_norm'):
8787
self.bn = nn.GroupNorm(self.norm_group, out_channels)
88-
else:
88+
elif(self.norm_type is not None):
8989
raise ValueError("unsupported normalization method {0:}".format(norm_type))
9090

9191
def forward(self, x):

pymic/layer/deconvolution.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,49 @@ def forward(self, x):
4141
if(self.acti_func is not None):
4242
f = self.acti_func(f)
4343
return f
44+
45+
class DepthSeperableDeconvolutionLayer(nn.Module):
46+
"""
47+
A compose layer with the following components:
48+
convolution -> (batch_norm) -> activation -> (dropout)
49+
batch norm and dropout are optional
50+
"""
51+
def __init__(self, in_channels, out_channels, kernel_size,
52+
dim = 3, stride = 1, padding = 0, output_padding = 0,
53+
dilation =1, groups = 1, bias = True,
54+
batch_norm = True, acti_func = None):
55+
super(DepthSeperableDeconvolutionLayer, self).__init__()
56+
self.n_in_chns = in_channels
57+
self.n_out_chns = out_channels
58+
self.batch_norm = batch_norm
59+
self.acti_func = acti_func
60+
self.groups = groups
61+
assert(dim == 2 or dim == 3)
62+
if(dim == 2):
63+
self.conv1x1 = nn.Conv2d(in_channels, out_channels,
64+
kernel_size = 1, stride = 1, padding = 0, dilation = dilation,
65+
groups = self.groups, bias = bias)
66+
self.conv = nn.ConvTranspose2d(out_channels, out_channels,
67+
kernel_size, stride, padding, output_padding,
68+
groups = out_channels, bias = bias, dilation = dilation)
69+
70+
if(self.batch_norm):
71+
self.bn = nn.modules.BatchNorm2d(out_channels)
72+
else:
73+
self.conv1x1 = nn.Conv3d(in_channels, out_channels,
74+
kernel_size = 1, stride = 1, padding = 0, dilation = dilation,
75+
groups = self.groups, bias = bias)
76+
self.conv = nn.ConvTranspose3d(out_channels, out_channels,
77+
kernel_size, stride, padding, output_padding,
78+
groups = out_channels, bias = bias, dilation = dilation)
79+
if(self.batch_norm):
80+
self.bn = nn.modules.BatchNorm3d(out_channels)
81+
82+
def forward(self, x):
83+
f = self.conv1x1(x)
84+
f = self.conv(f)
85+
if(self.batch_norm):
86+
f = self.bn(f)
87+
if(self.acti_func is not None):
88+
f = self.acti_func(f)
89+
return f

pymic/util/rename_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
3+
4+
def rename_model_variable(input_file, output_file, input_var_list, output_var_list):
5+
assert(len(input_var_list) == len(output_var_list))
6+
checkpoint = torch.load(input_file)
7+
state_dict = checkpoint['model_state_dict']
8+
for i in range(len(input_var_list)):
9+
input_var = input_var_list[i]
10+
output_var = output_var_list[i]
11+
state_dict[output_var] = state_dict[input_var]
12+
state_dict.pop(input_var)
13+
checkpoint['model_state_dict'] = state_dict
14+
torch.save(checkpoint, output_file)
15+
16+
if __name__ == "__main__":
17+
input_file = '/home/disk2t/projects/dlls/training_fetal_brain/model2/unet2dres/model_15000.pt'
18+
output_file = '/home/disk2t/projects/dlls/training_fetal_brain/model2/unet2dres/model_15000_rename.pt'
19+
input_var_list = ['conv.weight', 'conv.bias']
20+
output_var_list= ['conv9.weight', 'conv9.bias']
21+
rename_model_variable(input_file, output_file, input_var_list, output_var_list)

0 commit comments

Comments
 (0)