From 0f2ccef8f1cff18a13a1bc168d5303646482f1e0 Mon Sep 17 00:00:00 2001 From: A-lex-Ra Date: Thu, 3 Oct 2024 20:26:16 +0300 Subject: [PATCH] fix error of NoneType _synthetic_init_grid --- gridworld/env.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gridworld/env.py b/gridworld/env.py index 0858a87..d4dce45 100644 --- a/gridworld/env.py +++ b/gridworld/env.py @@ -229,7 +229,13 @@ def reset(self): # blocks to remove have negative ids. '', target_grid=self._task.target_grid - self._synthetic_init_grid ) - self._synthetic_task.reset() + else: + self._synthetic_task = Task( + # create a synthetic task with only diff blocks. + # blocks to remove have negative ids. + '', target_grid=self._task.target_grid + ) + self._synthetic_task.reset() for block in set(self.world.placed): self.world.remove_block(block) @@ -286,8 +292,11 @@ def step(self, action): obs['dialog'] = self._task.chat if self.vector_state: obs['grid'] = self.grid.copy().astype(np.int32) - obs['agentPos'] = np.array([x, y, z, pitch, yaw], dtype=np.float32) - synthetic_grid = self.grid - self._synthetic_init_grid + obs['agentPos'] = np.array([x, y, z, pitch, yaw], dtype=np.float32) + if self._synthetic_init_grid is not None: + synthetic_grid = self.grid - self._synthetic_init_grid + else: + synthetic_grid = self.grid right_placement, wrong_placement, done = self._synthetic_task.step_intersection(synthetic_grid) done = done or (self.step_no == self.max_steps) if right_placement == 0: