Skip to content

Commit 44a9e8c

Browse files
committed
update networks
1 parent eaeb496 commit 44a9e8c

File tree

5 files changed

+80
-23
lines changed

5 files changed

+80
-23
lines changed

pymic/io/transform3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def __call__(self, sample):
530530
image= sample['image']
531531
for chn in range(image.shape[0]):
532532
mask = np.asarray(image[chn] > self.threshold[chn], image.dtype)
533-
image[chn] = mask * image[chn]
533+
image[chn] = mask * (image[chn] - self.threshold[chn])
534534

535535
sample['image'] = image
536536
return sample

pymic/layer/convolution.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,21 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
6666

6767
assert(dim == 2 or dim == 3)
6868
if(dim == 2):
69-
self.conv = nn.Conv2d(in_channels, in_channels,
70-
kernel_size, stride, padding, dilation, groups = in_channels, bias = bias)
7169
self.conv1x1 = nn.Conv2d(in_channels, out_channels,
7270
kernel_size = 1, stride = stride, padding = 0, dilation = dilation, groups = conv_group, bias = bias)
71+
self.conv = nn.Conv2d(out_channels, out_channels,
72+
kernel_size, stride, padding, dilation, groups = out_channels, bias = bias)
7373
if(self.norm_type == 'batch_norm'):
7474
self.bn = nn.modules.BatchNorm2d(out_channels)
7575
elif(self.norm_type == 'group_norm'):
7676
self.bn = nn.GroupNorm(self.norm_group, out_channels)
7777
elif(self.norm_type is not None):
7878
raise ValueError("unsupported normalization method {0:}".format(norm_type))
79-
else:
80-
self.conv = nn.Conv3d(in_channels, in_channels,
81-
kernel_size, stride, padding, dilation, groups = in_channels, bias = bias)
79+
else:
8280
self.conv1x1 = nn.Conv3d(in_channels, out_channels,
83-
kernel_size = 1, stride = 0, padding = 0, dilation = 0, groups = conv_group, bias = bias)
81+
kernel_size = 1, stride = stride, padding = 0, dilation = dilation, groups = conv_group, bias = bias)
82+
self.conv = nn.Conv3d(out_channels, out_channels,
83+
kernel_size, stride, padding, dilation, groups = out_channels, bias = bias)
8484
if(self.norm_type == 'batch_norm'):
8585
self.bn = nn.modules.BatchNorm3d(out_channels)
8686
elif(self.norm_type == 'group_norm'):
@@ -89,8 +89,8 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
8989
raise ValueError("unsupported normalization method {0:}".format(norm_type))
9090

9191
def forward(self, x):
92-
f = self.conv(x)
93-
f = self.conv1x1(f)
92+
f = self.conv1x1(x)
93+
f = self.conv(f)
9494
if(self.norm_type is not None):
9595
f = self.bn(f)
9696
if(self.acti_func is not None):

pymic/net3d/unet2d5.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,56 @@
77
from pymic.layer.activation import get_acti_func
88
from pymic.layer.convolution import ConvolutionLayer
99
from pymic.layer.deconvolution import DeconvolutionLayer
10-
from pymic.net3d.unet2d5_ag import UNetBlock
1110

11+
class UNetBlock(nn.Module):
12+
def __init__(self, in_channels, out_chnannels,
13+
dim, resample, acti_func, acti_func_param):
14+
super(UNetBlock, self).__init__()
15+
16+
self.in_chns = in_channels
17+
self.out_chns = out_chnannels
18+
self.dim = dim
19+
self.resample = resample # resample should be 'down', 'up', or None
20+
self.acti_func = acti_func
21+
22+
self.conv1 = ConvolutionLayer(in_channels, out_chnannels, kernel_size = 3, padding=1,
23+
dim = self.dim, acti_func=get_acti_func(acti_func, acti_func_param))
24+
self.conv2 = ConvolutionLayer(out_chnannels, out_chnannels, kernel_size = 3, padding=1,
25+
dim = self.dim, acti_func=get_acti_func(acti_func, acti_func_param))
26+
if(self.resample == 'down'):
27+
if(self.dim == 2):
28+
self.resample_layer = nn.MaxPool2d(kernel_size = 2, stride = 2)
29+
else:
30+
self.resample_layer = nn.MaxPool3d(kernel_size = 2, stride = 2)
31+
elif(self.resample == 'up'):
32+
self.resample_layer = DeconvolutionLayer(out_chnannels, out_chnannels, kernel_size = 2,
33+
dim = self.dim, stride = 2, acti_func = get_acti_func(acti_func, acti_func_param))
34+
else:
35+
assert(self.resample == None)
36+
37+
def forward(self, x):
38+
x_shape = list(x.shape)
39+
if(self.dim == 2 and len(x_shape) == 5):
40+
[N, C, D, H, W] = x_shape
41+
new_shape = [N*D, C, H, W]
42+
x = torch.transpose(x, 1, 2)
43+
x = torch.reshape(x, new_shape)
44+
output = self.conv1(x)
45+
output = self.conv2(output)
46+
resample = None
47+
if(self.resample is not None):
48+
resample = self.resample_layer(output)
49+
50+
if(self.dim == 2 and len(x_shape) == 5):
51+
new_shape = [N, D] + list(output.shape)[1:]
52+
output = torch.reshape(output, new_shape)
53+
output = torch.transpose(output, 1, 2)
54+
if(resample is not None):
55+
resample_shape = list(resample.shape)
56+
new_shape = [N, D] + resample_shape[1:]
57+
resample = torch.reshape(resample, new_shape)
58+
resample = torch.transpose(resample, 1, 2)
59+
return output, resample
1260

1361
class UNet2D5(nn.Module):
1462
def __init__(self, params):

pymic/train_infer/train_infer.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
from pymic.util.parse_config import parse_config
2727

2828

29-
class TrainInferAgent():
29+
class TrainInferAgent(object):
3030
def __init__(self, config, stage = 'train'):
3131
self.config = config
3232
self.stage = stage
3333
assert(stage in ['train', 'inference', 'test'])
3434

35-
def __create_dataset(self):
35+
def create_dataset(self):
3636
root_dir = self.config['dataset']['root_dir']
3737
train_csv = self.config['dataset'].get('train_csv', None)
3838
valid_csv = self.config['dataset'].get('valid_csv', None)
@@ -74,11 +74,11 @@ def __create_dataset(self):
7474
self.test_loder = torch.utils.data.DataLoader(test_dataset,
7575
batch_size=batch_size, shuffle=False, num_workers=batch_size)
7676

77-
def __create_network(self):
77+
def create_network(self):
7878
self.net = get_network(self.config['network'])
7979
self.net.double()
8080

81-
def __create_optimizer(self):
81+
def create_optimizer(self):
8282
self.optimizer = get_optimiser(self.config['training']['optimizer'],
8383
self.net.parameters(),
8484
self.config['training'])
@@ -91,7 +91,7 @@ def __create_optimizer(self):
9191
self.config['training']['lr_gamma'],
9292
last_epoch = last_iter)
9393

94-
def __train(self):
94+
def train(self):
9595
device = torch.device(self.config['training']['device_name'])
9696
self.net.to(device)
9797

@@ -111,7 +111,7 @@ def __train(self):
111111
self.net.load_state_dict(self.checkpoint['model_state_dict'])
112112
else:
113113
self.checkpoint = None
114-
self.__create_optimizer()
114+
self.create_optimizer()
115115

116116
train_loss = 0
117117
train_dice_list = []
@@ -218,7 +218,7 @@ def __train(self):
218218
torch.save(save_dict, save_name)
219219
summ_writer.close()
220220

221-
def __infer(self):
221+
def infer(self):
222222
device = torch.device(self.config['testing']['device_name'])
223223
self.net.to(device)
224224
# laod network parameters and set the network as evaluation mode
@@ -264,6 +264,15 @@ def test_time_dropout(m):
264264
images = data['image'].double()
265265
names = data['names']
266266
print(names[0])
267+
# for debug
268+
# for i in range(images.shape[0]):
269+
# image_i = images[i][0]
270+
# label_i = images[i][0]
271+
# image_name = "temp/{0:}_image.nii.gz".format(names[0])
272+
# label_name = "temp/{0:}_label.nii.gz".format(names[0])
273+
# save_nd_array_as_image(image_i, image_name, reference_name = None)
274+
# save_nd_array_as_image(label_i, label_name, reference_name = None)
275+
# continue
267276
data['predict'] = volume_infer(images, self.net, device, class_num,
268277
mini_batch_size, mini_patch_inshape, mini_patch_outshape, mini_patch_stride)
269278

@@ -303,12 +312,12 @@ def test_time_dropout(m):
303312
print("average testing time {0:}".format(avg_time))
304313

305314
def run(self):
306-
agent.__create_dataset()
307-
agent.__create_network()
315+
self.create_dataset()
316+
self.create_network()
308317
if(self.stage == 'train'):
309-
self.__train()
318+
self.train()
310319
else:
311-
self.__infer()
320+
self.infer()
312321

313322
if __name__ == "__main__":
314323
if(len(sys.argv) < 3):

pymic/util/rename_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def rename_model_variable(input_file, output_file, input_var_list, output_var_li
1414
torch.save(checkpoint, output_file)
1515

1616
if __name__ == "__main__":
17-
input_file = '/home/disk2t/projects/dlls/training_fetal_brain/model2/unet2dres/model_15000.pt'
18-
output_file = '/home/disk2t/projects/dlls/training_fetal_brain/model2/unet2dres/model_15000_rename.pt'
17+
input_file = '/home/guotai/disk2t/projects/PyMIC/examples/prostate/model/unet3db/model_15000.pt'
18+
output_file = '/home/guotai/disk2t/projects/PyMIC/examples/prostate/model/unet3db/model_15000_rename.pt'
1919
input_var_list = ['conv.weight', 'conv.bias']
2020
output_var_list= ['conv9.weight', 'conv9.bias']
2121
rename_model_variable(input_file, output_file, input_var_list, output_var_list)

0 commit comments

Comments
 (0)