Skip to content

Commit 65c4e40

Browse files
entrpncoolkp
andauthored
Cross self attention switch (#251)
* skip flash block sizes setting for cross attention. * change sharding based on cross/self attention. * update sharding rules for attn. * lint. * ring attention rules are added at front if not present to shard sequence on fsdp axis * test fix * Add dense padded attention kernel and use unsafe rng key for generation * Update * Ignore history * remove file * Flag for using segment ids and masking padding tokens in attention Signed-off-by: Kunjan Patel <kunjanp@google.com> * Tokamax splash attn Signed-off-by: Kunjan Patel <kunjanp@google.com> * Flag for using same sequence sharding for self and cross Signed-off-by: Kunjan Patel <kunjanp@google.com> * update requirements.txt Signed-off-by: Kunjan Patel <kunjanp@google.com> * Delete splash_attn_benchmark.py * Delete padded_flash_attn.py * Merge main Signed-off-by: Kunjan Patel <kunjanp@google.com> * Ruff format Signed-off-by: Kunjan Patel <kunjanp@google.com> * Ruff format Signed-off-by: Kunjan Patel <kunjanp@google.com> * Ruff format Signed-off-by: Kunjan Patel <kunjanp@google.com> * Address comments Signed-off-by: Kunjan Patel <kunjanp@google.com> * Address comments Signed-off-by: Kunjan Patel <kunjanp@google.com> * Address comments Signed-off-by: Kunjan Patel <kunjanp@google.com> * Fix pprint error, add description of attention configuration params * Fix pprint error, add description of attention configuration params * Fix pprint error, add description of attention configuration params --------- Signed-off-by: Kunjan Patel <kunjanp@google.com> Co-authored-by: Kunjan Patel <kunjan@ucla.edu> Co-authored-by: Kunjan Patel <kunjanp@google.com>
1 parent d843dc0 commit 65c4e40

25 files changed

+497
-212
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6262
# add_pull_ready:
6363
# if: github.ref != 'refs/heads/main'
6464
# permissions:

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
__pycache__/
55
*.py[cod]
66
*$py.class
7-
87
# C extensions
98
*.so
109

@@ -98,6 +97,7 @@ celerybeat-schedule
9897

9998
# Environments
10099
.env
100+
.history
101101
.venv
102102
env/
103103
venv/

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ ftfy
1313
tensorboard>=2.17.0
1414
tensorboardx>=2.6.2.2
1515
tensorboard-plugin-profile>=2.15.2
16+
tokamax
1617
Jinja2
1718
scikit-image
1819
parameterized

src/maxdiffusion/common_types.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
BlockSizes = splash_attention_kernel.BlockSizes
3434

3535
AxisNames = tuple[str, ...]
36-
36+
# Physical axis names for device meshes.
37+
DATA = "data"
38+
FSDP = "fsdp"
39+
TENSOR = "tensor"
40+
# Logical axis names for model parameters and activations.
3741
BATCH = "activation_batch"
3842
LENGTH = "activation_length"
3943
KV_LENGTH = "activation_kv_length"
@@ -44,4 +48,32 @@
4448
KEEP_2 = "activation_keep_2"
4549
CONV_OUT = "activation_conv_out_channels"
4650

51+
# For setting self/cross attention independently in splash kernel
52+
SELF_ATTN_HEAD = "activation_self_attn_heads"
53+
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
54+
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
55+
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
56+
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
57+
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"
58+
59+
4760
WAN_MODEL = "Wan2.1"
61+
62+
### Common axis rules for ring attention ###
63+
RING_ATTENTION_AXIS_RULES = [
64+
[SELF_ATTN_HEAD, None],
65+
[SELF_ATTN_Q_LENGTH, FSDP],
66+
[SELF_ATTN_KV_LENGTH, FSDP],
67+
[CROSS_ATTN_HEAD, None],
68+
[CROSS_ATTN_Q_LENGTH, FSDP],
69+
[CROSS_ATTN_KV_LENGTH, FSDP],
70+
]
71+
72+
SEQUENCE_PARALLEL_AXIS_RULES = [
73+
[SELF_ATTN_HEAD, None],
74+
[SELF_ATTN_Q_LENGTH, FSDP],
75+
[SELF_ATTN_KV_LENGTH, None],
76+
[CROSS_ATTN_HEAD, None],
77+
[CROSS_ATTN_Q_LENGTH, FSDP],
78+
[CROSS_ATTN_KV_LENGTH, None],
79+
]

src/maxdiffusion/configs/base14.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5362
flash_block_sizes: {}
5463
# GroupNorm groups
5564
norm_num_groups: 32

src/maxdiffusion/configs/base21.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ jit_initializers: True
4949
from_pt: False
5050
split_head_dim: True
5151
attention: 'dot_product' # Supported attention: dot_product, flash
52+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
53+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
54+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
55+
mask_padding_tokens: True
56+
# Maxdiffusion has 2 types of attention sharding strategies:
57+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
58+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
59+
# in cross attention q.
60+
attention_sharding_uniform: True
61+
5262
flash_block_sizes: {}
5363
# GroupNorm groups
5464
norm_num_groups: 32

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ jit_initializers: True
5050
from_pt: True
5151
split_head_dim: True
5252
attention: 'flash' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
62+
5363
flash_block_sizes: {}
5464
# to override default block sizes for flash attention
5565
# flash_block_sizes:

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
6675

6776
flash_block_sizes: {}
6877
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
6675

6776
#flash_block_sizes: {}
6877
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
65+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
67+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+
mask_padding_tokens: True
69+
# Maxdiffusion has 2 types of attention sharding strategies:
70+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+
# in cross attention q.
73+
attention_sharding_uniform: True
6574
flash_block_sizes: {
6675
"block_q" : 256,
6776
"block_kv_compute" : 256,

0 commit comments

Comments
 (0)