Skip to content

Conversation

@tanmaysachan
Copy link
Contributor

@tanmaysachan tanmaysachan commented Feb 2, 2026

Addresses #952

This PR is a general implementation of Hyper connections.

This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.

Default case - Trainable is false. Expansion rate is 1.

  1. H_res is a single value matrix [1]
  2. H_pre and H_post are vectors of [1, 1, 1, ...] that result in no-op matmuls

For expansion rate > 1

  1. H_res is initialized as identity of size nxn (n is the expansion rate)
  2. H_pre is [1/n, 1/n, ...]
  3. H_post is [1, 1, 1, ...]

Todos

  • simplify rms integration - added elementwise_affine as a flag
  • Benchmark/ensure no regression for expansion_rate = 1 - minimal difference in step time when expansion rate is 1 and untrainable.

Future work

  • Fine tune on custom data with mHC + LoRA to see perf gains

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.

My review found a couple of issues:

  • An unused trainable parameter in the Connector class which should be removed for clarity.
  • A bug in DeepseekV3Model when handling intermediate hidden states for expansion_rate > 1, where squeeze() is used incorrectly.

Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.

A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.

Suggested change
all_hidden_states.append(hidden_states.squeeze())
all_hidden_states.append(hidden_states.mean(axis=-2))

hidden_dim: int,
expansion_rate: int,
*,
trainable: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trainable parameter is defined but it is not used anywhere in the Connector class. This could be misleading for developers using this module. Consider removing it from the method signature, and also the assignment self.trainable = trainable on line 27, to improve code clarity.

@pcmoritz pcmoritz added the tx label Feb 2, 2026
self.eps = eps
self.weight = Param(
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary, testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants