|
25 | 25 | from datetime import datetime |
26 | 26 | import functools |
27 | 27 | import json |
28 | | -from pathlib import Path |
29 | 28 |
|
30 | 29 | from brax.io import model |
31 | 30 | from brax.training.agents.bc import networks as bc_networks |
32 | 31 | from brax.training.agents.bc import train as bc_fast |
33 | | -# np.set_printoptions(precision=3, suppress=True, linewidth=100) |
34 | 32 | from etils import epath |
35 | 33 | from flax import linen |
36 | 34 | from flax.training import orbax_utils |
37 | | -# from IPython.display import clear_output |
38 | 35 | import jax |
39 | 36 | from jax import numpy as jp |
40 | | -# from IPython.display import clear_output, display |
41 | 37 | from orbax import checkpoint as ocp |
42 | 38 | import typer |
43 | 39 |
|
|
48 | 44 |
|
49 | 45 | app = typer.Typer(pretty_exceptions_enable=False) |
50 | 46 |
|
51 | | -metrics_keys = [ |
52 | | - "success", |
53 | | - "peg_insertion", |
54 | | - "drop", |
55 | | - "final_grasp", |
56 | | - "robot_target_qpos", |
57 | | -] |
58 | | - |
59 | 47 |
|
60 | 48 | @app.command() |
61 | 49 | def main( |
@@ -128,28 +116,10 @@ def main( |
128 | 116 |
|
129 | 117 | teacher_inference_fn = distillation.make_teacher_policy() |
130 | 118 |
|
131 | | - SUFFIX = None |
132 | | - FINETUNE_PATH = None |
133 | | - |
134 | 119 | # Generate unique experiment name. |
135 | 120 | now = datetime.now() |
136 | 121 | timestamp = now.strftime("%Y%m%d-%H%M%S") |
137 | 122 | exp_name = f"{env_name}-{timestamp}" |
138 | | - if SUFFIX is not None: |
139 | | - exp_name += f"-{SUFFIX}" |
140 | | - print(f"Experiment name: {exp_name}") |
141 | | - |
142 | | - # Possibly restore from the latest checkpoint. |
143 | | - if FINETUNE_PATH is not None: |
144 | | - FINETUNE_PATH = epath.Path(FINETUNE_PATH) |
145 | | - latest_ckpts = list(FINETUNE_PATH.glob("*")) |
146 | | - latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()] |
147 | | - latest_ckpts.sort(key=lambda x: int(x.name)) |
148 | | - latest_ckpt = latest_ckpts[-1] |
149 | | - restore_checkpoint_path = latest_ckpt |
150 | | - print(f"Restoring from: {restore_checkpoint_path}") |
151 | | - else: |
152 | | - restore_checkpoint_path = None |
153 | 123 |
|
154 | 124 | ckpt_path = epath.Path("logs").resolve() / exp_name |
155 | 125 | ckpt_path.mkdir(parents=True, exist_ok=True) |
@@ -190,16 +160,6 @@ def policy_params_fn(current_step, make_policy, params): |
190 | 160 | path = ckpt_path / f"{current_step}" |
191 | 161 | orbax_checkpointer.save(path, params, force=True, save_args=save_args) |
192 | 162 |
|
193 | | - m_path = Path(__file__) |
194 | | - m_path = m_path.parent.parent # private_mujoco_playground |
195 | | - m_path = ( |
196 | | - m_path |
197 | | - / "mujoco_playground" |
198 | | - / "external_deps" |
199 | | - / "mujoco_menagerie" |
200 | | - / "aloha" |
201 | | - ) |
202 | | - |
203 | 163 | def progress(epoch, metrics: dict): |
204 | 164 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
205 | 165 | if epoch == 0 and num_evals > 0: |
|
0 commit comments