Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
efe83f5
re-init
kijai Sep 30, 2025
460ce7f
Update model_multitalk.py
kijai Oct 3, 2025
6f6db12
whitespace...
kijai Oct 3, 2025
00c069d
Update model_multitalk.py
kijai Oct 3, 2025
57567bd
remove print
kijai Oct 3, 2025
9c5022e
this is redundant
kijai Oct 10, 2025
d0dce6b
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Oct 10, 2025
7842a5c
remove import
kijai Oct 10, 2025
99dc959
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Oct 19, 2025
4cbc1a6
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Oct 23, 2025
f5d53f2
Restore preview functionality
kijai Oct 23, 2025
897ffeb
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Oct 30, 2025
25063f2
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Nov 3, 2025
6bfce54
Move block_idx to transformer_options
kijai Nov 3, 2025
8d62661
Remove LoopingSamplerCustomAdvanced
kijai Nov 3, 2025
d53e629
Remove looping functionality, keep extension functionality
kijai Nov 3, 2025
3ae78a4
Update model_multitalk.py
kijai Nov 3, 2025
7237c36
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Nov 5, 2025
fb099a4
Handle ref_attn_mask with separate patch to avoid having to always re…
kijai Nov 5, 2025
b4d3f4e
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Nov 26, 2025
af4d412
Chunk attention map calculation for multiple speakers to reduce peak …
kijai Nov 26, 2025
64a9841
Update model_multitalk.py
kijai Nov 26, 2025
c36323f
Merge remote-tracking branch 'upstream/master' into multitalk
kijai Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def forward(self, x, freqs, transformer_options={}):
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
patches = transformer_options.get("patches", {})

b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim

def qkv_fn_q(x):
Expand All @@ -86,6 +88,10 @@ def qkv_fn_k(x):
transformer_options=transformer_options,
)

if "attn1_patch" in patches:
for p in patches["attn1_patch"]:
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})

x = self.o(x)
return x

Expand Down Expand Up @@ -225,6 +231,8 @@ def forward(
"""
# assert e.dtype == torch.float32

patches = transformer_options.get("patches", {})

if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
Expand All @@ -242,6 +250,11 @@ def forward(

# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)

if "attn2_patch" in patches:
for p in patches["attn2_patch"]:
x = p({"x": x, "transformer_options": transformer_options})

y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
Expand Down Expand Up @@ -488,7 +501,7 @@ def __init__(self,
self.blocks = nn.ModuleList([
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
for i in range(num_layers)
])

# head
Expand Down Expand Up @@ -541,6 +554,7 @@ def forward_orig(
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)

# time embeddings
Expand Down Expand Up @@ -569,6 +583,7 @@ def forward_orig(
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
transformer_options["block_idx"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down Expand Up @@ -735,6 +750,7 @@ def forward_orig(
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)

# time embeddings
Expand Down Expand Up @@ -764,6 +780,7 @@ def forward_orig(
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
transformer_options["block_idx"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down
Loading