Skip to content
9 changes: 5 additions & 4 deletions skyrl-tx/tests/models/test_qwen3_lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from tx.models.configs import Qwen3Config
from tx.models.qwen3 import Qwen3ForCausalLM
from tx.utils.models import get_dtype, load_safetensors
from tx.utils.models import get_dtype, load_safetensors, get_adapter_idx
from tx.layers.lora import init_lora_adapter
from tx.tinker.types import LoraConfig

Expand Down Expand Up @@ -47,15 +47,16 @@ def loss_fn(model, input_ids, target_ids, attention_mask):

# Helper to extract adapter params at specific index
def get_adapter_params(params, adapter_idx):
return jax.tree.map(lambda p: p[adapter_idx].copy(), params)
return jax.tree.map_with_path(lambda path, p: p[get_adapter_idx(path, adapter_idx)].copy(), params)

# Helper to extract out-of-rank params for an adapter
def get_out_of_rank_params(params, adapter_idx, rank):
def slice_param(path, p):
idx = get_adapter_idx(path, adapter_idx)
if "lora_A" in str(path):
return p[adapter_idx, :, rank:].copy()
return p[idx][..., :, rank:].copy()
elif "lora_B" in str(path):
return p[adapter_idx, rank:, :].copy()
return p[idx][..., rank:, :].copy()
return p

return jax.tree.map_with_path(slice_param, params)
Expand Down
27 changes: 12 additions & 15 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax
from jax import numpy as jnp

from tx.utils.models import filter_lora
from tx.utils.models import filter_lora, get_adapter_idx
from tx.layers.util import Param, prepare_routing, ragged_dot
from tx.models.types import ModelForCausalLM
from tx.tinker.types import LoraConfig
Expand Down Expand Up @@ -345,21 +345,19 @@ def init_adapter(path, value):
if not filter_lora(lora_config, normalized_path):
effective_rank = 0

idx = get_adapter_idx(path, adapter_index)
key_name = path[-2].key
if key_name == "lora_ranks":
return value.at[adapter_index].set(effective_rank)
return value.at[idx].set(effective_rank)
if key_name == "lora_scaling":
# Set scaling to 0.0 if rank is 0
return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0)
scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0
return value.at[idx].set(scaling)
if key_name == "lora_A":
# Reinitialize with he_uniform, then zero columns beyond rank
shape = value[adapter_index].shape
new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype)
new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype)
new_A = new_A.at[..., effective_rank:].set(0.0)
return value.at[adapter_index].set(new_A)
return value.at[idx].set(new_A)
if key_name == "lora_B":
# Explicitly zero lora_B
return value.at[adapter_index].set(0.0)
return value.at[idx].set(0.0)
return value

updated_state = jax.tree.map_with_path(init_adapter, state)
Expand All @@ -376,11 +374,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int):

def clear_adapter(path, value):
key = path[-2].key
if key == "lora_ranks":
return value.at[adapter_index].set(0)
if key in ("lora_scaling", "lora_A", "lora_B"):
return value.at[adapter_index].set(0.0)
return value
if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"):
return value
idx = get_adapter_idx(path, adapter_index)
return value.at[idx].set(0 if key == "lora_ranks" else 0.0)

updated_state = jax.tree.map_with_path(clear_adapter, state)
nnx.update(model, updated_state)
181 changes: 181 additions & 0 deletions skyrl-tx/tx/layers/stacked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""StackedDecoderLayers module for efficient transformer layer stacking."""

from typing import Callable

from flax import nnx
import jax
import jax.numpy as jnp

from tx.utils.generator import KVCache


class ArrayRef(nnx.Variable):
"""A Variable providing a view into an indexed slice of a parent Variable."""

def __init__(self, parent: nnx.Variable, idx: int):
super().__init__(parent[idx])
self.set_metadata("_parent", parent)
self.set_metadata("_idx", idx)

def __getitem__(self, key):
parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx")
return parent[idx] if key is Ellipsis else parent[idx][key]

def set_raw_value(self, value, **kwargs):
"""Write through to parent when value is set."""
parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx")
parent[...] = parent[...].at[idx].set(value)
super().set_raw_value(value, **kwargs)

@property
def shape(self):
return self.get_metadata("_parent")[self.get_metadata("_idx")].shape


class StackedDecoderLayers(nnx.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the easiest way to implement DeepSeekV3 is to implement DualStackedDecoderLayers which has two StackedDecoderLayers as members and the same interface as StackedDecoderLayers (modulo the constructor which can take two create_layer_fn functions and takes their respective numbers as arguments). This could be a separate PR.

"""Decoder layers with stacked weights created via nnx.vmap.

Parameters are stored in stacked format (num_layers, ...). The forward pass
uses jax.lax.scan for training/prefill and Python loops for decode.
"""

def __init__(
self,
create_layer_fn: Callable[[nnx.Rngs], nnx.Module],
num_layers: int,
rngs: nnx.Rngs,
):
self.num_layers = num_layers

@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0, transform_metadata={nnx.PARTITION_NAME: None})
def vmapped_create(rngs: nnx.Rngs) -> nnx.Module:
return create_layer_fn(rngs)

self._stacked = vmapped_create(rngs)

def __len__(self) -> int:
"""Return the number of layers."""
return self.num_layers

def __getitem__(self, index: int) -> nnx.Module:
"""Get view into layer at index. Only for tests and weight loading."""
if index < 0 or index >= self.num_layers:
raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})")
graphdef, state = nnx.split(self._stacked)
layer_state = jax.tree.map(
lambda x: ArrayRef(x, index),
state,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
return nnx.merge(graphdef, layer_state)

def __iter__(self):
"""Iterate over individual layers (for testing/weight loading)."""
for i in range(self.num_layers):
yield self[i]

def __call__(
self,
hidden_states: jax.Array,
*,
attention_mask: jax.Array,
positions: jax.Array,
adapter_indices: jax.Array | None,
kv_cache: KVCache | None,
output_hidden_states: bool,
gradient_checkpointing: bool,
is_training: bool = False,
) -> tuple[jax.Array, list[jax.Array], KVCache | None]:
"""Forward pass through all layers.

Uses scan for training/prefill, Python loop for decode.

Returns:
(final_hidden_states, all_hidden_states, kv_cache)
"""
graphdef, state = nnx.split(self._stacked)

# Decode mode: use Python loop
if kv_cache is not None:
all_hidden_states = []
new_keys, new_values = [], []

for i in range(self.num_layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)

layer = nnx.merge(graphdef, jax.tree.map(lambda x, i=i: x[i], state))
hidden_states, (k, v) = layer(
hidden_states,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=(kv_cache.keys[i], kv_cache.values[i]),
)
new_keys.append(k)
new_values.append(v)

return (
hidden_states,
all_hidden_states,
KVCache(
keys=new_keys,
values=new_values,
cache_position=kv_cache.cache_position + positions.shape[1],
),
)

# Training/prefill mode: use scan
def body_fn(hs, layer_params):
layer = nnx.merge(graphdef, layer_params)
new_hs, (k, v) = layer(
hs,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=None,
)
if is_training:
k = v = None
return new_hs, (new_hs if output_hidden_states else None, k, v)

if gradient_checkpointing:
body_fn = jax.checkpoint(body_fn)

final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state)

all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else []

if is_training:
return final_hs, all_hidden_states, None

return (
final_hs,
all_hidden_states,
KVCache(
keys=[all_keys[i] for i in range(self.num_layers)],
values=[all_values[i] for i in range(self.num_layers)],
cache_position=attention_mask.sum(axis=1).astype(jnp.int32),
),
)


def unstack_state(module: nnx.Module) -> nnx.GraphState:
"""Transform stacked layer state to unstacked ArrayRef views.

Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc.
Each entry is an ArrayRef that writes through to the original stacked variable.
"""
expanded = []
for path, var in nnx.to_flat_state(nnx.state(module)):
if "_stacked" not in path:
expanded.append((path, var))
continue

idx = path.index("_stacked")
for i in range(var[...].shape[0]):
new_path = path[:idx] + (str(i),) + path[idx + 1 :]
expanded.append((new_path, ArrayRef(var, i)))

return nnx.from_flat_state(expanded)
16 changes: 11 additions & 5 deletions skyrl-tx/tx/layers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,17 @@ def shard_map_ep(module: nnx.Module, func, *args):
*args: Arguments to pass to func (replicated across shards).
"""
graphdef, state = nnx.split(module)
# Extract only 'ep' dims from PartitionSpecs, replacing others with None
state_specs = jax.tree.map(
lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s,
nnx.get_partition_spec(state),
is_leaf=lambda x: isinstance(x, PartitionSpec),

def to_ep_spec(path, s):
if not isinstance(s, PartitionSpec):
return s
# Strip leading stacking dimension if path contains "_stacked"
dims = s[1:] if "_stacked" in str(path) else s
# Extract only 'ep' dims from PartitionSpecs, replacing others with None
return PartitionSpec(*(p if p == "ep" else None for p in dims))

state_specs = jax.tree_util.tree_map_with_path(
to_ep_spec, nnx.get_partition_spec(state), is_leaf=lambda x: isinstance(x, PartitionSpec)
)
in_specs = (state_specs,) + (PartitionSpec(),) * len(args)

Expand Down
49 changes: 25 additions & 24 deletions skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from jax import numpy as jnp
from jax.sharding import get_abstract_mesh

from tx.layers.attention import dot_product_attention
from tx.layers.layernorm import RMSNorm
from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear
from tx.layers.util import prepare_routing, shard_map_ep
from tx.layers.rotary_embedding import apply_rope
from tx.utils.logits_processor import LogitsProcessorMixin, LMHead
from tx.layers.layernorm import RMSNorm
from tx.layers.attention import dot_product_attention
from tx.layers.stacked import StackedDecoderLayers
from tx.layers.util import prepare_routing, shard_map_ep
from tx.models.configs import Qwen3Config
from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache
from tx.utils.logits_processor import LogitsProcessorMixin, LMHead


class Qwen3Attention(nnx.Module):
Expand Down Expand Up @@ -329,9 +330,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
embedding_init=nnx.initializers.normal(),
rngs=rngs,
)
self.layers = nnx.List(
[Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)]
)

def create_layer(rngs: nnx.Rngs) -> Qwen3DecoderLayer:
return Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs)

self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)

def __call__(
Expand All @@ -343,36 +346,32 @@ def __call__(
output_hidden_states: bool | None = None,
adapter_indices: jax.Array | None = None,
kv_cache: KVCache | None = None,
is_training: bool = False,
) -> ModelOutput:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices)
all_hidden_states: list[jax.Array] = []
updated_keys, updated_values = [], []

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)

hidden_states, (k, v) = layer(
hidden_states,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]),
)
updated_keys.append(k)
updated_values.append(v)

hidden_states, all_hidden_states, new_kv_cache = self.layers(
hidden_states,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
output_hidden_states=output_hidden_states,
gradient_checkpointing=self.config.gradient_checkpointing,
is_training=is_training,
)

hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states.append(hidden_states)

return ModelOutput(
last_hidden_state=hidden_states,
kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask),
kv_cache=new_kv_cache,
hidden_states=all_hidden_states if output_hidden_states else None,
)

Expand Down Expand Up @@ -417,6 +416,7 @@ def __call__(
output_hidden_states: bool | None = None,
adapter_indices: jax.Array | None = None,
kv_cache: KVCache | None = None,
is_training: bool = False,
) -> CausalLMOutput:
if positions is None:
positions = jnp.arange(attention_mask.shape[1])[None, :]
Expand All @@ -428,6 +428,7 @@ def __call__(
output_hidden_states=output_hidden_states,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
is_training=is_training,
)

return CausalLMOutput(
Expand Down
Loading
Loading