diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index ef8e370b6e..2d5ca5082e 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -580,6 +580,66 @@ def test_context_parallel_ring_attn_shardy( use_scan_ring=True, ) + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), + ) + @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) + @pytest.mark.parametrize("kv_groups", [1, 8]) + @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) + @pytest.mark.parametrize( + "qkv_layout, attn_mask_type", + [ + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), + pytest.param( + QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL" + ), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="BSHD_KVPACKED-NO_MASK"), + pytest.param( + QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK" + ), + ], + ) + @pytest.mark.parametrize( + "load_balanced", + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], + ) + def test_context_parallel_alltoall_attn_shardy( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + ): + if data_shape[2] % (mesh_shape[1] * mesh_shape[2] * kv_groups) != 0: + pytest.skip( + "Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups" + ) + if load_balanced: + pytest.skip( + "Load balanced causal attention is not yet supported with all-to-all strategy" + ) + self.impl_test_context_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced=load_balanced, + cp_strategy=CPStrategy.ALL_TO_ALL, + use_shardy=True, + ) + REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { "L0": [[]], diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 1ce44a2b93..fe8c5abf25 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -175,11 +175,13 @@ class CPStrategy(Enum): DEFAULT: Default strategy will choose automatically if context parallel axis is sharded. ALL_GATHER: All-gather/reduce scatter implementation. RING: Ring attention implementation (https://arxiv.org/abs/2310.01889). + ALL_TO_ALL: All-to-all implementation. """ DEFAULT = 0 ALL_GATHER = 1 RING = 2 + ALL_TO_ALL = 3 class ReorderStrategy(Enum): @@ -1149,6 +1151,8 @@ def fused_attn( time and memory consumption is proportional to `max_segments_per_seq`. window_size (Optional[Tuple[int, int]]): Sliding window size. + context_parallel_strategy (CPStrategy): + The strategy of context parallel. Options: DEFAULT, ALL_GATHER, RING, ALL_TO_ALL. context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index db2537c38f..6faff06c86 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1629,6 +1629,321 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) +class FusedAttnCPWithAllToAllFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism Primitive. + Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithA2AHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + # Get heads dimensions based on QKV layout + q_heads_dim, k_heads_dim, v_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) + q_heads = q.shape[q_heads_dim] + kv_heads = k.shape[k_heads_dim] + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size" + assert kv_heads % cp_size == 0, "kv_heads must be divisible by cp_size" + + # Load balanced causal attention is not yet supported for all-to-all strategy + if config.context_parallel_load_balanced: + raise NotImplementedError( + "context_parallel_load_balanced is not supported with all-to-all strategy" + ) + + assert config.qkv_layout in [ + QKVLayout.BS3HD, + QKVLayout.BSHD_BS2HD, + QKVLayout.BSHD_BSHD_BSHD, + ] + + # Apply all-to-all to transform from seq-sharded to heads-sharded + q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) + k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) + v_ = helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) + + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q_, + k_, + v_, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(), + ) + + # Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension) + # output is always [b, s, h/cp, d] -> heads_dim=2 + output = helper.all_to_all(output, False, seq_dim=1, heads_dim=2) + # softmax_aux has shape [b, h/cp, s, 1] with heads at dim 1, seq at dim 2 + softmax_aux = helper.all_to_all(softmax_aux, False, seq_dim=2, heads_dim=1) + return output, softmax_aux, rng_state + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllToAllFwdPrimitive) + + +class FusedAttnCPWithAllToAllBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism Primitive. + Like Ulysses, applying A2A to QKVO and its derivatives. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unnecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithA2AHelper(mesh, config) + helper.check_supported() + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + # Apply all-to-all to inputs before backward pass + # Get heads dimensions based on QKV layout (same as forward) + q_heads_dim, k_heads_dim, v_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) + q_heads = q.shape[q_heads_dim] + k_heads = k.shape[k_heads_dim] + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size" + assert k_heads % cp_size == 0, "k_heads must be divisible by cp_size" + + # Load balanced causal attention is not yet supported for all-to-all strategy + if config.context_parallel_load_balanced: + raise NotImplementedError( + "context_parallel_load_balanced is not supported with all-to-all strategy" + ) + + q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim) + k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim) + v_ = helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim) + + # doutput is always [b, s, h, d], so heads_dim=2 + doutput_ = helper.all_to_all(doutput, True, seq_dim=1, heads_dim=2) + output_ = helper.all_to_all(output, True, seq_dim=1, heads_dim=2) + # softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2 + softmax_aux_ = helper.all_to_all(softmax_aux, True, seq_dim=2, heads_dim=1) + + # Perform backward pass + dq, dk, dv, dbias = FusedAttnBwdPrimitive.impl( + q_, + k_, + v_, + bias, + softmax_aux_, + rng_state, + output_, + doutput_, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(), + ) + + # Apply all-to-all to gradients to restore original sharding (scatter in seq dimension) + dq_heads_dim, dk_heads_dim, dv_heads_dim = helper.get_qkv_heads_dims(seq_dim=1) + dq_ = helper.all_to_all(dq, False, seq_dim=1, heads_dim=dq_heads_dim) + dk_ = helper.all_to_all(dk, False, seq_dim=1, heads_dim=dk_heads_dim) + dv_ = helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim) + + return dq_, dk_, dv_, dbias + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllToAllBwdPrimitive) + + +@dataclass(frozen=True) +class _FusedAttnCPWithA2AHelper: + """ + Helper class for Ulysses-style context parallelism using all-to-all communication. + """ + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused A2A attention" + if self.config.qkv_layout.is_thd(): + raise ValueError(f"{header} does not support THD format") + if self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD: + raise ValueError(f"{header} does not support SBHD format") + + def get_qkv_heads_dims(self, seq_dim=1): + """ + Determines the heads dimension indices for Q, K, V tensors based on QKV layout. + """ + + if self.config.qkv_layout.is_qkvpacked(): + # [batch..., seq, 3, heads, dim] + heads_dim = seq_dim + 2 + return heads_dim, heads_dim, heads_dim + if self.config.qkv_layout.is_kvpacked(): + # Q=[batch..., seq, heads, dim] + # KV=[batch..., seq, 2, heads, dim] + q_heads_dim = seq_dim + 1 + kv_heads_dim = seq_dim + 2 + return q_heads_dim, kv_heads_dim, kv_heads_dim + # separate + # Q/K/V=[batch..., seq, heads, dim] + heads_dim = seq_dim + 1 + return heads_dim, heads_dim, heads_dim + + def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2): + """ + Performs all-to-all communication for context parallelism (Ulysses-style). + + Redistributes data between sequence and heads dimensions across CP ranks. + + Args: + x: Input tensor + before_attn: True = seq-sharded -> heads-sharded, False = heads-sharded -> seq-sharded + seq_dim: Sequence dimension position (typically 1 for BSHD) + heads_dim: Heads dimension position (depends on QKV layout) + + Returns: + Tensor after all-to-all with redistributed dimensions + """ + if x.shape[0] == 0: + return x # If tensor is empty, then no communication is performed. + + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + if before_attn: + num_heads = x.shape[heads_dim] + assert num_heads % cp_size == 0, "num_heads must be divisible by cp_size" + + shape = x.shape + + # Determine which dimension to split and where to concat + if before_attn: + split_axis, concat_axis = heads_dim, seq_dim + else: + split_axis = seq_dim + concat_axis = heads_dim + + # Reshape: insert cp_size at split_axis + assert shape[split_axis] % cp_size == 0 + x = x.reshape( + *shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1 :] + ) + if concat_axis > split_axis: + concat_axis += 1 # Added one dimenstion before concat_axis, need to adjust it. + + # Perform all-to-all + x = lax_paral_op( + x, + lax.all_to_all, + self.config.cp_axis, + mesh=self.mesh, + split_axis=split_axis, + concat_axis=concat_axis, + tiled=True, + ) + + # Merge the two dimensions created by all-to-all at split_axis + new_shape = list(x.shape) + new_shape[split_axis : split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]] + return x.reshape(new_shape) + + def get_step_config(self) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=self.config.attn_mask_type, + qkv_layout=self.config.qkv_layout, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.config.max_segments_per_seq, + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis="", # No CP axis for the inner attention call + cp_striped_window_size=None, + ) + + @dataclass(frozen=True) class _FusedAttnCPWithP2PHelper: """Helper class to assist with running the P2P ring strategy for CP attention.""" @@ -2641,6 +2956,8 @@ def fused_attn_fwd( primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive else: primitive = FusedRingAttnFwdPrimitive.outer_primitive + case CPStrategy.ALL_TO_ALL: + primitive = FusedAttnCPWithAllToAllFwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) output, softmax_aux, rng_state = primitive.bind( @@ -2767,6 +3084,8 @@ def fused_attn_bwd( primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive else: primitive = FusedRingAttnBwdPrimitive.outer_primitive + case CPStrategy.ALL_TO_ALL: + primitive = FusedAttnCPWithAllToAllBwdPrimitive.outer_primitive seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) *qkv_grads, bias_grad = primitive.bind( diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1eafed4131..7ec41e200a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -511,7 +511,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. - context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. + context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING, 3: ALL_TO_ALL. context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. Optimization parameters @@ -662,6 +662,9 @@ def __call__( "ALL_GATHER": CPStrategy.ALL_GATHER, "ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling "RING": CPStrategy.RING, + "ALL_TO_ALL": CPStrategy.ALL_TO_ALL, + "ALLTOALL": CPStrategy.ALL_TO_ALL, # Alternative spelling + "A2A": CPStrategy.ALL_TO_ALL, # Short form } strategy_key = self.context_parallel_strategy.upper()