Skip to content

Correctly generate state dict in MultiTokenPredictionBlock #3624

Open
asolergi-nv wants to merge 2 commits intoNVIDIA:mainfrom
asolergi-nv:fix_mtp_sharded_statedict
Open

Correctly generate state dict in MultiTokenPredictionBlock #3624
asolergi-nv wants to merge 2 commits intoNVIDIA:mainfrom
asolergi-nv:fix_mtp_sharded_statedict

Conversation

@asolergi-nv
Copy link
Contributor

What does this PR do ?

MultiTokenPredictionBlock has a torch.nn.ModuleList that contains all the layers, like TransformerBlock & MambaStack. The correct way to generate the sharded state dict is by calling the sharded_state_dict method layer by layer, since the ModuleList container does not have this method & will directly call the state_dict method of each layer.

Before this fix we were generating the state_dict with some components coming from the sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) call that called the state_dict method from the ModuleList and other parts from the per layer sharded_state_dict method of each layer.

TransformerBlock state dict generation

sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
num_layers = self.config.num_layers
for layer in self.layers:
offset = get_transformer_layer_offset(
self.config, self.vp_stage, get_pg_rank(self.pg_collection.pp)
)
global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock # pylint: disable=line-too-long
if non_homogeneous_layers:
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
else:
sharded_prefix = layer_prefix
sharded_pp_offset = [
(0, global_layer_offset, num_layers)
] # PP sharding offset for ShardedTensors
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)

MambaStack state dict generation

sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
for local_layer_idx, layer in enumerate(self.layers):
global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = (
f'{layer_prefix}{local_layer_idx}.' # module list index in MambaBlock
)
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@asolergi-nv asolergi-nv requested review from a team as code owners February 26, 2026 15:50
@asolergi-nv asolergi-nv added the Final Review PR is in the "final review" stage label Feb 26, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team February 26, 2026 15:50
@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Feb 26, 2026
@asolergi-nv
Copy link
Contributor Author

/ok to test 9f28b8e

@jaredcasper
Copy link
Contributor

Is this going to cause backward compatibility issues with existing checkpoints?

@asolergi-nv
Copy link
Contributor Author

Is this going to cause backward compatibility issues with existing checkpoints?

The only difference I see in the state dict (both keys & contents) is that before this fix we were adding self_attention.core_attention._extra_state which its just a empty tensor (TransformerEngine convention of ALWAYS adding _extra_state entry despite its a empty tensor). Megatron-LM is dropping the _extra_state entry when using the sharded_state_dict method

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Final Review PR is in the "final review" stage

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants