-
Couldn't load subscription status.
- Fork 416
Open
Description
Hi TorchRL devs,
I currently have this settings (simplified for clarity):
sampler = SliceSampler(
slice_len=4,
end_key=None,
traj_key=("collector", "traj_ids"),
truncated_key=None,
strict_length=True,
)
...some codes...
frames_per_batch = num_envs * num_steps_per_env
collector = SyncDataCollectorWrapper(
create_env_fn=env,
policy=actor_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
init_random_frames=init_random_frames,
exploration_type=ExplorationType.RANDOM,
device=self.device,
)
...some codes...
data = next(collector_iter)
self.replay_buffer.extend(data.reshape(-1))
batch = self.replay_buffer.sample()
# RuntimeError: Did not find a single trajectory with sufficient length (length range: 1 - 1 / required=4))After spending some time investigating this, I realized that the problem could be because SliceSampler expects wrong format of traj_key. Let's say we have num_envs=2 and num_steps_per_env=1, SliceSampler expects the data is stored in a episodic way, e.g. traj_key=[0,0,0,...0,1,1,1...,1]. While in reality, the data is stored sequentially, traj_key=[0,1,0,1,...,0,1].
Did I do something wrong here? or is there a way to workaround this / is it a bug that needs a patch?
My torchrl version is 0.8. Let me know if I need to provide more info. Thanks in advance!
Metadata
Metadata
Assignees
Labels
No labels