@@ -1678,7 +1678,7 @@ def impl(
16781678 cp_size = get_mesh_axis_size (config .cp_axis , mesh )
16791679 assert q_heads % cp_size == 0 , "q_heads must be divisible by cp_size"
16801680 assert kv_heads % cp_size == 0 , "kv_heads must be divisible by cp_size"
1681-
1681+
16821682 # Load balanced causal attention is not yet supported for all-to-all strategy
16831683 if config .context_parallel_load_balanced :
16841684 raise NotImplementedError (
@@ -1689,8 +1689,12 @@ def impl(
16891689 q_ = helper .all_to_all (q , True , seq_dim = 1 , heads_dim = q_heads_dim )
16901690 k_ = helper .all_to_all (k , True , seq_dim = 1 , heads_dim = k_heads_dim )
16911691 # For KVPACKED layout, v is empty placeholder
1692- v_ = v if v .shape [0 ] == 0 else helper .all_to_all (v , True , seq_dim = 1 , heads_dim = v_heads_dim )
1693-
1692+ v_ = (
1693+ v
1694+ if v .shape [0 ] == 0
1695+ else helper .all_to_all (v , True , seq_dim = 1 , heads_dim = v_heads_dim )
1696+ )
1697+
16941698 output , softmax_aux , rng_state = FusedAttnFwdPrimitive .impl (
16951699 q_ ,
16961700 k_ ,
@@ -1707,7 +1711,7 @@ def impl(
17071711 _kv_segment_pos ,
17081712 config = helper .get_step_config (),
17091713 )
1710-
1714+
17111715 # Apply all-to-all to transform from heads-sharded to seq-sharded (scatter in seq dimension)
17121716 # output is always [b, s, h/cp, d] -> heads_dim=2
17131717 output = helper .all_to_all (output , False , seq_dim = 1 , heads_dim = 2 )
@@ -1742,23 +1746,24 @@ def partition(config, mesh, arg_infos, result_infos):
17421746 dk_sharding = result_infos [1 ].sharding
17431747 dv_sharding = result_infos [2 ].sharding
17441748 dbias_sharding = result_infos [3 ].sharding
1745-
1749+
17461750 # For AllToAll context parallel, output and doutput need to be seq-sharded
17471751 # to match the forward output sharding (before they get transformed to heads-sharded)
17481752 arg_shardings = list ([arg_i .sharding for arg_i in arg_infos ])
17491753 # arg_infos: [q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, ...]
17501754 # output is at index 6, doutput is at index 7
17511755 # They should have the same sharding as the forward output (seq-sharded on cp axis)
17521756 output_seq_sharding = NamedSharding (mesh , PartitionSpec (None , config .cp_axis , None , None ))
1753- softmax_aux_seq_sharding = NamedSharding (mesh , PartitionSpec (None , None , config .cp_axis , None ))
1757+ softmax_aux_seq_sharding = NamedSharding (
1758+ mesh , PartitionSpec (None , None , config .cp_axis , None )
1759+ )
17541760 arg_shardings [4 ] = softmax_aux_seq_sharding # softmax_aux [b, h, s/cp, 1]
17551761 arg_shardings [6 ] = output_seq_sharding # output [b, s/cp, h, d]
17561762 arg_shardings [7 ] = output_seq_sharding # doutput [b, s/cp, h, d]
17571763 arg_shardings = tuple (arg_shardings )
1758-
1764+
17591765 out_shardings = (dq_sharding , dk_sharding , dv_sharding , dbias_sharding )
1760-
1761-
1766+
17621767 def impl (
17631768 q ,
17641769 k ,
@@ -1785,18 +1790,22 @@ def impl(
17851790 cp_size = get_mesh_axis_size (config .cp_axis , mesh )
17861791 assert q_heads % cp_size == 0 , "q_heads must be divisible by cp_size"
17871792 assert k_heads % cp_size == 0 , "k_heads must be divisible by cp_size"
1788-
1793+
17891794 # Load balanced causal attention is not yet supported for all-to-all strategy
17901795 if config .context_parallel_load_balanced :
17911796 raise NotImplementedError (
17921797 "context_parallel_load_balanced is not supported with all-to-all strategy"
17931798 )
1794-
1799+
17951800 # Apply all-to-all to transform from seq-sharded to heads-sharded (gather in seq dimension)
17961801 q_ = helper .all_to_all (q , True , seq_dim = 1 , heads_dim = q_heads_dim )
17971802 k_ = helper .all_to_all (k , True , seq_dim = 1 , heads_dim = k_heads_dim )
17981803 # For KVPACKED layout, v is empty placeholder
1799- v_ = v if v .shape [0 ] == 0 else helper .all_to_all (v , True , seq_dim = 1 , heads_dim = v_heads_dim )
1804+ v_ = (
1805+ v
1806+ if v .shape [0 ] == 0
1807+ else helper .all_to_all (v , True , seq_dim = 1 , heads_dim = v_heads_dim )
1808+ )
18001809 # doutput is always separate [b, s, h, d], so heads_dim=2
18011810 doutput_ = helper .all_to_all (doutput , True , seq_dim = 1 , heads_dim = 2 )
18021811 # output has the same shape as doutput, needs the same transformation
@@ -1824,15 +1833,19 @@ def impl(
18241833 _kv_segment_pos ,
18251834 config = helper .get_step_config (),
18261835 )
1827-
1836+
18281837 # Apply all-to-all to gradients to restore original sharding (scatter in seq dimension)
18291838 # Gradients have the same shape as inputs, so use same heads_dim
18301839 dq_heads_dim , dk_heads_dim , dv_heads_dim = helper .get_qkv_heads_dims (seq_dim = 1 )
1831-
1840+
18321841 dq_ = helper .all_to_all (dq , False , seq_dim = 1 , heads_dim = dq_heads_dim )
18331842 dk_ = helper .all_to_all (dk , False , seq_dim = 1 , heads_dim = dk_heads_dim )
18341843 # For KVPACKED layout, dv is empty placeholder
1835- dv_ = dv if dv .shape [0 ] == 0 else helper .all_to_all (dv , False , seq_dim = 1 , heads_dim = dv_heads_dim )
1844+ dv_ = (
1845+ dv
1846+ if dv .shape [0 ] == 0
1847+ else helper .all_to_all (dv , False , seq_dim = 1 , heads_dim = dv_heads_dim )
1848+ )
18361849
18371850 return dq_ , dk_ , dv_ , dbias
18381851
@@ -1870,19 +1883,19 @@ def check_supported(self):
18701883 def get_qkv_heads_dims (self , seq_dim = 1 ):
18711884 """
18721885 Determines the heads dimension indices for Q, K, V tensors based on QKV layout.
1873-
1886+
18741887 The heads dimension position depends on the QKV packing format:
18751888 - QKVPacked: All tensors packed together with dimension [qkv=3, heads, dim]
18761889 - KVPacked: Q is separate, K and V are packed with dimension [kv=2, heads, dim]
18771890 - Separate: All tensors are separate with dimension [heads, dim]
1878-
1891+
18791892 Args:
18801893 seq_dim: The sequence dimension position (default 1 for BSHD format)
1881-
1894+
18821895 Returns:
18831896 Tuple of (q_heads_dim, k_heads_dim, v_heads_dim) indicating the position
18841897 of the heads dimension for each tensor.
1885-
1898+
18861899 Examples for BSHD layout (seq_dim=1):
18871900 QKVPacked: Q=[b, s, 3, h, d] -> returns (3, 3, 3)
18881901 KVPacked: Q=[b, s, h, d], K=[b, s, 2, h, d], V=[b, s, 2, h, d] -> returns (2, 3, 3)
@@ -1943,21 +1956,26 @@ def all_to_all(self, x, before_attn=True, seq_dim=1, heads_dim=2):
19431956 # Reshape: insert cp_size at split_axis
19441957 assert shape [split_axis ] % cp_size == 0
19451958 x = x .reshape (
1946- * shape [:split_axis ], cp_size , shape [split_axis ] // cp_size , * shape [split_axis + 1 :]
1959+ * shape [:split_axis ], cp_size , shape [split_axis ] // cp_size , * shape [split_axis + 1 :]
19471960 )
19481961
19491962 # Adjust concat_axis if needed (only for before_attn case)
19501963 adjusted_concat_axis = concat_axis + 1 if needs_adjustment else concat_axis
19511964
19521965 # Perform all-to-all
19531966 x = lax_paral_op (
1954- x , lax .all_to_all , self .config .cp_axis , mesh = self .mesh ,
1955- split_axis = split_axis , concat_axis = adjusted_concat_axis , tiled = True
1967+ x ,
1968+ lax .all_to_all ,
1969+ self .config .cp_axis ,
1970+ mesh = self .mesh ,
1971+ split_axis = split_axis ,
1972+ concat_axis = adjusted_concat_axis ,
1973+ tiled = True ,
19561974 )
19571975
19581976 # Merge the two dimensions created by all-to-all at split_axis
19591977 new_shape = list (x .shape )
1960- new_shape [split_axis : split_axis + 2 ] = [x .shape [split_axis ] * x .shape [split_axis + 1 ]]
1978+ new_shape [split_axis : split_axis + 2 ] = [x .shape [split_axis ] * x .shape [split_axis + 1 ]]
19611979 return x .reshape (new_shape )
19621980
19631981 def get_step_config (self ) -> _FusedAttnConfig :
0 commit comments