From 01d5c41007651fb277199dff90028713b453c696 Mon Sep 17 00:00:00 2001 From: Clemens Schwarke Date: Mon, 10 Nov 2025 18:28:57 +0100 Subject: [PATCH 1/2] restructure rollout storage for clarity --- pyproject.toml | 2 +- rsl_rl/algorithms/distillation.py | 28 ++-------- rsl_rl/algorithms/ppo.py | 51 +++++++++--------- rsl_rl/runners/distillation_runner.py | 17 +++--- rsl_rl/runners/on_policy_runner.py | 17 +++--- rsl_rl/storage/rollout_storage.py | 75 +++++++++++---------------- 6 files changed, 76 insertions(+), 114 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4d9e8eb..dadeb635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,5 +56,5 @@ reportMissingImports = "none" # This is required to ignore type checks of modules with stubs missing. reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers reportGeneralTypeIssues = "none" # -> usage of literal MISSING in dataclasses -reportOptionalMemberAccess = "warning" +reportOptionalMemberAccess = "none" reportPrivateUsage = "warning" diff --git a/rsl_rl/algorithms/distillation.py b/rsl_rl/algorithms/distillation.py index 5da039a9..2bdb4baf 100644 --- a/rsl_rl/algorithms/distillation.py +++ b/rsl_rl/algorithms/distillation.py @@ -21,6 +21,7 @@ class Distillation: def __init__( self, policy: StudentTeacher | StudentTeacherRecurrent, + storage: RolloutStorage, num_learning_epochs: int = 1, gradient_length: int = 15, learning_rate: float = 1e-3, @@ -46,12 +47,12 @@ def __init__( # Distillation components self.policy = policy self.policy.to(self.device) - self.storage = None # Initialized later - # Initialize the optimizer + # Create the optimizer self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate) - # Initialize the transition + # Add storage + self.storage = storage self.transition = RolloutStorage.Transition() self.last_hidden_states = (None, None) @@ -73,24 +74,6 @@ def __init__( self.num_updates = 0 - def init_storage( - self, - training_type: str, - num_envs: int, - num_transitions_per_env: int, - obs: TensorDict, - actions_shape: tuple[int], - ) -> None: - # Create rollout storage - self.storage = RolloutStorage( - training_type, - num_envs, - num_transitions_per_env, - obs, - actions_shape, - self.device, - ) - def act(self, obs: TensorDict) -> torch.Tensor: # Compute the actions self.transition.actions = self.policy.act(obs).detach() @@ -104,12 +87,11 @@ def process_env_step( ) -> None: # Update the normalizers self.policy.update_normalization(obs) - # Record the rewards and dones self.transition.rewards = rewards self.transition.dones = dones # Record the transition - self.storage.add_transitions(self.transition) + self.storage.add_transition(self.transition) self.transition.clear() self.policy.reset(dones) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 1479c06a..1779b5db 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -26,6 +26,7 @@ class PPO: def __init__( self, policy: ActorCritic | ActorCriticRecurrent, + storage: RolloutStorage, num_learning_epochs: int = 5, num_mini_batches: int = 4, clip_param: float = 0.2, @@ -38,8 +39,8 @@ def __init__( use_clipped_value_loss: bool = True, schedule: str = "adaptive", desired_kl: float = 0.01, - device: str = "cpu", normalize_advantage_per_mini_batch: bool = False, + device: str = "cpu", # RND parameters rnd_cfg: dict | None = None, # Symmetry parameters @@ -100,11 +101,11 @@ def __init__( self.policy = policy self.policy.to(self.device) - # Create optimizer + # Create the optimizer self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate) - # Create rollout storage - self.storage: RolloutStorage | None = None + # Add storage + self.storage = storage self.transition = RolloutStorage.Transition() # PPO parameters @@ -122,24 +123,6 @@ def __init__( self.learning_rate = learning_rate self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch - def init_storage( - self, - training_type: str, - num_envs: int, - num_transitions_per_env: int, - obs: TensorDict, - actions_shape: tuple[int] | list[int], - ) -> None: - # Create rollout storage - self.storage = RolloutStorage( - training_type, - num_envs, - num_transitions_per_env, - obs, - actions_shape, - self.device, - ) - def act(self, obs: TensorDict) -> torch.Tensor: if self.policy.is_recurrent: self.transition.hidden_states = self.policy.get_hidden_states() @@ -180,16 +163,32 @@ def process_env_step( ) # Record the transition - self.storage.add_transitions(self.transition) + self.storage.add_transition(self.transition) self.transition.clear() self.policy.reset(dones) def compute_returns(self, obs: TensorDict) -> None: + st = self.storage # Compute value for the last step last_values = self.policy.evaluate(obs).detach() - self.storage.compute_returns( - last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch - ) + # Compute returns and advantages + advantage = 0 + for step in reversed(range(st.num_transitions_per_env)): + # If we are at the last step, bootstrap the return value + next_values = last_values if step == st.num_transitions_per_env - 1 else st.values[step + 1] + # 1 if we are not in a terminal state, 0 otherwise + next_is_not_terminal = 1.0 - st.dones[step].float() + # TD error: r_t + gamma * V(s_{t+1}) - V(s_t) + delta = st.rewards[step] + next_is_not_terminal * self.gamma * next_values - st.values[step] + # Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1}) + advantage = delta + next_is_not_terminal * self.gamma * self.lam * advantage + # Return: R_t = A(s_t, a_t) + V(s_t) + st.returns[step] = advantage + st.values[step] + # Compute the advantages + st.advantages = st.returns - st.values + # Normalize the advantages if per minibatch normalization is not used + if not self.normalize_advantage_per_mini_batch: + st.advantages = (st.advantages - st.advantages.mean()) / (st.advantages.std() + 1e-8) def update(self) -> dict[str, float]: mean_value_loss = 0 diff --git a/rsl_rl/runners/distillation_runner.py b/rsl_rl/runners/distillation_runner.py index 6f9c502b..d83da744 100644 --- a/rsl_rl/runners/distillation_runner.py +++ b/rsl_rl/runners/distillation_runner.py @@ -16,6 +16,7 @@ from rsl_rl.env import VecEnv from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent from rsl_rl.runners import OnPolicyRunner +from rsl_rl.storages import RolloutStorage from rsl_rl.utils import resolve_obs_groups, store_code_state @@ -158,19 +159,15 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation: obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) + # Initialize the storage + storage = RolloutStorage( + "distillation", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device + ) + # Initialize the algorithm alg_class = eval(self.alg_cfg.pop("class_name")) alg: Distillation = alg_class( - student_teacher, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg - ) - - # Initialize the storage - alg.init_storage( - "distillation", - self.env.num_envs, - self.num_steps_per_env, - obs, - [self.env.num_actions], + student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg ) return alg diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 46a9b524..a5864e52 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -17,6 +17,7 @@ from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.storage import RolloutStorage from rsl_rl.utils import resolve_obs_groups, store_code_state @@ -418,17 +419,15 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) + # Initialize the storage + storage = RolloutStorage( + "rl", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device + ) + # Initialize the algorithm alg_class = eval(self.alg_cfg.pop("class_name")) - alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg) - - # Initialize the storage - alg.init_storage( - "rl", - self.env.num_envs, - self.num_steps_per_env, - obs, - [self.env.num_actions], + alg: PPO = alg_class( + actor_critic, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg ) return alg diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py index 539a57bb..37c88efa 100644 --- a/rsl_rl/storage/rollout_storage.py +++ b/rsl_rl/storage/rollout_storage.py @@ -14,7 +14,15 @@ class RolloutStorage: + """Storage for the data collected during a rollout. + + The rollout storage is populated by adding transitions during the rollout phase. It then returns a generator for + learning, depending on the algorithm and the policy architecture. + """ + class Transition: + """Storage for a single state transition.""" + def __init__(self) -> None: self.observations: TensorDict | None = None self.actions: torch.Tensor | None = None @@ -75,7 +83,7 @@ def __init__( # Counter for the number of transitions stored self.step = 0 - def add_transitions(self, transition: Transition) -> None: + def add_transition(self, transition: Transition) -> None: # Check if the transition is valid if self.step >= self.num_transitions_per_env: raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.") @@ -103,53 +111,9 @@ def add_transitions(self, transition: Transition) -> None: # Increment the counter self.step += 1 - def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None: - if hidden_states == (None, None): - return - # Make a tuple out of GRU hidden states to match the LSTM format - hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) - hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) - # Initialize hidden states if needed - if self.saved_hidden_state_a is None: - self.saved_hidden_state_a = [ - torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device) - for i in range(len(hidden_state_a)) - ] - self.saved_hidden_state_c = [ - torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device) - for i in range(len(hidden_state_c)) - ] - # Copy the states - for i in range(len(hidden_state_a)): - self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i]) - self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i]) - def clear(self) -> None: self.step = 0 - def compute_returns( - self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True - ) -> None: - advantage = 0 - for step in reversed(range(self.num_transitions_per_env)): - # If we are at the last step, bootstrap the return value - next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1] - # 1 if we are not in a terminal state, 0 otherwise - next_is_not_terminal = 1.0 - self.dones[step].float() - # TD error: r_t + gamma * V(s_{t+1}) - V(s_t) - delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] - # Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1}) - advantage = delta + next_is_not_terminal * gamma * lam * advantage - # Return: R_t = A(s_t, a_t) + V(s_t) - self.returns[step] = advantage + self.values[step] - - # Compute the advantages - self.advantages = self.returns - self.values - # Normalize the advantages if flag is set - # Note: This is to prevent double normalization (i.e. if per minibatch normalization is used) - if normalize_advantage: - self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) - # For distillation def generator(self) -> Generator: if self.training_type != "distillation": @@ -289,3 +253,24 @@ def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int ) first_traj = last_traj + + def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None: + if hidden_states == (None, None): + return + # Make a tuple out of GRU hidden states to match the LSTM format + hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) + hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) + # Initialize hidden states if needed + if self.saved_hidden_state_a is None: + self.saved_hidden_state_a = [ + torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device) + for i in range(len(hidden_state_a)) + ] + self.saved_hidden_state_c = [ + torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device) + for i in range(len(hidden_state_c)) + ] + # Copy the states + for i in range(len(hidden_state_a)): + self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i]) + self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i]) From 864cc1d0f8a022cf8da29c30922a2be54c45314a Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Thu, 13 Nov 2025 10:44:41 +0100 Subject: [PATCH 2/2] fix impotr --- rsl_rl/runners/distillation_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rsl_rl/runners/distillation_runner.py b/rsl_rl/runners/distillation_runner.py index d83da744..912bee5f 100644 --- a/rsl_rl/runners/distillation_runner.py +++ b/rsl_rl/runners/distillation_runner.py @@ -16,7 +16,7 @@ from rsl_rl.env import VecEnv from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent from rsl_rl.runners import OnPolicyRunner -from rsl_rl.storages import RolloutStorage +from rsl_rl.storage import RolloutStorage from rsl_rl.utils import resolve_obs_groups, store_code_state