Skip to content

Commit b4ae622

Browse files
committed
clean up train_dagger.py
1 parent 2cce751 commit b4ae622

File tree

1 file changed

+0
-40
lines changed

1 file changed

+0
-40
lines changed

mujoco_playground/experimental/train_dagger.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,15 @@
2525
from datetime import datetime
2626
import functools
2727
import json
28-
from pathlib import Path
2928

3029
from brax.io import model
3130
from brax.training.agents.bc import networks as bc_networks
3231
from brax.training.agents.bc import train as bc_fast
33-
# np.set_printoptions(precision=3, suppress=True, linewidth=100)
3432
from etils import epath
3533
from flax import linen
3634
from flax.training import orbax_utils
37-
# from IPython.display import clear_output
3835
import jax
3936
from jax import numpy as jp
40-
# from IPython.display import clear_output, display
4137
from orbax import checkpoint as ocp
4238
import typer
4339

@@ -48,14 +44,6 @@
4844

4945
app = typer.Typer(pretty_exceptions_enable=False)
5046

51-
metrics_keys = [
52-
"success",
53-
"peg_insertion",
54-
"drop",
55-
"final_grasp",
56-
"robot_target_qpos",
57-
]
58-
5947

6048
@app.command()
6149
def main(
@@ -128,28 +116,10 @@ def main(
128116

129117
teacher_inference_fn = distillation.make_teacher_policy()
130118

131-
SUFFIX = None
132-
FINETUNE_PATH = None
133-
134119
# Generate unique experiment name.
135120
now = datetime.now()
136121
timestamp = now.strftime("%Y%m%d-%H%M%S")
137122
exp_name = f"{env_name}-{timestamp}"
138-
if SUFFIX is not None:
139-
exp_name += f"-{SUFFIX}"
140-
print(f"Experiment name: {exp_name}")
141-
142-
# Possibly restore from the latest checkpoint.
143-
if FINETUNE_PATH is not None:
144-
FINETUNE_PATH = epath.Path(FINETUNE_PATH)
145-
latest_ckpts = list(FINETUNE_PATH.glob("*"))
146-
latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
147-
latest_ckpts.sort(key=lambda x: int(x.name))
148-
latest_ckpt = latest_ckpts[-1]
149-
restore_checkpoint_path = latest_ckpt
150-
print(f"Restoring from: {restore_checkpoint_path}")
151-
else:
152-
restore_checkpoint_path = None
153123

154124
ckpt_path = epath.Path("logs").resolve() / exp_name
155125
ckpt_path.mkdir(parents=True, exist_ok=True)
@@ -190,16 +160,6 @@ def policy_params_fn(current_step, make_policy, params):
190160
path = ckpt_path / f"{current_step}"
191161
orbax_checkpointer.save(path, params, force=True, save_args=save_args)
192162

193-
m_path = Path(__file__)
194-
m_path = m_path.parent.parent # private_mujoco_playground
195-
m_path = (
196-
m_path
197-
/ "mujoco_playground"
198-
/ "external_deps"
199-
/ "mujoco_menagerie"
200-
/ "aloha"
201-
)
202-
203163
def progress(epoch, metrics: dict):
204164
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
205165
if epoch == 0 and num_evals > 0:

0 commit comments

Comments
 (0)