diff --git a/gridworld/env.py b/gridworld/env.py index 0858a87..d4dce45 100644 --- a/gridworld/env.py +++ b/gridworld/env.py @@ -229,7 +229,13 @@ def reset(self): # blocks to remove have negative ids. '', target_grid=self._task.target_grid - self._synthetic_init_grid ) - self._synthetic_task.reset() + else: + self._synthetic_task = Task( + # create a synthetic task with only diff blocks. + # blocks to remove have negative ids. + '', target_grid=self._task.target_grid + ) + self._synthetic_task.reset() for block in set(self.world.placed): self.world.remove_block(block) @@ -286,8 +292,11 @@ def step(self, action): obs['dialog'] = self._task.chat if self.vector_state: obs['grid'] = self.grid.copy().astype(np.int32) - obs['agentPos'] = np.array([x, y, z, pitch, yaw], dtype=np.float32) - synthetic_grid = self.grid - self._synthetic_init_grid + obs['agentPos'] = np.array([x, y, z, pitch, yaw], dtype=np.float32) + if self._synthetic_init_grid is not None: + synthetic_grid = self.grid - self._synthetic_init_grid + else: + synthetic_grid = self.grid right_placement, wrong_placement, done = self._synthetic_task.step_intersection(synthetic_grid) done = done or (self.step_no == self.max_steps) if right_placement == 0: