-
Notifications
You must be signed in to change notification settings - Fork 324
Open
Description
Description
I am trying to find the optimal system parameters (damping/mass etc.) to minimize the MSE between real-sim trajectories. When I tried to get the gradient from the loss function, I got jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: batch in most minor dimension. Here is the minimal example to reproduce the error:
import functools
import jax
import jax.numpy as jp
import mujoco
import numpy as np
from brax.io.mjcf import load_model
from brax.mjx import pipeline
from mujoco import MjModel
jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_disable_jit", True)
jax.config.update("jax_traceback_filtering", False)
mujoco_xml = f"""
<?xml version="1.0" encoding="UTF-8"?>
<mujoco model="sysid">
<compiler angle="radian"/>
<option timestep="0.002"/>
<custom>
<numeric data="2" name="max_contact_points"/>
<numeric data="2" name="max_geom_pairs"/>
</custom>
<worldbody>
<body name="link" pos="0 0.5 0">
<inertial pos="0 0 0"
mass="0.1"
diaginertia="0.08396 0.00125 0.08396"/>
<joint name="hinge_z"
type="hinge"
axis="0 0 1"
pos="0 -0.5 0"
damping="2."/>
<joint name="slide_x"
type="slide"
axis="1 0 0"
damping="100."/>
<geom type="box"
size="0.05 0.5 0.05"
pos="0 0 0"
rgba="0.8 0.3 0.3 1"/>
</body>
<body name="obstacle" pos="0.8 0 0">
<geom type="box"
size="0.1 0.1 0.1"
pos="0 0 0"
rgba="0.2 0.2 0.8 1"/>
</body>
</worldbody>
<actuator>
<motor name="m1"
joint="hinge_z"
ctrlrange="-1 1"/>
</actuator>
</mujoco>
"""
@jax.jit
def modify_params(params, sys):
body_id = 2
sys = sys.replace(
dof_damping=sys.dof_damping.at[0].set(params[0]),
body_mass=sys.body_mass.at[body_id].set(params[1]),
)
return sys
def simulate(
sys,
init_qpos,
init_qvel,
actions,
):
init_state = pipeline.init(sys, init_qpos, init_qvel)
def step_fn(carry, act):
state = carry
# def sub_step_fn(inter_state, _):
# next_inter_state = pipeline.step(sys, inter_state, act)
# return next_inter_state, None
# next_state = jax.lax.scan(sub_step_fn, state, (), length=10)[0]
next_state = pipeline.step(sys, state, act)
out = (next_state.q, next_state.qd)
return next_state, out
_, traj = jax.lax.scan(
step_fn,
init_state,
actions,
)
pred_qpos = traj[0]
pred_qvel = traj[1]
return pred_qpos, pred_qvel
def simulate_batch(
sys,
init_qpos_batch,
init_qvel_batch,
actions_batch,
):
return jax.vmap(simulate, (None, 0, 0, 0))(
sys, init_qpos_batch, init_qvel_batch, actions_batch
)
def mse(y_batched, y_pred_batched):
def squared_error(y, y_pred):
return jp.inner(y - y_pred, y - y_pred) / 2.0
return jp.mean(jax.vmap(squared_error)(y_batched, y_pred_batched), axis=0)
def log_param_transform(fn):
@functools.wraps(fn)
def wrapper(u, *args, **kwargs):
x = jp.exp(u)
return fn(x, *args, **kwargs)
return wrapper
@log_param_transform
def loss_fn(
params,
data,
origin_sys,
modify_fn,
simulate_fn,
):
sys = modify_fn(params, origin_sys)
gt_qpos, gt_qvel, actions = data
init_qpos = gt_qpos[:, 0]
init_qvel = gt_qvel[:, 0]
pred_qpos, _ = simulate_fn(
sys,
init_qpos,
init_qvel,
actions,
)
loss = jax.vmap(mse)(gt_qpos, pred_qpos)
loss = jp.mean(loss)
loss = 100.* loss
return loss
GT_PARAMS = jp.array([200, 0.3], dtype=jp.float32)
batch_size = 10
ep_len = 20
max_iter = 100
lr = 1.0
if __name__ == "__main__":
mj_model = MjModel.from_xml_string(mujoco_xml)
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
# iterations > 1 will raise ValueError: Reverse-mode differentiation does
# not work for lax.while_loop or lax.fori_loop with dynamic start/stop
# values. See:
# https://github.com/google-deepmind/mujoco/issues/1182#issuecomment-1823411911
mj_model.opt.iterations = 1
mj_model.opt.ls_iterations = 4
mj_model.opt.integrator = mujoco.mjtIntegrator.mjINT_IMPLICITFAST
# for fast implicit integrator need to set these to zero
mj_model.opt.density = 0.0
mj_model.opt.viscosity = 0.0
mj_model.opt.wind = np.zeros(3, dtype=np.float32)
# mjx doesn't support elliptic cones yet
mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL
mj_model.dof_armature = 1e-3
sys = load_model(mj_model)
sys_gt = modify_params(GT_PARAMS, sys)
init_qpos_batch = jp.zeros((batch_size, sys.q_size()), dtype=jp.float32)
init_qvel_batch = jp.zeros((batch_size, sys.qd_size()), dtype=jp.float32)
key = jax.random.PRNGKey(0)
action_scale = 1.
actions_batch = action_scale*jax.random.uniform(
key,
(batch_size, ep_len, sys.act_size()),
minval=-1.0,
maxval=1.0,
dtype=jp.float32,
)
simulate_batch_jit = jax.jit(simulate_batch)
qpos_batch, qvel_batch = simulate_batch_jit(
sys_gt,
init_qpos_batch,
init_qvel_batch,
actions_batch,
)
x0 = jp.array([250, 0.5], dtype=jp.float32)
eps = 1e-8
x0 = jp.where(x0 <= 0, eps, x0)
u0 = jp.log(x0)
loss_fn_packed = functools.partial(
loss_fn,
modify_fn=modify_params,
origin_sys=sys,
simulate_fn=simulate_batch_jit,
)
grad_fn = jax.value_and_grad(loss_fn_packed, argnums=0)
grad_fn = jax.jit(grad_fn)
u = u0
for i in range(max_iter):
# with jax.disable_jit():
grad = grad_fn(u, (qpos_batch, qvel_batch, actions_batch))
loss, grad = grad_fn(u, (qpos_batch, qvel_batch, actions_batch))
u = u - lr * grad
print(f"Iter {i} loss: {loss}, params: {jp.exp(u)}")Error not occurs, If I disable jit around grad_fn or the jax.vmap in simulate_batch or just remove one of the two joints. Seems to be related to this jax/issues#16991.
Version and system info
python 3.8.10; Ubuntu 20.04.4; GPU: Nvidia Quadro RTX 4000
brax==0.10.4
mujoco==3.2.3
mujoco-mjx=3.2.3
jax==0.4.13
jaxlib==0.4.13+cuda11.cudnn86
Metadata
Metadata
Assignees
Labels
No labels