-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·26 lines (23 loc) · 988 Bytes
/
train.py
File metadata and controls
executable file
·26 lines (23 loc) · 988 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import ray
import argparse
from runner import Runner
from data import data_utils
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('task', type=str,
help='Task to traing on')
parser.add_argument('--num_gpus', type=int, default=1,
help='Number of GPUs to use during training')
parser.add_argument('--results_path', type=str, default=None,
help='Path to save results/logs to while training. Defaults to timestamp if not specified')
parser.add_argument('--checkpoint', type=str, default=None,
help='Path to the checkpoint to load')
parser.add_argument('--buffer', type=str, default=None,
help='Path to the replay buffer to load')
args = parser.parse_args()
task_config = data_utils.getTaskConfig(args.task, args.num_gpus, results_path=args.results_path)
runner = Runner(task_config,
checkpoint=args.checkpoint,
replay_buffer=args.buffer)
runner.train()
ray.shutdown()