|
33 | 33 | BlockSizes = splash_attention_kernel.BlockSizes |
34 | 34 |
|
35 | 35 | AxisNames = tuple[str, ...] |
36 | | - |
| 36 | +# Physical axis names for device meshes. |
| 37 | +DATA = "data" |
| 38 | +FSDP = "fsdp" |
| 39 | +TENSOR = "tensor" |
| 40 | +# Logical axis names for model parameters and activations. |
37 | 41 | BATCH = "activation_batch" |
38 | 42 | LENGTH = "activation_length" |
39 | 43 | KV_LENGTH = "activation_kv_length" |
|
44 | 48 | KEEP_2 = "activation_keep_2" |
45 | 49 | CONV_OUT = "activation_conv_out_channels" |
46 | 50 |
|
| 51 | +# For setting self/cross attention independently in splash kernel |
| 52 | +SELF_ATTN_HEAD = "activation_self_attn_heads" |
| 53 | +SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length" |
| 54 | +SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length" |
| 55 | +CROSS_ATTN_HEAD = "activation_cross_attn_heads" |
| 56 | +CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length" |
| 57 | +CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length" |
| 58 | + |
| 59 | + |
47 | 60 | WAN_MODEL = "Wan2.1" |
| 61 | + |
| 62 | +### Common axis rules for ring attention ### |
| 63 | +RING_ATTENTION_AXIS_RULES = [ |
| 64 | + [SELF_ATTN_HEAD, None], |
| 65 | + [SELF_ATTN_Q_LENGTH, FSDP], |
| 66 | + [SELF_ATTN_KV_LENGTH, FSDP], |
| 67 | + [CROSS_ATTN_HEAD, None], |
| 68 | + [CROSS_ATTN_Q_LENGTH, FSDP], |
| 69 | + [CROSS_ATTN_KV_LENGTH, FSDP], |
| 70 | +] |
| 71 | + |
| 72 | +SEQUENCE_PARALLEL_AXIS_RULES = [ |
| 73 | + [SELF_ATTN_HEAD, None], |
| 74 | + [SELF_ATTN_Q_LENGTH, FSDP], |
| 75 | + [SELF_ATTN_KV_LENGTH, None], |
| 76 | + [CROSS_ATTN_HEAD, None], |
| 77 | + [CROSS_ATTN_Q_LENGTH, FSDP], |
| 78 | + [CROSS_ATTN_KV_LENGTH, None], |
| 79 | +] |
0 commit comments