Motivation
AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:
envs.reset() method doesn't accept env_ids arguments, unlike sync mode. And even if it does:
- (
step_env() or send() seem to accept env_ids arguments,) but it's hard to generate dynamic-shaped env_ids array in xla mode and jax. (I've made some preliminary exps. to verify this)
The problem resulted from this, would be incorrect transitions to be appeared: (term_state, any_action, rew -> init_state)
as also pointed out in #194
Solution
I think we could try adding masked_env_ids in reset() and send() methods
SO we could do something like this
obs, _ = envs.reset()
handle, recv, send, step = envs.xla()
while True:
handle, (obs, rew, terms, truncs, info) = step(handle, some_acts)
# proposed masked auto-resetting
auto_reset_masks = jnp.logical_or(terms, truncs)
_obs = env.reset(masked_env_ids = auto_reset_masks)
obs = jnp.where(auto_reset_masks, _obs, obs)
in xla mode
masked_env_ids has static shape of env_nums and would only reset envs of True in masks and return dummy obs for False-ed envs.
Alternative Methods
Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.
But if the proposed solution is correct, I think it's better to have it for elegance.
Additional context
Unfortunately, I'm not an expert in C++ and I'm not sure if the proposed solution, despite simple, would work as expected.
But based on my understanding, this shall be implementable so long as we perform it in C++ processes.
Checklist
Motivation
AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:
envs.reset()method doesn't acceptenv_idsarguments, unlike sync mode. And even if it does:step_env()orsend()seem to acceptenv_idsarguments,) but it's hard to generate dynamic-shapedenv_idsarray inxlamode andjax. (I've made some preliminary exps. to verify this)The problem resulted from this, would be incorrect transitions to be appeared:
(term_state, any_action, rew -> init_state)as also pointed out in #194
Solution
I think we could try adding
masked_env_idsinreset()andsend()methodsSO we could do something like this
in
xlamodemasked_env_idshas static shape ofenv_numsand would only reset envs ofTruein masks and return dummy obs forFalse-ed envs.Alternative Methods
Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.
But if the proposed solution is correct, I think it's better to have it for elegance.
Additional context
Unfortunately, I'm not an expert in
C++and I'm not sure if the proposed solution, despite simple, would work as expected.But based on my understanding, this shall be implementable so long as we perform it in
C++processes.Checklist