Skip to content

Commit aa78e66

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 290bef1 commit aa78e66

File tree

2 files changed

+47
-25
lines changed

2 files changed

+47
-25
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,13 @@ def test_context_parallel_alltoall_attn(
618618
load_balanced,
619619
):
620620
if data_shape[2] % (mesh_shape[1] * mesh_shape[2] * kv_groups) != 0:
621-
pytest.skip("Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups")
621+
pytest.skip(
622+
"Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups"
623+
)
622624
if load_balanced:
623-
pytest.skip("Load balanced causal attention is not yet supported with all-to-all strategy")
625+
pytest.skip(
626+
"Load balanced causal attention is not yet supported with all-to-all strategy"
627+
)
624628
self.impl_test_context_parallel_attn(
625629
device_count,
626630
mesh_shape,

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,7 @@ def impl(
16781678
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
16791679
assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size"
16801680
assert kv_heads % cp_size == 0, "kv_heads must be divisible by cp_size"
1681-
1681+
16821682
# Load balanced causal attention is not yet supported for all-to-all strategy
16831683
if config.context_parallel_load_balanced:
16841684
raise NotImplementedError(
@@ -1689,8 +1689,12 @@ def impl(
16891689
q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim)
16901690
k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim)
16911691
# For KVPACKED layout, v is empty placeholder
1692-
v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim)
1693-
1692+
v_ = (
1693+
v
1694+
if v.shape[0] == 0
1695+
else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim)
1696+
)
1697+
16941698
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
16951699
q_,
16961700
k_,
@@ -1707,7 +1711,7 @@ def impl(
17071711
_kv_segment_pos,
17081712
config=helper.get_step_config(),
17091713
)
1710-
1714+
17111715
# Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension)
17121716
# output is always [b, s, h/cp, d] -> heads_dim=2
17131717
output = helper.all_to_all(output, False, seq_dim=1, heads_dim=2)
@@ -1742,23 +1746,24 @@ def partition(config, mesh, arg_infos, result_infos):
17421746
dk_sharding = result_infos[1].sharding
17431747
dv_sharding = result_infos[2].sharding
17441748
dbias_sharding = result_infos[3].sharding
1745-
1749+
17461750
# For AllToAll context parallel, output and doutput need to be seq-sharded
17471751
# to match the forward output sharding (before they get transformed to heads-sharded)
17481752
arg_shardings = list([arg_i.sharding for arg_i in arg_infos])
17491753
# arg_infos: [q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, ...]
17501754
# output is at index 6, doutput is at index 7
17511755
# They should have the same sharding as the forward output (seq-sharded on cp axis)
17521756
output_seq_sharding = NamedSharding(mesh, PartitionSpec(None, config.cp_axis, None, None))
1753-
softmax_aux_seq_sharding = NamedSharding(mesh, PartitionSpec(None, None, config.cp_axis, None))
1757+
softmax_aux_seq_sharding = NamedSharding(
1758+
mesh, PartitionSpec(None, None, config.cp_axis, None)
1759+
)
17541760
arg_shardings[4] = softmax_aux_seq_sharding # softmax_aux [b, h, s/cp, 1]
17551761
arg_shardings[6] = output_seq_sharding # output [b, s/cp, h, d]
17561762
arg_shardings[7] = output_seq_sharding # doutput [b, s/cp, h, d]
17571763
arg_shardings = tuple(arg_shardings)
1758-
1764+
17591765
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
1760-
1761-
1766+
17621767
def impl(
17631768
q,
17641769
k,
@@ -1785,18 +1790,22 @@ def impl(
17851790
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
17861791
assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size"
17871792
assert k_heads % cp_size == 0, "k_heads must be divisible by cp_size"
1788-
1793+
17891794
# Load balanced causal attention is not yet supported for all-to-all strategy
17901795
if config.context_parallel_load_balanced:
17911796
raise NotImplementedError(
17921797
"context_parallel_load_balanced is not supported with all-to-all strategy"
17931798
)
1794-
1799+
17951800
# Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension)
17961801
q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim)
17971802
k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim)
17981803
# For KVPACKED layout, v is empty placeholder
1799-
v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim)
1804+
v_ = (
1805+
v
1806+
if v.shape[0] == 0
1807+
else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim)
1808+
)
18001809
# doutput is always separate [b, s, h, d], so heads_dim=2
18011810
doutput_ = helper.all_to_all(doutput, True, seq_dim=1, heads_dim=2)
18021811
# output has the same shape as doutput, needs the same transformation
@@ -1824,15 +1833,19 @@ def impl(
18241833
_kv_segment_pos,
18251834
config=helper.get_step_config(),
18261835
)
1827-
1836+
18281837
# Apply all-to-all to gradients to restore original sharding (scatter in seq dimension)
18291838
# Gradients have the same shape as inputs, so use same heads_dim
18301839
dq_heads_dim, dk_heads_dim, dv_heads_dim = helper.get_qkv_heads_dims(seq_dim=1)
1831-
1840+
18321841
dq_ = helper.all_to_all(dq, False, seq_dim=1, heads_dim=dq_heads_dim)
18331842
dk_ = helper.all_to_all(dk, False, seq_dim=1, heads_dim=dk_heads_dim)
18341843
# For KVPACKED layout, dv is empty placeholder
1835-
dv_ = dv if dv.shape[0] == 0 else helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim)
1844+
dv_ = (
1845+
dv
1846+
if dv.shape[0] == 0
1847+
else helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim)
1848+
)
18361849

18371850
return dq_, dk_, dv_, dbias
18381851

@@ -1870,19 +1883,19 @@ def check_supported(self):
18701883
def get_qkv_heads_dims(self, seq_dim=1):
18711884
"""
18721885
Determines the heads dimension indices for Q, K, V tensors based on QKV layout.
1873-
1886+
18741887
The heads dimension position depends on the QKV packing format:
18751888
- QKVPacked: All tensors packed together with dimension [qkv=3, heads, dim]
18761889
- KVPacked: Q is separate, K and V are packed with dimension [kv=2, heads, dim]
18771890
- Separate: All tensors are separate with dimension [heads, dim]
1878-
1891+
18791892
Args:
18801893
seq_dim: The sequence dimension position (default 1 for BSHD format)
1881-
1894+
18821895
Returns:
18831896
Tuple of (q_heads_dim, k_heads_dim, v_heads_dim) indicating the position
18841897
of the heads dimension for each tensor.
1885-
1898+
18861899
Examples for BSHD layout (seq_dim=1):
18871900
QKVPacked: Q=[b, s, 3, h, d] -> returns (3, 3, 3)
18881901
KVPacked: Q=[b, s, h, d], K=[b, s, 2, h, d], V=[b, s, 2, h, d] -> returns (2, 3, 3)
@@ -1943,21 +1956,26 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2):
19431956
# Reshape: insert cp_size at split_axis
19441957
assert shape[split_axis] % cp_size == 0
19451958
x = x.reshape(
1946-
*shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1:]
1959+
*shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1 :]
19471960
)
19481961

19491962
# Adjust concat_axis if needed (only for before_attn case)
19501963
adjusted_concat_axis = concat_axis + 1 if needs_adjustment else concat_axis
19511964

19521965
# Perform all-to-all
19531966
x = lax_paral_op(
1954-
x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh,
1955-
split_axis=split_axis, concat_axis=adjusted_concat_axis, tiled=True
1967+
x,
1968+
lax.all_to_all,
1969+
self.config.cp_axis,
1970+
mesh=self.mesh,
1971+
split_axis=split_axis,
1972+
concat_axis=adjusted_concat_axis,
1973+
tiled=True,
19561974
)
19571975

19581976
# Merge the two dimensions created by all-to-all at split_axis
19591977
new_shape = list(x.shape)
1960-
new_shape[split_axis:split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]]
1978+
new_shape[split_axis : split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]]
19611979
return x.reshape(new_shape)
19621980

19631981
def get_step_config(self) -> _FusedAttnConfig:

0 commit comments

Comments
 (0)