From dd73a3805a501a85a7990ba2f812c721520f7608 Mon Sep 17 00:00:00 2001 From: Simeon Manolov Date: Thu, 14 Sep 2023 13:20:40 +0300 Subject: [PATCH 1/4] support for SB3 callbacks in adversarial training --- .../algorithms/adversarial/common.py | 33 +++++++++++------ src/imitation/scripts/train_adversarial.py | 28 ++++++++++++-- tests/algorithms/test_adversarial.py | 37 +++++++++++++++++++ 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index ece30b011..da153d4dc 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,7 +2,7 @@ import abc import dataclasses import logging -from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload +from typing import Iterable, Iterator, Mapping, Optional, Type, overload import numpy as np import torch as th @@ -15,6 +15,8 @@ policies, vec_env, ) +from stable_baselines3.common.type_aliases import MaybeCallback +from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F @@ -392,6 +394,7 @@ def train_gen( self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None, + callback: MaybeCallback = None, ) -> None: """Trains the generator to maximize the discriminator loss. @@ -404,17 +407,27 @@ def train_gen( `self.gen_train_timesteps`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. + callback: additional callback(s) passed to the generator's `learn` method. """ if total_timesteps is None: total_timesteps = self.gen_train_timesteps if learn_kwargs is None: learn_kwargs = {} + callbacks = [self.gen_callback] + + if isinstance(callback, list): + callbacks.extend(callback) + elif isinstance(callback, BaseCallback): + callbacks.append(callback) + elif callback is not None: + callbacks.append(ConvertCallback(callback)) + with self.logger.accumulate_means("gen"): self.gen_algo.learn( total_timesteps=total_timesteps, reset_num_timesteps=False, - callback=self.gen_callback, + callback=callbacks, **learn_kwargs, ) self._global_step += 1 @@ -427,12 +440,12 @@ def train_gen( def train( self, total_timesteps: int, - callback: Optional[Callable[[int], None]] = None, + callback: MaybeCallback = None, ) -> None: """Alternates between training the generator and discriminator. - Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`, - a call to `train_disc`, and finally a call to `callback(round)`. + Every "round" consists of a call to + `train_gen(self.gen_train_timesteps, callback)`, then a call to `train_disc`. Training ends once an additional "round" would cause the number of transitions sampled from the environment to exceed `total_timesteps`. @@ -440,9 +453,7 @@ def train( Args: total_timesteps: An upper bound on the number of transitions to sample from the environment during training. - callback: A function called at the end of every round which takes in a - single argument, the round number. Round numbers are in - `range(total_timesteps // self.gen_train_timesteps)`. + callback: callback(s) passed to the generator's `learn` method. """ n_rounds = total_timesteps // self.gen_train_timesteps assert n_rounds >= 1, ( @@ -450,14 +461,12 @@ def train( f"{self.gen_train_timesteps} timesteps, have only " f"total_timesteps={total_timesteps})!" ) - for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.train_gen(self.gen_train_timesteps) + for _r in tqdm.tqdm(range(0, n_rounds), desc="round"): + self.train_gen(self.gen_train_timesteps, callback=callback) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() - if callback: - callback(r) self.logger.dump(self._global_step) @overload diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 26c8d7bcf..8de7d22b7 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -8,6 +8,7 @@ import sacred.commands import torch as th from sacred.observers import FileStorageObserver +from stable_baselines3.common.callbacks import BaseCallback from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common @@ -22,6 +23,28 @@ logger = logging.getLogger("imitation.scripts.train_adversarial") +class CheckpointCallback(BaseCallback): + def __init__( + self, + trainer: common.AdversarialTrainer, + log_dir: pathlib.Path, + interval: int + ): + super().__init__(self) + self.trainer = trainer + self.log_dir = log_dir + self.interval = interval + self.round_num = 0 + + def _on_step(self) -> bool: + return True + + def _on_training_end(self) -> None: + self.round_num += 1 + if self.interval > 0 and self.round_num % self.interval == 0: + save(self.trainer, self.log_dir / "checkpoints" / f"{self.round_num:05d}") + + def save(trainer: common.AdversarialTrainer, save_path: pathlib.Path): """Save discriminator and generator.""" # We implement this here and not in Trainer since we do not want to actually @@ -153,10 +176,7 @@ def train_adversarial( **algorithm_kwargs, ) - def callback(round_num: int, /) -> None: - if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, log_dir / "checkpoints" / f"{round_num:05d}") - + callback = CheckpointCallback(trainer, log_dir, checkpoint_interval) trainer.train(total_timesteps, callback) imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train) diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index d3609efaa..5153e98db 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -10,6 +10,7 @@ import stable_baselines3 import torch as th from stable_baselines3.common import policies +from stable_baselines3.common.callbacks import BaseCallback from torch.utils import data as th_data from imitation.algorithms.adversarial import airl, common, gail @@ -464,3 +465,39 @@ def test_regression_gail_with_sac( reward_net=reward_net, ) gail_trainer.train(8) + + +def test_gen_callback(trainer: common.AdversarialTrainer): + learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv) + + def make_fn_callback(calls, key): + def cb(_a, _b): + calls[key] += 1 + return cb + + class SB3Callback(BaseCallback): + def __init__(self, calls, key): + super().__init__(self) + self.calls = calls + self.key = key + + def _on_step(self): + self.calls[self.key] += 1 + return True + + n_steps = trainer.gen_train_timesteps * 2 + calls = {"fn": 0, "sb3": 0, "list.0": 0, "list.1": 0} + + trainer.train(n_steps, callback=make_fn_callback(calls, "fn")) + trainer.train(n_steps, callback=SB3Callback(calls, "sb3")) + trainer.train(n_steps, callback=[ + SB3Callback(calls, "list.0"), + SB3Callback(calls, "list.1") + ]) + + # Env steps for off-plicy algos (DQN) may exceed `total_timesteps`, + # so we check if the callback was called *at least* that many times. + assert calls["fn"] >= n_steps + assert calls["sb3"] >= n_steps + assert calls["list.0"] >= n_steps + assert calls["list.1"] >= n_steps From 6e622c53f2e2ac8e532645bceb31071f05877bb1 Mon Sep 17 00:00:00 2001 From: Simeon Manolov Date: Thu, 14 Sep 2023 15:54:56 +0300 Subject: [PATCH 2/4] fix lint errors --- src/imitation/scripts/train_adversarial.py | 5 ++++- tests/algorithms/test_adversarial.py | 11 +++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 8de7d22b7..9e1da2d5b 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -24,12 +24,15 @@ class CheckpointCallback(BaseCallback): + """A callback for calling `save` at regular intervals.""" + def __init__( self, trainer: common.AdversarialTrainer, log_dir: pathlib.Path, - interval: int + interval: int, ): + """Creates new Checkpoint callback.""" super().__init__(self) self.trainer = trainer self.log_dir = log_dir diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index 5153e98db..118a8b49f 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -468,11 +468,10 @@ def test_regression_gail_with_sac( def test_gen_callback(trainer: common.AdversarialTrainer): - learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv) - def make_fn_callback(calls, key): def cb(_a, _b): calls[key] += 1 + return cb class SB3Callback(BaseCallback): @@ -490,10 +489,10 @@ def _on_step(self): trainer.train(n_steps, callback=make_fn_callback(calls, "fn")) trainer.train(n_steps, callback=SB3Callback(calls, "sb3")) - trainer.train(n_steps, callback=[ - SB3Callback(calls, "list.0"), - SB3Callback(calls, "list.1") - ]) + trainer.train( + n_steps, + callback=[SB3Callback(calls, "list.0"), SB3Callback(calls, "list.1")], + ) # Env steps for off-plicy algos (DQN) may exceed `total_timesteps`, # so we check if the callback was called *at least* that many times. From 8ef5e0012d85031ca9850f42ca63d70574f05831 Mon Sep 17 00:00:00 2001 From: Simeon Manolov Date: Tue, 19 Sep 2023 13:33:50 +0300 Subject: [PATCH 3/4] fix mypy errors --- src/imitation/algorithms/adversarial/common.py | 7 +++++-- src/imitation/scripts/train_adversarial.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index da153d4dc..4dfd554a0 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,7 +2,7 @@ import abc import dataclasses import logging -from typing import Iterable, Iterator, Mapping, Optional, Type, overload +from typing import Iterable, Iterator, Mapping, Optional, Type, List, overload import numpy as np import torch as th @@ -414,7 +414,10 @@ def train_gen( if learn_kwargs is None: learn_kwargs = {} - callbacks = [self.gen_callback] + callbacks: List[BaseCallback] = [] + + if self.gen_callback: + callbacks.append(self.gen_callback) if isinstance(callback, list): callbacks.extend(callback) diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 9e1da2d5b..90ba777ad 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -33,7 +33,7 @@ def __init__( interval: int, ): """Creates new Checkpoint callback.""" - super().__init__(self) + super().__init__() self.trainer = trainer self.log_dir = log_dir self.interval = interval From 159fd8ef0c60facaa85b9ce3ac8051ddbfefb4f1 Mon Sep 17 00:00:00 2001 From: Simeon Manolov Date: Wed, 20 Sep 2023 00:37:51 +0300 Subject: [PATCH 4/4] fix isort errors --- src/imitation/algorithms/adversarial/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 4dfd554a0..8049cca97 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,7 +2,7 @@ import abc import dataclasses import logging -from typing import Iterable, Iterator, Mapping, Optional, Type, List, overload +from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload import numpy as np import torch as th @@ -15,8 +15,8 @@ policies, vec_env, ) -from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback +from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F