Skip to content

Conversation

@pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Feb 4, 2026

This is based on all the great work that @raulchen did in #996 and #906, it also fixes the performance regression in decoding vs. the main branch.

@pcmoritz pcmoritz added the tx label Feb 4, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces StackedDecoderLayers to optimize the transformer forward pass using nnx.vmap and jax.lax.scan, which is a significant performance improvement for training and prefill. The changes are well-encapsulated, making the model code cleaner and more efficient. However, I've identified a critical bug in the load_safetensors utility related to how parameter paths are handled, which would prevent it from loading weights correctly for certain layer types.

Comment on lines 118 to 133
updates = []
for path, param in nnx.to_flat_state(nnx.state(module)):
if filter_fn is not None and not filter_fn(path):
continue
key = key_prefix + get_param_key(path)
if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path):
continue
if "experts" in path:
tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}:
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding)))
nnx.update(module, nnx.from_flat_state(updates))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The checks for parameter types like "lora_A" in path or "experts" in path are incorrect. The path variable is a tuple of nnx.path.PathEntry objects, not strings, so these checks will always evaluate to False. This will prevent weights for LoRA, experts, and projections from being loaded correctly.

To fix this, you should convert the path to a tuple of strings before performing these checks. This will ensure that the logic correctly identifies the parameter types and applies the appropriate loading logic.

Suggested change
updates = []
for path, param in nnx.to_flat_state(nnx.state(module)):
if filter_fn is not None and not filter_fn(path):
continue
key = key_prefix + get_param_key(path)
if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path):
continue
if "experts" in path:
tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}:
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding)))
nnx.update(module, nnx.from_flat_state(updates))
updates = []
for path, param in nnx.to_flat_state(nnx.state(module)):
path_str_tuple = tuple(map(str, path))
if filter_fn is not None and not filter_fn(path):
continue
key = key_prefix + get_param_key(path)
if skip_lora and any(p in path_str_tuple for p in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")):
continue
if "experts" in path_str_tuple:
tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
if len(path_str_tuple) >= 2 and path_str_tuple[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}:
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding)))
nnx.update(module, nnx.from_flat_state(updates))

return self.num_layers

def __getitem__(self, index: int) -> nnx.Module:
"""Get view into layer at index (stays synced with stacked state)."""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to the docstring that this should only be used for tests and weight loading

return self.get_metadata("_parent")[self.get_metadata("_idx")].shape


class StackedDecoderLayers(nnx.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the easiest way to implement DeepSeekV3 is to implement DualStackedDecoderLayers which has two StackedDecoderLayers as members and the same interface as StackedDecoderLayers (modulo the constructor which can take two create_layer_fn functions and takes their respective numbers as arguments). This could be a separate PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant