-
Notifications
You must be signed in to change notification settings - Fork 247
[tx] Implement stacked weights #1018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pcmoritz
wants to merge
18
commits into
NovaSky-AI:main
Choose a base branch
from
pcmoritz:tx-stacked-layers
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+277
−79
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
7653b1c
[tx] Implement stacked layers
pcmoritz 320681c
add file
pcmoritz 169ec5a
fix
pcmoritz 872fcf0
update
pcmoritz e3c3ecd
update
pcmoritz 2336a08
fix ruff
pcmoritz 526efa2
update
pcmoritz 26d9a43
update
pcmoritz 3751008
update
pcmoritz 3f2879d
update
pcmoritz 52fcccf
update
pcmoritz 851dd0f
update
pcmoritz c2cdecb
cleanup
pcmoritz 7cfa898
update
pcmoritz e2352d3
fix test
pcmoritz de2229e
black
pcmoritz ca212c1
update
pcmoritz 27e0523
update
pcmoritz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| """StackedDecoderLayers module for efficient transformer layer stacking.""" | ||
|
|
||
| from typing import Callable | ||
|
|
||
| from flax import nnx | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| from tx.utils.generator import KVCache | ||
|
|
||
|
|
||
| class ArrayRef(nnx.Variable): | ||
| """A Variable providing a view into an indexed slice of a parent Variable.""" | ||
|
|
||
| def __init__(self, parent: nnx.Variable, idx: int): | ||
| super().__init__(parent[idx]) | ||
| self.set_metadata("_parent", parent) | ||
| self.set_metadata("_idx", idx) | ||
|
|
||
| def __getitem__(self, key): | ||
| parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") | ||
| return parent[idx] if key is Ellipsis else parent[idx][key] | ||
|
|
||
| def set_raw_value(self, value, **kwargs): | ||
| """Write through to parent when value is set.""" | ||
| parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") | ||
| parent[...] = parent[...].at[idx].set(value) | ||
| super().set_raw_value(value, **kwargs) | ||
|
|
||
| @property | ||
| def shape(self): | ||
| return self.get_metadata("_parent")[self.get_metadata("_idx")].shape | ||
|
|
||
|
|
||
| class StackedDecoderLayers(nnx.Module): | ||
| """Decoder layers with stacked weights created via nnx.vmap. | ||
|
|
||
| Parameters are stored in stacked format (num_layers, ...). The forward pass | ||
| uses jax.lax.scan for training/prefill and Python loops for decode. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| create_layer_fn: Callable[[nnx.Rngs], nnx.Module], | ||
| num_layers: int, | ||
| rngs: nnx.Rngs, | ||
| ): | ||
| self.num_layers = num_layers | ||
|
|
||
| @nnx.split_rngs(splits=num_layers) | ||
| @nnx.vmap(in_axes=(0,), out_axes=0, transform_metadata={nnx.PARTITION_NAME: None}) | ||
| def vmapped_create(rngs: nnx.Rngs) -> nnx.Module: | ||
| return create_layer_fn(rngs) | ||
|
|
||
| self._stacked = vmapped_create(rngs) | ||
|
|
||
| def __len__(self) -> int: | ||
| """Return the number of layers.""" | ||
| return self.num_layers | ||
|
|
||
| def __getitem__(self, index: int) -> nnx.Module: | ||
| """Get view into layer at index. Only for tests and weight loading.""" | ||
| if index < 0 or index >= self.num_layers: | ||
| raise IndexError(f"Layer index {index} out of range [0, {self.num_layers})") | ||
| graphdef, state = nnx.split(self._stacked) | ||
| layer_state = jax.tree.map( | ||
| lambda x: ArrayRef(x, index), | ||
| state, | ||
| is_leaf=lambda x: isinstance(x, nnx.Variable), | ||
| ) | ||
| return nnx.merge(graphdef, layer_state) | ||
|
|
||
| def __iter__(self): | ||
| """Iterate over individual layers (for testing/weight loading).""" | ||
| for i in range(self.num_layers): | ||
| yield self[i] | ||
|
|
||
| def __call__( | ||
| self, | ||
| hidden_states: jax.Array, | ||
| *, | ||
| attention_mask: jax.Array, | ||
| positions: jax.Array, | ||
| adapter_indices: jax.Array | None, | ||
| kv_cache: KVCache | None, | ||
| output_hidden_states: bool, | ||
| gradient_checkpointing: bool, | ||
| is_training: bool = False, | ||
| ) -> tuple[jax.Array, list[jax.Array], KVCache | None]: | ||
| """Forward pass through all layers. | ||
|
|
||
| Uses scan for training/prefill, Python loop for decode. | ||
|
|
||
| Returns: | ||
| (final_hidden_states, all_hidden_states, kv_cache) | ||
| """ | ||
| graphdef, state = nnx.split(self._stacked) | ||
|
|
||
| # Decode mode: use Python loop | ||
| if kv_cache is not None: | ||
| all_hidden_states = [] | ||
| new_keys, new_values = [], [] | ||
|
|
||
| for i in range(self.num_layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
|
|
||
| layer = nnx.merge(graphdef, jax.tree.map(lambda x, i=i: x[i], state)) | ||
| hidden_states, (k, v) = layer( | ||
| hidden_states, | ||
| attention_mask=attention_mask, | ||
| positions=positions, | ||
| adapter_indices=adapter_indices, | ||
| kv_cache=(kv_cache.keys[i], kv_cache.values[i]), | ||
| ) | ||
| new_keys.append(k) | ||
| new_values.append(v) | ||
|
|
||
| return ( | ||
| hidden_states, | ||
| all_hidden_states, | ||
| KVCache( | ||
| keys=new_keys, | ||
| values=new_values, | ||
| cache_position=kv_cache.cache_position + positions.shape[1], | ||
| ), | ||
| ) | ||
|
|
||
| # Training/prefill mode: use scan | ||
| def body_fn(hs, layer_params): | ||
| layer = nnx.merge(graphdef, layer_params) | ||
| new_hs, (k, v) = layer( | ||
| hs, | ||
| attention_mask=attention_mask, | ||
| positions=positions, | ||
| adapter_indices=adapter_indices, | ||
| kv_cache=None, | ||
| ) | ||
| if is_training: | ||
| k = v = None | ||
| return new_hs, (new_hs if output_hidden_states else None, k, v) | ||
|
|
||
| if gradient_checkpointing: | ||
| body_fn = jax.checkpoint(body_fn) | ||
|
|
||
| final_hs, (all_hs, all_keys, all_values) = jax.lax.scan(body_fn, hidden_states, state) | ||
|
|
||
| all_hidden_states = [hidden_states] + list(all_hs[:-1]) if output_hidden_states else [] | ||
|
|
||
| if is_training: | ||
| return final_hs, all_hidden_states, None | ||
|
|
||
| return ( | ||
| final_hs, | ||
| all_hidden_states, | ||
| KVCache( | ||
| keys=[all_keys[i] for i in range(self.num_layers)], | ||
| values=[all_values[i] for i in range(self.num_layers)], | ||
| cache_position=attention_mask.sum(axis=1).astype(jnp.int32), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def unstack_state(module: nnx.Module) -> nnx.GraphState: | ||
| """Transform stacked layer state to unstacked ArrayRef views. | ||
|
|
||
| Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. | ||
| Each entry is an ArrayRef that writes through to the original stacked variable. | ||
| """ | ||
| expanded = [] | ||
| for path, var in nnx.to_flat_state(nnx.state(module)): | ||
| if "_stacked" not in path: | ||
| expanded.append((path, var)) | ||
| continue | ||
|
|
||
| idx = path.index("_stacked") | ||
| for i in range(var[...].shape[0]): | ||
| new_path = path[:idx] + (str(i),) + path[idx + 1 :] | ||
| expanded.append((new_path, ArrayRef(var, i))) | ||
|
|
||
| return nnx.from_flat_state(expanded) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably the easiest way to implement DeepSeekV3 is to implement
DualStackedDecoderLayerswhich has twoStackedDecoderLayersas members and the same interface asStackedDecoderLayers(modulo the constructor which can take twocreate_layer_fnfunctions and takes their respective numbers as arguments). This could be a separate PR.