Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,66 @@ def test_context_parallel_ring_attn_shardy(
use_scan_ring=True,
)

@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
[
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(
QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"
),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="BSHD_KVPACKED-NO_MASK"),
pytest.param(
QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
def test_context_parallel_alltoall_attn_shardy(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
if data_shape[2] % (mesh_shape[1] * mesh_shape[2] * kv_groups) != 0:
pytest.skip(
"Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups"
)
if load_balanced:
pytest.skip(
"Load balanced causal attention is not yet supported with all-to-all strategy"
)
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced=load_balanced,
cp_strategy=CPStrategy.ALL_TO_ALL,
use_shardy=True,
)


REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
"L0": [[]],
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ class CPStrategy(Enum):
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
ALL_TO_ALL: All-to-all implementation.
"""

DEFAULT = 0
ALL_GATHER = 1
RING = 2
ALL_TO_ALL = 3


class ReorderStrategy(Enum):
Expand Down Expand Up @@ -1149,6 +1151,8 @@ def fused_attn(
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]):
Sliding window size.
context_parallel_strategy (CPStrategy):
The strategy of context parallel. Options: DEFAULT, ALL_GATHER, RING, ALL_TO_ALL.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Expand Down
Loading
Loading