diff --git a/skyrl-tx/tests/models/test_qwen3_lora_training.py b/skyrl-tx/tests/models/test_qwen3_lora_training.py index 85f5f3bda..8273dfead 100644 --- a/skyrl-tx/tests/models/test_qwen3_lora_training.py +++ b/skyrl-tx/tests/models/test_qwen3_lora_training.py @@ -7,7 +7,7 @@ from tx.models.configs import Qwen3Config from tx.models.qwen3 import Qwen3ForCausalLM -from tx.utils.models import get_dtype, load_safetensors +from tx.utils.models import get_dtype, load_safetensors, get_adapter_idx from tx.layers.lora import init_lora_adapter from tx.tinker.types import LoraConfig @@ -47,15 +47,16 @@ def loss_fn(model, input_ids, target_ids, attention_mask): # Helper to extract adapter params at specific index def get_adapter_params(params, adapter_idx): - return jax.tree.map(lambda p: p[adapter_idx].copy(), params) + return jax.tree.map_with_path(lambda path, p: p[get_adapter_idx(path, adapter_idx)].copy(), params) # Helper to extract out-of-rank params for an adapter def get_out_of_rank_params(params, adapter_idx, rank): def slice_param(path, p): + idx = get_adapter_idx(path, adapter_idx) if "lora_A" in str(path): - return p[adapter_idx, :, rank:].copy() + return p[idx][..., :, rank:].copy() elif "lora_B" in str(path): - return p[adapter_idx, rank:, :].copy() + return p[idx][..., rank:, :].copy() return p return jax.tree.map_with_path(slice_param, params) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index 776b7af59..8d4bd01a1 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -2,7 +2,7 @@ import jax from jax import numpy as jnp -from tx.utils.models import filter_lora +from tx.utils.models import filter_lora, get_adapter_idx from tx.layers.util import Param, prepare_routing, ragged_dot from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig @@ -345,21 +345,19 @@ def init_adapter(path, value): if not filter_lora(lora_config, normalized_path): effective_rank = 0 + idx = get_adapter_idx(path, adapter_index) key_name = path[-2].key if key_name == "lora_ranks": - return value.at[adapter_index].set(effective_rank) + return value.at[idx].set(effective_rank) if key_name == "lora_scaling": - # Set scaling to 0.0 if rank is 0 - return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0) + scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0 + return value.at[idx].set(scaling) if key_name == "lora_A": - # Reinitialize with he_uniform, then zero columns beyond rank - shape = value[adapter_index].shape - new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype) + new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype) new_A = new_A.at[..., effective_rank:].set(0.0) - return value.at[adapter_index].set(new_A) + return value.at[idx].set(new_A) if key_name == "lora_B": - # Explicitly zero lora_B - return value.at[adapter_index].set(0.0) + return value.at[idx].set(0.0) return value updated_state = jax.tree.map_with_path(init_adapter, state) @@ -376,11 +374,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int): def clear_adapter(path, value): key = path[-2].key - if key == "lora_ranks": - return value.at[adapter_index].set(0) - if key in ("lora_scaling", "lora_A", "lora_B"): - return value.at[adapter_index].set(0.0) - return value + if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"): + return value + idx = get_adapter_idx(path, adapter_index) + return value.at[idx].set(0 if key == "lora_ranks" else 0.0) updated_state = jax.tree.map_with_path(clear_adapter, state) nnx.update(model, updated_state) diff --git a/skyrl-tx/tx/layers/stacked.py b/skyrl-tx/tx/layers/stacked.py new file mode 100644 index 000000000..487243690 --- /dev/null +++ b/skyrl-tx/tx/layers/stacked.py @@ -0,0 +1,181 @@ +"""StackedDecoderLayers module for efficient transformer layer stacking.""" + +from typing import Callable + +from flax import nnx +import jax +import jax.numpy as jnp + +from tx.utils.generator import KVCache + + +class ArrayRef(nnx.Variable): + """A Variable providing a view into an indexed slice of a parent Variable.""" + + def __init__(self, parent: nnx.Variable, idx: int): + super().__init__(parent[idx]) + self.set_metadata("_parent", parent) + self.set_metadata("_idx", idx) + + def __getitem__(self, key): + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + return parent[idx] if key is Ellipsis else parent[idx][key] + + def set_raw_value(self, value, **kwargs): + """Write through to parent when value is set.""" + parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") + parent[...] = parent[...].at[idx].set(value) + super().set_raw_value(value, **kwargs) + + @property + def shape(self): + return self.get_metadata("_parent")[self.get_metadata("_idx")].shape + + +class StackedDecoderLayers(nnx.Module): + """Decoder layers with stacked weights created via nnx.vmap. + + Parameters are stored in stacked format (num_layers, ...). The forward pass + uses jax.lax.scan for training/prefill and Python loops for decode. + """ + + def __init__( + self, + create_layer_fn: Callable[[nnx.Rngs], nnx.Module], + num_layers: int, + rngs: nnx.Rngs, + ): + self.num_layers = num_layers + + @nnx.split_rngs(splits=num_layers) + @nnx.vmap(in_axes=(0,), out_axes=0, transform_metadata={nnx.PARTITION_NAME: None}) + def vmapped_create(rngs: nnx.Rngs) -> nnx.Module: + return create_layer_fn(rngs) + + self._stacked = vmapped_create(rngs) + + def __len__(self) -> int: + """Return the number of layers.""" + return self.num_layers + + def __getitem__(self, index: int) -> nnx.Module: + """Get view into layer at index. Only for tests and weight loading.""" + if index < 0 or index >= self.num_layers: + raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") + graphdef, state = nnx.split(self._stacked) + layer_state = jax.tree.map( + lambda x: ArrayRef(x, index), + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + return nnx.merge(graphdef, layer_state) + + def __iter__(self): + """Iterate over individual layers (for testing/weight loading).""" + for i in range(self.num_layers): + yield self[i] + + def __call__( + self, + hidden_states: jax.Array, + *, + attention_mask: jax.Array, + positions: jax.Array, + adapter_indices: jax.Array | None, + kv_cache: KVCache | None, + output_hidden_states: bool, + gradient_checkpointing: bool, + is_training: bool = False, + ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: + """Forward pass through all layers. + + Uses scan for training/prefill, Python loop for decode. + + Returns: + (final_hidden_states, all_hidden_states, kv_cache) + """ + graphdef, state = nnx.split(self._stacked) + + # Decode mode: use Python loop + if kv_cache is not None: + all_hidden_states = [] + new_keys, new_values = [], [] + + for i in range(self.num_layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + layer = nnx.merge(graphdef, jax.tree.map(lambda x, i=i: x[i], state)) + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=(kv_cache.keys[i], kv_cache.values[i]), + ) + new_keys.append(k) + new_values.append(v) + + return ( + hidden_states, + all_hidden_states, + KVCache( + keys=new_keys, + values=new_values, + cache_position=kv_cache.cache_position + positions.shape[1], + ), + ) + + # Training/prefill mode: use scan + def body_fn(hs, layer_params): + layer = nnx.merge(graphdef, layer_params) + new_hs, (k, v) = layer( + hs, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=None, + ) + if is_training: + k = v = None + return new_hs, (new_hs if output_hidden_states else None, k, v) + + if gradient_checkpointing: + body_fn = jax.checkpoint(body_fn) + + final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) + + all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] + + if is_training: + return final_hs, all_hidden_states, None + + return ( + final_hs, + all_hidden_states, + KVCache( + keys=[all_keys[i] for i in range(self.num_layers)], + values=[all_values[i] for i in range(self.num_layers)], + cache_position=attention_mask.sum(axis=1).astype(jnp.int32), + ), + ) + + +def unstack_state(module: nnx.Module) -> nnx.GraphState: + """Transform stacked layer state to unstacked ArrayRef views. + + Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. + Each entry is an ArrayRef that writes through to the original stacked variable. + """ + expanded = [] + for path, var in nnx.to_flat_state(nnx.state(module)): + if "_stacked" not in path: + expanded.append((path, var)) + continue + + idx = path.index("_stacked") + for i in range(var[...].shape[0]): + new_path = path[:idx] + (str(i),) + path[idx + 1 :] + expanded.append((new_path, ArrayRef(var, i))) + + return nnx.from_flat_state(expanded) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..272f09b0c 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -96,11 +96,17 @@ def shard_map_ep(module: nnx.Module, func, *args): *args: Arguments to pass to func (replicated across shards). """ graphdef, state = nnx.split(module) - # Extract only 'ep' dims from PartitionSpecs, replacing others with None - state_specs = jax.tree.map( - lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s, - nnx.get_partition_spec(state), - is_leaf=lambda x: isinstance(x, PartitionSpec), + + def to_ep_spec(path, s): + if not isinstance(s, PartitionSpec): + return s + # Strip leading stacking dimension if path contains "_stacked" + dims = s[1:] if "_stacked" in str(path) else s + # Extract only 'ep' dims from PartitionSpecs, replacing others with None + return PartitionSpec(*(p if p == "ep" else None for p in dims)) + + state_specs = jax.tree_util.tree_map_with_path( + to_ep_spec, nnx.get_partition_spec(state), is_leaf=lambda x: isinstance(x, PartitionSpec) ) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 1348cac09..9e5332716 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -3,15 +3,16 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh +from tx.layers.attention import dot_product_attention +from tx.layers.layernorm import RMSNorm from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear -from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope -from tx.utils.logits_processor import LogitsProcessorMixin, LMHead -from tx.layers.layernorm import RMSNorm -from tx.layers.attention import dot_product_attention +from tx.layers.stacked import StackedDecoderLayers +from tx.layers.util import prepare_routing, shard_map_ep from tx.models.configs import Qwen3Config from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput from tx.utils.generator import GeneratorMixin, KVCache +from tx.utils.logits_processor import LogitsProcessorMixin, LMHead class Qwen3Attention(nnx.Module): @@ -329,9 +330,11 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> embedding_init=nnx.initializers.normal(), rngs=rngs, ) - self.layers = nnx.List( - [Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)] - ) + + def create_layer(rngs: nnx.Rngs) -> Qwen3DecoderLayer: + return Qwen3DecoderLayer(config, dtype=dtype, rngs=rngs) + + self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) def __call__( @@ -343,28 +346,24 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> ModelOutput: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices) - all_hidden_states: list[jax.Array] = [] - updated_keys, updated_values = [], [] - - for layer_idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states.append(hidden_states) - - hidden_states, (k, v) = layer( - hidden_states, - attention_mask=attention_mask, - positions=positions, - adapter_indices=adapter_indices, - kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]), - ) - updated_keys.append(k) - updated_values.append(v) + + hidden_states, all_hidden_states, new_kv_cache = self.layers( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + output_hidden_states=output_hidden_states, + gradient_checkpointing=self.config.gradient_checkpointing, + is_training=is_training, + ) hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -372,7 +371,7 @@ def __call__( return ModelOutput( last_hidden_state=hidden_states, - kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask), + kv_cache=new_kv_cache, hidden_states=all_hidden_states if output_hidden_states else None, ) @@ -417,6 +416,7 @@ def __call__( output_hidden_states: bool | None = None, adapter_indices: jax.Array | None = None, kv_cache: KVCache | None = None, + is_training: bool = False, ) -> CausalLMOutput: if positions is None: positions = jnp.arange(attention_mask.shape[1])[None, :] @@ -428,6 +428,7 @@ def __call__( output_hidden_states=output_hidden_states, adapter_indices=adapter_indices, kv_cache=kv_cache, + is_training=is_training, ) return CausalLMOutput( diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 99ae33327..d2757c9bd 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -45,6 +45,7 @@ from tx.utils.models import ( get_dtype, get_model_class, + get_adapter_idx, load_safetensors, load_lora_checkpoint, save_lora_checkpoint, @@ -126,15 +127,19 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated def get_mean(self, adapter_index: jax.Array) -> nnx.State: """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" count = self.counts[adapter_index] - return jax.tree.map( - lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), - self.grad_sum, - ) + + def _select_mean(path, g): + idx = get_adapter_idx(path, adapter_index) + return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype)) + + return jax.tree.map_with_path(_select_mean, self.grad_sum) def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": """Reset gradients and count for a specific adapter.""" return AccumulatedGradients( - grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), + grad_sum=jax.tree.map_with_path( + lambda path, g: g.at[get_adapter_idx(path, adapter_index)].set(0.0), self.grad_sum + ), counts=self.counts.at[adapter_index].set(0), ) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index 6e840febf..32db1f6f4 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -91,6 +91,11 @@ def get_expert_key(path: tuple, expert_idx: int) -> str: return ".".join(map(str, path)) +def get_adapter_idx(path: tuple, adapter_index: int) -> tuple: + """Return index tuple for accessing an adapter. Stacked: [:, idx], non-stacked: [idx].""" + return (slice(None), adapter_index) if "_stacked" in str(path) else (adapter_index,) + + def load_safetensors( checkpoint_dir: str | os.PathLike, config: ModelConfig, @@ -99,32 +104,30 @@ def load_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: + from tx.layers.stacked import unstack_state + tensors = {} for file in Path(checkpoint_dir).glob("*.safetensors"): tensors.update(safetensors.numpy.load_file(file)) tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} - model_params = nnx.to_flat_state(nnx.state(model)) - updates = [] - for path, param in model_params: + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) with ArrayRef write-through, matching checkpoint key format + for path, param in nnx.to_flat_state(unstack_state(model)): if filter_fn is not None and not filter_fn(path): continue key = get_param_key(path) - # Skip LoRA parameters if requested if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): continue if "experts" in path: - tensors[key] = np.stack( - [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 - ) + tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) else: - tensors[key] = tensors[key] if "embed_tokens" in path else tensors[key].T + tensor = tensors[key] if "embed_tokens" in key else tensors[key].T if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: - tensors[key] = tensors[key].reshape(param.shape) - assert param.shape == tensors[key].shape, f"shape mismatch for {key}" - sharded_tensor = jax.device_put(tensors[key].astype(param.dtype), param.sharding) - updates.append((path, sharded_tensor)) - nnx.update(model, nnx.from_flat_state(updates)) + tensor = tensor.reshape(param.shape) + assert param.shape == tensor.shape, f"shape mismatch for {key}" + # ArrayRef.set_raw_value writes through to the stacked parent variable + param.set_raw_value(jax.device_put(tensor.astype(param.dtype), param.sharding)) def save_safetensors( @@ -134,7 +137,11 @@ def save_safetensors( prefix: str = "", filter_fn: Callable[[tuple], bool] | None = None, ) -> None: - model_params = nnx.to_flat_state(nnx.state(model)) + from tx.layers.stacked import unstack_state + + # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths + # (layers.0.xxx) matching the checkpoint key format used by HuggingFace + model_params = nnx.to_flat_state(unstack_state(model)) tensors = {} for path, param in model_params: if "rngs" in path: @@ -252,13 +259,13 @@ def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: "Helper function to extract the adapter parameters for a specific adapter index." def extract_state(path: tuple, p: jnp.ndarray): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return p - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p[adapter_index, ..., :, :rank] - if path[-2].key == "lora_B": - return p[adapter_index, ..., :rank, :] + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p[idx][..., :, :rank] + return p[idx][..., :rank, :] return jax.tree.map_with_path(extract_state, lora_params) @@ -271,13 +278,13 @@ def insert_adapter_state( "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." def insert_state(path: tuple, p: jax.Array, new: jax.Array): - if path[-2].key not in {"lora_A", "lora_B"}: + key = path[-2].key + if key not in {"lora_A", "lora_B"}: return new - assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" - if path[-2].key == "lora_A": - return p.at[adapter_index, ..., :, :rank].set(new) - elif path[-2].key == "lora_B": - return p.at[adapter_index, ..., :rank, :].set(new) + idx = get_adapter_idx(path, adapter_index) + if key == "lora_A": + return p.at[*idx, ..., :, :rank].set(new) + return p.at[*idx, ..., :rank, :].set(new) updated = jax.tree.map_with_path(insert_state, lora_params, new_params) nnx.update(lora_params, updated)