-
Notifications
You must be signed in to change notification settings - Fork 247
[WIP] [tx] Implement stacked weights #1018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces StackedDecoderLayers to optimize the transformer forward pass using nnx.vmap and jax.lax.scan, which is a significant performance improvement for training and prefill. The changes are well-encapsulated, making the model code cleaner and more efficient. However, I've identified a critical bug in the load_safetensors utility related to how parameter paths are handled, which would prevent it from loading weights correctly for certain layer types.
skyrl-tx/tx/utils/models.py
Outdated
| updates = [] | ||
| for path, param in nnx.to_flat_state(nnx.state(module)): | ||
| if filter_fn is not None and not filter_fn(path): | ||
| continue | ||
| key = key_prefix + get_param_key(path) | ||
| if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): | ||
| continue | ||
| if "experts" in path: | ||
| tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) | ||
| else: | ||
| tensor = tensors[key] if "embed_tokens" in key else tensors[key].T | ||
| if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: | ||
| tensor = tensor.reshape(param.shape) | ||
| assert param.shape == tensor.shape, f"shape mismatch for {key}" | ||
| updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) | ||
| nnx.update(module, nnx.from_flat_state(updates)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checks for parameter types like "lora_A" in path or "experts" in path are incorrect. The path variable is a tuple of nnx.path.PathEntry objects, not strings, so these checks will always evaluate to False. This will prevent weights for LoRA, experts, and projections from being loaded correctly.
To fix this, you should convert the path to a tuple of strings before performing these checks. This will ensure that the logic correctly identifies the parameter types and applies the appropriate loading logic.
| updates = [] | |
| for path, param in nnx.to_flat_state(nnx.state(module)): | |
| if filter_fn is not None and not filter_fn(path): | |
| continue | |
| key = key_prefix + get_param_key(path) | |
| if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): | |
| continue | |
| if "experts" in path: | |
| tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) | |
| else: | |
| tensor = tensors[key] if "embed_tokens" in key else tensors[key].T | |
| if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: | |
| tensor = tensor.reshape(param.shape) | |
| assert param.shape == tensor.shape, f"shape mismatch for {key}" | |
| updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) | |
| nnx.update(module, nnx.from_flat_state(updates)) | |
| updates = [] | |
| for path, param in nnx.to_flat_state(nnx.state(module)): | |
| path_str_tuple = tuple(map(str, path)) | |
| if filter_fn is not None and not filter_fn(path): | |
| continue | |
| key = key_prefix + get_param_key(path) | |
| if skip_lora and any(p in path_str_tuple for p in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): | |
| continue | |
| if "experts" in path_str_tuple: | |
| tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) | |
| else: | |
| tensor = tensors[key] if "embed_tokens" in key else tensors[key].T | |
| if len(path_str_tuple) >= 2 and path_str_tuple[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: | |
| tensor = tensor.reshape(param.shape) | |
| assert param.shape == tensor.shape, f"shape mismatch for {key}" | |
| updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) | |
| nnx.update(module, nnx.from_flat_state(updates)) |
| return self.num_layers | ||
|
|
||
| def __getitem__(self, index: int) -> nnx.Module: | ||
| """Get view into layer at index (stays synced with stacked state).""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to the docstring that this should only be used for tests and weight loading
| return self.get_metadata("_parent")[self.get_metadata("_idx")].shape | ||
|
|
||
|
|
||
| class StackedDecoderLayers(nnx.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably the easiest way to implement DeepSeekV3 is to implement DualStackedDecoderLayers which has two StackedDecoderLayers as members and the same interface as StackedDecoderLayers (modulo the constructor which can take two create_layer_fn functions and takes their respective numbers as arguments). This could be a separate PR.
This is based on all the great work that @raulchen did in #996 and #906, it also fixes the performance regression in decoding vs. the main branch.