From 91662eed6931e6fccf384ae7e8821833c9080d13 Mon Sep 17 00:00:00 2001 From: Eduardo Pignatelli Date: Thu, 29 Dec 2022 18:54:41 +0000 Subject: [PATCH] Deprecate `Catch._observation` for `Catch.get_observation` --- bsuite/environments/catch.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/bsuite/environments/catch.py b/bsuite/environments/catch.py index 5f88f4f0..4d046999 100644 --- a/bsuite/environments/catch.py +++ b/bsuite/environments/catch.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================ """Catch reinforcement learning environment.""" +import warnings from typing import Optional @@ -65,6 +66,13 @@ def __init__(self, self._total_regret = 0. self.bsuite_num_episodes = sweep.NUM_EPISODES + def _get_observation(self): + self._board.fill(0.) + self._board[self._ball_y, self._ball_x] = 1. + self._board[self._paddle_y, self._paddle_x] = 1. + + return self._board.copy() + def _reset(self) -> dm_env.TimeStep: """Returns the first `TimeStep` of a new episode.""" self._reset_next_step = False @@ -107,11 +115,10 @@ def action_spec(self) -> specs.DiscreteArray: dtype=np.int, num_values=len(_ACTIONS), name="action") def _observation(self) -> np.ndarray: - self._board.fill(0.) - self._board[self._ball_y, self._ball_x] = 1. - self._board[self._paddle_y, self._paddle_x] = 1. - - return self._board.copy() + warnings.warn( + "Deprecated method `_observation`, use `_get_observation` instead." + ) + return self._get_observation() def bsuite_info(self): return dict(total_regret=self._total_regret)