From d4159f5869a87fefbbe7e31640d0cad97d25433e Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Mon, 26 Sep 2022 19:02:50 +0530 Subject: [PATCH 1/6] Add high level changes to the algorithm --- .../algorithms/adversarial/common.py | 45 ++++-- .../policies/replay_buffer_wrapper.py | 140 +++++++++++++++++- 2 files changed, 169 insertions(+), 16 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index eee6937e4..a78831eba 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -16,6 +16,7 @@ from imitation.algorithms import base from imitation.data import buffer, rollout, types, wrappers +from imitation.policies import replay_buffer_wrapper from imitation.rewards import reward_nets, reward_wrapper from imitation.util import logger, networks, util @@ -357,6 +358,29 @@ def train_disc( return train_stats + def collect_rollouts( + self, + total_timesteps: Optional[int] = None, + callback: MaybeCallback = None, + ): + """Collect rollouts. + + Args: + total_timesteps: The number of transitions to sample from + `self.venv_train` during training. By default, + `self.gen_train_timesteps`. + """ + if total_timesteps is None: + total_timesteps = self.gen_train_timesteps + + # NOTE (Taufeeque): call setup_learn or not? + self.gen_algo.collect_rollouts( + self.gen_algo.env, + callback, + self.gen_algo.rollout_buffer, + n_rollout_steps=total_timesteps, + ) + def train_gen( self, total_timesteps: Optional[int] = None, @@ -368,24 +392,20 @@ def train_gen( discriminator training) with `self.disc_batch_size` transitions. Args: - total_timesteps: The number of transitions to sample from - `self.venv_train` during training. By default, - `self.gen_train_timesteps`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. """ - if total_timesteps is None: - total_timesteps = self.gen_train_timesteps if learn_kwargs is None: learn_kwargs = {} with self.logger.accumulate_means("gen"): - self.gen_algo.learn( - total_timesteps=total_timesteps, - reset_num_timesteps=False, - callback=self.gen_callback, - **learn_kwargs, - ) + # self.gen_algo.learn( + # total_timesteps=total_timesteps, + # reset_num_timesteps=False, + # callback=self.gen_callback, + # **learn_kwargs, + # ) + self.gen_algo.train() self._global_step += 1 gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() @@ -420,11 +440,12 @@ def train( f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.train_gen(self.gen_train_timesteps) + self.collect_rollouts(self.gen_train_timesteps) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() + self.train_gen() if callback: callback(r) self.logger.dump(self._global_step) diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 9bb011063..c33a23322 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -5,14 +5,39 @@ import numpy as np from gym import spaces -from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer -from stable_baselines3.common.type_aliases import ReplayBufferSamples +from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer, RolloutBuffer +from stable_baselines3.common.type_aliases import ( + ReplayBufferSamples, + RolloutBufferSamples, +) from imitation.rewards.reward_function import RewardFn from imitation.util import util -def _samples_to_reward_fn_input( +class RolloutBufferMod(RolloutBuffer): + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super().__init__( + buffer_size, + observation_space, + action_space, + device, + gae_lambda, + gamma, + n_envs, + ) + + +def _replay_samples_to_reward_fn_input( samples: ReplayBufferSamples, ) -> Mapping[str, np.ndarray]: """Convert a sample from a replay buffer to a numpy array.""" @@ -24,6 +49,18 @@ def _samples_to_reward_fn_input( ) +def _rollout_samples_to_reward_fn_input( + buffer: RolloutBuffer, +) -> Mapping[str, np.ndarray]: + """Convert a sample from a rollout buffer to a numpy array.""" + return dict( + state=buffer.observations, + action=buffer.actions, + next_state=None, + done=None, + ) + + class ReplayBufferRewardWrapper(BaseBuffer): """Relabel the rewards in transitions sampled from a ReplayBuffer.""" @@ -81,7 +118,7 @@ def full(self, full: bool): def sample(self, *args, **kwargs): samples = self.replay_buffer.sample(*args, **kwargs) - rewards = self.reward_fn(**_samples_to_reward_fn_input(samples)) + rewards = self.reward_fn(**_replay_samples_to_reward_fn_input(samples)) shape = samples.rewards.shape device = samples.rewards.device rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device) @@ -102,3 +139,98 @@ def _get_samples(self): "_get_samples() is intentionally not implemented." "This method should not be called.", ) + + +class RolloutBufferRewardWrapper(BaseBuffer): + """Relabel the rewards in transitions sampled from a RolloutBuffer.""" + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + *, + rollout_buffer_class: Type[RolloutBuffer], + reward_fn: RewardFn, + **kwargs, + ): + """Builds RolloutBufferRewardWrapper. + + Args: + buffer_size: Max number of elements in the buffer + observation_space: Observation space + action_space: Action space + replay_buffer_class: Class of the replay buffer. + reward_fn: Reward function for reward relabeling. + **kwargs: keyword arguments for ReplayBuffer. + """ + # Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of + # DictReplayBuffer because the current RewardFn only takes in NumPy array-based + # inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See: + # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 + assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported" + assert not isinstance(observation_space, spaces.Dict) + self.rollout_buffer = rollout_buffer_class( + buffer_size, + observation_space, + action_space, + **kwargs, + ) + self.reward_fn = reward_fn + _base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]} + super().__init__(buffer_size, observation_space, action_space, **_base_kwargs) + + @property + def pos(self) -> int: + return self.rollout_buffer.pos + + @pos.setter + def pos(self, pos: int): + self.rollout_buffer.pos = pos + + @property + def full(self) -> bool: + return self.rollout_buffer.full + + @full.setter + def full(self, full: bool): + self.rollout_buffer.full = full + + # def sample(self, *args, **kwargs): + # samples = self.rollout_buffer.sample(*args, **kwargs) + # rewards = self.reward_fn(**_replay_samples_to_reward_fn_input(samples)) + # shape = samples.rewards.shape + # device = samples.rewards.device + # rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device) + + # return RolloutBufferSamples( + # samples.observations, + # samples.actions, + # samples.next_observations, + # samples.dones, + # rewards_th, + # ) + + def get(self, *args, **kwargs): + + rewards = self.reward_fn(**_rollout_samples_to_reward_fn_input(samples)) + shape = samples.rewards.shape + device = samples.rewards.device + rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device) + + return RolloutBufferSamples( + samples.observations, + samples.actions, + samples.next_observations, + samples.dones, + rewards_th, + ) + + def add(self, *args, **kwargs): + self.rollout_buffer.add(*args, **kwargs) + + def _get_samples(self): + raise NotImplementedError( + "_get_samples() is intentionally not implemented." + "This method should not be called.", + ) From 8623e36e295b5587909d3f3bdde76d553ee8524f Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Tue, 27 Sep 2022 02:20:08 +0530 Subject: [PATCH 2/6] Add minor changes --- src/imitation/algorithms/adversarial/common.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index a78831eba..580f82b6c 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -380,6 +380,10 @@ def collect_rollouts( self.gen_algo.rollout_buffer, n_rollout_steps=total_timesteps, ) + gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() + self._check_fixed_horizon(ep_lens) + gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) + self._gen_replay_buffer.store(gen_samples) def train_gen( self, @@ -408,11 +412,6 @@ def train_gen( self.gen_algo.train() self._global_step += 1 - gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() - self._check_fixed_horizon(ep_lens) - gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) - self._gen_replay_buffer.store(gen_samples) - def train( self, total_timesteps: int, From 34b52ff6d0ab4778fc33581b0979765daf832a3d Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Thu, 29 Sep 2022 01:14:08 +0530 Subject: [PATCH 3/6] Add hacky workaround to implement reference paper's adversarial algo --- .../algorithms/adversarial/common.py | 77 ++++++++++++++++--- .../policies/replay_buffer_wrapper.py | 76 +++++------------- 2 files changed, 86 insertions(+), 67 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 580f82b6c..787bbb935 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -10,7 +10,10 @@ import torch as th import torch.utils.tensorboard as thboard import tqdm -from stable_baselines3.common import base_class, policies, vec_env +from stable_baselines3.common import base_class, on_policy_algorithm, policies +from stable_baselines3.common import utils as sb3_utils +from stable_baselines3.common import vec_env +from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F @@ -232,6 +235,22 @@ def __init__( else: self.gen_train_timesteps = gen_train_timesteps + if type(self.gen_algo) is on_policy_algorithm.OnPolicyAlgorithm: + rollout_buffer = self.gen_algo.rollout_buffer + self.gen_algo.rollout_buffer = ( + replay_buffer_wrapper.RolloutBufferRewardWrapper( + buffer_size=self.gen_train_timesteps // rollout_buffer.n_envs, + observation_space=rollout_buffer.observation_space, + action_space=rollout_buffer.action_space, + rollout_buffer_class=rollout_buffer.__class__, + reward_fn=self.reward_train.predict_processed, + device=rollout_buffer.device, + gae_lambda=rollout_buffer.gae_lambda, + gamma=rollout_buffer.gamma, + n_envs=rollout_buffer.n_envs, + ) + ) + if gen_replay_buffer_capacity is None: gen_replay_buffer_capacity = self.gen_train_timesteps self._gen_replay_buffer = buffer.ReplayBuffer( @@ -362,6 +381,7 @@ def collect_rollouts( self, total_timesteps: Optional[int] = None, callback: MaybeCallback = None, + learn_kwargs: Optional[Mapping] = None, ): """Collect rollouts. @@ -369,39 +389,74 @@ def collect_rollouts( total_timesteps: The number of transitions to sample from `self.venv_train` during training. By default, `self.gen_train_timesteps`. + callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` + method. """ + if learn_kwargs is None: + learn_kwargs = {} + if total_timesteps is None: total_timesteps = self.gen_train_timesteps + # total timesteps should be per env + total_timesteps = total_timesteps // self.gen_algo.n_envs # NOTE (Taufeeque): call setup_learn or not? + if "eval_env" not in learn_kwargs: + total_timesteps, callback = self.gen_algo._setup_learn( + total_timesteps, + eval_env=None, + callback=callback, + **learn_kwargs, + ) + else: + total_timesteps, callback = self.gen_algo._setup_learn( + total_timesteps, + callback=callback, + **learn_kwargs, + ) + callback.on_training_start(locals(), globals()) self.gen_algo.collect_rollouts( self.gen_algo.env, callback, self.gen_algo.rollout_buffer, n_rollout_steps=total_timesteps, ) + if ( + len(self.gen_algo.ep_info_buffer) > 0 + and len(self.gen_algo.ep_info_buffer[0]) > 0 + ): + self.logger.record( + "rollout/ep_rew_mean", + sb3_utils.safe_mean( + [ep_info["r"] for ep_info in self.gen_algo.ep_info_buffer] + ), + ) + self.logger.record( + "rollout/ep_len_mean", + sb3_utils.safe_mean( + [ep_info["l"] for ep_info in self.gen_algo.ep_info_buffer] + ), + ) + self.logger.record( + "time/total_timesteps", self.gen_algo.num_timesteps, exclude="tensorboard" + ) + gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() self._check_fixed_horizon(ep_lens) gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) self._gen_replay_buffer.store(gen_samples) + callback.on_training_end() def train_gen( self, - total_timesteps: Optional[int] = None, - learn_kwargs: Optional[Mapping] = None, ) -> None: """Trains the generator to maximize the discriminator loss. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. - - Args: - learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` - method. """ - if learn_kwargs is None: - learn_kwargs = {} - with self.logger.accumulate_means("gen"): # self.gen_algo.learn( # total_timesteps=total_timesteps, @@ -439,7 +494,7 @@ def train( f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.collect_rollouts(self.gen_train_timesteps) + self.collect_rollouts(self.gen_train_timesteps, self.gen_callback) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index c33a23322..df34eb335 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -4,39 +4,15 @@ from typing import Mapping, Type import numpy as np +import torch as th from gym import spaces from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer, RolloutBuffer -from stable_baselines3.common.type_aliases import ( - ReplayBufferSamples, - RolloutBufferSamples, -) +from stable_baselines3.common.type_aliases import ReplayBufferSamples from imitation.rewards.reward_function import RewardFn from imitation.util import util -class RolloutBufferMod(RolloutBuffer): - def __init__( - self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[th.device, str] = "auto", - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1, - ): - super().__init__( - buffer_size, - observation_space, - action_space, - device, - gae_lambda, - gamma, - n_envs, - ) - - def _replay_samples_to_reward_fn_input( samples: ReplayBufferSamples, ) -> Mapping[str, np.ndarray]: @@ -160,12 +136,12 @@ def __init__( buffer_size: Max number of elements in the buffer observation_space: Observation space action_space: Action space - replay_buffer_class: Class of the replay buffer. + rollout_buffer_class: Class of the rollout buffer. reward_fn: Reward function for reward relabeling. - **kwargs: keyword arguments for ReplayBuffer. + **kwargs: keyword arguments for RolloutBuffer. """ - # Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of - # DictReplayBuffer because the current RewardFn only takes in NumPy array-based + # Note(yawen-d): we directly inherit RolloutBuffer and leave out the case of + # DictRolloutBuffer because the current RewardFn only takes in NumPy array-based # inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See: # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported" @@ -196,35 +172,16 @@ def full(self) -> bool: def full(self, full: bool): self.rollout_buffer.full = full - # def sample(self, *args, **kwargs): - # samples = self.rollout_buffer.sample(*args, **kwargs) - # rewards = self.reward_fn(**_replay_samples_to_reward_fn_input(samples)) - # shape = samples.rewards.shape - # device = samples.rewards.device - # rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device) - - # return RolloutBufferSamples( - # samples.observations, - # samples.actions, - # samples.next_observations, - # samples.dones, - # rewards_th, - # ) + def reset(self): + self.rollout_buffer.reset() def get(self, *args, **kwargs): - - rewards = self.reward_fn(**_rollout_samples_to_reward_fn_input(samples)) - shape = samples.rewards.shape - device = samples.rewards.device - rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device) - - return RolloutBufferSamples( - samples.observations, - samples.actions, - samples.next_observations, - samples.dones, - rewards_th, + self.rollout_buffer.rewards = self.reward_fn( + **_rollout_samples_to_reward_fn_input(self.rollout_buffer), ) + self.rollout_buffer.compute_returns_and_advantage(self.last_values, self.dones) + + return self.rollout_buffer.get(*args, **kwargs) def add(self, *args, **kwargs): self.rollout_buffer.add(*args, **kwargs) @@ -234,3 +191,10 @@ def _get_samples(self): "_get_samples() is intentionally not implemented." "This method should not be called.", ) + + def compute_returns_and_advantage( + self, last_values: th.Tensor, dones: np.ndarray + ) -> None: + self.last_values = last_values + self.last_dones = dones + self.rollout_buffer.compute_returns_and_advantage(last_values, dones) From e783e2f9ea5d62495e73522e019538c0ca602b95 Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Sat, 19 Nov 2022 16:38:23 +0530 Subject: [PATCH 4/6] Fix bug and add support for Off Policy RL --- .../algorithms/adversarial/common.py | 103 +++++++++++++----- .../policies/replay_buffer_wrapper.py | 5 +- 2 files changed, 79 insertions(+), 29 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 2e3e51e0a..55e6610a8 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -254,7 +254,10 @@ def __init__( else: self.gen_train_timesteps = gen_train_timesteps - if type(self.gen_algo) is on_policy_algorithm.OnPolicyAlgorithm: + self.is_gen_on_policy = isinstance( + self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm + ) + if self.is_gen_on_policy: rollout_buffer = self.gen_algo.rollout_buffer self.gen_algo.rollout_buffer = ( replay_buffer_wrapper.RolloutBufferRewardWrapper( @@ -269,6 +272,19 @@ def __init__( n_envs=rollout_buffer.n_envs, ) ) + else: + replay_buffer = self.gen_algo.replay_buffer + self.gen_algo.replay_buffer = ( + replay_buffer_wrapper.ReplayBufferRewardWrapper( + buffer_size=self.gen_train_timesteps, + observation_space=replay_buffer.observation_space, + action_space=replay_buffer.action_space, + replay_buffer_class=replay_buffer.__class__, + reward_fn=self.reward_train.predict_processed, + device=replay_buffer.device, + n_envs=replay_buffer.n_envs, + ) + ) if gen_replay_buffer_capacity is None: gen_replay_buffer_capacity = self.gen_train_timesteps @@ -446,40 +462,61 @@ def collect_rollouts( **learn_kwargs, ) callback.on_training_start(locals(), globals()) - self.gen_algo.collect_rollouts( - self.gen_algo.env, - callback, - self.gen_algo.rollout_buffer, - n_rollout_steps=total_timesteps, - ) - if ( - len(self.gen_algo.ep_info_buffer) > 0 - and len(self.gen_algo.ep_info_buffer[0]) > 0 - ): - self.logger.record( - "rollout/ep_rew_mean", - sb3_utils.safe_mean( - [ep_info["r"] for ep_info in self.gen_algo.ep_info_buffer] - ), + if self.is_gen_on_policy: + self.gen_algo.collect_rollouts( + self.gen_algo.env, + callback, + self.gen_algo.rollout_buffer, + n_rollout_steps=total_timesteps, ) - self.logger.record( - "rollout/ep_len_mean", - sb3_utils.safe_mean( - [ep_info["l"] for ep_info in self.gen_algo.ep_info_buffer] - ), + rollouts = None + else: + self.gen_algo.train_freq = total_timesteps + self.gen_algo._convert_train_freq() + rollouts = self.gen_algo.collect_rollouts( + self.gen_algo.env, + train_freq=self.gen_algo.train_freq, + action_noise=self.gen_algo.action_noise, + callback=callback, + learning_starts=self.gen_algo.learning_starts, + replay_buffer=self.gen_algo.replay_buffer, ) - self.logger.record( - "time/total_timesteps", self.gen_algo.num_timesteps, exclude="tensorboard" - ) + + if self.is_gen_on_policy: + if ( + len(self.gen_algo.ep_info_buffer) > 0 + and len(self.gen_algo.ep_info_buffer[0]) > 0 + ): + self.logger.record( + "rollout/ep_rew_mean", + sb3_utils.safe_mean( + [ep_info["r"] for ep_info in self.gen_algo.ep_info_buffer] + ), + ) + self.logger.record( + "rollout/ep_len_mean", + sb3_utils.safe_mean( + [ep_info["l"] for ep_info in self.gen_algo.ep_info_buffer] + ), + ) + self.logger.record( + "time/total_timesteps", + self.gen_algo.num_timesteps, + exclude="tensorboard", + ) + else: + self.gen_algo._dump_logs() gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() self._check_fixed_horizon(ep_lens) gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) self._gen_replay_buffer.store(gen_samples) callback.on_training_end() + return rollouts def train_gen( self, + rollouts, ) -> None: """Trains the generator to maximize the discriminator loss. @@ -493,7 +530,17 @@ def train_gen( # callback=self.gen_callback, # **learn_kwargs, # ) - self.gen_algo.train() + if self.is_gen_on_policy: + self.gen_algo.train() + else: + if self.gen_algo.gradient_steps >= 0: + gradient_steps = self.gen_algo.gradient_steps + else: + gradient_steps = rollouts.episode_timesteps + self.gen_algo.train( + batch_size=self.gen_algo.batch_size, + gradient_steps=gradient_steps, + ) self._global_step += 1 def train( @@ -523,12 +570,14 @@ def train( f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.collect_rollouts(self.gen_train_timesteps, self.gen_callback) + rollouts = self.collect_rollouts( + self.gen_train_timesteps, self.gen_callback + ) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() - self.train_gen() + self.train_gen(rollouts) if callback: callback(r) self.logger.dump(self._global_step) diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 0190db89a..f66834d04 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -28,11 +28,12 @@ def _rollout_samples_to_reward_fn_input( buffer: RolloutBuffer, ) -> Mapping[str, np.ndarray]: """Convert a sample from a rollout buffer to a numpy array.""" + shape = buffer.observations.shape return dict( state=buffer.observations, action=buffer.actions, - next_state=None, - done=None, + next_state=np.full(shape, np.nan), + done=np.full(shape, np.nan), ) From 4008b620bf155a5e8d71316d51e883c61b07a20a Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Sun, 20 Nov 2022 18:50:09 +0530 Subject: [PATCH 5/6] Add changes --- .../algorithms/adversarial/common.py | 1 + .../policies/replay_buffer_wrapper.py | 68 ++++++++++++++++--- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 55e6610a8..0beb96187 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -581,6 +581,7 @@ def train( if callback: callback(r) self.logger.dump(self._global_step) + print("FINNNNNNNNNNNNNNNNNNNNNNNNNNNN") @overload def _torchify_array(self, ndarray: np.ndarray) -> th.Tensor: diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index f66834d04..e6372aa00 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -28,12 +28,16 @@ def _rollout_samples_to_reward_fn_input( buffer: RolloutBuffer, ) -> Mapping[str, np.ndarray]: """Convert a sample from a rollout buffer to a numpy array.""" - shape = buffer.observations.shape + obs_shape = buffer.observations.shape + done_shape = buffer.episode_starts.shape + print("obs:", obs_shape) + print("action shape:", buffer.actions.shape) + print("dones shape:", done_shape) return dict( state=buffer.observations, action=buffer.actions, - next_state=np.full(shape, np.nan), - done=np.full(shape, np.nan), + next_state=np.full(obs_shape, np.nan), + done=np.full(done_shape, np.nan), ) @@ -145,6 +149,7 @@ def __init__( # inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See: # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported" + print("HELLLLLLLLLLLLLLLLLLLLLLLLLLOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO") assert not isinstance(observation_space, spaces.Dict) self.rollout_buffer = rollout_buffer_class( buffer_size, @@ -160,6 +165,34 @@ def __init__( def pos(self) -> int: return self.rollout_buffer.pos + @property + def values(self): + return self.rollout_buffer.values + + @property + def observations(self): + return self.rollout_buffer.observations + + @property + def actions(self): + return self.rollout_buffer.actions + + @property + def log_probs(self): + return self.rollout_buffer.log_probs + + @property + def advantages(self): + return self.rollout_buffer.advantages + + @property + def rewards(self): + return self.rollout_buffer.rewards + + @property + def returns(self): + return self.rollout_buffer.returns + @pos.setter def pos(self, pos: int): self.rollout_buffer.pos = pos @@ -176,12 +209,29 @@ def reset(self): self.rollout_buffer.reset() def get(self, *args, **kwargs): - self.rollout_buffer.rewards = self.reward_fn( - **_rollout_samples_to_reward_fn_input(self.rollout_buffer), - ) - self.rollout_buffer.compute_returns_and_advantage(self.last_values, self.dones) - - return self.rollout_buffer.get(*args, **kwargs) + print("getting") + if not self.rollout_buffer.generator_ready: + print("os:", self.rollout_buffer.observations.shape) + input_dict = _rollout_samples_to_reward_fn_input(self.rollout_buffer) + print("os:", self.rollout_buffer.observations.shape) + rewards = np.zeros_like(self.rollout_buffer.rewards) + for i in range(self.buffer_size): + rewards[i] = self.reward_fn(**{k: v[i] for k, v in input_dict.items()}) + print("rewards i:", rewards[i]) + print({k: v[i] for k, v in input_dict.items()}) + + print("os:", self.rollout_buffer.observations.shape) + self.rollout_buffer.rewards = rewards + print("rewards:", rewards) + print("values:", self.values) + print("lv:", self.last_values) + self.rollout_buffer.compute_returns_and_advantage( + self.last_values, self.last_dones + ) + print("os:", self.rollout_buffer.observations.shape) + ret = self.rollout_buffer.get(*args, **kwargs) + print("os:", self.rollout_buffer.observations.shape) + return ret def add(self, *args, **kwargs): self.rollout_buffer.add(*args, **kwargs) From 9e4fdd44e0bb333286b591f0c439a348aed54d2b Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Tue, 22 Nov 2022 01:37:15 +0530 Subject: [PATCH 6/6] Add changes for onpolicy & offpolicy --- .../algorithms/adversarial/common.py | 12 +++------ .../policies/replay_buffer_wrapper.py | 25 ++++--------------- src/imitation/rewards/reward_nets.py | 2 +- 3 files changed, 10 insertions(+), 29 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 0beb96187..c6210a8a2 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -8,12 +8,9 @@ import torch as th import torch.utils.tensorboard as thboard import tqdm -from stable_baselines3.common import ( - base_class, - on_policy_algorithm, - policies, - type_aliases, -) +from stable_baselines3.common import base_class +from stable_baselines3.common import buffers as sb3_buffers +from stable_baselines3.common import on_policy_algorithm, policies, type_aliases from stable_baselines3.common import utils as sb3_utils from stable_baselines3.common import vec_env from stable_baselines3.sac import policies as sac_policies @@ -279,7 +276,7 @@ def __init__( buffer_size=self.gen_train_timesteps, observation_space=replay_buffer.observation_space, action_space=replay_buffer.action_space, - replay_buffer_class=replay_buffer.__class__, + replay_buffer_class=sb3_buffers.ReplayBuffer, reward_fn=self.reward_train.predict_processed, device=replay_buffer.device, n_envs=replay_buffer.n_envs, @@ -581,7 +578,6 @@ def train( if callback: callback(r) self.logger.dump(self._global_step) - print("FINNNNNNNNNNNNNNNNNNNNNNNNNNNN") @overload def _torchify_array(self, ndarray: np.ndarray) -> th.Tensor: diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index e6372aa00..535aff1a7 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -28,16 +28,11 @@ def _rollout_samples_to_reward_fn_input( buffer: RolloutBuffer, ) -> Mapping[str, np.ndarray]: """Convert a sample from a rollout buffer to a numpy array.""" - obs_shape = buffer.observations.shape - done_shape = buffer.episode_starts.shape - print("obs:", obs_shape) - print("action shape:", buffer.actions.shape) - print("dones shape:", done_shape) return dict( state=buffer.observations, action=buffer.actions, - next_state=np.full(obs_shape, np.nan), - done=np.full(done_shape, np.nan), + next_state=buffer.next_observations, + done=buffer.dones, ) @@ -68,7 +63,9 @@ def __init__( # DictReplayBuffer because the current RewardFn only takes in NumPy array-based # inputs, and SAC is the only use case for ReplayBuffer relabeling. See: # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 - assert replay_buffer_class is ReplayBuffer, "only ReplayBuffer is supported" + assert ( + replay_buffer_class is ReplayBuffer + ), f"only ReplayBuffer is supported: given {replay_buffer_class}" assert not isinstance(observation_space, spaces.Dict) self.replay_buffer = replay_buffer_class( buffer_size, @@ -149,7 +146,6 @@ def __init__( # inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See: # https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194 assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported" - print("HELLLLLLLLLLLLLLLLLLLLLLLLLLOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO") assert not isinstance(observation_space, spaces.Dict) self.rollout_buffer = rollout_buffer_class( buffer_size, @@ -209,28 +205,17 @@ def reset(self): self.rollout_buffer.reset() def get(self, *args, **kwargs): - print("getting") if not self.rollout_buffer.generator_ready: - print("os:", self.rollout_buffer.observations.shape) input_dict = _rollout_samples_to_reward_fn_input(self.rollout_buffer) - print("os:", self.rollout_buffer.observations.shape) rewards = np.zeros_like(self.rollout_buffer.rewards) for i in range(self.buffer_size): rewards[i] = self.reward_fn(**{k: v[i] for k, v in input_dict.items()}) - print("rewards i:", rewards[i]) - print({k: v[i] for k, v in input_dict.items()}) - print("os:", self.rollout_buffer.observations.shape) self.rollout_buffer.rewards = rewards - print("rewards:", rewards) - print("values:", self.values) - print("lv:", self.last_values) self.rollout_buffer.compute_returns_and_advantage( self.last_values, self.last_dones ) - print("os:", self.rollout_buffer.observations.shape) ret = self.rollout_buffer.get(*args, **kwargs) - print("os:", self.rollout_buffer.observations.shape) return ret def add(self, *args, **kwargs): diff --git a/src/imitation/rewards/reward_nets.py b/src/imitation/rewards/reward_nets.py index 4e6c747e3..00a634af6 100644 --- a/src/imitation/rewards/reward_nets.py +++ b/src/imitation/rewards/reward_nets.py @@ -723,7 +723,7 @@ def forward( # series of remaining potential shapings can lead to reward shaping # that does not preserve the optimal policy if the episodes have variable # length! - new_shaping = (1 - done.float()) * new_shaping_output + new_shaping = (1 - done.float().flatten()) * new_shaping_output final_rew = ( base_reward_net_output + self.discount_factor * new_shaping