-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathtrain.py
More file actions
48 lines (38 loc) · 1.88 KB
/
train.py
File metadata and controls
48 lines (38 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Adopted from https://github.com/guandeh17/Self-Forcing
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from omegaconf import OmegaConf
import wandb
from trainer import ScoreDistillationTrainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--no_save", action="store_true")
parser.add_argument("--no_visualize", action="store_true")
parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
parser.add_argument("--disable-wandb", action="store_true")
parser.add_argument("--no-auto-resume", action="store_true", help="Disable auto resume from latest checkpoint in logdir")
parser.add_argument("--no-one-logger", action="store_true", help="Disable One Logger (enabled by default)")
args = parser.parse_args()
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
config.no_save = args.no_save
config.no_visualize = args.no_visualize
# get the filename of config_path
# config_name = os.path.basename(args.config_path).split(".")[0]
config_name = os.path.dirname(args.config_path).split("/")[-1]
config.config_name = config_name
config.logdir = args.logdir
config.wandb_save_dir = args.wandb_save_dir
config.disable_wandb = args.disable_wandb
config.auto_resume = not args.no_auto_resume # Default to True unless --no-auto-resume is specified
config.use_one_logger = not args.no_one_logger
if config.trainer == "score_distillation":
trainer = ScoreDistillationTrainer(config)
trainer.train()
wandb.finish()
if __name__ == "__main__":
main()