-
Couldn't load subscription status.
- Fork 1.4k
Open
Description
It seems that the method is_valid_transition of OutOfGraphReplayBuffer is not checking if the stacked images are coming from another truncated trajectory, in which case the index is invalid.
It only checks if:
- the stacked images are coming from another terminated trajectory:
if self.get_terminal_stack(index)[:-1].any(): - the following observations are coming from the trajectory and are not truncated:
if i in self.episode_end_indices and not self._store['terminal'][i]:
Here is a simple example of where it can be problematic:
import numpy as np
from dopamine.replay_memory.circular_replay_buffer import OutOfGraphReplayBuffer
replay_buffer = OutOfGraphReplayBuffer(observation_shape=(1,), stack_size=2, replay_capacity=10, batch_size=1)
replay_buffer.add(np.array([1]), 1, 1, False, episode_end=True)
replay_buffer.add(np.array([2]), 2, 2, False)
print(replay_buffer._store["observation"][:4])
print(replay_buffer.sample_transition_batch())>>> [[0], [1], [0], [2]] # there is no valid index to sample.
>>> (array([[[1, 0]]], dtype=uint8), array([0], dtype=int32), array([0.], dtype=float32), array([[[0, 2]]], dtype=uint8), array([2], dtype=int32), array([2.], dtype=float32), array([0], dtype=uint8), array([2], dtype=int32))Here, index 2 is considered to be valid while it is not the case since the state array([[[1, 0]]]) is composed of an observation from the previous trajectory: [1] and a sample from the new trajectory: [0].
To solve this bug,
| for i in modulo_range(index, self._update_horizon, self._replay_capacity): |
for i in modulo_range(index - self._stack_size + 1, self._update_horizon, self._replay_capacity):Metadata
Metadata
Assignees
Labels
No labels