-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcli.py
More file actions
73 lines (56 loc) · 1.83 KB
/
cli.py
File metadata and controls
73 lines (56 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Ignite training cli."""
import json
import os
import shutil
from pathlib import Path
# from typing import Any, Dict, Optional, Union
from typing import Optional
import torch
import typer
from alignn.config import TrainingConfig
from alignn.profile import profile_dgl
from alignn.train import train_dgl
def cli(
config: Optional[Path] = typer.Argument(None),
progress: bool = False,
checkpoint_dir: Path = Path("/tmp/models"),
store_outputs: bool = False,
tensorboard: bool = False,
profile: bool = False,
):
"""ALIGNN training cli.
config: path to json config file (conform to TrainingConfig)
progress: enable tqdm console logging
tensorboard: enable tensorboard logging
profile: run profiling script for one epoch instead of training
"""
model_dir = config.parent
if config is None:
model_dir = os.getcwd()
config = TrainingConfig(epochs=10, n_train=32, n_val=32, batch_size=16)
elif config.is_file():
model_dir = config.parent
with open(config, "r") as f:
config = json.load(f)
config = TrainingConfig(**config)
if profile:
profile_dgl(config)
return
hist = train_dgl(
config,
progress=progress,
checkpoint_dir=checkpoint_dir,
store_outputs=store_outputs,
log_tensorboard=tensorboard,
)
# print(model_dir)
# with open(model_dir / "metrics.json", "w") as f:
# json.dump(hist, f)
torch.save(hist, model_dir / "metrics.pt")
with open(model_dir / "fullconfig.json", "w") as f:
json.dump(json.loads(config.json()), f, indent=2)
# move temporary checkpoint data into model_dir
for checkpoint in checkpoint_dir.glob("*.pt"):
shutil.copy(checkpoint, model_dir / checkpoint.name)
if __name__ == "__main__":
typer.run(cli)