From 57d188119c1561e929c9e95225bc3441774e737d Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Mon, 2 Feb 2026 22:53:05 +0530 Subject: [PATCH 1/7] Initial design --- skyrl-tx/tx/layers/connectors.py | 76 ++++++++++++++++++++++++++++++++ skyrl-tx/tx/models/deepseekv3.py | 37 ++++++++++++---- skyrl-tx/tx/utils/models.py | 3 ++ 3 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 skyrl-tx/tx/layers/connectors.py diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py new file mode 100644 index 000000000..72a4b8ac3 --- /dev/null +++ b/skyrl-tx/tx/layers/connectors.py @@ -0,0 +1,76 @@ +"""Connection mechanisms for transformer layers (residual, learned connectors, etc.).""" + +from flax import nnx +import jax +from jax import numpy as jnp + +from tx.layers.util import Param +from tx.layers.layernorm import RMSNorm + + +class Connector(nnx.Module): + + def __init__( + self, + hidden_dim: int, + expansion_rate: int, + *, + trainable: bool = False, + sinkhorn_iters: int = 20, + eps: float = 1e-5, + dtype: jnp.dtype, + rngs: nnx.Rngs, + ) -> None: + self.hidden_dim = hidden_dim + self.expansion_rate = expansion_rate + self.trainable = trainable + self.sinkhorn_iters = sinkhorn_iters + self.eps = eps + n = expansion_rate + C = hidden_dim + + self.norm = RMSNorm(hidden_dim, eps=eps, dtype=dtype, rngs=rngs) + + self.phi_pre = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) + self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) + self.phi_res = Param(n * C, n * n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) + + self.b_pre = Param(n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init()) + self.b_post = Param(n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init()) + self.b_res = Param(n, n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init()) + + self.alpha_pre = nnx.Param(jnp.array(0.01, dtype=dtype)) + self.alpha_post = nnx.Param(jnp.array(0.01, dtype=dtype)) + self.alpha_res = nnx.Param(jnp.array(0.01, dtype=dtype)) + + def _sinkhorn_knopp(self, M: jax.Array) -> jax.Array: + M = jnp.exp(M) + for _ in range(self.sinkhorn_iters): + M = M / (M.sum(axis=-1, keepdims=True) + self.eps) + M = M / (M.sum(axis=-2, keepdims=True) + self.eps) + return M + + def pre(self, x: jax.Array) -> jax.Array: + *batch_dims, n, C = x.shape + + x_flat = x.reshape(*batch_dims, n * C) + rms = jnp.sqrt(jnp.mean(x_flat * x_flat, axis=-1, keepdims=True) + self.eps) + x_norm = x_flat / rms + + tilde_H_pre = self.alpha_pre[...] * (x_norm @ self.phi_pre[...]) + self.b_pre[...] + tilde_H_post = self.alpha_post[...] * (x_norm @ self.phi_post[...]) + self.b_post[...] + tilde_H_res = self.alpha_res[...] * (x_norm @ self.phi_res[...]).reshape(*batch_dims, n, n) + self.b_res[...] + + H_pre = jax.nn.sigmoid(tilde_H_pre) + self._H_post = 2.0 * jax.nn.sigmoid(tilde_H_post) + self._M = self._sinkhorn_knopp(tilde_H_res) + + x_agg = jnp.einsum("...i,...ic->...c", H_pre, x) + x_normed = self.norm(x_agg) + + return x_normed + + def post(self, residual: jax.Array, output: jax.Array) -> jax.Array: + y_dist = self._H_post[..., None] * output[..., None, :] + x_mixed = jnp.einsum("...ij,...jc->...ic", self._M, residual) + return x_mixed + y_dist diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 9232832d1..aa1c138fb 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -7,6 +7,7 @@ from tx.layers.rotary_embedding import get_rope from tx.layers.util import Param, prepare_routing, shard_map_ep from tx.layers.layernorm import RMSNorm +from tx.layers.connectors import Connector from tx.models.configs import DeepseekV3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -417,17 +418,28 @@ def __call__( class DeepseekV3DecoderLayer(nnx.Module): - def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + def __init__( + self, + config: DeepseekV3Config, + layer_idx: int, + *, + dtype: jnp.dtype, + rngs: nnx.Rngs, + expansion_rate: int = 1, + ) -> None: self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) + self.layer_idx = layer_idx + self.num_layers = config.num_hidden_layers + self.expansion_rate = expansion_rate - # Use dense MLP for initial layers, MoE for the rest if layer_idx >= config.first_k_dense_replace: self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) else: self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) + self.attn_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs) + self.mlp_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs) + def __call__( self, hidden_states: jax.Array, @@ -437,8 +449,12 @@ def __call__( adapter_indices: jax.Array | None = None, kv_cache: tuple[jax.Array, jax.Array] | None = None, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + n = self.expansion_rate + if self.layer_idx == 0: + hidden_states = jnp.repeat(hidden_states[..., None, :], n, axis=-2) + residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.attn_connector.pre(hidden_states) hidden_states, updated_cache = self.self_attn( hidden_states, attention_mask=attention_mask, @@ -446,12 +462,15 @@ def __call__( adapter_indices=adapter_indices, kv_cache=kv_cache, ) - hidden_states = residual + hidden_states + hidden_states = self.attn_connector.post(residual, hidden_states) residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp_connector.pre(hidden_states) mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices) - hidden_states = residual + mlp_output + hidden_states = self.mlp_connector.post(residual, mlp_output) + + if self.layer_idx == self.num_layers - 1: + hidden_states = hidden_states.sum(axis=-2) return hidden_states, updated_cache @@ -500,7 +519,7 @@ def __call__( for layer_idx, layer in enumerate(self.layers): if output_hidden_states: - all_hidden_states.append(hidden_states) + all_hidden_states.append(hidden_states.squeeze()) hidden_states, (k, v) = layer( hidden_states, diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6e840febf..170262381 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -113,6 +113,9 @@ def load_safetensors( # Skip LoRA parameters if requested if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue + # Skip connector parameters + if any("connector" in str(p) for p in path): + continue if "experts" in path: tensors[key] = np.stack( [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 From 91d5e74c2249f435bc05ee0f6c397e1607e76d13 Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Mon, 2 Feb 2026 22:54:55 +0530 Subject: [PATCH 2/7] Add comment --- skyrl-tx/tx/layers/connectors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py index 72a4b8ac3..399312492 100644 --- a/skyrl-tx/tx/layers/connectors.py +++ b/skyrl-tx/tx/layers/connectors.py @@ -9,6 +9,7 @@ class Connector(nnx.Module): + """General implementation of (m?)Hyper Connections""" def __init__( self, From 24b82d72dee5ac904a0038080d620851894ba4f5 Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Mon, 2 Feb 2026 23:58:54 +0530 Subject: [PATCH 3/7] Identity mapping for initial passthrough --- skyrl-tx/tx/layers/connectors.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py index 399312492..5c2b6c101 100644 --- a/skyrl-tx/tx/layers/connectors.py +++ b/skyrl-tx/tx/layers/connectors.py @@ -36,13 +36,25 @@ def __init__( self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) self.phi_res = Param(n * C, n * n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) - self.b_pre = Param(n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init()) - self.b_post = Param(n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init()) - self.b_res = Param(n, n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init()) + # Initialize biases for identity-like behavior: + # H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing) - self.alpha_pre = nnx.Param(jnp.array(0.01, dtype=dtype)) - self.alpha_post = nnx.Param(jnp.array(0.01, dtype=dtype)) - self.alpha_res = nnx.Param(jnp.array(0.01, dtype=dtype)) + # H_pre = sigmoid(b_pre) = 1/n => b_pre = logit(1/n) + target_h_pre = jnp.array(1.0 / n, dtype=dtype) + clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6) + logit_1_over_n = jnp.log(clamped) - jnp.log(1.0 - clamped) + self.b_pre = nnx.Param(jnp.full((n,), logit_1_over_n, dtype=dtype)) + + # H_post = 2 * sigmoid(b_post) = 1 => b_post = 0 + self.b_post = nnx.Param(jnp.zeros((n,), dtype=dtype)) + + # M = sinkhorn(exp(b_res)) = I => b_res = large diagonal matrix + self.b_res = nnx.Param(10.0 * jnp.eye(n, dtype=dtype)) + + # Alpha = 0 so phi matrices don't contribute initially + self.alpha_pre = nnx.Param(jnp.array(0.0, dtype=dtype)) + self.alpha_post = nnx.Param(jnp.array(0.0, dtype=dtype)) + self.alpha_res = nnx.Param(jnp.array(0.0, dtype=dtype)) def _sinkhorn_knopp(self, M: jax.Array) -> jax.Array: M = jnp.exp(M) From 874ab0879ec9fb562870280067bcb284c4958dd9 Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Tue, 3 Feb 2026 00:30:45 +0530 Subject: [PATCH 4/7] Add trainable flag for freezing weights --- skyrl-tx/tx/layers/connectors.py | 29 +++++++++++++++++++++++------ skyrl-tx/tx/layers/layernorm.py | 2 +- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py index 5c2b6c101..ac1f54c28 100644 --- a/skyrl-tx/tx/layers/connectors.py +++ b/skyrl-tx/tx/layers/connectors.py @@ -9,7 +9,11 @@ class Connector(nnx.Module): - """General implementation of (m?)Hyper Connections""" + """ + Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880) + + Weights initialized with identity mapping; Default behaviour equates to residual networks. + """ def __init__( self, @@ -51,7 +55,6 @@ def __init__( # M = sinkhorn(exp(b_res)) = I => b_res = large diagonal matrix self.b_res = nnx.Param(10.0 * jnp.eye(n, dtype=dtype)) - # Alpha = 0 so phi matrices don't contribute initially self.alpha_pre = nnx.Param(jnp.array(0.0, dtype=dtype)) self.alpha_post = nnx.Param(jnp.array(0.0, dtype=dtype)) self.alpha_res = nnx.Param(jnp.array(0.0, dtype=dtype)) @@ -63,6 +66,16 @@ def _sinkhorn_knopp(self, M: jax.Array) -> jax.Array: M = M / (M.sum(axis=-2, keepdims=True) + self.eps) return M + def _get_params(self): + """Get all connector params, with stop_gradient applied if not trainable.""" + sg = (lambda x: x) if self.trainable else jax.lax.stop_gradient + return ( + sg(self.alpha_pre[...]), sg(self.alpha_post[...]), sg(self.alpha_res[...]), + sg(self.phi_pre[...]), sg(self.phi_post[...]), sg(self.phi_res[...]), + sg(self.b_pre[...]), sg(self.b_post[...]), sg(self.b_res[...]), + sg(self.norm.weight[...]), + ) + def pre(self, x: jax.Array) -> jax.Array: *batch_dims, n, C = x.shape @@ -70,16 +83,20 @@ def pre(self, x: jax.Array) -> jax.Array: rms = jnp.sqrt(jnp.mean(x_flat * x_flat, axis=-1, keepdims=True) + self.eps) x_norm = x_flat / rms - tilde_H_pre = self.alpha_pre[...] * (x_norm @ self.phi_pre[...]) + self.b_pre[...] - tilde_H_post = self.alpha_post[...] * (x_norm @ self.phi_post[...]) + self.b_post[...] - tilde_H_res = self.alpha_res[...] * (x_norm @ self.phi_res[...]).reshape(*batch_dims, n, n) + self.b_res[...] + (alpha_pre, alpha_post, alpha_res, phi_pre, phi_post, phi_res, + b_pre, b_post, b_res, norm_weight) = self._get_params() + + tilde_H_pre = alpha_pre * (x_norm @ phi_pre) + b_pre + tilde_H_post = alpha_post * (x_norm @ phi_post) + b_post + tilde_H_res = alpha_res * (x_norm @ phi_res).reshape(*batch_dims, n, n) + b_res H_pre = jax.nn.sigmoid(tilde_H_pre) self._H_post = 2.0 * jax.nn.sigmoid(tilde_H_post) self._M = self._sinkhorn_knopp(tilde_H_res) x_agg = jnp.einsum("...i,...ic->...c", H_pre, x) - x_normed = self.norm(x_agg) + rms_norm = jnp.sqrt(jnp.mean(x_agg**2, axis=-1, keepdims=True) + self.norm.eps) + x_normed = norm_weight * x_agg / rms_norm return x_normed diff --git a/skyrl-tx/tx/layers/layernorm.py b/skyrl-tx/tx/layers/layernorm.py index e061d6d09..2f95a3158 100644 --- a/skyrl-tx/tx/layers/layernorm.py +++ b/skyrl-tx/tx/layers/layernorm.py @@ -16,7 +16,7 @@ class RMSNorm(nnx.Module): def __init__(self, size: int, *, eps: float = 1e-6, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: self.eps = eps self.weight = Param( - size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs + size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs ) def __call__(self, x: jax.Array) -> jax.Array: From 975faa1d2f2083fc39894bbaabe89d83f79f5dda Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Tue, 3 Feb 2026 00:37:17 +0530 Subject: [PATCH 5/7] Stray comment --- skyrl-tx/tx/layers/connectors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py index ac1f54c28..553e5d9d7 100644 --- a/skyrl-tx/tx/layers/connectors.py +++ b/skyrl-tx/tx/layers/connectors.py @@ -1,5 +1,3 @@ -"""Connection mechanisms for transformer layers (residual, learned connectors, etc.).""" - from flax import nnx import jax from jax import numpy as jnp From f685543e301068ae3d183a41315700410c75fee2 Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Tue, 3 Feb 2026 00:55:26 +0530 Subject: [PATCH 6/7] simplify --- skyrl-tx/tx/layers/connectors.py | 27 ++++++++++----------------- skyrl-tx/tx/layers/util.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py index 553e5d9d7..ac98b57b5 100644 --- a/skyrl-tx/tx/layers/connectors.py +++ b/skyrl-tx/tx/layers/connectors.py @@ -2,7 +2,7 @@ import jax from jax import numpy as jnp -from tx.layers.util import Param +from tx.layers.util import Param, sinkhorn_knopp from tx.layers.layernorm import RMSNorm @@ -41,11 +41,11 @@ def __init__( # Initialize biases for identity-like behavior: # H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing) - # H_pre = sigmoid(b_pre) = 1/n => b_pre = logit(1/n) + # H_pre = sigmoid(b_pre) = 1/n => b_pre = inv_sigmoid(1/n) target_h_pre = jnp.array(1.0 / n, dtype=dtype) - clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6) - logit_1_over_n = jnp.log(clamped) - jnp.log(1.0 - clamped) - self.b_pre = nnx.Param(jnp.full((n,), logit_1_over_n, dtype=dtype)) + clamped = jnp.clip(target_h_pre, 1e-6, 1.0) + inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped) + self.b_pre = nnx.Param(jnp.full((n,), inv_sigmoid, dtype=dtype)) # H_post = 2 * sigmoid(b_post) = 1 => b_post = 0 self.b_post = nnx.Param(jnp.zeros((n,), dtype=dtype)) @@ -57,13 +57,6 @@ def __init__( self.alpha_post = nnx.Param(jnp.array(0.0, dtype=dtype)) self.alpha_res = nnx.Param(jnp.array(0.0, dtype=dtype)) - def _sinkhorn_knopp(self, M: jax.Array) -> jax.Array: - M = jnp.exp(M) - for _ in range(self.sinkhorn_iters): - M = M / (M.sum(axis=-1, keepdims=True) + self.eps) - M = M / (M.sum(axis=-2, keepdims=True) + self.eps) - return M - def _get_params(self): """Get all connector params, with stop_gradient applied if not trainable.""" sg = (lambda x: x) if self.trainable else jax.lax.stop_gradient @@ -89,16 +82,16 @@ def pre(self, x: jax.Array) -> jax.Array: tilde_H_res = alpha_res * (x_norm @ phi_res).reshape(*batch_dims, n, n) + b_res H_pre = jax.nn.sigmoid(tilde_H_pre) - self._H_post = 2.0 * jax.nn.sigmoid(tilde_H_post) - self._M = self._sinkhorn_knopp(tilde_H_res) + self.H_post = 2.0 * jax.nn.sigmoid(tilde_H_post) + self.M = sinkhorn_knopp(tilde_H_res, self.sinkhorn_iters, self.eps) - x_agg = jnp.einsum("...i,...ic->...c", H_pre, x) + x_agg = (H_pre[..., None] * x).sum(axis=-2) rms_norm = jnp.sqrt(jnp.mean(x_agg**2, axis=-1, keepdims=True) + self.norm.eps) x_normed = norm_weight * x_agg / rms_norm return x_normed def post(self, residual: jax.Array, output: jax.Array) -> jax.Array: - y_dist = self._H_post[..., None] * output[..., None, :] - x_mixed = jnp.einsum("...ij,...jc->...ic", self._M, residual) + y_dist = self.H_post[..., None] * output[..., None, :] + x_mixed = self.M @ residual return x_mixed + y_dist diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..238de2474 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -63,6 +63,18 @@ def Param(*shape: int, dtype: jnp.dtype, kernel_init: nnx.Initializer, rngs: nnx return nnx.Param(kernel_init(rngs.param(), shape, dtype)) +def sinkhorn_knopp(M: jax.Array, iters: int = 20, eps: float = 1e-5) -> jax.Array: + """Sinkhorn-Knopp algorithm to project a matrix onto the set of doubly stochastic matrices.""" + M = jnp.exp(M) + + def step(_, M): + M = M / (M.sum(axis=-1, keepdims=True) + eps) + M = M / (M.sum(axis=-2, keepdims=True) + eps) + return M + + return lax.fori_loop(0, iters, step, M) + + def prepare_routing( tokens: jax.Array, indices: jax.Array, num_groups: int, adapter_indices: jax.Array | None = None ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]: From e493ae58d628cfe08e51efb6440855bc5d48c1bd Mon Sep 17 00:00:00 2001 From: tanmaysachan Date: Tue, 3 Feb 2026 09:54:30 +0530 Subject: [PATCH 7/7] Add elementwise_affine flag to RMS to match pytorch impl. Replace raw rms in mhc --- skyrl-tx/tests/layers/test_connectors.py | 58 ++++++++++++++++++++++++ skyrl-tx/tx/layers/connectors.py | 15 ++---- skyrl-tx/tx/layers/layernorm.py | 16 +++++-- 3 files changed, 74 insertions(+), 15 deletions(-) create mode 100644 skyrl-tx/tests/layers/test_connectors.py diff --git a/skyrl-tx/tests/layers/test_connectors.py b/skyrl-tx/tests/layers/test_connectors.py new file mode 100644 index 000000000..806607fda --- /dev/null +++ b/skyrl-tx/tests/layers/test_connectors.py @@ -0,0 +1,58 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from flax import nnx + + +@pytest.fixture(scope="module") +def mesh(): + return jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + +@pytest.mark.parametrize("expansion_rate", [1, 2, 4]) +def test_connector_shapes(mesh, expansion_rate: int): + """Test that Connector produces correct output shapes.""" + with jax.set_mesh(mesh): + from tx.layers.connectors import Connector + + hidden_dim = 64 + batch, seq = 2, 8 + + conn = Connector(hidden_dim, expansion_rate, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + x = jnp.ones((batch, seq, expansion_rate, hidden_dim)) + pre_out = conn.pre(x) + post_out = conn.post(x, pre_out) + + assert pre_out.shape == (batch, seq, hidden_dim) + assert post_out.shape == (batch, seq, expansion_rate, hidden_dim) + + +@pytest.mark.parametrize("expansion_rate", [1, 2, 4]) +def test_connector_identity_initialization(mesh, expansion_rate: int): + """Test that Connector with identity initialization behaves like residual connection.""" + with jax.set_mesh(mesh): + from tx.layers.connectors import Connector + from tx.layers.util import sinkhorn_knopp + + hidden_dim = 64 + n = expansion_rate + + conn = Connector(hidden_dim, n, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + # Verify H_pre = 1/n + _, _, _, _, _, _, b_pre, _, _ = conn._get_params() + h_pre = jax.nn.sigmoid(b_pre) + assert np.allclose(h_pre, 1.0 / n, atol=1e-5) + + # Verify H_post = 1 + _, _, _, _, _, _, _, b_post, _ = conn._get_params() + h_post = 2.0 * jax.nn.sigmoid(b_post) + assert np.allclose(h_post, 1.0, atol=1e-6) + + # Verify M = I + _, _, _, _, _, _, _, _, b_res = conn._get_params() + M = sinkhorn_knopp(b_res) + assert np.allclose(M, jnp.eye(n), atol=1e-3) + diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py index ac98b57b5..1093e7d3b 100644 --- a/skyrl-tx/tx/layers/connectors.py +++ b/skyrl-tx/tx/layers/connectors.py @@ -32,7 +32,8 @@ def __init__( n = expansion_rate C = hidden_dim - self.norm = RMSNorm(hidden_dim, eps=eps, dtype=dtype, rngs=rngs) + self.input_norm = RMSNorm(n * C, eps=eps, elementwise_affine=False, dtype=dtype, rngs=rngs) + self.output_norm = RMSNorm(hidden_dim, eps=eps, elementwise_affine=trainable, dtype=dtype, rngs=rngs) self.phi_pre = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) @@ -64,18 +65,15 @@ def _get_params(self): sg(self.alpha_pre[...]), sg(self.alpha_post[...]), sg(self.alpha_res[...]), sg(self.phi_pre[...]), sg(self.phi_post[...]), sg(self.phi_res[...]), sg(self.b_pre[...]), sg(self.b_post[...]), sg(self.b_res[...]), - sg(self.norm.weight[...]), ) def pre(self, x: jax.Array) -> jax.Array: *batch_dims, n, C = x.shape - x_flat = x.reshape(*batch_dims, n * C) - rms = jnp.sqrt(jnp.mean(x_flat * x_flat, axis=-1, keepdims=True) + self.eps) - x_norm = x_flat / rms + x_norm = self.input_norm(x.reshape(*batch_dims, n * C)) (alpha_pre, alpha_post, alpha_res, phi_pre, phi_post, phi_res, - b_pre, b_post, b_res, norm_weight) = self._get_params() + b_pre, b_post, b_res) = self._get_params() tilde_H_pre = alpha_pre * (x_norm @ phi_pre) + b_pre tilde_H_post = alpha_post * (x_norm @ phi_post) + b_post @@ -86,10 +84,7 @@ def pre(self, x: jax.Array) -> jax.Array: self.M = sinkhorn_knopp(tilde_H_res, self.sinkhorn_iters, self.eps) x_agg = (H_pre[..., None] * x).sum(axis=-2) - rms_norm = jnp.sqrt(jnp.mean(x_agg**2, axis=-1, keepdims=True) + self.norm.eps) - x_normed = norm_weight * x_agg / rms_norm - - return x_normed + return self.output_norm(x_agg) def post(self, residual: jax.Array, output: jax.Array) -> jax.Array: y_dist = self.H_post[..., None] * output[..., None, :] diff --git a/skyrl-tx/tx/layers/layernorm.py b/skyrl-tx/tx/layers/layernorm.py index 2f95a3158..ca751b1cf 100644 --- a/skyrl-tx/tx/layers/layernorm.py +++ b/skyrl-tx/tx/layers/layernorm.py @@ -13,12 +13,18 @@ class RMSNorm(nnx.Module): Reference: https://arxiv.org/abs/1910.07467 """ - def __init__(self, size: int, *, eps: float = 1e-6, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + def __init__( + self, size: int, *, eps: float = 1e-6, elementwise_affine: bool = True, dtype: jnp.dtype, rngs: nnx.Rngs + ) -> None: self.eps = eps - self.weight = Param( - size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs - ) + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = Param( + size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs + ) def __call__(self, x: jax.Array) -> jax.Array: rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps) - return self.weight * x / rms + if self.elementwise_affine: + return self.weight * x / rms + return x / rms