Skip to content

Conversation

@kmalik22
Copy link

@kmalik22 kmalik22 commented Sep 8, 2025

Summary

  • Adds tensor-parallel sharding for Attention and MLP layers
  • Adds a unit test for the above
  • Adds scripts to test, benchmark and profile forward on MLP and Attention blocks with different shapes with default and megatron sharding

Test

uv run -m pytest src/openpi/models/megatron_sharding_test.py -v

Scripts

All scripts are in scripts/kmalik2

  • test_feedforward_inference.py: Test, profile and benchmark timing for MLP block with default and megatron sharding
  • test_attention_inference.py: Test, profile and benchmark timing for Attention block with default and megatron sharding
  • profile_all.py: Wrapper script to generate profiles for a few different configuratoins
  • timing_sweep.py: Wrapper script to run feedforward inference for a few different default and megatrion configs, dump timing information in a csv
  • timing_sweep_attention.py: Wrapper script to run feedforward inference for a few different default and megatrion configs, dump timing information in a csv

Summary of MLP benchmark results

For small values of (BT), megatron is better. For larger values of (BT) default sharding is better.
The graph below shows forward latency for fixed values of model_dim, hidden_dim, num_shards and batch_size. The only thing changing is sequence length

SCR-20250907-owkh

The same graph shown as a speedup
SCR-20250907-owmc

Summary of Attention benchmark results

Similar to MLP, megatron comms are O(num_activations) while fsdp comms are O(num_parameters)
If BT is smaller than 4D_MODEL, megatron is useful. Otherwise stick to fsdp.
SCR-20250908-jxuv

Speedup
SCR-20250908-jxwc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant