@@ -1644,9 +1644,7 @@ def partition(config, mesh, arg_infos, result_infos):
16441644 return FusedAttnFwdPrimitive .partition (config , mesh , arg_infos , result_infos )
16451645
16461646 helper = _FusedAttnCPWithA2AHelper (mesh , config )
1647- q_aval = arg_infos [0 ].aval if hasattr (arg_infos [0 ], "aval" ) else arg_infos [0 ]
1648- num_heads = q_aval .shape [2 ]
1649- helper .check_supported (num_heads )
1647+ helper .check_supported ()
16501648
16511649 out_sharding = result_infos [0 ].sharding
16521650 softmax_aux_sharding = result_infos [1 ].sharding
@@ -1673,11 +1671,26 @@ def impl(
16731671 _q_segment_pos ,
16741672 _kv_segment_pos ,
16751673 ):
1676- q_ , k_ , v_ = (
1677- helper .all_to_all (q , True ),
1678- helper .all_to_all (k , True ),
1679- helper .all_to_all (v , True ),
1680- )
1674+ # Get heads dimensions based on QKV layout
1675+ q_heads_dim , k_heads_dim , v_heads_dim = helper .get_qkv_heads_dims (seq_dim = 1 )
1676+ q_heads = q .shape [q_heads_dim ]
1677+ kv_heads = k .shape [k_heads_dim ]
1678+ cp_size = get_mesh_axis_size (config .cp_axis , mesh )
1679+ assert q_heads % cp_size == 0 , "q_heads must be divisible by cp_size"
1680+ assert kv_heads % cp_size == 0 , "kv_heads must be divisible by cp_size"
1681+
1682+ # Load balanced causal attention is not yet supported for all-to-all strategy
1683+ if config .context_parallel_load_balanced :
1684+ raise NotImplementedError (
1685+ "context_parallel_load_balanced is not supported with all-to-all strategy"
1686+ )
1687+
1688+ # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension)
1689+ q_ = helper .all_to_all (q , True , seq_dim = 1 , heads_dim = q_heads_dim )
1690+ k_ = helper .all_to_all (k , True , seq_dim = 1 , heads_dim = k_heads_dim )
1691+ # For KVPACKED layout, v is empty placeholder
1692+ v_ = v if v .shape [0 ] == 0 else helper .all_to_all (v , True , seq_dim = 1 , heads_dim = v_heads_dim )
1693+
16811694 output , softmax_aux , rng_state = FusedAttnFwdPrimitive .impl (
16821695 q_ ,
16831696 k_ ,
@@ -1694,7 +1707,10 @@ def impl(
16941707 _kv_segment_pos ,
16951708 config = helper .get_step_config (),
16961709 )
1697- output = helper .all_to_all (output , False )
1710+
1711+ # Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension)
1712+ # output is always [b, s, h/cp, d] -> heads_dim=2
1713+ output = helper .all_to_all (output , False , seq_dim = 1 , heads_dim = 2 )
16981714 # softmax_aux has shape [b, h/cp, s, 1] with heads at dim 1, seq at dim 2
16991715 softmax_aux = helper .all_to_all (softmax_aux , False , seq_dim = 2 , heads_dim = 1 )
17001716 return output , softmax_aux , rng_state
@@ -1720,17 +1736,29 @@ def partition(config, mesh, arg_infos, result_infos):
17201736 return FusedAttnBwdPrimitive .partition (config , mesh , arg_infos , result_infos )
17211737
17221738 helper = _FusedAttnCPWithA2AHelper (mesh , config )
1723- q_aval = arg_infos [0 ].aval if hasattr (arg_infos [0 ], "aval" ) else arg_infos [0 ]
1724- num_heads = q_aval .shape [2 ]
1725- helper .check_supported (num_heads )
1739+ helper .check_supported ()
17261740
17271741 dq_sharding = result_infos [0 ].sharding
17281742 dk_sharding = result_infos [1 ].sharding
17291743 dv_sharding = result_infos [2 ].sharding
17301744 dbias_sharding = result_infos [3 ].sharding
1731- arg_shardings = tuple ([arg_i .sharding for arg_i in arg_infos ])
1745+
1746+ # For AllToAll context parallel, output and doutput need to be seq-sharded
1747+ # to match the forward output sharding (before they get transformed to heads-sharded)
1748+ arg_shardings = list ([arg_i .sharding for arg_i in arg_infos ])
1749+ # arg_infos: [q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, ...]
1750+ # output is at index 6, doutput is at index 7
1751+ # They should have the same sharding as the forward output (seq-sharded on cp axis)
1752+ output_seq_sharding = NamedSharding (mesh , PartitionSpec (None , config .cp_axis , None , None ))
1753+ softmax_aux_seq_sharding = NamedSharding (mesh , PartitionSpec (None , None , config .cp_axis , None ))
1754+ arg_shardings [4 ] = softmax_aux_seq_sharding # softmax_aux [b, h, s/cp, 1]
1755+ arg_shardings [6 ] = output_seq_sharding # output [b, s/cp, h, d]
1756+ arg_shardings [7 ] = output_seq_sharding # doutput [b, s/cp, h, d]
1757+ arg_shardings = tuple (arg_shardings )
1758+
17321759 out_shardings = (dq_sharding , dk_sharding , dv_sharding , dbias_sharding )
1733-
1760+
1761+
17341762 def impl (
17351763 q ,
17361764 k ,
@@ -1750,12 +1778,29 @@ def impl(
17501778 _kv_segment_pos ,
17511779 ):
17521780 # Apply all-to-all to inputs before backward pass
1753- q_ , k_ , v_ = (
1754- helper .all_to_all (q , True ),
1755- helper .all_to_all (k , True ),
1756- helper .all_to_all (v , True ),
1757- )
1758- doutput_ = helper .all_to_all (doutput , True )
1781+ # Get heads dimensions based on QKV layout (same as forward)
1782+ q_heads_dim , k_heads_dim , v_heads_dim = helper .get_qkv_heads_dims (seq_dim = 1 )
1783+ q_heads = q .shape [q_heads_dim ]
1784+ k_heads = k .shape [k_heads_dim ]
1785+ cp_size = get_mesh_axis_size (config .cp_axis , mesh )
1786+ assert q_heads % cp_size == 0 , "q_heads must be divisible by cp_size"
1787+ assert k_heads % cp_size == 0 , "k_heads must be divisible by cp_size"
1788+
1789+ # Load balanced causal attention is not yet supported for all-to-all strategy
1790+ if config .context_parallel_load_balanced :
1791+ raise NotImplementedError (
1792+ "context_parallel_load_balanced is not supported with all-to-all strategy"
1793+ )
1794+
1795+ # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension)
1796+ q_ = helper .all_to_all (q , True , seq_dim = 1 , heads_dim = q_heads_dim )
1797+ k_ = helper .all_to_all (k , True , seq_dim = 1 , heads_dim = k_heads_dim )
1798+ # For KVPACKED layout, v is empty placeholder
1799+ v_ = v if v .shape [0 ] == 0 else helper .all_to_all (v , True , seq_dim = 1 , heads_dim = v_heads_dim )
1800+ # doutput is always separate [b, s, h, d], so heads_dim=2
1801+ doutput_ = helper .all_to_all (doutput , True , seq_dim = 1 , heads_dim = 2 )
1802+ # output has the same shape as doutput, needs the same transformation
1803+ output_ = helper .all_to_all (output , True , seq_dim = 1 , heads_dim = 2 )
17591804 # softmax_aux has shape [b, h, s/cp, 1] with heads at dim 1, seq at dim 2
17601805 softmax_aux_ = helper .all_to_all (softmax_aux , True , seq_dim = 2 , heads_dim = 1 )
17611806
@@ -1767,7 +1812,7 @@ def impl(
17671812 bias ,
17681813 softmax_aux_ ,
17691814 rng_state ,
1770- output ,
1815+ output_ ,
17711816 doutput_ ,
17721817 q_seqlen ,
17731818 kv_seqlen ,
@@ -1779,11 +1824,15 @@ def impl(
17791824 _kv_segment_pos ,
17801825 config = helper .get_step_config (),
17811826 )
1782-
1783- # Apply all-to-all to gradients to restore original sharding
1784- dq_ = helper .all_to_all (dq , False )
1785- dk_ = helper .all_to_all (dk , False )
1786- dv_ = helper .all_to_all (dv , False )
1827+
1828+ # Apply all-to-all to gradients to restore original sharding (scatter in seq dimension)
1829+ # Gradients have the same shape as inputs, so use same heads_dim
1830+ dq_heads_dim , dk_heads_dim , dv_heads_dim = helper .get_qkv_heads_dims (seq_dim = 1 )
1831+
1832+ dq_ = helper .all_to_all (dq , False , seq_dim = 1 , heads_dim = dq_heads_dim )
1833+ dk_ = helper .all_to_all (dk , False , seq_dim = 1 , heads_dim = dk_heads_dim )
1834+ # For KVPACKED layout, dv is empty placeholder
1835+ dv_ = dv if dv .shape [0 ] == 0 else helper .all_to_all (dv , False , seq_dim = 1 , heads_dim = dv_heads_dim )
17871836
17881837 return dq_ , dk_ , dv_ , dbias
17891838
@@ -1795,86 +1844,121 @@ def impl(
17951844
17961845@dataclass (frozen = True )
17971846class _FusedAttnCPWithA2AHelper :
1798- """Helper class to assist with all-to-all communication for context parallel attention.
1847+ """
1848+ Helper class for Ulysses-style context parallelism using all-to-all communication.
1849+
1850+ This helper manages the all-to-all communication pattern that redistributes tensors
1851+ between sequence-sharded and heads-sharded layouts. This enables context parallelism
1852+ by allowing each rank to process different parts of the sequence and heads dimensions.
17991853
1800- This class provides methods for performing all-to-all communication across devices
1801- and handles both THD and BSHD layout formats appropriately.
1854+ Attributes:
1855+ mesh: JAX mesh for distributed computation
1856+ config: Fused attention configuration including QKV layout information
18021857 """
18031858
18041859 mesh : jax .sharding .Mesh
18051860 config : _FusedAttnConfig
18061861
1807- def check_supported (self , num_heads ):
1862+ def check_supported (self ):
18081863 """Checks if the context parallel implementation is supported by the given arguments."""
18091864 header = "Context parallel fused A2A attention"
18101865 if self .config .qkv_layout .is_thd ():
18111866 raise ValueError (f"{ header } does not support THD format" )
18121867 elif self .config .qkv_layout .get_qkv_format () is QKVFormat .SBHD :
18131868 raise ValueError (f"{ header } does not support SBHD format" )
18141869
1815- cp_size = get_mesh_axis_size (self .config .cp_axis , self .mesh )
1816- if num_heads % cp_size != 0 :
1817- raise ValueError (
1818- f"{ header } requires num_heads ({ num_heads } ) to be divisible by "
1819- f"context parallel size ({ cp_size } )"
1820- )
1870+ def get_qkv_heads_dims (self , seq_dim = 1 ):
1871+ """
1872+ Determines the heads dimension indices for Q, K, V tensors based on QKV layout.
1873+
1874+ The heads dimension position depends on the QKV packing format:
1875+ - QKVPacked: All tensors packed together with dimension [qkv=3, heads, dim]
1876+ - KVPacked: Q is separate, K and V are packed with dimension [kv=2, heads, dim]
1877+ - Separate: All tensors are separate with dimension [heads, dim]
1878+
1879+ Args:
1880+ seq_dim: The sequence dimension position (default 1 for BSHD format)
1881+
1882+ Returns:
1883+ Tuple of (q_heads_dim, k_heads_dim, v_heads_dim) indicating the position
1884+ of the heads dimension for each tensor.
1885+
1886+ Examples for BSHD layout (seq_dim=1):
1887+ QKVPacked: Q=[b, s, 3, h, d] -> returns (3, 3, 3)
1888+ KVPacked: Q=[b, s, h, d], K=[b, s, 2, h, d], V=[b, s, 2, h, d] -> returns (2, 3, 3)
1889+ Separate: Q=[b, s, h, d], K=[b, s, h, d], V=[b, s, h, d] -> returns (2, 2, 2)
1890+ """
1891+ if self .config .qkv_layout .is_qkvpacked ():
1892+ # QKV all packed together: [batch..., seq, 3, heads, dim]
1893+ # Heads dimension is at seq_dim + 2 for all tensors
1894+ heads_dim = seq_dim + 2
1895+ return heads_dim , heads_dim , heads_dim
1896+ elif self .config .qkv_layout .is_kvpacked ():
1897+ # Q separate, K and V packed: Q=[batch..., seq, heads, dim]
1898+ # K/V=[batch..., seq, 2, heads, dim]
1899+ q_heads_dim = seq_dim + 1 # Q has no packing dimension
1900+ kv_heads_dim = seq_dim + 2 # K/V have packing dimension [2, heads, dim]
1901+ return q_heads_dim , kv_heads_dim , kv_heads_dim
1902+ else : # separate
1903+ # All separate: Q/K/V=[batch..., seq, heads, dim]
1904+ # Heads dimension is at seq_dim + 1 for all tensors
1905+ heads_dim = seq_dim + 1
1906+ return heads_dim , heads_dim , heads_dim
18211907
18221908 def all_to_all (self , x , before_attn = True , seq_dim = 1 , heads_dim = 2 ):
18231909 """
1824- Performs all-to-all communication for context parallelism.
1910+ Performs all-to-all communication for context parallelism (Ulysses-style).
1911+
1912+ Redistributes data between sequence and heads dimensions across CP ranks.
18251913
18261914 Args:
18271915 x: Input tensor
1828- before_attn: If True, converts seq-> heads dist. If False, converts heads->seq dist.
1829- seq_dim: Position of sequence dimension (default 1 for BSHD: [b, s, h, d] )
1830- heads_dim: Position of heads dimension (default 2 for BSHD: [b, s, h, d] )
1916+ before_attn: True = seq-sharded -> heads-sharded, False = heads-sharded -> seq-sharded
1917+ seq_dim: Sequence dimension position (typically 1 for BSHD)
1918+ heads_dim: Heads dimension position (depends on QKV layout )
18311919
18321920 Returns:
18331921 Tensor after all-to-all with redistributed dimensions
1834-
1835- Shape transforms for BSHD (seq_dim=1, heads_dim=2):
1836- before_attn=True: [b, s/cp, h, d] -> [b, s, h/cp, d]
1837- before_attn=False: [b, s, h/cp, d] -> [b, s/cp, h, d]
1838-
1839- Shape transforms for softmax_aux (seq_dim=2, heads_dim=1):
1840- before_attn=False: [b, h/cp, s, ...] -> [b, h, s/cp, ...]
18411922 """
1923+ if x .shape [0 ] == 0 :
1924+ return x
1925+
18421926 cp_size = get_mesh_axis_size (self .config .cp_axis , self .mesh )
1927+ if before_attn :
1928+ num_heads = x .shape [heads_dim ]
1929+ assert num_heads % cp_size == 0 , "num_heads must be divisible by cp_size"
1930+
18431931 shape = x .shape
18441932
1933+ # Determine which dimension to split and where to concat
18451934 if before_attn :
1846- # Input: sharded on seq, want to shard on heads
1847- # Split heads: [..., s/cp, ..., h, ...] -> [..., s/cp, ..., cp, h/cp, ...]
1848- x = x .reshape (
1849- * shape [:heads_dim ], cp_size , shape [heads_dim ] // cp_size , * shape [heads_dim + 1 :]
1850- )
1851- # A2A splits cp dimension and concatenates into seq
1852- split_axis = heads_dim
1853- concat_axis = seq_dim
1935+ split_axis , concat_axis = heads_dim , seq_dim
1936+ # After reshape, need to adjust concat_axis if it comes after split_axis
1937+ needs_adjustment = concat_axis > split_axis
18541938 else :
1855- # Input: sharded on heads, want to shard on seq
1856- # Unflatten seq: [..., s, ..., h/cp, ...] -> [..., cp, s/cp, ..., h/cp, ...]
1857- s_global = shape [seq_dim ]
1858- s_local = s_global // cp_size
1859- new_shape = list (shape )
1860- new_shape [seq_dim : seq_dim + 1 ] = [cp_size , s_local ]
1861- x = x .reshape (new_shape )
1862- # A2A splits cp dimension (at seq_dim) and concatenates into heads
18631939 split_axis = seq_dim
1864- concat_axis = heads_dim
1940+ concat_axis = heads_dim + 1 if heads_dim > seq_dim else heads_dim
1941+ needs_adjustment = False # Already adjusted above
1942+
1943+ # Reshape: insert cp_size at split_axis
1944+ assert shape [split_axis ] % cp_size == 0
1945+ x = x .reshape (
1946+ * shape [:split_axis ], cp_size , shape [split_axis ] // cp_size , * shape [split_axis + 1 :]
1947+ )
1948+
1949+ # Adjust concat_axis if needed (only for before_attn case)
1950+ adjusted_concat_axis = concat_axis + 1 if needs_adjustment else concat_axis
18651951
1866- # All -to-all communication
1952+ # Perform all -to-all
18671953 x = lax_paral_op (
1868- x ,
1869- lax .all_to_all ,
1870- self .config .cp_axis ,
1871- mesh = self .mesh ,
1872- split_axis = split_axis ,
1873- concat_axis = concat_axis ,
1874- tiled = True ,
1954+ x , lax .all_to_all , self .config .cp_axis , mesh = self .mesh ,
1955+ split_axis = split_axis , concat_axis = adjusted_concat_axis , tiled = True
18751956 )
18761957
1877- return x
1958+ # Merge the two dimensions created by all-to-all at split_axis
1959+ new_shape = list (x .shape )
1960+ new_shape [split_axis :split_axis + 2 ] = [x .shape [split_axis ] * x .shape [split_axis + 1 ]]
1961+ return x .reshape (new_shape )
18781962
18791963 def get_step_config (self ) -> _FusedAttnConfig :
18801964 """Returns a _FusedAttnConfig for single CP step call to fused attention."""
0 commit comments