-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
83 lines (69 loc) · 2.54 KB
/
run.py
File metadata and controls
83 lines (69 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Entry point: run research pipeline.
Usage:
python run.py --config configs/gkx_linear.yaml --verbose
python run.py --config configs/gkx_linear.yaml --override train.max_train_rows=0 experiment.run_id=gkx_linear_full
"""
import argparse
import os
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent
os.chdir(ROOT)
if str(ROOT / "src") not in sys.path:
sys.path.insert(0, str(ROOT / "src"))
def _parse_overrides(override_strs):
"""Parse dotted key=value overrides into nested dict.
Examples:
train.max_train_rows=0 -> {"train": {"max_train_rows": 0}}
experiment.run_id=foo -> {"experiment": {"run_id": "foo"}}
"""
if not override_strs:
return {}
result = {}
for s in override_strs:
if "=" not in s:
print(f"Invalid override (missing =): {s}", file=sys.stderr)
continue
key, val = s.split("=", 1)
# Auto-cast value
if val.lower() in ("true",):
val = True
elif val.lower() in ("false",):
val = False
elif val.lower() in ("none", "null"):
val = None
else:
try:
val = int(val)
except ValueError:
try:
val = float(val)
except ValueError:
pass # keep as string
parts = key.split(".")
d = result
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = val
return result
def main():
p = argparse.ArgumentParser(description="Run research pipeline")
p.add_argument("--config", "-c", type=Path, default=None, help="Config YAML")
p.add_argument("--verbose", "-v", action="store_true", help="Log to console")
p.add_argument("--stages", nargs="*", help="Stages to run (default: from config)")
p.add_argument("--override", "-O", nargs="*", default=[],
help="Config overrides as dotted.key=value (e.g. train.max_train_rows=0)")
args = p.parse_args()
config_path = args.config or ROOT / "configs" / "research_default.yaml"
if not config_path.exists():
print(f"Config not found: {config_path}", file=sys.stderr)
sys.exit(1)
overrides = _parse_overrides(args.override)
from research.pipeline import run_pipeline
stages = args.stages if args.stages else None
return run_pipeline(config_path=config_path, config_overrides=overrides or None, stages=stages)
if __name__ == "__main__":
sys.exit(main() or 0)