-
Notifications
You must be signed in to change notification settings - Fork 247
[tx] General implementation of trainable Hyper Connections #1008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.
My review found a couple of issues:
- An unused
trainableparameter in theConnectorclass which should be removed for clarity. - A bug in
DeepseekV3Modelwhen handling intermediate hidden states forexpansion_rate > 1, wheresqueeze()is used incorrectly.
Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.
A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.
| all_hidden_states.append(hidden_states.squeeze()) | |
| all_hidden_states.append(hidden_states.mean(axis=-2)) |
| hidden_dim: int, | ||
| expansion_rate: int, | ||
| *, | ||
| trainable: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skyrl-tx/tx/layers/layernorm.py
Outdated
| self.eps = eps | ||
| self.weight = Param( | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporary, testing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html
Torch also initalizes to one by default
Addresses #952
This PR is a general implementation of Hyper connections.
This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.
Default case - Trainable is false. Expansion rate is 1.
For expansion rate > 1
Todos
Future work