diff --git a/skyrl-tx/tests/layers/test_connectors.py b/skyrl-tx/tests/layers/test_connectors.py new file mode 100644 index 000000000..806607fda --- /dev/null +++ b/skyrl-tx/tests/layers/test_connectors.py @@ -0,0 +1,58 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from flax import nnx + + +@pytest.fixture(scope="module") +def mesh(): + return jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + +@pytest.mark.parametrize("expansion_rate", [1, 2, 4]) +def test_connector_shapes(mesh, expansion_rate: int): + """Test that Connector produces correct output shapes.""" + with jax.set_mesh(mesh): + from tx.layers.connectors import Connector + + hidden_dim = 64 + batch, seq = 2, 8 + + conn = Connector(hidden_dim, expansion_rate, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + x = jnp.ones((batch, seq, expansion_rate, hidden_dim)) + pre_out = conn.pre(x) + post_out = conn.post(x, pre_out) + + assert pre_out.shape == (batch, seq, hidden_dim) + assert post_out.shape == (batch, seq, expansion_rate, hidden_dim) + + +@pytest.mark.parametrize("expansion_rate", [1, 2, 4]) +def test_connector_identity_initialization(mesh, expansion_rate: int): + """Test that Connector with identity initialization behaves like residual connection.""" + with jax.set_mesh(mesh): + from tx.layers.connectors import Connector + from tx.layers.util import sinkhorn_knopp + + hidden_dim = 64 + n = expansion_rate + + conn = Connector(hidden_dim, n, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + # Verify H_pre = 1/n + _, _, _, _, _, _, b_pre, _, _ = conn._get_params() + h_pre = jax.nn.sigmoid(b_pre) + assert np.allclose(h_pre, 1.0 / n, atol=1e-5) + + # Verify H_post = 1 + _, _, _, _, _, _, _, b_post, _ = conn._get_params() + h_post = 2.0 * jax.nn.sigmoid(b_post) + assert np.allclose(h_post, 1.0, atol=1e-6) + + # Verify M = I + _, _, _, _, _, _, _, _, b_res = conn._get_params() + M = sinkhorn_knopp(b_res) + assert np.allclose(M, jnp.eye(n), atol=1e-3) + diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py new file mode 100644 index 000000000..1093e7d3b --- /dev/null +++ b/skyrl-tx/tx/layers/connectors.py @@ -0,0 +1,92 @@ +from flax import nnx +import jax +from jax import numpy as jnp + +from tx.layers.util import Param, sinkhorn_knopp +from tx.layers.layernorm import RMSNorm + + +class Connector(nnx.Module): + """ + Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880) + + Weights initialized with identity mapping; Default behaviour equates to residual networks. + """ + + def __init__( + self, + hidden_dim: int, + expansion_rate: int, + *, + trainable: bool = False, + sinkhorn_iters: int = 20, + eps: float = 1e-5, + dtype: jnp.dtype, + rngs: nnx.Rngs, + ) -> None: + self.hidden_dim = hidden_dim + self.expansion_rate = expansion_rate + self.trainable = trainable + self.sinkhorn_iters = sinkhorn_iters + self.eps = eps + n = expansion_rate + C = hidden_dim + + self.input_norm = RMSNorm(n * C, eps=eps, elementwise_affine=False, dtype=dtype, rngs=rngs) + self.output_norm = RMSNorm(hidden_dim, eps=eps, elementwise_affine=trainable, dtype=dtype, rngs=rngs) + + self.phi_pre = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) + self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) + self.phi_res = Param(n * C, n * n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) + + # Initialize biases for identity-like behavior: + # H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing) + + # H_pre = sigmoid(b_pre) = 1/n => b_pre = inv_sigmoid(1/n) + target_h_pre = jnp.array(1.0 / n, dtype=dtype) + clamped = jnp.clip(target_h_pre, 1e-6, 1.0) + inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped) + self.b_pre = nnx.Param(jnp.full((n,), inv_sigmoid, dtype=dtype)) + + # H_post = 2 * sigmoid(b_post) = 1 => b_post = 0 + self.b_post = nnx.Param(jnp.zeros((n,), dtype=dtype)) + + # M = sinkhorn(exp(b_res)) = I => b_res = large diagonal matrix + self.b_res = nnx.Param(10.0 * jnp.eye(n, dtype=dtype)) + + self.alpha_pre = nnx.Param(jnp.array(0.0, dtype=dtype)) + self.alpha_post = nnx.Param(jnp.array(0.0, dtype=dtype)) + self.alpha_res = nnx.Param(jnp.array(0.0, dtype=dtype)) + + def _get_params(self): + """Get all connector params, with stop_gradient applied if not trainable.""" + sg = (lambda x: x) if self.trainable else jax.lax.stop_gradient + return ( + sg(self.alpha_pre[...]), sg(self.alpha_post[...]), sg(self.alpha_res[...]), + sg(self.phi_pre[...]), sg(self.phi_post[...]), sg(self.phi_res[...]), + sg(self.b_pre[...]), sg(self.b_post[...]), sg(self.b_res[...]), + ) + + def pre(self, x: jax.Array) -> jax.Array: + *batch_dims, n, C = x.shape + + x_norm = self.input_norm(x.reshape(*batch_dims, n * C)) + + (alpha_pre, alpha_post, alpha_res, phi_pre, phi_post, phi_res, + b_pre, b_post, b_res) = self._get_params() + + tilde_H_pre = alpha_pre * (x_norm @ phi_pre) + b_pre + tilde_H_post = alpha_post * (x_norm @ phi_post) + b_post + tilde_H_res = alpha_res * (x_norm @ phi_res).reshape(*batch_dims, n, n) + b_res + + H_pre = jax.nn.sigmoid(tilde_H_pre) + self.H_post = 2.0 * jax.nn.sigmoid(tilde_H_post) + self.M = sinkhorn_knopp(tilde_H_res, self.sinkhorn_iters, self.eps) + + x_agg = (H_pre[..., None] * x).sum(axis=-2) + return self.output_norm(x_agg) + + def post(self, residual: jax.Array, output: jax.Array) -> jax.Array: + y_dist = self.H_post[..., None] * output[..., None, :] + x_mixed = self.M @ residual + return x_mixed + y_dist diff --git a/skyrl-tx/tx/layers/layernorm.py b/skyrl-tx/tx/layers/layernorm.py index e061d6d09..ca751b1cf 100644 --- a/skyrl-tx/tx/layers/layernorm.py +++ b/skyrl-tx/tx/layers/layernorm.py @@ -13,12 +13,18 @@ class RMSNorm(nnx.Module): Reference: https://arxiv.org/abs/1910.07467 """ - def __init__(self, size: int, *, eps: float = 1e-6, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: + def __init__( + self, size: int, *, eps: float = 1e-6, elementwise_affine: bool = True, dtype: jnp.dtype, rngs: nnx.Rngs + ) -> None: self.eps = eps - self.weight = Param( - size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs - ) + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = Param( + size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs + ) def __call__(self, x: jax.Array) -> jax.Array: rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps) - return self.weight * x / rms + if self.elementwise_affine: + return self.weight * x / rms + return x / rms diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..238de2474 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -63,6 +63,18 @@ def Param(*shape: int, dtype: jnp.dtype, kernel_init: nnx.Initializer, rngs: nnx return nnx.Param(kernel_init(rngs.param(), shape, dtype)) +def sinkhorn_knopp(M: jax.Array, iters: int = 20, eps: float = 1e-5) -> jax.Array: + """Sinkhorn-Knopp algorithm to project a matrix onto the set of doubly stochastic matrices.""" + M = jnp.exp(M) + + def step(_, M): + M = M / (M.sum(axis=-1, keepdims=True) + eps) + M = M / (M.sum(axis=-2, keepdims=True) + eps) + return M + + return lax.fori_loop(0, iters, step, M) + + def prepare_routing( tokens: jax.Array, indices: jax.Array, num_groups: int, adapter_indices: jax.Array | None = None ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]: diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 9232832d1..aa1c138fb 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -7,6 +7,7 @@ from tx.layers.rotary_embedding import get_rope from tx.layers.util import Param, prepare_routing, shard_map_ep from tx.layers.layernorm import RMSNorm +from tx.layers.connectors import Connector from tx.models.configs import DeepseekV3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache @@ -417,17 +418,28 @@ def __call__( class DeepseekV3DecoderLayer(nnx.Module): - def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) + def __init__( + self, + config: DeepseekV3Config, + layer_idx: int, + *, + dtype: jnp.dtype, + rngs: nnx.Rngs, + expansion_rate: int = 1, + ) -> None: self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) + self.layer_idx = layer_idx + self.num_layers = config.num_hidden_layers + self.expansion_rate = expansion_rate - # Use dense MLP for initial layers, MoE for the rest if layer_idx >= config.first_k_dense_replace: self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) else: self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) + self.attn_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs) + self.mlp_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs) + def __call__( self, hidden_states: jax.Array, @@ -437,8 +449,12 @@ def __call__( adapter_indices: jax.Array | None = None, kv_cache: tuple[jax.Array, jax.Array] | None = None, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + n = self.expansion_rate + if self.layer_idx == 0: + hidden_states = jnp.repeat(hidden_states[..., None, :], n, axis=-2) + residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.attn_connector.pre(hidden_states) hidden_states, updated_cache = self.self_attn( hidden_states, attention_mask=attention_mask, @@ -446,12 +462,15 @@ def __call__( adapter_indices=adapter_indices, kv_cache=kv_cache, ) - hidden_states = residual + hidden_states + hidden_states = self.attn_connector.post(residual, hidden_states) residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp_connector.pre(hidden_states) mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices) - hidden_states = residual + mlp_output + hidden_states = self.mlp_connector.post(residual, mlp_output) + + if self.layer_idx == self.num_layers - 1: + hidden_states = hidden_states.sum(axis=-2) return hidden_states, updated_cache @@ -500,7 +519,7 @@ def __call__( for layer_idx, layer in enumerate(self.layers): if output_hidden_states: - all_hidden_states.append(hidden_states) + all_hidden_states.append(hidden_states.squeeze()) hidden_states, (k, v) = layer( hidden_states, diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6e840febf..170262381 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -113,6 +113,9 @@ def load_safetensors( # Skip LoRA parameters if requested if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue + # Skip connector parameters + if any("connector" in str(p) for p in path): + continue if "experts" in path: tensors[key] = np.stack( [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0