-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
88 lines (69 loc) · 2.33 KB
/
main.py
File metadata and controls
88 lines (69 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import os
import multiprocessing as mp
import pprint
import yaml
from src.utils.distributed import init_distributed
from src.train import main as app_main
parser = argparse.ArgumentParser()
parser.add_argument(
"--fname", type=str, help="name of config file to load", default="configs.yaml"
)
parser.add_argument(
"--devices",
type=str,
nargs="+",
default=["cuda:0"],
help="which devices to use on local machine",
)
parser.add_argument(
"--batch_size", type=int, default=None, help="override batch_size from config file"
)
def process_main(rank, fname, world_size, devices, batch_size_override=None):
os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1])
import logging
logging.basicConfig()
logger = logging.getLogger()
if rank == 0:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.ERROR)
logger.info(f"called-params {fname}")
# -- load script params
params = None
with open(fname, "r") as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info("loaded params...")
# -- apply command-line overrides
if batch_size_override is not None:
params["data"]["batch_size"] = batch_size_override
logger.info(f"overriding batch_size to {batch_size_override}")
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
params.setdefault("_tracking", {})
params["_tracking"].update(
{
"config_path": os.path.abspath(fname),
"launcher": "main.py",
"devices": devices,
"batch_size_override": batch_size_override,
}
)
world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
logger.info(f"Running... (rank: {rank}/{world_size})")
app_main(args=params)
if __name__ == "__main__":
args = parser.parse_args()
num_gpus = len(args.devices)
mp.set_start_method("spawn")
for rank in range(num_gpus):
mp.Process(
target=process_main,
args=(rank, args.fname, num_gpus, args.devices, args.batch_size),
).start()