@@ -983,7 +983,7 @@ def merge_imgs(imgs, original_image_shape):
983983 return merged_imgs
984984
985985
986- def download_model (modelname , target_dir ):
986+ def download_model (modelname ):
987987 """
988988 Downloads a specific pretained model.
989989 This code is adapted from DeepLabCut with permission from MWMathis
@@ -995,24 +995,11 @@ def download_model(modelname, target_dir):
995995 def show_progress (count , block_size , total_size ):
996996 pbar .update (block_size )
997997
998- def tarfilenamecutting (tarf ):
999- """' auxfun to extract folder path
1000- ie. /xyz-trainsetxyshufflez/
1001- """
1002- for memberid , member in enumerate (tarf .getmembers ()):
1003- if memberid == 0 :
1004- parent = str (member .path )
1005- l = len (parent ) + 1
1006- if member .path .startswith (parent ):
1007- member .path = member .path [l :]
1008- yield member
1009-
1010- #TODO: fix error in line 1021;
1011- cellseg3d_path = os .path .split (importlib .util .find_spec ("napari-cellseg3d" ).origin )[0 ]
1012- json_path = os .path .join (cellseg3d_path , "models" , "pretrained" , "pretrained_model_urls.json" )
998+ cellseg3d_path = os .path .split (importlib .util .find_spec ("napari_cellseg3d" ).origin )[0 ]
999+ pretrained_folder_path = os .path .join (cellseg3d_path , "models" , "pretrained" )
1000+ json_path = os .path .join (pretrained_folder_path , "pretrained_model_urls.json" )
10131001 with open (json_path ) as f :
10141002 neturls = json .load (f )
1015-
10161003 if modelname in neturls .keys ():
10171004 url = neturls [modelname ]
10181005 response = urllib .request .urlopen (url )
@@ -1021,12 +1008,9 @@ def tarfilenamecutting(tarf):
10211008 pbar = tqdm (unit = "B" , total = total_size , position = 0 )
10221009 filename , _ = urllib .request .urlretrieve (url , reporthook = show_progress )
10231010 with tarfile .open (filename , mode = "r:gz" ) as tar :
1024- tar .extractall (target_dir , members = tarfilenamecutting (tar ))
1011+ tar .extractall (pretrained_folder_path )
1012+ return pretrained_folder_path
10251013 else :
1026- models = [
1027- fn
1028- for fn in neturls .keys ()
1029- if "VNet_" not in fn and "SegResNet" not in fn and "TRAILMAP_" not in fn
1030- ]
1031- print ("Model does not exist: " , modelname )
1032- #print("Pick one of the following: ", models)
1014+ raise ValueError (
1015+ f"Unknown model. `modelname` should be one of { ', ' .join (neturls )} "
1016+ )
0 commit comments