-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
93 lines (73 loc) · 3.4 KB
/
train.py
File metadata and controls
93 lines (73 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
This module is a training script for PyTorch models using the PyTorch Lightning and Hydra libraries.
The script uses configuration files to specify the model, data, callbacks, logger, and trainer to use.
It then trains the model on the data and logs the results to Weights & Biases (W&B) using the W&B logger.
After training, the script saves the best and latest checkpoints of the model, along with the
configuration files used for each checkpoint. Finally, it closes the W&B connection.
"""
import os
import re
import shutil
from os import listdir
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
import wandb
@hydra.main(version_base=None, config_path="conf", config_name="config")
def run(cfg: DictConfig) -> None | float:
pl.seed_everything(seed=cfg.general.seed, workers=True)
print("==> initializing data ...")
datamodule = hydra.utils.instantiate(cfg.data)
print("==> initializing model ...")
model = hydra.utils.instantiate(cfg.model)
print("==> initializing callbacks ...")
callbacks = hydra.utils.instantiate(cfg.callbacks)
print("==> initializing logger...")
wandb_logger = hydra.utils.instantiate(cfg.logger.wandb)
wandb_logger.watch(model, **OmegaConf.to_container(cfg.logger.watch))
print("==> initializing trainer ...")
trainer = hydra.utils.instantiate(
cfg.trainer,
logger=wandb_logger,
callbacks=callbacks,
)
print("==> start training ...")
trainer.fit(model, datamodule)
print("==> extract best metric")
metric = trainer.callback_metrics[cfg.general.optimizer_goal].item()
if cfg.trainer.enable_checkpointing:
print("==> start checkpointing")
# finds the best and latest ckeckpoint for the mode
best_ckpts, exist_best_ckpt, latest_ckpts = [], False, []
for ckpt in listdir("checkpoints"):
if ckpt.startswith(f"latest:mode={cfg.model.mode}"):
latest_ckpts.append(ckpt)
if ckpt.startswith(f"best:mode={cfg.model.mode}"):
exist_best_ckpt = True
match = re.search(r"val_loss=(.*)\.(ckpt|yaml)", ckpt)
best_metric = float(match.group(1))
if metric < best_metric:
best_ckpts.append(ckpt)
# the path to the current config file & the name of the checkpoint
current_config_path = f"{cfg.general.run_dir}/.hydra/config.yaml"
file_name = f"mode={cfg.model.mode}-val_loss={metric:.6f}"
# save the latest checkpoint allways
trainer.save_checkpoint(f"checkpoints/latest:{file_name}.ckpt")
shutil.copyfile(current_config_path, f"checkpoints/latest:{file_name}.yaml")
if latest_ckpts:
for latest_ckpt in latest_ckpts:
os.remove(f"checkpoints/{latest_ckpt}")
# save the best checkpoint
if not exist_best_ckpt:
trainer.save_checkpoint(f"checkpoints/best:{file_name}.ckpt")
shutil.copyfile(current_config_path, f"checkpoints/best:{file_name}.yaml")
elif best_ckpts:
trainer.save_checkpoint(f"checkpoints/best:{file_name}.ckpt")
shutil.copyfile(current_config_path, f"checkpoints/best:{file_name}.yaml")
for best_ckpt in best_ckpts:
os.remove(f"checkpoints/{best_ckpt}")
print("==> close wandb connection")
wandb.finish()
return metric
if __name__ == "__main__":
run()