Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
8913626
feat: add chunked lm_head for memory-efficient logprobs computation
raulchen Jan 13, 2026
9726415
fix: fallback to non-chunked loss when train_unembed=True or chunk_si…
raulchen Jan 14, 2026
3fa6d2d
add tests
raulchen Jan 20, 2026
801f1e9
checkpoint
raulchen Jan 20, 2026
07469ff
deprecation warning
raulchen Jan 20, 2026
30f083a
lint
raulchen Jan 21, 2026
f318cbb
feat: add per-layer gradient checkpointing
raulchen Jan 13, 2026
a763fce
feat: use fori_loop for gradient checkpointing to enable XLA buffer r…
raulchen Jan 13, 2026
3676aae
fix: use attention_mask instead of seq_lengths in model forward
raulchen Jan 20, 2026
cb083ae
fix: pass is_training=True to enable gradient checkpointing
raulchen Jan 20, 2026
c368f23
feat: use scan instead of fori_loop to support output_hidden_states
raulchen Jan 20, 2026
9ef7e17
perf: return None from scan when output_hidden_states=False to save m…
raulchen Jan 20, 2026
03f64fb
fix: exclude last layer output from all_hidden_states to match non-ch…
raulchen Jan 20, 2026
94a5a56
test: add gradient checkpointing tests
raulchen Jan 20, 2026
9ec6b17
fix
raulchen Jan 21, 2026
f3cda4f
lint
raulchen Jan 21, 2026
5cf1c66
fix: add guard for empty layers in checkpointed forward
raulchen Jan 21, 2026
561d308
Merge main into chunked-lm-head
raulchen Jan 21, 2026
cb0e72e
Unify logprobs computation in LogitsProcessor
raulchen Jan 21, 2026
dc6f2a4
fix: restore skip_prompt_logits parameter (separate from skip_logits)
raulchen Jan 21, 2026
1e4b246
docs: add LogitsProcessor design document
raulchen Jan 21, 2026
5e2d937
refactor: implement LogitsProcessor design
raulchen Jan 22, 2026
7f9a762
refactor: encapsulate LogitsProcessor in CausalLMBase
raulchen Jan 22, 2026
6cbe1cb
inline logits processor
raulchen Jan 22, 2026
cd2fd4e
refactor: runtime train_unembed check with per-adapter mask
raulchen Jan 22, 2026
f9cb177
refactor: explicit CausalLMBase.__init__ for lm_head
raulchen Jan 22, 2026
929d96b
remove doc
raulchen Jan 22, 2026
b1254c6
rename test_logits_processor to test_compute_logits
raulchen Jan 22, 2026
1ad1612
fix: DummyModel calls CausalLMBase.__init__
raulchen Jan 22, 2026
9e396a3
refactor: remove ModelForCausalLM Protocol, use CausalLMBase
raulchen Jan 22, 2026
3451144
refactor: move config to CausalLMBase.__init__
raulchen Jan 22, 2026
4a63a2b
fix: lm_head type is Callable, not LoRALinear
raulchen Jan 22, 2026
2789a48
Revert: remove chunked logprobs (to be submitted in separate PR)
raulchen Jan 22, 2026
e149112
refactor: split test_models_common into focused tests
raulchen Jan 22, 2026
9575da3
lint
raulchen Jan 22, 2026
36a6961
address comments
raulchen Jan 22, 2026
f6ed3fb
fix: pass adapter_indices to compute_logprobs for prompt logprobs
raulchen Jan 22, 2026
d635429
use mixin
raulchen Jan 22, 2026
a81c27f
feat: add chunked cross-entropy loss computation
raulchen Jan 22, 2026
38175fe
fix
raulchen Jan 22, 2026
5241683
fix
raulchen Jan 22, 2026
ab68bd7
refine tests
raulchen Jan 22, 2026
8b5b02d
address comments
raulchen Jan 22, 2026
53a2466
Merge refactor-logits-compute into chunked-lm-head
raulchen Jan 22, 2026
10ff606
fix: use float32 and per-model tolerances in test_compute_logits
raulchen Jan 22, 2026
418fb3b
Merge branch 'refactor-logits-compute' into chunked-lm-head
raulchen Jan 23, 2026
0781e20
fix: use float32 and per-model tolerances in test_compute_logits
raulchen Jan 22, 2026
7616066
Merge refactor-logits-compute with simplified tolerance
raulchen Jan 23, 2026
ff949df
remove comment
raulchen Jan 23, 2026
fccbbba
Merge refactor-logits-compute into chunked-lm-head
raulchen Jan 23, 2026
1bde686
remove comment
raulchen Jan 23, 2026
42ef8f0
lint
raulchen Jan 23, 2026
8831bf2
lint
raulchen Jan 23, 2026
07b7be7
empty
raulchen Jan 23, 2026
d55e04c
refactor: use lm_head() in chunked path to support LoRA
raulchen Jan 23, 2026
006d412
cleanup: remove _train_unembed_mask and simplify chunked lm_head
raulchen Jan 23, 2026
b2f8eba
refactor: compute adapter indices on-the-fly in chunked path
raulchen Jan 23, 2026
a82cd53
fix: load one model at a time in test_compute_logits to avoid OOM
raulchen Jan 23, 2026
345d5c1
lint
raulchen Jan 23, 2026
d61878f
Merge refactor-logits-compute into chunked-lm-head
raulchen Jan 23, 2026
2f78bab
fix: add missing config args and restore test_chunked_logprobs
raulchen Jan 23, 2026
074a6cc
merge: incorporate main branch changes
raulchen Jan 23, 2026
e0cb768
test: load one model at a time in test_chunked_logprobs
raulchen Jan 23, 2026
9d90795
test: load one backend at a time in test_mixed_train_unembed_adapters
raulchen Jan 23, 2026
d5a2133
inherit
raulchen Jan 23, 2026
4e39b49
test: add unit tests for chunked logprobs edge cases
raulchen Jan 23, 2026
0925010
lint
raulchen Jan 23, 2026
fa93a01
default values
raulchen Jan 23, 2026
445a4c8
empty
raulchen Jan 23, 2026
7bea8fe
Merge chunked-lm-head into per-layer-checkpointing
raulchen Jan 23, 2026
1eca137
minor cleanup
raulchen Jan 23, 2026
0ef5ea3
refactor: extract forward layer utilities to reduce duplication
raulchen Jan 23, 2026
3ce5eab
Merge main into per-layer-checkpointing
raulchen Jan 26, 2026
572a697
fix: remove unused new_cache_position variable
raulchen Jan 26, 2026
2c5b3a7
remove comments
raulchen Jan 26, 2026
246c2af
fix
raulchen Jan 26, 2026
159dc82
remove comment
raulchen Jan 26, 2026
58527c7
unify forward_layers
raulchen Jan 26, 2026
53316f7
model.train()
raulchen Jan 26, 2026
aea4dae
Merge branch 'main' into per-layer-checkpointing
raulchen Jan 27, 2026
113bd92
stack weights
raulchen Jan 29, 2026
6ebf1b9
remove duplication
raulchen Jan 29, 2026
dbe5114
remove duplication
raulchen Jan 29, 2026
15b4086
load model twice
raulchen Jan 29, 2026
a3adadd
type hints
raulchen Jan 29, 2026
52a3aff
Merge branch 'main' into per-layer-checkpointing
raulchen Jan 29, 2026
a552dfc
fix
raulchen Jan 29, 2026
42f9a14
Merge per-layer-checkpointing into stack-weights
raulchen Jan 29, 2026
687f2a5
minor fixes
raulchen Jan 29, 2026
55a42e6
simplify and optimize forward_layers
raulchen Jan 29, 2026
38509ce
skip skyrl-train
raulchen Jan 29, 2026
6d4d17d
simplify models.py
raulchen Jan 30, 2026
846aa96
clean up lora.py
raulchen Jan 30, 2026
5217343
fix tests/utils
raulchen Jan 30, 2026
6bf3cae
Update tests and load_safetensors for stacked layer format
raulchen Jan 30, 2026
801458b
Add workarounds for non-stacked DeepSeekV3 layers
raulchen Jan 30, 2026
e7bab93
Revert "Add workarounds for non-stacked DeepSeekV3 layers"
raulchen Jan 30, 2026
c18747d
Implement split stacked layers for DeepSeekV3
raulchen Jan 30, 2026
650c926
Remove unused train/eval methods from all models
raulchen Jan 30, 2026
c669b34
Remove .train()/.eval() calls no longer needed
raulchen Jan 30, 2026
68f82df
Fix outdated test name and improve dtype cast comment
raulchen Jan 30, 2026
301f7dc
Refactor: remove unused code and consolidate stacked path utilities
raulchen Jan 30, 2026
6abe6e7
Fix tinker tests for stacked layer access
raulchen Jan 30, 2026
4cdd7dc
lint
raulchen Jan 30, 2026
3fd1420
Fix AccumulatedGradients indexing for stacked layer params
raulchen Jan 30, 2026
acb98fd
revert pyproject
raulchen Jan 30, 2026
8cbebe5
Merge branch 'main' into stack-weights
raulchen Jan 30, 2026
8cfe622
Refactor: extract _lora_slice helper to reduce duplication
raulchen Jan 30, 2026
a8a3e52
Add tests for stacked layer utilities
raulchen Jan 30, 2026
e3ed933
Add mlp type annotation to DeepseekV3DecoderLayer base class
raulchen Jan 30, 2026
7d5bf5b
Fix Qwen3 MoE softmax ordering to match HuggingFace
raulchen Jan 30, 2026
3651dec
Address PR review feedback
raulchen Jan 30, 2026
b6a6f95
Add get_adapter_idx to consolidate stacked/non-stacked indexing
raulchen Jan 30, 2026
6f8e486
Revert "Fix Qwen3 MoE softmax ordering to match HuggingFace"
raulchen Jan 31, 2026
2f2f765
Remove redundant _is_stacked_layer_param function
raulchen Jan 31, 2026
ab1a7c9
Use KVCache.split() and concatenate() in DeepseekV3
raulchen Jan 31, 2026
9635e4d
lint
raulchen Jan 31, 2026
1bf80be
fix
raulchen Jan 31, 2026
3abaa7c
skip kv cache for training
raulchen Jan 31, 2026
209f959
Fix shard_map_ep PartitionSpec length mismatch for extracted layers
raulchen Jan 31, 2026
23f2484
Merge branch 'main' into stack-weights
pcmoritz Jan 31, 2026
5122c2c
remove closure
pcmoritz Jan 31, 2026
2c0c3e9
Fix create_stacked_layers to avoid vmap memory overhead
raulchen Feb 3, 2026
40f99d4
Optimize create_stacked_layers to avoid 2x peak memory
raulchen Feb 4, 2026
bceff5f
Use KV cache as scan carry for buffer donation
raulchen Feb 5, 2026
08ec23a
Simplify create_stacked_layers while preserving memory efficiency
raulchen Feb 5, 2026
993d6de
Merge branch 'main' into stack-weights
raulchen Feb 5, 2026
98d5429
Sync NNX sharding metadata after stacking layers
raulchen Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions skyrl-tx/tests/models/lora_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Shared test utilities for LoRA training tests."""

import jax
import jax.numpy as jnp

from tx.utils.models import get_adapter_idx


def get_adapter_params(params, adapter_idx: int):
"""Extract adapter params at a specific index.

Decoder layer LoRA params have shape (num_layers, num_adapters, ...).
Embed tokens LoRA params have shape (num_adapters, ...).
"""

def extract(path, p):
idx = get_adapter_idx(path, adapter_idx)
return p[idx].copy()

return jax.tree.map_with_path(extract, params)


def _slice_out_of_rank(params, adapter_idx: int, get_rank):
"""Extract out-of-rank params using a rank function.

Args:
params: LoRA parameters tree.
adapter_idx: Adapter index to extract.
get_rank: Function (path) -> int returning effective rank for that path.
"""

def slice_param(path, p):
path_str = str(path)
if "lora_A" not in path_str and "lora_B" not in path_str:
return p
rank = get_rank(path)
idx = get_adapter_idx(path, adapter_idx)
if "lora_A" in path_str:
return p[idx + (..., slice(rank, None))].copy()
return p[idx + (..., slice(rank, None), slice(None))].copy()

return jax.tree.map_with_path(slice_param, params)


def get_out_of_rank_params(params, adapter_idx: int, rank: int):
"""Extract out-of-rank params for an adapter."""
return _slice_out_of_rank(params, adapter_idx, lambda _: rank)


def verify_params_unchanged(initial_params, final_params, error_msg_prefix: str):
"""Verify that params haven't changed between initial and final state."""
for (path, initial), (_, final) in zip(
jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params)
):
assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}"


def _is_routed_expert_path(path) -> bool:
"""Check if path is for routed experts (not shared_experts)."""
keys = []
for p in path:
if hasattr(p, "key"):
keys.append(str(p.key))
elif hasattr(p, "name"):
keys.append(str(p.name))
for i, key in enumerate(keys):
if key == "experts" and i > 0 and keys[i - 1] == "mlp":
return True
return False


def get_moe_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int):
"""Extract out-of-rank params for MoE models.

For routed experts, uses effective rank = max(1, rank // num_experts).
"""

def get_rank(path):
return max(1, rank // num_experts) if _is_routed_expert_path(path) else rank

return _slice_out_of_rank(params, adapter_idx, get_rank)
48 changes: 48 additions & 0 deletions skyrl-tx/tests/models/test_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,51 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int):
output_merged = moe_layer_merged(x_sample)

assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3)


def test_deepseekv3_gradient_checkpointing():
"""Test that gradient checkpointing produces identical outputs for DeepSeekV3.

DeepSeekV3 has split stacking (dense_layers + moe_layers), so this tests
that gradient checkpointing works correctly with heterogeneous layer types.
"""
model_name = "yujiepan/deepseek-v3-tiny-random"
base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True)

batch_size, seq_len = 2, 8
mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)

results = {}
for use_checkpointing in [False, True]:
config = DeepseekV3Config(
base_config,
max_lora_adapters=1,
max_lora_rank=1,
shard_attention_heads=True,
gradient_checkpointing=use_checkpointing,
)
with jax.set_mesh(mesh):
model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))

input_ids = jax.random.randint(jax.random.key(42), (batch_size, seq_len), 0, config.vocab_size)
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)

out = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
logits = model.compute_logits(out.last_hidden_state)

results[use_checkpointing] = {
"logits": np.array(logits),
"hidden_states": [np.array(hs) for hs in out.hidden_states],
"kv_cache_shape": out.kv_cache.keys.shape,
}

# Verify outputs match
np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6)

# Verify hidden states match
assert len(results[False]["hidden_states"]) == len(results[True]["hidden_states"])
for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])):
np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}")

# Verify KV cache shape is correct (num_layers, batch, seq, heads, dim)
assert results[True]["kv_cache_shape"][0] == config.num_hidden_layers
75 changes: 13 additions & 62 deletions skyrl-tx/tests/models/test_deepseekv3_lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,11 @@
from tx.layers.lora import init_lora_adapter
from tx.tinker.types import LoraConfig


def _is_routed_expert_path(path) -> bool:
"""Disambiguate shared_experts and experts"""
keys = []
for p in path:
if hasattr(p, "key"):
keys.append(str(p.key))
elif hasattr(p, "name"):
keys.append(str(p.name))

for i, key in enumerate(keys):
if key == "experts" and i > 0 and keys[i - 1] == "mlp":
return True
return False


def _get_out_of_rank_params(params, adapter_idx: int, rank: int, num_experts: int):
"""Extract out-of-rank params, using effective rank for routed expert layers."""

def slice_param(path, p):
path_str = str(path)

if _is_routed_expert_path(path):
effective_rank = max(1, rank // num_experts)
else:
effective_rank = rank

if "lora_A" in path_str:
# lora_A shape: [adapters, ..., max_rank] - slice last dim
return p[adapter_idx, ..., effective_rank:].copy()
elif "lora_B" in path_str:
# lora_B shape: [adapters, ..., max_rank, out] - slice second-to-last dim
return p[adapter_idx, ..., effective_rank:, :].copy()
return p

return jax.tree.map_with_path(slice_param, params)
from tests.models.lora_test_utils import (
get_adapter_params,
get_moe_out_of_rank_params,
verify_params_unchanged,
)


def test_lora_training_moe_rank_normalized():
Expand Down Expand Up @@ -85,15 +54,12 @@ def loss_fn(model, input_ids, target_ids, attention_mask):

graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...)

def get_adapter_params(params, adapter_idx):
return jax.tree.map(lambda p: p[adapter_idx].copy(), params)

num_experts = config.n_routed_experts

# Save initial states
initial_adapter_2_params = get_adapter_params(lora_params, 2)
initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts)
initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts)
initial_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts)
initial_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts)

initial_loss = None

Expand All @@ -116,24 +82,18 @@ def loss_for_lora(lora_params):

final_loss = float(loss)

def verify_params_unchanged(initial_params, final_params, error_msg_prefix):
for (path, initial), (_, final) in zip(
jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params)
):
assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}"

assert final_loss < initial_loss, f"Loss did not decrease: {initial_loss} -> {final_loss}"

# Verify unused adapter was not modified
final_adapter_2_params = get_adapter_params(lora_params, 2)
verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified")

# Verify out-of-rank params were not modified
final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts)
final_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts)
verify_params_unchanged(
initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified"
)
final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts)
final_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts)
verify_params_unchanged(
initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified"
)
Expand Down Expand Up @@ -172,9 +132,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask):

graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...)

def get_adapter_params(params, adapter_idx):
return jax.tree.map(lambda p: p[adapter_idx].copy(), params)

num_experts = config.n_routed_experts

# Save initial states for all unused adapters
Expand All @@ -183,8 +140,8 @@ def get_adapter_params(params, adapter_idx):
initial_adapter_4_params = get_adapter_params(lora_params, 4)

# Save out-of-rank params for adapters 0 and 1
initial_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts)
initial_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts)
initial_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts)
initial_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts)

# Training loop
for step in range(10):
Expand All @@ -200,12 +157,6 @@ def loss_for_lora(lora_params):

print(f"Step {step}: loss = {float(loss):.4f}")

def verify_params_unchanged(initial_params, final_params, error_msg_prefix):
for (path, initial), (_, final) in zip(
jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params)
):
assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}"

# Verify unused adapters (2, 3, 4) were not modified
final_adapter_2_params = get_adapter_params(lora_params, 2)
verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified")
Expand All @@ -217,11 +168,11 @@ def verify_params_unchanged(initial_params, final_params, error_msg_prefix):
verify_params_unchanged(initial_adapter_4_params, final_adapter_4_params, "Adapter 4 was modified")

# Verify out-of-rank params were not modified
final_adapter_0_out_of_rank = _get_out_of_rank_params(lora_params, 0, 16, num_experts)
final_adapter_0_out_of_rank = get_moe_out_of_rank_params(lora_params, 0, 16, num_experts)
verify_params_unchanged(
initial_adapter_0_out_of_rank, final_adapter_0_out_of_rank, "Adapter 0 out-of-rank params modified"
)
final_adapter_1_out_of_rank = _get_out_of_rank_params(lora_params, 1, 8, num_experts)
final_adapter_1_out_of_rank = get_moe_out_of_rank_params(lora_params, 1, 8, num_experts)
verify_params_unchanged(
initial_adapter_1_out_of_rank, final_adapter_1_out_of_rank, "Adapter 1 out-of-rank params modified"
)
23 changes: 2 additions & 21 deletions skyrl-tx/tests/models/test_llama3_lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from tx.layers.lora import init_lora_adapter
from tx.tinker.types import LoraConfig

from tests.models.lora_test_utils import get_adapter_params, get_out_of_rank_params, verify_params_unchanged


def test_lora_training():
base_model = "unsloth/Llama-3.2-1B"
Expand Down Expand Up @@ -45,21 +47,6 @@ def loss_fn(model, input_ids, target_ids, attention_mask):
# that we want to compute gradients for
graphdef, lora_params, non_lora_params = nnx.split(model, model.is_lora_param, ...)

# Helper to extract adapter params at specific index
def get_adapter_params(params, adapter_idx):
return jax.tree.map(lambda p: p[adapter_idx].copy(), params)

# Helper to extract out-of-rank params for an adapter
def get_out_of_rank_params(params, adapter_idx, rank):
def slice_param(path, p):
if "lora_A" in str(path):
return p[adapter_idx, :, rank:].copy()
elif "lora_B" in str(path):
return p[adapter_idx, rank:, :].copy()
return p

return jax.tree.map_with_path(slice_param, params)

# Save initial states
initial_adapter_2_params = get_adapter_params(lora_params, 2)
initial_adapter_0_out_of_rank = get_out_of_rank_params(lora_params, 0, 16)
Expand All @@ -79,12 +66,6 @@ def loss_for_lora(lora_params):

print(f"Step {step}: loss = {float(loss):.4f}")

def verify_params_unchanged(initial_params, final_params, error_msg_prefix):
for (path, initial), (_, final) in zip(
jax.tree.leaves_with_path(initial_params), jax.tree.leaves_with_path(final_params)
):
assert jnp.allclose(initial, final), f"{error_msg_prefix} for {path}"

# Verify adapter 2 (unused) was not modified
final_adapter_2_params = get_adapter_params(lora_params, 2)
verify_params_unchanged(initial_adapter_2_params, final_adapter_2_params, "Adapter 2 was modified")
Expand Down
Loading
Loading