Skip to content

Commit 9beb8a1

Browse files
committed
Fix model download function
1 parent 04e0af2 commit 9beb8a1

File tree

5 files changed

+19
-41
lines changed

5 files changed

+19
-41
lines changed

napari_cellseg3d/models/TRAILMAP_MS.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from napari_cellseg3d import utils
44
import os
55

6-
modelname = "TRAILMAP_MS"
7-
target_dir = os.path.join("models","pretrained")
86

97
def get_weights_file():
10-
utils.download_model(modelname, target_dir)
11-
return "TRAILMAP_MS_best_metric_epoch_26.pth" #model additionally trained on Mathis/Wyss mesoSPIM data
8+
# model additionally trained on Mathis/Wyss mesoSPIM data
9+
target_dir = utils.download_model("TRAILMAP_MS")
10+
return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth")
1211

1312

1413
def get_net():

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@
22
from napari_cellseg3d import utils
33
import os
44

5-
modelname = "SegResNet"
6-
target_dir = os.path.join("models","pretrained")
75

86
def get_net():
97
return SegResNetVAE
108

119

1210
def get_weights_file():
13-
utils.download_model(modelname, target_dir)
14-
return "SegResNet.pth"
11+
target_dir = utils.download_model("SegResNet")
12+
return os.path.join(target_dir, "SegResNet.pth")
1513

1614

1715
def get_output(model, input):

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
from napari_cellseg3d import utils
33
import os
44

5-
modelname = "TRAILMAP"
6-
target_dir = os.path.join("models","pretrained")
75

86
def get_weights_file():
9-
utils.download_model(modelname, target_dir)
10-
return "TRAILMAP_PyTorch.pth" #original model from Liqun Luo lab, transfered to pytorch
7+
# original model from Liqun Luo lab, transfered to pytorch
8+
target_dir = utils.download_model("TRAILMAP")
9+
return os.path.join(target_dir, "TRAILMAP_PyTorch.pth")
1110

1211

1312
def get_net():

napari_cellseg3d/models/model_VNet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
from napari_cellseg3d import utils
44
import os
55

6-
modelname = "VNet"
7-
target_dir = os.path.join("models","pretrained")
86

97
def get_net():
108
return VNet()
119

1210

1311
def get_weights_file():
14-
utils.download_model(modelname, target_dir)
15-
return "VNet_40e.pth"
12+
target_dir = utils.download_model("VNet")
13+
return os.path.join(target_dir, "VNet_40e.pth")
1614

1715

1816
def get_output(model, input):

napari_cellseg3d/utils.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def merge_imgs(imgs, original_image_shape):
983983
return merged_imgs
984984

985985

986-
def download_model(modelname, target_dir):
986+
def download_model(modelname):
987987
"""
988988
Downloads a specific pretained model.
989989
This code is adapted from DeepLabCut with permission from MWMathis
@@ -995,24 +995,11 @@ def download_model(modelname, target_dir):
995995
def show_progress(count, block_size, total_size):
996996
pbar.update(block_size)
997997

998-
def tarfilenamecutting(tarf):
999-
"""' auxfun to extract folder path
1000-
ie. /xyz-trainsetxyshufflez/
1001-
"""
1002-
for memberid, member in enumerate(tarf.getmembers()):
1003-
if memberid == 0:
1004-
parent = str(member.path)
1005-
l = len(parent) + 1
1006-
if member.path.startswith(parent):
1007-
member.path = member.path[l:]
1008-
yield member
1009-
1010-
#TODO: fix error in line 1021;
1011-
cellseg3d_path = os.path.split(importlib.util.find_spec("napari-cellseg3d").origin)[0]
1012-
json_path = os.path.join(cellseg3d_path, "models", "pretrained", "pretrained_model_urls.json")
998+
cellseg3d_path = os.path.split(importlib.util.find_spec("napari_cellseg3d").origin)[0]
999+
pretrained_folder_path = os.path.join(cellseg3d_path, "models", "pretrained")
1000+
json_path = os.path.join(pretrained_folder_path, "pretrained_model_urls.json")
10131001
with open(json_path) as f:
10141002
neturls = json.load(f)
1015-
10161003
if modelname in neturls.keys():
10171004
url = neturls[modelname]
10181005
response = urllib.request.urlopen(url)
@@ -1021,12 +1008,9 @@ def tarfilenamecutting(tarf):
10211008
pbar = tqdm(unit="B", total=total_size, position=0)
10221009
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
10231010
with tarfile.open(filename, mode="r:gz") as tar:
1024-
tar.extractall(target_dir, members=tarfilenamecutting(tar))
1011+
tar.extractall(pretrained_folder_path)
1012+
return pretrained_folder_path
10251013
else:
1026-
models = [
1027-
fn
1028-
for fn in neturls.keys()
1029-
if "VNet_" not in fn and "SegResNet" not in fn and "TRAILMAP_" not in fn
1030-
]
1031-
print("Model does not exist: ", modelname)
1032-
#print("Pick one of the following: ", models)
1014+
raise ValueError(
1015+
f"Unknown model. `modelname` should be one of {', '.join(neturls)}"
1016+
)

0 commit comments

Comments
 (0)