File tree Expand file tree Collapse file tree 3 files changed +16
-0
lines changed
Expand file tree Collapse file tree 3 files changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -32,12 +32,19 @@ def restart(
3232 Restart model training from a checkpoint.
3333 """
3434 import os
35+ from pathlib import Path
3536
3637 import torch
3738 import wandb
3839
3940 from ...model import Trainer
4041
42+ # If ckptpath is a directory, get the last saved model
43+ ckptpath = Path (ckptpath )
44+ if ckptpath .is_dir ():
45+ ckptpath = sorted (ckptpath .glob ("*.pth" ))[- 1 ]
46+ ckptpath = str (ckptpath )
47+
4148 # Load the config from the previous model checkpoint
4249 config = torch .load (ckptpath , weights_only = False )["config" ]
4350 config ["ckptpath" ] = ckptpath
Original file line number Diff line number Diff line change @@ -339,6 +339,12 @@ def train(
339339 # Create the output directory for saving model weights
340340 Path (outpath ).mkdir (parents = True , exist_ok = True )
341341
342+ # If ckptpath is a directory, get the last saved model
343+ ckptpath = Path (ckptpath )
344+ if ckptpath .is_dir ():
345+ ckptpath = sorted (ckptpath .glob ("*.pth" ))[- 1 ]
346+ ckptpath = str (ckptpath )
347+
342348 # Parse 6-DoF pose parameters
343349 alphamin , alphamax = r1
344350 betamin , betamax = r2
Original file line number Diff line number Diff line change @@ -168,6 +168,9 @@ def train(self, run=None):
168168 if run is not None :
169169 self ._log_wandb (itr , log , imgs , masks )
170170
171+ # Save the final model
172+ self ._checkpoint (itr )
173+
171174 def step (self , itr ):
172175 if self .single_subject :
173176 log , imgs , masks = self ._step_single_subject (itr )
You can’t perform that action at this time.
0 commit comments