From 7653b1c3335efbb21c363f5096b8096d06a544d6 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 16:08:03 -0800 Subject: [PATCH 01/20] [tx] Implement stacked layers --- skyrl-tx/tx/layers/lora.py | 27 +++---- skyrl-tx/tx/models/qwen3.py | 49 +++++++------ skyrl-tx/tx/utils/models.py | 141 ++++++++++++++++++++++++------------ 3 files changed, 131 insertions(+), 86 deletions(-) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 776b7af59..8d4bd01a1 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -2,7 +2,7 @@ import jax from jax import numpy as jnp -from tx.utils.models import filter_lora +from tx.utils.models import filter_lora, get_adapter_idx from tx.layers.util import Param, prepare_routing, ragged_dot from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig @@ -345,21 +345,19 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 + idx = get_adapter_idx(path, adapter_index) key_name = path[-2].key if key_name == "lora_ranks": - return value.at[adapter_index].set(effective_rank) + return value.at[idx].set(effective_rank) if key_name == "lora_scaling": - # Set scaling to 0.0 if rank is 0 - return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0) + scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0 + return value.at[idx].set(scaling) if key_name == "lora_A": - # Reinitialize with he_uniform, then zero columns beyond rank - shape = value[adapter_index].shape - new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) + new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype) new_A = new_A.at[..., effective_rank:].set(0.0) - return value.at[adapter_index].set(new_A) + return value.at[idx].set(new_A) if key_name == "lora_B": - # Explicitly zero lora_B - return value.at[adapter_index].set(0.0) + return value.at[idx].set(0.0) return value updated_state = jax.tree.map_with_path(init_adapter, state) @@ -376,11 +374,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): def clear_adapter(path, value): key = path[-2].key - if key == "lora_ranks": - return value.at[adapter_index].set(0) - if key in ("lora_scaling", "lora_A", "lora_B"): - return value.at[adapter_index].set(0.0) - return value + if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): + return value + idx = get_adapter_idx(path, adapter_index) + return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) nnx.update(model, updated_state) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 1348cac09..9e5332716 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -3,15 +3,16 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh +from tx.layers.attention import dot_product_attention +from tx.layers.layernorm import RMSNorm from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear -from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope -from tx.utils.logits_processor import LogitsProcessorMixin, LMHead -from tx.layers.layernorm import RMSNorm -from tx.layers.attention import dot_product_attention +from tx.layers.stacked import StackedDecoderLayers +from tx.layers.util import prepare_routing, shard_map_ep from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead class Qwen3Attention(nnx.Module): @@ -329,9 +330,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> embedding_init=nnx.initializers.normal(), rngs=rngs, ) - self.layers = nnx.List( - [Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)] - ) + + def create_layer(rngs: nnx.Rngs) -> Qwen3DecoderLayer: + return Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) + + self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -343,28 +346,24 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), - ) - updated_keys.append(k) - updated_values.append(v) + + hidden_states, all_hidden_states, new_kv_cache = self.layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -372,7 +371,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -417,6 +416,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -428,6 +428,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6e840febf..47b724227 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -76,19 +76,49 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: raise ValueError(f"None of the architectures {config.architectures} is currently supported.") -def get_param_key(path: tuple, prefix: str = "") -> str: - "Get the safetensors key for a given model path." - if path[-1] in {"embedding", "kernel"}: - path = (*path[:-1], "weight") - elif path[-1] in {"lora_A", "lora_B"}: - path = (*path, "weight") - return prefix + ".".join(map(str, path)) +def is_stacked_path(path: tuple) -> bool: + """Check if a parameter path is for stacked layers (contains '_stacked').""" + return any((p.key if hasattr(p, "key") else str(p)) == "_stacked" for p in path) -def get_expert_key(path: tuple, expert_idx: int) -> str: - "Get the safetensors key for an expert weight model path." - path = tuple(s if s != "experts" else f"experts.{expert_idx}" for s in path) - return ".".join(map(str, path)) +def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: + """Return index tuple for accessing an adapter. Stacked: [:, idx], non-stacked: [idx].""" + return (slice(None), adapter_index) if is_stacked_path(path) else (adapter_index,) + + +def _path_to_hf_key(path: tuple, layer_idx: int | None = None) -> str: + """Convert param path to HuggingFace key.""" + parts = [] + for p in path: + key = p.key if hasattr(p, "key") else str(p) + if key == "_stacked": + continue + if key == "layers" and layer_idx is not None: + parts.append(f"layers.{layer_idx}") + elif key in ("kernel", "embedding"): + parts.append("weight") + elif key in ("lora_A", "lora_B"): + parts.extend([key, "weight"]) + elif key == "experts": + parts.append("experts") # expert idx added separately + else: + parts.append(key) + return ".".join(parts) + + +def _load_layer_tensor(tensors: dict, path: tuple, shape: tuple, num_experts: int | None, layer_idx: int | None) -> np.ndarray: + """Load tensor for one layer, handling experts and transpose.""" + key = _path_to_hf_key(path, layer_idx) + path_str = str(path) + + if "experts" in path_str and num_experts: + tensor = np.stack([tensors[key.replace(".experts.", f".experts.{i}.")].T for i in range(num_experts)], axis=0) + else: + tensor = tensors[key] if "embed_tokens" in path_str else tensors[key].T + + if any(p in path_str for p in ("q_proj", "k_proj", "v_proj", "o_proj")): + tensor = tensor.reshape(shape) + return tensor def load_safetensors( @@ -104,29 +134,47 @@ def load_safetensors( tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} + num_experts = config.get_num_experts() model_params = nnx.to_flat_state(nnx.state(model)) updates = [] + for path, param in model_params: if filter_fn is not None and not filter_fn(path): continue - key = get_param_key(path) - # Skip LoRA parameters if requested - if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): + path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] + if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): continue - if "experts" in path: - tensors[key] = np.stack( - [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 - ) + + if is_stacked_path(path): + num_layers = param.shape[0] + tensor = np.stack([_load_layer_tensor(tensors, path, param.shape[1:], num_experts, i) for i in range(num_layers)]) else: - tensors[key] = tensors[key] if "embed_tokens" in path else tensors[key].T - if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: - tensors[key] = tensors[key].reshape(param.shape) - assert param.shape == tensors[key].shape, f"shape mismatch for {key}" - sharded_tensor = jax.device_put(tensors[key].astype(param.dtype), param.sharding) - updates.append((path, sharded_tensor)) + tensor = _load_layer_tensor(tensors, path, param.shape, num_experts, None) + + assert param.shape == tensor.shape, f"Shape mismatch for {path}: {param.shape} vs {tensor.shape}" + updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) + nnx.update(model, nnx.from_flat_state(updates)) +def _save_layer_tensor(tensors: dict, path: tuple, param: np.ndarray, num_experts: int | None, layer_idx: int | None, prefix: str) -> None: + """Save tensor for one layer, handling experts and transpose.""" + key = prefix + _path_to_hf_key(path, layer_idx) + path_str = str(path) + + if "experts" in path_str and num_experts: + for i in range(num_experts): + tensors[key.replace(".experts.", f".experts.{i}.")] = param[i].T + return + + if any(p in path_str for p in ("q_proj", "k_proj", "v_proj")): + param = param.reshape(param.shape[0], -1) + elif "o_proj" in path_str: + param = param.reshape(-1, param.shape[-1]) + + tensors[key] = param if "embed_tokens" in path_str else param.T + + def save_safetensors( config: ModelConfig, model: nnx.Module, @@ -134,23 +182,22 @@ def save_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: + num_experts = config.get_num_experts() model_params = nnx.to_flat_state(nnx.state(model)) tensors = {} + for path, param in model_params: - if "rngs" in path: + path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] + if "rngs" in path_keys: continue if filter_fn is not None and not filter_fn(path): continue - key = get_param_key(path, prefix=prefix) - if "experts" in path: - for i in range(config.get_num_experts()): - tensors[get_expert_key(path, i)] = param[i, :, :].T - continue - if "q_proj" in path or "k_proj" in path or "v_proj" in path: - param = param.reshape(param.shape[0], -1) - elif "o_proj" in path: - param = param.reshape(-1, param.shape[-1]) - tensors[key] = param if "embed_tokens" in path else param.T + + if is_stacked_path(path): + for i in range(param.shape[0]): + _save_layer_tensor(tensors, path, param[i], num_experts, i, prefix) + else: + _save_layer_tensor(tensors, path, param, num_experts, None, prefix) # In multi-host mode, gather all shards and only save from rank 0 if jax.process_count() > 1: @@ -252,13 +299,13 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: "Helper function to extract the adapter parameters for a specific adapter index." def extract_state(path: tuple, p: jnp.ndarray): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return p - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p[adapter_index, ..., :, :rank] - if path[-2].key == "lora_B": - return p[adapter_index, ..., :rank, :] + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p[idx + (..., slice(None, rank))] + return p[idx + (..., slice(None, rank), slice(None))] return jax.tree.map_with_path(extract_state, lora_params) @@ -271,13 +318,13 @@ def insert_adapter_state( "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." def insert_state(path: tuple, p: jax.Array, new: jax.Array): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return new - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p.at[adapter_index, ..., :, :rank].set(new) - elif path[-2].key == "lora_B": - return p.at[adapter_index, ..., :rank, :].set(new) + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p.at[idx + (..., slice(None, rank))].set(new) + return p.at[idx + (..., slice(None, rank), slice(None))].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From 320681c1c85939afc7dc02040e87f75589c94169 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 16:08:26 -0800 Subject: [PATCH 02/20] add file --- skyrl-tx/tx/layers/stacked.py | 109 ++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 skyrl-tx/tx/layers/stacked.py diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py new file mode 100644 index 000000000..6551d88d4 --- /dev/null +++ b/skyrl-tx/tx/layers/stacked.py @@ -0,0 +1,109 @@ +"""StackedDecoderLayers module for efficient transformer layer stacking.""" + +from typing import Callable + +from flax import nnx +import jax +import jax.numpy as jnp + +from tx.utils.generator import KVCache + + +class StackedDecoderLayers(nnx.Module): + """Decoder layers with stacked weights created via nnx.vmap. + + Parameters are stored in stacked format (num_layers, ...). The forward pass + uses jax.lax.scan for training/prefill and Python loops for decode. + """ + + def __init__( + self, + create_layer_fn: Callable[[nnx.Rngs], nnx.Module], + num_layers: int, + rngs: nnx.Rngs, + ): + self.num_layers = num_layers + + @nnx.split_rngs(splits=num_layers) + @nnx.vmap(in_axes=(0,), out_axes=0, transform_metadata={nnx.PARTITION_NAME: None}) + def vmapped_create(rngs: nnx.Rngs) -> nnx.Module: + return create_layer_fn(rngs) + + self._stacked = vmapped_create(rngs) + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + gradient_checkpointing: bool, + is_training: bool = False, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward pass through all layers. + + Uses scan for training/prefill, Python loop for decode. + + Returns: + (final_hidden_states, all_hidden_states, kv_cache) + """ + graphdef, state = nnx.split(self._stacked) + + # Decode mode: use Python loop + if kv_cache is not None: + all_hidden_states = [] + new_keys, new_values = [], [] + + for i in range(self.num_layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + layer = nnx.merge(graphdef, jax.tree.map(lambda x, i=i: x[i], state)) + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=(kv_cache.keys[i], kv_cache.values[i]), + ) + new_keys.append(k) + new_values.append(v) + + return hidden_states, all_hidden_states, KVCache( + keys=new_keys, + values=new_values, + cache_position=kv_cache.cache_position + positions.shape[1], + ) + + # Training/prefill mode: use scan + def body_fn(hs, layer_params): + layer = nnx.merge(graphdef, layer_params) + new_hs, (k, v) = layer( + hs, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=None, + ) + if is_training: + k = v = None + return new_hs, (new_hs if output_hidden_states else None, k, v) + + if gradient_checkpointing: + body_fn = jax.checkpoint(body_fn) + + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) + + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + + if is_training: + return final_hs, all_hidden_states, None + + # Prefill: convert stacked arrays to list and create KVCache + keys_list = [all_keys[i] for i in range(self.num_layers)] + values_list = [all_values[i] for i in range(self.num_layers)] + cache_position = attention_mask.sum(axis=1).astype(jnp.int32) + return final_hs, all_hidden_states, KVCache(keys=keys_list, values=values_list, cache_position=cache_position) From 169ec5a3077624e3b50d11d512ecdd05bb2ac88f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 17:03:31 -0800 Subject: [PATCH 03/20] fix --- skyrl-tx/tx/tinker/backends/jax.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 99ae33327..27c50a59b 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -45,6 +45,7 @@ from tx.utils.models import ( get_dtype, get_model_class, + get_adapter_idx, load_safetensors, load_lora_checkpoint, save_lora_checkpoint, @@ -126,15 +127,20 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated def get_mean(self, adapter_index: jax.Array) -> nnx.State: """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" count = self.counts[adapter_index] - return jax.tree.map( - lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), - self.grad_sum, - ) + def _select_mean(path, g): + idx = get_adapter_idx(path, adapter_index) + return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype)) + + return jax.tree.map_with_path(_select_mean, self.grad_sum) def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": """Reset gradients and count for a specific adapter.""" + def _reset(path, g): + idx = get_adapter_idx(path, adapter_index) + return g.at[idx].set(0.0) + return AccumulatedGradients( - grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), + grad_sum=jax.tree.map_with_path(_reset, self.grad_sum), counts=self.counts.at[adapter_index].set(0), ) From 872fcf080d295f4993b93191976e134174cdba24 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 18:00:53 -0800 Subject: [PATCH 04/20] update --- skyrl-tx/tests/models/test_qwen3.py | 5 ++++- skyrl-tx/tx/layers/stacked.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 55a779c9e..7c06b757d 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -246,7 +246,8 @@ def test_qwen3_lora(): ) # Load layer LoRA weights - for i, layer in enumerate(model.model.layers): + for i in range(len(model.model.layers)): + layer = model.model.layers[i] hf_layer = hf_model.base_model.model.model.layers[i] for module, projections in [ ("mlp", ["gate_proj", "up_proj", "down_proj"]), @@ -262,6 +263,8 @@ def test_qwen3_lora(): scaling=lora_config.lora_alpha / lora_config.r, rank=lora_config.r, ) + # Write back the modified layer to update stacked weights + model.model.layers[i] = layer # Use different adapter indices for each input adapter_indices = jnp.arange(len(lora_adapters), dtype=jnp.int32) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 6551d88d4..da07405ce 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -31,6 +31,37 @@ def vmapped_create(rngs: nnx.Rngs) -> nnx.Module: self._stacked = vmapped_create(rngs) + def __len__(self) -> int: + """Return the number of layers.""" + return self.num_layers + + def __getitem__(self, index: int) -> nnx.Module: + """Get individual layer by index (for testing/weight loading).""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + graphdef, state = nnx.split(self._stacked) + layer_state = jax.tree.map(lambda x: x[index], state) + return nnx.merge(graphdef, layer_state) + + def __iter__(self): + """Iterate over individual layers (for testing/weight loading).""" + for i in range(self.num_layers): + yield self[i] + + def __setitem__(self, index: int, layer: nnx.Module): + """Update stacked state from a modified layer (for testing/weight loading).""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + graphdef, state = nnx.split(self._stacked) + _, layer_state = nnx.split(layer) + new_state = jax.tree.map( + lambda s, l: s.replace(s[...].at[index].set(l[...])), + state, + layer_state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + self._stacked = nnx.merge(graphdef, new_state) + def __call__( self, hidden_states: jax.Array, From e3c3ecd68c3c7c251e0ada244d36a87e52b9af03 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 18:58:12 -0800 Subject: [PATCH 05/20] update --- skyrl-tx/tx/layers/stacked.py | 4 + skyrl-tx/tx/utils/models.py | 163 ++++++++++++++-------------------- 2 files changed, 72 insertions(+), 95 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index da07405ce..0862e1c42 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -48,6 +48,10 @@ def __iter__(self): for i in range(self.num_layers): yield self[i] + @property + def is_stacked(self) -> bool: + return True + def __setitem__(self, index: int, layer: nnx.Module): """Update stacked state from a modified layer (for testing/weight loading).""" if index < 0 or index >= self.num_layers: diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 47b724227..44ba6acef 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -76,9 +76,24 @@ def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: raise ValueError(f"None of the architectures {config.architectures} is currently supported.") +def get_param_key(path: tuple, prefix: str = "") -> str: + "Get the safetensors key for a given model path." + if path[-1] in {"embedding", "kernel"}: + path = (*path[:-1], "weight") + elif path[-1] in {"lora_A", "lora_B"}: + path = (*path, "weight") + return prefix + ".".join(map(str, path)) + + +def get_expert_key(path: tuple, expert_idx: int) -> str: + "Get the safetensors key for an expert weight model path." + path = tuple(s if s != "experts" else f"experts.{expert_idx}" for s in path) + return ".".join(map(str, path)) + + def is_stacked_path(path: tuple) -> bool: """Check if a parameter path is for stacked layers (contains '_stacked').""" - return any((p.key if hasattr(p, "key") else str(p)) == "_stacked" for p in path) + return "_stacked" in str(path) def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: @@ -86,41 +101,6 @@ def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: return (slice(None), adapter_index) if is_stacked_path(path) else (adapter_index,) -def _path_to_hf_key(path: tuple, layer_idx: int | None = None) -> str: - """Convert param path to HuggingFace key.""" - parts = [] - for p in path: - key = p.key if hasattr(p, "key") else str(p) - if key == "_stacked": - continue - if key == "layers" and layer_idx is not None: - parts.append(f"layers.{layer_idx}") - elif key in ("kernel", "embedding"): - parts.append("weight") - elif key in ("lora_A", "lora_B"): - parts.extend([key, "weight"]) - elif key == "experts": - parts.append("experts") # expert idx added separately - else: - parts.append(key) - return ".".join(parts) - - -def _load_layer_tensor(tensors: dict, path: tuple, shape: tuple, num_experts: int | None, layer_idx: int | None) -> np.ndarray: - """Load tensor for one layer, handling experts and transpose.""" - key = _path_to_hf_key(path, layer_idx) - path_str = str(path) - - if "experts" in path_str and num_experts: - tensor = np.stack([tensors[key.replace(".experts.", f".experts.{i}.")].T for i in range(num_experts)], axis=0) - else: - tensor = tensors[key] if "embed_tokens" in path_str else tensors[key].T - - if any(p in path_str for p in ("q_proj", "k_proj", "v_proj", "o_proj")): - tensor = tensor.reshape(shape) - return tensor - - def load_safetensors( checkpoint_dir: str | os.PathLike, config: ModelConfig, @@ -134,45 +114,37 @@ def load_safetensors( tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} - num_experts = config.get_num_experts() - model_params = nnx.to_flat_state(nnx.state(model)) - updates = [] - - for path, param in model_params: - if filter_fn is not None and not filter_fn(path): - continue - path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] - if skip_lora and any(k in path_keys for k in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): - continue - - if is_stacked_path(path): - num_layers = param.shape[0] - tensor = np.stack([_load_layer_tensor(tensors, path, param.shape[1:], num_experts, i) for i in range(num_layers)]) + def load_params(module: nnx.Module, key_prefix: str): + updates = [] + for path, param in nnx.to_flat_state(nnx.state(module)): + if filter_fn is not None and not filter_fn(path): + continue + key = key_prefix + get_param_key(path) + if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): + continue + if "experts" in path: + tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) + else: + tensor = tensors[key] if "embed_tokens" in key else tensors[key].T + if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: + tensor = tensor.reshape(param.shape) + assert param.shape == tensor.shape, f"shape mismatch for {key}" + updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) + nnx.update(module, nnx.from_flat_state(updates)) + + def load_recursive(module: nnx.Module, key_prefix: str): + if getattr(module, "is_stacked", False): + for i in range(len(module)): + layer = module[i] + load_params(layer, f"{key_prefix}{i}.") + module[i] = layer + elif children := list(nnx.iter_children(module)): + for name, child in children: + load_recursive(child, f"{key_prefix}{name}.") else: - tensor = _load_layer_tensor(tensors, path, param.shape, num_experts, None) - - assert param.shape == tensor.shape, f"Shape mismatch for {path}: {param.shape} vs {tensor.shape}" - updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) - - nnx.update(model, nnx.from_flat_state(updates)) - + load_params(module, key_prefix) -def _save_layer_tensor(tensors: dict, path: tuple, param: np.ndarray, num_experts: int | None, layer_idx: int | None, prefix: str) -> None: - """Save tensor for one layer, handling experts and transpose.""" - key = prefix + _path_to_hf_key(path, layer_idx) - path_str = str(path) - - if "experts" in path_str and num_experts: - for i in range(num_experts): - tensors[key.replace(".experts.", f".experts.{i}.")] = param[i].T - return - - if any(p in path_str for p in ("q_proj", "k_proj", "v_proj")): - param = param.reshape(param.shape[0], -1) - elif "o_proj" in path_str: - param = param.reshape(-1, param.shape[-1]) - - tensors[key] = param if "embed_tokens" in path_str else param.T + load_recursive(model, "") def save_safetensors( @@ -182,22 +154,23 @@ def save_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: - num_experts = config.get_num_experts() model_params = nnx.to_flat_state(nnx.state(model)) tensors = {} - for path, param in model_params: - path_keys = [p.key if hasattr(p, "key") else str(p) for p in path] - if "rngs" in path_keys: + if "rngs" in path: continue if filter_fn is not None and not filter_fn(path): continue - - if is_stacked_path(path): - for i in range(param.shape[0]): - _save_layer_tensor(tensors, path, param[i], num_experts, i, prefix) - else: - _save_layer_tensor(tensors, path, param, num_experts, None, prefix) + key = get_param_key(path, prefix=prefix) + if "experts" in path: + for i in range(config.get_num_experts()): + tensors[get_expert_key(path, i)] = param[i, :, :].T + continue + if "q_proj" in path or "k_proj" in path or "v_proj" in path: + param = param.reshape(param.shape[0], -1) + elif "o_proj" in path: + param = param.reshape(-1, param.shape[-1]) + tensors[key] = param if "embed_tokens" in path else param.T # In multi-host mode, gather all shards and only save from rank 0 if jax.process_count() > 1: @@ -299,13 +272,13 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: "Helper function to extract the adapter parameters for a specific adapter index." def extract_state(path: tuple, p: jnp.ndarray): - key = path[-2].key - if key not in {"lora_A", "lora_B"}: + if path[-2].key not in {"lora_A", "lora_B"}: return p - idx = get_adapter_idx(path, adapter_index) - if key == "lora_A": - return p[idx + (..., slice(None, rank))] - return p[idx + (..., slice(None, rank), slice(None))] + assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" + if path[-2].key == "lora_A": + return p[adapter_index, ..., :, :rank] + if path[-2].key == "lora_B": + return p[adapter_index, ..., :rank, :] return jax.tree.map_with_path(extract_state, lora_params) @@ -318,13 +291,13 @@ def insert_adapter_state( "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." def insert_state(path: tuple, p: jax.Array, new: jax.Array): - key = path[-2].key - if key not in {"lora_A", "lora_B"}: + if path[-2].key not in {"lora_A", "lora_B"}: return new - idx = get_adapter_idx(path, adapter_index) - if key == "lora_A": - return p.at[idx + (..., slice(None, rank))].set(new) - return p.at[idx + (..., slice(None, rank), slice(None))].set(new) + assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" + if path[-2].key == "lora_A": + return p.at[adapter_index, ..., :, :rank].set(new) + elif path[-2].key == "lora_B": + return p.at[adapter_index, ..., :rank, :].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From 2336a08dd19f2bb39c4f6ab77b6bae3705f6c2d8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 19:48:34 -0800 Subject: [PATCH 06/20] fix ruff --- skyrl-tx/tx/layers/stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 0862e1c42..cff22d52c 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -59,7 +59,7 @@ def __setitem__(self, index: int, layer: nnx.Module): graphdef, state = nnx.split(self._stacked) _, layer_state = nnx.split(layer) new_state = jax.tree.map( - lambda s, l: s.replace(s[...].at[index].set(l[...])), + lambda s, lv: s.replace(s[...].at[index].set(lv[...])), state, layer_state, is_leaf=lambda x: isinstance(x, nnx.Variable), From 526efa28815d6a3e9d50d5722047445eb50d3a73 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 21:55:52 -0800 Subject: [PATCH 07/20] update --- skyrl-tx/tx/layers/stacked.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index cff22d52c..70101eec9 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -9,6 +9,23 @@ from tx.utils.generator import KVCache +class ArrayRef(nnx.Variable): + """A Variable providing a view into an indexed slice of a parent Variable.""" + + def __init__(self, parent: nnx.Variable, idx: int): + super().__init__(parent[idx]) + self.set_metadata("_parent", parent) + self.set_metadata("_idx", idx) + + def __getitem__(self, key): + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + return parent[idx] if key is Ellipsis else parent[idx][key] + + @property + def shape(self): + return self.get_metadata("_parent")[self.get_metadata("_idx")].shape + + class StackedDecoderLayers(nnx.Module): """Decoder layers with stacked weights created via nnx.vmap. @@ -36,11 +53,15 @@ def __len__(self) -> int: return self.num_layers def __getitem__(self, index: int) -> nnx.Module: - """Get individual layer by index (for testing/weight loading).""" + """Get view into layer at index (stays synced with stacked state).""" if index < 0 or index >= self.num_layers: raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") graphdef, state = nnx.split(self._stacked) - layer_state = jax.tree.map(lambda x: x[index], state) + layer_state = jax.tree.map( + lambda x: ArrayRef(x, index), + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) return nnx.merge(graphdef, layer_state) def __iter__(self): @@ -59,7 +80,7 @@ def __setitem__(self, index: int, layer: nnx.Module): graphdef, state = nnx.split(self._stacked) _, layer_state = nnx.split(layer) new_state = jax.tree.map( - lambda s, lv: s.replace(s[...].at[index].set(lv[...])), + lambda s, lv: s.replace(s[...].at[index].set(lv.get_raw_value())), state, layer_state, is_leaf=lambda x: isinstance(x, nnx.Variable), From 26d9a435c04a5a675b7158ffcae6c60e7d937b21 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 3 Feb 2026 22:38:58 -0800 Subject: [PATCH 08/20] update --- skyrl-tx/tx/layers/stacked.py | 20 ++++++-------------- skyrl-tx/tx/utils/models.py | 4 +--- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 70101eec9..9bf581bd5 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -21,6 +21,12 @@ def __getitem__(self, key): parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") return parent[idx] if key is Ellipsis else parent[idx][key] + def set_raw_value(self, value, **kwargs): + """Write through to parent when value is set.""" + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + parent[...] = parent[...].at[idx].set(value) + super().set_raw_value(value, **kwargs) + @property def shape(self): return self.get_metadata("_parent")[self.get_metadata("_idx")].shape @@ -73,20 +79,6 @@ def __iter__(self): def is_stacked(self) -> bool: return True - def __setitem__(self, index: int, layer: nnx.Module): - """Update stacked state from a modified layer (for testing/weight loading).""" - if index < 0 or index >= self.num_layers: - raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") - graphdef, state = nnx.split(self._stacked) - _, layer_state = nnx.split(layer) - new_state = jax.tree.map( - lambda s, lv: s.replace(s[...].at[index].set(lv.get_raw_value())), - state, - layer_state, - is_leaf=lambda x: isinstance(x, nnx.Variable), - ) - self._stacked = nnx.merge(graphdef, new_state) - def __call__( self, hidden_states: jax.Array, diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 44ba6acef..dc60b34e3 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -135,9 +135,7 @@ def load_params(module: nnx.Module, key_prefix: str): def load_recursive(module: nnx.Module, key_prefix: str): if getattr(module, "is_stacked", False): for i in range(len(module)): - layer = module[i] - load_params(layer, f"{key_prefix}{i}.") - module[i] = layer + load_params(module[i], f"{key_prefix}{i}.") # ArrayRef writes through elif children := list(nnx.iter_children(module)): for name, child in children: load_recursive(child, f"{key_prefix}{name}.") From 3751008dbd143524f9cd2a8c400d7479055a8fa3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 00:34:44 -0800 Subject: [PATCH 09/20] update --- skyrl-tx/tests/models/test_qwen3_lora_training.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index 85f5f3bda..8273dfead 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -7,7 +7,7 @@ from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM -from tx.utils.models import get_dtype, load_safetensors +from tx.utils.models import get_dtype, load_safetensors, get_adapter_idx from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig @@ -47,15 +47,16 @@ def loss_fn(model, input_ids, target_ids, attention_mask): # Helper to extract adapter params at specific index def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + return jax.tree.map_with_path(lambda path, p: p[get_adapter_idx(path, adapter_idx)].copy(), params) # Helper to extract out-of-rank params for an adapter def get_out_of_rank_params(params, adapter_idx, rank): def slice_param(path, p): + idx = get_adapter_idx(path, adapter_idx) if "lora_A" in str(path): - return p[adapter_idx, :, rank:].copy() + return p[idx][..., :, rank:].copy() elif "lora_B" in str(path): - return p[adapter_idx, rank:, :].copy() + return p[idx][..., rank:, :].copy() return p return jax.tree.map_with_path(slice_param, params) From 3f2879d18404ef9415d728e7542f430800f9689a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 00:53:49 -0800 Subject: [PATCH 10/20] update --- skyrl-tx/tx/layers/stacked.py | 20 +++++++++++++++ skyrl-tx/tx/utils/models.py | 46 ++++++++++++++--------------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 9bf581bd5..1ce53ee99 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -155,3 +155,23 @@ def body_fn(hs, layer_params): values_list = [all_values[i] for i in range(self.num_layers)] cache_position = attention_mask.sum(axis=1).astype(jnp.int32) return final_hs, all_hidden_states, KVCache(keys=keys_list, values=values_list, cache_position=cache_position) + + +def unstack_state(module: nnx.Module) -> nnx.GraphState: + """Transform stacked layer state to unstacked ArrayRef views. + + Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. + Each entry is an ArrayRef that writes through to the original stacked variable. + """ + expanded = [] + for path, var in nnx.to_flat_state(nnx.state(module)): + if "_stacked" not in path: + expanded.append((path, var)) + continue + + idx = path.index("_stacked") + for i in range(var[...].shape[0]): + new_path = path[:idx] + (str(i),) + path[idx + 1 :] + expanded.append((new_path, ArrayRef(var, i))) + + return nnx.from_flat_state(expanded) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index dc60b34e3..fb3dabed3 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -109,40 +109,30 @@ def load_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: + from tx.layers.stacked import unstack_state + tensors = {} for file in Path(checkpoint_dir).glob("*.safetensors"): tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} - def load_params(module: nnx.Module, key_prefix: str): - updates = [] - for path, param in nnx.to_flat_state(nnx.state(module)): - if filter_fn is not None and not filter_fn(path): - continue - key = key_prefix + get_param_key(path) - if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): - continue - if "experts" in path: - tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) - else: - tensor = tensors[key] if "embed_tokens" in key else tensors[key].T - if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: - tensor = tensor.reshape(param.shape) - assert param.shape == tensor.shape, f"shape mismatch for {key}" - updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) - nnx.update(module, nnx.from_flat_state(updates)) - - def load_recursive(module: nnx.Module, key_prefix: str): - if getattr(module, "is_stacked", False): - for i in range(len(module)): - load_params(module[i], f"{key_prefix}{i}.") # ArrayRef writes through - elif children := list(nnx.iter_children(module)): - for name, child in children: - load_recursive(child, f"{key_prefix}{name}.") + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) with ArrayRef write-through, matching checkpoint key format + for path, param in nnx.to_flat_state(unstack_state(model)): + if filter_fn is not None and not filter_fn(path): + continue + key = get_param_key(path) + if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): + continue + if "experts" in path: + tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) else: - load_params(module, key_prefix) - - load_recursive(model, "") + tensor = tensors[key] if "embed_tokens" in key else tensors[key].T + if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: + tensor = tensor.reshape(param.shape) + assert param.shape == tensor.shape, f"shape mismatch for {key}" + # ArrayRef.set_raw_value writes through to the stacked parent variable + param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) def save_safetensors( From 52fcccf6659ebbf6de674293c3b8d9f391a879b4 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 01:40:06 -0800 Subject: [PATCH 11/20] update --- skyrl-tx/tests/models/test_qwen3.py | 5 +---- skyrl-tx/tx/tinker/backends/jax.py | 8 +++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index 7c06b757d..55a779c9e 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -246,8 +246,7 @@ def test_qwen3_lora(): ) # Load layer LoRA weights - for i in range(len(model.model.layers)): - layer = model.model.layers[i] + for i, layer in enumerate(model.model.layers): hf_layer = hf_model.base_model.model.model.layers[i] for module, projections in [ ("mlp", ["gate_proj", "up_proj", "down_proj"]), @@ -263,8 +262,6 @@ def test_qwen3_lora(): scaling=lora_config.lora_alpha / lora_config.r, rank=lora_config.r, ) - # Write back the modified layer to update stacked weights - model.model.layers[i] = layer # Use different adapter indices for each input adapter_indices = jnp.arange(len(lora_adapters), dtype=jnp.int32) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 27c50a59b..b54948a79 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -135,12 +135,10 @@ def _select_mean(path, g): def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": """Reset gradients and count for a specific adapter.""" - def _reset(path, g): - idx = get_adapter_idx(path, adapter_index) - return g.at[idx].set(0.0) - return AccumulatedGradients( - grad_sum=jax.tree.map_with_path(_reset, self.grad_sum), + grad_sum=jax.tree.map_with_path( + lambda path, g: g.at[get_adapter_idx(path, adapter_index)].set(0.0), self.grad_sum + ), counts=self.counts.at[adapter_index].set(0), ) From 851dd0f937c782ff81a131250a1e0f767b8d29fe Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 02:08:52 -0800 Subject: [PATCH 12/20] update --- skyrl-tx/tx/utils/models.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index fb3dabed3..1efec2bc0 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -142,7 +142,11 @@ def save_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: - model_params = nnx.to_flat_state(nnx.state(model)) + from tx.layers.stacked import unstack_state + + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) matching the checkpoint key format used by HuggingFace + model_params = nnx.to_flat_state(unstack_state(model)) tensors = {} for path, param in model_params: if "rngs" in path: @@ -260,13 +264,13 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: "Helper function to extract the adapter parameters for a specific adapter index." def extract_state(path: tuple, p: jnp.ndarray): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return p - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p[adapter_index, ..., :, :rank] - if path[-2].key == "lora_B": - return p[adapter_index, ..., :rank, :] + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p[idx][..., :, :rank] + return p[idx][..., :rank, :] return jax.tree.map_with_path(extract_state, lora_params) @@ -279,13 +283,13 @@ def insert_adapter_state( "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." def insert_state(path: tuple, p: jax.Array, new: jax.Array): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return new - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p.at[adapter_index, ..., :, :rank].set(new) - elif path[-2].key == "lora_B": - return p.at[adapter_index, ..., :rank, :].set(new) + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p.at[*idx, ..., :, :rank].set(new) + return p.at[*idx, ..., :rank, :].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated) From c2cdecbcee15da05f1cc3fefe19b850ea1988444 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 02:15:01 -0800 Subject: [PATCH 13/20] cleanup --- skyrl-tx/tx/layers/stacked.py | 4 ---- skyrl-tx/tx/utils/models.py | 7 +------ 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 1ce53ee99..17a8b5d4c 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -75,10 +75,6 @@ def __iter__(self): for i in range(self.num_layers): yield self[i] - @property - def is_stacked(self) -> bool: - return True - def __call__( self, hidden_states: jax.Array, diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 1efec2bc0..32db1f6f4 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -91,14 +91,9 @@ def get_expert_key(path: tuple, expert_idx: int) -> str: return ".".join(map(str, path)) -def is_stacked_path(path: tuple) -> bool: - """Check if a parameter path is for stacked layers (contains '_stacked').""" - return "_stacked" in str(path) - - def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: """Return index tuple for accessing an adapter. Stacked: [:, idx], non-stacked: [idx].""" - return (slice(None), adapter_index) if is_stacked_path(path) else (adapter_index,) + return (slice(None), adapter_index) if "_stacked" in str(path) else (adapter_index,) def load_safetensors( From 7cfa89858b1abbc587525053c2b19c1cc81d330a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 02:18:34 -0800 Subject: [PATCH 14/20] update --- skyrl-tx/tx/layers/stacked.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 17a8b5d4c..113e10b15 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -146,11 +146,11 @@ def body_fn(hs, layer_params): if is_training: return final_hs, all_hidden_states, None - # Prefill: convert stacked arrays to list and create KVCache - keys_list = [all_keys[i] for i in range(self.num_layers)] - values_list = [all_values[i] for i in range(self.num_layers)] - cache_position = attention_mask.sum(axis=1).astype(jnp.int32) - return final_hs, all_hidden_states, KVCache(keys=keys_list, values=values_list, cache_position=cache_position) + return final_hs, all_hidden_states, KVCache( + keys=[all_keys[i] for i in range(self.num_layers)], + values=[all_values[i] for i in range(self.num_layers)], + cache_position=attention_mask.sum(axis=1).astype(jnp.int32), + ) def unstack_state(module: nnx.Module) -> nnx.GraphState: From e2352d3b618113e361e0831cb8f81e8efba54cb5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 10:27:48 -0800 Subject: [PATCH 15/20] fix test --- skyrl-tx/tx/layers/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..5bec8f3ce 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -97,8 +97,9 @@ def shard_map_ep(module: nnx.Module, func, *args): """ graphdef, state = nnx.split(module) # Extract only 'ep' dims from PartitionSpecs, replacing others with None + # Also strip the leading dimension from PartitionSpec, which corresponds to the weight stacking dimension state_specs = jax.tree.map( - lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s, + lambda s: PartitionSpec(*(p if p == "ep" else None for p in s[1:])) if isinstance(s, PartitionSpec) else s, nnx.get_partition_spec(state), is_leaf=lambda x: isinstance(x, PartitionSpec), ) From de2229ecd316267631f1f89b1635c8b3ae191128 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 10:30:06 -0800 Subject: [PATCH 16/20] black --- skyrl-tx/tx/layers/stacked.py | 24 ++++++++++++++++-------- skyrl-tx/tx/tinker/backends/jax.py | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 113e10b15..bbb756e9d 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -116,10 +116,14 @@ def __call__( new_keys.append(k) new_values.append(v) - return hidden_states, all_hidden_states, KVCache( - keys=new_keys, - values=new_values, - cache_position=kv_cache.cache_position + positions.shape[1], + return ( + hidden_states, + all_hidden_states, + KVCache( + keys=new_keys, + values=new_values, + cache_position=kv_cache.cache_position + positions.shape[1], + ), ) # Training/prefill mode: use scan @@ -146,10 +150,14 @@ def body_fn(hs, layer_params): if is_training: return final_hs, all_hidden_states, None - return final_hs, all_hidden_states, KVCache( - keys=[all_keys[i] for i in range(self.num_layers)], - values=[all_values[i] for i in range(self.num_layers)], - cache_position=attention_mask.sum(axis=1).astype(jnp.int32), + return ( + final_hs, + all_hidden_states, + KVCache( + keys=[all_keys[i] for i in range(self.num_layers)], + values=[all_values[i] for i in range(self.num_layers)], + cache_position=attention_mask.sum(axis=1).astype(jnp.int32), + ), ) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index b54948a79..d2757c9bd 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -127,6 +127,7 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated def get_mean(self, adapter_index: jax.Array) -> nnx.State: """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" count = self.counts[adapter_index] + def _select_mean(path, g): idx = get_adapter_idx(path, adapter_index) return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype)) From ca212c193f022d2b4a6ef5db682debbf5faa38d9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 4 Feb 2026 23:45:14 -0800 Subject: [PATCH 17/20] update --- skyrl-tx/tx/layers/util.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 5bec8f3ce..272f09b0c 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -96,12 +96,17 @@ def shard_map_ep(module: nnx.Module, func, *args): *args: Arguments to pass to func (replicated across shards). """ graphdef, state = nnx.split(module) - # Extract only 'ep' dims from PartitionSpecs, replacing others with None - # Also strip the leading dimension from PartitionSpec, which corresponds to the weight stacking dimension - state_specs = jax.tree.map( - lambda s: PartitionSpec(*(p if p == "ep" else None for p in s[1:])) if isinstance(s, PartitionSpec) else s, - nnx.get_partition_spec(state), - is_leaf=lambda x: isinstance(x, PartitionSpec), + + def to_ep_spec(path, s): + if not isinstance(s, PartitionSpec): + return s + # Strip leading stacking dimension if path contains "_stacked" + dims = s[1:] if "_stacked" in str(path) else s + # Extract only 'ep' dims from PartitionSpecs, replacing others with None + return PartitionSpec(*(p if p == "ep" else None for p in dims)) + + state_specs = jax.tree_util.tree_map_with_path( + to_ep_spec, nnx.get_partition_spec(state), is_leaf=lambda x: isinstance(x, PartitionSpec) ) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) From 27e05236a8ea17cb190390a1dcd45638820c066c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 5 Feb 2026 00:17:54 -0800 Subject: [PATCH 18/20] update --- skyrl-tx/tx/layers/stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index bbb756e9d..487243690 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -59,7 +59,7 @@ def __len__(self) -> int: return self.num_layers def __getitem__(self, index: int) -> nnx.Module: - """Get view into layer at index (stays synced with stacked state).""" + """Get view into layer at index. Only for tests and weight loading.""" if index < 0 or index >= self.num_layers: raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") graphdef, state = nnx.split(self._stacked) From 268c0a7625a321853a83ce6d3fcda0af45d5c026 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 5 Feb 2026 12:44:36 -0800 Subject: [PATCH 19/20] simplify --- skyrl-tx/tx/layers/stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 487243690..2e6eebbbd 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -19,7 +19,7 @@ def __init__(self, parent: nnx.Variable, idx: int): def __getitem__(self, key): parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") - return parent[idx] if key is Ellipsis else parent[idx][key] + return parent[idx][key] def set_raw_value(self, value, **kwargs): """Write through to parent when value is set.""" From f9c7b7edcbef07372d46b36e9f040b6b02e75752 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 5 Feb 2026 12:57:11 -0800 Subject: [PATCH 20/20] simplify --- skyrl-tx/tx/layers/stacked.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py index 2e6eebbbd..71cd662e4 100644 --- a/skyrl-tx/tx/layers/stacked.py +++ b/skyrl-tx/tx/layers/stacked.py @@ -171,11 +171,9 @@ def unstack_state(module: nnx.Module) -> nnx.GraphState: for path, var in nnx.to_flat_state(nnx.state(module)): if "_stacked" not in path: expanded.append((path, var)) - continue - - idx = path.index("_stacked") - for i in range(var[...].shape[0]): - new_path = path[:idx] + (str(i),) + path[idx + 1 :] - expanded.append((new_path, ArrayRef(var, i))) - + else: + idx = path.index("_stacked") + for i in range(var[...].shape[0]): + new_path = path[:idx] + (str(i),) + path[idx + 1 :] + expanded.append((new_path, ArrayRef(var, i))) return nnx.from_flat_state(expanded)