From 95b415bd81deb1285b27bbc821b5a147914fcb5b Mon Sep 17 00:00:00 2001 From: "Jerry.Bai" Date: Wed, 2 Jul 2025 18:07:12 -0700 Subject: [PATCH] add ResidualFCBlock to layers and add options in EncodingNetwork, ActorDistributionNetwork, and CirticNetwork to use it instead of (joint) FC layers --- alf/layers.py | 73 +++++++++++++++++++++ alf/networks/actor_distribution_networks.py | 16 +++++ alf/networks/critic_networks.py | 48 +++++++++----- alf/networks/encoding_networks.py | 33 +++++++++- 4 files changed, 154 insertions(+), 16 deletions(-) diff --git a/alf/layers.py b/alf/layers.py index b1fd2b94e..8dc512503 100644 --- a/alf/layers.py +++ b/alf/layers.py @@ -2507,6 +2507,79 @@ def _conv_transpose_2d(in_channels, bias=bias) +@alf.configurable +class ResidualFCBlock(nn.Module): + r"""The Residual block with FC layers. + + This is the Residual Feedforward block used in the following paper, replacing + the MLP layers. + + :: + + Lee et al "SimBA: Simplicity Bias for Scaling up Parameters in Deep Reinforcement Learning", arXiv:2410.09754 + + The block is defined as, + + :math:`x_{out} = x_{in} + 2-layer-MLP(LayerNorm(x_{in}))` + + """ + + def __init__(self, + input_size: int, + output_size: int, + hidden_size: Optional[int] = None, + use_bias: Optional[bool] = True, + use_output_ln: Optional[bool] = False, + activation: Callable = torch.relu_, + kernel_initializer: Callable[[Tensor], + None] = nn.init.kaiming_normal_, + bias_init_value: float = 0.0): + """ + Args: + input_size (int): input size + output_size (int): output size + hidden_sizes (int): size of the hidden layer. + use_bias (bool): whether to use bias for FC layers. + activation (Callable): activation for the hidden layer. + kernel_initializer (Callable): initializer for the FC layer kernel. + bias_init_value (float): a constant for the initial FC bias value. + """ + super().__init__() + self._use_output_ln = use_output_ln + if hidden_size is None: + hidden_size = input_size + fc1 = FC(input_size, + hidden_size, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_init_value=bias_init_value) + fc2 = FC(hidden_size, + output_size, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_init_value=bias_init_value) + self._core_layers = nn.Sequential(fc1, fc2) + self._ln = nn.LayerNorm(input_size) + if use_output_ln: + self._out_ln = nn.LayerNorm(output_size) + + def reset_parameters(self): + self._ln.reset_parameters() + for layer in self._core_layers: + layer.reset_parameters() + if self._use_output_ln: + self._out_ln.reset_parameters() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + core_inputs = self._ln(inputs) + core = self._core_layers(core_inputs) + outputs = core + inputs + if self._use_output_ln: + outputs = self._out_ln(outputs) + return outputs + + @alf.configurable(whitelist=[ 'with_batch_normalization', 'bn_ctor', 'weight_opt_args', 'activation' ]) diff --git a/alf/networks/actor_distribution_networks.py b/alf/networks/actor_distribution_networks.py index ee3bd9eb6..225e90f88 100644 --- a/alf/networks/actor_distribution_networks.py +++ b/alf/networks/actor_distribution_networks.py @@ -138,6 +138,10 @@ def __init__(self, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=None, + use_residual_fc_block=False, + num_residual_fc_blocks=1, + residual_fc_block_hidden_size=None, + residual_fc_block_use_output_ln=True, activation=torch.relu_, kernel_initializer=None, use_fc_bn=False, @@ -173,6 +177,14 @@ def __init__(self, where ``padding`` is optional. fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes. + use_residual_fc_block (bool): whether to use residual block instead of + FC layers. + num_residual_fc_blocks (int): number of residual FC blocks, only valid + if use_residual_fc_block is True. + residual_fc_block_hidden_size (int): hidden size of residual FC blocks, + only valid if use_residual_fc_block is True. + residual_fc_block_use_output_ln (bool): whether to use layer norm for + the output of residual FC block, only valid if use_residual_fc_block. activation (nn.functional): activation used for hidden layers. kernel_initializer (Callable): initializer for all the layers excluding the projection net. If none is provided a default @@ -201,6 +213,10 @@ def __init__(self, preprocessing_combiner=preprocessing_combiner, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params, + use_residual_fc_block=use_residual_fc_block, + num_residual_fc_blocks=num_residual_fc_blocks, + residual_fc_block_hidden_size=residual_fc_block_hidden_size, + residual_fc_block_use_output_ln=residual_fc_block_use_output_ln, activation=activation, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, diff --git a/alf/networks/critic_networks.py b/alf/networks/critic_networks.py index e663e0651..4bb58f39a 100644 --- a/alf/networks/critic_networks.py +++ b/alf/networks/critic_networks.py @@ -79,6 +79,10 @@ def __init__(self, action_fc_layer_params=None, observation_action_combiner=None, joint_fc_layer_params=None, + joint_use_residual_fc_block=False, + joint_num_residual_fc_blocks=1, + joint_residual_fc_block_hidden_size=None, + joint_residual_fc_block_use_output_ln=True, activation=torch.relu_, kernel_initializer=None, use_fc_bn=False, @@ -124,6 +128,14 @@ def __init__(self, joint_fc_layer_params (tuple[int]): a tuple of integers representing hidden FC layer sizes FC layers after merging observations and actions. + joint_use_residual_fc_block (bool): whether to use residual block instead + of FC layers after merging observations and actions. + joint_num_residual_fc_blocks (int): number of joint residual FC blocks, + only valid if joint_use_residual_fc_block is True. + joint_residual_fc_block_hidden_size (int): hidden size of residual FC + blocks, only valid if joint_use_residual_fc_block is True. + joint_residual_fc_block_use_output_ln (bool): whether to use layer norm + for the output of joint residual FC block. activation (nn.functional): activation used for hidden layers. The last layer will not be activated. kernel_initializer (Callable): initializer for all the layers but @@ -184,21 +196,27 @@ def __init__(self, if observation_action_combiner is None: observation_action_combiner = alf.layers.NestConcat(dim=-1) - super().__init__(input_tensor_spec=input_tensor_spec, - output_tensor_spec=output_tensor_spec, - input_preprocessors=(obs_encoder, action_encoder), - preprocessing_combiner=observation_action_combiner, - fc_layer_params=joint_fc_layer_params, - activation=activation, - kernel_initializer=kernel_initializer, - use_fc_bn=use_fc_bn, - use_fc_ln=use_fc_ln, - last_layer_size=output_tensor_spec.numel, - last_activation=last_layer_activation, - last_kernel_initializer=last_kernel_initializer, - last_use_fc_bn=last_use_fc_bn, - last_use_fc_ln=last_use_fc_ln, - name=name) + super().__init__( + input_tensor_spec=input_tensor_spec, + output_tensor_spec=output_tensor_spec, + input_preprocessors=(obs_encoder, action_encoder), + preprocessing_combiner=observation_action_combiner, + fc_layer_params=joint_fc_layer_params, + use_residual_fc_block=joint_use_residual_fc_block, + num_residual_fc_blocks=joint_num_residual_fc_blocks, + residual_fc_block_hidden_size=joint_residual_fc_block_hidden_size, + residual_fc_block_use_output_ln= + joint_residual_fc_block_use_output_ln, + activation=activation, + kernel_initializer=kernel_initializer, + use_fc_bn=use_fc_bn, + use_fc_ln=use_fc_ln, + last_layer_size=output_tensor_spec.numel, + last_activation=last_layer_activation, + last_kernel_initializer=last_kernel_initializer, + last_use_fc_bn=last_use_fc_bn, + last_use_fc_ln=last_use_fc_ln, + name=name) self._use_naive_parallel_network = use_naive_parallel_network def make_parallel(self, n): diff --git a/alf/networks/encoding_networks.py b/alf/networks/encoding_networks.py index 2f689a45b..d7f5d4560 100644 --- a/alf/networks/encoding_networks.py +++ b/alf/networks/encoding_networks.py @@ -619,6 +619,10 @@ def __init__(self, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=None, + use_residual_fc_block=False, + num_residual_fc_blocks=1, + residual_fc_block_hidden_size=None, + residual_fc_block_use_output_ln=True, activation=torch.relu_, kernel_initializer=None, use_fc_bn=False, @@ -668,6 +672,14 @@ def __init__(self, where ``padding`` is optional. fc_layer_params (tuple[int]): a tuple of integers representing FC layer sizes. + use_residual_fc_block (bool): whether to use residual block instead of + FC layers. + num_residual_fc_blocks (int): number of residual FC blocks, only valid + if use_residual_fc_block is True. + residual_fc_block_hidden_size (int): hidden size of residual FC blocks, + only valid if use_residual_fc_block is True. + residual_fc_block_use_output_ln (bool): whether to use layer norm for + the output of residual FC block, only valid if use_residual_fc_block. activation (nn.functional): activation used for all the layers but the last layer. kernel_initializer (Callable): initializer for all the layers but @@ -766,7 +778,7 @@ def __init__(self, f"The input shape {spec.shape} should be like (N, )" "or (N, D, ).") - if fc_layer_params is None: + if fc_layer_params is None or use_residual_fc_block: fc_layer_params = [] else: assert isinstance(fc_layer_params, tuple) @@ -790,6 +802,25 @@ def __init__(self, kernel_initializer=kernel_initializer)) input_size = size + if use_residual_fc_block: + if residual_fc_block_hidden_size is None: + residual_fc_block_hidden_size = input_size + nets.append( + fc_layer_ctor(input_size, + residual_fc_block_hidden_size, + activation=activation, + use_bn=use_fc_bn, + use_ln=use_fc_ln, + kernel_initializer=kernel_initializer)) + input_size = residual_fc_block_hidden_size + for _ in range(num_residual_fc_blocks): + nets.append( + layers.ResidualFCBlock( + input_size, + residual_fc_block_hidden_size, + use_output_ln=residual_fc_block_use_output_ln)) + input_size = residual_fc_block_hidden_size + if last_layer_size is not None or last_activation is not None: assert last_layer_size is not None and last_activation is not None, \ "Both last_layer_size and last_activation need to be specified!"