Skip to content

Commit 9006a19

Browse files
authored
Merge pull request #10 from AdaptiveMotorControlLab/bugfix/local-installation
Bug fix in local installation & minor fixes in launch_review
2 parents 9941222 + d4080ee commit 9006a19

20 files changed

+110
-92
lines changed

napari_cellseg3d/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import Union
21
from typing import Optional
2+
from typing import Union
33

44
from qtpy.QtCore import Qt
55
from qtpy.QtCore import QUrl

napari_cellseg3d/launch_review.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
import os
12
from pathlib import Path
23

34
import matplotlib.pyplot as plt
45
import numpy as np
56
from magicgui import magicgui
6-
from matplotlib.backends.backend_qt5agg import (
7-
FigureCanvasQTAgg as FigureCanvas,
8-
)
7+
from matplotlib.backends.backend_qt5agg import \
8+
FigureCanvasQTAgg as FigureCanvas
99
from matplotlib.figure import Figure
10+
from monai.transforms import Zoom
1011
from qtpy.QtWidgets import QSizePolicy
1112
from scipy import ndimage
12-
from monai.transforms import Zoom
1313
from tifffile import imwrite
1414

1515
from napari_cellseg3d import utils
@@ -147,6 +147,7 @@ def launch_review(
147147

148148
layer = view1.layers[0]
149149
layer1 = view1.layers[1]
150+
150151
# if not as_folder:
151152
# r_path = os.path.dirname(r_path)
152153

@@ -164,20 +165,19 @@ def file_widget(
164165
dirname = Path(r_path)
165166
# def saver():
166167
out_dir = file_widget.dirname.value
168+
167169
# print("The directory is:", out_dir)
168170

169171
def quicksave():
170172
if not as_folder:
171173
if viewer.layers["labels"] is not None:
172-
time = utils.get_date_time()
173-
name = str(out_dir) + "/labels_reviewed_" + time + ".tif"
174+
name = os.path.join(str(out_dir), "labels_reviewed.tif")
174175
dat = viewer.layers["labels"].data
175176
imwrite(name, data=dat)
176177

177178
else:
178179
if viewer.layers["labels"] is not None:
179-
time = utils.get_date_time()
180-
dir_name = str(out_dir) + "/labels_reviewed_" + time
180+
dir_name = os.path.join(str(out_dir), "labels_reviewed")
181181
dat = viewer.layers["labels"].data
182182
utils.save_stack(dat, dir_name, filetype=filetype)
183183

@@ -206,17 +206,17 @@ def quicksave():
206206
xy_axes = canvas.figure.add_subplot(3, 1, 1)
207207
canvas.figure.suptitle("Shift-click on image for plot \n", fontsize=8)
208208
xy_axes.imshow(np.zeros((100, 100), np.int16))
209-
xy_axes.scatter(50, 50, s=10, c="red", alpha=0.25)
209+
xy_axes.scatter(50, 50, s=10, c="green", alpha=0.25)
210210
xy_axes.set_xlabel("x axis")
211211
xy_axes.set_ylabel("y axis")
212212
yz_axes = canvas.figure.add_subplot(3, 1, 2)
213213
yz_axes.imshow(np.zeros((100, 100), np.int16))
214-
yz_axes.scatter(50, 50, s=10, c="red", alpha=0.25)
214+
yz_axes.scatter(50, 50, s=10, c="green", alpha=0.25)
215215
yz_axes.set_xlabel("y axis")
216216
yz_axes.set_ylabel("z axis")
217217
zx_axes = canvas.figure.add_subplot(3, 1, 3)
218218
zx_axes.imshow(np.zeros((100, 100), np.int16))
219-
zx_axes.scatter(50, 50, s=10, c="red", alpha=0.25)
219+
zx_axes.scatter(50, 50, s=10, c="green", alpha=0.25)
220220
zx_axes.set_xlabel("x axis")
221221
zx_axes.set_ylabel("z axis")
222222

@@ -234,17 +234,33 @@ def update_canvas_canvas(viewer, event):
234234

235235
if "shift" in event.modifiers:
236236
try:
237-
m_point = np.round(viewer.cursor.position).astype(int)
238-
print(m_point)
239-
240-
crop_big = crop_img(
241-
[m_point[0], m_point[1], m_point[2]],
237+
cursor_position = np.round(viewer.cursor.position).astype(int)
238+
print(cursor_position)
239+
240+
cropped_volume = crop_volume_around_point(
241+
[
242+
cursor_position[0],
243+
cursor_position[1],
244+
cursor_position[2],
245+
],
242246
viewer.layers["volume"],
243247
)
244248

245-
xy_axes.imshow(crop_big[50], cmap="inferno", vmin=200, vmax=2000)
246-
yz_axes.imshow(crop_big.transpose(1, 0, 2)[50], cmap="inferno", vmin=200, vmax=2000)
247-
zx_axes.imshow(crop_big.transpose(2, 0, 1)[50], cmap="inferno", vmin=200, vmax=2000)
249+
xy_axes.imshow(
250+
cropped_volume[50], cmap="inferno", vmin=200, vmax=2000
251+
)
252+
yz_axes.imshow(
253+
cropped_volume.transpose(1, 0, 2)[50],
254+
cmap="inferno",
255+
vmin=200,
256+
vmax=2000,
257+
)
258+
zx_axes.imshow(
259+
cropped_volume.transpose(2, 0, 1)[50],
260+
cmap="inferno",
261+
vmin=200,
262+
vmax=2000,
263+
)
248264
canvas.draw_idle()
249265
except Exception as e:
250266
print(e)
@@ -262,40 +278,49 @@ def update_button(axis_event):
262278

263279
view1.dims.events.current_step.connect(update_button)
264280

265-
def crop_img(points, layer):
281+
def crop_volume_around_point(points, layer):
266282

267-
if zoom_factor != [1,1,1]:
268-
im = np.array(layer.data, dtype=np.int16)
269-
image = Zoom(
283+
if zoom_factor != [1, 1, 1]:
284+
vol = np.array(layer.data, dtype=np.int16)
285+
volume = Zoom(
270286
zoom_factor,
271287
keep_size=False,
272288
padding_mode="empty",
273-
)(np.expand_dims(im, axis=0))
274-
image = image[0]
289+
)(np.expand_dims(vol, axis=0))
290+
volume = volume[0]
275291
# image = ndimage.zoom(layer.data, zoom_factor, mode="nearest") # cleaner but much slower...
276-
else :
277-
image = layer.data
278-
279-
min_vals = [x - 50 for x in points]
280-
max_vals = [x + 50 for x in points]
281-
yohaku_minus = [n if n < 0 else 0 for n in min_vals]
282-
yohaku_plus = [
283-
x - image.shape[i] if image.shape[i] < x else 0
284-
for i, x in enumerate(max_vals)
292+
else:
293+
volume = layer.data
294+
295+
min_coordinates = [point - 50 for point in points]
296+
max_coordinates = [point + 50 for point in points]
297+
inferior_bound = [
298+
min_coordinate if min_coordinate < 0 else 0
299+
for min_coordinate in min_coordinates
300+
]
301+
superior_bound = [
302+
max_coordinate - volume.shape[i]
303+
if volume.shape[i] < max_coordinate
304+
else 0
305+
for i, max_coordinate in enumerate(max_coordinates)
285306
]
307+
286308
crop_slice = tuple(
287-
slice(np.maximum(0, n), x) for n, x in zip(min_vals, max_vals)
309+
slice(np.maximum(0, min_coordinate), max_coordinate)
310+
for min_coordinate, max_coordinate in zip(
311+
min_coordinates, max_coordinates
312+
)
288313
)
289314

290315
if as_folder:
291-
crop_temp = image[crop_slice].persist().compute()
316+
crop_temp = volume[crop_slice].persist().compute()
292317
else:
293-
294318
crop_temp = layer.data[crop_slice]
295-
cropped_img = np.zeros((100, 100, 100), np.int16)
296-
cropped_img[
297-
-yohaku_minus[0] : 100 - yohaku_plus[0],
298-
-yohaku_minus[1] : 100 - yohaku_plus[1],
299-
-yohaku_minus[2] : 100 - yohaku_plus[2],
319+
320+
cropped_volume = np.zeros((100, 100, 100), np.int16)
321+
cropped_volume[
322+
-inferior_bound[0] : 100 - superior_bound[0],
323+
-inferior_bound[1] : 100 - superior_bound[1],
324+
-inferior_bound[2] : 100 - superior_bound[2],
300325
] = crop_temp
301-
return cropped_img
326+
return cropped_volume

napari_cellseg3d/model_framework.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import napari
55
import torch
6-
76
# Qt
87
from qtpy.QtWidgets import QLineEdit
98
from qtpy.QtWidgets import QProgressBar

napari_cellseg3d/model_instance_seg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from __future__ import print_function
33

44
import numpy as np
5-
from skimage.measure import label
6-
75
# from skimage.measure import marching_cubes
86
# from skimage.measure import mesh_surface_area
7+
from skimage.measure import label
98
from skimage.measure import regionprops
109
from skimage.morphology import remove_small_objects
1110
from skimage.segmentation import watershed

napari_cellseg3d/model_workers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import torch
7-
87
# MONAI
98
from monai.data import CacheDataset
109
from monai.data import DataLoader
@@ -30,17 +29,14 @@
3029
from monai.transforms import SpatialPadd
3130
from monai.transforms import Zoom
3231
from monai.utils import set_determinism
33-
3432
# threads
3533
from napari.qt.threading import GeneratorWorker
3634
from napari.qt.threading import WorkerBaseSignals
37-
3835
# Qt
3936
from qtpy.QtCore import Signal
4037
from tifffile import imwrite
4138

4239
from napari_cellseg3d import utils
43-
4440
# local
4541
from napari_cellseg3d.model_instance_seg import binary_connected
4642
from napari_cellseg3d.model_instance_seg import binary_watershed

napari_cellseg3d/models/TRAILMAP_MS.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import os
2+
13
import torch
24
from torch import nn
5+
36
from napari_cellseg3d import utils
4-
import os
57

68

79
def get_weights_file():

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
2+
13
from monai.networks.nets import SegResNetVAE
4+
25
from napari_cellseg3d import utils
3-
import os
46

57

68
def get_net():

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from napari_cellseg3d.models.unet.model import UNet3D
2-
from napari_cellseg3d import utils
31
import os
42

3+
from napari_cellseg3d import utils
4+
from napari_cellseg3d.models.unet.model import UNet3D
5+
56

67
def get_weights_file():
78
# original model from Liqun Luo lab, transfered to pytorch

napari_cellseg3d/models/model_VNet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import os
2+
13
from monai.inferers import sliding_window_inference
24
from monai.networks.nets import VNet
5+
36
from napari_cellseg3d import utils
4-
import os
57

68

79
def get_net():

napari_cellseg3d/plugin_crop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from magicgui import magicgui
77
from magicgui.widgets import Container
88
from magicgui.widgets import Slider
9-
109
# Qt
1110
from qtpy.QtWidgets import QSizePolicy
1211
from tifffile import imwrite
@@ -72,8 +71,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):
7271

7372
self.build()
7473

75-
76-
7774
def toggle_label_path(self):
7875
if self.crop_label_choice.isChecked():
7976
self.lbl_label.setVisible(True)

0 commit comments

Comments
 (0)