Skip to content

Commit 9941222

Browse files
authored
Merge pull request #6 from AdaptiveMotorControlLab/mwm/download_models
Mwm/download models
2 parents 31b5fd7 + 9beb8a1 commit 9941222

File tree

11 files changed

+62
-100
lines changed

11 files changed

+62
-100
lines changed

docs/res/guides/custom_model_template.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ To add a custom model, you will need a **.py** file with the following structure
1717

1818
def get_weights_file():
1919
return "weights_file.pth" # name of the weights file for the model,
20-
# which should be in *napari_cellseg3d/models/saved_weights*
20+
# which should be in *napari_cellseg3d/models/pretrained*
2121

2222

2323
def get_output(model, input):
@@ -35,5 +35,3 @@ To add a custom model, you will need a **.py** file with the following structure
3535
def ModelClass(x1,x2...):
3636
# your Pytorch model here...
3737
return results # should return as [C, N, D,H,W]
38-
39-

napari_cellseg3d/models/TRAILMAP_MS.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import torch
22
from torch import nn
3+
from napari_cellseg3d import utils
4+
import os
35

46

57
def get_weights_file():
6-
# return "TMP_TEST_40e.pth"
7-
return "TRAILMAP_DFl_best.pth"
8+
# model additionally trained on Mathis/Wyss mesoSPIM data
9+
target_dir = utils.download_model("TRAILMAP_MS")
10+
return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth")
811

912

1013
def get_net():

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from monai.networks.nets import SegResNetVAE
2+
from napari_cellseg3d import utils
3+
import os
24

35

46
def get_net():
57
return SegResNetVAE
68

79

810
def get_weights_file():
9-
return "SegResNet.pth"
11+
target_dir = utils.download_model("SegResNet")
12+
return os.path.join(target_dir, "SegResNet.pth")
1013

1114

1215
def get_output(model, input):

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from napari_cellseg3d.models.unet.model import UNet3D
2+
from napari_cellseg3d import utils
3+
import os
24

35

46
def get_weights_file():
5-
# return "TMP_TEST_40e.pth"
6-
return "trailmaptorchpretrained.pth"
7+
# original model from Liqun Luo lab, transfered to pytorch
8+
target_dir = utils.download_model("TRAILMAP")
9+
return os.path.join(target_dir, "TRAILMAP_PyTorch.pth")
710

811

912
def get_net():

napari_cellseg3d/models/model_VNet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from monai.inferers import sliding_window_inference
22
from monai.networks.nets import VNet
3+
from napari_cellseg3d import utils
4+
import os
35

46

57
def get_net():
68
return VNet()
79

810

911
def get_weights_file():
10-
# return "dice_VNet.pth"
11-
return "VNet_40e.pth"
12+
target_dir = utils.download_model("VNet")
13+
return os.path.join(target_dir, "VNet_40e.pth")
1214

1315

1416
def get_output(model, input):
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz",
3+
"TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
4+
"SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz",
5+
"VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz"
6+
}

napari_cellseg3d/models/pretrained/pretrained_model_urls.yaml

Lines changed: 0 additions & 6 deletions
This file was deleted.

napari_cellseg3d/plugin_crop.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):
7272

7373
self.build()
7474

75-
###########################################
76-
if utils.ENABLE_TEST_MODE():
77-
# TODO : remove/disable once done
78-
if self.as_folder:
79-
self.image_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_png/sample"
80-
if self.crop_label_choice.isChecked():
81-
self.label_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_png/sample_labels"
82-
else:
83-
self.image_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif/volumes/images.tif"
84-
if self.crop_label_choice.isChecked():
85-
self.label_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif/labels/testing_im.tif"
86-
87-
###########################################
75+
8876

8977
def toggle_label_path(self):
9078
if self.crop_label_choice.isChecked():

napari_cellseg3d/plugin_metrics.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):
6767

6868
self.build()
6969

70-
######################################
71-
# TODO test remove
72-
import glob
73-
import os
74-
75-
if utils.ENABLE_TEST_MODE():
76-
ground_directory = "C:/Users/Cyril/Desktop/Proj_bachelor/data/cropped_visual/train/lab"
77-
# ground_directory = "C:/Users/Cyril/Desktop/test/labels"
78-
pred_directory = "C:/Users/Cyril/Desktop/test/pred"
79-
# pred_directory = "C:/Users/Cyril/Desktop/test"
80-
self.images_filepaths = sorted(
81-
glob.glob(os.path.join(ground_directory, "*.tif"))
82-
)
83-
self.labels_filepaths = sorted(
84-
glob.glob(os.path.join(pred_directory, "*.tif"))
85-
)
86-
###############################################################################
70+
8771

8872
def build(self):
8973
"""Builds the layout of the widget."""

napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -132,44 +132,7 @@ def __init__(
132132

133133
self.save_as_zip = False
134134
"""Whether to zip results folder once done. Creates a zipped copy of the results folder."""
135-
######################
136-
######################
137-
######################
138-
# TEST TODO REMOVE
139-
import glob
140-
141-
if utils.ENABLE_TEST_MODE():
142-
directory = os.path.dirname(os.path.realpath(__file__)) + str(
143-
Path("/models/dataset/volumes")
144-
)
145-
self.data_path = directory
146-
147-
lab_directory = os.path.dirname(os.path.realpath(__file__)) + str(
148-
Path("/models/dataset/lab_sem")
149-
)
150-
self.label_path = lab_directory
151-
152-
self.images_filepaths = sorted(
153-
glob.glob(os.path.join(directory, "*.tif"))
154-
)
155-
156-
self.labels_filepaths = sorted(
157-
glob.glob(os.path.join(lab_directory, "*.tif"))
158-
)
159-
160-
if results_path == "":
161-
self.results_path = "C:/Users/Cyril/Desktop/test/models"
162-
else:
163-
self.results_path = results_path
164-
165-
if data_path != "":
166-
self.data_path = data_path
167-
168-
if label_path != "":
169-
self.label_path = label_path
170-
#######################
171-
#######################
172-
#######################
135+
173136

174137
# recover default values
175138
self.num_samples = samples

0 commit comments

Comments
 (0)