From aa656bf62dae5c96117cc2b1d9e1949636cd4c28 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 Oct 2025 10:56:39 +0000 Subject: [PATCH 01/12] init Signed-off-by: Pawel Gadzinski --- tests/jax/test_distributed_fused_attn.py | 43 +++ transformer_engine/jax/attention.py | 2 + .../jax/cpp_extensions/attention.py | 247 ++++++++++++++++++ transformer_engine/jax/flax/transformer.py | 3 + 4 files changed, 295 insertions(+) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index ef8e370b6e..afecc30322 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -580,6 +580,49 @@ def test_context_parallel_ring_attn_shardy( use_scan_ring=True, ) + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + [ + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="BSHD_KVPACKED-NO_MASK"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), + ], + ) + def test_context_parallel_alltoall_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + ): + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced=True, + cp_strategy=CPStrategy.ALL_TO_ALL, + use_shardy=False, + ) + REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { "L0": [[]], diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0931461627..9165a98c53 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -175,11 +175,13 @@ class CPStrategy(Enum): DEFAULT: Default strategy will choose automatically if context parallel axis is sharded. ALL_GATHER: All-gather/reduce scatter implementation. RING: Ring attention implementation (https://arxiv.org/abs/2310.01889). + ALL_TO_ALL: All-to-all implementation. """ DEFAULT = 0 ALL_GATHER = 1 RING = 2 + ALL_TO_ALL = 3 class ReorderStrategy(Enum): diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index db2537c38f..57f4cf6e01 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1629,6 +1629,249 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) +class FusedAttnCPWithAllToAllFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism Primitive. + Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithA2AHelper(mesh, config) + q_aval = arg_infos[0].aval if hasattr(arg_infos[0], 'aval') else arg_infos[0] + num_heads = q_aval.shape[2] + helper.check_supported(num_heads) + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + q_, k_, v_ = helper.all_to_all(q, True), helper.all_to_all(k, True), helper.all_to_all(v, True) + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q_, + k_, + v_, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(), + ) + output = helper.all_to_all(output, False) + # softmax_aux has shape [b, h/cp, s, 1] with heads at dim 1, seq at dim 2 + softmax_aux = helper.all_to_all(softmax_aux, False, seq_dim=2, heads_dim=1) + return output, softmax_aux, rng_state + + return mesh, impl, out_shardings, arg_shardings + +register_primitive(FusedAttnCPWithAllToAllFwdPrimitive) + +class FusedAttnCPWithAllToAllBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism Primitive. + Like Ulysses, applying A2A to QKVO and its derivatives. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unnecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithA2AHelper(mesh, config) + helper.check_supported() + + dq_sharding = result_infos[0].sharding + dk_sharding = result_infos[1].sharding + dv_sharding = result_infos[2].sharding + dbias_sharding = result_infos[3].sharding + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos]) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + # Apply all-to-all to inputs before backward pass + q_, k_, v_ = helper.all_to_all(q, True), helper.all_to_all(k, True), helper.all_to_all(v, True) + doutput_ = helper.all_to_all(doutput, True) + # softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2 + softmax_aux_ = helper.all_to_all(softmax_aux, True, seq_dim=2, heads_dim=1) + + # Perform backward pass + dq, dk, dv, dbias = FusedAttnBwdPrimitive.impl( + q_, + k_, + v_, + bias, + softmax_aux_, + rng_state, + output, + doutput_, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(), + ) + + # Apply all-to-all to gradients to restore original sharding + dq_ = helper.all_to_all(dq, False) + dk_ = helper.all_to_all(dk, False) + dv_ = helper.all_to_all(dv, False) + + return dq_, dk_, dv_, dbias + + return mesh, impl, out_shardings, arg_shardings + +register_primitive(FusedAttnCPWithAllToAllBwdPrimitive) + + +@dataclass(frozen=True) +class _FusedAttnCPWithA2AHelper: + """Helper class to assist with all-to-all communication for context parallel attention. + + This class provides methods for performing all-to-all communication across devices + and handles both THD and BSHD layout formats appropriately. + """ + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + def check_supported(self, num_heads): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused A2A attention" + if self.config.qkv_layout.is_thd(): + raise ValueError(f"{header} does not support THD format") + elif self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD: + raise ValueError(f"{header} does not support SBHD format") + + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + if num_heads % cp_size != 0: + raise ValueError( + f"{header} requires num_heads ({num_heads}) to be divisible by " + f"context parallel size ({cp_size})" + ) + + def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): + """ + Performs all-to-all communication for context parallelism. + + Args: + x: Input tensor + before_attn: If True, converts seq->heads dist. If False, converts heads->seq dist. + seq_dim: Position of sequence dimension (default 1 for BSHD: [b, s, h, d]) + heads_dim: Position of heads dimension (default 2 for BSHD: [b, s, h, d]) + + Returns: + Tensor after all-to-all with redistributed dimensions + + Shape transforms for BSHD (seq_dim=1, heads_dim=2): + before_attn=True: [b, s/cp, h, d] -> [b, s, h/cp, d] + before_attn=False: [b, s, h/cp, d] -> [b, s/cp, h, d] + + Shape transforms for softmax_aux (seq_dim=2, heads_dim=1): + before_attn=False: [b, h/cp, s, ...] -> [b, h, s/cp, ...] + """ + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + shape = x.shape + + if before_attn: + # Input: sharded on seq, want to shard on heads + # Split heads: [..., s/cp, ..., h, ...] -> [..., s/cp, ..., cp, h/cp, ...] + x = x.reshape(*shape[:heads_dim], cp_size, shape[heads_dim] // cp_size, *shape[heads_dim+1:]) + # A2A splits cp dimension and concatenates into seq + split_axis = heads_dim + concat_axis = seq_dim + else: + # Input: sharded on heads, want to shard on seq + # Unflatten seq: [..., s, ..., h/cp, ...] -> [..., cp, s/cp, ..., h/cp, ...] + s_global = shape[seq_dim] + s_local = s_global // cp_size + new_shape = list(shape) + new_shape[seq_dim:seq_dim+1] = [cp_size, s_local] + x = x.reshape(new_shape) + # A2A splits cp dimension (at seq_dim) and concatenates into heads + split_axis = seq_dim + concat_axis = heads_dim + + # All-to-all communication + x = lax_paral_op( + x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh, + split_axis=split_axis, concat_axis=concat_axis, tiled=True, + ) + + return x + + def get_step_config(self) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=self.config.attn_mask_type, + qkv_layout=self.config.qkv_layout, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.config.max_segments_per_seq, + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis="", # No CP axis for the inner attention call + cp_striped_window_size=None, + ) + + @dataclass(frozen=True) class _FusedAttnCPWithP2PHelper: """Helper class to assist with running the P2P ring strategy for CP attention.""" @@ -2641,6 +2884,8 @@ def fused_attn_fwd( primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive else: primitive = FusedRingAttnFwdPrimitive.outer_primitive + case CPStrategy.ALL_TO_ALL: + primitive = FusedAttnCPWithAllToAllFwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) output, softmax_aux, rng_state = primitive.bind( @@ -2767,6 +3012,8 @@ def fused_attn_bwd( primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive else: primitive = FusedRingAttnBwdPrimitive.outer_primitive + case CPStrategy.ALL_TO_ALL: + primitive = FusedAttnCPWithAllToAllBwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) *qkv_grads, bias_grad = primitive.bind( diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index c95765bf3a..8c903f56f4 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -662,6 +662,9 @@ def __call__( "ALL_GATHER": CPStrategy.ALL_GATHER, "ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling "RING": CPStrategy.RING, + "ALL_TO_ALL": CPStrategy.ALL_TO_ALL, + "ALLTOALL": CPStrategy.ALL_TO_ALL, # Alternative spelling + "A2A": CPStrategy.ALL_TO_ALL, # Short form } strategy_key = self.context_parallel_strategy.upper() From 993f9028169a3a65c1eab0c07129fa9f04f2bb87 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 Oct 2025 11:06:10 +0000 Subject: [PATCH 02/12] small fixes Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/attention.py | 2 ++ transformer_engine/jax/cpp_extensions/attention.py | 4 +++- transformer_engine/jax/flax/transformer.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 9165a98c53..ffcbd0f011 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1049,6 +1049,8 @@ def fused_attn( time and memory consumption is proportional to `max_segments_per_seq`. window_size (Optional[Tuple[int, int]]): Sliding window size. + context_parallel_strategy (CPStrategy): + The strategy of context parallel. Options: DEFAULT, ALL_GATHER, RING, ALL_TO_ALL. context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 57f4cf6e01..8253a2d652 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1713,7 +1713,9 @@ def partition(config, mesh, arg_infos, result_infos): return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) helper = _FusedAttnCPWithA2AHelper(mesh, config) - helper.check_supported() + q_aval = arg_infos[0].aval if hasattr(arg_infos[0], 'aval') else arg_infos[0] + num_heads = q_aval.shape[2] + helper.check_supported(num_heads) dq_sharding = result_infos[0].sharding dk_sharding = result_infos[1].sharding diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 8c903f56f4..5e3f4e424e 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -511,7 +511,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. - context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. + context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING, 3: ALL_TO_ALL. context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. Optimization parameters From ff5e3a7bcb900d8ee026b603b8be4ea1d826e6f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:07:35 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_fused_attn.py | 8 +- .../jax/cpp_extensions/attention.py | 93 +++++++++++-------- 2 files changed, 62 insertions(+), 39 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index afecc30322..de4388075f 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -591,9 +591,13 @@ def test_context_parallel_ring_attn_shardy( "qkv_layout, attn_mask_type", [ pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), + pytest.param( + QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL" + ), pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="BSHD_KVPACKED-NO_MASK"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), + pytest.param( + QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK" + ), ], ) def test_context_parallel_alltoall_attn( diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 8253a2d652..13371880af 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1631,9 +1631,9 @@ def _cross_attn_bwd( class FusedAttnCPWithAllToAllFwdPrimitive(FusedAttnFwdPrimitive): """ - Fused Attention Forward with Context Parallelism Primitive. - Like Ulysses, applying A2A to QKVO. - Refer the paper `DeepSpeed Ulysses `_. + Fused Attention Forward with Context Parallelism Primitive. + Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. """ @staticmethod @@ -1642,12 +1642,12 @@ def partition(config, mesh, arg_infos, result_infos): is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 if not is_context_parallel: return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) - + helper = _FusedAttnCPWithA2AHelper(mesh, config) - q_aval = arg_infos[0].aval if hasattr(arg_infos[0], 'aval') else arg_infos[0] + q_aval = arg_infos[0].aval if hasattr(arg_infos[0], "aval") else arg_infos[0] num_heads = q_aval.shape[2] helper.check_supported(num_heads) - + out_sharding = result_infos[0].sharding softmax_aux_sharding = result_infos[1].sharding rng_state_sharding = seed_sharding = NamedSharding( @@ -1657,6 +1657,7 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[4] = seed_sharding arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + def impl( q, k, @@ -1672,7 +1673,11 @@ def impl( _q_segment_pos, _kv_segment_pos, ): - q_, k_, v_ = helper.all_to_all(q, True), helper.all_to_all(k, True), helper.all_to_all(v, True) + q_, k_, v_ = ( + helper.all_to_all(q, True), + helper.all_to_all(k, True), + helper.all_to_all(v, True), + ) output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_, k_, @@ -1696,34 +1701,36 @@ def impl( return mesh, impl, out_shardings, arg_shardings + register_primitive(FusedAttnCPWithAllToAllFwdPrimitive) + class FusedAttnCPWithAllToAllBwdPrimitive(FusedAttnBwdPrimitive): """ - Fused Attention Backward with Context Parallelism Primitive. - Like Ulysses, applying A2A to QKVO and its derivatives. - Refer the paper `DeepSpeed Ulysses `_. + Fused Attention Backward with Context Parallelism Primitive. + Like Ulysses, applying A2A to QKVO and its derivatives. + Refer the paper `DeepSpeed Ulysses `_. """ - + @staticmethod def partition(config, mesh, arg_infos, result_infos): # Call base implementation for non-context parallel mesh to avoid unnecessary work. is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) - + helper = _FusedAttnCPWithA2AHelper(mesh, config) - q_aval = arg_infos[0].aval if hasattr(arg_infos[0], 'aval') else arg_infos[0] + q_aval = arg_infos[0].aval if hasattr(arg_infos[0], "aval") else arg_infos[0] num_heads = q_aval.shape[2] helper.check_supported(num_heads) - + dq_sharding = result_infos[0].sharding dk_sharding = result_infos[1].sharding dv_sharding = result_infos[2].sharding dbias_sharding = result_infos[3].sharding arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos]) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - + def impl( q, k, @@ -1743,11 +1750,15 @@ def impl( _kv_segment_pos, ): # Apply all-to-all to inputs before backward pass - q_, k_, v_ = helper.all_to_all(q, True), helper.all_to_all(k, True), helper.all_to_all(v, True) + q_, k_, v_ = ( + helper.all_to_all(q, True), + helper.all_to_all(k, True), + helper.all_to_all(v, True), + ) doutput_ = helper.all_to_all(doutput, True) # softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2 softmax_aux_ = helper.all_to_all(softmax_aux, True, seq_dim=2, heads_dim=1) - + # Perform backward pass dq, dk, dv, dbias = FusedAttnBwdPrimitive.impl( q_, @@ -1768,30 +1779,31 @@ def impl( _kv_segment_pos, config=helper.get_step_config(), ) - + # Apply all-to-all to gradients to restore original sharding dq_ = helper.all_to_all(dq, False) dk_ = helper.all_to_all(dk, False) dv_ = helper.all_to_all(dv, False) - + return dq_, dk_, dv_, dbias - + return mesh, impl, out_shardings, arg_shardings + register_primitive(FusedAttnCPWithAllToAllBwdPrimitive) @dataclass(frozen=True) class _FusedAttnCPWithA2AHelper: """Helper class to assist with all-to-all communication for context parallel attention. - + This class provides methods for performing all-to-all communication across devices and handles both THD and BSHD layout formats appropriately. """ - + mesh: jax.sharding.Mesh config: _FusedAttnConfig - + def check_supported(self, num_heads): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused A2A attention" @@ -1799,41 +1811,43 @@ def check_supported(self, num_heads): raise ValueError(f"{header} does not support THD format") elif self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD: raise ValueError(f"{header} does not support SBHD format") - + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if num_heads % cp_size != 0: raise ValueError( f"{header} requires num_heads ({num_heads}) to be divisible by " f"context parallel size ({cp_size})" ) - + def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): """ Performs all-to-all communication for context parallelism. - + Args: x: Input tensor before_attn: If True, converts seq->heads dist. If False, converts heads->seq dist. seq_dim: Position of sequence dimension (default 1 for BSHD: [b, s, h, d]) heads_dim: Position of heads dimension (default 2 for BSHD: [b, s, h, d]) - + Returns: Tensor after all-to-all with redistributed dimensions - + Shape transforms for BSHD (seq_dim=1, heads_dim=2): before_attn=True: [b, s/cp, h, d] -> [b, s, h/cp, d] before_attn=False: [b, s, h/cp, d] -> [b, s/cp, h, d] - + Shape transforms for softmax_aux (seq_dim=2, heads_dim=1): before_attn=False: [b, h/cp, s, ...] -> [b, h, s/cp, ...] """ cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) shape = x.shape - + if before_attn: # Input: sharded on seq, want to shard on heads # Split heads: [..., s/cp, ..., h, ...] -> [..., s/cp, ..., cp, h/cp, ...] - x = x.reshape(*shape[:heads_dim], cp_size, shape[heads_dim] // cp_size, *shape[heads_dim+1:]) + x = x.reshape( + *shape[:heads_dim], cp_size, shape[heads_dim] // cp_size, *shape[heads_dim + 1 :] + ) # A2A splits cp dimension and concatenates into seq split_axis = heads_dim concat_axis = seq_dim @@ -1843,20 +1857,25 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): s_global = shape[seq_dim] s_local = s_global // cp_size new_shape = list(shape) - new_shape[seq_dim:seq_dim+1] = [cp_size, s_local] + new_shape[seq_dim : seq_dim + 1] = [cp_size, s_local] x = x.reshape(new_shape) # A2A splits cp dimension (at seq_dim) and concatenates into heads split_axis = seq_dim concat_axis = heads_dim - + # All-to-all communication x = lax_paral_op( - x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh, - split_axis=split_axis, concat_axis=concat_axis, tiled=True, + x, + lax.all_to_all, + self.config.cp_axis, + mesh=self.mesh, + split_axis=split_axis, + concat_axis=concat_axis, + tiled=True, ) - + return x - + def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" return _FusedAttnConfig( From d439db946573b0050a996c6e3929a6a71a85c001 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 Oct 2025 11:23:25 +0000 Subject: [PATCH 04/12] fix Signed-off-by: Pawel Gadzinski --- tests/jax/test_distributed_fused_attn.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index de4388075f..09db441d73 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -600,6 +600,10 @@ def test_context_parallel_ring_attn_shardy( ), ], ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], + ) def test_context_parallel_alltoall_attn( self, device_count, @@ -611,6 +615,7 @@ def test_context_parallel_alltoall_attn( attn_mask_type, dtype, qkv_layout, + load_balanced, ): self.impl_test_context_parallel_attn( device_count, @@ -622,7 +627,7 @@ def test_context_parallel_alltoall_attn( attn_mask_type, dtype, qkv_layout, - load_balanced=True, + load_balanced=load_balanced, cp_strategy=CPStrategy.ALL_TO_ALL, use_shardy=False, ) From 290bef1c48ca27bc67eb6b6deee8485d1c6896be Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 Oct 2025 14:49:41 +0200 Subject: [PATCH 05/12] fix Signed-off-by: Pawel Gadzinski --- tests/jax/distributed_test_base.py | 8 +- tests/jax/test_distributed_fused_attn.py | 4 + .../jax/cpp_extensions/attention.py | 230 ++++++++++++------ 3 files changed, 167 insertions(+), 75 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 4693086b83..260b5570dd 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -8,7 +8,8 @@ import pytest import jax -from jax.experimental.pjit import pjit, _UNSPECIFIED +from jax.experimental.pjit import pjit +from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource @@ -38,7 +39,7 @@ def generate_configs(): return configs -def generate_context_parallel_configs_for_attn(): +def generate_context_parallel_configs_for_attn(heads_divisible_by_cp_times_tp=False): """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only""" configsL1 = [] configsL2 = [] @@ -52,6 +53,9 @@ def generate_context_parallel_configs_for_attn(): if is_devices_enough(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: + if heads_divisible_by_cp_times_tp: + if num_heads % (cp * tp) != 0: + continue configsL1.append( pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") ) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 09db441d73..3b6b5a4ce4 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -617,6 +617,10 @@ def test_context_parallel_alltoall_attn( qkv_layout, load_balanced, ): + if data_shape[2] % (mesh_shape[1] * mesh_shape[2] * kv_groups) != 0: + pytest.skip("Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups") + if load_balanced: + pytest.skip("Load balanced causal attention is not yet supported with all-to-all strategy") self.impl_test_context_parallel_attn( device_count, mesh_shape, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 13371880af..de2c74852e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1644,9 +1644,7 @@ def partition(config, mesh, arg_infos, result_infos): return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) helper = _FusedAttnCPWithA2AHelper(mesh, config) - q_aval = arg_infos[0].aval if hasattr(arg_infos[0], "aval") else arg_infos[0] - num_heads = q_aval.shape[2] - helper.check_supported(num_heads) + helper.check_supported() out_sharding = result_infos[0].sharding softmax_aux_sharding = result_infos[1].sharding @@ -1673,11 +1671,26 @@ def impl( _q_segment_pos, _kv_segment_pos, ): - q_, k_, v_ = ( - helper.all_to_all(q, True), - helper.all_to_all(k, True), - helper.all_to_all(v, True), - ) + # Get heads dimensions based on QKV layout + q_heads_dim, k_heads_dim, v_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) + q_heads = q.shape[q_heads_dim] + kv_heads = k.shape[k_heads_dim] + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size" + assert kv_heads % cp_size == 0, "kv_heads must be divisible by cp_size" + + # Load balanced causal attention is not yet supported for all-to-all strategy + if config.context_parallel_load_balanced: + raise NotImplementedError( + "context_parallel_load_balanced is not supported with all-to-all strategy" + ) + + # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension) + q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) + k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) + # For KVPACKED layout, v is empty placeholder + v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_, k_, @@ -1694,7 +1707,10 @@ def impl( _kv_segment_pos, config=helper.get_step_config(), ) - output = helper.all_to_all(output, False) + + # Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension) + # output is always [b, s, h/cp, d] -> heads_dim=2 + output = helper.all_to_all(output, False, seq_dim=1, heads_dim=2) # softmax_aux has shape [b, h/cp, s, 1] with heads at dim 1, seq at dim 2 softmax_aux = helper.all_to_all(softmax_aux, False, seq_dim=2, heads_dim=1) return output, softmax_aux, rng_state @@ -1720,17 +1736,29 @@ def partition(config, mesh, arg_infos, result_infos): return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) helper = _FusedAttnCPWithA2AHelper(mesh, config) - q_aval = arg_infos[0].aval if hasattr(arg_infos[0], "aval") else arg_infos[0] - num_heads = q_aval.shape[2] - helper.check_supported(num_heads) + helper.check_supported() dq_sharding = result_infos[0].sharding dk_sharding = result_infos[1].sharding dv_sharding = result_infos[2].sharding dbias_sharding = result_infos[3].sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos]) + + # For AllToAll context parallel, output and doutput need to be seq-sharded + # to match the forward output sharding (before they get transformed to heads-sharded) + arg_shardings = list([arg_i.sharding for arg_i in arg_infos]) + # arg_infos: [q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, ...] + # output is at index 6, doutput is at index 7 + # They should have the same sharding as the forward output (seq-sharded on cp axis) + output_seq_sharding = NamedSharding(mesh, PartitionSpec(None, config.cp_axis, None, None)) + softmax_aux_seq_sharding = NamedSharding(mesh, PartitionSpec(None, None, config.cp_axis, None)) + arg_shardings[4] = softmax_aux_seq_sharding # softmax_aux [b, h, s/cp, 1] + arg_shardings[6] = output_seq_sharding # output [b, s/cp, h, d] + arg_shardings[7] = output_seq_sharding # doutput [b, s/cp, h, d] + arg_shardings = tuple(arg_shardings) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - + + def impl( q, k, @@ -1750,12 +1778,29 @@ def impl( _kv_segment_pos, ): # Apply all-to-all to inputs before backward pass - q_, k_, v_ = ( - helper.all_to_all(q, True), - helper.all_to_all(k, True), - helper.all_to_all(v, True), - ) - doutput_ = helper.all_to_all(doutput, True) + # Get heads dimensions based on QKV layout (same as forward) + q_heads_dim, k_heads_dim, v_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) + q_heads = q.shape[q_heads_dim] + k_heads = k.shape[k_heads_dim] + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size" + assert k_heads % cp_size == 0, "k_heads must be divisible by cp_size" + + # Load balanced causal attention is not yet supported for all-to-all strategy + if config.context_parallel_load_balanced: + raise NotImplementedError( + "context_parallel_load_balanced is not supported with all-to-all strategy" + ) + + # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension) + q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) + k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) + # For KVPACKED layout, v is empty placeholder + v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) + # doutput is always separate [b, s, h, d], so heads_dim=2 + doutput_ = helper.all_to_all(doutput, True, seq_dim=1, heads_dim=2) + # output has the same shape as doutput, needs the same transformation + output_ = helper.all_to_all(output, True, seq_dim=1, heads_dim=2) # softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2 softmax_aux_ = helper.all_to_all(softmax_aux, True, seq_dim=2, heads_dim=1) @@ -1767,7 +1812,7 @@ def impl( bias, softmax_aux_, rng_state, - output, + output_, doutput_, q_seqlen, kv_seqlen, @@ -1779,11 +1824,15 @@ def impl( _kv_segment_pos, config=helper.get_step_config(), ) - - # Apply all-to-all to gradients to restore original sharding - dq_ = helper.all_to_all(dq, False) - dk_ = helper.all_to_all(dk, False) - dv_ = helper.all_to_all(dv, False) + + # Apply all-to-all to gradients to restore original sharding (scatter in seq dimension) + # Gradients have the same shape as inputs, so use same heads_dim + dq_heads_dim, dk_heads_dim, dv_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) + + dq_ = helper.all_to_all(dq, False, seq_dim=1, heads_dim=dq_heads_dim) + dk_ = helper.all_to_all(dk, False, seq_dim=1, heads_dim=dk_heads_dim) + # For KVPACKED layout, dv is empty placeholder + dv_ = dv if dv.shape[0] == 0 else helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim) return dq_, dk_, dv_, dbias @@ -1795,16 +1844,22 @@ def impl( @dataclass(frozen=True) class _FusedAttnCPWithA2AHelper: - """Helper class to assist with all-to-all communication for context parallel attention. + """ + Helper class for Ulysses-style context parallelism using all-to-all communication. + + This helper manages the all-to-all communication pattern that redistributes tensors + between sequence-sharded and heads-sharded layouts. This enables context parallelism + by allowing each rank to process different parts of the sequence and heads dimensions. - This class provides methods for performing all-to-all communication across devices - and handles both THD and BSHD layout formats appropriately. + Attributes: + mesh: JAX mesh for distributed computation + config: Fused attention configuration including QKV layout information """ mesh: jax.sharding.Mesh config: _FusedAttnConfig - def check_supported(self, num_heads): + def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused A2A attention" if self.config.qkv_layout.is_thd(): @@ -1812,69 +1867,98 @@ def check_supported(self, num_heads): elif self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD: raise ValueError(f"{header} does not support SBHD format") - cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - if num_heads % cp_size != 0: - raise ValueError( - f"{header} requires num_heads ({num_heads}) to be divisible by " - f"context parallel size ({cp_size})" - ) + def get_qkv_heads_dims(self, seq_dim=1): + """ + Determines the heads dimension indices for Q, K, V tensors based on QKV layout. + + The heads dimension position depends on the QKV packing format: + - QKVPacked: All tensors packed together with dimension [qkv=3, heads, dim] + - KVPacked: Q is separate, K and V are packed with dimension [kv=2, heads, dim] + - Separate: All tensors are separate with dimension [heads, dim] + + Args: + seq_dim: The sequence dimension position (default 1 for BSHD format) + + Returns: + Tuple of (q_heads_dim, k_heads_dim, v_heads_dim) indicating the position + of the heads dimension for each tensor. + + Examples for BSHD layout (seq_dim=1): + QKVPacked: Q=[b, s, 3, h, d] -> returns (3, 3, 3) + KVPacked: Q=[b, s, h, d], K=[b, s, 2, h, d], V=[b, s, 2, h, d] -> returns (2, 3, 3) + Separate: Q=[b, s, h, d], K=[b, s, h, d], V=[b, s, h, d] -> returns (2, 2, 2) + """ + if self.config.qkv_layout.is_qkvpacked(): + # QKV all packed together: [batch..., seq, 3, heads, dim] + # Heads dimension is at seq_dim + 2 for all tensors + heads_dim = seq_dim + 2 + return heads_dim, heads_dim, heads_dim + elif self.config.qkv_layout.is_kvpacked(): + # Q separate, K and V packed: Q=[batch..., seq, heads, dim] + # K/V=[batch..., seq, 2, heads, dim] + q_heads_dim = seq_dim + 1 # Q has no packing dimension + kv_heads_dim = seq_dim + 2 # K/V have packing dimension [2, heads, dim] + return q_heads_dim, kv_heads_dim, kv_heads_dim + else: # separate + # All separate: Q/K/V=[batch..., seq, heads, dim] + # Heads dimension is at seq_dim + 1 for all tensors + heads_dim = seq_dim + 1 + return heads_dim, heads_dim, heads_dim def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): """ - Performs all-to-all communication for context parallelism. + Performs all-to-all communication for context parallelism (Ulysses-style). + + Redistributes data between sequence and heads dimensions across CP ranks. Args: x: Input tensor - before_attn: If True, converts seq->heads dist. If False, converts heads->seq dist. - seq_dim: Position of sequence dimension (default 1 for BSHD: [b, s, h, d]) - heads_dim: Position of heads dimension (default 2 for BSHD: [b, s, h, d]) + before_attn: True = seq-sharded -> heads-sharded, False = heads-sharded -> seq-sharded + seq_dim: Sequence dimension position (typically 1 for BSHD) + heads_dim: Heads dimension position (depends on QKV layout) Returns: Tensor after all-to-all with redistributed dimensions - - Shape transforms for BSHD (seq_dim=1, heads_dim=2): - before_attn=True: [b, s/cp, h, d] -> [b, s, h/cp, d] - before_attn=False: [b, s, h/cp, d] -> [b, s/cp, h, d] - - Shape transforms for softmax_aux (seq_dim=2, heads_dim=1): - before_attn=False: [b, h/cp, s, ...] -> [b, h, s/cp, ...] """ + if x.shape[0] == 0: + return x + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + if before_attn: + num_heads = x.shape[heads_dim] + assert num_heads % cp_size == 0, "num_heads must be divisible by cp_size" + shape = x.shape + # Determine which dimension to split and where to concat if before_attn: - # Input: sharded on seq, want to shard on heads - # Split heads: [..., s/cp, ..., h, ...] -> [..., s/cp, ..., cp, h/cp, ...] - x = x.reshape( - *shape[:heads_dim], cp_size, shape[heads_dim] // cp_size, *shape[heads_dim + 1 :] - ) - # A2A splits cp dimension and concatenates into seq - split_axis = heads_dim - concat_axis = seq_dim + split_axis, concat_axis = heads_dim, seq_dim + # After reshape, need to adjust concat_axis if it comes after split_axis + needs_adjustment = concat_axis > split_axis else: - # Input: sharded on heads, want to shard on seq - # Unflatten seq: [..., s, ..., h/cp, ...] -> [..., cp, s/cp, ..., h/cp, ...] - s_global = shape[seq_dim] - s_local = s_global // cp_size - new_shape = list(shape) - new_shape[seq_dim : seq_dim + 1] = [cp_size, s_local] - x = x.reshape(new_shape) - # A2A splits cp dimension (at seq_dim) and concatenates into heads split_axis = seq_dim - concat_axis = heads_dim + concat_axis = heads_dim + 1 if heads_dim > seq_dim else heads_dim + needs_adjustment = False # Already adjusted above + + # Reshape: insert cp_size at split_axis + assert shape[split_axis] % cp_size == 0 + x = x.reshape( + *shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1:] + ) + + # Adjust concat_axis if needed (only for before_attn case) + adjusted_concat_axis = concat_axis + 1 if needs_adjustment else concat_axis - # All-to-all communication + # Perform all-to-all x = lax_paral_op( - x, - lax.all_to_all, - self.config.cp_axis, - mesh=self.mesh, - split_axis=split_axis, - concat_axis=concat_axis, - tiled=True, + x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh, + split_axis=split_axis, concat_axis=adjusted_concat_axis, tiled=True ) - return x + # Merge the two dimensions created by all-to-all at split_axis + new_shape = list(x.shape) + new_shape[split_axis:split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]] + return x.reshape(new_shape) def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" From 189bad68934e9d0741a8151182a102df458d3836 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 Oct 2025 17:30:45 +0200 Subject: [PATCH 06/12] fix Signed-off-by: Pawel Gadzinski --- tests/jax/distributed_test_base.py | 5 +- .../jax/cpp_extensions/attention.py | 104 ++++++------------ 2 files changed, 32 insertions(+), 77 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 260b5570dd..1113471b22 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -39,7 +39,7 @@ def generate_configs(): return configs -def generate_context_parallel_configs_for_attn(heads_divisible_by_cp_times_tp=False): +def generate_context_parallel_configs_for_attn(): """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only""" configsL1 = [] configsL2 = [] @@ -53,9 +53,6 @@ def generate_context_parallel_configs_for_attn(heads_divisible_by_cp_times_tp=Fa if is_devices_enough(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: - if heads_divisible_by_cp_times_tp: - if num_heads % (cp * tp) != 0: - continue configsL1.append( pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index de2c74852e..168125e394 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1050,7 +1050,7 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings = list(arg_i.sharding for arg_i in arg_infos) arg_shardings[-1] = arg_shardings[-3] arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) @@ -1685,11 +1685,12 @@ def impl( "context_parallel_load_balanced is not supported with all-to-all strategy" ) - # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension) + assert config.qkv_layout in [QKVLayout.BS3HD, QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] + + # Apply all-to-all to transform from seq-sharded to heads-sharded q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) - # For KVPACKED layout, v is empty placeholder - v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) + v_ = helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_, @@ -1738,27 +1739,19 @@ def partition(config, mesh, arg_infos, result_infos): helper = _FusedAttnCPWithA2AHelper(mesh, config) helper.check_supported() - dq_sharding = result_infos[0].sharding - dk_sharding = result_infos[1].sharding - dv_sharding = result_infos[2].sharding - dbias_sharding = result_infos[3].sharding - - # For AllToAll context parallel, output and doutput need to be seq-sharded - # to match the forward output sharding (before they get transformed to heads-sharded) - arg_shardings = list([arg_i.sharding for arg_i in arg_infos]) - # arg_infos: [q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, ...] - # output is at index 6, doutput is at index 7 - # They should have the same sharding as the forward output (seq-sharded on cp axis) - output_seq_sharding = NamedSharding(mesh, PartitionSpec(None, config.cp_axis, None, None)) - softmax_aux_seq_sharding = NamedSharding(mesh, PartitionSpec(None, None, config.cp_axis, None)) - arg_shardings[4] = softmax_aux_seq_sharding # softmax_aux [b, h, s/cp, 1] - arg_shardings[6] = output_seq_sharding # output [b, s/cp, h, d] - arg_shardings[7] = output_seq_sharding # doutput [b, s/cp, h, d] - arg_shardings = tuple(arg_shardings) + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - def impl( q, k, @@ -1792,14 +1785,12 @@ def impl( "context_parallel_load_balanced is not supported with all-to-all strategy" ) - # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension) q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) - # For KVPACKED layout, v is empty placeholder - v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) - # doutput is always separate [b, s, h, d], so heads_dim=2 + v_ = helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) + + # doutput is always [b, s, h, d], so heads_dim=2 doutput_ = helper.all_to_all(doutput, True, seq_dim=1, heads_dim=2) - # output has the same shape as doutput, needs the same transformation output_ = helper.all_to_all(output, True, seq_dim=1, heads_dim=2) # softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2 softmax_aux_ = helper.all_to_all(softmax_aux, True, seq_dim=2, heads_dim=1) @@ -1826,13 +1817,10 @@ def impl( ) # Apply all-to-all to gradients to restore original sharding (scatter in seq dimension) - # Gradients have the same shape as inputs, so use same heads_dim dq_heads_dim, dk_heads_dim, dv_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) - dq_ = helper.all_to_all(dq, False, seq_dim=1, heads_dim=dq_heads_dim) dk_ = helper.all_to_all(dk, False, seq_dim=1, heads_dim=dk_heads_dim) - # For KVPACKED layout, dv is empty placeholder - dv_ = dv if dv.shape[0] == 0 else helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim) + dv_ = helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim) return dq_, dk_, dv_, dbias @@ -1846,14 +1834,6 @@ def impl( class _FusedAttnCPWithA2AHelper: """ Helper class for Ulysses-style context parallelism using all-to-all communication. - - This helper manages the all-to-all communication pattern that redistributes tensors - between sequence-sharded and heads-sharded layouts. This enables context parallelism - by allowing each rank to process different parts of the sequence and heads dimensions. - - Attributes: - mesh: JAX mesh for distributed computation - config: Fused attention configuration including QKV layout information """ mesh: jax.sharding.Mesh @@ -1870,38 +1850,20 @@ def check_supported(self): def get_qkv_heads_dims(self, seq_dim=1): """ Determines the heads dimension indices for Q, K, V tensors based on QKV layout. - - The heads dimension position depends on the QKV packing format: - - QKVPacked: All tensors packed together with dimension [qkv=3, heads, dim] - - KVPacked: Q is separate, K and V are packed with dimension [kv=2, heads, dim] - - Separate: All tensors are separate with dimension [heads, dim] - - Args: - seq_dim: The sequence dimension position (default 1 for BSHD format) - - Returns: - Tuple of (q_heads_dim, k_heads_dim, v_heads_dim) indicating the position - of the heads dimension for each tensor. - - Examples for BSHD layout (seq_dim=1): - QKVPacked: Q=[b, s, 3, h, d] -> returns (3, 3, 3) - KVPacked: Q=[b, s, h, d], K=[b, s, 2, h, d], V=[b, s, 2, h, d] -> returns (2, 3, 3) - Separate: Q=[b, s, h, d], K=[b, s, h, d], V=[b, s, h, d] -> returns (2, 2, 2) """ + if self.config.qkv_layout.is_qkvpacked(): - # QKV all packed together: [batch..., seq, 3, heads, dim] - # Heads dimension is at seq_dim + 2 for all tensors + # [batch..., seq, 3, heads, dim] heads_dim = seq_dim + 2 return heads_dim, heads_dim, heads_dim elif self.config.qkv_layout.is_kvpacked(): - # Q separate, K and V packed: Q=[batch..., seq, heads, dim] - # K/V=[batch..., seq, 2, heads, dim] - q_heads_dim = seq_dim + 1 # Q has no packing dimension - kv_heads_dim = seq_dim + 2 # K/V have packing dimension [2, heads, dim] + # Q=[batch..., seq, heads, dim] + # KV=[batch..., seq, 2, heads, dim] + q_heads_dim = seq_dim + 1 + kv_heads_dim = seq_dim + 2 return q_heads_dim, kv_heads_dim, kv_heads_dim else: # separate - # All separate: Q/K/V=[batch..., seq, heads, dim] - # Heads dimension is at seq_dim + 1 for all tensors + # Q/K/V=[batch..., seq, heads, dim] heads_dim = seq_dim + 1 return heads_dim, heads_dim, heads_dim @@ -1921,7 +1883,7 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): Tensor after all-to-all with redistributed dimensions """ if x.shape[0] == 0: - return x + return x # If tensor is empty, then no communication is performed. cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if before_attn: @@ -1933,26 +1895,22 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): # Determine which dimension to split and where to concat if before_attn: split_axis, concat_axis = heads_dim, seq_dim - # After reshape, need to adjust concat_axis if it comes after split_axis - needs_adjustment = concat_axis > split_axis else: split_axis = seq_dim - concat_axis = heads_dim + 1 if heads_dim > seq_dim else heads_dim - needs_adjustment = False # Already adjusted above + concat_axis = heads_dim # Reshape: insert cp_size at split_axis assert shape[split_axis] % cp_size == 0 x = x.reshape( *shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1:] ) - - # Adjust concat_axis if needed (only for before_attn case) - adjusted_concat_axis = concat_axis + 1 if needs_adjustment else concat_axis + if concat_axis > split_axis: + concat_axis += 1 # Added one dimenstion before concat_axis, need to adjust it. # Perform all-to-all x = lax_paral_op( x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh, - split_axis=split_axis, concat_axis=adjusted_concat_axis, tiled=True + split_axis=split_axis, concat_axis=concat_axis, tiled=True ) # Merge the two dimensions created by all-to-all at split_axis From 98e3b3acbfbbcb223f5f654be7dbb8303b8ca4d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:42:31 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_fused_attn.py | 8 +++- .../jax/cpp_extensions/attention.py | 39 ++++++++++++------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 3b6b5a4ce4..bb8e11debe 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -618,9 +618,13 @@ def test_context_parallel_alltoall_attn( load_balanced, ): if data_shape[2] % (mesh_shape[1] * mesh_shape[2] * kv_groups) != 0: - pytest.skip("Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups") + pytest.skip( + "Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups" + ) if load_balanced: - pytest.skip("Load balanced causal attention is not yet supported with all-to-all strategy") + pytest.skip( + "Load balanced causal attention is not yet supported with all-to-all strategy" + ) self.impl_test_context_parallel_attn( device_count, mesh_shape, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 168125e394..3c2bae65ab 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1678,20 +1678,24 @@ def impl( cp_size = get_mesh_axis_size(config.cp_axis, mesh) assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size" assert kv_heads % cp_size == 0, "kv_heads must be divisible by cp_size" - + # Load balanced causal attention is not yet supported for all-to-all strategy if config.context_parallel_load_balanced: raise NotImplementedError( "context_parallel_load_balanced is not supported with all-to-all strategy" ) - assert config.qkv_layout in [QKVLayout.BS3HD, QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] + assert config.qkv_layout in [ + QKVLayout.BS3HD, + QKVLayout.BSHD_BS2HD, + QKVLayout.BSHD_BSHD_BSHD, + ] # Apply all-to-all to transform from seq-sharded to heads-sharded q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) v_ = helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) - + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( q_, k_, @@ -1708,7 +1712,7 @@ def impl( _kv_segment_pos, config=helper.get_step_config(), ) - + # Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension) # output is always [b, s, h/cp, d] -> heads_dim=2 output = helper.all_to_all(output, False, seq_dim=1, heads_dim=2) @@ -1748,10 +1752,10 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - + def impl( q, k, @@ -1778,13 +1782,13 @@ def impl( cp_size = get_mesh_axis_size(config.cp_axis, mesh) assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size" assert k_heads % cp_size == 0, "k_heads must be divisible by cp_size" - + # Load balanced causal attention is not yet supported for all-to-all strategy if config.context_parallel_load_balanced: raise NotImplementedError( "context_parallel_load_balanced is not supported with all-to-all strategy" ) - + q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) v_ = helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) @@ -1815,7 +1819,7 @@ def impl( _kv_segment_pos, config=helper.get_step_config(), ) - + # Apply all-to-all to gradients to restore original sharding (scatter in seq dimension) dq_heads_dim, dk_heads_dim, dv_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) dq_ = helper.all_to_all(dq, False, seq_dim=1, heads_dim=dq_heads_dim) @@ -1883,7 +1887,7 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): Tensor after all-to-all with redistributed dimensions """ if x.shape[0] == 0: - return x # If tensor is empty, then no communication is performed. + return x # If tensor is empty, then no communication is performed. cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) if before_attn: @@ -1902,20 +1906,25 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): # Reshape: insert cp_size at split_axis assert shape[split_axis] % cp_size == 0 x = x.reshape( - *shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1:] + *shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1 :] ) if concat_axis > split_axis: - concat_axis += 1 # Added one dimenstion before concat_axis, need to adjust it. + concat_axis += 1 # Added one dimenstion before concat_axis, need to adjust it. # Perform all-to-all x = lax_paral_op( - x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh, - split_axis=split_axis, concat_axis=concat_axis, tiled=True + x, + lax.all_to_all, + self.config.cp_axis, + mesh=self.mesh, + split_axis=split_axis, + concat_axis=concat_axis, + tiled=True, ) # Merge the two dimensions created by all-to-all at split_axis new_shape = list(x.shape) - new_shape[split_axis:split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]] + new_shape[split_axis : split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]] return x.reshape(new_shape) def get_step_config(self) -> _FusedAttnConfig: From 21e079f30d3e6eadb6be2958aebddb56f7650276 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 Oct 2025 17:44:17 +0200 Subject: [PATCH 08/12] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/cpp_extensions/attention.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 3c2bae65ab..1f596c98cc 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1848,7 +1848,7 @@ def check_supported(self): header = "Context parallel fused A2A attention" if self.config.qkv_layout.is_thd(): raise ValueError(f"{header} does not support THD format") - elif self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD: + if self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD: raise ValueError(f"{header} does not support SBHD format") def get_qkv_heads_dims(self, seq_dim=1): @@ -1860,16 +1860,16 @@ def get_qkv_heads_dims(self, seq_dim=1): # [batch..., seq, 3, heads, dim] heads_dim = seq_dim + 2 return heads_dim, heads_dim, heads_dim - elif self.config.qkv_layout.is_kvpacked(): + if self.config.qkv_layout.is_kvpacked(): # Q=[batch..., seq, heads, dim] # KV=[batch..., seq, 2, heads, dim] q_heads_dim = seq_dim + 1 kv_heads_dim = seq_dim + 2 return q_heads_dim, kv_heads_dim, kv_heads_dim - else: # separate - # Q/K/V=[batch..., seq, heads, dim] - heads_dim = seq_dim + 1 - return heads_dim, heads_dim, heads_dim + # separate + # Q/K/V=[batch..., seq, heads, dim] + heads_dim = seq_dim + 1 + return heads_dim, heads_dim, heads_dim def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): """ From 26a71eb8860acb21a5d584293f76f06b8a289653 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 Oct 2025 19:22:57 +0200 Subject: [PATCH 09/12] fix Signed-off-by: Pawel Gadzinski --- tests/jax/distributed_test_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 1113471b22..4693086b83 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -8,8 +8,7 @@ import pytest import jax -from jax.experimental.pjit import pjit -from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED +from jax.experimental.pjit import pjit, _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource From 183093baecc085d0595335cd8c44c7c5eb24a410 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 16:35:12 +0200 Subject: [PATCH 10/12] fix Signed-off-by: Pawel Gadzinski --- tests/jax/test_distributed_fused_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index bb8e11debe..9abebd5eed 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -637,7 +637,7 @@ def test_context_parallel_alltoall_attn( qkv_layout, load_balanced=load_balanced, cp_strategy=CPStrategy.ALL_TO_ALL, - use_shardy=False, + use_shardy=True, ) From bbd2b765dd7ef90d1460728a3fc5bfc7d0f2d919 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 16:36:54 +0200 Subject: [PATCH 11/12] fix Signed-off-by: Pawel Gadzinski --- tests/jax/test_distributed_fused_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 9abebd5eed..2d5ca5082e 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -604,7 +604,7 @@ def test_context_parallel_ring_attn_shardy( "load_balanced", [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], ) - def test_context_parallel_alltoall_attn( + def test_context_parallel_alltoall_attn_shardy( self, device_count, mesh_shape, From 8e3bb3786010ed4ac90b7437949b834fb192f3af Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 16:38:12 +0200 Subject: [PATCH 12/12] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/cpp_extensions/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 1f596c98cc..6faff06c86 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1050,7 +1050,7 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = list(arg_i.sharding for arg_i in arg_infos) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[-1] = arg_shardings[-3] arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings)