Skip to content

Calculating gradient for system identification gives jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: batch in most minor dimension #623

@jongyaoY

Description

@jongyaoY

Description

I am trying to find the optimal system parameters (damping/mass etc.) to minimize the MSE between real-sim trajectories. When I tried to get the gradient from the loss function, I got jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: batch in most minor dimension. Here is the minimal example to reproduce the error:

import functools

import jax
import jax.numpy as jp
import mujoco
import numpy as np
from brax.io.mjcf import load_model
from brax.mjx import pipeline
from mujoco import MjModel

jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_disable_jit", True)
jax.config.update("jax_traceback_filtering", False)


mujoco_xml = f"""

<?xml version="1.0" encoding="UTF-8"?>
<mujoco model="sysid">
    <compiler angle="radian"/>
    <option timestep="0.002"/>
    <custom>
        <numeric data="2" name="max_contact_points"/>
        <numeric data="2" name="max_geom_pairs"/>
    </custom>

    <worldbody>
        <body name="link" pos="0 0.5 0">
            <inertial pos="0 0 0"
                      mass="0.1"
                      diaginertia="0.08396 0.00125 0.08396"/>
            <joint name="hinge_z"
                   type="hinge"
                   axis="0 0 1"
                   pos="0 -0.5 0"
                   damping="2."/>
            <joint name="slide_x"
                   type="slide"
                   axis="1 0 0"
                   damping="100."/>

            <geom type="box"
                  size="0.05 0.5 0.05"
                  pos="0 0 0"
                  rgba="0.8 0.3 0.3 1"/>
        </body>

        <body name="obstacle" pos="0.8 0 0">
            <geom type="box"
                  size="0.1 0.1 0.1"
                  pos="0 0 0"
                  rgba="0.2 0.2 0.8 1"/>
        </body>
    </worldbody>

    <actuator>
        <motor name="m1"
               joint="hinge_z"
               ctrlrange="-1 1"/>
    </actuator>
</mujoco>

"""


@jax.jit
def modify_params(params, sys):
    body_id = 2
    sys = sys.replace(
        dof_damping=sys.dof_damping.at[0].set(params[0]),
        body_mass=sys.body_mass.at[body_id].set(params[1]),
    )
    return sys


def simulate(
    sys,
    init_qpos,
    init_qvel,
    actions,
):
    init_state = pipeline.init(sys, init_qpos, init_qvel)

    def step_fn(carry, act):
        state = carry

        # def sub_step_fn(inter_state, _):
        #     next_inter_state = pipeline.step(sys, inter_state, act)
        #     return next_inter_state, None

        # next_state = jax.lax.scan(sub_step_fn, state, (), length=10)[0]
        next_state = pipeline.step(sys, state, act)
        out = (next_state.q, next_state.qd)
        return next_state, out

    _, traj = jax.lax.scan(
        step_fn,
        init_state,
        actions,
    )
    pred_qpos = traj[0]
    pred_qvel = traj[1]

    return pred_qpos, pred_qvel


def simulate_batch(
    sys,
    init_qpos_batch,
    init_qvel_batch,
    actions_batch,
):
    return jax.vmap(simulate, (None, 0, 0, 0))(
        sys, init_qpos_batch, init_qvel_batch, actions_batch
    )


def mse(y_batched, y_pred_batched):
    def squared_error(y, y_pred):
        return jp.inner(y - y_pred, y - y_pred) / 2.0

    return jp.mean(jax.vmap(squared_error)(y_batched, y_pred_batched), axis=0)


def log_param_transform(fn):
    @functools.wraps(fn)
    def wrapper(u, *args, **kwargs):
        x = jp.exp(u)
        return fn(x, *args, **kwargs)

    return wrapper


@log_param_transform
def loss_fn(
    params,
    data,
    origin_sys,
    modify_fn,
    simulate_fn,
):
    sys = modify_fn(params, origin_sys)
    gt_qpos, gt_qvel, actions = data
    init_qpos = gt_qpos[:, 0]
    init_qvel = gt_qvel[:, 0]
    pred_qpos, _ = simulate_fn(
        sys,
        init_qpos,
        init_qvel,
        actions,
    )
    loss = jax.vmap(mse)(gt_qpos, pred_qpos)
    loss = jp.mean(loss)
    loss = 100.* loss
    return loss


GT_PARAMS = jp.array([200, 0.3], dtype=jp.float32)

batch_size = 10
ep_len = 20
max_iter = 100
lr = 1.0

    
if __name__ == "__main__":
    mj_model = MjModel.from_xml_string(mujoco_xml)

    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
    # iterations > 1 will raise ValueError: Reverse-mode differentiation does
    # not work for lax.while_loop or lax.fori_loop with dynamic start/stop
    # values. See:
    # https://github.com/google-deepmind/mujoco/issues/1182#issuecomment-1823411911
    mj_model.opt.iterations = 1
    mj_model.opt.ls_iterations = 4
    mj_model.opt.integrator = mujoco.mjtIntegrator.mjINT_IMPLICITFAST
    # for fast implicit integrator need to set these to zero
    mj_model.opt.density = 0.0
    mj_model.opt.viscosity = 0.0
    mj_model.opt.wind = np.zeros(3, dtype=np.float32)
    # mjx doesn't support elliptic cones yet
    mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL

    mj_model.dof_armature = 1e-3
    sys = load_model(mj_model)

    sys_gt = modify_params(GT_PARAMS, sys)
    init_qpos_batch = jp.zeros((batch_size, sys.q_size()), dtype=jp.float32)
    init_qvel_batch = jp.zeros((batch_size, sys.qd_size()), dtype=jp.float32)
    key = jax.random.PRNGKey(0)
    action_scale = 1.
    actions_batch = action_scale*jax.random.uniform(
        key,
        (batch_size, ep_len, sys.act_size()),
        minval=-1.0,
        maxval=1.0,
        dtype=jp.float32,
    )
    simulate_batch_jit = jax.jit(simulate_batch)
    qpos_batch, qvel_batch = simulate_batch_jit(
        sys_gt,
        init_qpos_batch,
        init_qvel_batch,
        actions_batch,
    )

    x0 = jp.array([250, 0.5], dtype=jp.float32)
    eps = 1e-8

    x0 = jp.where(x0 <= 0, eps, x0)
    u0 = jp.log(x0)

    loss_fn_packed = functools.partial(
        loss_fn,
        modify_fn=modify_params,
        origin_sys=sys,
        simulate_fn=simulate_batch_jit,
    )

    grad_fn = jax.value_and_grad(loss_fn_packed, argnums=0)
    grad_fn = jax.jit(grad_fn)
    u = u0
    for i in range(max_iter):
        # with jax.disable_jit():
        grad = grad_fn(u, (qpos_batch, qvel_batch, actions_batch))
        loss, grad = grad_fn(u, (qpos_batch, qvel_batch, actions_batch))
        u = u - lr * grad
        print(f"Iter {i}  loss: {loss}, params: {jp.exp(u)}")

Error not occurs, If I disable jit around grad_fn or the jax.vmap in simulate_batch or just remove one of the two joints. Seems to be related to this jax/issues#16991.

Version and system info

python 3.8.10; Ubuntu 20.04.4; GPU: Nvidia Quadro RTX 4000

brax==0.10.4
mujoco==3.2.3
mujoco-mjx=3.2.3
jax==0.4.13
jaxlib==0.4.13+cuda11.cudnn86

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions