Skip to content

Commit 290bef1

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
1 parent d439db9 commit 290bef1

File tree

3 files changed

+167
-75
lines changed

3 files changed

+167
-75
lines changed

tests/jax/distributed_test_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99

1010
import jax
11-
from jax.experimental.pjit import pjit, _UNSPECIFIED
11+
from jax.experimental.pjit import pjit
12+
from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
1213

1314
from transformer_engine.jax.sharding import MeshResource
1415

@@ -38,7 +39,7 @@ def generate_configs():
3839
return configs
3940

4041

41-
def generate_context_parallel_configs_for_attn():
42+
def generate_context_parallel_configs_for_attn(heads_divisible_by_cp_times_tp=False):
4243
"""Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
4344
configsL1 = []
4445
configsL2 = []
@@ -52,6 +53,9 @@ def generate_context_parallel_configs_for_attn():
5253
if is_devices_enough(ndev):
5354
# Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations)
5455
if cp != 1:
56+
if heads_divisible_by_cp_times_tp:
57+
if num_heads % (cp * tp) != 0:
58+
continue
5559
configsL1.append(
5660
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
5761
)

tests/jax/test_distributed_fused_attn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ def test_context_parallel_alltoall_attn(
617617
qkv_layout,
618618
load_balanced,
619619
):
620+
if data_shape[2] % (mesh_shape[1] * mesh_shape[2] * kv_groups) != 0:
621+
pytest.skip("Skipping test as num_heads is not divisible by cp_size * tp_size * kv_groups")
622+
if load_balanced:
623+
pytest.skip("Load balanced causal attention is not yet supported with all-to-all strategy")
620624
self.impl_test_context_parallel_attn(
621625
device_count,
622626
mesh_shape,

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 157 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,9 +1644,7 @@ def partition(config, mesh, arg_infos, result_infos):
16441644
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
16451645

16461646
helper = _FusedAttnCPWithA2AHelper(mesh, config)
1647-
q_aval = arg_infos[0].aval if hasattr(arg_infos[0], "aval") else arg_infos[0]
1648-
num_heads = q_aval.shape[2]
1649-
helper.check_supported(num_heads)
1647+
helper.check_supported()
16501648

16511649
out_sharding = result_infos[0].sharding
16521650
softmax_aux_sharding = result_infos[1].sharding
@@ -1673,11 +1671,26 @@ def impl(
16731671
_q_segment_pos,
16741672
_kv_segment_pos,
16751673
):
1676-
q_, k_, v_ = (
1677-
helper.all_to_all(q, True),
1678-
helper.all_to_all(k, True),
1679-
helper.all_to_all(v, True),
1680-
)
1674+
# Get heads dimensions based on QKV layout
1675+
q_heads_dim, k_heads_dim, v_heads_dim = helper.get_qkv_heads_dims(seq_dim=1)
1676+
q_heads = q.shape[q_heads_dim]
1677+
kv_heads = k.shape[k_heads_dim]
1678+
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
1679+
assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size"
1680+
assert kv_heads % cp_size == 0, "kv_heads must be divisible by cp_size"
1681+
1682+
# Load balanced causal attention is not yet supported for all-to-all strategy
1683+
if config.context_parallel_load_balanced:
1684+
raise NotImplementedError(
1685+
"context_parallel_load_balanced is not supported with all-to-all strategy"
1686+
)
1687+
1688+
# Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension)
1689+
q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim)
1690+
k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim)
1691+
# For KVPACKED layout, v is empty placeholder
1692+
v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim)
1693+
16811694
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
16821695
q_,
16831696
k_,
@@ -1694,7 +1707,10 @@ def impl(
16941707
_kv_segment_pos,
16951708
config=helper.get_step_config(),
16961709
)
1697-
output = helper.all_to_all(output, False)
1710+
1711+
# Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension)
1712+
# output is always [b, s, h/cp, d] -> heads_dim=2
1713+
output = helper.all_to_all(output, False, seq_dim=1, heads_dim=2)
16981714
# softmax_aux has shape [b, h/cp, s, 1] with heads at dim 1, seq at dim 2
16991715
softmax_aux = helper.all_to_all(softmax_aux, False, seq_dim=2, heads_dim=1)
17001716
return output, softmax_aux, rng_state
@@ -1720,17 +1736,29 @@ def partition(config, mesh, arg_infos, result_infos):
17201736
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
17211737

17221738
helper = _FusedAttnCPWithA2AHelper(mesh, config)
1723-
q_aval = arg_infos[0].aval if hasattr(arg_infos[0], "aval") else arg_infos[0]
1724-
num_heads = q_aval.shape[2]
1725-
helper.check_supported(num_heads)
1739+
helper.check_supported()
17261740

17271741
dq_sharding = result_infos[0].sharding
17281742
dk_sharding = result_infos[1].sharding
17291743
dv_sharding = result_infos[2].sharding
17301744
dbias_sharding = result_infos[3].sharding
1731-
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos])
1745+
1746+
# For AllToAll context parallel, output and doutput need to be seq-sharded
1747+
# to match the forward output sharding (before they get transformed to heads-sharded)
1748+
arg_shardings = list([arg_i.sharding for arg_i in arg_infos])
1749+
# arg_infos: [q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, ...]
1750+
# output is at index 6, doutput is at index 7
1751+
# They should have the same sharding as the forward output (seq-sharded on cp axis)
1752+
output_seq_sharding = NamedSharding(mesh, PartitionSpec(None, config.cp_axis, None, None))
1753+
softmax_aux_seq_sharding = NamedSharding(mesh, PartitionSpec(None, None, config.cp_axis, None))
1754+
arg_shardings[4] = softmax_aux_seq_sharding # softmax_aux [b, h, s/cp, 1]
1755+
arg_shardings[6] = output_seq_sharding # output [b, s/cp, h, d]
1756+
arg_shardings[7] = output_seq_sharding # doutput [b, s/cp, h, d]
1757+
arg_shardings = tuple(arg_shardings)
1758+
17321759
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
1733-
1760+
1761+
17341762
def impl(
17351763
q,
17361764
k,
@@ -1750,12 +1778,29 @@ def impl(
17501778
_kv_segment_pos,
17511779
):
17521780
# Apply all-to-all to inputs before backward pass
1753-
q_, k_, v_ = (
1754-
helper.all_to_all(q, True),
1755-
helper.all_to_all(k, True),
1756-
helper.all_to_all(v, True),
1757-
)
1758-
doutput_ = helper.all_to_all(doutput, True)
1781+
# Get heads dimensions based on QKV layout (same as forward)
1782+
q_heads_dim, k_heads_dim, v_heads_dim = helper.get_qkv_heads_dims(seq_dim=1)
1783+
q_heads = q.shape[q_heads_dim]
1784+
k_heads = k.shape[k_heads_dim]
1785+
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
1786+
assert q_heads % cp_size == 0, "q_heads must be divisible by cp_size"
1787+
assert k_heads % cp_size == 0, "k_heads must be divisible by cp_size"
1788+
1789+
# Load balanced causal attention is not yet supported for all-to-all strategy
1790+
if config.context_parallel_load_balanced:
1791+
raise NotImplementedError(
1792+
"context_parallel_load_balanced is not supported with all-to-all strategy"
1793+
)
1794+
1795+
# Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension)
1796+
q_ = helper.all_to_all(q, True, seq_dim=1, heads_dim=q_heads_dim)
1797+
k_ = helper.all_to_all(k, True, seq_dim=1, heads_dim=k_heads_dim)
1798+
# For KVPACKED layout, v is empty placeholder
1799+
v_ = v if v.shape[0] == 0 else helper.all_to_all(v, True, seq_dim=1, heads_dim=v_heads_dim)
1800+
# doutput is always separate [b, s, h, d], so heads_dim=2
1801+
doutput_ = helper.all_to_all(doutput, True, seq_dim=1, heads_dim=2)
1802+
# output has the same shape as doutput, needs the same transformation
1803+
output_ = helper.all_to_all(output, True, seq_dim=1, heads_dim=2)
17591804
# softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2
17601805
softmax_aux_ = helper.all_to_all(softmax_aux, True, seq_dim=2, heads_dim=1)
17611806

@@ -1767,7 +1812,7 @@ def impl(
17671812
bias,
17681813
softmax_aux_,
17691814
rng_state,
1770-
output,
1815+
output_,
17711816
doutput_,
17721817
q_seqlen,
17731818
kv_seqlen,
@@ -1779,11 +1824,15 @@ def impl(
17791824
_kv_segment_pos,
17801825
config=helper.get_step_config(),
17811826
)
1782-
1783-
# Apply all-to-all to gradients to restore original sharding
1784-
dq_ = helper.all_to_all(dq, False)
1785-
dk_ = helper.all_to_all(dk, False)
1786-
dv_ = helper.all_to_all(dv, False)
1827+
1828+
# Apply all-to-all to gradients to restore original sharding (scatter in seq dimension)
1829+
# Gradients have the same shape as inputs, so use same heads_dim
1830+
dq_heads_dim, dk_heads_dim, dv_heads_dim = helper.get_qkv_heads_dims(seq_dim=1)
1831+
1832+
dq_ = helper.all_to_all(dq, False, seq_dim=1, heads_dim=dq_heads_dim)
1833+
dk_ = helper.all_to_all(dk, False, seq_dim=1, heads_dim=dk_heads_dim)
1834+
# For KVPACKED layout, dv is empty placeholder
1835+
dv_ = dv if dv.shape[0] == 0 else helper.all_to_all(dv, False, seq_dim=1, heads_dim=dv_heads_dim)
17871836

17881837
return dq_, dk_, dv_, dbias
17891838

@@ -1795,86 +1844,121 @@ def impl(
17951844

17961845
@dataclass(frozen=True)
17971846
class _FusedAttnCPWithA2AHelper:
1798-
"""Helper class to assist with all-to-all communication for context parallel attention.
1847+
"""
1848+
Helper class for Ulysses-style context parallelism using all-to-all communication.
1849+
1850+
This helper manages the all-to-all communication pattern that redistributes tensors
1851+
between sequence-sharded and heads-sharded layouts. This enables context parallelism
1852+
by allowing each rank to process different parts of the sequence and heads dimensions.
17991853
1800-
This class provides methods for performing all-to-all communication across devices
1801-
and handles both THD and BSHD layout formats appropriately.
1854+
Attributes:
1855+
mesh: JAX mesh for distributed computation
1856+
config: Fused attention configuration including QKV layout information
18021857
"""
18031858

18041859
mesh: jax.sharding.Mesh
18051860
config: _FusedAttnConfig
18061861

1807-
def check_supported(self, num_heads):
1862+
def check_supported(self):
18081863
"""Checks if the context parallel implementation is supported by the given arguments."""
18091864
header = "Context parallel fused A2A attention"
18101865
if self.config.qkv_layout.is_thd():
18111866
raise ValueError(f"{header} does not support THD format")
18121867
elif self.config.qkv_layout.get_qkv_format() is QKVFormat.SBHD:
18131868
raise ValueError(f"{header} does not support SBHD format")
18141869

1815-
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
1816-
if num_heads % cp_size != 0:
1817-
raise ValueError(
1818-
f"{header} requires num_heads ({num_heads}) to be divisible by "
1819-
f"context parallel size ({cp_size})"
1820-
)
1870+
def get_qkv_heads_dims(self, seq_dim=1):
1871+
"""
1872+
Determines the heads dimension indices for Q, K, V tensors based on QKV layout.
1873+
1874+
The heads dimension position depends on the QKV packing format:
1875+
- QKVPacked: All tensors packed together with dimension [qkv=3, heads, dim]
1876+
- KVPacked: Q is separate, K and V are packed with dimension [kv=2, heads, dim]
1877+
- Separate: All tensors are separate with dimension [heads, dim]
1878+
1879+
Args:
1880+
seq_dim: The sequence dimension position (default 1 for BSHD format)
1881+
1882+
Returns:
1883+
Tuple of (q_heads_dim, k_heads_dim, v_heads_dim) indicating the position
1884+
of the heads dimension for each tensor.
1885+
1886+
Examples for BSHD layout (seq_dim=1):
1887+
QKVPacked: Q=[b, s, 3, h, d] -> returns (3, 3, 3)
1888+
KVPacked: Q=[b, s, h, d], K=[b, s, 2, h, d], V=[b, s, 2, h, d] -> returns (2, 3, 3)
1889+
Separate: Q=[b, s, h, d], K=[b, s, h, d], V=[b, s, h, d] -> returns (2, 2, 2)
1890+
"""
1891+
if self.config.qkv_layout.is_qkvpacked():
1892+
# QKV all packed together: [batch..., seq, 3, heads, dim]
1893+
# Heads dimension is at seq_dim + 2 for all tensors
1894+
heads_dim = seq_dim + 2
1895+
return heads_dim, heads_dim, heads_dim
1896+
elif self.config.qkv_layout.is_kvpacked():
1897+
# Q separate, K and V packed: Q=[batch..., seq, heads, dim]
1898+
# K/V=[batch..., seq, 2, heads, dim]
1899+
q_heads_dim = seq_dim + 1 # Q has no packing dimension
1900+
kv_heads_dim = seq_dim + 2 # K/V have packing dimension [2, heads, dim]
1901+
return q_heads_dim, kv_heads_dim, kv_heads_dim
1902+
else: # separate
1903+
# All separate: Q/K/V=[batch..., seq, heads, dim]
1904+
# Heads dimension is at seq_dim + 1 for all tensors
1905+
heads_dim = seq_dim + 1
1906+
return heads_dim, heads_dim, heads_dim
18211907

18221908
def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2):
18231909
"""
1824-
Performs all-to-all communication for context parallelism.
1910+
Performs all-to-all communication for context parallelism (Ulysses-style).
1911+
1912+
Redistributes data between sequence and heads dimensions across CP ranks.
18251913
18261914
Args:
18271915
x: Input tensor
1828-
before_attn: If True, converts seq->heads dist. If False, converts heads->seq dist.
1829-
seq_dim: Position of sequence dimension (default 1 for BSHD: [b, s, h, d])
1830-
heads_dim: Position of heads dimension (default 2 for BSHD: [b, s, h, d])
1916+
before_attn: True = seq-sharded -> heads-sharded, False = heads-sharded -> seq-sharded
1917+
seq_dim: Sequence dimension position (typically 1 for BSHD)
1918+
heads_dim: Heads dimension position (depends on QKV layout)
18311919
18321920
Returns:
18331921
Tensor after all-to-all with redistributed dimensions
1834-
1835-
Shape transforms for BSHD (seq_dim=1, heads_dim=2):
1836-
before_attn=True: [b, s/cp, h, d] -> [b, s, h/cp, d]
1837-
before_attn=False: [b, s, h/cp, d] -> [b, s/cp, h, d]
1838-
1839-
Shape transforms for softmax_aux (seq_dim=2, heads_dim=1):
1840-
before_attn=False: [b, h/cp, s, ...] -> [b, h, s/cp, ...]
18411922
"""
1923+
if x.shape[0] == 0:
1924+
return x
1925+
18421926
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
1927+
if before_attn:
1928+
num_heads = x.shape[heads_dim]
1929+
assert num_heads % cp_size == 0, "num_heads must be divisible by cp_size"
1930+
18431931
shape = x.shape
18441932

1933+
# Determine which dimension to split and where to concat
18451934
if before_attn:
1846-
# Input: sharded on seq, want to shard on heads
1847-
# Split heads: [..., s/cp, ..., h, ...] -> [..., s/cp, ..., cp, h/cp, ...]
1848-
x = x.reshape(
1849-
*shape[:heads_dim], cp_size, shape[heads_dim] // cp_size, *shape[heads_dim + 1 :]
1850-
)
1851-
# A2A splits cp dimension and concatenates into seq
1852-
split_axis = heads_dim
1853-
concat_axis = seq_dim
1935+
split_axis, concat_axis = heads_dim, seq_dim
1936+
# After reshape, need to adjust concat_axis if it comes after split_axis
1937+
needs_adjustment = concat_axis > split_axis
18541938
else:
1855-
# Input: sharded on heads, want to shard on seq
1856-
# Unflatten seq: [..., s, ..., h/cp, ...] -> [..., cp, s/cp, ..., h/cp, ...]
1857-
s_global = shape[seq_dim]
1858-
s_local = s_global // cp_size
1859-
new_shape = list(shape)
1860-
new_shape[seq_dim : seq_dim + 1] = [cp_size, s_local]
1861-
x = x.reshape(new_shape)
1862-
# A2A splits cp dimension (at seq_dim) and concatenates into heads
18631939
split_axis = seq_dim
1864-
concat_axis = heads_dim
1940+
concat_axis = heads_dim + 1 if heads_dim > seq_dim else heads_dim
1941+
needs_adjustment = False # Already adjusted above
1942+
1943+
# Reshape: insert cp_size at split_axis
1944+
assert shape[split_axis] % cp_size == 0
1945+
x = x.reshape(
1946+
*shape[:split_axis], cp_size, shape[split_axis] // cp_size, *shape[split_axis + 1:]
1947+
)
1948+
1949+
# Adjust concat_axis if needed (only for before_attn case)
1950+
adjusted_concat_axis = concat_axis + 1 if needs_adjustment else concat_axis
18651951

1866-
# All-to-all communication
1952+
# Perform all-to-all
18671953
x = lax_paral_op(
1868-
x,
1869-
lax.all_to_all,
1870-
self.config.cp_axis,
1871-
mesh=self.mesh,
1872-
split_axis=split_axis,
1873-
concat_axis=concat_axis,
1874-
tiled=True,
1954+
x, lax.all_to_all, self.config.cp_axis, mesh=self.mesh,
1955+
split_axis=split_axis, concat_axis=adjusted_concat_axis, tiled=True
18751956
)
18761957

1877-
return x
1958+
# Merge the two dimensions created by all-to-all at split_axis
1959+
new_shape = list(x.shape)
1960+
new_shape[split_axis:split_axis + 2] = [x.shape[split_axis] * x.shape[split_axis + 1]]
1961+
return x.reshape(new_shape)
18781962

18791963
def get_step_config(self) -> _FusedAttnConfig:
18801964
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""

0 commit comments

Comments
 (0)