Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions skyrl-tx/tests/layers/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import jax
import jax.numpy as jnp
from flax import nnx

from tx.layers.lora import LoRAEmbed


def test_lora_embed_transposed():
"""Test that LoRAEmbed.T correctly applies LoRA adapters with scaling."""
vocab_size = 100
features = 32
max_lora_adapters = 2
max_lora_rank = 4
batch_size = 2
seq_len = 5

# Use realistic alpha/rank scaling values (e.g., alpha=16, rank=4 -> scale=4.0)
lora_scaling_val = jnp.array([4.0, 2.0], dtype=jnp.float32)

mesh = jax.make_mesh(
(1,),
("dp",),
axis_types=(jax.sharding.AxisType.Auto,),
)
with jax.set_mesh(mesh):
embed = LoRAEmbed(
num_embeddings=vocab_size,
features=features,
max_lora_adapters=max_lora_adapters,
max_lora_rank=max_lora_rank,
dtype=jnp.float32,
embedding_init=nnx.with_partitioning(nnx.initializers.normal(0.02), (None, None)),
rngs=nnx.Rngs(0),
)

# Set known LoRA weights for testing
# lora_A: (adapters, vocab_size, rank)
# lora_B: (adapters, rank, features)
lora_A_val = jnp.ones((max_lora_adapters, vocab_size, max_lora_rank)) * 0.1
lora_B_val = jnp.ones((max_lora_adapters, max_lora_rank, features)) * 0.2
embed.lora_A[...] = lora_A_val
embed.lora_B[...] = lora_B_val
embed.lora_scaling[...] = lora_scaling_val

# Test input
hidden_states = jax.random.normal(jax.random.key(42), (batch_size, seq_len, features))
adapter_indices = jnp.array([0, 1], dtype=jnp.int32)

# Get the transposed projection callable
project = embed.T

# Output without LoRA (adapter_indices=None)
base_output = project(hidden_states, adapter_indices=None)
expected_base = hidden_states @ embed.embedding[...].T
assert jnp.allclose(base_output, expected_base), "Base output without LoRA should match"

# Output with LoRA
lora_output = project(hidden_states, adapter_indices=adapter_indices)

# Verify the math: lora_contribution = (hidden_states @ B.T @ A.T) * scale
# For each sample in batch, use its adapter's weights and scaling
for i in range(batch_size):
adapter_idx = adapter_indices[i]
h = hidden_states[i] # (seq_len, features)
lora_B_T = lora_B_val[adapter_idx].T # (features, rank)
lora_A_T = lora_A_val[adapter_idx].T # (rank, vocab_size)
scale = lora_scaling_val[adapter_idx]
expected_lora_contribution = (h @ lora_B_T @ lora_A_T) * scale # (seq_len, vocab_size)
expected_total = expected_base[i] + expected_lora_contribution

assert jnp.allclose(lora_output[i], expected_total, atol=1e-5), f"LoRA math incorrect for batch {i}"


def test_lora_embed_forward_and_transposed_consistency():
"""Test that forward and transposed LoRA use the same weights and scaling correctly."""
vocab_size = 50
features = 16
max_lora_adapters = 1
max_lora_rank = 4
batch_size = 1
seq_len = 3

# Use a non-trivial scaling value to ensure it's properly tested
lora_scaling_val = jnp.array([2.5], dtype=jnp.float32)

mesh = jax.make_mesh(
(1,),
("dp",),
axis_types=(jax.sharding.AxisType.Auto,),
)
with jax.set_mesh(mesh):
embed = LoRAEmbed(
num_embeddings=vocab_size,
features=features,
max_lora_adapters=max_lora_adapters,
max_lora_rank=max_lora_rank,
dtype=jnp.float32,
embedding_init=nnx.with_partitioning(nnx.initializers.normal(0.02), (None, None)),
rngs=nnx.Rngs(0),
)

# Set LoRA weights and scaling
lora_A_val = jax.random.normal(jax.random.key(1), (max_lora_adapters, vocab_size, max_lora_rank)) * 0.1
lora_B_val = jax.random.normal(jax.random.key(2), (max_lora_adapters, max_lora_rank, features)) * 0.1
embed.lora_A[...] = lora_A_val
embed.lora_B[...] = lora_B_val
embed.lora_scaling[...] = lora_scaling_val

adapter_indices = jnp.array([0], dtype=jnp.int32)
scale = lora_scaling_val[0]

# Forward pass: token_ids -> embeddings
token_ids = jnp.array([[5, 10, 15]], dtype=jnp.int32)
forward_output = embed(token_ids, adapter_indices=adapter_indices)

# Expected forward: base_embed + (A[token_ids] @ B) * scale
base_embed = embed.embedding[...][token_ids] # (1, 3, features)
lora_A_lookup = lora_A_val[0, token_ids[0], :] # (3, rank)
forward_lora_contribution = (lora_A_lookup @ lora_B_val[0]) * scale # (3, features)
expected_forward = base_embed + forward_lora_contribution

assert jnp.allclose(forward_output, expected_forward, atol=1e-5), "Forward LoRA incorrect"

# Transposed pass: hidden_states -> logits
hidden_states = jax.random.normal(jax.random.key(3), (batch_size, seq_len, features))
transposed_output = embed.T(hidden_states, adapter_indices=adapter_indices)

# Expected transposed: hidden @ embed.T + (hidden @ B.T @ A.T) * scale
base_transposed = hidden_states @ embed.embedding[...].T
lora_B_T = lora_B_val[0].T # (features, rank)
lora_A_T = lora_A_val[0].T # (rank, vocab_size)
transposed_lora_contribution = (hidden_states @ lora_B_T @ lora_A_T) * scale
expected_transposed = base_transposed + transposed_lora_contribution

assert jnp.allclose(transposed_output, expected_transposed, atol=1e-5), "Transposed LoRA incorrect"
29 changes: 19 additions & 10 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def apply_lora(
x: jax.Array,
base_output: jax.Array,
adapter_indices: jax.Array | None,
*,
transposed: bool = False,
) -> jax.Array:
if self.max_lora_adapters == 0 or adapter_indices is None:
return base_output
Expand All @@ -99,9 +101,17 @@ def apply_lora(
x_flat, adapter_indices_expanded, self.max_lora_adapters, adapter_indices=adapter_indices_expanded
)

# Apply LoRA: x @ A @ B (or A[x] @ B for embeddings)
intermediate = self._apply_lora_weight(self.lora_A[...], x_sorted, adapter_indices_sorted, group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes)
# Apply LoRA computation
if transposed:
# x @ B.T @ A.T (always linear matmul - can't lookup with continuous hidden states)
lora_B_T = self.lora_B[...].transpose((0, 2, 1))
lora_A_T = self.lora_A[...].transpose((0, 2, 1))
intermediate = jax.lax.ragged_dot(x_sorted, lora_B_T, group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_A_T, group_sizes)
else:
# x @ A @ B (or A[x] @ B for embeddings via _apply_lora_weight override)
intermediate = self._apply_lora_weight(self.lora_A[...], x_sorted, adapter_indices_sorted, group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is a Mixin class, instead of the _apply_lora_weight abstraction, I feel it'd be cleaner to handle both lookup-based and matmul-based paths in this class. and subclasses can choose which one to use with a flag.


# Unsort, reshape, scale
lora_output = lora_output_sorted[unsort_indices].reshape(batch_size, seq_len, -1)
Expand Down Expand Up @@ -169,8 +179,12 @@ def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> ja
@property
def T(self):
"""Return a callable that projects hidden states back to vocabulary space."""
# TODO: Apply lora adapters here as well
return lambda hidden_states, adapter_indices=None: hidden_states @ self.embedding[...].T

def project(hidden_states: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array:
base_out = hidden_states @ self.embedding[...].T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits:

  • avoid capturing self.
  • rename hidden_states to something more general, since this is the general LoRAMixin class.

return self.apply_lora(hidden_states, base_out, adapter_indices, transposed=True)

return project


class LoRALinear(LoRAMixin, nnx.Linear):
Expand Down Expand Up @@ -323,11 +337,6 @@ def init_lora_adapter(model: ModelForCausalLM, adapter_index: int, lora_config:
adapter_index: Index of the adapter to initialize
lora_config: LoraConfig object containing rank, alpha, seed, and training flags
"""
if lora_config.train_unembed and getattr(model.config, "tie_word_embeddings", False):
raise ValueError(
"train_unembed=True is incompatible with tie_word_embeddings=True. "
"Tied embeddings use embed_tokens.T which does not support LoRA."
)
rngs = nnx.Rngs(lora_config.seed)
state = nnx.state(model)

Expand Down
Loading