forked from stduhpf/ComfyUI-WanMoeKSampler
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnodes.py
More file actions
322 lines (279 loc) · 14.7 KB
/
nodes.py
File metadata and controls
322 lines (279 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import torch
import numpy as np
import math # Added for ceil equivalent
import comfy.sample
import comfy.samplers
import comfy.utils
import comfy.model_sampling
from comfy.model_sampling import ModelSamplingDiscreteFlow, CONST
import latent_preview
def wan_ksampler(
model_high_noise,
model_low_noise,
seed,
steps,
cfgs,
sampler_name,
scheduler,
positive,
negative,
latent,
boundary=0.875,
denoise=1.0,
disable_noise=False,
start_step=None,
last_step=None,
force_full_denoise=False,
cfg_fall_ratio_high=0.5,
cfg_fall_ratio_low=0.5,
):
latent_image = latent["samples"]
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent.get("batch_index", None)
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = latent.get("noise_mask", None)
# --- Determine switching point (boundary) ---
sampling = model_high_noise.get_model_object("model_sampling")
sigmas = comfy.samplers.calculate_sigmas(sampling, scheduler, steps)
timesteps = [sampling.timestep(sigma) / 1000 for sigma in sigmas.tolist()]
split_at_step = steps
for i, t in enumerate(timesteps):
if i == 0:
continue
if t < boundary:
split_at_step = i
break
print(f"Switching model at step {split_at_step}. High-noise runs {split_at_step} steps, Low-noise runs {steps - split_at_step} steps.")
# Clamp user-defined start/end
start_at = 0 if start_step is None else start_step
end_at = steps if last_step is None else min(steps, last_step)
high_noise_end_step = min(end_at, split_at_step)
low_noise_start_step = max(start_at, split_at_step)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
# --- Build configurable CFG schedules (With corrected logic) ---
def build_cfg_schedule(start_cfg, total_steps, ratio):
if total_steps <= 1:
return np.array([start_cfg])
# Base calculation: Use floor/int to determine the number of steps that actively decay.
# The linear space will then use this number + 1 points to include the start and end values.
fall_steps = int(total_steps * ratio)
if fall_steps < 1:
# If fall_steps is 0, or if total_steps is too small for a ratio, just return 1.0
return np.ones(total_steps) * 1.0
# --- CRITICAL FIX: To distribute the decay across N steps, we need N+1 points. ---
# Example: 8 * 0.50 = 4. We need 5 points (Step 0 to Step 4) to decay across 4 transitions.
num_decay_points = fall_steps + 1
# Ensure we don't exceed the total available steps
num_decay_points = min(num_decay_points, total_steps)
# The decay is now calculated over the required number of points
decay_points = np.linspace(start_cfg, 1.0, num_decay_points)
# The sustain section starts immediately after the decay points.
sustain_len = total_steps - num_decay_points
if sustain_len <= 0:
return decay_points
sustain_points = np.ones(sustain_len) * 1.0
return np.concatenate([decay_points, sustain_points])
cfg_high_schedule = build_cfg_schedule(cfgs[0], high_noise_end_step - start_at, cfg_fall_ratio_high)
cfg_low_schedule = build_cfg_schedule(cfgs[1], end_at - low_noise_start_step, cfg_fall_ratio_low)
# --- HIGH NOISE MODEL ---
if start_at < high_noise_end_step:
print(f"Running high noise model for steps {start_at}–{high_noise_end_step - 1} (CFG {cfgs[0]}→1 over {cfg_fall_ratio_high*100:.0f}% of steps)")
callback = latent_preview.prepare_callback(model_high_noise, steps)
latent_image = comfy.sample.fix_empty_latent_channels(model_high_noise, latent_image)
current_latent = latent_image
for i, cfg_val in enumerate(cfg_high_schedule, start=start_at):
# --- CFG OUTPUT FOR TERMINAL ---
print(f"[WanMoE] Step {i}: High-Noise CFG is {cfg_val:.4f}")
current_latent = comfy.sample.sample(
model_high_noise,
noise,
steps,
float(cfg_val),
sampler_name,
scheduler,
positive,
negative,
current_latent,
denoise=denoise,
disable_noise=(low_noise_start_step < end_at) or disable_noise,
start_step=i,
last_step=i + 1,
force_full_denoise=(low_noise_start_step >= end_at) or force_full_denoise,
noise_mask=noise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
)
latent_image = current_latent
# --- LOW NOISE MODEL ---
if low_noise_start_step < end_at:
print(f"Running low noise model for steps {low_noise_start_step}–{end_at - 1} (CFG {cfgs[1]}→1 over {cfg_fall_ratio_low*100:.0f}% of steps)")
callback = latent_preview.prepare_callback(model_low_noise, steps)
latent_image = comfy.sample.fix_empty_latent_channels(model_low_noise, latent_image)
current_latent = latent_image
for i, cfg_val in enumerate(cfg_low_schedule, start=low_noise_start_step):
# --- CFG OUTPUT FOR TERMINAL ---
print(f"[WanMoE] Step {i}: Low-Noise CFG is {cfg_val:.4f}")
current_latent = comfy.sample.sample(
model_low_noise,
noise,
steps,
float(cfg_val),
sampler_name,
scheduler,
positive,
negative,
current_latent,
denoise=denoise,
disable_noise=disable_noise,
start_step=i,
last_step=i + 1,
force_full_denoise=force_full_denoise,
noise_mask=noise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed,
)
latent_image = current_latent
out = latent.copy()
out["samples"] = latent_image
return (out,)
# --- Helper for sigma shift ---
def set_shift(model, sigma_shift):
model_sampling = model.get_model_object("model_sampling")
if not model_sampling:
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_parameters(shift=sigma_shift, multiplier=1000)
model.add_object_patch("model_sampling", model_sampling)
return model
# --- Simple Node ---
class WanMoeKSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_high_noise": ("MODEL", {"tooltip": "High-noise expert model used for the early denoising phase."}),
"model_low_noise": ("MODEL", {"tooltip": "Low-noise expert model used for the later denoising phase."}),
"boundary": ("FLOAT", {"default": 0.875, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Timestep boundary where the sampler switches from high-noise to low-noise model."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Random seed for noise generation; controls reproducibility."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "Total number of denoising steps to perform."}),
"cfg_high_noise": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.1, "tooltip": "Initial CFG (Classifier-Free Guidance) scale for the high-noise model."}),
"cfg_low_noise": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.1, "tooltip": "Initial CFG scale for the low-noise model."}),
"cfg_fall_ratio_high": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Fraction of high-noise model steps during which CFG linearly falls from start value to 1.0."}),
"cfg_fall_ratio_low": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Fraction of low-noise model steps during which CFG linearly falls from start value to 1.0."}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "Sampling algorithm used during denoising."}),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "Noise schedule controlling how noise is removed per step."}),
"sigma_shift": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Sigma shift factor that modifies the noise distribution for both models."}),
"positive": ("CONDITIONING", {"tooltip": "Positive prompt conditioning (what you want to see)."}),
"negative": ("CONDITIONING", {"tooltip": "Negative prompt conditioning (what you want to avoid)."}),
"latent_image": ("LATENT", {"tooltip": "Input latent tensor to be denoised."}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Denoising strength; lower values retain more of the original latent structure."}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
DESCRIPTION = "Dual-model sampler with independent dynamic CFG decay ratios for each model."
def sample(self, model_high_noise, model_low_noise, boundary, seed, steps, cfg_high_noise, cfg_low_noise, cfg_fall_ratio_high, cfg_fall_ratio_low, sampler_name, scheduler, sigma_shift, positive, negative, latent_image, denoise=1.0):
model_high_noise = set_shift(model_high_noise, sigma_shift)
model_low_noise = set_shift(model_low_noise, sigma_shift)
return wan_ksampler(
model_high_noise,
model_low_noise,
seed,
steps,
(cfg_high_noise, cfg_low_noise),
sampler_name,
scheduler,
positive,
negative,
latent_image,
boundary=boundary,
denoise=denoise,
cfg_fall_ratio_high=cfg_fall_ratio_high,
cfg_fall_ratio_low=cfg_fall_ratio_low,
)
# --- Advanced Node ---
class WanMoeKSamplerAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_high_noise": ("MODEL", {"tooltip": "High-noise expert model used for early denoising."}),
"model_low_noise": ("MODEL", {"tooltip": "Low-noise expert model used for later refinement."}),
"boundary": ("FLOAT", {"default": 0.875, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Boundary (t_moe) determining where to switch from high- to low-noise model."}),
"add_noise": (["enable", "disable"], {"tooltip": "Enable or disable noise addition at the start of denoising."}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Random seed for noise generation."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "Number of total denoising steps."}),
"cfg_high_noise": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.1, "tooltip": "Starting CFG scale for the high-noise model."}),
"cfg_low_noise": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.1, "tooltip": "Starting CFG scale for the low-noise model."}),
"cfg_fall_ratio_high": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Fraction of high-noise steps where CFG linearly decays to 1.0."}),
"cfg_fall_ratio_low": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Fraction of low-noise steps where CFG linearly decays to 1.0."}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "Select the sampler algorithm (e.g., euler, dpmpp_2m, etc.)."}),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "Scheduler type controlling noise sigma progression."}),
"sigma_shift": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Shift applied to sigma schedule for adjusting denoising behavior."}),
"positive": ("CONDITIONING", {"tooltip": "Positive text conditioning for guiding image generation."}),
"negative": ("CONDITIONING", {"tooltip": "Negative text conditioning for avoiding unwanted elements."}),
"latent_image": ("LATENT", {"tooltip": "Input latent image tensor."}),
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000, "tooltip": "Optional: start denoising from this step index."}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000, "tooltip": "Optional: stop denoising at this step index."}),
"return_with_leftover_noise": (["disable", "enable"], {"tooltip": "If enabled, retains leftover noise in the output latent instead of full denoising."}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
DESCRIPTION = "Advanced version of the dual-model sampler with precise control over noise behavior and CFG decay."
def sample(
self,
model_high_noise,
model_low_noise,
boundary,
add_noise,
noise_seed,
steps,
cfg_high_noise,
cfg_low_noise,
cfg_fall_ratio_high,
cfg_fall_ratio_low,
sampler_name,
scheduler,
sigma_shift,
positive,
negative,
latent_image,
start_at_step,
end_at_step,
return_with_leftover_noise,
denoise=1.0,
):
model_high_noise = set_shift(model_high_noise, sigma_shift)
model_low_noise = set_shift(model_low_noise, sigma_shift)
force_full_denoise = return_with_leftover_noise != "enable"
disable_noise = add_noise == "disable"
return wan_ksampler(
model_high_noise,
model_low_noise,
noise_seed,
steps,
(cfg_high_noise, cfg_low_noise),
sampler_name,
scheduler,
positive,
negative,
latent_image,
boundary=boundary,
denoise=denoise,
disable_noise=disable_noise,
start_step=start_at_step,
last_step=end_at_step,
force_full_denoise=force_full_denoise,
cfg_fall_ratio_high=cfg_fall_ratio_high,
cfg_fall_ratio_low=cfg_fall_ratio_low,
)