Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions skyrl-tx/tests/layers/test_connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from flax import nnx


@pytest.fixture(scope="module")
def mesh():
return jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)


@pytest.mark.parametrize("expansion_rate", [1, 2, 4])
def test_connector_shapes(mesh, expansion_rate: int):
"""Test that Connector produces correct output shapes."""
with jax.set_mesh(mesh):
from tx.layers.connectors import Connector

hidden_dim = 64
batch, seq = 2, 8

conn = Connector(hidden_dim, expansion_rate, dtype=jnp.float32, rngs=nnx.Rngs(0))

x = jnp.ones((batch, seq, expansion_rate, hidden_dim))
pre_out = conn.pre(x)
post_out = conn.post(x, pre_out)

assert pre_out.shape == (batch, seq, hidden_dim)
assert post_out.shape == (batch, seq, expansion_rate, hidden_dim)


@pytest.mark.parametrize("expansion_rate", [1, 2, 4])
def test_connector_identity_initialization(mesh, expansion_rate: int):
"""Test that Connector with identity initialization behaves like residual connection."""
with jax.set_mesh(mesh):
from tx.layers.connectors import Connector
from tx.layers.util import sinkhorn_knopp

hidden_dim = 64
n = expansion_rate

conn = Connector(hidden_dim, n, dtype=jnp.float32, rngs=nnx.Rngs(0))

# Verify H_pre = 1/n
_, _, _, _, _, _, b_pre, _, _ = conn._get_params()
h_pre = jax.nn.sigmoid(b_pre)
assert np.allclose(h_pre, 1.0 / n, atol=1e-5)

# Verify H_post = 1
_, _, _, _, _, _, _, b_post, _ = conn._get_params()
h_post = 2.0 * jax.nn.sigmoid(b_post)
assert np.allclose(h_post, 1.0, atol=1e-6)

# Verify M = I
_, _, _, _, _, _, _, _, b_res = conn._get_params()
M = sinkhorn_knopp(b_res)
assert np.allclose(M, jnp.eye(n), atol=1e-3)

92 changes: 92 additions & 0 deletions skyrl-tx/tx/layers/connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from flax import nnx
import jax
from jax import numpy as jnp

from tx.layers.util import Param, sinkhorn_knopp
from tx.layers.layernorm import RMSNorm


class Connector(nnx.Module):
"""
Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880)

Weights initialized with identity mapping; Default behaviour equates to residual networks.
"""

def __init__(
self,
hidden_dim: int,
expansion_rate: int,
*,
trainable: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trainable parameter is defined but it is not used anywhere in the Connector class. This could be misleading for developers using this module. Consider removing it from the method signature, and also the assignment self.trainable = trainable on line 27, to improve code clarity.

sinkhorn_iters: int = 20,
eps: float = 1e-5,
dtype: jnp.dtype,
rngs: nnx.Rngs,
) -> None:
self.hidden_dim = hidden_dim
self.expansion_rate = expansion_rate
self.trainable = trainable
self.sinkhorn_iters = sinkhorn_iters
self.eps = eps
n = expansion_rate
C = hidden_dim

self.input_norm = RMSNorm(n * C, eps=eps, elementwise_affine=False, dtype=dtype, rngs=rngs)
self.output_norm = RMSNorm(hidden_dim, eps=eps, elementwise_affine=trainable, dtype=dtype, rngs=rngs)

self.phi_pre = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs)
self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs)
self.phi_res = Param(n * C, n * n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs)

# Initialize biases for identity-like behavior:
# H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing)

# H_pre = sigmoid(b_pre) = 1/n => b_pre = inv_sigmoid(1/n)
target_h_pre = jnp.array(1.0 / n, dtype=dtype)
clamped = jnp.clip(target_h_pre, 1e-6, 1.0)
inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
self.b_pre = nnx.Param(jnp.full((n,), inv_sigmoid, dtype=dtype))

# H_post = 2 * sigmoid(b_post) = 1 => b_post = 0
self.b_post = nnx.Param(jnp.zeros((n,), dtype=dtype))

# M = sinkhorn(exp(b_res)) = I => b_res = large diagonal matrix
self.b_res = nnx.Param(10.0 * jnp.eye(n, dtype=dtype))

self.alpha_pre = nnx.Param(jnp.array(0.0, dtype=dtype))
self.alpha_post = nnx.Param(jnp.array(0.0, dtype=dtype))
self.alpha_res = nnx.Param(jnp.array(0.0, dtype=dtype))

def _get_params(self):
"""Get all connector params, with stop_gradient applied if not trainable."""
sg = (lambda x: x) if self.trainable else jax.lax.stop_gradient
return (
sg(self.alpha_pre[...]), sg(self.alpha_post[...]), sg(self.alpha_res[...]),
sg(self.phi_pre[...]), sg(self.phi_post[...]), sg(self.phi_res[...]),
sg(self.b_pre[...]), sg(self.b_post[...]), sg(self.b_res[...]),
)

def pre(self, x: jax.Array) -> jax.Array:
*batch_dims, n, C = x.shape

x_norm = self.input_norm(x.reshape(*batch_dims, n * C))

(alpha_pre, alpha_post, alpha_res, phi_pre, phi_post, phi_res,
b_pre, b_post, b_res) = self._get_params()

tilde_H_pre = alpha_pre * (x_norm @ phi_pre) + b_pre
tilde_H_post = alpha_post * (x_norm @ phi_post) + b_post
tilde_H_res = alpha_res * (x_norm @ phi_res).reshape(*batch_dims, n, n) + b_res

H_pre = jax.nn.sigmoid(tilde_H_pre)
self.H_post = 2.0 * jax.nn.sigmoid(tilde_H_post)
self.M = sinkhorn_knopp(tilde_H_res, self.sinkhorn_iters, self.eps)

x_agg = (H_pre[..., None] * x).sum(axis=-2)
return self.output_norm(x_agg)

def post(self, residual: jax.Array, output: jax.Array) -> jax.Array:
y_dist = self.H_post[..., None] * output[..., None, :]
x_mixed = self.M @ residual
return x_mixed + y_dist
16 changes: 11 additions & 5 deletions skyrl-tx/tx/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ class RMSNorm(nnx.Module):
Reference: https://arxiv.org/abs/1910.07467
"""

def __init__(self, size: int, *, eps: float = 1e-6, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
def __init__(
self, size: int, *, eps: float = 1e-6, elementwise_affine: bool = True, dtype: jnp.dtype, rngs: nnx.Rngs
) -> None:
self.eps = eps
self.weight = Param(
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs
)
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = Param(
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs
)

def __call__(self, x: jax.Array) -> jax.Array:
rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
return self.weight * x / rms
if self.elementwise_affine:
return self.weight * x / rms
return x / rms
12 changes: 12 additions & 0 deletions skyrl-tx/tx/layers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def Param(*shape: int, dtype: jnp.dtype, kernel_init: nnx.Initializer, rngs: nnx
return nnx.Param(kernel_init(rngs.param(), shape, dtype))


def sinkhorn_knopp(M: jax.Array, iters: int = 20, eps: float = 1e-5) -> jax.Array:
"""Sinkhorn-Knopp algorithm to project a matrix onto the set of doubly stochastic matrices."""
M = jnp.exp(M)

def step(_, M):
M = M / (M.sum(axis=-1, keepdims=True) + eps)
M = M / (M.sum(axis=-2, keepdims=True) + eps)
return M

return lax.fori_loop(0, iters, step, M)


def prepare_routing(
tokens: jax.Array, indices: jax.Array, num_groups: int, adapter_indices: jax.Array | None = None
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]:
Expand Down
37 changes: 28 additions & 9 deletions skyrl-tx/tx/models/deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tx.layers.rotary_embedding import get_rope
from tx.layers.util import Param, prepare_routing, shard_map_ep
from tx.layers.layernorm import RMSNorm
from tx.layers.connectors import Connector
from tx.models.configs import DeepseekV3Config
from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache
Expand Down Expand Up @@ -417,17 +418,28 @@ def __call__(

class DeepseekV3DecoderLayer(nnx.Module):

def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)
def __init__(
self,
config: DeepseekV3Config,
layer_idx: int,
*,
dtype: jnp.dtype,
rngs: nnx.Rngs,
expansion_rate: int = 1,
) -> None:
self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs)
self.layer_idx = layer_idx
self.num_layers = config.num_hidden_layers
self.expansion_rate = expansion_rate

# Use dense MLP for initial layers, MoE for the rest
if layer_idx >= config.first_k_dense_replace:
self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs)
else:
self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs)

self.attn_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs)
self.mlp_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs)

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -437,21 +449,28 @@ def __call__(
adapter_indices: jax.Array | None = None,
kv_cache: tuple[jax.Array, jax.Array] | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
n = self.expansion_rate
if self.layer_idx == 0:
hidden_states = jnp.repeat(hidden_states[..., None, :], n, axis=-2)

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.attn_connector.pre(hidden_states)
hidden_states, updated_cache = self.self_attn(
hidden_states,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_connector.post(residual, hidden_states)

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp_connector.pre(hidden_states)
mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices)
hidden_states = residual + mlp_output
hidden_states = self.mlp_connector.post(residual, mlp_output)

if self.layer_idx == self.num_layers - 1:
hidden_states = hidden_states.sum(axis=-2)

return hidden_states, updated_cache

Expand Down Expand Up @@ -500,7 +519,7 @@ def __call__(

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.

A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.

Suggested change
all_hidden_states.append(hidden_states.squeeze())
all_hidden_states.append(hidden_states.mean(axis=-2))


hidden_states, (k, v) = layer(
hidden_states,
Expand Down
3 changes: 3 additions & 0 deletions skyrl-tx/tx/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def load_safetensors(
# Skip LoRA parameters if requested
if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path):
continue
# Skip connector parameters
if any("connector" in str(p) for p in path):
continue
if "experts" in path:
tensors[key] = np.stack(
[tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0
Expand Down
Loading