Skip to content

Commit e082175

Browse files
committed
Fix pprint error, add description of attention configuration params
1 parent 85cec45 commit e082175

File tree

10 files changed

+82
-17
lines changed

10 files changed

+82
-17
lines changed

src/maxdiffusion/configs/base14.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
54-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5562
flash_block_sizes: {}
5663
# GroupNorm groups
5764
norm_num_groups: 32

src/maxdiffusion/configs/base21.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,15 @@ jit_initializers: True
4949
from_pt: False
5050
split_head_dim: True
5151
attention: 'dot_product' # Supported attention: dot_product, flash
52-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
53-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
52+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
53+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
54+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
55+
mask_padding_tokens: True
56+
# Maxdiffusion has 2 types of attention sharding strategies:
57+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
58+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
59+
# in cross attention q.
60+
attention_sharding_uniform: True
5461

5562
flash_block_sizes: {}
5663
# GroupNorm groups

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ jit_initializers: True
5050
from_pt: True
5151
split_head_dim: True
5252
attention: 'flash' # Supported attention: dot_product, flash
53-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
54-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5562

5663
flash_block_sizes: {}
5764
# to override default block sizes for flash attention

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
67-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
6875

6976
flash_block_sizes: {}
7077
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
67-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
6875

6976
#flash_block_sizes: {}
7077
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
65-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
66-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
65+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
67+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+
mask_padding_tokens: True
69+
# Maxdiffusion has 2 types of attention sharding strategies:
70+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+
# in cross attention q.
73+
attention_sharding_uniform: True
6774
flash_block_sizes: {
6875
"block_q" : 256,
6976
"block_kv_compute" : 256,

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,15 @@ from_pt: True
6161
split_head_dim: True
6262
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6363
flash_min_seq_length: 4096
64-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
65-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
64+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
65+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
66+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
67+
mask_padding_tokens: True
68+
# Maxdiffusion has 2 types of attention sharding strategies:
69+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
70+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
71+
# in cross attention q.
72+
attention_sharding_uniform: True
6673
dropout: 0.1
6774

6875
flash_block_sizes: {

src/maxdiffusion/configs/base_xl.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53-
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
54-
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5562
flash_block_sizes: {}
5663
# GroupNorm groups
5764
norm_num_groups: 32

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ jit_initializers: True
4848
from_pt: False
4949
split_head_dim: True
5050
attention: 'flash' # Supported attention: dot_product, flash
51+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
52+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
53+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
54+
mask_padding_tokens: True
55+
# Maxdiffusion has 2 types of attention sharding strategies:
56+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
57+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
58+
# in cross attention q.
59+
attention_sharding_uniform: True
5160
flash_block_sizes: {}
5261
# GroupNorm groups
5362
norm_num_groups: 32

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import datetime
1919
import functools
20-
from pprint import pprint
20+
import pprint
2121
import numpy as np
2222
import threading
2323
from concurrent.futures import ThreadPoolExecutor

0 commit comments

Comments
 (0)