Skip to content

Commit 1c2354e

Browse files
committed
Allow training to restart from a folder (#64)
* If directory passed, get last model * Save the final model after all iterations
1 parent 3af0d12 commit 1c2354e

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

src/xvr/cli/commands/restart.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,19 @@ def restart(
3232
Restart model training from a checkpoint.
3333
"""
3434
import os
35+
from pathlib import Path
3536

3637
import torch
3738
import wandb
3839

3940
from ...model import Trainer
4041

42+
# If ckptpath is a directory, get the last saved model
43+
ckptpath = Path(ckptpath)
44+
if ckptpath.is_dir():
45+
ckptpath = sorted(ckptpath.glob("*.pth"))[-1]
46+
ckptpath = str(ckptpath)
47+
4148
# Load the config from the previous model checkpoint
4249
config = torch.load(ckptpath, weights_only=False)["config"]
4350
config["ckptpath"] = ckptpath

src/xvr/cli/commands/train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,12 @@ def train(
339339
# Create the output directory for saving model weights
340340
Path(outpath).mkdir(parents=True, exist_ok=True)
341341

342+
# If ckptpath is a directory, get the last saved model
343+
ckptpath = Path(ckptpath)
344+
if ckptpath.is_dir():
345+
ckptpath = sorted(ckptpath.glob("*.pth"))[-1]
346+
ckptpath = str(ckptpath)
347+
342348
# Parse 6-DoF pose parameters
343349
alphamin, alphamax = r1
344350
betamin, betamax = r2

src/xvr/model/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ def train(self, run=None):
168168
if run is not None:
169169
self._log_wandb(itr, log, imgs, masks)
170170

171+
# Save the final model
172+
self._checkpoint(itr)
173+
171174
def step(self, itr):
172175
if self.single_subject:
173176
log, imgs, masks = self._step_single_subject(itr)

0 commit comments

Comments
 (0)