[tx] Support LoRA in the unembedding layer, redux#984
[tx] Support LoRA in the unembedding layer, redux#984pcmoritz wants to merge 5 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request effectively adds support for LoRA in the unembedding layer, which is a great enhancement. The implementation is clean, with the logic correctly handled through a transposed flag in apply_lora and a new implementation for LoRAEmbed.T. The accompanying tests are thorough, covering both the transposed projection and ensuring consistency between the forward and transposed passes. I have one minor suggestion to make the test code more idiomatic.
raulchen
left a comment
There was a problem hiding this comment.
LGTM overall. maybe also add an e2e test at the model level if possible.
| return lambda hidden_states, adapter_indices=None: hidden_states @ self.embedding[...].T | ||
|
|
||
| def project(hidden_states: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array: | ||
| base_out = hidden_states @ self.embedding[...].T |
There was a problem hiding this comment.
nits:
- avoid capturing
self. - rename hidden_states to something more general, since this is the general LoRAMixin class.
| else: | ||
| # x @ A @ B (or A[x] @ B for embeddings via _apply_lora_weight override) | ||
| intermediate = self._apply_lora_weight(self.lora_A[...], x_sorted, adapter_indices_sorted, group_sizes) | ||
| lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes) |
There was a problem hiding this comment.
since this is a Mixin class, instead of the _apply_lora_weight abstraction, I feel it'd be cleaner to handle both lookup-based and matmul-based paths in this class. and subclasses can choose which one to use with a flag.
This is a better version of #969, to support LoRA also in the unembedding layer.