Skip to content

[Feature Request] A simple (and effective?) way to support cherry-picked env reset in xla mode #293

@bkkgbkjb

Description

@bkkgbkjb

Motivation

AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:

  1. envs.reset() method doesn't accept env_ids arguments, unlike sync mode. And even if it does:
  2. (step_env() or send() seem to accept env_ids arguments,) but it's hard to generate dynamic-shaped env_ids array in xla mode and jax. (I've made some preliminary exps. to verify this)

The problem resulted from this, would be incorrect transitions to be appeared: (term_state, any_action, rew -> init_state)
as also pointed out in #194

Solution

I think we could try adding masked_env_ids in reset() and send() methods

SO we could do something like this

obs, _ = envs.reset()
handle, recv, send, step = envs.xla()

while True:

    handle, (obs, rew, terms, truncs, info) = step(handle, some_acts)

    # proposed masked auto-resetting
    auto_reset_masks = jnp.logical_or(terms, truncs)
    _obs = env.reset(masked_env_ids = auto_reset_masks)

    obs = jnp.where(auto_reset_masks, _obs, obs) 

in xla mode

masked_env_ids has static shape of env_nums and would only reset envs of True in masks and return dummy obs for False-ed envs.

Alternative Methods

Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.

But if the proposed solution is correct, I think it's better to have it for elegance.

Additional context

Unfortunately, I'm not an expert in C++ and I'm not sure if the proposed solution, despite simple, would work as expected.
But based on my understanding, this shall be implementable so long as we perform it in C++ processes.

Checklist

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions