Skip to content

[BUG] SliceSampler doesn't work as expected when collecting data from parallel environment? #3194

@briannnyee

Description

@briannnyee

Hi TorchRL devs,

I currently have this settings (simplified for clarity):

sampler = SliceSampler(
            slice_len=4, 
            end_key=None,
            traj_key=("collector", "traj_ids"),
            truncated_key=None,
		    strict_length=True,
)

...some codes...

frames_per_batch = num_envs * num_steps_per_env
collector = SyncDataCollectorWrapper(
            create_env_fn=env,
            policy=actor_module,
            frames_per_batch=frames_per_batch,
            total_frames=total_frames,
            init_random_frames=init_random_frames,
            exploration_type=ExplorationType.RANDOM,
            device=self.device,
)

...some codes...

data = next(collector_iter)
self.replay_buffer.extend(data.reshape(-1))

batch = self.replay_buffer.sample()
# RuntimeError: Did not find a single trajectory with sufficient length (length range: 1 - 1 / required=4))

After spending some time investigating this, I realized that the problem could be because SliceSampler expects wrong format of traj_key. Let's say we have num_envs=2 and num_steps_per_env=1, SliceSampler expects the data is stored in a episodic way, e.g. traj_key=[0,0,0,...0,1,1,1...,1]. While in reality, the data is stored sequentially, traj_key=[0,1,0,1,...,0,1].

Did I do something wrong here? or is there a way to workaround this / is it a bug that needs a patch?

My torchrl version is 0.8. Let me know if I need to provide more info. Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions