Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Oct 15, 2025

Description

This PR adds A2A CP support for JAX.

Before
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    1.97s | avg:   0.16s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  612.61s | avg:   3.83s
test_context_parallel_allgather_attn_shardy                  |  20x |   90.95s | avg:   4.55s
test_context_parallel_ring_attn                              | 640x | 1042.37s | avg:   1.63s
test_context_parallel_ring_attn_shardy                       |  20x |   37.74s | avg:   1.89s
test_cross_attn                                              |   6x |   31.82s | avg:   5.30s
test_distributed_gemm                                        |   6x |    6.10s | avg:   1.02s
test_layernorm                                               | 144x |   81.39s | avg:   0.57s
test_layernorm_mlp_grad                                      | 240x |  301.51s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  293.58s | avg:   1.22s
test_layernorm_mlp_layer                                     |  48x |   21.58s | avg:   0.45s
test_layernorm_mlp_layer_fp8                                 | 192x |   81.58s | avg:   0.42s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   91.23s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   25.98s | avg:   0.54s
test_rmsnorm                                                 |  72x |   29.43s | avg:   0.41s
test_self_attn                                               |  18x |   89.75s | avg:   4.99s
test_self_attn_shardy                                        |   6x |   17.32s | avg:   2.89s
test_softmax                                                 | 288x |  185.44s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   13.07s | avg:   0.54s
test_te_distributed_dense_grad                               |   6x |    5.12s | avg:   0.85s
================================================================================
TOTAL RUNTIME                                                |      | 3060.56s |
================================================================================

After
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    1.95s | avg:   0.16s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  573.10s | avg:   3.58s
test_context_parallel_allgather_attn_shardy                  |  20x |   84.57s | avg:   4.23s
test_context_parallel_alltoall_attn                          | 128x |  140.05s | avg:   1.09s
test_context_parallel_ring_attn                              | 640x | 1001.07s | avg:   1.56s
test_context_parallel_ring_attn_shardy                       |  20x |   35.29s | avg:   1.76s
test_cross_attn                                              |   6x |   30.53s | avg:   5.09s
test_distributed_gemm                                        |   6x |    5.71s | avg:   0.95s
test_layernorm                                               | 144x |   79.83s | avg:   0.55s
test_layernorm_mlp_grad                                      | 240x |  300.15s | avg:   1.25s
test_layernorm_mlp_grad_shardy                               | 240x |  316.08s | avg:   1.32s
test_layernorm_mlp_layer                                     |  48x |   23.87s | avg:   0.50s
test_layernorm_mlp_layer_fp8                                 | 192x |   83.67s | avg:   0.44s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |  101.28s | avg:   0.53s
test_layernorm_mlp_layer_shardy                              |  48x |   28.21s | avg:   0.59s
test_rmsnorm                                                 |  72x |   27.90s | avg:   0.39s
test_self_attn                                               |  18x |   87.16s | avg:   4.84s
test_self_attn_shardy                                        |   6x |   16.33s | avg:   2.72s
test_softmax                                                 | 288x |  177.94s | avg:   0.62s
test_softmax_gspmd                                           |  24x |   12.08s | avg:   0.50s
test_te_distributed_dense_grad                               |   6x |    4.66s | avg:   0.78s
================================================================================
TOTAL RUNTIME                                                |      | 3131.45s |
================================================================================

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 2 commits October 15, 2025 10:56
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pre-commit-ci bot and others added 3 commits October 15, 2025 11:07
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL requested a review from phu0ngng October 16, 2025 13:40
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pre-commit-ci bot and others added 3 commits October 16, 2025 15:42
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review October 16, 2025 17:32
@pggPL pggPL requested a review from KshitijLakhani October 16, 2025 17:33
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 17, 2025

/te-ci jax L1

dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this change from [...] to list(...)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension)
# output is always [b, s, h/cp, d] -> heads_dim=2
output = helper.all_to_all(output, False, seq_dim=1, heads_dim=2)
# softmax_aux has shape [b, h/cp, s, 1] with heads at dim 1, seq at dim 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clear comments!

pggPL added 3 commits October 23, 2025 16:35
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 23, 2025

/te-ci jax L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants