Skip to content

[tx] Support LoRA in the unembedding layer, redux#984

Open
pcmoritz wants to merge 5 commits intoNovaSky-AI:mainfrom
pcmoritz:tx-lora-unembed-2
Open

[tx] Support LoRA in the unembedding layer, redux#984
pcmoritz wants to merge 5 commits intoNovaSky-AI:mainfrom
pcmoritz:tx-lora-unembed-2

Conversation

@pcmoritz
Copy link
Collaborator

This is a better version of #969, to support LoRA also in the unembedding layer.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively adds support for LoRA in the unembedding layer, which is a great enhancement. The implementation is clean, with the logic correctly handled through a transposed flag in apply_lora and a new implementation for LoRAEmbed.T. The accompanying tests are thorough, covering both the transposed projection and ensuring consistency between the forward and transposed passes. I have one minor suggestion to make the test code more idiomatic.

Copy link
Contributor

@raulchen raulchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. maybe also add an e2e test at the model level if possible.

return lambda hidden_states, adapter_indices=None: hidden_states @ self.embedding[...].T

def project(hidden_states: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array:
base_out = hidden_states @ self.embedding[...].T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits:

  • avoid capturing self.
  • rename hidden_states to something more general, since this is the general LoRAMixin class.

else:
# x @ A @ B (or A[x] @ B for embeddings via _apply_lora_weight override)
intermediate = self._apply_lora_weight(self.lora_A[...], x_sorted, adapter_indices_sorted, group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is a Mixin class, instead of the _apply_lora_weight abstraction, I feel it'd be cleaner to handle both lookup-based and matmul-based paths in this class. and subclasses can choose which one to use with a flag.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants