Skip to content

Conversation

@bilelsgh
Copy link

@bilelsgh bilelsgh commented Dec 5, 2025

Feature overview

Implementation of Prioritized Approximation Loss (PAL), an computationally efficient approximation of the Prioritized Experience Replay (PER).
A NeurIPS 2020 paper shows that using PER is equivalent to adapting the loss function while using uniform experience replay.

The expected gradient of the loss function (1/τ) * |δ(i)|^τ, where τ > 0, when used with PER, is equal to the expected gradient of the following loss under uniform sampling.
https://papers.neurips.cc/paper_files/paper/2020/file/a3bf6e4db673b6449c2f7d13ee6ec9c0-Paper.pdf

This means we can avoid managing a sorted buffer and the associated complexity, while still converging to the same gradient.

PAL is a very good alternative while waiting for an effective implementation of the Prioritized Experience Replay while being computationally efficient.

PAL can be used as a simple and computationally efficient replacement for LAP, requiring only an adjustment to the loss function used to train the Q-network.

Description

I've added a new loss function, which adapts the Huber Loss by incorporating priority as described in the referenced paper. The buffer itself performs uniform sampling (ReplayBuffer). Additionally, I implemented a PALReplayBuffer (and let the PrioritizedReplayBuffer for the Rainbow implementation (👋 @araffin) #622) to initialize the PAL parameters (following the paper) and to properly handle the case where the PAL is applied within the training method.

Test

The PAL was evaluated on 3 environments including 2 ATARI games that were also evaluated on the PAL paper. The results are displayed in the comments. The architecture of the NN and the DRL parameters are the same as in the PAL paper.

Layer In Out
Conv2d n_input_channels 32
Conv2d 32 32
Conv2d 32 64
Flatten 64 n_flatten )
Linear n_flatten features_dim

Atari games

For both Breakout and SpaceInvaders, the reward converges faster to higher value.
Screenshot 2025-12-05 at 19 16 13

Classic env

PAL leads to a better reward on Lunar Lander
Screenshot 2025-12-05 at 19 19 50

It's important to remember that PER (and then PAL) doesn't necessarily lead to better reward every time but depends on the context and the environment where the agent evolves.

Motivation and Context

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@bilelsgh
Copy link
Author

bilelsgh commented Dec 5, 2025

Here is the code used for testing

import gymnasium as gym
from gymnasium import spaces
from loguru import logger

from stable_baselines3 import DQN
from stable_baselines3.common.buffers import PALReplayBuffer
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import VecFrameStack
import torch.nn as nn
import torch as th
import ale_py

class CustomCNN(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)

        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))


def drl():
    env_names = ['LunarLander-v3']

    for env_name in env_names :
        for buffer in [None, PALReplayBuffer]:
            logger.info(f"Training on {env_name} with buffer: {buffer}")

            log_name = f"{env_name}_classic" if not buffer else f"{env_name}_PAL"
            env = gym.make(env_name)

            model = DQN("MlpPolicy",
                        env,
                        replay_buffer_class=buffer,
                        tensorboard_log="./board",
                        verbose=0,
                        device="mps")

            model.learn(total_timesteps=300000, log_interval=4, tb_log_name=log_name)

def drl_atari():

    gym.register_envs(ale_py)
    atari_games = [
        'ALE/Breakout-v5',
        'ALE/SpaceInvaders-v5',
        'ALE/Riverraid-v5'
    ]
    policy_kwargs = dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=512),
    )

    for game in atari_games:
        for buffer in [None]: #PALReplayBuffer]:
            log_name = f"{game.split('/')[-1]}_classic" if not buffer else f"{game.split('/')[-1]}_PAL"
            logger.info(f"Training on {game} with buffer: {buffer}")

            env = make_atari_env(game, n_envs=1, seed=42)
            env = VecFrameStack(env, n_stack=4)  # Stack de 4 frames

            model = DQN(
                "CnnPolicy",  # CNN pour images
                env,
                replay_buffer_class=buffer,
                learning_starts=10000,
                batch_size=32,
                exploration_fraction=0.1,
                exploration_final_eps=0.05,
                tensorboard_log="./atari_board",
                verbose=0,
                device="mps",
                policy_kwargs=policy_kwargs,
            )

            model.learn(total_timesteps=300000, log_interval=4, tb_log_name=log_name)

if __name__ == '__main__':
    #drl()
    drl_atari()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant