Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics. Use [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics
- version {plugin['version']} (id: {plugin['id']})
+ version {plugin['version']} by {author} (id: {plugin['id']}){plugin.get('description', 'No description provided.')}
-
-
-
-
-
+ {actions_container_html}
"""
user_html = f'
{user_items_html}
'
- return f"{css}
{user_html}
"
+ return f"{css}
{instruction_html}{user_html}
"
def create_plugin_manager_ui(self):
with gr.Blocks() as plugin_blocks:
with gr.Row(equal_height=False, variant='panel'):
with gr.Column(scale=2, min_width=600):
- gr.Markdown("### Installed Plugins (Drag to reorder tabs)")
+ gr.Markdown("### Plugins Available Locally (Drag to reorder tabs)")
self.plugins_html_display = gr.HTML()
with gr.Row(elem_classes="save-buttons-container"):
self.save_plugins_button = gr.Button("Save", variant="secondary", size="sm", scale=0, elem_classes="stylish-save-btn")
self.save_and_restart_button = gr.Button("Save and Restart", variant="primary", size="sm", scale=0, elem_classes="stylish-save-btn")
+ self.refresh_catalog_button = gr.Button("Check for Updates", variant="secondary", size="sm", scale=0, elem_classes="stylish-save-btn")
with gr.Column(scale=2, min_width=300):
gr.Markdown("### Discover & Install")
@@ -310,6 +428,12 @@ def create_plugin_manager_ui(self):
self.save_plugins_button.click(fn=None, js="handleSave(false)")
self.save_and_restart_button.click(fn=None, js="handleSave(true)")
+ self.refresh_catalog_button.click(
+ fn=self._refresh_catalog,
+ inputs=[],
+ outputs=[self.plugins_html_display, self.community_plugins_html],
+ show_progress="full"
+ )
self.save_action_input.change(
fn=self._handle_save_action,
@@ -343,9 +467,42 @@ def _on_tab_select_refresh(self, evt: gr.SelectData):
community_html = self._build_community_plugins_html()
return gr.update(value=installed_html), gr.update(value=community_html)
+ def _refresh_catalog(self, progress=gr.Progress()):
+ self.app.plugin_manager.refresh_catalog(installed_only=True, use_remote=False)
+ if hasattr(self, '_community_plugins_cache'):
+ del self._community_plugins_cache
+ updates_available = self._count_available_updates()
+ if updates_available <= 0:
+ gr.Info("No Plugin Update is available")
+ elif updates_available == 1:
+ gr.Info("One Plugin Update is available")
+ else:
+ gr.Info(f"{updates_available} Plugin Updates are available")
+ return self._build_plugins_html(), self._build_community_plugins_html()
+
+ def _count_available_updates(self) -> int:
+ try:
+ plugins_info = self.app.plugin_manager.get_plugins_info()
+ remote_plugins_info = self.app.plugin_manager.get_merged_catalog_entries(use_remote=False)
+ count = 0
+ for plugin in plugins_info:
+ if plugin.get('system'):
+ continue
+ if not plugin.get('uninstallable', True):
+ continue
+ plugin_id = plugin.get('id')
+ if not plugin_id or plugin_id not in remote_plugins_info:
+ continue
+ remote_entry = remote_plugins_info[plugin_id]
+ if compare_release_metadata(remote_entry, plugin) > 0:
+ count += 1
+ return count
+ except Exception:
+ return 0
+
def _enable_plugin_after_install(self, url: str):
try:
- plugin_id = url.split('/')[-1].replace('.git', '')
+ plugin_id = plugin_id_from_url(url)
enabled_plugins = self.server_config.get("enabled_plugins", [])
if plugin_id not in enabled_plugins:
enabled_plugins.append(plugin_id)
@@ -369,7 +526,10 @@ def _save_and_restart(self, enabled_plugins: list):
with open(self.server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(self.server_config, indent=4))
gr.Info("Settings saved. Restarting application...")
- quit_application()
+ if callable(getattr(self, "quit_application", None)):
+ self.quit_application()
+ return
+ gr.Warning("Restart hook is unavailable. Please restart WanGP manually.")
def _handle_save_action(self, payload_str: str):
if not payload_str:
@@ -393,6 +553,11 @@ def _install_plugin_and_refresh(self, url, progress=gr.Progress()):
was_enabled = self._enable_plugin_after_install(url)
if was_enabled:
result_message = result_message.replace("Please enable it", "It has been auto-enabled")
+ plugin_id = plugin_id_from_url(url)
+ if plugin_id:
+ self.app.plugin_manager.record_plugin_metadata(plugin_id, url=url)
+ if hasattr(self, '_community_plugins_cache'):
+ del self._community_plugins_cache
gr.Info(result_message)
else:
gr.Warning(result_message)
@@ -445,4 +610,4 @@ def _handle_plugin_action_from_json(self, payload_str: str, progress=gr.Progress
if hasattr(self, '_community_plugins_cache'):
del self._community_plugins_cache
- return self._build_plugins_html(), self._build_community_plugins_html()
\ No newline at end of file
+ return self._build_plugins_html(), self._build_community_plugins_html()
diff --git a/Wan2GP/plugins/wan2gp-sample/plugin.py b/Wan2GP/plugins/wan2gp-sample/plugin.py
index 63e2eca0c..017f3f61d 100644
--- a/Wan2GP/plugins/wan2gp-sample/plugin.py
+++ b/Wan2GP/plugins/wan2gp-sample/plugin.py
@@ -18,9 +18,6 @@ def release_GPU(state):
class ConfigTabPlugin(WAN2GPPlugin):
def __init__(self):
super().__init__()
- self.name = PlugIn_Name
- self.version = "1.0.0"
- self.description = PlugIn_Name
def setup_ui(self):
self.request_global("get_current_model_settings")
@@ -91,4 +88,4 @@ def big_process(state):
outputs=[ self.main_tabs ]
)
-
\ No newline at end of file
+
diff --git a/Wan2GP/postprocessing/film_grain.py b/Wan2GP/postprocessing/film_grain.py
index a38b43a8b..affeaf7a4 100644
--- a/Wan2GP/postprocessing/film_grain.py
+++ b/Wan2GP/postprocessing/film_grain.py
@@ -2,10 +2,13 @@
import torch
def add_film_grain(images: torch.Tensor, grain_intensity: float = 0, saturation: float = 0.5):
- device = images.device
+ device = images.device
+ input_was_uint8 = images.dtype == torch.uint8
+ if input_was_uint8:
+ images = images.float().div_(255.0).mul_(2.0).sub_(1.0)
- images = images.permute(1, 2 ,3 ,0)
- images.add_(1.).div_(2.)
+ images = images.permute(1, 2, 3, 0)
+ images.add_(1.0).div_(2.0)
grain = torch.randn_like(images, device=device)
grain[:, :, :, 0] *= 2
grain[:, :, :, 2] *= 3
@@ -16,6 +19,8 @@ def add_film_grain(images: torch.Tensor, grain_intensity: float = 0, saturation:
# Blend the grain with the image
noised_images = images + grain_intensity * grain
noised_images.clamp_(0, 1)
- noised_images.sub_(.5).mul_(2.)
- noised_images = noised_images.permute(3, 0, 1 ,2)
+ noised_images.sub_(0.5).mul_(2.0)
+ noised_images = noised_images.permute(3, 0, 1, 2)
+ if input_was_uint8:
+ noised_images = noised_images.add(1.0).mul(127.5).clamp(0, 255).to(torch.uint8)
return noised_images
diff --git a/Wan2GP/postprocessing/mmaudio/ext/autoencoder/autoencoder.py b/Wan2GP/postprocessing/mmaudio/ext/autoencoder/autoencoder.py
index e40f3fc70..d1b79e8ac 100644
--- a/Wan2GP/postprocessing/mmaudio/ext/autoencoder/autoencoder.py
+++ b/Wan2GP/postprocessing/mmaudio/ext/autoencoder/autoencoder.py
@@ -1,13 +1,39 @@
+import os
from typing import Literal, Optional
import torch
import torch.nn as nn
+from mmgp import offload
+from shared.utils import files_locator as fl
from ..autoencoder.vae import VAE, get_my_vae
from ..bigvgan import BigVGAN
from ..bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
from ...model.utils.distributions import DiagonalGaussianDistribution
+_BIGVGAN_V2_FOLDER = "bigvgan_v2_44khz_128band_512x"
+
+
+def _resolve_bigvgan_v2_files():
+ weights_path = fl.locate_file(
+ os.path.join(_BIGVGAN_V2_FOLDER, "bigvgan_generator.pt"), error_if_none=False
+ )
+ config_path = fl.locate_file(
+ os.path.join(_BIGVGAN_V2_FOLDER, "config.json"), error_if_none=False
+ )
+ if weights_path is None or config_path is None:
+ raise FileNotFoundError(
+ f"Missing BigVGANv2 files in '{_BIGVGAN_V2_FOLDER}'. "
+ "Expected 'config.json' and 'bigvgan_generator.pt'."
+ )
+ return weights_path, config_path
+
+
+def _preprocess_bigvgan_v2_state_dict(state_dict, quantization_map=None, tied_weights_map=None):
+ if isinstance(state_dict, dict) and isinstance(state_dict.get("generator"), dict):
+ state_dict = state_dict["generator"]
+ return state_dict, quantization_map, tied_weights_map
+
class AutoEncoderModule(nn.Module):
@@ -27,9 +53,18 @@ def __init__(self,
assert vocoder_ckpt_path is not None
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
elif mode == '44k':
- self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
- use_cuda_kernel=False)
+ vocoder_ckpt_path, vocoder_config_path = _resolve_bigvgan_v2_files()
+ self.vocoder = offload.fast_load_transformers_model(
+ vocoder_ckpt_path,
+ modelClass=BigVGANv2,
+ forcedConfigPath=vocoder_config_path,
+ preprocess_sd=_preprocess_bigvgan_v2_state_dict,
+ configKwargs={"use_cuda_kernel": False},
+ writable_tensors=False,
+ default_dtype=torch.float32,
+ )
self.vocoder.remove_weight_norm()
+ self.vocoder.eval()
else:
raise ValueError(f'Unknown mode: {mode}')
diff --git a/Wan2GP/postprocessing/mmaudio/ext/bigvgan_v2/bigvgan.py b/Wan2GP/postprocessing/mmaudio/ext/bigvgan_v2/bigvgan.py
index 96b87c20c..4432583db 100644
--- a/Wan2GP/postprocessing/mmaudio/ext/bigvgan_v2/bigvgan.py
+++ b/Wan2GP/postprocessing/mmaudio/ext/bigvgan_v2/bigvgan.py
@@ -346,6 +346,12 @@ def remove_weight_norm(self):
print("[INFO] Model already removed weight norm. Skipping!")
pass
+ @classmethod
+ def from_config(cls, config: Dict):
+ config_data = dict(config)
+ use_cuda_kernel = config_data.pop("use_cuda_kernel", False)
+ return cls(AttrDict(config_data), use_cuda_kernel=use_cuda_kernel)
+
# Additional methods for huggingface_hub support
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights and config.json from a Pytorch model to a local directory."""
diff --git a/Wan2GP/postprocessing/rife/RIFE_V4.py b/Wan2GP/postprocessing/rife/RIFE_V4.py
new file mode 100644
index 000000000..9ae0ec708
--- /dev/null
+++ b/Wan2GP/postprocessing/rife/RIFE_V4.py
@@ -0,0 +1,287 @@
+"""
+MIT License
+
+Copyright (c) 2024 Hzwer
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
+ dtype = tenInput.dtype
+ tenInput = tenInput.to(torch.float)
+ tenFlow = tenFlow.to(torch.float)
+
+ tenFlow = torch.cat(
+ [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1
+ )
+ g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1)
+ padding_mode = "border"
+ if tenInput.device.type == "mps":
+ padding_mode = "zeros"
+ g = g.clamp(-1, 1)
+ return F.grid_sample(
+ input=tenInput,
+ grid=g,
+ mode="bilinear",
+ padding_mode=padding_mode,
+ align_corners=True,
+ ).to(dtype)
+
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=True,
+ ),
+ nn.LeakyReLU(0.2, True),
+ )
+
+
+class Head(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)
+ self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)
+ self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)
+ self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1)
+ self.relu = nn.LeakyReLU(0.2, True)
+
+ def forward(self, x):
+ x = x.clamp(0.0, 1.0)
+ x = self.relu(self.cnn0(x))
+ x = self.relu(self.cnn1(x))
+ x = self.relu(self.cnn2(x))
+ x = self.cnn3(x)
+ return x
+
+
+class ResConv(nn.Module):
+ def __init__(self, c, dilation=1):
+ super().__init__()
+ self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
+ self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
+ self.relu = nn.LeakyReLU(0.2, True)
+
+ def forward(self, x):
+ return self.relu(self.conv(x) * self.beta + x)
+
+
+class IFBlock(nn.Module):
+ def __init__(self, in_planes, c=64):
+ super().__init__()
+ self.conv0 = nn.Sequential(
+ conv(in_planes, c // 2, 3, 2, 1),
+ conv(c // 2, c, 3, 2, 1),
+ )
+ self.convblock = nn.Sequential(
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ )
+ self.lastconv = nn.Sequential(
+ nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1),
+ nn.PixelShuffle(2),
+ )
+
+ def forward(self, x, flow=None, scale=1):
+ x = F.interpolate(
+ x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
+ )
+ if flow is not None:
+ flow = (
+ F.interpolate(
+ flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
+ )
+ * 1.0
+ / scale
+ )
+ x = torch.cat((x, flow), 1)
+ feat = self.conv0(x)
+ feat = self.convblock(feat)
+ tmp = self.lastconv(feat)
+ tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
+ flow = tmp[:, :4] * scale
+ mask = tmp[:, 4:5]
+ feat = tmp[:, 5:]
+ return flow, mask, feat
+
+
+class IFNet(nn.Module):
+ def __init__(self, scale=1.0):
+ super().__init__()
+ self.block0 = IFBlock(7 + 8, c=192)
+ self.block1 = IFBlock(8 + 4 + 8 + 8, c=128)
+ self.block2 = IFBlock(8 + 4 + 8 + 8, c=96)
+ self.block3 = IFBlock(8 + 4 + 8 + 8, c=64)
+ self.block4 = IFBlock(8 + 4 + 8 + 8, c=32)
+ self.scaleList = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale]
+ self.blocks = [self.block0, self.block1, self.block2, self.block3, self.block4]
+
+ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1):
+ img0 = img0.clamp(0.0, 1.0)
+ img1 = img1.clamp(0.0, 1.0)
+
+ warped_img0 = img0
+ warped_img1 = img1
+ flow = None
+ mask = None
+ feat = None
+
+ for i in range(5):
+ if flow is None:
+ flow, mask, feat = self.blocks[i](
+ torch.cat((img0, img1, f0, f1, timestep), 1),
+ None,
+ scale=self.scaleList[i],
+ )
+ else:
+ wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid)
+ wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid)
+ fd, m0, feat = self.blocks[i](
+ torch.cat(
+ (
+ warped_img0,
+ warped_img1,
+ wf0,
+ wf1,
+ timestep,
+ mask,
+ feat,
+ ),
+ 1,
+ ),
+ flow,
+ scale=self.scaleList[i],
+ )
+ mask = m0
+ flow = flow + fd
+ warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid)
+ warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid)
+ mask = torch.sigmoid(mask)
+ return warped_img0 * mask + warped_img1 * (1 - mask)
+
+
+class Model:
+ def __init__(self):
+ self.flownet = IFNet()
+ self.encode = Head()
+ self.pad_mod = 64
+ self.supports_timestep = True
+ self._grid_cache = {}
+ self.device = None
+
+ def train(self):
+ self.flownet.train()
+ self.encode.train()
+
+ def eval(self):
+ self.flownet.eval()
+ self.encode.eval()
+
+ def to(self, device):
+ self.flownet.to(device)
+ self.encode.to(device)
+
+ def _get_grid(self, height, width, device):
+ key = (height, width, device.type, device.index)
+ cached = self._grid_cache.get(key)
+ if cached is not None:
+ return cached
+ tenFlow_div = torch.tensor(
+ [(width - 1.0) / 2.0, (height - 1.0) / 2.0],
+ dtype=torch.float32,
+ device=device,
+ )
+ tenHorizontal = (
+ torch.linspace(-1.0, 1.0, width, dtype=torch.float32, device=device)
+ .view(1, 1, 1, width)
+ .expand(1, 1, height, width)
+ )
+ tenVertical = (
+ torch.linspace(-1.0, 1.0, height, dtype=torch.float32, device=device)
+ .view(1, 1, height, 1)
+ .expand(1, 1, height, width)
+ )
+ backwarp_tenGrid = torch.cat([tenHorizontal, tenVertical], 1)
+ self._grid_cache[key] = (tenFlow_div, backwarp_tenGrid)
+ return tenFlow_div, backwarp_tenGrid
+
+ def load_model(self, path, rank=0, device="cuda"):
+ self.device = device
+ state_dict = torch.load(path, map_location=device)
+ if isinstance(state_dict, dict):
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ elif "flownet" in state_dict:
+ state_dict = state_dict["flownet"]
+ state_dict = {
+ k.replace("module.", ""): v for k, v in state_dict.items()
+ }
+ head_state = {
+ k.replace("encode.", ""): v
+ for k, v in state_dict.items()
+ if k.startswith("encode.")
+ }
+ if head_state:
+ self.encode.load_state_dict(head_state, strict=True)
+ flow_state = {
+ k: v for k, v in state_dict.items() if not k.startswith("encode.")
+ }
+ self.flownet.load_state_dict(flow_state, strict=False)
+ self.to(device)
+
+ def inference(self, img0, img1, timestep=0.5, scale=1.0):
+ if scale != 1.0:
+ self.flownet.scaleList = [
+ 16 / scale,
+ 8 / scale,
+ 4 / scale,
+ 2 / scale,
+ 1 / scale,
+ ]
+ f0 = self.encode(img0)
+ f1 = self.encode(img1)
+ height = img0.shape[2]
+ width = img0.shape[3]
+ tenFlow_div, backwarp_tenGrid = self._get_grid(height, width, img0.device)
+ timestep_tensor = torch.full(
+ (1, 1, height, width),
+ float(timestep),
+ dtype=img0.dtype,
+ device=img0.device,
+ )
+ return self.flownet(
+ img0, img1, timestep_tensor, tenFlow_div, backwarp_tenGrid, f0, f1
+ )
diff --git a/Wan2GP/postprocessing/rife/inference.py b/Wan2GP/postprocessing/rife/inference.py
index a213496a4..0fa92a43a 100644
--- a/Wan2GP/postprocessing/rife/inference.py
+++ b/Wan2GP/postprocessing/rife/inference.py
@@ -4,13 +4,18 @@
# from .model.pytorch_msssim import ssim_matlab
from .ssim import ssim_matlab
-from .RIFE_HDv3 import Model
+from .RIFE_HDv3 import Model as ModelV3
+from .RIFE_V4 import Model as ModelV4
def get_frame(frames, frame_no):
if frame_no >= frames.shape[1]:
return None
- frame = (frames[:, frame_no] + 1) /2
- frame = frame.clip(0., 1.)
+ frame = frames[:, frame_no]
+ if frame.dtype == torch.uint8:
+ frame = frame.float().div_(255.0)
+ else:
+ frame = (frame + 1) / 2
+ frame = frame.clip(0., 1.)
return frame
def add_frame(frames, frame, h, w):
@@ -29,8 +34,14 @@ def process_frames(model, device, frames, exp):
_, h, w = lastframe.shape
scale = 1
fp16 = False
+ supports_timestep = getattr(model, "supports_timestep", False)
+ pad_mod = getattr(model, "pad_mod", 32)
def make_inference(I0, I1, n):
+ if n <= 0:
+ return []
+ if supports_timestep:
+ return [model.inference(I0, I1, (i + 1) / (n + 1), scale) for i in range(n)]
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
@@ -41,7 +52,7 @@ def make_inference(I0, I1, n):
else:
return [*first_half, *second_half]
- tmp = max(32, int(32 / scale))
+ tmp = max(pad_mod, int(pad_mod / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
@@ -83,7 +94,10 @@ def pad_image(img):
temp = frame
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
- I1 = model.inference(I0, I1, scale)
+ if supports_timestep:
+ I1 = model.inference(I0, I1, 0.5, scale)
+ else:
+ I1 = model.inference(I0, I1, scale)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[0][:, :h, :w]
@@ -105,15 +119,21 @@ def pad_image(img):
add_frame(output_frames, lastframe, h, w)
return torch.cat( output_frames, dim=1)
-def temporal_interpolation(model_path, frames, exp, device ="cuda"):
+def temporal_interpolation(model_path, frames, exp, device ="cuda", rife_version="v3"):
- model = Model()
+ input_was_uint8 = frames.dtype == torch.uint8
+ if rife_version == "v4":
+ model = ModelV4()
+ else:
+ model = ModelV3()
model.load_model(model_path, -1, device=device)
model.eval()
model.to(device=device)
with torch.no_grad():
- output = process_frames(model, device, frames.float(), exp)
+ output = process_frames(model, device, frames, exp)
+ if input_was_uint8:
+ output = output.add_(1.0).mul_(127.5).clamp_(0, 255).to(torch.uint8)
return output
diff --git a/Wan2GP/requirements.txt b/Wan2GP/requirements.txt
index 3310f89f9..a42a9b487 100644
--- a/Wan2GP/requirements.txt
+++ b/Wan2GP/requirements.txt
@@ -1,6 +1,7 @@
+--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-13-nightly/pypi/simple/
# Core AI stack
-diffusers==0.34.0
-transformers==4.53.1
+diffusers==0.36.0
+transformers==4.54.0 #4.53.1
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
@@ -10,6 +11,7 @@ einops
sentencepiece
open_clip_torch>=2.29.0
numpy==2.1.2
+num2words==0.5.14
# Video & media
moviepy==1.0.3
@@ -24,6 +26,7 @@ librosa==0.11.0
speechbrain==1.0.3
audio-separator==0.36.1
pyannote.audio==3.3.2
+torchcodec
# UI & interaction
gradio==5.29.0
@@ -32,18 +35,25 @@ loguru
s3tokenizer
conformer==0.3.2
spacy_pkuseg
+spacy
+gradio_rangeslider
# Vision & segmentation
opencv-python>=4.12.0.88
segment-anything
rembg[gpu]==2.0.65
-onnxruntime-gpu==1.22
+onnxruntime-gpu==1.22.0; python_version < "3.11"
+onnxruntime-gpu==1.25.0.dev20260210001; python_version >= "3.11"
decord
timm
-insightface @ https://github.com/deepbeepmeep/insightface/raw/refs/heads/master/wheels/insightface-0.7.3-cp310-cp310-win_amd64.whl ; sys_platform == "win32" and python_version == "3.10"
+insightface @ https://github.com/deepbeepmeep/insightface/releases/download/insightface/insightface-0.7.3-cp310-cp310-win_amd64.whl ; sys_platform == "win32" and python_version == "3.10"
+insightface @ https://github.com/deepbeepmeep/insightface/releases/download/insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl ; sys_platform == "win32" and python_version == "3.11"
+insightface @ https://github.com/deepbeepmeep/insightface/releases/download/insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl ; sys_platform == "win32" and python_version == "3.12"
insightface==0.7.3 ; sys_platform == "linux"
facexlib==0.3.0
taichi
+vector_quantize_pytorch==1.27.19
+
# chumpy wheel hosted on GitHub to avoid sdist build isolation issue
chumpy @ https://github.com/deepbeepmeep/chumpy/releases/download/v0.71/chumpy-0.71-py3-none-any.whl
smplfitter @ https://github.com/deepbeepmeep/smplfitter/releases/download/v0.2.10/smplfitter-0.2.10-py3-none-any.whl
@@ -57,9 +67,11 @@ pydantic==2.10.6
# Math & modeling
torchdiffeq>=0.2.5
tensordict>=0.6.1
-mmgp==3.6.10
-peft==0.15.0
+mmgp==3.7.4
+peft==0.17.0 # 0.15.0
+vector-quantize-pytorch
matplotlib
+gguf==0.17.1
# Utilities
ftfy
@@ -69,10 +81,9 @@ misaki
gitdb==4.0.12
gitpython==3.1.45
stringzilla==4.0.14
+xxhash
# Optional / commented out
# transformers==4.46.3 # for llamallava pre-patch
# rembg==2.0.65 # non-GPU fallback
# huggingface_hub[hf_xet] # slows down everything
-# num2words
-# spacy
diff --git a/Wan2GP/scripts/install.bat b/Wan2GP/scripts/install.bat
new file mode 100644
index 000000000..14469c6d1
--- /dev/null
+++ b/Wan2GP/scripts/install.bat
@@ -0,0 +1,56 @@
+@echo off
+cd /d "%~dp0.."
+setlocal enabledelayedexpansion
+title WanGP Installer
+
+:MENU
+cls
+echo ======================================================
+echo WAN2GP INSTALLER MENU
+echo ======================================================
+echo 1. Use 'venv' (Easiest - Comes prepackaged with python)
+echo 2. Use 'uv' (Recommended - Handles Python 3.11 better)
+echo 3. Use 'Conda'
+echo 4. No Environment (Not Recommended)
+echo 5. Exit
+echo ------------------------------------------------------
+set /p choice="Select an option (1-4): "
+
+if "%choice%"=="1" (
+ set "ENV_TYPE=venv"
+ goto START_INSTALL
+)
+
+if "%choice%"=="2" (
+ set "ENV_TYPE=uv"
+ where uv >nul 2>nul
+ if !errorlevel! neq 0 (
+ echo [!] 'uv' not found.
+ echo 1. Install 'uv' via PowerShell (Recommended)
+ echo 2. Install 'uv' via Pip
+ set /p uv_choice="Select method: "
+ if "!uv_choice!"=="1" (
+ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
+ set "PATH=!USERPROFILE!\.local\bin;!APPDATA!\uv\bin;!PATH!"
+ )
+ if "!uv_choice!"=="2" python -m pip install uv
+ )
+ goto START_INSTALL
+)
+
+if "%choice%"=="3" (
+ set "ENV_TYPE=conda"
+ goto START_INSTALL
+)
+
+if "%choice%"=="4" (
+ set "ENV_TYPE=none"
+ goto START_INSTALL
+)
+
+if "%choice%"=="5" exit
+goto MENU
+
+:START_INSTALL
+python setup.py install --env !ENV_TYPE!
+pause
\ No newline at end of file
diff --git a/Wan2GP/scripts/install.sh b/Wan2GP/scripts/install.sh
new file mode 100644
index 000000000..a2ff4c46c
--- /dev/null
+++ b/Wan2GP/scripts/install.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+cd "$(dirname "$0")/.."
+clear
+echo "======================================================"
+echo " WAN2GP INSTALLER MENU"
+echo "======================================================"
+echo "1. Use 'venv' (Easiest - Comes prepackaged)"
+echo "2. Use 'uv' (Recommended - Fast)"
+echo "3. Use 'Conda'"
+echo "4. No Environment (Not Recommended)"
+echo "5. Exit"
+echo "------------------------------------------------------"
+read -p "Select an option (1-4): " choice
+
+if [ "$choice" == "1" ]; then
+ ENV_TYPE="venv"
+elif [ "$choice" == "2" ]; then
+ ENV_TYPE="uv"
+ if ! command -v uv &> /dev/null; then
+ echo "[!] 'uv' not found."
+ echo "Installing uv..."
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ source $HOME/.cargo/env
+ fi
+elif [ "$choice" == "3" ]; then
+ ENV_TYPE="conda"
+elif [ "$choice" == "4" ]; then
+ ENV_TYPE="none"
+else
+ exit 0
+fi
+
+python3 setup.py install --env $ENV_TYPE
+echo "Installation complete. Run ./run.sh to start."
+read -p "Press Enter to exit..."
\ No newline at end of file
diff --git a/Wan2GP/scripts/manage.bat b/Wan2GP/scripts/manage.bat
new file mode 100644
index 000000000..dbda4ab6f
--- /dev/null
+++ b/Wan2GP/scripts/manage.bat
@@ -0,0 +1,4 @@
+@echo off
+cd /d "%~dp0.."
+python setup.py manage
+pause
\ No newline at end of file
diff --git a/Wan2GP/scripts/manage.sh b/Wan2GP/scripts/manage.sh
new file mode 100644
index 000000000..b890d3836
--- /dev/null
+++ b/Wan2GP/scripts/manage.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+cd "$(dirname "$0")/.."
+clear
+python3 setup.py manage
\ No newline at end of file
diff --git a/Wan2GP/scripts/run.bat b/Wan2GP/scripts/run.bat
new file mode 100644
index 000000000..3e29bdeff
--- /dev/null
+++ b/Wan2GP/scripts/run.bat
@@ -0,0 +1,4 @@
+@echo off
+cd /d "%~dp0.."
+python setup.py run
+pause
\ No newline at end of file
diff --git a/Wan2GP/scripts/run.sh b/Wan2GP/scripts/run.sh
new file mode 100644
index 000000000..9058fc669
--- /dev/null
+++ b/Wan2GP/scripts/run.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+cd "$(dirname "$0")/.."
+python3 setup.py run
+read -p "Press Enter to exit..."
\ No newline at end of file
diff --git a/Wan2GP/scripts/update.bat b/Wan2GP/scripts/update.bat
new file mode 100644
index 000000000..7a0525a82
--- /dev/null
+++ b/Wan2GP/scripts/update.bat
@@ -0,0 +1,39 @@
+@echo off
+cd /d "%~dp0.."
+setlocal enabledelayedexpansion
+title WanGP Update & Upgrade
+
+:MENU
+cls
+echo ======================================================
+echo WAN2GP UPDATE / UPGRADE
+echo ======================================================
+python setup.py status
+echo 1. Update (git pull + install requirements)
+echo 2. Upgrade (Upgrade Torch, Triton, Sage Attention, etc.)
+echo 3. Platform Migration (Upgrade to Py 3.11/Torch 2.10)
+echo 4. Exit
+echo ------------------------------------------------------
+set /p choice="Select an option (1-4): "
+
+if "%choice%"=="1" (
+ python setup.py update
+ pause
+ goto MENU
+)
+
+if "%choice%"=="2" (
+ python setup.py upgrade
+ pause
+ goto MENU
+)
+
+if "%choice%"=="3" (
+ echo [!] This will rebuild your environment with Python 3.11/Torch 2.10
+ python setup.py migrate
+ pause
+ goto MENU
+)
+
+if "%choice%"=="4" exit
+goto MENU
\ No newline at end of file
diff --git a/Wan2GP/scripts/update.sh b/Wan2GP/scripts/update.sh
new file mode 100644
index 000000000..7887062f3
--- /dev/null
+++ b/Wan2GP/scripts/update.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+cd "$(dirname "$0")/.."
+clear
+echo "======================================================"
+echo " WAN2GP UPDATE / UPGRADE"
+echo "======================================================"
+python3 setup.py status
+echo "1. Update (git pull + install requirements)"
+echo "2. Upgrade (Upgrade Torch, Triton, Sage Attention, etc.)"
+echo "3. Platform Migration (Upgrade to Py 3.11/Torch 2.10)"
+echo "4. Exit"
+echo "------------------------------------------------------"
+read -p "Select an option (1-4): " choice
+
+if [ "$choice" == "1" ]; then
+ python3 setup.py update
+elif [ "$choice" == "2" ]; then
+ python3 setup.py upgrade
+elif [ "$choice" == "3" ]; then
+ echo "[!] This will rebuild your environment with Python 3.11/Torch 2.10"
+ python3 setup.py migrate
+else
+ exit 0
+fi
+read -p "Press Enter to exit..."
\ No newline at end of file
diff --git a/Wan2GP/setup.py b/Wan2GP/setup.py
new file mode 100644
index 000000000..fd4e90e5c
--- /dev/null
+++ b/Wan2GP/setup.py
@@ -0,0 +1,691 @@
+import os
+import sys
+import json
+import subprocess
+import argparse
+import shutil
+import platform
+
+CONFIG_PATH = "setup_config.json"
+ENVS_FILE = "envs.json"
+IS_WIN = os.name == 'nt'
+
+ENV_TEMPLATES = {
+ "uv": {
+ "create": "uv venv --python {ver} \"{dir}\"",
+ "run": os.path.join("{dir}", "Scripts", "python.exe") if IS_WIN else os.path.join("{dir}", "bin", "python"),
+ "install": (os.path.join("{dir}", "Scripts", "python.exe") if IS_WIN else os.path.join("{dir}", "bin", "python")) + " -m uv pip install"
+ },
+ "venv": {
+ "create": "{sys_py} -m venv \"{dir}\"",
+ "run": os.path.join("{dir}", "Scripts", "python.exe") if IS_WIN else os.path.join("{dir}", "bin", "python"),
+ "install": (os.path.join("{dir}", "Scripts", "python.exe") if IS_WIN else os.path.join("{dir}", "bin", "python")) + " -m pip install"
+ },
+ "conda": {
+ "create": "conda create -y -p \"{dir}\" python={ver}",
+ "run": "conda run -p \"{dir}\" python",
+ "install": "conda run -p \"{dir}\" pip install"
+ },
+ "none": {
+ "create": "",
+ "run": "python" if IS_WIN else "python3",
+ "install": "pip install"
+ }
+}
+
+VERSION_CHECK_SCRIPT = """
+import sys
+import importlib
+import importlib.metadata
+
+pkgs = ['torch', 'triton', 'sageattention', 'flash_attn']
+res = []
+try:
+ res.append(f"python={sys.version.split()[0]}")
+except:
+ res.append("python=Unknown")
+
+for p in pkgs:
+ try:
+ ver = importlib.metadata.version(p)
+ res.append(f"{p}={ver}")
+ except importlib.metadata.PackageNotFoundError:
+ try:
+ # Fallback to __version__
+ m = importlib.import_module(p)
+ ver = getattr(m, '__version__', 'Installed')
+ res.append(f"{p}={ver}")
+ except ImportError:
+ res.append(f"{p}=Missing")
+ except Exception:
+ res.append(f"{p}=Error")
+print("||".join(res))
+"""
+
+class EnvsManager:
+ def __init__(self):
+ self.data = {"active": None, "envs": {}}
+ self.load()
+
+ def load(self):
+ if os.path.exists(ENVS_FILE):
+ try:
+ with open(ENVS_FILE, 'r') as f:
+ self.data = json.load(f)
+ except:
+ print(f"[!] Warning: {ENVS_FILE} corrupted. Starting fresh.")
+
+ def save(self):
+ with open(ENVS_FILE, 'w') as f:
+ json.dump(self.data, f, indent=4)
+
+ def get_active(self):
+ return self.data.get("active")
+
+ def set_active(self, name):
+ if name in self.data["envs"]:
+ self.data["active"] = name
+ self.save()
+ print(f"[*] '{name}' is now the active environment.")
+ else:
+ print(f"[!] Environment '{name}' not found.")
+
+ def add_env(self, name, type, path):
+ self.data["envs"][name] = {"type": type, "path": path}
+ if not self.data["active"]:
+ self.data["active"] = name
+ self.save()
+
+ def remove_env(self, name):
+ if name in self.data["envs"]:
+ entry = self.data["envs"][name]
+ path = entry["path"]
+
+ if os.path.exists(path) and entry["type"] != "none":
+ try:
+ print(f"[*] Deleting directory: {path}")
+ if entry["type"] == "conda":
+ run_cmd(f"conda env remove -p \"{path}\" -y")
+ else:
+ shutil.rmtree(path)
+ except Exception as e:
+ print(f"[!] Error removing directory: {e}")
+
+ del self.data["envs"][name]
+ if self.data["active"] == name:
+ self.data["active"] = None
+ keys = list(self.data["envs"].keys())
+ if keys:
+ self.data["active"] = keys[0]
+ print(f"[*] Active environment switched to '{keys[0]}'.")
+ else:
+ print("[*] No environments left.")
+ self.save()
+
+ def list_envs(self):
+ return self.data["envs"]
+
+ def resolve_target_env(self):
+ """Intelligently determine which env to use for operations."""
+ envs = self.list_envs()
+ if not envs:
+ print("[!] No environments found. Please run install first.")
+ sys.exit(1)
+
+ active = self.get_active()
+
+ if len(envs) == 1:
+ return list(envs.keys())[0]
+
+ print("\nMultiple environments detected:")
+ keys = list(envs.keys())
+ for i, k in enumerate(keys):
+ marker = "*" if k == active else " "
+ print(f"{i+1}. [{marker}] {k} ({envs[k]['type']})")
+
+ print(f"Default: {active}")
+ choice = input("Select environment (Number) or Press Enter for Default: ").strip()
+
+ if choice == "":
+ return active
+ try:
+ idx = int(choice) - 1
+ if 0 <= idx < len(keys):
+ return keys[idx]
+ except:
+ pass
+ return active
+
+def load_config():
+ if not os.path.exists(CONFIG_PATH):
+ print(f"Error: {CONFIG_PATH} not found.")
+ sys.exit(1)
+ with open(CONFIG_PATH, 'r') as f: return json.load(f)
+
+def get_gpu_info():
+ try:
+ name = subprocess.check_output(
+ ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
+ encoding='utf-8',
+ stderr=subprocess.DEVNULL
+ ).strip()
+ return name, "NVIDIA"
+ except: pass
+
+ if IS_WIN:
+ try:
+ name = subprocess.check_output(
+ "wmic path win32_VideoController get name",
+ shell=True,
+ encoding='utf-8',
+ stderr=subprocess.DEVNULL
+ )
+ name = name.replace("Name", "").strip().split('\n')[0].strip()
+ if "Radeon" in name or "AMD" in name: return name, "AMD"
+ return name, "INTEL"
+ except: pass
+ else:
+ try:
+ name = subprocess.check_output(
+ "lspci | grep -i vga",
+ shell=True,
+ encoding='utf-8',
+ stderr=subprocess.DEVNULL
+ )
+ if "NVIDIA" in name: return name, "NVIDIA"
+ if "AMD" in name or "Advanced Micro Devices" in name: return name, "AMD"
+ except: pass
+
+ return "Unknown", "UNKNOWN"
+
+def get_profile_key(gpu_name, vendor):
+ g = gpu_name.upper()
+ if vendor == "NVIDIA":
+ if "50" in g: return "RTX_50"
+ if "40" in g: return "RTX_40"
+ if "30" in g: return "RTX_30"
+ if "20" in g or "QUADRO" in g: return "RTX_20"
+ return "GTX_10"
+ elif vendor == "AMD":
+ if any(x in g for x in ["7600", "7700", "7800", "7900"]): return "AMD_GFX110X"
+ if any(x in g for x in ["7000", "Z1", "PHOENIX"]): return "AMD_GFX1151"
+ if any(x in g for x in ["8000", "STRIX", "1201"]): return "AMD_GFX1201"
+ return "AMD_GFX110X"
+ return "RTX_40"
+
+def get_os_key():
+ return "win" if IS_WIN else "linux"
+
+def resolve_cmd(cmd_entry):
+ if isinstance(cmd_entry, dict):
+ return cmd_entry.get(get_os_key())
+ return cmd_entry
+
+def run_cmd(cmd, env_vars=None):
+ if not cmd: return
+
+ if "&&" in cmd and not IS_WIN:
+ print(f"\n>>> Running (Shell): {cmd}")
+ custom_env = os.environ.copy()
+ if env_vars: custom_env.update(env_vars)
+ subprocess.run(cmd, shell=True, check=True, env=custom_env)
+ return
+
+ print(f"\n>>> Running: {cmd}")
+ custom_env = os.environ.copy()
+ if env_vars:
+ for k, v in env_vars.items():
+ print(f" [ENV SET] {k}={v}")
+ custom_env[k] = v
+
+ subprocess.run(cmd, shell=True, check=True, env=custom_env)
+
+def get_env_details(name, env_data):
+ env_type = env_data["type"]
+ dir_name = env_data["path"]
+ entry = ENV_TEMPLATES[env_type]
+
+ if env_type == "conda":
+ cmd_base = entry['run'].format(dir=dir_name)
+ full_cmd = f"{cmd_base} -c \"{VERSION_CHECK_SCRIPT.replace(chr(10), ';')}\""
+ else:
+ py_exec = entry['run'].format(dir=dir_name)
+ full_cmd = [py_exec, "-c", VERSION_CHECK_SCRIPT]
+
+ try:
+ if env_type == "conda":
+ output = subprocess.check_output(full_cmd, shell=True, encoding='utf-8', stderr=subprocess.DEVNULL)
+ else:
+ output = subprocess.check_output(full_cmd, encoding='utf-8', stderr=subprocess.DEVNULL)
+
+ data = {k: v for k, v in [x.split('=') for x in output.strip().split('||')]}
+ data['path'] = dir_name
+ data['type'] = env_type
+ return data
+ except Exception as e:
+ return {'error': str(e), 'type': env_type, 'path': dir_name}
+
+def show_status():
+ manager = EnvsManager()
+ print("\n" + "="*95)
+ print(f"{'INSTALLED ENVIRONMENTS & VERSIONS':^95}")
+ print("="*95)
+
+ envs = manager.list_envs()
+ active = manager.get_active()
+
+ if not envs:
+ print(" No environments installed.")
+ print("="*95)
+ return
+
+ print(f"{'NAME':<15} | {'TYPE':<6} | {'PYTHON':<8} | {'TORCH':<15} | {'TRITON':<10} | {'SAGE':<12} | {'FLASH':<12}")
+ print("-" * 95)
+
+ for name, data in envs.items():
+ details = get_env_details(name, data)
+ marker = "*" if name == active else " "
+ display_name = f"[{marker}] {name}"
+
+ if 'error' in details:
+ print(f"{display_name:<15} | {data['type']:<6} | [Error reading environment]")
+ continue
+
+ print(f"{display_name:<15} | {data['type']:<6} | "
+ f"{details.get('python','?'):<8} | "
+ f"{details.get('torch','?'):<15} | "
+ f"{details.get('triton','?'):<10} | "
+ f"{details.get('sageattention','?'):<12} | "
+ f"{details.get('flash_attn','?'):<12}")
+
+ print("-" * 95)
+ print(f" * = Active Environment")
+ print("="*95 + "\n")
+
+def install_logic(env_name, env_type, env_path, py_k, torch_k, triton_k, sage_k, flash_k, kernel_list, config):
+ template = ENV_TEMPLATES[env_type]
+ target_py_ver = config['components']['python'][py_k]['ver']
+
+ print(f"\n[1/3] Preparing Environment: {env_name} ({env_type})...")
+
+ if env_type != "none":
+ create_cmd = template["create"].format(ver=target_py_ver, dir=env_path, sys_py=sys.executable)
+ if create_cmd:
+ run_cmd(create_cmd)
+
+ pip = template["install"].format(dir=env_path)
+
+ print(f"\n[2/3] Installing Torch: {config['components']['torch'][torch_k]['label']}...")
+ torch_cmd = resolve_cmd(config['components']['torch'][torch_k]['cmd'])
+ run_cmd(f"{pip} {torch_cmd}")
+
+ print(f"\n[3/3] Installing Requirements & Extras...")
+ run_cmd(f"{pip} -r requirements.txt")
+
+ if triton_k:
+ cmd = resolve_cmd(config['components']['triton'][triton_k]['cmd'])
+ if cmd: run_cmd(f"{pip} {cmd}")
+
+ if sage_k:
+ cmd = resolve_cmd(config['components']['sage'][sage_k]['cmd'])
+ if cmd.startswith("http") or cmd.startswith("sageattention"):
+ run_cmd(f"{pip} {cmd}")
+ else:
+ if env_type == "venv" or env_type == "uv":
+ act = f". {env_path}/bin/activate && " if not IS_WIN else ""
+ run_cmd(f"{act}{cmd}")
+ elif env_type == "conda":
+ pass
+
+ if flash_k:
+ cmd = resolve_cmd(config['components']['flash'][flash_k]['cmd'])
+ if cmd: run_cmd(f"{pip} {cmd}")
+
+ for k in kernel_list:
+ if k in config['components']['kernels']:
+ cmd = resolve_cmd(config['components']['kernels'][k]['cmd'])
+ if cmd: run_cmd(f"{pip} {cmd}")
+
+def menu(title, options, recommended_key=None):
+ print(f"\n--- {title} ---")
+ keys = list(options.keys())
+ for i, k in enumerate(keys):
+ rec = " [RECOMMENDED FOR YOUR GPU]" if k == recommended_key else ""
+ print(f"{i+1}. {options[k]['label']}{rec}")
+ choice = input(f"Select option (Enter for Recommended): ")
+ if choice == "" and recommended_key: return recommended_key
+ try: return keys[int(choice)-1]
+ except: return recommended_key
+
+def do_install_interactive(env_type, config, detected_key):
+ manager = EnvsManager()
+ create_wgp_config(detected_key, config)
+
+ default_name = f"env_{env_type}" if env_type != "none" else "system"
+ print(f"\n--- Configuration for {env_type} ---")
+ name = input(f"Enter a name for this environment (Default: {default_name}): ").strip()
+ if not name: name = default_name
+
+ cwd = os.getcwd()
+ path = os.path.join(cwd, name) if env_type != "none" else ""
+
+ if name in manager.list_envs():
+ print(f"\n[!] Warning: Environment '{name}' already exists in registry.")
+ choice = input("Do you want to overwrite it? (This will delete the old folder) [y/N]: ").lower()
+ if choice != 'y': return
+ manager.remove_env(name)
+ elif os.path.exists(path) and env_type != "none":
+ print(f"\n[!] Warning: Directory '{path}' exists but is not registered.")
+ choice = input("Do you want to overwrite this directory? [y/N]: ").lower()
+ if choice != 'y': return
+ try: shutil.rmtree(path)
+ except: pass
+
+ print("\n--- Select Install Mode ---")
+ print("1. Autoselect (Recommended - Based on your card)")
+ print("2. Manual Selection (Custom versions)")
+ print("3. Use Latest (Forces RTX 50 Profile)")
+
+ mode = input("Select option (1-3) [Default: 1]: ").strip()
+
+ if mode == "2":
+ base = config['gpu_profiles'][detected_key]
+ py_k = menu("Python Version", config['components']['python'], base['python'])
+ torch_k = menu("Torch Version", config['components']['torch'], base['torch'])
+ triton_k = menu("Triton", config['components']['triton'], base['triton'])
+ sage_k = menu("Sage Attention", config['components']['sage'], base['sage'])
+ flash_k = menu("Flash Attention", config['components']['flash'], base['flash'])
+ kernels = base['kernels']
+
+ install_logic(name, env_type, path, py_k, torch_k, triton_k, sage_k, flash_k, kernels, config)
+
+ elif mode == "3":
+ p = config['gpu_profiles']['RTX_50']
+ install_logic(name, env_type, path, p['python'], p['torch'], p['triton'], p['sage'], p.get('flash'), p['kernels'], config)
+ else:
+ p = config['gpu_profiles'][detected_key]
+ install_logic(name, env_type, path, p['python'], p['torch'], p['triton'], p['sage'], p.get('flash'), p['kernels'], config)
+
+ manager.add_env(name, env_type, path)
+
+ if len(manager.list_envs()) > 1:
+ choice = input(f"\nDo you want to make '{name}' the active environment? [Y/n]: ").lower()
+ if choice != 'n':
+ manager.set_active(name)
+ else:
+ print(f"\n[*] '{name}' is the only environment, setting as active.")
+ manager.set_active(name)
+
+def do_manage():
+ manager = EnvsManager()
+ while True:
+ os.system('cls' if IS_WIN else 'clear')
+ print("======================================================")
+ print(" ENVIRONMENT MANAGER")
+ print("======================================================")
+ envs = manager.list_envs()
+ active = manager.get_active()
+
+ if not envs:
+ print(" No environments installed.")
+ else:
+ for name, data in envs.items():
+ status = "(Active)" if name == active else ""
+ print(f" - {name:<15} [{data['type']}] {status}")
+
+ print("------------------------------------------------------")
+ print("1. Set Active Environment")
+ print("2. Delete Environment")
+ print("3. Add Existing Environment")
+ print("4. List Environment Details")
+ print("5. Return to Menu / Exit")
+
+ choice = input("\nSelect option: ")
+
+ if choice == "1":
+ name = input("Enter name of environment to activate: ")
+ manager.set_active(name)
+ input("Press Enter...")
+ elif choice == "2":
+ name = input("Enter name of environment to DELETE: ")
+ conf = input(f"Are you sure you want to delete '{name}' and its files? (y/n): ")
+ if conf.lower() == 'y':
+ manager.remove_env(name)
+ input("Deleted. Press Enter...")
+ elif choice == "3":
+ path = input("Enter the path to the existing environment folder: ").strip()
+ if not os.path.exists(path):
+ print("[!] Error: Path does not exist.")
+ else:
+ name = input("Enter a nickname for this environment: ").strip()
+ if not name: name = os.path.basename(path.rstrip(os.sep))
+
+ print("\nSelect Environment Type:")
+ print("1. venv")
+ print("2. uv")
+ print("3. conda")
+ t_choice = input("Choice (Default 1): ")
+ e_type = "uv" if t_choice == "2" else "conda" if t_choice == "3" else "venv"
+
+ manager.add_env(name, e_type, os.path.abspath(path))
+ print(f"[*] Registered '{name}' at {os.path.abspath(path)}")
+ input("Press Enter...")
+ elif choice == "4":
+ show_status()
+ input("Press Enter...")
+ elif choice == "5":
+ break
+
+def do_migrate(config):
+ manager = EnvsManager()
+ print("\n" + "="*60)
+ print(" WAN2GP AUTOMATED PLATFORM MIGRATION (TO 3.11)")
+ print("="*60)
+
+ env_name = manager.resolve_target_env()
+ env_data = manager.list_envs()[env_name]
+
+ print(f"\nTarget Environment: {env_name} ({env_data['type']})")
+ confirm = input(f"This will wipe '{env_name}' and rebuild it. Proceed? (y/n): ")
+ if confirm.lower() != 'y': return
+
+ target = config['gpu_profiles']['RTX_50']
+
+ manager.remove_env(env_name)
+
+ install_logic(env_name, env_data['type'], env_data['path'],
+ target['python'], target['torch'], target['triton'],
+ target['sage'], target.get('flash'), target['kernels'], config)
+
+ manager.add_env(env_name, env_data['type'], env_data['path'])
+
+def do_upgrade(config):
+ manager = EnvsManager()
+ print("\n" + "="*60)
+ print(" WAN2GP MANUAL COMPONENT UPGRADE")
+ print("="*60)
+
+ env_name = manager.resolve_target_env()
+ env_data = manager.list_envs()[env_name]
+
+ gpu_name, vendor = get_gpu_info()
+ rec = config['gpu_profiles'][get_profile_key(gpu_name, vendor)]
+
+ py_k = menu("Python Version", config['components']['python'], rec['python'])
+ torch_k = menu("Torch Version", config['components']['torch'], rec['torch'])
+ triton_k = menu("Triton", config['components']['triton'], rec['triton'])
+ sage_k = menu("Sage Attention", config['components']['sage'], rec['sage'])
+ flash_k = menu("Flash Attention", config['components']['flash'], rec['flash'])
+
+ install_logic(env_name, env_data['type'], env_data['path'], py_k, torch_k, triton_k, sage_k, flash_k, rec['kernels'], config)
+
+def get_system_specs():
+ ram_gb = 0
+ vram_gb = 0
+
+ if IS_WIN:
+ try:
+ out = subprocess.check_output(
+ ["powershell", "-NoProfile", "-Command", "(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory"],
+ encoding='utf-8', stderr=subprocess.DEVNULL
+ ).strip()
+ if out:
+ ram_gb = int(out) / (1024**3)
+ except:
+ try:
+ out = subprocess.check_output(
+ "wmic computersystem get TotalPhysicalMemory /value",
+ shell=True, encoding='utf-8', stderr=subprocess.DEVNULL
+ )
+ for line in out.splitlines():
+ if "TotalPhysicalMemory=" in line:
+ ram_gb = int(line.split('=')[1]) / (1024**3)
+ break
+ except:
+ pass
+ else:
+ try:
+ with open('/proc/meminfo', 'r') as f:
+ for line in f:
+ if 'MemTotal' in line:
+ kb_val = float(line.split()[1])
+ ram_gb = kb_val / (1024**2)
+ break
+ except: pass
+
+ if ram_gb == 0:
+ print("[!] Warning: Could not detect System RAM. Defaulting to 16GB.")
+ ram_gb = 16
+
+ try:
+ out = subprocess.check_output(
+ ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
+ encoding='utf-8', stderr=subprocess.DEVNULL
+ ).strip()
+ vram_gb = float(out.split('\n')[0]) / 1024
+ except:
+ print("[!] Warning: Could not detect VRAM via nvidia-smi. Defaulting to 8GB.")
+ vram_gb = 8
+
+ return ram_gb, vram_gb
+
+def create_wgp_config(profile_key, config_data):
+ WGP_CONFIG_FILE = "wgp_config.json"
+
+ if os.path.exists(WGP_CONFIG_FILE):
+ return
+
+ print("\n[*] Auto-generating wgp_config.json based on hardware...")
+
+ ram, vram = get_system_specs()
+ print(f" Detected: {int(ram)}GB RAM / {int(vram)}GB VRAM")
+
+ has_high_ram = ram > 60
+ has_mid_ram = ram > 30
+ has_huge_vram = vram > 22
+ has_high_vram = vram > 11
+
+ pid = 5
+
+ if has_high_ram and has_huge_vram:
+ pid = 1
+ elif has_high_ram:
+ pid = 2
+ elif has_mid_ram and has_huge_vram:
+ pid = 3
+ elif has_mid_ram and has_high_vram:
+ pid = 4
+ else:
+ pid = 5
+
+ prof_settings = config_data['gpu_profiles'].get(profile_key, {})
+
+ attn_mode = ""
+ if "50" in profile_key or "40" in profile_key or "30" in profile_key:
+ attn_mode = "sage2"
+ elif "20" in profile_key:
+ attn_mode = "sage"
+
+ compile_mode = ""
+ triton_key = prof_settings.get('triton')
+ if triton_key and triton_key != "none":
+ compile_mode = "transformer"
+
+ config_out = {
+ "attention_mode": attn_mode,
+ "compile": compile_mode,
+ "video_profile": pid,
+ "image_profile": pid,
+ "audio_profile": pid,
+ }
+
+ try:
+ with open(WGP_CONFIG_FILE, 'w') as f:
+ json.dump(config_out, f, indent=4)
+ print(f" Created config with Profile {pid}, Attention: '{attn_mode}', Compile: '{compile_mode}'")
+ except Exception as e:
+ print(f"[!] Error writing config: {e}")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("mode", choices=["install", "run", "update", "migrate", "upgrade", "status", "manage"])
+ parser.add_argument("--env", default="venv", help="Type of env for install (venv, uv, conda, none)")
+ args = parser.parse_args()
+ cfg = load_config()
+
+ if args.mode == "status":
+ show_status()
+ sys.exit(0)
+
+ if args.mode == "manage":
+ do_manage()
+ sys.exit(0)
+
+ gpu_name, vendor = get_gpu_info()
+ profile_key = get_profile_key(gpu_name, vendor)
+ profile = cfg['gpu_profiles'][profile_key]
+
+ if args.mode == "install":
+ print(f"Hardware Detected: {gpu_name} ({vendor})")
+ do_install_interactive(args.env, cfg, profile_key)
+
+ elif args.mode == "run":
+ manager = EnvsManager()
+ active = manager.get_active()
+ if not active:
+ print("[!] No active environment found. Run install or manage.")
+ sys.exit(1)
+
+ env_data = manager.list_envs().get(active)
+ if not env_data:
+ print(f"[!] Active environment '{active}' data missing from registry.")
+ sys.exit(1)
+
+ print(f"[*] Launching using active environment: {active}")
+
+ extra_args = ""
+ if os.path.exists("scripts/args.txt"):
+ with open("scripts/args.txt", "r") as f:
+ lines = [l.strip() for l in f.readlines() if l.strip() and not l.startswith("#")]
+ extra_args = " ".join(lines)
+
+ env_vars = profile.get("env", {})
+ cmd_fmt = ENV_TEMPLATES[env_data['type']]['run']
+ cmd = f"{cmd_fmt.format(dir=env_data['path'])} wgp.py {extra_args}"
+ run_cmd(cmd, env_vars=env_vars)
+
+ elif args.mode == "update":
+ run_cmd("git pull")
+ manager = EnvsManager()
+ env_name = manager.resolve_target_env()
+ env_data = manager.list_envs()[env_name]
+
+ cmd_fmt = ENV_TEMPLATES[env_data['type']]['run']
+ cmd = f"{cmd_fmt.format(dir=env_data['path'])} -m pip install -r requirements.txt"
+ run_cmd(cmd)
+
+ elif args.mode == "migrate":
+ do_migrate(cfg)
+
+ elif args.mode == "upgrade":
+ do_upgrade(cfg)
\ No newline at end of file
diff --git a/Wan2GP/shared/RGB_factors.py b/Wan2GP/shared/RGB_factors.py
index c0db6129c..ec1cf16bd 100644
--- a/Wan2GP/shared/RGB_factors.py
+++ b/Wan2GP/shared/RGB_factors.py
@@ -140,6 +140,141 @@ def get_rgb_factors(model_family, model_type = None, sub_family = None):
]
latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
+ elif model_family == "ltx2":
+ # LTX2_RGB_FACTORS_START
+ latent_rgb_factors = [
+ [0.034942, 0.025665, 0.019890],
+ [-0.010039, -0.003935, -0.004114],
+ [0.012398, -0.009390, 0.002166],
+ [0.021493, 0.004525, -0.001346],
+ [0.007121, 0.021149, 0.006106],
+ [0.009438, 0.013676, 0.003909],
+ [-0.012689, -0.017296, -0.011732],
+ [-0.010145, -0.008827, -0.040893],
+ [-0.000609, 0.002179, 0.004511],
+ [0.017091, 0.010640, 0.000340],
+ [-0.009421, -0.006105, 0.000176],
+ [-0.005993, -0.002687, 0.003946],
+ [0.014557, 0.021932, 0.020061],
+ [0.008345, 0.001006, -0.004580],
+ [0.034493, 0.018170, 0.023638],
+ [-0.001190, 0.003468, 0.012947],
+ [-0.002343, -0.012411, -0.011100],
+ [-0.011259, -0.000162, -0.005704],
+ [0.035304, 0.037420, 0.045084],
+ [-0.002184, -0.026547, -0.019204],
+ [0.011548, 0.006255, 0.012846],
+ [-0.000233, -0.002465, -0.004693],
+ [-0.034094, -0.021138, -0.005063],
+ [-0.000989, -0.000375, 0.002170],
+ [0.007333, 0.015489, 0.024374],
+ [0.006371, -0.018043, -0.031470],
+ [-0.029347, -0.027023, -0.016491],
+ [-0.001422, -0.006893, -0.001695],
+ [0.005795, 0.005987, 0.007621],
+ [-0.003406, 0.002260, 0.000005],
+ [0.029885, 0.015924, 0.013307],
+ [0.010122, 0.005565, 0.003611],
+ [0.010770, 0.001935, 0.003951],
+ [0.000247, 0.007556, -0.000721],
+ [0.003422, 0.002761, 0.002203],
+ [-0.000763, -0.001199, 0.002576],
+ [0.013336, 0.018443, 0.022138],
+ [-0.011958, -0.026991, -0.033736],
+ [0.000398, -0.005840, -0.001664],
+ [0.001106, 0.001808, -0.001595],
+ [-0.003785, 0.003479, -0.001016],
+ [0.015053, 0.015471, 0.013186],
+ [0.008358, 0.007849, 0.018877],
+ [0.005239, -0.003565, -0.034033],
+ [-0.055078, -0.052277, -0.043015],
+ [-0.004671, -0.017110, -0.012464],
+ [-0.007845, -0.001926, -0.000054],
+ [0.063222, 0.078728, 0.058501],
+ [-0.021986, -0.016549, -0.020792],
+ [0.000084, -0.003974, -0.002621],
+ [-0.008120, -0.009763, -0.020849],
+ [0.008795, 0.008390, 0.004435],
+ [0.003783, -0.002978, -0.003743],
+ [0.013151, -0.004677, -0.001796],
+ [0.002784, 0.000578, 0.012341],
+ [0.032561, 0.018570, 0.015776],
+ [-0.018336, -0.014570, 0.001422],
+ [-0.016377, -0.004534, -0.008379],
+ [-0.018032, -0.018326, -0.016427],
+ [0.011842, 0.012533, 0.018174],
+ [-0.007721, -0.007604, -0.009647],
+ [0.010696, 0.002887, 0.012046],
+ [-0.022643, -0.029499, -0.022168],
+ [-0.004399, -0.003200, -0.001333],
+ [0.077523, 0.061106, 0.040685],
+ [0.016623, 0.018785, 0.015607],
+ [0.005365, 0.004560, -0.008725],
+ [0.001207, -0.001362, 0.001578],
+ [0.019177, 0.030301, 0.023738],
+ [0.000936, 0.001365, 0.005611],
+ [0.015523, 0.008943, 0.012178],
+ [-0.026548, -0.002319, 0.008205],
+ [0.015405, -0.008554, -0.012774],
+ [-0.002702, 0.012692, 0.004927],
+ [-0.001287, -0.004590, -0.003223],
+ [0.008966, 0.009299, 0.005486],
+ [-0.016937, -0.008969, -0.014726],
+ [0.005981, 0.006354, -0.006855],
+ [0.009750, 0.006882, 0.007736],
+ [0.001820, 0.004259, 0.012132],
+ [0.012835, 0.012450, 0.011795],
+ [0.003041, 0.010194, 0.013934],
+ [-0.016527, -0.032534, -0.030963],
+ [-0.015136, -0.007481, -0.009911],
+ [0.030708, 0.021832, 0.025773],
+ [-0.008353, -0.012020, -0.008660],
+ [0.018777, 0.017951, 0.006013],
+ [-0.006846, -0.006453, -0.005759],
+ [0.017944, 0.016239, 0.017806],
+ [-0.009166, -0.004829, 0.002145],
+ [0.011764, 0.010028, 0.008942],
+ [0.015022, -0.016713, -0.031551],
+ [-0.103677, -0.102297, -0.093770],
+ [-0.006865, -0.003216, -0.002682],
+ [-0.007705, 0.001121, -0.012102],
+ [0.015788, -0.003327, 0.006230],
+ [-0.005562, -0.009712, -0.008889],
+ [0.006411, 0.011945, 0.014182],
+ [-0.003523, -0.003832, -0.008597],
+ [-0.002705, -0.007006, -0.002440],
+ [0.010826, 0.021793, 0.019520],
+ [0.021403, 0.017133, 0.011349],
+ [-0.020997, -0.001073, -0.013768],
+ [-0.004439, 0.005850, 0.001262],
+ [0.008814, 0.004013, 0.004906],
+ [0.008196, 0.005846, 0.007751],
+ [0.000102, 0.001182, 0.005504],
+ [-0.007416, -0.009051, -0.006597],
+ [0.039224, 0.052929, 0.060699],
+ [0.006937, 0.009651, 0.004330],
+ [-0.013241, -0.008414, -0.010154],
+ [-0.021549, 0.012296, 0.043766],
+ [0.009432, -0.007242, 0.003204],
+ [-0.038648, -0.035801, -0.032508],
+ [0.009745, 0.013935, 0.012653],
+ [-0.024403, -0.018949, -0.025871],
+ [-0.002547, -0.010403, -0.004686],
+ [-0.000516, 0.000137, 0.000241],
+ [-0.001571, -0.000162, -0.010227],
+ [-0.023109, -0.024477, -0.019911],
+ [-0.003461, -0.001731, -0.004959],
+ [0.001869, 0.002194, 0.001378],
+ [-0.011800, -0.001712, -0.004228],
+ [-0.019423, -0.002752, 0.007851],
+ [-0.000403, -0.011006, -0.011421],
+ [0.000448, -0.002081, -0.002759],
+ [-0.004389, -0.005971, -0.034043],
+ [-0.013471, -0.013140, -0.013874],
+ ]
+ latent_rgb_factors_bias = [-0.048359, -0.119311, -0.188382]
+ # LTX2_RGB_FACTORS_END
+
elif model_family == "ltxv":
latent_channels = 128
latent_dimensions = 3
@@ -340,4 +475,4 @@ def get_rgb_factors(model_family, model_type = None, sub_family = None):
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
else:
latent_rgb_factors_bias = latent_rgb_factors = None
- return latent_rgb_factors, latent_rgb_factors_bias
\ No newline at end of file
+ return latent_rgb_factors, latent_rgb_factors_bias
diff --git a/Wan2GP/shared/attention.py b/Wan2GP/shared/attention.py
index 16cb412f3..89ee52eb9 100644
--- a/Wan2GP/shared/attention.py
+++ b/Wan2GP/shared/attention.py
@@ -4,10 +4,18 @@
from mmgp import offload
import torch.nn.functional as F
import warnings
+from importlib.metadata import version
major, minor = torch.cuda.get_device_capability(None)
bfloat16_supported = major >= 8
+try:
+ import triton
+ triton_installed = True
+except:
+ triton_installed = False
+
+
try:
from xformers.ops import memory_efficient_attention
except ImportError:
@@ -46,7 +54,12 @@ def sageattn_varlen_wrapper(
from spas_sage_attn import block_sparse_sage2_attn_cuda
except ImportError:
block_sparse_sage2_attn_cuda = None
-
+ if not triton_installed:
+ try:
+ sg2_version = version("sageattention")
+ print("Sage Attention has been detected but it won't work until Triton is installed.")
+ except ImportError:
+ pass
try:
from .sage2_core import sageattn as sageattn2, is_sage2_supported
@@ -54,6 +67,13 @@ def sageattn_varlen_wrapper(
except ImportError:
sageattn2 = None
sage2_supported = False
+ if not triton_installed:
+ try:
+ sg2_version = version("sageattention")
+ if not triton_installed: print("Sage Attention 2 has been detected but it won't work until Triton is installed.")
+ except ImportError:
+ pass
+
@torch.compiler.disable()
def sageattn2_wrapper(
qkv_list,
@@ -69,14 +89,28 @@ def sageattn2_wrapper(
try:
from sageattn import sageattn_blackwell as sageattn3
+ if not triton_installed:
+ print("Sage Attention 3 is installed but it won't be supported until Triton is installed.")
except ImportError:
sageattn3 = None
+ if not triton_installed:
+ try:
+ sg3_version = version("sageattn_blackwell")
+ print("Sage Attention 3 has been detected but it won't work until Triton is installed.")
+ except ImportError:
+ pass
if sageattn3 is None:
try:
from sageattn3 import sageattn3_blackwell as sageattn3 #word0 windows version
except ImportError:
sageattn3 = None
+ if not triton_installed:
+ try:
+ sg3_version = version("sageattn3_blackwell")
+ print("Sage Attention 3 has been detected but it won't work until Triton is installed.")
+ except ImportError:
+ pass
@torch.compiler.disable()
def sageattn3_wrapper(
@@ -166,17 +200,17 @@ def get_attention_modes():
def get_supported_attention_modes():
ret = get_attention_modes()
major, minor = torch.cuda.get_device_capability()
- if major < 10:
+ if major < 10 or not triton_installed:
if "sage3" in ret:
ret.remove("sage3")
- if not sage2_supported:
+ if not sage2_supported or not triton_installed:
if "sage2" in ret:
ret.remove("sage2")
if "radial" in ret:
ret.remove("radial")
- if major < 7:
+ if major < 7 or not triton_installed:
if "sage" in ret:
ret.remove("sage")
diff --git a/Wan2GP/shared/ffmpeg_setup.py b/Wan2GP/shared/ffmpeg_setup.py
index 5d114568b..c5c7928b7 100644
--- a/Wan2GP/shared/ffmpeg_setup.py
+++ b/Wan2GP/shared/ffmpeg_setup.py
@@ -29,16 +29,52 @@ def download_ffmpeg(bin_directory: typing.Optional[typing.Union[str, Path]] = No
bin_dir.mkdir(parents=True, exist_ok=True)
repo_root = bin_dir.parent
+ def _candidate_name(name: str) -> str:
+ if os.name == "nt" and not name.endswith(".exe"):
+ return f"{name}.exe"
+ return name
+
+ def _quarantine_root_ffmpeg():
+ root_ffmpeg = repo_root / _candidate_name("ffmpeg")
+ if not root_ffmpeg.is_file():
+ return
+ quarantine_dir = repo_root / "ffmpeg_quarantine"
+ quarantine_dir.mkdir(parents=True, exist_ok=True)
+ target_path = quarantine_dir / root_ffmpeg.name
+ if target_path.exists():
+ stem = target_path.stem
+ suffix = target_path.suffix
+ idx = 1
+ while True:
+ candidate = quarantine_dir / f"{stem}_{idx}{suffix}"
+ if not candidate.exists():
+ target_path = candidate
+ break
+ idx += 1
+ shutil.move(str(root_ffmpeg), str(target_path))
+ print(
+ f"[FFmpeg] Quarantined root binary: {root_ffmpeg} -> {target_path}. "
+ "Reason: ffmpeg.exe in the project root can be picked from CWD and break TorchCodec DLL loading on Windows. Quarantined file can be deleted if unused."
+ )
+
def _ensure_bin_dir_on_path():
current_path = os.environ.get("PATH", "")
path_parts = current_path.split(os.pathsep) if current_path else []
- dirs_to_add = []
- # Add ffmpeg_bins and repo root to PATH
+
+ def _normalize(p: str) -> str:
+ p = os.path.normpath(p)
+ return os.path.normcase(p) if os.name == "nt" else p
+
+ prioritized = []
+ seen = set()
for d in [bin_dir, repo_root]:
- if str(d) not in path_parts:
- dirs_to_add.append(str(d))
- if dirs_to_add:
- os.environ["PATH"] = os.pathsep.join(dirs_to_add + path_parts)
+ key = _normalize(str(d))
+ if key not in seen:
+ prioritized.append(str(d))
+ seen.add(key)
+
+ filtered = [p for p in path_parts if _normalize(p) not in seen]
+ os.environ["PATH"] = os.pathsep.join(prioritized + filtered)
def _ensure_library_path():
if os.name == "nt":
@@ -48,14 +84,10 @@ def _ensure_library_path():
if str(bin_dir) not in ld_parts:
os.environ["LD_LIBRARY_PATH"] = os.pathsep.join([str(bin_dir)] + ld_parts) if current_ld else str(bin_dir)
+ _quarantine_root_ffmpeg()
_ensure_bin_dir_on_path()
_ensure_library_path()
- def _candidate_name(name: str) -> str:
- if os.name == "nt" and not name.endswith(".exe"):
- return f"{name}.exe"
- return name
-
def _resolve_path(name: str) -> typing.Optional[Path]:
# Check ffmpeg_bins folder first
candidate = bin_dir / _candidate_name(name)
@@ -73,6 +105,10 @@ def _resolve_path(name: str) -> typing.Optional[Path]:
def _binary_exists(name: str) -> bool:
return _resolve_path(name) is not None
+ def _local_binary_exists(name: str) -> bool:
+ candidate = bin_dir / _candidate_name(name)
+ return candidate.exists()
+
def _libs_present() -> bool:
if os.name == "nt":
return True
@@ -93,8 +129,12 @@ def _set_env_vars():
if ffplay_path:
os.environ["FFPLAY_BINARY"] = str(ffplay_path)
- missing = [binary for binary in required_binaries if not _binary_exists(binary)]
- libs_ok = _libs_present()
+ if os.name == "nt":
+ missing = [binary for binary in required_binaries if not _local_binary_exists(binary)]
+ libs_ok = True
+ else:
+ missing = [binary for binary in required_binaries if not _binary_exists(binary)]
+ libs_ok = _libs_present()
if not missing and libs_ok:
_set_env_vars()
return
@@ -218,9 +258,14 @@ def _download_posix_build():
print(f"Failed to download FFmpeg binaries automatically: {exc}")
return
- if not all(_binary_exists(binary) for binary in required_binaries):
- print("FFmpeg binaries are still missing after download; please install them manually.")
- return
+ if os.name == "nt":
+ if not all(_local_binary_exists(binary) for binary in required_binaries):
+ print("FFmpeg binaries are still missing after download; please install them manually.")
+ return
+ else:
+ if not all(_binary_exists(binary) for binary in required_binaries):
+ print("FFmpeg binaries are still missing after download; please install them manually.")
+ return
_ensure_bin_dir_on_path()
_ensure_library_path()
diff --git a/Wan2GP/shared/gradio/audio_gallery.py b/Wan2GP/shared/gradio/audio_gallery.py
index 9f0f265e9..1fd9a7fbe 100644
--- a/Wan2GP/shared/gradio/audio_gallery.py
+++ b/Wan2GP/shared/gradio/audio_gallery.py
@@ -6,6 +6,17 @@
import uuid
+def _get_selected_idx(audio_infos, selected_idx):
+ selected_idx = int(selected_idx) if selected_idx is not None else 0
+ if selected_idx >= len(audio_infos):
+ selected_idx = len(audio_infos) -1
+ elif selected_idx < 0:
+ selected_idx = 0
+ if len(audio_infos) == 0:
+ selected_idx = -1
+ return selected_idx
+
+
class AudioGallery:
"""
A custom Gradio component that displays an audio gallery with thumbnails.
@@ -19,7 +30,7 @@ class AudioGallery:
update_only: If True, only render the inner HTML/Audio (internal use)
"""
- def __init__(self, audio_paths=None, selected_index=0, max_thumbnails=10, height=400, label="Audio Gallery", update_only=False):
+ def __init__(self, audio_paths=None, selected_index=-1, max_thumbnails=10, height=400, label="Audio Gallery", update_only=False):
self.audio_paths = audio_paths or []
self.selected_index = selected_index
self.max_thumbnails = max_thumbnails
@@ -204,8 +215,8 @@ def update(
else:
selected_idx = 0
- if not audio_infos or selected_idx >= len(audio_infos) or selected_idx < 0:
- selected_idx = 0
+ selected_idx = _get_selected_idx(audio_infos, selected_idx)
+
# Trigger id to notify the frontend to refresh (observed via MutationObserver on HTML rerender)
refresh_id = str(uuid.uuid4())
@@ -295,7 +306,7 @@ def _render(self, update_only):
self.refresh_trigger.change(
fn=self._refresh_gallery,
inputs=[self.refresh_trigger, self.state_paths, self.state_selected],
- outputs=[self.audio_player, self.gallery_html],
+ outputs=[self.audio_player, self.gallery_html, self.state_selected],
show_progress="hidden",
)
@@ -309,11 +320,12 @@ def _select_audio(self, click_value, paths_json, current_selected):
audio_infos = self._process_audio_paths(paths)
if not audio_infos:
- return None, self._create_gallery_html([], 0), paths_json, 0, ""
+ return None, self._create_gallery_html([], 0), paths_json, -1, ""
new_index = int(click_value)
if 0 <= new_index < len(audio_infos):
selected_path = audio_infos[new_index]["path"]
+ if not os.path.exists(selected_path): selected_path = None
return (
selected_path,
self._create_gallery_html(audio_infos, new_index),
@@ -329,25 +341,23 @@ def _select_audio(self, click_value, paths_json, current_selected):
def _refresh_gallery(self, refresh_id, paths_json, selected_idx):
"""Refresh gallery based on state (programmatic)."""
if not refresh_id:
- return self._render_from_state(paths_json, selected_idx)[:2]
+ return self._render_from_state(paths_json, selected_idx)[:2], selected_idx
try:
paths = json.loads(paths_json) if paths_json else []
audio_infos = self._process_audio_paths(paths)
if not audio_infos:
- return None, self._create_gallery_html([], 0)
+ return None, self._create_gallery_html([], 0), -1
- selected_idx = int(selected_idx) if selected_idx is not None else 0
- if selected_idx >= len(audio_infos) or selected_idx < 0:
- selected_idx = 0
+ selected_idx = _get_selected_idx(audio_infos, selected_idx)
selected_path = audio_infos[selected_idx]["path"]
gallery_html_content = self._create_gallery_html(audio_infos, selected_idx)
- return selected_path, gallery_html_content
+ return selected_path, gallery_html_content, selected_idx
except Exception:
- return None, self._create_gallery_html([], 0)
+ return None, self._create_gallery_html([], 0), -1
def _get_audio_duration(self, audio_path):
"""Get audio duration in seconds. Returns formatted string."""
@@ -388,19 +398,25 @@ def _format_duration(self, seconds):
secs = int(seconds % 60)
return f"{mins}:{secs:02d}"
- def _get_file_info(self, audio_path):
+ def _get_file_info(self, audio_path, not_found=False):
"""Get file information: basename, date/time, duration."""
p = Path(audio_path)
basename = p.name
- # Get modification time
- mtime = os.path.getmtime(audio_path)
- dt = datetime.fromtimestamp(mtime)
- date_str = dt.strftime("%Y-%m-%d")
- time_str = dt.strftime("%H:%M:%S")
+ if not_found:
+ mtime = ""
+ date_str = "Deleted"
+ time_str = "00:00:00"
+ duration = "0:00"
+ else:
+ # Get modification time
+ mtime = os.path.getmtime(audio_path)
+ dt = datetime.fromtimestamp(mtime)
+ date_str = dt.strftime("%Y-%m-%d")
+ time_str = dt.strftime("%H:%M:%S")
- # Get duration
- duration = self._get_audio_duration(audio_path)
+ # Get duration
+ duration = self._get_audio_duration(audio_path)
return {
"basename": basename,
@@ -624,6 +640,8 @@ def _process_audio_paths(self, paths):
try:
if os.path.exists(path):
audio_infos.append(self._get_file_info(path))
+ else:
+ audio_infos.append(self._get_file_info(path, True))
except Exception:
continue
audio_infos = audio_infos[: self.max_thumbnails]
@@ -636,15 +654,13 @@ def _render_from_state(self, paths_json, selected_idx):
audio_infos = self._process_audio_paths(paths)
if not audio_infos:
- return None, self._create_gallery_html([], 0), paths_json, 0, ""
-
- selected_idx = int(selected_idx) if selected_idx is not None else 0
- if selected_idx >= len(audio_infos) or selected_idx < 0:
- selected_idx = 0
+ return None, self._create_gallery_html([], 0), paths_json, -1, ""
- selected_path = audio_infos[selected_idx]["path"]
+ selected_idx = _get_selected_idx(audio_infos, selected_idx)
+ selected_path = audio_infos[selected_idx]["path"]
+ if not os.path.exists(selected_path): selected_path = None
gallery_html_content = self._create_gallery_html(audio_infos, selected_idx)
return selected_path, gallery_html_content, paths_json, selected_idx, ""
except Exception:
- return None, self._create_gallery_html([], 0), paths_json, 0, ""
+ return None, self._create_gallery_html([], 0), paths_json, -1, ""
diff --git a/Wan2GP/shared/gradio/gallery.py b/Wan2GP/shared/gradio/gallery.py
index bf7ea249d..efa48dfa0 100644
--- a/Wan2GP/shared/gradio/gallery.py
+++ b/Wan2GP/shared/gradio/gallery.py
@@ -10,8 +10,8 @@
FilePath = str
ImageLike = Union["PIL.Image.Image", Any]
-IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"}
-VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"}
+IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg", ".PNG", ".JPG", ".JPEG", ".BMP", ".GIF", ".WEBP", ".TIF", ".TIFF", ".JFIF", ".PJPEG"}
+VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv", ".MP4", ".MOV", ".AVI", ".MKV", ".WEBM", ".M4V", ".MPEG", ".MPG", ".OGV" }
def get_state(state):
return state if isinstance(state, dict) else state.value
diff --git a/Wan2GP/shared/gradio/ui_styles.css b/Wan2GP/shared/gradio/ui_styles.css
index f55edf64e..ec3d90cdc 100644
--- a/Wan2GP/shared/gradio/ui_styles.css
+++ b/Wan2GP/shared/gradio/ui_styles.css
@@ -364,3 +364,14 @@ user-select: text;
overflow: visible !important;
}
.tabitem {padding-top:0px}
+.rule-row { margin-bottom: 8px; align-items: center !important; display: flex; gap: 8px; }
+.rule-card { background-color: var(--background-fill-secondary); padding: 8px 12px; border-radius: 6px; border: 1px solid var(--border-color-primary); flex-grow: 1; margin-bottom: 0 !important; }
+.rule-card p { margin-bottom: 0 !important; }
+.delete-btn {
+ min-width: 42px !important;
+ max-width: 42px !important;
+ height: 42px !important;
+ padding: 0 !important;
+ align-self: center;
+}
+#refiner-input-row { align-items: center; }
\ No newline at end of file
diff --git a/Wan2GP/shared/inpainting/lanpaint.py b/Wan2GP/shared/inpainting/lanpaint.py
index 3165e7b01..b5f9b0526 100644
--- a/Wan2GP/shared/inpainting/lanpaint.py
+++ b/Wan2GP/shared/inpainting/lanpaint.py
@@ -58,6 +58,7 @@ def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_
if n_steps is None:
n_steps = self.n_steps
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
+ if out is None: return None
out = _pack_latents(out)
return out
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
@@ -65,16 +66,27 @@ def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma,
cfg_BIG = 1.0
def double_denoise(latents, t):
+ latents_unpacked = latents
latents = _pack_latents(latents)
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
- if noise_pred == None: return None, None
+ if noise_pred is None:
+ return None, None
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
+ if predict_std is None:
+ return None, None
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
if true_cfg_scale == cfg_BIG:
predict_big = predict_std
else:
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
+ if predict_big is None:
+ return None, None
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
+ if self.IS_FLUX or self.IS_FLOW:
+ # Flow/Flux models predict velocity; convert to x0 for LanPaint scoring.
+ t_broadcast = self.add_none_dims(t)
+ predict_std = latents_unpacked - t_broadcast * predict_std
+ predict_big = latents_unpacked - t_broadcast * predict_big
return predict_std, predict_big
if len(sigma.shape) == 0:
@@ -113,6 +125,7 @@ def double_denoise(latents, t):
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
if score_func is None: return None
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
+ if x_t is None: return None
if IS_FLUX or IS_FLOW:
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
@@ -168,6 +181,7 @@ def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=
def Coef_C(x_t):
x0 = self.x0_evalutation(x_t, score, sigma, args)
+ if x0 is None: return None
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
return C
def advance_time(x_t, v, dt, Gamma, A, C, D):
@@ -182,6 +196,8 @@ def advance_time(x_t, v, dt, Gamma, A, C, D):
#v = torch.zeros_like(x_t)
v = None
C = Coef_C(x_t)
+ if C is None:
+ return None, None
#print(torch.squeeze(dtx), torch.squeeze(dty))
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
else:
@@ -190,6 +206,8 @@ def advance_time(x_t, v, dt, Gamma, A, C, D):
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C_new = Coef_C(x_t)
+ if C_new is None:
+ return None, None
v = v + Gamma**0.5 * ( C_new - C) *dt
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
@@ -236,5 +254,8 @@ def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
def x0_evalutation(self, x_t, score, sigma, args):
- x0 = x_t + score(x_t)
- return x0
\ No newline at end of file
+ score = score(x_t)
+ if score is None:
+ return None
+ x0 = x_t + score
+ return x0
diff --git a/Wan2GP/shared/kernels/__init__.py b/Wan2GP/shared/kernels/__init__.py
new file mode 100644
index 000000000..8b1378917
--- /dev/null
+++ b/Wan2GP/shared/kernels/__init__.py
@@ -0,0 +1 @@
+
diff --git a/Wan2GP/shared/kernels/quanto_int8_inject.py b/Wan2GP/shared/kernels/quanto_int8_inject.py
new file mode 100644
index 000000000..e95deec2b
--- /dev/null
+++ b/Wan2GP/shared/kernels/quanto_int8_inject.py
@@ -0,0 +1,804 @@
+from __future__ import annotations
+
+import importlib
+import os
+import atexit
+import traceback
+from types import SimpleNamespace
+from typing import Optional, Tuple
+
+import torch
+
+try:
+ from torch._subclasses.fake_tensor import is_fake as _torch_is_fake_tensor
+except Exception: # pragma: no cover
+ _torch_is_fake_tensor = None
+
+# Env toggles
+_ENV_ENABLE = "WAN2GP_QUANTO_INT8_KERNEL"
+_ENV_DEBUG = "WAN2GP_QUANTO_INT8_DEBUG"
+_ENV_ALLOW_RUNTIME_FALLBACK = "WAN2GP_QUANTO_INT8_ALLOW_RUNTIME_FALLBACK"
+_ENV_NATIVE_FALLBACK_MAX_M = "WAN2GP_QUANTO_INT8_NATIVE_FALLBACK_MAX_M"
+_ENV_PROFILE_SHAPES = "WAN2GP_QUANTO_INT8_PROFILE_SHAPES"
+_ENV_PROFILE_TIME = "WAN2GP_QUANTO_INT8_PROFILE_TIME"
+
+_STARTUP_PRINTED = False
+_RUNTIME_DISABLED = False
+_RUNTIME_DISABLE_REASON = ""
+_RUNTIME_DISABLE_PRINTED = False
+_TRITON_MODULE = None
+_TRITON_DIRECT_FUSED_READY = False
+_TRITON_DIRECT_SCALED_READY = False
+_KERNEL_USED_PRINTED = False
+_SHAPE_PROFILE_ON = False
+_SHAPE_COUNTS_FUSED = {}
+_SHAPE_COUNTS_SCALED = {}
+_TIME_PROFILE_ON = False
+_TIME_PROFILE_EVENTS = []
+_TIME_PROFILE_CPU_MS = 0.0
+_TIME_PROFILE_CALLS = 0
+_DEBUG_OVERRIDE: Optional[bool] = None
+
+_PATCH_STATE = SimpleNamespace(enabled=False, orig_forward=None)
+_OPS_REGISTERED = False
+_OPS_NAMESPACE = "wan2gp_int8"
+_OPS_LIBS = []
+_FUSED_LAUNCH_CACHE_MAX = 4096
+_FUSED_LAUNCH_CACHE = {}
+_FUSED_LAUNCH_CACHE_FIFO = []
+_SCALED_LAUNCH_CACHE_MAX = 4096
+_SCALED_LAUNCH_CACHE = {}
+_SCALED_LAUNCH_CACHE_FIFO = []
+_QBYTES_TENSOR_CLS = None
+_WEIGHT_QBYTES_CLS = None
+_NATIVE_FALLBACK_MAX_M = 0
+
+
+def _encode_dtype(dtype: torch.dtype) -> int:
+ if dtype == torch.float16:
+ return 1
+ if dtype == torch.float32:
+ return 2
+ return 0
+
+
+def _decode_dtype(code: int, fallback: torch.dtype = torch.bfloat16) -> torch.dtype:
+ if int(code) == 1:
+ return torch.float16
+ if int(code) == 2:
+ return torch.float32
+ return torch.bfloat16 if fallback not in (torch.bfloat16, torch.float16, torch.float32) else fallback
+
+
+def _env_flag(name: str, default: str = "1") -> bool:
+ val = os.environ.get(name, default)
+ return str(val).strip().lower() in ("1", "true", "yes", "on")
+
+
+def _env_int(name: str, default: int) -> int:
+ try:
+ return int(os.environ.get(name, str(default)))
+ except Exception:
+ return default
+
+
+def _log(msg: str) -> None:
+ print(f"[WAN2GP][INT8][quanto] {msg}")
+
+
+def _debug(msg: str) -> None:
+ if _DEBUG_OVERRIDE is None:
+ debug_on = _env_flag(_ENV_DEBUG, "0")
+ else:
+ debug_on = bool(_DEBUG_OVERRIDE)
+ if debug_on:
+ _log(msg)
+
+
+def _format_exception_detail(exc: Exception) -> str:
+ try:
+ return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)).strip()
+ except Exception:
+ return str(exc)
+
+
+def _summarize_kernel_error(exc_or_text: Exception | str, max_chars: int = 480) -> str:
+ text = str(exc_or_text)
+ lines = [ln.strip() for ln in text.replace("\r", "\n").split("\n") if ln.strip()]
+ if len(lines) == 0:
+ return "Unknown Triton kernel failure"
+ keywords = (
+ "CompilationError",
+ "shape mismatch",
+ "tl.dot",
+ "K >=",
+ "M >=",
+ "N >=",
+ "Triton",
+ "unsupported",
+ "invalid",
+ "at ",
+ )
+ picked = [ln for ln in lines if any(kw in ln for kw in keywords)]
+ if len(picked) == 0:
+ picked = [lines[-1]]
+ unique: list[str] = []
+ seen = set()
+ for ln in picked:
+ if ln in seen:
+ continue
+ seen.add(ln)
+ unique.append(ln)
+ summary = " | ".join(unique[-4:])
+ if len(summary) > max_chars:
+ summary = summary[: max_chars - 3] + "..."
+ return summary
+
+
+def set_kernel_debug(enabled: Optional[bool] = None) -> None:
+ global _DEBUG_OVERRIDE
+ _DEBUG_OVERRIDE = None if enabled is None else bool(enabled)
+
+
+def _allow_runtime_fallback() -> bool:
+ return _env_flag(_ENV_ALLOW_RUNTIME_FALLBACK, "1")
+
+
+def _startup_status(enabled: bool, detail: str) -> None:
+ global _STARTUP_PRINTED
+ if _STARTUP_PRINTED:
+ return
+ _STARTUP_PRINTED = True
+ if enabled:
+ _log(f"Injected int8 kernels ACTIVE (backend=triton).")
+ else:
+ _log(f"Injected int8 kernels INACTIVE. {detail}")
+
+
+def _disable_runtime(reason: str) -> None:
+ global _RUNTIME_DISABLED, _RUNTIME_DISABLE_REASON, _RUNTIME_DISABLE_PRINTED
+ _RUNTIME_DISABLED = True
+ _RUNTIME_DISABLE_REASON = _summarize_kernel_error(reason)
+ if not _RUNTIME_DISABLE_PRINTED:
+ _RUNTIME_DISABLE_PRINTED = True
+ _log(
+ "Runtime fallback to non-injected Quanto path is now active. Reason: "
+ f"{_RUNTIME_DISABLE_REASON}"
+ )
+
+
+def _init_quanto_tensor_types() -> bool:
+ global _QBYTES_TENSOR_CLS, _WEIGHT_QBYTES_CLS
+ if _QBYTES_TENSOR_CLS is not None and _WEIGHT_QBYTES_CLS is not None:
+ return True
+ try:
+ from optimum.quanto.tensor.qbytes import QBytesTensor
+ from optimum.quanto.tensor.weights.qbytes import WeightQBytesTensor
+ except Exception:
+ return False
+ _QBYTES_TENSOR_CLS = QBytesTensor
+ _WEIGHT_QBYTES_CLS = WeightQBytesTensor
+ return True
+
+
+def _refresh_triton_direct_kernel_flags() -> None:
+ global _TRITON_DIRECT_FUSED_READY, _TRITON_DIRECT_SCALED_READY
+ mod = _TRITON_MODULE
+ triton_ns = getattr(mod, "triton", None) if mod is not None else None
+ has_common = bool(mod is not None and triton_ns is not None and hasattr(triton_ns, "cdiv") and hasattr(mod, "_select_triton_int8_config"))
+ _TRITON_DIRECT_FUSED_READY = bool(has_common and hasattr(mod, "_fused_dynamic_int8_blockscale_gemm_kernel"))
+ _TRITON_DIRECT_SCALED_READY = bool(has_common and hasattr(mod, "_scaled_int8_gemm_kernel"))
+
+
+def _is_qbytes_tensor(t: torch.Tensor) -> bool:
+ if not _init_quanto_tensor_types():
+ return False
+ return isinstance(t, _QBYTES_TENSOR_CLS)
+
+
+def _is_weight_qbytes(t: torch.Tensor) -> bool:
+ if not _init_quanto_tensor_types():
+ return False
+ return isinstance(t, _WEIGHT_QBYTES_CLS)
+
+
+def _flatten_scale(scale: torch.Tensor) -> torch.Tensor:
+ if scale.ndim == 2 and scale.shape[1] == 1:
+ return scale.view(-1)
+ if scale.ndim == 1:
+ return scale
+ return scale.reshape(-1)
+
+
+def _expand_scale_to_rows(scale: torch.Tensor, rows: int, dtype: torch.dtype, device: Optional[torch.device] = None) -> torch.Tensor:
+ scale = _flatten_scale(scale)
+ if scale.numel() == 1:
+ scale = scale.reshape(1).expand(rows)
+ elif scale.numel() != rows:
+ raise RuntimeError(f"Activation scale length mismatch: expected {rows}, got {scale.numel()}")
+ if device is None:
+ return scale.contiguous().to(dtype=dtype)
+ return scale.contiguous().to(device=device, dtype=dtype, non_blocking=True)
+
+
+def _prepare_weight_scale(scale: torch.Tensor, out_features: int, device: torch.device) -> torch.Tensor:
+ flat_scale = _flatten_scale(scale)
+ if flat_scale.numel() != out_features:
+ raise RuntimeError("Weight scale length does not match output features")
+ if flat_scale.device != device:
+ flat_scale = flat_scale.to(device=device, non_blocking=True)
+ if flat_scale.dtype != torch.float32:
+ flat_scale = flat_scale.to(torch.float32)
+ if not flat_scale.is_contiguous():
+ flat_scale = flat_scale.contiguous()
+ return flat_scale
+
+
+def _cache_launch_params(cache: dict, fifo: list, max_size: int, key: tuple[int, int, int, int], params: tuple[int, int, int, int, int, int, int]) -> tuple[int, int, int, int, int, int, int]:
+ if key in cache:
+ return cache[key]
+ cache[key] = params
+ fifo.append(key)
+ if len(fifo) > max_size:
+ stale_key = fifo.pop(0)
+ cache.pop(stale_key, None)
+ return params
+
+
+def _fused_launch_params(m: int, k: int, n: int, device: torch.device) -> tuple[int, int, int, int, int, int, int]:
+ device_index = int(device.index if device.type == "cuda" else -1)
+ key = (device_index, m, k, n)
+ cached = _FUSED_LAUNCH_CACHE.get(key)
+ if cached is not None:
+ return cached
+ mod = _TRITON_MODULE
+ if mod is None:
+ raise RuntimeError("Triton backend not initialized")
+ block_m, block_n, block_k, num_warps, num_stages = mod._select_triton_int8_config(m, k, n, device=device, kernel_kind="fused")
+ grid_m = mod.triton.cdiv(m, block_m)
+ grid_n = mod.triton.cdiv(n, block_n)
+ params = (block_m, block_n, block_k, num_warps, num_stages, grid_m, grid_n)
+ return _cache_launch_params(_FUSED_LAUNCH_CACHE, _FUSED_LAUNCH_CACHE_FIFO, _FUSED_LAUNCH_CACHE_MAX, key, params)
+
+
+def _scaled_launch_params(m: int, k: int, n: int, device: torch.device) -> tuple[int, int, int, int, int, int, int]:
+ device_index = int(device.index if device.type == "cuda" else -1)
+ key = (device_index, m, k, n)
+ cached = _SCALED_LAUNCH_CACHE.get(key)
+ if cached is not None:
+ return cached
+ mod = _TRITON_MODULE
+ if mod is None:
+ raise RuntimeError("Triton backend not initialized")
+ block_m, block_n, block_k, num_warps, num_stages = mod._select_triton_int8_config(m, k, n, device=device, kernel_kind="scaled")
+ grid_m = mod.triton.cdiv(m, block_m)
+ grid_n = mod.triton.cdiv(n, block_n)
+ params = (block_m, block_n, block_k, num_warps, num_stages, grid_m, grid_n)
+ return _cache_launch_params(_SCALED_LAUNCH_CACHE, _SCALED_LAUNCH_CACHE_FIFO, _SCALED_LAUNCH_CACHE_MAX, key, params)
+
+
+def _is_compiling_graph() -> bool:
+ try:
+ if bool(torch.compiler.is_compiling()):
+ return True
+ except Exception:
+ pass
+ try:
+ import torch._dynamo as _dynamo
+
+ if bool(_dynamo.is_compiling()):
+ return True
+ except Exception:
+ pass
+ return False
+
+
+def _is_fake_tensor(t: object) -> bool:
+ if not torch.is_tensor(t):
+ return False
+ if _torch_is_fake_tensor is not None:
+ return bool(_torch_is_fake_tensor(t))
+ return False
+
+
+def _resolve_output_dtype(input: torch.Tensor, other: torch.Tensor) -> torch.dtype:
+ other_scale = getattr(other, "_scale", None)
+ if torch.is_tensor(other_scale) and other_scale.dtype in (torch.bfloat16, torch.float16, torch.float32):
+ return other_scale.dtype
+ if _is_qbytes_tensor(input):
+ input_scale = getattr(input, "_scale", None)
+ if torch.is_tensor(input_scale) and input_scale.dtype in (torch.bfloat16, torch.float16, torch.float32):
+ return input_scale.dtype
+ if isinstance(input, torch.Tensor) and input.dtype in (torch.bfloat16, torch.float16, torch.float32):
+ return input.dtype
+ return torch.bfloat16
+
+
+def _probe_triton_backend() -> Tuple[Optional[object], str]:
+ try:
+ mod = importlib.import_module("shared.kernels.quanto_int8_triton")
+ except Exception as exc:
+ return None, f"failed to import shared.kernels.quanto_int8_triton ({exc})"
+
+ if not hasattr(mod, "is_available"):
+ return None, "shared.kernels.quanto_int8_triton.is_available() missing"
+ try:
+ if not bool(mod.is_available()):
+ return None, "Triton backend unavailable on this runtime/GPU"
+ except Exception as exc:
+ return None, f"Triton availability check failed ({exc})"
+ return mod, "ok"
+
+
+def _register_int8_ops_for_namespace(ns: str, lib: torch.library.Library) -> None:
+ lib.define("fused_quant_scaled_mm(Tensor x2d, Tensor qweight, Tensor qweight_scale, int out_dtype_code=0) -> Tensor")
+ lib.define("scaled_int8_mm(Tensor a_int8, Tensor b_int8, Tensor a_scale, Tensor b_scale, int out_dtype_code=0) -> Tensor")
+
+ @torch.library.impl(f"{ns}::fused_quant_scaled_mm", "CUDA")
+ def _fused_quant_scaled_mm_cuda(x2d: torch.Tensor, qweight: torch.Tensor, qweight_scale: torch.Tensor, out_dtype_code: int = 0):
+ if _TRITON_MODULE is None:
+ raise RuntimeError("Triton backend not initialized")
+ out_dtype = _decode_dtype(out_dtype_code, x2d.dtype)
+ return _TRITON_MODULE.fused_quant_scaled_mm(x2d, qweight, qweight_scale, out_dtype=out_dtype)
+
+ @torch.library.impl(f"{ns}::scaled_int8_mm", "CUDA")
+ def _scaled_int8_mm_cuda(a_int8: torch.Tensor, b_int8: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, out_dtype_code: int = 0):
+ if _TRITON_MODULE is None:
+ raise RuntimeError("Triton backend not initialized")
+ out_dtype = _decode_dtype(out_dtype_code, torch.bfloat16)
+ return _TRITON_MODULE.scaled_int8_mm(a_int8, b_int8, a_scale, b_scale, out_dtype=out_dtype)
+
+ @torch.library.register_fake(f"{ns}::fused_quant_scaled_mm")
+ def _fused_quant_scaled_mm_fake(x2d: torch.Tensor, qweight: torch.Tensor, qweight_scale: torch.Tensor, out_dtype_code: int = 0):
+ if x2d.ndim != 2 or qweight.ndim != 2:
+ raise RuntimeError("fused_quant_scaled_mm expects 2D tensors")
+ out_dtype = _decode_dtype(out_dtype_code, x2d.dtype)
+ return x2d.new_empty((x2d.shape[0], qweight.shape[0]), dtype=out_dtype)
+
+ @torch.library.register_fake(f"{ns}::scaled_int8_mm")
+ def _scaled_int8_mm_fake(a_int8: torch.Tensor, b_int8: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, out_dtype_code: int = 0):
+ if a_int8.ndim != 2 or b_int8.ndim != 2:
+ raise RuntimeError("scaled_int8_mm expects 2D tensors")
+ out_dtype = _decode_dtype(out_dtype_code, torch.bfloat16)
+ return a_int8.new_empty((a_int8.shape[0], b_int8.shape[0]), dtype=out_dtype)
+
+
+def _ensure_compile_safe_ops() -> None:
+ global _OPS_REGISTERED, _OPS_LIBS
+ if _OPS_REGISTERED:
+ return
+
+ libs = []
+ try:
+ lib = torch.library.Library(_OPS_NAMESPACE, "DEF")
+ libs.append(lib)
+ _register_int8_ops_for_namespace(_OPS_NAMESPACE, lib)
+ except Exception:
+ # Namespace/op may already exist in long-lived processes.
+ op_ns = getattr(torch.ops, _OPS_NAMESPACE, None)
+ has_ops = bool(
+ op_ns is not None
+ and hasattr(op_ns, "fused_quant_scaled_mm")
+ and hasattr(op_ns, "scaled_int8_mm")
+ )
+ if not has_ops:
+ raise
+ _OPS_LIBS = libs
+
+ _OPS_REGISTERED = True
+
+
+def _fused_quant_scaled_mm_direct_call(x2d: torch.Tensor, qweight: torch.Tensor, qweight_scale: torch.Tensor, output_dtype: torch.dtype) -> torch.Tensor:
+ mod = _TRITON_MODULE
+ if mod is None:
+ raise RuntimeError("Triton backend not initialized")
+ if not _TRITON_DIRECT_FUSED_READY:
+ return mod.fused_quant_scaled_mm(x2d, qweight, qweight_scale, out_dtype=output_dtype)
+
+ m, k = x2d.shape
+ n, k2 = qweight.shape
+ if k != k2:
+ raise RuntimeError(f"Triton int8 GEMM shape mismatch: x={x2d.shape}, w={qweight.shape}")
+
+ block_m, block_n, block_k, num_warps, num_stages, grid_m, grid_n = _fused_launch_params(m, k, n, x2d.device)
+ out = torch.empty((m, n), device=x2d.device, dtype=output_dtype)
+ try:
+ mod._fused_dynamic_int8_blockscale_gemm_kernel[(grid_m, grid_n)](
+ x2d,
+ qweight,
+ qweight_scale,
+ out,
+ m,
+ n,
+ k,
+ x2d.stride(0),
+ x2d.stride(1),
+ qweight.stride(0),
+ qweight.stride(1),
+ out.stride(0),
+ out.stride(1),
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ except Exception as exc:
+ raise RuntimeError(
+ "Triton fused int8 kernel launch failed "
+ f"(shape m={m}, k={k}, n={n}; tile=({block_m},{block_n},{block_k}); "
+ f"warps={num_warps}, stages={num_stages}). {exc}"
+ ) from exc
+ return out
+
+
+def _scaled_int8_mm_direct_call(
+ a_int8: torch.Tensor,
+ b_int8: torch.Tensor,
+ a_scale: torch.Tensor,
+ b_scale: torch.Tensor,
+ output_dtype: torch.dtype,
+) -> torch.Tensor:
+ mod = _TRITON_MODULE
+ if mod is None:
+ raise RuntimeError("Triton backend not initialized")
+ if not _TRITON_DIRECT_SCALED_READY:
+ return mod.scaled_int8_mm(a_int8, b_int8, a_scale, b_scale, out_dtype=output_dtype)
+
+ m, k = a_int8.shape
+ n, k2 = b_int8.shape
+ if k != k2:
+ raise RuntimeError(f"Triton int8 GEMM shape mismatch: a={a_int8.shape}, w={b_int8.shape}")
+
+ block_m, block_n, block_k, num_warps, num_stages, grid_m, grid_n = _scaled_launch_params(m, k, n, a_int8.device)
+ out = torch.empty((m, n), device=a_int8.device, dtype=output_dtype)
+ try:
+ mod._scaled_int8_gemm_kernel[(grid_m, grid_n)](
+ a_int8,
+ b_int8,
+ a_scale,
+ b_scale,
+ out,
+ m,
+ n,
+ k,
+ a_int8.stride(0),
+ a_int8.stride(1),
+ b_int8.stride(0),
+ b_int8.stride(1),
+ out.stride(0),
+ out.stride(1),
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ except Exception as exc:
+ raise RuntimeError(
+ "Triton scaled int8 kernel launch failed "
+ f"(shape m={m}, k={k}, n={n}; tile=({block_m},{block_n},{block_k}); "
+ f"warps={num_warps}, stages={num_stages}). {exc}"
+ ) from exc
+ return out
+
+
+def _fused_quant_scaled_mm_call(x2d: torch.Tensor, qweight: torch.Tensor, qweight_scale: torch.Tensor, output_dtype: torch.dtype) -> torch.Tensor:
+ if _TRITON_MODULE is not None and not _is_compiling_graph() and not (_is_fake_tensor(x2d) or _is_fake_tensor(qweight) or _is_fake_tensor(qweight_scale)):
+ return _fused_quant_scaled_mm_direct_call(x2d, qweight, qweight_scale, output_dtype)
+ return torch.ops.wan2gp_int8.fused_quant_scaled_mm(x2d, qweight, qweight_scale, _encode_dtype(output_dtype))
+
+
+def _scaled_int8_mm_call(
+ a_int8: torch.Tensor,
+ b_int8: torch.Tensor,
+ a_scale: torch.Tensor,
+ b_scale: torch.Tensor,
+ output_dtype: torch.dtype,
+) -> torch.Tensor:
+ if _TRITON_MODULE is not None and not _is_compiling_graph() and not ( _is_fake_tensor(a_int8) or _is_fake_tensor(b_int8) or _is_fake_tensor(a_scale) or _is_fake_tensor(b_scale)):
+ return _scaled_int8_mm_direct_call(a_int8, b_int8, a_scale, b_scale, output_dtype)
+ return torch.ops.wan2gp_int8.scaled_int8_mm(a_int8, b_int8, a_scale, b_scale, _encode_dtype(output_dtype))
+
+
+def _use_int8_kernel(input: torch.Tensor, other: torch.Tensor) -> bool:
+ if _RUNTIME_DISABLED:
+ return False
+ if _TRITON_MODULE is None:
+ return False
+ if not _is_weight_qbytes(other):
+ return False
+ if other._data.dtype != torch.int8:
+ return False
+ if not other._data.is_cuda:
+ return False
+
+ if _is_qbytes_tensor(input):
+ return input._data.dtype == torch.int8 and input._data.is_cuda
+ return input.is_cuda and input.dtype in (torch.bfloat16, torch.float16, torch.float32)
+
+
+def _activation_rows(input_shape: torch.Size) -> int:
+ rows = 1
+ for dim in input_shape[:-1]:
+ rows *= int(dim)
+ return rows
+
+
+def _prefer_native_quanto_path(input: torch.Tensor) -> bool:
+ if _NATIVE_FALLBACK_MAX_M < 0:
+ return False
+ return _activation_rows(input.shape) <= _NATIVE_FALLBACK_MAX_M
+
+
+def _mark_kernel_used() -> None:
+ global _KERNEL_USED_PRINTED
+ if _KERNEL_USED_PRINTED:
+ return
+ _KERNEL_USED_PRINTED = True
+ _log("Injected Triton int8 kernels are being used.")
+
+
+def _int8_linear_forward_triton_dense_fast(ctx, input: torch.Tensor, other: torch.Tensor, bias: Optional[torch.Tensor]):
+ ctx.save_for_backward(input, other)
+ if _TRITON_MODULE is None:
+ raise RuntimeError("Triton backend not initialized")
+ _mark_kernel_used()
+
+ input_shape = input.shape
+ in_features = int(input_shape[-1])
+ out_features = int(other.shape[0])
+ a_2d = input.reshape(-1, in_features)
+ if not a_2d.is_contiguous():
+ a_2d = a_2d.contiguous()
+ b_int8 = other._data
+ if not b_int8.is_contiguous():
+ b_int8 = b_int8.contiguous()
+ b_scale = _prepare_weight_scale(other._scale, out_features, b_int8.device)
+
+ if _SHAPE_PROFILE_ON:
+ key = (int(a_2d.shape[0]), int(in_features), int(out_features))
+ _SHAPE_COUNTS_FUSED[key] = _SHAPE_COUNTS_FUSED.get(key, 0) + 1
+ if _TIME_PROFILE_ON and torch.cuda.is_available():
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ out_2d = _fused_quant_scaled_mm_call(a_2d, b_int8, b_scale, input.dtype)
+ end.record()
+ _TIME_PROFILE_EVENTS.append((start, end))
+ else:
+ out_2d = _fused_quant_scaled_mm_call(a_2d, b_int8, b_scale, input.dtype)
+
+ out = out_2d.reshape(input_shape[:-1] + (out_features,))
+ if bias is not None:
+ out = out + bias
+ return out
+
+
+def _int8_linear_forward_triton(ctx, input: torch.Tensor, other: torch.Tensor, bias: Optional[torch.Tensor]):
+ ctx.save_for_backward(input, other)
+ if _TRITON_MODULE is None:
+ raise RuntimeError("Triton backend not initialized")
+ _mark_kernel_used()
+
+ input_shape = input.shape
+ in_features = int(input_shape[-1])
+ out_features = int(other.shape[0])
+ b_int8 = other._data
+ if not b_int8.is_contiguous():
+ b_int8 = b_int8.contiguous()
+ b_scale = _prepare_weight_scale(other._scale, out_features, b_int8.device)
+ output_dtype = _resolve_output_dtype(input, other)
+ input_is_qbytes = _is_qbytes_tensor(input)
+
+ if input_is_qbytes:
+ a_int8 = input._data.reshape(-1, in_features)
+ if a_int8.dtype != torch.int8:
+ raise RuntimeError("QBytes input must be int8 for injected path")
+ if not a_int8.is_contiguous():
+ a_int8 = a_int8.contiguous()
+ a_scale = _expand_scale_to_rows(input._scale, a_int8.shape[0], torch.float32, device=a_int8.device)
+ if _SHAPE_PROFILE_ON:
+ key = (int(a_int8.shape[0]), int(in_features), int(out_features))
+ _SHAPE_COUNTS_SCALED[key] = _SHAPE_COUNTS_SCALED.get(key, 0) + 1
+ if _TIME_PROFILE_ON and torch.cuda.is_available():
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ out_2d = _scaled_int8_mm_call(a_int8, b_int8, a_scale, b_scale, output_dtype)
+ end.record()
+ _TIME_PROFILE_EVENTS.append((start, end))
+ else:
+ out_2d = _scaled_int8_mm_call(a_int8, b_int8, a_scale, b_scale, output_dtype)
+ else:
+ a_2d = input.reshape(-1, in_features)
+ if not a_2d.is_contiguous():
+ a_2d = a_2d.contiguous()
+ if _SHAPE_PROFILE_ON:
+ key = (int(a_2d.shape[0]), int(in_features), int(out_features))
+ _SHAPE_COUNTS_FUSED[key] = _SHAPE_COUNTS_FUSED.get(key, 0) + 1
+ if _TIME_PROFILE_ON and torch.cuda.is_available():
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ out_2d = _fused_quant_scaled_mm_call(a_2d, b_int8, b_scale, output_dtype)
+ end.record()
+ _TIME_PROFILE_EVENTS.append((start, end))
+ else:
+ out_2d = _fused_quant_scaled_mm_call(a_2d, b_int8, b_scale, output_dtype)
+
+ out = out_2d.reshape(input_shape[:-1] + (out_features,))
+ if bias is not None:
+ out = out + bias
+ return out
+
+
+def enable_quanto_int8_kernel(triton_mod=None) -> bool:
+ global _TRITON_MODULE, _NATIVE_FALLBACK_MAX_M
+ if _PATCH_STATE.enabled:
+ return True
+
+ try:
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
+ except Exception as exc:
+ _debug(f"cannot import optimum.quanto qbytes ({exc})")
+ return False
+
+ if triton_mod is None:
+ triton_mod, _ = _probe_triton_backend()
+ if triton_mod is None:
+ return False
+ _TRITON_MODULE = triton_mod
+ _refresh_triton_direct_kernel_flags()
+ _NATIVE_FALLBACK_MAX_M = _env_int(_ENV_NATIVE_FALLBACK_MAX_M, 0)
+ _init_quanto_tensor_types()
+ _ensure_compile_safe_ops()
+
+ orig_forward = _qbytes.WeightQBytesLinearFunction.forward
+
+ def forward(ctx, input, other, bias=None):
+ dense_hot_path = (
+ not _RUNTIME_DISABLED
+ and type(input) is torch.Tensor
+ and input.is_cuda
+ and input.dtype in (torch.bfloat16, torch.float16, torch.float32)
+ and _WEIGHT_QBYTES_CLS is not None
+ and isinstance(other, _WEIGHT_QBYTES_CLS)
+ and other._data.dtype == torch.int8
+ and other._data.is_cuda
+ )
+ if dense_hot_path:
+ if _prefer_native_quanto_path(input):
+ return orig_forward(ctx, input, other, bias)
+ try:
+ return _int8_linear_forward_triton_dense_fast(ctx, input, other, bias)
+ except Exception as exc:
+ short_reason = _summarize_kernel_error(exc)
+ if _allow_runtime_fallback():
+ _disable_runtime(short_reason)
+ _debug(f"Full Triton failure detail:\n{_format_exception_detail(exc)}")
+ return orig_forward(ctx, input, other, bias)
+ full_detail = _format_exception_detail(exc)
+ raise RuntimeError(
+ "Injected Triton int8 kernel failed. "
+ f"Set {_ENV_ALLOW_RUNTIME_FALLBACK}=1 to force fallback to non-injected Quanto path. "
+ f"Reason: {short_reason}\n"
+ f"Full Triton error details:\n{full_detail}"
+ ) from exc
+
+ if not _use_int8_kernel(input, other):
+ return orig_forward(ctx, input, other, bias)
+ if _prefer_native_quanto_path(input):
+ return orig_forward(ctx, input, other, bias)
+ try:
+ return _int8_linear_forward_triton(ctx, input, other, bias)
+ except Exception as exc:
+ short_reason = _summarize_kernel_error(exc)
+ if _allow_runtime_fallback():
+ _disable_runtime(short_reason)
+ _debug(f"Full Triton failure detail:\n{_format_exception_detail(exc)}")
+ return orig_forward(ctx, input, other, bias)
+ full_detail = _format_exception_detail(exc)
+ raise RuntimeError(
+ "Injected Triton int8 kernel failed. "
+ f"Set {_ENV_ALLOW_RUNTIME_FALLBACK}=1 to force fallback to non-injected Quanto path. "
+ f"Reason: {short_reason}\n"
+ f"Full Triton error details:\n{full_detail}"
+ ) from exc
+
+ _qbytes.WeightQBytesLinearFunction.forward = staticmethod(forward)
+ _PATCH_STATE.enabled = True
+ _PATCH_STATE.orig_forward = orig_forward
+ return True
+
+
+def disable_quanto_int8_kernel(notify_disabled = False) -> bool:
+ global _FUSED_LAUNCH_CACHE, _FUSED_LAUNCH_CACHE_FIFO, _SCALED_LAUNCH_CACHE, _SCALED_LAUNCH_CACHE_FIFO
+ global _TRITON_DIRECT_FUSED_READY, _TRITON_DIRECT_SCALED_READY, _STARTUP_PRINTED
+
+ if not _PATCH_STATE.enabled:
+ return False
+ from optimum.quanto.tensor.weights import qbytes as _qbytes
+
+ _qbytes.WeightQBytesLinearFunction.forward = staticmethod(_PATCH_STATE.orig_forward)
+ _PATCH_STATE.enabled = False
+ _PATCH_STATE.orig_forward = None
+ _FUSED_LAUNCH_CACHE = {}
+ _FUSED_LAUNCH_CACHE_FIFO = []
+ _SCALED_LAUNCH_CACHE = {}
+ _SCALED_LAUNCH_CACHE_FIFO = []
+ _TRITON_DIRECT_FUSED_READY = False
+ _TRITON_DIRECT_SCALED_READY = False
+ _STARTUP_PRINTED = False
+ if notify_disabled:
+ _startup_status(False, f"disabled by User.")
+ return True
+
+
+def maybe_enable_quanto_int8_kernel(verbose_level: Optional[int] = None) -> bool:
+ global _SHAPE_PROFILE_ON, _TIME_PROFILE_ON, _STARTUP_PRINTED
+
+ _STARTUP_PRINTED = False
+ verbose_debug: Optional[bool] = None
+ if verbose_level is not None:
+ try:
+ verbose_debug = int(verbose_level) >= 2
+ except Exception:
+ verbose_debug = False
+ set_kernel_debug(verbose_debug)
+
+ if not _env_flag(_ENV_ENABLE, "1"):
+ # _startup_status(False, f"disabled by {_ENV_ENABLE}=0; using non-injected Quanto path.")
+ return False
+
+ triton_mod, reason = _probe_triton_backend()
+ if triton_mod is None:
+ # _startup_status(False, f"{reason}; using non-injected Quanto path.")
+ return False
+ set_triton_debug = getattr(triton_mod, "set_autotune_debug", None)
+ if callable(set_triton_debug):
+ set_triton_debug(verbose_debug)
+
+ if not enable_quanto_int8_kernel(triton_mod=triton_mod):
+ _startup_status(False, "failed to patch Quanto linear forward; using non-injected Quanto path.")
+ return False
+
+ _SHAPE_PROFILE_ON = _env_flag(_ENV_PROFILE_SHAPES, "0")
+ _TIME_PROFILE_ON = _env_flag(_ENV_PROFILE_TIME, "0")
+ _startup_status(
+ True,
+ (
+ "Triton int8 kernels will be used for Quanto qint8 linear layers "
+ "(QBytes int8 activations + fused dynamic int8 activation quantization)."
+ ),
+ )
+ return True
+
+
+
+def _print_shape_profile() -> None:
+ if not _SHAPE_PROFILE_ON and not _TIME_PROFILE_ON:
+ return
+ if _SHAPE_PROFILE_ON and _SHAPE_COUNTS_FUSED:
+ top_fused = sorted(_SHAPE_COUNTS_FUSED.items(), key=lambda kv: kv[1], reverse=True)[:10]
+ _log(f"Fused shape profile (top {len(top_fused)}): {top_fused}")
+ if _SHAPE_PROFILE_ON and _SHAPE_COUNTS_SCALED:
+ top_scaled = sorted(_SHAPE_COUNTS_SCALED.items(), key=lambda kv: kv[1], reverse=True)[:10]
+ _log(f"Scaled shape profile (top {len(top_scaled)}): {top_scaled}")
+
+ if _TIME_PROFILE_ON:
+ total_ms = 0.0
+ calls = 0
+ if _TIME_PROFILE_EVENTS:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ for start, end in _TIME_PROFILE_EVENTS:
+ total_ms += float(start.elapsed_time(end))
+ calls = len(_TIME_PROFILE_EVENTS)
+ else:
+ total_ms = _TIME_PROFILE_CPU_MS
+ calls = _TIME_PROFILE_CALLS
+ _log(f"Triton kernel time profile: {total_ms / 1000.0:.3f}s over {calls} calls")
+
+
+atexit.register(_print_shape_profile)
diff --git a/Wan2GP/shared/kernels/quanto_int8_triton.py b/Wan2GP/shared/kernels/quanto_int8_triton.py
new file mode 100644
index 000000000..c06092708
--- /dev/null
+++ b/Wan2GP/shared/kernels/quanto_int8_triton.py
@@ -0,0 +1,1254 @@
+from __future__ import annotations
+
+import atexit
+import json
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ from triton.language.extra.cuda import libdevice as tl_libdevice
+
+ _TRITON_AVAILABLE = True
+except Exception: # pragma: no cover
+ triton = None # type: ignore
+ tl = None # type: ignore
+ tl_libdevice = None # type: ignore
+ _TRITON_AVAILABLE = False
+
+
+_ENV_ENABLE = "WAN2GP_QUANTO_INT8_TRITON"
+_ENV_AUTOTUNE_ENABLE = "WAN2GP_QUANTO_INT8_AUTOTUNE"
+_ENV_AUTOTUNE_DEBUG = "WAN2GP_QUANTO_INT8_AUTOTUNE_DEBUG"
+_ENV_AUTOTUNE_MAX_M = "WAN2GP_QUANTO_INT8_AUTOTUNE_MAX_M"
+_ENV_AUTOTUNE_MAX_SHAPES = "WAN2GP_QUANTO_INT8_AUTOTUNE_MAX_SHAPES"
+_ENV_AUTOTUNE_WARMUP = "WAN2GP_QUANTO_INT8_AUTOTUNE_WARMUP"
+_ENV_AUTOTUNE_ITERS = "WAN2GP_QUANTO_INT8_AUTOTUNE_ITERS"
+_ENV_AUTOTUNE_MIN_SPEEDUP = "WAN2GP_QUANTO_INT8_AUTOTUNE_MIN_SPEEDUP"
+_ENV_AUTOTUNE_CACHE = "WAN2GP_QUANTO_INT8_AUTOTUNE_CACHE"
+_ENV_AUTOTUNE_VALIDATE = "WAN2GP_QUANTO_INT8_AUTOTUNE_VALIDATE"
+_ENV_AUTOTUNE_MAX_ABS_ERR = "WAN2GP_QUANTO_INT8_AUTOTUNE_MAX_ABS_ERR"
+_ENV_AUTOTUNE_MAX_REL_ERR = "WAN2GP_QUANTO_INT8_AUTOTUNE_MAX_REL_ERR"
+_ENV_AUTOTUNE_LOCK_FUSED_BLOCK_K = "WAN2GP_QUANTO_INT8_AUTOTUNE_LOCK_FUSED_BLOCK_K"
+_IS_AVAILABLE = None
+_CONFIG_LEN = 5
+_AUTOTUNE_CACHE_LOADED = False
+_AUTOTUNE_CACHE_DIRTY = False
+_AUTOTUNE_CONFIG_CACHE: dict[str, tuple[int, int, int, int, int]] = {}
+_AUTOTUNE_SESSION_CACHE: dict[tuple[int, str, str], tuple[int, int, int, int, int]] = {}
+_AUTOTUNE_SEEN_SLOTS: set[tuple[int, str, str]] = set()
+_AUTOTUNE_SLOTS_TUNED = 0
+_AUTOTUNE_DEBUG_OVERRIDE: Optional[bool] = None
+
+# Tuned decode-time configs reused from nanovllm int8 kernels.
+_TRITON_SMALL_M_CONFIGS = {
+ (2048, 4096): (2, 32, 256, 8, 5),
+ (2048, 2048): (1, 32, 64, 2, 4),
+ (2048, 12288): (8, 64, 256, 8, 4),
+ (6144, 2048): (1, 32, 512, 4, 5),
+}
+_TRITON_TINY_M_SHAPE_CONFIGS = {
+ (2, 3072, 3072): (2, 128, 64, 8, 4),
+ (4, 3072, 3072): (4, 256, 64, 4, 4),
+ (2, 3072, 1024): (2, 256, 128, 8, 4),
+ (4, 3072, 1024): (2, 256, 128, 8, 4),
+ (2, 3072, 1536): (2, 64, 64, 4, 4),
+ (4, 3072, 1536): (2, 256, 128, 8, 4),
+ (2, 3072, 8192): (2, 128, 64, 8, 4),
+ (4, 3072, 8192): (2, 128, 64, 8, 4),
+ (2, 8192, 3072): (4, 128, 64, 4, 4),
+ (4, 8192, 3072): (2, 128, 128, 8, 4),
+}
+_TRITON_TINY_M_PAIR_CONFIGS = {
+ (3072, 3072): (2, 128, 64, 8, 4),
+ (3072, 1024): (2, 256, 128, 8, 4),
+ (3072, 1536): (2, 64, 64, 4, 4),
+ (3072, 8192): (2, 128, 64, 8, 4),
+ (8192, 3072): (4, 128, 64, 4, 4),
+}
+_TRITON_SMALL_M_DEFAULT = (4, 256, 64, 4, 4)
+_TRITON_SMALL_M_K3072_DEFAULT = (2, 64, 64, 4, 4)
+_TRITON_MID_M_DEFAULT = (32, 128, 64, 8, 4)
+_TRITON_LARGE_M_DEFAULT = (64, 128, 64, 8, 4)
+_TRITON_LARGE_M_SHAPE_CONFIGS = {
+ # Hot LTX2 distilled fused-int8 shapes profiled on RTX 50xx.
+ (3840, 2048): (64, 256, 64, 8, 4),
+ (3840, 15360): (64, 256, 64, 8, 4),
+ (3840, 4096): (64, 256, 64, 8, 4),
+ (4096, 3840): (64, 256, 64, 8, 4),
+ (15360, 3840): (64, 256, 64, 8, 4),
+ # Hot WAN2 I2V enhanced-lightning fused-int8 shapes (M ~= 512).
+ (4096, 4096): (64, 256, 64, 8, 4),
+ (4096, 10240): (64, 256, 64, 8, 4),
+ (10240, 4096): (64, 256, 64, 8, 4),
+}
+
+_AUTOTUNE_SLOT_REPS = {
+ "tiny_k3072_default": ((2, 3072, 2048), (4, 3072, 4096)),
+ "tiny_default": ((2, 4096, 1536), (4, 4096, 1536)),
+ "mid_default": ((32, 2048, 4096), (32, 3072, 3072), (32, 4096, 4096)),
+ "large_n_ge_2048": ((512, 4096, 4096), (3840, 3840, 4096)),
+ "large_default": ((128, 4096, 1024), (192, 3072, 1536)),
+}
+
+
+def _env_flag(name: str, default: str = "1") -> bool:
+ val = os.environ.get(name, default)
+ return str(val).strip().lower() in ("1", "true", "yes", "on")
+
+
+def _parse_version(ver: str) -> tuple[int, int]:
+ try:
+ parts = ver.split(".")
+ return int(parts[0]), int(parts[1])
+ except Exception:
+ return (0, 0)
+
+
+def _env_int(name: str, default: int) -> int:
+ try:
+ return int(os.environ.get(name, str(default)))
+ except Exception:
+ return default
+
+
+def _env_float(name: str, default: float) -> float:
+ try:
+ return float(os.environ.get(name, str(default)))
+ except Exception:
+ return default
+
+
+def _autotune_debug(msg: str) -> None:
+ if _AUTOTUNE_DEBUG_OVERRIDE is None:
+ debug_on = _env_flag(_ENV_AUTOTUNE_DEBUG, "0")
+ else:
+ debug_on = bool(_AUTOTUNE_DEBUG_OVERRIDE)
+ if debug_on:
+ print(f"[WAN2GP][INT8][autotune] {msg}")
+
+
+def set_autotune_debug(enabled: Optional[bool] = None) -> None:
+ global _AUTOTUNE_DEBUG_OVERRIDE
+ _AUTOTUNE_DEBUG_OVERRIDE = None if enabled is None else bool(enabled)
+
+
+def _runtime_compatible() -> bool:
+ if not (_TRITON_AVAILABLE and torch.cuda.is_available()):
+ return False
+ try:
+ cc_major, _ = torch.cuda.get_device_capability()
+ except Exception:
+ return False
+
+ # Triton int8 dot kernels require tensor-core generation GPUs.
+ if cc_major < 8:
+ return False
+
+ # Keep SM120 safe on older Triton builds that abort at compile time.
+ triton_ver = _parse_version(getattr(triton, "__version__", "0.0"))
+ if cc_major >= 12 and triton_ver < (3, 6):
+ return False
+
+ return True
+
+
+def is_available() -> bool:
+ global _IS_AVAILABLE
+ if _IS_AVAILABLE is None:
+ _IS_AVAILABLE = bool(_runtime_compatible() and _env_flag(_ENV_ENABLE, "1"))
+ return _IS_AVAILABLE
+
+
+def _select_static_triton_int8_config(m: int, k: int, n: int) -> tuple[int, int, int, int, int]:
+ if m <= 4:
+ cfg = _TRITON_TINY_M_SHAPE_CONFIGS.get((m, k, n))
+ if cfg is not None:
+ return cfg
+ cfg = _TRITON_TINY_M_PAIR_CONFIGS.get((k, n))
+ if cfg is not None:
+ return cfg
+ cfg = _TRITON_SMALL_M_CONFIGS.get((k, n))
+ if cfg is not None:
+ return cfg
+ if k == 3072:
+ return _TRITON_SMALL_M_K3072_DEFAULT
+ return _TRITON_SMALL_M_DEFAULT
+ if m < 64:
+ return _TRITON_MID_M_DEFAULT
+ if m >= 256:
+ cfg = _TRITON_LARGE_M_SHAPE_CONFIGS.get((k, n))
+ if cfg is not None:
+ return cfg
+ if n >= 2048:
+ return (64, 256, 64, 8, 4)
+ return _TRITON_LARGE_M_DEFAULT
+
+
+def _dedup_shapes(shapes: tuple[tuple[int, int, int], ...]) -> tuple[tuple[int, int, int], ...]:
+ out: list[tuple[int, int, int]] = []
+ seen: set[tuple[int, int, int]] = set()
+ for shape in shapes:
+ if not isinstance(shape, (list, tuple)) or len(shape) != 3:
+ continue
+ try:
+ m, k, n = (int(shape[0]), int(shape[1]), int(shape[2]))
+ except Exception:
+ continue
+ if m <= 0 or k <= 0 or n <= 0:
+ continue
+ key = (m, k, n)
+ if key in seen:
+ continue
+ seen.add(key)
+ out.append(key)
+ return tuple(out)
+
+
+def _resolve_autotune_slot(m: int, k: int, n: int) -> tuple[str, tuple[tuple[int, int, int], ...]]:
+ baseline = _select_static_triton_int8_config(m, k, n)
+ if m <= 4:
+ if (m, k, n) in _TRITON_TINY_M_SHAPE_CONFIGS:
+ slot_id = f"tiny_shape|m={m}|k={k}|n={n}"
+ reps = ((m, k, n),)
+ elif (k, n) in _TRITON_TINY_M_PAIR_CONFIGS:
+ slot_id = f"tiny_pair|k={k}|n={n}"
+ reps = ((2, k, n), (4, k, n))
+ elif (k, n) in _TRITON_SMALL_M_CONFIGS:
+ slot_id = f"tiny_small_pair|k={k}|n={n}"
+ reps = ((2, k, n), (4, k, n))
+ elif k == 3072:
+ slot_id = "tiny_k3072_default"
+ reps = _AUTOTUNE_SLOT_REPS[slot_id]
+ else:
+ slot_id = "tiny_default"
+ reps = _AUTOTUNE_SLOT_REPS[slot_id]
+ elif m < 64:
+ slot_id = "mid_default"
+ reps = _AUTOTUNE_SLOT_REPS[slot_id]
+ elif m >= 256 and (k, n) in _TRITON_LARGE_M_SHAPE_CONFIGS:
+ slot_id = f"large_hot_pair|k={k}|n={n}"
+ reps = ((512, k, n), (3840, k, n))
+ elif m >= 256 and n >= 2048:
+ slot_id = "large_n_ge_2048"
+ reps = _AUTOTUNE_SLOT_REPS[slot_id]
+ else:
+ slot_id = "large_default"
+ reps = _AUTOTUNE_SLOT_REPS[slot_id]
+
+ filtered = [shape for shape in _dedup_shapes(reps) if _select_static_triton_int8_config(shape[0], shape[1], shape[2]) == baseline]
+ if len(filtered) == 0:
+ filtered = [(m, k, n)]
+ return slot_id, tuple(filtered)
+
+
+def _normalize_config(cfg) -> Optional[tuple[int, int, int, int, int]]:
+ if not isinstance(cfg, (list, tuple)) or len(cfg) != _CONFIG_LEN:
+ return None
+ try:
+ c0, c1, c2, c3, c4 = (int(v) for v in cfg)
+ except Exception:
+ return None
+ if c0 <= 0 or c1 <= 0 or c2 <= 0 or c3 <= 0 or c4 <= 0:
+ return None
+ return (c0, c1, c2, c3, c4)
+
+
+def _autotune_cache_path() -> Path:
+ default_path = str(Path.home() / ".triton" / "autotune" / "wan2gp_int8_autotune_cache.json")
+ return Path(os.environ.get(_ENV_AUTOTUNE_CACHE, default_path)).expanduser()
+
+
+def _load_autotune_cache() -> None:
+ global _AUTOTUNE_CACHE_LOADED, _AUTOTUNE_CONFIG_CACHE
+ if _AUTOTUNE_CACHE_LOADED:
+ return
+ _AUTOTUNE_CACHE_LOADED = True
+ cache_path = _autotune_cache_path()
+ try:
+ payload = json.loads(cache_path.read_text(encoding="utf-8"))
+ except Exception:
+ return
+ entries = payload.get("entries", {})
+ if not isinstance(entries, dict):
+ return
+ parsed = {}
+ for key, raw_cfg in entries.items():
+ if not isinstance(key, str):
+ continue
+ cfg = _normalize_config(raw_cfg)
+ if cfg is not None:
+ parsed[key] = cfg
+ _AUTOTUNE_CONFIG_CACHE = parsed
+
+
+def _save_autotune_cache() -> None:
+ global _AUTOTUNE_CACHE_DIRTY
+ if not _AUTOTUNE_CACHE_DIRTY:
+ return
+ cache_path = _autotune_cache_path()
+ try:
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
+ tmp_path = Path(f"{cache_path}.tmp")
+ payload = {"version": 1, "entries": {key: list(cfg) for key, cfg in _AUTOTUNE_CONFIG_CACHE.items()}}
+ tmp_path.write_text(json.dumps(payload, sort_keys=True), encoding="utf-8")
+ tmp_path.replace(cache_path)
+ _AUTOTUNE_CACHE_DIRTY = False
+ except Exception as exc:
+ _autotune_debug(f"cache write failed: {exc}")
+
+
+def _device_index(device: Optional[torch.device]) -> int:
+ if device is not None and device.type == "cuda" and device.index is not None:
+ return int(device.index)
+ return int(torch.cuda.current_device())
+
+
+def _device_fingerprint(device_index: int) -> str:
+ props = torch.cuda.get_device_properties(device_index)
+ triton_ver = getattr(triton, "__version__", "0.0")
+ return (
+ f"{props.name}|cc={props.major}.{props.minor}|sm={props.multi_processor_count}|"
+ f"torch={torch.__version__}|triton={triton_ver}"
+ )
+
+
+def _autotune_slot_cache_key(device_index: int, kernel_kind: str, slot_id: str) -> str:
+ return f"{_device_fingerprint(device_index)}|{kernel_kind}|slot={slot_id}"
+
+
+def _autotune_legacy_shape_cache_key(device_index: int, kernel_kind: str, m: int, k: int, n: int) -> str:
+ return f"{_device_fingerprint(device_index)}|{kernel_kind}|{m}|{k}|{n}"
+
+
+def _get_cached_config(device_index: int, kernel_kind: str, slot_id: str, m: int, k: int, n: int) -> Optional[tuple[int, int, int, int, int]]:
+ global _AUTOTUNE_CACHE_DIRTY
+ _load_autotune_cache()
+ slot_key = _autotune_slot_cache_key(device_index, kernel_kind, slot_id)
+ cfg = _AUTOTUNE_CONFIG_CACHE.get(slot_key)
+ if cfg is not None:
+ return cfg
+ legacy_key = _autotune_legacy_shape_cache_key(device_index, kernel_kind, m, k, n)
+ legacy_cfg = _AUTOTUNE_CONFIG_CACHE.get(legacy_key)
+ if legacy_cfg is None:
+ return None
+ _AUTOTUNE_CONFIG_CACHE[slot_key] = legacy_cfg
+ _AUTOTUNE_CACHE_DIRTY = True
+ return legacy_cfg
+
+
+def _set_cached_config(device_index: int, kernel_kind: str, slot_id: str, cfg: tuple[int, int, int, int, int]) -> None:
+ global _AUTOTUNE_CACHE_DIRTY
+ _load_autotune_cache()
+ key = _autotune_slot_cache_key(device_index, kernel_kind, slot_id)
+ if _AUTOTUNE_CONFIG_CACHE.get(key) == cfg:
+ return
+ _AUTOTUNE_CONFIG_CACHE[key] = cfg
+ _AUTOTUNE_CACHE_DIRTY = True
+
+
+def _drop_cached_config(device_index: int, kernel_kind: str, slot_id: str, m: int, k: int, n: int) -> None:
+ global _AUTOTUNE_CACHE_DIRTY
+ _load_autotune_cache()
+ keys = (
+ _autotune_slot_cache_key(device_index, kernel_kind, slot_id),
+ _autotune_legacy_shape_cache_key(device_index, kernel_kind, m, k, n),
+ )
+ removed = False
+ for key in keys:
+ if key in _AUTOTUNE_CONFIG_CACHE:
+ del _AUTOTUNE_CONFIG_CACHE[key]
+ removed = True
+ if removed:
+ _AUTOTUNE_CACHE_DIRTY = True
+
+
+def _config_compatible_with_baseline(
+ kind: str,
+ baseline: tuple[int, int, int, int, int],
+ cfg: tuple[int, int, int, int, int],
+) -> bool:
+ if kind == "fused" and _env_flag(_ENV_AUTOTUNE_LOCK_FUSED_BLOCK_K, "1"):
+ # Fused blockscale kernel computes row scales per K-chunk; changing block_k changes numerics.
+ return int(cfg[2]) == int(baseline[2])
+ return True
+
+
+def _candidate_configs(
+ baseline: tuple[int, int, int, int, int],
+ m: int,
+ k: int,
+ n: int,
+ *,
+ kind: str,
+) -> list[tuple[int, int, int, int, int]]:
+ out = [baseline]
+ if m <= 4:
+ out.extend(
+ [
+ (1, 64, 64, 2, 4),
+ (1, 128, 64, 4, 4),
+ (2, 64, 64, 4, 4),
+ (2, 128, 64, 4, 4),
+ (2, 128, 128, 8, 4),
+ (2, 256, 64, 8, 4),
+ (4, 128, 64, 4, 4),
+ (4, 256, 64, 4, 4),
+ (8, 128, 64, 4, 4),
+ ]
+ )
+ shape_cfg = _TRITON_TINY_M_SHAPE_CONFIGS.get((m, k, n))
+ if shape_cfg is not None:
+ out.append(shape_cfg)
+ pair_cfg = _TRITON_TINY_M_PAIR_CONFIGS.get((k, n))
+ if pair_cfg is not None:
+ out.append(pair_cfg)
+ elif m <= 16:
+ out.extend([(8, 128, 64, 4, 4), (8, 256, 64, 8, 4), (16, 128, 64, 8, 4), (16, 256, 64, 8, 4), (32, 128, 64, 8, 4)])
+ dedup: list[tuple[int, int, int, int, int]] = []
+ seen = set()
+ for cfg in out:
+ norm = _normalize_config(cfg)
+ if norm is None or norm in seen:
+ continue
+ if not _config_compatible_with_baseline(kind, baseline, norm):
+ continue
+ seen.add(norm)
+ dedup.append(norm)
+ return dedup
+
+
+def _looks_like_unsupported_dot_tile(cfg: tuple[int, int, int, int, int]) -> bool:
+ block_m, block_n, block_k, _, _ = cfg
+ # Triton int8 dot kernels can reject tiny tiles on some runtimes (e.g. decode-time M<=4).
+ return block_m < 16 or block_n < 16 or block_k < 32
+
+
+def _compile_recovery_candidates(
+ kind: str,
+ baseline: tuple[int, int, int, int, int],
+ preferred: tuple[int, int, int, int, int],
+ m: int,
+ k: int,
+ n: int,
+) -> list[tuple[int, int, int, int, int]]:
+ block_k = max(32, int(baseline[2]))
+ if block_k % 32 != 0:
+ block_k = ((block_k + 31) // 32) * 32
+ conservative_large_tiles = [
+ (16, 32, block_k, 4, 4),
+ (16, 64, block_k, 4, 4),
+ (16, 128, block_k, 4, 4),
+ (32, 32, block_k, 4, 4),
+ (32, 64, block_k, 8, 4),
+ (32, 128, block_k, 8, 4),
+ (64, 64, block_k, 8, 4),
+ (64, 128, block_k, 8, 4),
+ ]
+ raw = [preferred]
+ raw.extend(_candidate_configs(baseline, m, k, n, kind=kind))
+ raw.extend(conservative_large_tiles)
+
+ dedup: list[tuple[int, int, int, int, int]] = []
+ seen = set()
+ for cfg in raw:
+ norm = _normalize_config(cfg)
+ if norm is None or norm in seen:
+ continue
+ if not _config_compatible_with_baseline(kind, baseline, norm):
+ continue
+ seen.add(norm)
+ dedup.append(norm)
+
+ if baseline not in dedup:
+ dedup.append(baseline)
+
+ if len(dedup) <= 1:
+ return dedup
+
+ head = dedup[0]
+ tail = dedup[1:]
+ non_tiny = [cfg for cfg in tail if not _looks_like_unsupported_dot_tile(cfg)]
+ tiny = [cfg for cfg in tail if _looks_like_unsupported_dot_tile(cfg)]
+ return [head, *non_tiny, *tiny]
+
+
+def _launch_candidate(kind: str, cfg: tuple[int, int, int, int, int], tensors: tuple[torch.Tensor, ...], m: int, n: int, k: int) -> None:
+ block_m, block_n, block_k, num_warps, num_stages = cfg
+ grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n))
+ if kind == "fused":
+ x_mm_c, qweight_c, b_scale_c, out = tensors
+ _fused_dynamic_int8_blockscale_gemm_kernel[grid](
+ x_mm_c,
+ qweight_c,
+ b_scale_c,
+ out,
+ m,
+ n,
+ k,
+ x_mm_c.stride(0),
+ x_mm_c.stride(1),
+ qweight_c.stride(0),
+ qweight_c.stride(1),
+ out.stride(0),
+ out.stride(1),
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return
+ a_int8_c, b_int8_c, a_scale_c, b_scale_c, out = tensors
+ _scaled_int8_gemm_kernel[grid](
+ a_int8_c,
+ b_int8_c,
+ a_scale_c,
+ b_scale_c,
+ out,
+ m,
+ n,
+ k,
+ a_int8_c.stride(0),
+ a_int8_c.stride(1),
+ b_int8_c.stride(0),
+ b_int8_c.stride(1),
+ out.stride(0),
+ out.stride(1),
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+
+def _create_bench_tensors(kind: str, device: torch.device, m: int, k: int, n: int) -> tuple[torch.Tensor, ...]:
+ if kind == "fused":
+ x_mm_c = torch.randn((m, k), device=device, dtype=torch.bfloat16)
+ qweight_c = torch.randint(-128, 128, (n, k), device=device, dtype=torch.int8)
+ b_scale_c = torch.rand((n,), device=device, dtype=torch.float32).add_(1e-4)
+ out = torch.empty((m, n), device=device, dtype=torch.bfloat16)
+ return (x_mm_c, qweight_c, b_scale_c, out)
+ a_int8_c = torch.randint(-128, 128, (m, k), device=device, dtype=torch.int8)
+ b_int8_c = torch.randint(-128, 128, (n, k), device=device, dtype=torch.int8)
+ a_scale_c = torch.rand((m,), device=device, dtype=torch.float32).add_(1e-4)
+ b_scale_c = torch.rand((n,), device=device, dtype=torch.float32).add_(1e-4)
+ out = torch.empty((m, n), device=device, dtype=torch.bfloat16)
+ return (a_int8_c, b_int8_c, a_scale_c, b_scale_c, out)
+
+
+def _run_candidate_once_with_error(
+ kind: str,
+ cfg: tuple[int, int, int, int, int],
+ tensors: tuple[torch.Tensor, ...],
+ m: int,
+ k: int,
+ n: int,
+) -> tuple[Optional[torch.Tensor], Optional[Exception]]:
+ try:
+ if kind == "fused":
+ x_mm_c, qweight_c, b_scale_c, _ = tensors
+ out = torch.empty((m, n), device=x_mm_c.device, dtype=torch.bfloat16)
+ _launch_candidate(kind, cfg, (x_mm_c, qweight_c, b_scale_c, out), m, n, k)
+ torch.cuda.synchronize(x_mm_c.device)
+ return out, None
+ a_int8_c, b_int8_c, a_scale_c, b_scale_c, _ = tensors
+ out = torch.empty((m, n), device=a_int8_c.device, dtype=torch.bfloat16)
+ _launch_candidate(kind, cfg, (a_int8_c, b_int8_c, a_scale_c, b_scale_c, out), m, n, k)
+ torch.cuda.synchronize(a_int8_c.device)
+ return out, None
+ except Exception as exc:
+ _autotune_debug(f"single-run failed for {kind} shape=({m},{k},{n}) cfg={cfg}: {exc}")
+ return None, exc
+
+
+def _run_candidate_once(kind: str, cfg: tuple[int, int, int, int, int], tensors: tuple[torch.Tensor, ...], m: int, k: int, n: int) -> Optional[torch.Tensor]:
+ out, _ = _run_candidate_once_with_error(kind, cfg, tensors, m, k, n)
+ return out
+
+
+def _ensure_compile_compatible_config(
+ kind: str,
+ device_index: int,
+ slot_id: str,
+ preferred: tuple[int, int, int, int, int],
+ baseline: tuple[int, int, int, int, int],
+ m: int,
+ k: int,
+ n: int,
+ rep_shapes: tuple[tuple[int, int, int], ...],
+) -> tuple[tuple[int, int, int, int, int], Optional[Exception]]:
+ device = torch.device("cuda", device_index)
+ _ = rep_shapes
+ # Probing the current shape is enough to catch tile compile incompatibilities while
+ # keeping allocations low for large representative shapes.
+ probe_shapes = ((m, k, n),)
+ tensors_by_shape = {shape: _create_bench_tensors(kind, device, *shape) for shape in probe_shapes}
+ candidates = _compile_recovery_candidates(kind, baseline, preferred, m, k, n)
+ last_error: Optional[Exception] = None
+
+ for cfg in candidates:
+ all_ok = True
+ for probe_m, probe_k, probe_n in probe_shapes:
+ _, probe_err = _run_candidate_once_with_error(
+ kind,
+ cfg,
+ tensors_by_shape[(probe_m, probe_k, probe_n)],
+ probe_m,
+ probe_k,
+ probe_n,
+ )
+ if probe_err is not None:
+ all_ok = False
+ last_error = probe_err
+ break
+ if all_ok:
+ if cfg != preferred:
+ _autotune_debug(
+ f"compile recovery picked {cfg} for {kind} slot={slot_id} shape=({m},{k},{n}) "
+ f"instead of {preferred}"
+ )
+ return cfg, None
+
+ if last_error is not None:
+ _autotune_debug(
+ f"compile recovery failed for {kind} slot={slot_id} shape=({m},{k},{n}); "
+ f"keeping {preferred}. last_error={last_error}"
+ )
+ return preferred, last_error
+
+
+def _candidate_matches_baseline(
+ baseline_out: torch.Tensor,
+ candidate_out: torch.Tensor,
+ *,
+ max_abs_limit: float,
+ rel_limit: float,
+) -> tuple[bool, float, float]:
+ if not torch.isfinite(candidate_out).all().item():
+ return False, float("inf"), float("inf")
+ base_f = baseline_out.float()
+ cand_f = candidate_out.float()
+ diff = (base_f - cand_f).abs()
+ max_abs = float(diff.max().item())
+ denom = base_f.abs().mean().clamp_min(1e-6)
+ rel = float((diff.mean() / denom).item())
+ return (max_abs <= max_abs_limit and rel <= rel_limit), max_abs, rel
+
+
+def _validate_config(
+ kind: str,
+ device: torch.device,
+ m: int,
+ k: int,
+ n: int,
+ baseline: tuple[int, int, int, int, int],
+ cfg: tuple[int, int, int, int, int],
+) -> bool:
+ if not _config_compatible_with_baseline(kind, baseline, cfg):
+ return False
+ if cfg == baseline:
+ return True
+ if not _env_flag(_ENV_AUTOTUNE_VALIDATE, "1"):
+ return True
+ max_abs_limit = max(0.0, _env_float(_ENV_AUTOTUNE_MAX_ABS_ERR, 0.25))
+ rel_limit = max(0.0, _env_float(_ENV_AUTOTUNE_MAX_REL_ERR, 0.001))
+ tensors = _create_bench_tensors(kind, device, m, k, n)
+ baseline_out = _run_candidate_once(kind, baseline, tensors, m, k, n)
+ candidate_out = _run_candidate_once(kind, cfg, tensors, m, k, n)
+ if baseline_out is None or candidate_out is None:
+ return False
+ ok, max_abs, rel = _candidate_matches_baseline(
+ baseline_out,
+ candidate_out,
+ max_abs_limit=max_abs_limit,
+ rel_limit=rel_limit,
+ )
+ if not ok:
+ _autotune_debug(
+ f"rejecting config {cfg} for {kind} shape=({m},{k},{n}) "
+ f"vs baseline {baseline}: max_abs={max_abs:.6f}, rel={rel:.6f}"
+ )
+ return ok
+
+
+def _benchmark_config_ms(kind: str, cfg: tuple[int, int, int, int, int], tensors: tuple[torch.Tensor, ...], device: torch.device, m: int, k: int, n: int) -> Optional[float]:
+ warmup = max(1, _env_int(_ENV_AUTOTUNE_WARMUP, 2))
+ iters = max(1, _env_int(_ENV_AUTOTUNE_ITERS, 5))
+ try:
+ for _ in range(warmup):
+ _launch_candidate(kind, cfg, tensors, m, n, k)
+ torch.cuda.synchronize(device)
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ for _ in range(iters):
+ _launch_candidate(kind, cfg, tensors, m, n, k)
+ end.record()
+ end.synchronize()
+ return float(start.elapsed_time(end)) / float(iters)
+ except Exception as exc:
+ _autotune_debug(f"benchmark failed for {kind} shape=({m},{k},{n}) cfg={cfg}: {exc}")
+ return None
+
+
+def _can_tune_slot(slot_key: tuple[int, str, str]) -> bool:
+ global _AUTOTUNE_SLOTS_TUNED
+ if slot_key in _AUTOTUNE_SEEN_SLOTS:
+ return True
+ max_shapes = max(0, _env_int(_ENV_AUTOTUNE_MAX_SHAPES, 32))
+ if _AUTOTUNE_SLOTS_TUNED >= max_shapes:
+ return False
+ _AUTOTUNE_SEEN_SLOTS.add(slot_key)
+ _AUTOTUNE_SLOTS_TUNED += 1
+ return True
+
+
+def _benchmark_slot_config_ms(
+ kind: str,
+ cfg: tuple[int, int, int, int, int],
+ device: torch.device,
+ rep_shapes: tuple[tuple[int, int, int], ...],
+) -> Optional[float]:
+ total = 0.0
+ count = 0
+ for rep_m, rep_k, rep_n in rep_shapes:
+ rep_baseline = _select_static_triton_int8_config(rep_m, rep_k, rep_n)
+ if not _validate_config(kind, device, rep_m, rep_k, rep_n, rep_baseline, cfg):
+ return None
+ tensors = _create_bench_tensors(kind, device, rep_m, rep_k, rep_n)
+ ms = _benchmark_config_ms(kind, cfg, tensors, device, rep_m, rep_k, rep_n)
+ if ms is None:
+ return None
+ total += ms
+ count += 1
+ if count == 0:
+ return None
+ return total / float(count)
+
+
+def _autotune_config(
+ kind: str,
+ device_index: int,
+ m: int,
+ k: int,
+ n: int,
+ baseline: tuple[int, int, int, int, int],
+ slot_id: str,
+ rep_shapes: tuple[tuple[int, int, int], ...],
+) -> tuple[int, int, int, int, int]:
+ device = torch.device("cuda", device_index)
+ cached = _get_cached_config(device_index, kind, slot_id, m, k, n)
+ if cached is not None:
+ if _validate_config(kind, device, m, k, n, baseline, cached):
+ return cached
+ _drop_cached_config(device_index, kind, slot_id, m, k, n)
+ slot_key = (device_index, kind, slot_id)
+ if not _can_tune_slot(slot_key):
+ _autotune_debug(f"slot budget reached; keeping baseline for {kind} slot={slot_id} shape=({m},{k},{n})")
+ return baseline
+
+ rep_m, rep_k, rep_n = rep_shapes[0]
+ rep_baseline = _select_static_triton_int8_config(rep_m, rep_k, rep_n)
+ candidate_seed = rep_baseline if _config_compatible_with_baseline(kind, baseline, rep_baseline) else baseline
+ candidates = _candidate_configs(candidate_seed, rep_m, rep_k, rep_n, kind=kind)
+ if baseline not in candidates:
+ candidates = [baseline, *candidates]
+
+ results: dict[tuple[int, int, int, int, int], float] = {}
+ for cfg in candidates:
+ ms = _benchmark_slot_config_ms(kind, cfg, device, rep_shapes)
+ if ms is not None:
+ results[cfg] = ms
+ baseline_ms = results.get(baseline)
+ if baseline_ms is None:
+ if len(results) > 0:
+ recovered_cfg, recovered_ms = min(results.items(), key=lambda item: item[1])
+ _set_cached_config(device_index, kind, slot_id, recovered_cfg)
+ _autotune_debug(
+ f"baseline config failed for {kind} slot={slot_id} shape=({m},{k},{n}); "
+ f"using first compilable cfg={recovered_cfg} (ms={recovered_ms:.4f})"
+ )
+ return recovered_cfg
+ _set_cached_config(device_index, kind, slot_id, baseline)
+ _autotune_debug(
+ f"no compilable configs found during autotune for {kind} slot={slot_id} shape=({m},{k},{n}); "
+ f"keeping baseline {baseline}"
+ )
+ return baseline
+ best_cfg, best_ms = min(results.items(), key=lambda item: item[1])
+ min_speedup = max(1.0, _env_float(_ENV_AUTOTUNE_MIN_SPEEDUP, 1.02))
+ use_best = best_cfg != baseline and best_ms > 0.0 and (baseline_ms / best_ms) >= min_speedup
+ picked = best_cfg if use_best else baseline
+ if not _validate_config(kind, device, m, k, n, baseline, picked):
+ picked = baseline
+ _set_cached_config(device_index, kind, slot_id, picked)
+ if use_best:
+ _autotune_debug(
+ f"picked {picked} over baseline {baseline} for {kind} slot={slot_id} shape=({m},{k},{n}), "
+ f"baseline_ms={baseline_ms:.4f}, tuned_ms={best_ms:.4f}, speedup={baseline_ms / best_ms:.3f}x"
+ )
+ else:
+ _autotune_debug(
+ f"kept baseline {baseline} for {kind} slot={slot_id} shape=({m},{k},{n}), "
+ f"baseline_ms={baseline_ms:.4f}, best_cfg={best_cfg}, best_ms={best_ms:.4f}"
+ )
+ return picked
+
+
+def _select_triton_int8_config(
+ m: int,
+ k: int,
+ n: int,
+ *,
+ device: Optional[torch.device] = None,
+ kernel_kind: str = "fused",
+) -> tuple[int, int, int, int, int]:
+ baseline = _select_static_triton_int8_config(m, k, n)
+ if not is_available() or not torch.cuda.is_available():
+ return baseline
+ try:
+ device_index = _device_index(device)
+ except Exception:
+ return baseline
+ slot_id, rep_shapes = _resolve_autotune_slot(m, k, n)
+ session_key = (device_index, kernel_kind, slot_id)
+ cached = _AUTOTUNE_SESSION_CACHE.get(session_key)
+ if cached is not None:
+ return cached
+
+ autotune_enabled = _env_flag(_ENV_AUTOTUNE_ENABLE, "1")
+ max_m = _env_int(_ENV_AUTOTUNE_MAX_M, -1)
+ if autotune_enabled and not (max_m >= 0 and m > max_m):
+ preferred = _autotune_config(kernel_kind, device_index, m, k, n, baseline, slot_id, rep_shapes)
+ else:
+ preferred = baseline
+
+ compile_safe, compile_err = _ensure_compile_compatible_config(
+ kernel_kind,
+ device_index,
+ slot_id,
+ preferred,
+ baseline,
+ m,
+ k,
+ n,
+ rep_shapes,
+ )
+
+ picked = compile_safe
+ if compile_safe != preferred:
+ _set_cached_config(device_index, kernel_kind, slot_id, compile_safe)
+ elif compile_err is not None:
+ _autotune_debug(
+ f"compile probe could not find an alternative for {kernel_kind} slot={slot_id} "
+ f"shape=({m},{k},{n}); will keep {preferred}"
+ )
+ _AUTOTUNE_SESSION_CACHE[session_key] = picked
+ return picked
+
+
+atexit.register(_save_autotune_cache)
+
+
+if _TRITON_AVAILABLE:
+
+ @triton.jit
+ def _fused_dynamic_int8_gemm_kernel(
+ a_ptr,
+ b_ptr,
+ s_ptr,
+ c_ptr,
+ m,
+ n,
+ k,
+ stride_am,
+ stride_ak,
+ stride_bn,
+ stride_bk,
+ stride_cm,
+ stride_cn,
+ block_m: tl.constexpr,
+ block_n: tl.constexpr,
+ block_k: tl.constexpr,
+ ):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * block_m + tl.arange(0, block_m)
+ offs_n = pid_n * block_n + tl.arange(0, block_n)
+ offs_k = tl.arange(0, block_k)
+
+ # Pass 1: rowwise absmax for dynamic symmetric int8 activation quantization.
+ row_amax = tl.zeros((block_m,), dtype=tl.float32)
+ for k0 in range(0, k, block_k):
+ kk = k0 + offs_k
+ a = tl.load(
+ a_ptr + offs_m[:, None] * stride_am + kk[None, :] * stride_ak,
+ mask=(offs_m[:, None] < m) & (kk[None, :] < k),
+ other=0,
+ ).to(tl.float32)
+ row_amax = tl.maximum(row_amax, tl.max(tl.abs(a), axis=1))
+
+ row_scale = row_amax / 127.0
+ row_scale = tl.where(row_scale > 0.0, row_scale, 1.0)
+ row_inv_scale = 1.0 / row_scale
+
+ # Pass 2: quantize activations on the fly + int8 dot.
+ acc = tl.zeros((block_m, block_n), dtype=tl.int32)
+ for k0 in range(0, k, block_k):
+ kk = k0 + offs_k
+ a = tl.load(
+ a_ptr + offs_m[:, None] * stride_am + kk[None, :] * stride_ak,
+ mask=(offs_m[:, None] < m) & (kk[None, :] < k),
+ other=0,
+ ).to(tl.float32)
+ a = a * row_inv_scale[:, None]
+ # Match torch.round behavior (ties-to-even) used by quanto::quantize_symmetric.
+ a = tl_libdevice.rint(a)
+ a = tl.maximum(tl.minimum(a, 127.0), -128.0).to(tl.int8)
+
+ # Weight is [N, K]; load as [K, N] tile for dot.
+ b = tl.load(
+ b_ptr + offs_n[None, :] * stride_bn + kk[:, None] * stride_bk,
+ mask=(offs_n[None, :] < n) & (kk[:, None] < k),
+ other=0,
+ ).to(tl.int8)
+ acc += tl.dot(a, b)
+
+ scales = tl.load(s_ptr + offs_n, mask=offs_n < n, other=0).to(tl.float32)
+ out = acc.to(tl.float32) * row_scale[:, None] * scales[None, :]
+ tl.store(
+ c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
+ out,
+ mask=(offs_m[:, None] < m) & (offs_n[None, :] < n),
+ )
+
+ @triton.jit
+ def _fused_dynamic_int8_blockscale_gemm_kernel(
+ a_ptr,
+ b_ptr,
+ s_ptr,
+ c_ptr,
+ m,
+ n,
+ k,
+ stride_am,
+ stride_ak,
+ stride_bn,
+ stride_bk,
+ stride_cm,
+ stride_cn,
+ block_m: tl.constexpr,
+ block_n: tl.constexpr,
+ block_k: tl.constexpr,
+ ):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * block_m + tl.arange(0, block_m)
+ offs_n = pid_n * block_n + tl.arange(0, block_n)
+ offs_k = tl.arange(0, block_k)
+
+ acc = tl.zeros((block_m, block_n), dtype=tl.float32)
+ for k0 in range(0, k, block_k):
+ kk = k0 + offs_k
+ a = tl.load(
+ a_ptr + offs_m[:, None] * stride_am + kk[None, :] * stride_ak,
+ mask=(offs_m[:, None] < m) & (kk[None, :] < k),
+ other=0,
+ ).to(tl.float32)
+ row_amax = tl.max(tl.abs(a), axis=1)
+ row_scale = row_amax / 127.0
+ row_scale = tl.where(row_scale > 0.0, row_scale, 1.0)
+ a = a / row_scale[:, None]
+ a = tl_libdevice.rint(a)
+ a = tl.maximum(tl.minimum(a, 127.0), -128.0).to(tl.int8)
+
+ b = tl.load(
+ b_ptr + offs_n[None, :] * stride_bn + kk[:, None] * stride_bk,
+ mask=(offs_n[None, :] < n) & (kk[:, None] < k),
+ other=0,
+ ).to(tl.int8)
+
+ dot_i32 = tl.dot(a, b)
+ acc += dot_i32.to(tl.float32) * row_scale[:, None]
+
+ scales = tl.load(s_ptr + offs_n, mask=offs_n < n, other=0).to(tl.float32)
+ out = acc * scales[None, :]
+ tl.store(
+ c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
+ out,
+ mask=(offs_m[:, None] < m) & (offs_n[None, :] < n),
+ )
+
+ @triton.jit
+ def _scaled_int8_gemm_kernel(
+ a_ptr,
+ b_ptr,
+ a_scales_ptr,
+ b_scales_ptr,
+ c_ptr,
+ m,
+ n,
+ k,
+ stride_am,
+ stride_ak,
+ stride_bn,
+ stride_bk,
+ stride_cm,
+ stride_cn,
+ block_m: tl.constexpr,
+ block_n: tl.constexpr,
+ block_k: tl.constexpr,
+ ):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * block_m + tl.arange(0, block_m)
+ offs_n = pid_n * block_n + tl.arange(0, block_n)
+ offs_k = tl.arange(0, block_k)
+
+ acc = tl.zeros((block_m, block_n), dtype=tl.int32)
+ for k0 in range(0, k, block_k):
+ kk = k0 + offs_k
+ a = tl.load(
+ a_ptr + offs_m[:, None] * stride_am + kk[None, :] * stride_ak,
+ mask=(offs_m[:, None] < m) & (kk[None, :] < k),
+ other=0,
+ ).to(tl.int8)
+ # Weight is [N, K]; load as [K, N] tile for dot.
+ b = tl.load(
+ b_ptr + offs_n[None, :] * stride_bn + kk[:, None] * stride_bk,
+ mask=(offs_n[None, :] < n) & (kk[:, None] < k),
+ other=0,
+ ).to(tl.int8)
+ acc += tl.dot(a, b)
+
+ a_scales = tl.load(a_scales_ptr + offs_m, mask=offs_m < m, other=1).to(tl.float32)
+ b_scales = tl.load(b_scales_ptr + offs_n, mask=offs_n < n, other=1).to(tl.float32)
+ out = acc.to(tl.float32) * a_scales[:, None] * b_scales[None, :]
+ tl.store(
+ c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
+ out,
+ mask=(offs_m[:, None] < m) & (offs_n[None, :] < n),
+ )
+
+
+def _flatten_scale(scale: torch.Tensor) -> torch.Tensor:
+ if scale.ndim == 2 and scale.shape[1] == 1:
+ return scale.view(-1)
+ if scale.ndim == 1:
+ return scale
+ return scale.reshape(-1)
+
+
+def _expand_or_validate_scale(scale: torch.Tensor, expected: int) -> torch.Tensor:
+ scale = _flatten_scale(scale)
+ if scale.numel() == 1:
+ return scale.reshape(1).expand(expected)
+ if scale.numel() != expected:
+ raise RuntimeError(f"Scale length mismatch: expected {expected}, got {scale.numel()}")
+ return scale
+
+
+def _fused_quant_scaled_mm_common(
+ x2d: torch.Tensor,
+ qweight: torch.Tensor,
+ b_scale: torch.Tensor,
+ *,
+ k: int,
+ n: int,
+ stride_bn: int,
+ stride_bk: int,
+ out_dtype: torch.dtype,
+) -> torch.Tensor:
+ m = x2d.shape[0]
+ out = torch.empty((m, n), device=x2d.device, dtype=out_dtype)
+ x_mm_c = x2d if x2d.is_contiguous() else x2d.contiguous()
+ b_scale_c = b_scale if b_scale.is_contiguous() else b_scale.contiguous()
+
+ block_m, block_n, block_k, num_warps, num_stages = _select_triton_int8_config(m, k, n, device=x2d.device, kernel_kind="fused")
+ grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n))
+ _fused_dynamic_int8_blockscale_gemm_kernel[grid](
+ x_mm_c,
+ qweight,
+ b_scale_c,
+ out,
+ m,
+ n,
+ k,
+ x_mm_c.stride(0),
+ x_mm_c.stride(1),
+ stride_bn,
+ stride_bk,
+ out.stride(0),
+ out.stride(1),
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return out
+
+
+def fused_quant_scaled_mm(
+ x2d: torch.Tensor,
+ qweight: torch.Tensor,
+ qweight_scale: torch.Tensor,
+ out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+ if not is_available():
+ raise RuntimeError("Triton backend not available")
+ if x2d.ndim != 2:
+ raise RuntimeError("x2d must be 2D")
+ if qweight.ndim != 2:
+ raise RuntimeError("qweight must be 2D [N, K]")
+ if x2d.dtype not in (torch.bfloat16, torch.float16, torch.float32):
+ raise RuntimeError("x2d must be bf16/fp16/fp32")
+ if qweight.dtype != torch.int8:
+ raise RuntimeError("qweight must be int8")
+ if not x2d.is_cuda or not qweight.is_cuda:
+ raise RuntimeError("fused_quant_scaled_mm requires CUDA tensors")
+
+ m, k = x2d.shape
+ n, k2 = qweight.shape
+ if k != k2:
+ raise RuntimeError(f"Triton int8 GEMM shape mismatch: x={x2d.shape}, w={qweight.shape}")
+
+ b_scale = _expand_or_validate_scale(qweight_scale, n)
+ if b_scale.device != x2d.device or b_scale.dtype != torch.float32:
+ b_scale = b_scale.to(device=x2d.device, dtype=torch.float32)
+ elif not b_scale.is_contiguous():
+ b_scale = b_scale.contiguous()
+ if x2d.dtype not in (torch.bfloat16, torch.float16, torch.float32):
+ raise RuntimeError(f"Unsupported activation dtype for fused path: {x2d.dtype}")
+
+ out_dtype = out_dtype or x2d.dtype
+ qweight_c = qweight if qweight.is_contiguous() else qweight.contiguous()
+ return _fused_quant_scaled_mm_common(
+ x2d,
+ qweight_c,
+ b_scale,
+ k=k,
+ n=n,
+ stride_bn=qweight_c.stride(0),
+ stride_bk=qweight_c.stride(1),
+ out_dtype=out_dtype,
+ )
+
+
+def fused_quant_scaled_mm_transposed(
+ x2d: torch.Tensor,
+ qweight_t: torch.Tensor,
+ qweight_scale: torch.Tensor,
+ out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+ if not is_available():
+ raise RuntimeError("Triton backend not available")
+ if x2d.ndim != 2:
+ raise RuntimeError("x2d must be 2D")
+ if qweight_t.ndim != 2:
+ raise RuntimeError("qweight_t must be 2D [K, N]")
+ if x2d.dtype not in (torch.bfloat16, torch.float16, torch.float32):
+ raise RuntimeError("x2d must be bf16/fp16/fp32")
+ if qweight_t.dtype != torch.int8:
+ raise RuntimeError("qweight_t must be int8")
+ if not x2d.is_cuda or not qweight_t.is_cuda:
+ raise RuntimeError("fused_quant_scaled_mm_transposed requires CUDA tensors")
+
+ m, k = x2d.shape
+ k2, n = qweight_t.shape
+ if k != k2:
+ raise RuntimeError(f"Triton int8 GEMM shape mismatch: x={x2d.shape}, w_t={qweight_t.shape}")
+
+ b_scale = _expand_or_validate_scale(qweight_scale, n)
+ if b_scale.device != x2d.device or b_scale.dtype != torch.float32:
+ b_scale = b_scale.to(device=x2d.device, dtype=torch.float32)
+ elif not b_scale.is_contiguous():
+ b_scale = b_scale.contiguous()
+ if x2d.dtype not in (torch.bfloat16, torch.float16, torch.float32):
+ raise RuntimeError(f"Unsupported activation dtype for fused path: {x2d.dtype}")
+
+ out_dtype = out_dtype or x2d.dtype
+ qweight_t_c = qweight_t if qweight_t.is_contiguous() else qweight_t.contiguous()
+ return _fused_quant_scaled_mm_common(
+ x2d,
+ qweight_t_c,
+ b_scale,
+ k=k,
+ n=n,
+ stride_bn=qweight_t_c.stride(1),
+ stride_bk=qweight_t_c.stride(0),
+ out_dtype=out_dtype,
+ )
+
+
+def scaled_int8_mm(
+ a_int8: torch.Tensor,
+ b_int8: torch.Tensor,
+ a_scale: torch.Tensor,
+ b_scale: torch.Tensor,
+ out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+ if not is_available():
+ raise RuntimeError("Triton backend not available")
+ if a_int8.ndim != 2:
+ raise RuntimeError("a_int8 must be 2D")
+ if b_int8.ndim != 2:
+ raise RuntimeError("b_int8 must be 2D [N, K]")
+ if a_int8.dtype != torch.int8 or b_int8.dtype != torch.int8:
+ raise RuntimeError("scaled_int8_mm requires int8 activations and int8 weights")
+ if not a_int8.is_cuda or not b_int8.is_cuda:
+ raise RuntimeError("scaled_int8_mm requires CUDA tensors")
+
+ m, k = a_int8.shape
+ n, k2 = b_int8.shape
+ if k != k2:
+ raise RuntimeError(f"Triton int8 GEMM shape mismatch: a={a_int8.shape}, w={b_int8.shape}")
+
+ a_scale = _expand_or_validate_scale(a_scale, m)
+ b_scale = _expand_or_validate_scale(b_scale, n)
+ if a_scale.device != a_int8.device or a_scale.dtype != torch.float32:
+ a_scale = a_scale.to(device=a_int8.device, dtype=torch.float32)
+ elif not a_scale.is_contiguous():
+ a_scale = a_scale.contiguous()
+ if b_scale.device != a_int8.device or b_scale.dtype != torch.float32:
+ b_scale = b_scale.to(device=a_int8.device, dtype=torch.float32)
+ elif not b_scale.is_contiguous():
+ b_scale = b_scale.contiguous()
+
+ out_dtype = out_dtype or torch.bfloat16
+ out = torch.empty((m, n), device=a_int8.device, dtype=out_dtype)
+ a_int8_c = a_int8 if a_int8.is_contiguous() else a_int8.contiguous()
+ b_int8_c = b_int8 if b_int8.is_contiguous() else b_int8.contiguous()
+ a_scale_c = a_scale if a_scale.is_contiguous() else a_scale.contiguous()
+ b_scale_c = b_scale if b_scale.is_contiguous() else b_scale.contiguous()
+
+ block_m, block_n, block_k, num_warps, num_stages = _select_triton_int8_config(m, k, n, device=a_int8.device, kernel_kind="scaled")
+ grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n))
+ _scaled_int8_gemm_kernel[grid](
+ a_int8_c,
+ b_int8_c,
+ a_scale_c,
+ b_scale_c,
+ out,
+ m,
+ n,
+ k,
+ a_int8_c.stride(0),
+ a_int8_c.stride(1),
+ b_int8_c.stride(0),
+ b_int8_c.stride(1),
+ out.stride(0),
+ out.stride(1),
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return out
diff --git a/Wan2GP/shared/llm_engines/__init__.py b/Wan2GP/shared/llm_engines/__init__.py
new file mode 100644
index 000000000..da404f0f3
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/__init__.py
@@ -0,0 +1 @@
+"""Shared LLM engine helpers."""
diff --git a/Wan2GP/shared/llm_engines/nanovllm/__init__.py b/Wan2GP/shared/llm_engines/nanovllm/__init__.py
new file mode 100644
index 000000000..8ef752613
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/__init__.py
@@ -0,0 +1,11 @@
+__all__ = ["LLM", "SamplingParams"]
+
+
+def __getattr__(name):
+ if name == "LLM":
+ from .llm import LLM
+ return LLM
+ if name == "SamplingParams":
+ from .sampling_params import SamplingParams
+ return SamplingParams
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
diff --git a/Wan2GP/shared/llm_engines/nanovllm/config.py b/Wan2GP/shared/llm_engines/nanovllm/config.py
new file mode 100644
index 000000000..69513e238
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/config.py
@@ -0,0 +1,34 @@
+import os
+from dataclasses import dataclass
+from transformers import AutoConfig
+
+
+@dataclass
+class Config:
+ model: str
+ max_num_batched_tokens: int = 16384
+ max_num_seqs: int = 512
+ max_model_len: int = 4096
+ gpu_memory_utilization: float = 0.9
+ tensor_parallel_size: int = 1
+ enforce_eager: bool = False
+ weight_load_mode: str = "eager" # eager | lazy | pinned
+ hf_config: AutoConfig | None = None
+ eos: int = -1
+ kvcache_block_size: int = 256
+ num_kvcache_blocks: int = -1
+ model_dir: str | None = None
+ model_file: str | None = None
+
+ def __post_init__(self):
+ if os.path.isfile(self.model) and self.model.endswith(".safetensors"):
+ self.model_file = self.model
+ self.model_dir = os.path.dirname(self.model)
+ else:
+ assert os.path.isdir(self.model)
+ self.model_dir = self.model
+ assert self.kvcache_block_size % 256 == 0
+ assert 1 <= self.tensor_parallel_size <= 8
+ self.hf_config = AutoConfig.from_pretrained(self.model_dir)
+ self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
+ assert self.max_num_batched_tokens >= self.max_model_len
diff --git a/Wan2GP/shared/llm_engines/nanovllm/engine/block_manager.py b/Wan2GP/shared/llm_engines/nanovllm/engine/block_manager.py
new file mode 100644
index 000000000..4954044a9
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/engine/block_manager.py
@@ -0,0 +1,119 @@
+from collections import deque
+import xxhash
+import numpy as np
+
+from nanovllm.engine.sequence import Sequence
+
+
+class Block:
+
+ def __init__(self, block_id):
+ self.block_id = block_id
+ self.ref_count = 0
+ self.hash = -1
+ self.token_ids = []
+
+ def update(self, hash: int, token_ids: list[int]):
+ self.hash = hash
+ self.token_ids = token_ids
+
+ def reset(self):
+ self.ref_count = 1
+ self.hash = -1
+ self.token_ids = []
+
+
+class BlockManager:
+
+ def __init__(self, num_blocks: int, block_size: int):
+ self.block_size = block_size
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
+ self.hash_to_block_id: dict[int, int] = dict()
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
+ self.used_block_ids: set[int] = set()
+
+ @classmethod
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
+ h = xxhash.xxh64()
+ if prefix != -1:
+ h.update(prefix.to_bytes(8, "little"))
+ h.update(np.array(token_ids).tobytes())
+ return h.intdigest()
+
+ def _allocate_block(self, block_id: int) -> Block:
+ block = self.blocks[block_id]
+ assert block.ref_count == 0
+ block.reset()
+ self.free_block_ids.remove(block_id)
+ self.used_block_ids.add(block_id)
+ return self.blocks[block_id]
+
+ def _deallocate_block(self, block_id: int) -> Block:
+ assert self.blocks[block_id].ref_count == 0
+ self.used_block_ids.remove(block_id)
+ self.free_block_ids.append(block_id)
+
+ def can_allocate(self, seq: Sequence) -> bool:
+ return len(self.free_block_ids) >= seq.num_blocks
+
+ def allocate(self, seq: Sequence):
+ assert not seq.block_table
+ h = -1
+ cache_miss = False
+ for i in range(seq.num_blocks):
+ token_ids = seq.block(i)
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
+ block_id = self.hash_to_block_id.get(h, -1)
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
+ cache_miss = True
+ if cache_miss:
+ block_id = self.free_block_ids[0]
+ block = self._allocate_block(block_id)
+ else:
+ seq.num_cached_tokens += self.block_size
+ if block_id in self.used_block_ids:
+ block = self.blocks[block_id]
+ block.ref_count += 1
+ else:
+ block = self._allocate_block(block_id)
+ if h != -1:
+ block.update(h, token_ids)
+ self.hash_to_block_id[h] = block_id
+ seq.block_table.append(block_id)
+
+ def deallocate(self, seq: Sequence):
+ for block_id in reversed(seq.block_table):
+ block = self.blocks[block_id]
+ block.ref_count -= 1
+ if block.ref_count == 0:
+ # Fix: Clean up hash_to_block_id mapping to prevent stale references
+ # This prevents CUDA illegal memory access when prefix cache tries to
+ # reuse a block_id that has already been freed
+ if block.hash != -1:
+ cached_id = self.hash_to_block_id.get(block.hash)
+ if cached_id == block_id:
+ del self.hash_to_block_id[block.hash]
+ self._deallocate_block(block_id)
+ seq.num_cached_tokens = 0
+ seq.block_table.clear()
+
+ def can_append(self, seq: Sequence) -> bool:
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
+
+ def may_append(self, seq: Sequence):
+ block_table = seq.block_table
+ last_block = self.blocks[block_table[-1]]
+ if len(seq) % self.block_size == 1:
+ assert last_block.hash != -1
+ block_id = self.free_block_ids[0]
+ self._allocate_block(block_id)
+ block_table.append(block_id)
+ elif len(seq) % self.block_size == 0:
+ assert last_block.hash == -1
+ token_ids = seq.block(seq.num_blocks-1)
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
+ h = self.compute_hash(token_ids, prefix)
+ last_block.update(h, token_ids)
+ self.hash_to_block_id[h] = last_block.block_id
+ else:
+ assert last_block.hash == -1
diff --git a/Wan2GP/shared/llm_engines/nanovllm/engine/llm_engine.py b/Wan2GP/shared/llm_engines/nanovllm/engine/llm_engine.py
new file mode 100644
index 000000000..b48d179be
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/engine/llm_engine.py
@@ -0,0 +1,203 @@
+import atexit
+from dataclasses import fields
+from time import perf_counter
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer
+import torch.multiprocessing as mp
+import torch
+
+from nanovllm.config import Config
+from nanovllm.sampling_params import SamplingParams
+from nanovllm.engine.sequence import Sequence
+from nanovllm.engine.scheduler import Scheduler
+from nanovllm.engine.block_manager import BlockManager
+from nanovllm.engine.model_runner import ModelRunner
+
+
+class LLMEngine:
+
+ def __init__(self, model, **kwargs):
+ config_fields = {field.name for field in fields(Config)}
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
+ config = Config(model, **config_kwargs)
+ self.config = config
+ self.ps = []
+ self.events = []
+ ctx = mp.get_context("spawn")
+ for i in range(1, config.tensor_parallel_size):
+ event = ctx.Event()
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
+ process.start()
+ self.ps.append(process)
+ self.events.append(event)
+ self.model_runner = ModelRunner(config, 0, self.events)
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ self.tokenizer = tokenizer
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
+ config.eos = self.tokenizer.eos_token_id
+ self.scheduler = Scheduler(config)
+ atexit.register(self.exit)
+
+ def exit(self):
+ self.model_runner.call("exit")
+ del self.model_runner
+ for p in self.ps:
+ p.join()
+
+ def unload_weights(self):
+ self.model_runner.unload_weights()
+ # KV cache is invalid after unload/reload, so cached prefix block metadata
+ # must be dropped as well to prevent stale-cache reuse.
+ try:
+ self.reset()
+ except Exception:
+ pass
+ self.scheduler.waiting.clear()
+ self.scheduler.running.clear()
+ self.scheduler.block_manager = BlockManager(
+ self.config.num_kvcache_blocks,
+ self.config.kvcache_block_size,
+ )
+
+ def clear_graph_cache(self):
+ self.model_runner.clear_graph_cache()
+
+ def reset_guard_counts(self):
+ self.model_runner.call("reset_guard_counts")
+
+ def get_guard_counts(self, reset: bool = False):
+ return self.model_runner.call("get_guard_counts", reset)
+
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
+ if isinstance(prompt, str):
+ prompt = self.tokenizer.encode(prompt)
+ # For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
+ if sampling_params.cfg_scale > 1.0:
+ if unconditional_prompt is None:
+ # Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
+ # This is a fallback - ideally users should provide unconditional_prompt
+ if isinstance(prompt, list):
+ # For now, just use the same prompt (user should provide unconditional_prompt)
+ # TODO: Implement automatic "NO USER INPUT" replacement if possible
+ unconditional_prompt = prompt
+ else:
+ unconditional_prompt = prompt
+ if isinstance(unconditional_prompt, str):
+ unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
+ # Create unconditional sequence first (so we can reference it from conditional)
+ uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
+ # Create conditional sequence with reference to unconditional
+ cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
+ uncond_seq.paired_seq = cond_seq # Link them bidirectionally
+ # Add both sequences to scheduler
+ self.scheduler.add(cond_seq)
+ self.scheduler.add(uncond_seq)
+ else:
+ seq = Sequence(prompt, sampling_params)
+ self.scheduler.add(seq)
+
+ def step(self):
+ seqs, is_prefill = self.scheduler.schedule()
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
+ self.scheduler.postprocess(seqs, token_ids)
+ # Only output conditional sequences (unconditional sequences are just for CFG computation)
+ output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
+ return outputs, num_tokens
+
+ def is_finished(self):
+ return self.scheduler.is_finished()
+
+ def reset(self):
+ """
+ Reset the scheduler state and release all allocated blocks.
+ This should be called when an exception occurs during generation to prevent
+ KV cache block leaks that can cause 'deque index out of range' errors.
+ """
+ # Deallocate all running sequences
+ while self.scheduler.running:
+ seq = self.scheduler.running.popleft()
+ if seq.block_table: # Only deallocate if blocks are allocated
+ self.scheduler.block_manager.deallocate(seq)
+
+ # Deallocate all waiting sequences (they might have blocks from preemption)
+ while self.scheduler.waiting:
+ seq = self.scheduler.waiting.popleft()
+ if seq.block_table:
+ self.scheduler.block_manager.deallocate(seq)
+
+ def generate(
+ self,
+ prompts: list[str] | list[list[int]],
+ sampling_params: SamplingParams | list[SamplingParams],
+ use_tqdm: bool = True,
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
+ ) -> list[str]:
+ # Ensure weights/KV cache are ready for lazy/pinned modes, and sync scheduler blocks.
+ self.model_runner.ensure_weights_loaded()
+ if (self.config.num_kvcache_blocks > 0 and
+ len(self.scheduler.block_manager.blocks) != self.config.num_kvcache_blocks):
+ self.scheduler.block_manager = BlockManager(
+ self.config.num_kvcache_blocks,
+ self.config.kvcache_block_size,
+ )
+ # Clean up any residual state from previous interrupted generations
+ # This prevents 'deque index out of range' errors from accumulated block leaks
+ if not self.is_finished():
+ self.reset()
+
+ if use_tqdm:
+ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
+ if not isinstance(sampling_params, list):
+ sampling_params = [sampling_params] * len(prompts)
+ # Seed once per request-batch; keeps deterministic decode without per-step overhead.
+ seed_to_apply = None
+ for sp in sampling_params:
+ seed_val = getattr(sp, "seed", None)
+ if seed_val is not None:
+ try:
+ seed_to_apply = int(seed_val)
+ break
+ except Exception:
+ seed_to_apply = None
+ if seed_to_apply is not None:
+ torch.manual_seed(seed_to_apply)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed_to_apply)
+ if unconditional_prompts is None:
+ unconditional_prompts = [None] * len(prompts)
+ for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
+ self.add_request(prompt, sp, uncond_prompt)
+ outputs = {}
+ prefill_throughput = decode_throughput = 0.
+ try:
+ while not self.is_finished():
+ t = perf_counter()
+ output, num_tokens = self.step()
+ if use_tqdm:
+ if num_tokens > 0:
+ prefill_throughput = num_tokens / (perf_counter() - t)
+ else:
+ decode_throughput = -num_tokens / (perf_counter() - t)
+ pbar.set_postfix({
+ "Prefill": f"{int(prefill_throughput)}tok/s",
+ "Decode": f"{int(decode_throughput)}tok/s",
+ })
+ for seq_id, token_ids in output:
+ outputs[seq_id] = token_ids
+ if use_tqdm:
+ pbar.update(1)
+ except Exception:
+ # Clean up on exception to prevent block leaks
+ self.reset()
+ raise
+ finally:
+ if use_tqdm:
+ pbar.close()
+
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
+ return outputs
diff --git a/Wan2GP/shared/llm_engines/nanovllm/engine/model_runner.py b/Wan2GP/shared/llm_engines/nanovllm/engine/model_runner.py
new file mode 100644
index 000000000..f9015a535
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/engine/model_runner.py
@@ -0,0 +1,814 @@
+import pickle
+import torch
+import torch.distributed as dist
+from multiprocessing.synchronize import Event
+from multiprocessing.shared_memory import SharedMemory
+import sys
+
+from nanovllm.config import Config
+from nanovllm.engine.sequence import Sequence
+from nanovllm.models.qwen3 import Qwen3ForCausalLM
+from nanovllm.layers.sampler import Sampler
+from nanovllm.utils.context import set_context, get_context, reset_context
+from nanovllm.utils.loader import load_model, WeightStore
+
+import socket
+
+
+def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
+ """Find an available port starting from start_port.
+
+ Args:
+ start_port: The starting port number to check
+ max_attempts: Maximum number of ports to try
+
+ Returns:
+ An available port number
+
+ Raises:
+ RuntimeError: If no available port is found within max_attempts
+ """
+ for i in range(max_attempts):
+ port = start_port + i
+ try:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind(('localhost', port))
+ return port
+ except OSError:
+ # Port is in use, try next one
+ continue
+ raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
+
+
+class ModelRunner:
+
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
+ # Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
+ torch._dynamo.config.capture_scalar_outputs = True
+
+ self.config = config
+ hf_config = config.hf_config
+ self.block_size = config.kvcache_block_size
+ self.enforce_eager = config.enforce_eager
+ self.world_size = config.tensor_parallel_size
+ self.rank = rank
+ self.event = event
+ if self.world_size > 1:
+ dist_port = find_available_port()
+ print(f"[debug]dist_port: {dist_port}")
+ # Use gloo backend on Windows, nccl on Linux/other platforms
+ backend = "gloo" if sys.platform == "win32" else "nccl"
+ dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
+ torch.cuda.set_device(rank)
+ else:
+ if torch.cuda.is_available():
+ torch.cuda.set_device(0)
+ default_dtype = torch.get_default_dtype()
+ # Use dtype instead of deprecated torch_dtype
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', None))
+
+ # Validate and convert config_dtype to a valid torch floating-point dtype
+ # Default to bfloat16 for CUDA (required for Flash Attention 2)
+ if config_dtype is None:
+ config_dtype = torch.bfloat16
+ elif isinstance(config_dtype, str):
+ # Convert string dtype to torch dtype
+ dtype_map = {
+ 'float32': torch.float32,
+ 'float16': torch.float16,
+ 'bfloat16': torch.bfloat16,
+ 'float64': torch.float64,
+ 'torch.float32': torch.float32,
+ 'torch.float16': torch.float16,
+ 'torch.bfloat16': torch.bfloat16,
+ 'torch.float64': torch.float64,
+ }
+ config_dtype = dtype_map.get(config_dtype.lower(), torch.bfloat16)
+ elif not isinstance(config_dtype, torch.dtype) or not config_dtype.is_floating_point:
+ # If not a valid floating-point torch dtype, default to bfloat16
+ config_dtype = torch.bfloat16
+
+ self.dtype = config_dtype # Save for later use
+ self.weight_load_mode = (config.weight_load_mode or "eager").lower()
+ self._weights_loaded = False
+ self._weight_store = None
+ self._is_quanto_int8 = False
+ self._graph_cache = {}
+ self._graph_cache_order = []
+ self._logits_bias_cache = {}
+ self._guard_counts = {}
+ self._guard_seen_details = set()
+ torch.set_default_dtype(config_dtype)
+ if self.weight_load_mode in ("lazy", "pinned"):
+ torch.set_default_device("cpu")
+ self.model = Qwen3ForCausalLM(hf_config)
+ self._weight_store = WeightStore(config.model_file or config.model_dir, mode=self.weight_load_mode)
+ self._is_quanto_int8 = bool(getattr(self._weight_store, "is_quanto_int8", False))
+ else:
+ torch.set_default_device("cuda")
+ self.model = Qwen3ForCausalLM(hf_config)
+ load_model(self.model, config.model_file or config.model_dir)
+ self._retie_word_embeddings_if_needed()
+ self._weights_loaded = True
+ self.sampler = Sampler()
+
+ # Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
+ # Must be called before warmup_model() since it uses these buffers
+ self._allocate_sample_buffers()
+
+ if self._weights_loaded:
+ self.warmup_model()
+ self.allocate_kv_cache()
+ if not self.enforce_eager:
+ self.capture_cudagraph()
+
+ torch.set_default_device("cpu")
+ torch.set_default_dtype(default_dtype)
+
+ if self.world_size > 1:
+ if rank == 0:
+ self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
+ dist.barrier()
+ else:
+ dist.barrier()
+ self.shm = SharedMemory(name="nanovllm")
+ self.loop()
+
+ def ensure_weights_loaded(self):
+ if self._weights_loaded:
+ return
+ default_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(self.dtype)
+ torch.set_default_device("cuda")
+ if self._is_quanto_int8:
+ for module in self.model.modules():
+ prepare = getattr(module, "prepare_for_quantized_load", None)
+ if callable(prepare):
+ prepare()
+ self.model = self.model.to("cuda")
+ load_model(self.model, "", weight_store=self._weight_store)
+ self._retie_word_embeddings_if_needed()
+ self._weights_loaded = True
+ self.warmup_model()
+ self.allocate_kv_cache()
+ if not self.enforce_eager:
+ self.capture_cudagraph()
+ torch.set_default_device("cpu")
+ torch.set_default_dtype(default_dtype)
+
+ def _retie_word_embeddings_if_needed(self):
+ # Some quantized checkpoints omit lm_head.weight and rely on tied embeddings.
+ # After device moves/load cycles, the tie can be broken; restore it explicitly.
+ try:
+ lm_head = getattr(self.model, "lm_head", None)
+ embed = getattr(self.model, "embed_tokens", None)
+ if lm_head is None or embed is None:
+ return
+ lm_w = getattr(lm_head, "weight", None)
+ emb_w = getattr(embed, "weight", None)
+ if lm_w is None or emb_w is None:
+ return
+ if lm_w.shape != emb_w.shape:
+ return
+ if lm_w.data_ptr() != emb_w.data_ptr():
+ lm_head.weight.data = emb_w.data
+ except Exception:
+ return
+
+ def unload_weights(self):
+ if not self._weights_loaded:
+ return
+ try:
+ self.model = self.model.to("cpu")
+ except Exception:
+ pass
+ # Clear attention KV cache refs so we don't write into freed storage later.
+ try:
+ for module in self.model.modules():
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
+ module.k_cache = module.v_cache = torch.tensor([])
+ except Exception:
+ pass
+ if hasattr(self, "kv_cache"):
+ try:
+ del self.kv_cache
+ except Exception:
+ pass
+ # CUDA graphs captured against previous weight/KV pointers are unsafe after unload/reload.
+ # Force recapture on next load to avoid stale-pointer illegal memory access.
+ try:
+ self.clear_graph_cache()
+ except Exception:
+ pass
+ try:
+ self.graphs = {}
+ self.graph_vars = {}
+ self.graph_bs = []
+ self.graph_pool = None
+ except Exception:
+ pass
+ try:
+ torch.cuda.empty_cache()
+ except Exception:
+ pass
+ self._logits_bias_cache.clear()
+ self._weights_loaded = False
+
+ def _get_graph_capture_signature(self):
+ model_ptr = -1
+ kv_ptr = -1
+ try:
+ first_param = next(self.model.parameters())
+ if first_param.is_cuda:
+ model_ptr = int(first_param.data_ptr())
+ except Exception:
+ pass
+ try:
+ if hasattr(self, "kv_cache") and torch.is_tensor(self.kv_cache) and self.kv_cache.is_cuda:
+ kv_ptr = int(self.kv_cache.data_ptr())
+ except Exception:
+ pass
+ return (model_ptr, kv_ptr, int(self.config.max_model_len), int(self.config.max_num_seqs))
+
+ def _drop_graph_cache_entry(self, cache_key):
+ entry = self._graph_cache.pop(cache_key, None)
+ if cache_key in self._graph_cache_order:
+ self._graph_cache_order.remove(cache_key)
+ if entry is None:
+ return
+ try:
+ del entry["graphs"]
+ del entry["pool"]
+ del entry["vars"]
+ del entry["bs"]
+ except Exception:
+ pass
+
+ def clear_graph_cache(self):
+ if self._graph_cache:
+ for key in list(self._graph_cache.keys()):
+ self._drop_graph_cache_entry(key)
+ self._graph_cache.clear()
+ self._graph_cache_order.clear()
+
+ def _note_guard(self, name: str, detail: str | None = None):
+ count = self._guard_counts.get(name, 0) + 1
+ self._guard_counts[name] = count
+ if detail:
+ detail_key = (name, detail)
+ if detail_key not in self._guard_seen_details:
+ print(f"[nanovllm][guard] {name}: {detail}")
+ self._guard_seen_details.add(detail_key)
+ return
+ if count == 1:
+ print(f"[nanovllm][guard] {name}")
+
+ def reset_guard_counts(self):
+ self._guard_counts.clear()
+ self._guard_seen_details.clear()
+
+ def get_guard_counts(self, reset: bool = False):
+ counts = dict(self._guard_counts)
+ if reset:
+ self.reset_guard_counts()
+ return counts
+
+ def _get_logits_bias(self, seq: Sequence, logits: torch.Tensor):
+ bias = getattr(seq, "logits_bias", None)
+ if bias is None or not torch.is_tensor(bias):
+ return None
+ key = (id(bias), logits.device, logits.dtype)
+ cached = self._logits_bias_cache.get(key)
+ if cached is not None:
+ return cached
+ cached = bias.to(device=logits.device, dtype=logits.dtype)
+ self._logits_bias_cache[key] = cached
+ return cached
+
+ @staticmethod
+ def _apply_logits_bias(logits_row: torch.Tensor, bias: torch.Tensor):
+ logits_row.add_(bias)
+
+ def _allocate_sample_buffers(self):
+ """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
+ max_bs = self.config.max_num_seqs
+ max_tokens = self.config.max_num_batched_tokens
+ max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
+
+ # Pre-allocate pinned memory buffers on CPU for fast transfer
+ # Must explicitly specify device="cpu" since default device may be "cuda"
+ self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
+ self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
+ self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
+ self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
+ self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
+
+ # Pre-allocate decode buffers on CPU with pinned memory
+ self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
+ self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
+ self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
+ self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
+
+ # Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
+ self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
+ self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
+ self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
+ self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
+
+ # Pre-allocate block tables buffer (shared by both decode and prefill)
+ self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
+
+ # Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
+ # Max length is max_model_len since sequences can be that long
+ self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
+
+ def exit(self):
+ if self.world_size > 1:
+ self.shm.close()
+ dist.barrier()
+ if self.rank == 0:
+ self.shm.unlink()
+ if not self.enforce_eager:
+ if hasattr(self, "graphs"):
+ del self.graphs
+ if hasattr(self, "graph_pool"):
+ del self.graph_pool
+ try:
+ torch.cuda.synchronize()
+ except Exception:
+ pass
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+ def loop(self):
+ while True:
+ method_name, args = self.read_shm()
+ self.call(method_name, *args)
+ if method_name == "exit":
+ break
+
+ def read_shm(self):
+ assert self.world_size > 1 and self.rank > 0
+ self.event.wait()
+ n = int.from_bytes(self.shm.buf[0:4], "little")
+ method_name, *args = pickle.loads(self.shm.buf[4:n+4])
+ self.event.clear()
+ return method_name, args
+
+ def write_shm(self, method_name, *args):
+ assert self.world_size > 1 and self.rank == 0
+ data = pickle.dumps([method_name, *args])
+ n = len(data)
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
+ self.shm.buf[4:n+4] = data
+ for event in self.event:
+ event.set()
+
+ def call(self, method_name, *args):
+ if self.world_size > 1 and self.rank == 0:
+ self.write_shm(method_name, *args)
+ method = getattr(self, method_name, None)
+ return method(*args)
+
+ def warmup_model(self):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
+ self.run(seqs, True)
+ torch.cuda.empty_cache()
+
+ def allocate_kv_cache(self):
+ config = self.config
+ hf_config = config.hf_config
+ free, total = torch.cuda.mem_get_info()
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * self.dtype.itemsize
+
+ # Calculate available memory for KV cache
+ # After warmup_model, empty_cache has been called, so current represents model memory only
+ # Use free memory but respect the gpu_memory_utilization limit
+ target_total_usage = total * config.gpu_memory_utilization
+ available_for_kv_cache = min(free * 0.9, target_total_usage - current)
+
+ # Ensure we have positive memory available
+ if available_for_kv_cache <= 0:
+ available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
+
+ config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
+ # Cap KV cache blocks to what is required by max_model_len and max_num_seqs.
+ # This keeps VRAM usage proportional to the requested token budget (incl. CFG).
+ required_blocks_per_seq = (config.max_model_len + self.block_size - 1) // self.block_size
+ required_total_blocks = required_blocks_per_seq * max(1, config.max_num_seqs)
+ if required_total_blocks > 0:
+ config.num_kvcache_blocks = min(config.num_kvcache_blocks, required_total_blocks)
+ if config.num_kvcache_blocks <= 0:
+ raise RuntimeError(
+ f"Insufficient GPU memory for KV cache. "
+ f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
+ f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
+ f"Block size: {block_bytes / 1024**2:.2f} MB"
+ )
+ self.kv_cache = torch.empty(
+ 2,
+ hf_config.num_hidden_layers,
+ config.num_kvcache_blocks,
+ self.block_size,
+ num_kv_heads,
+ head_dim,
+ device="cuda",
+ dtype=self.dtype,
+ )
+ layer_id = 0
+ for module in self.model.modules():
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
+ module.k_cache = self.kv_cache[0, layer_id]
+ module.v_cache = self.kv_cache[1, layer_id]
+ layer_id += 1
+
+ def prepare_block_tables(self, seqs: list[Sequence]):
+ max_len = max(len(seq.block_table) for seq in seqs)
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
+ return block_tables
+
+ def prepare_prefill(self, seqs: list[Sequence]):
+ input_ids = []
+ positions = []
+ cu_seqlens_q = [0]
+ cu_seqlens_k = [0]
+ max_seqlen_q = 0
+ max_seqlen_k = 0
+ slot_mapping = []
+ block_tables = None
+ for seq in seqs:
+ seqlen = len(seq)
+ input_ids.extend(seq[seq.num_cached_tokens:])
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
+ seqlen_q = seqlen - seq.num_cached_tokens
+ seqlen_k = seqlen
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
+ if not seq.block_table: # warmup: no blocks allocated yet
+ slot_mapping.extend([-1] * seqlen_q)
+ continue
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
+ start = seq.block_table[i] * self.block_size
+ if i != seq.num_blocks - 1:
+ end = start + self.block_size
+ else:
+ end = start + seq.last_block_num_tokens
+ slot_mapping.extend(list(range(start, end)))
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
+ block_tables = self.prepare_block_tables(seqs)
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
+ return input_ids, positions
+
+ def prepare_decode(self, seqs: list[Sequence]):
+ """Optimized decode preparation using pre-allocated buffers."""
+ bs = len(seqs)
+
+ # Use pre-allocated CPU buffers
+ for i, seq in enumerate(seqs):
+ self._cpu_input_ids[i] = seq.last_token
+ self._cpu_positions[i] = len(seq) - 1
+ self._cpu_context_lens[i] = len(seq)
+ self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
+
+ # Transfer to GPU using sliced views
+ input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
+ positions = self._cpu_positions[:bs].cuda(non_blocking=True)
+ slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
+ context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
+ block_tables = self.prepare_block_tables(seqs)
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
+ return input_ids, positions
+
+ def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
+ """Optimized sample preparation using pre-allocated buffers."""
+ if is_cfg_batch:
+ num_seqs = len(seqs) // 2
+ target_seqs = seqs[:num_seqs]
+ else:
+ num_seqs = len(seqs)
+ target_seqs = seqs
+
+ # Fill pre-allocated CPU buffers
+ top_ks_is_zero = True
+ top_ps_is_one = True
+ repetition_penalties_is_one = True
+ for i, seq in enumerate(target_seqs):
+ self._cpu_temperatures[i] = seq.temperature
+ self._cpu_cfg_scales[i] = seq.cfg_scale
+ self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
+ if seq.top_k is not None and seq.top_k > 0:
+ top_ks_is_zero = False
+ self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
+ if seq.top_p is not None and seq.top_p != 1.0:
+ top_ps_is_one = False
+ self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
+ if seq.repetition_penalty is not None and seq.repetition_penalty != 1.0:
+ repetition_penalties_is_one = False
+
+ # Transfer to GPU using sliced views (single batched transfer)
+ temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
+ cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
+
+ return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
+
+ @torch.inference_mode()
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
+ return self.model.compute_logits(self.model(input_ids, positions))
+ else:
+ bs = input_ids.size(0)
+ context = get_context()
+
+ # Check if block_tables size exceeds pre-allocated buffer size
+ # This can happen when conditional and unconditional sequences have different lengths
+ # in CFG mode, causing block_tables to have more columns than expected
+ max_num_blocks = self.graph_vars["block_tables"].size(1)
+ if context.block_tables.size(1) > max_num_blocks:
+ # Fall back to eager mode when block_tables is too large for CUDA graph
+ self._note_guard(
+ "cudagraph_fallback_block_table_cols",
+ f"requested={context.block_tables.size(1)} max={max_num_blocks}",
+ )
+ return self.model.compute_logits(self.model(input_ids, positions))
+
+ # Fix: Also check if block_tables row count matches batch size
+ # Dimension mismatch can cause CUDA illegal memory access during graph replay
+ if context.block_tables.size(0) != bs:
+ # Fall back to eager mode when block_tables row count doesn't match batch size
+ self._note_guard(
+ "cudagraph_fallback_block_table_rows",
+ f"rows={context.block_tables.size(0)} bs={bs}",
+ )
+ return self.model.compute_logits(self.model(input_ids, positions))
+
+ # Fix: Verify slot_mapping and context_lens dimensions match batch size
+ if context.slot_mapping.size(0) != bs or context.context_lens.size(0) != bs:
+ # Fall back to eager mode when dimensions don't match
+ self._note_guard(
+ "cudagraph_fallback_context_shape",
+ f"slot={context.slot_mapping.size(0)} ctx={context.context_lens.size(0)} bs={bs}",
+ )
+ return self.model.compute_logits(self.model(input_ids, positions))
+
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
+ graph_vars = self.graph_vars
+ graph_vars["input_ids"][:bs] = input_ids
+ graph_vars["positions"][:bs] = positions
+ graph_vars["slot_mapping"].fill_(-1)
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
+ graph_vars["context_lens"].zero_()
+ graph_vars["context_lens"][:bs] = context.context_lens
+ # Clear block_tables first to ensure no stale data from previous runs
+ graph_vars["block_tables"][:bs].fill_(-1)
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
+ graph.replay()
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
+
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
+ """Run model forward and sampling. For CFG sequences, batch is structured as:
+ [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
+ where uncond_seqi is the paired unconditional sequence of cond_seqi."""
+ self.ensure_weights_loaded()
+ # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
+ is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
+ if is_cfg_batch:
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
+ num_cond = len(seqs) // 2
+ cond_seqs = seqs[:num_cond]
+ # uncond_seqs = seqs[num_cond:]
+
+ # Prepare inputs for both conditional and unconditional (they're already in the batch)
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
+ if sample_params is not None:
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
+ else:
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
+
+ # Run model forward (processes entire batch: cond + uncond)
+ logits_all = self.run_model(input_ids, positions, is_prefill)
+ reset_context()
+
+ if self.rank == 0:
+ # Split logits: first half is conditional, second half is unconditional
+ logits_cond = logits_all[:num_cond]
+ logits_uncond = logits_all[num_cond:]
+
+ # Apply repetition penalty to conditional logits (before CFG)
+ if repetition_penalties is not None:
+ for i, seq in enumerate(cond_seqs):
+ penalty = repetition_penalties[i].item()
+ if penalty != 1.0:
+ # Only penalize completion tokens (not prompt tokens)
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
+ if len(completion_tokens) > 0:
+ # Create token mask: mark tokens that appeared in completion
+ token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
+ token_mask[completion_tokens] = True
+
+ # Apply standard repetition penalty formula (matching transformers implementation):
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
+ penalty_scores = torch.where(
+ logits_cond[i] < 0,
+ logits_cond[i] * penalty,
+ logits_cond[i] / penalty
+ )
+ # Only apply penalty to tokens that appeared in completion
+ logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
+
+ # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
+ cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
+ logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
+
+ # Apply optional per-sequence logits bias before processors/sampling.
+ for i, seq in enumerate(cond_seqs):
+ bias = self._get_logits_bias(seq, logits_cfg)
+ if bias is not None:
+ self._apply_logits_bias(logits_cfg[i], bias)
+
+ # Apply logits processor for constrained decoding (if any sequence has one)
+ for i, seq in enumerate(cond_seqs):
+ if seq.logits_processor is not None:
+ # Create input_ids tensor for this sequence
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
+ # Apply processor to this sequence's logits
+ logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
+
+ # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
+ # cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
+
+ # Sample from CFG logits
+ token_ids_cfg = self.sampler(
+ logits_cfg,
+ temperatures,
+ top_ks=top_ks if top_ks is not None else None,
+ top_ps=top_ps if top_ps is not None else None,
+ repetition_penalties=None, # Already applied above
+ # input_ids=cond_input_ids,
+ ).tolist()
+
+ # Update logits processor state after sampling
+ # NOTE: Only update for the first sequence since all sequences share the same processor
+ # Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
+ if cond_seqs and cond_seqs[0].logits_processor_update_state is not None:
+ cond_seqs[0].logits_processor_update_state(token_ids_cfg[0])
+
+ # Return token_ids (will be applied to both conditional and unconditional sequences)
+ return token_ids_cfg
+ else:
+ return None
+ else:
+ # Normal batch (non-CFG)
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
+ else self.prepare_decode(seqs))
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
+ if sample_params is not None:
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
+ else:
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
+ logits = self.run_model(input_ids, positions, is_prefill)
+ reset_context()
+
+ if self.rank == 0:
+ # Apply repetition penalty to logits
+ if repetition_penalties is not None:
+ for i, seq in enumerate(seqs):
+ penalty = repetition_penalties[i].item()
+ if penalty != 1.0:
+ # Only penalize completion tokens (not prompt tokens)
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
+ if len(completion_tokens) > 0:
+ # Create token mask: mark tokens that appeared in completion
+ token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
+ token_mask[completion_tokens] = True
+
+ # Apply standard repetition penalty formula (matching transformers implementation):
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
+ penalty_scores = torch.where(
+ logits[i] < 0,
+ logits[i] * penalty,
+ logits[i] / penalty
+ )
+ # Only apply penalty to tokens that appeared in completion
+ logits[i] = torch.where(token_mask, penalty_scores, logits[i])
+
+ # Apply logits processor for constrained decoding (if any sequence has one)
+ # Clone logits to avoid in-place update issues in inference mode
+ logits = logits.clone()
+ for i, seq in enumerate(seqs):
+ bias = self._get_logits_bias(seq, logits)
+ if bias is not None:
+ self._apply_logits_bias(logits[i], bias)
+ for i, seq in enumerate(seqs):
+ if seq.logits_processor is not None:
+ # Create input_ids tensor for this sequence
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
+ # Apply processor to this sequence's logits (clone to avoid inference mode issues)
+ processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
+ logits[i] = processed[0]
+
+ # Prepare input_ids for sampler
+ # seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
+
+ token_ids = self.sampler(
+ logits,
+ temperatures,
+ top_ks=top_ks if top_ks is not None else None,
+ top_ps=top_ps if top_ps is not None else None,
+ repetition_penalties=None, # Already applied above
+ # input_ids=seq_input_ids,
+ ).tolist()
+
+ # Update logits processor state after sampling
+ # NOTE: Only update for the first sequence since all sequences may share the same processor
+ # (when using a single SamplingParams for batch generation)
+ # Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
+ if seqs and seqs[0].logits_processor_update_state is not None:
+ seqs[0].logits_processor_update_state(token_ids[0])
+
+ return token_ids
+ else:
+ return None
+
+ @torch.inference_mode()
+ def capture_cudagraph(self):
+ config = self.config
+ cache_key = (config.max_model_len, config.max_num_seqs)
+ cached = self._graph_cache.get(cache_key)
+ if cached is not None:
+ current_sig = self._get_graph_capture_signature()
+ if cached.get("sig") == current_sig:
+ self.graphs = cached["graphs"]
+ self.graph_pool = cached["pool"]
+ self.graph_vars = cached["vars"]
+ self.graph_bs = cached["bs"]
+ if cache_key in self._graph_cache_order:
+ self._graph_cache_order.remove(cache_key)
+ self._graph_cache_order.append(cache_key)
+ return
+ self._drop_graph_cache_entry(cache_key)
+ hf_config = config.hf_config
+ max_bs = min(self.config.max_num_seqs, 512)
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
+ positions = torch.zeros(max_bs, dtype=torch.int64)
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
+ self.graphs = {}
+ self.graph_pool = None
+
+ for bs in reversed(self.graph_bs):
+ graph = torch.cuda.CUDAGraph()
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
+ with torch.cuda.graph(graph, self.graph_pool):
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
+ if self.graph_pool is None:
+ self.graph_pool = graph.pool()
+ self.graphs[bs] = graph
+ torch.cuda.synchronize()
+ reset_context()
+
+ self.graph_vars = dict(
+ input_ids=input_ids,
+ positions=positions,
+ slot_mapping=slot_mapping,
+ context_lens=context_lens,
+ block_tables=block_tables,
+ outputs=outputs,
+ )
+ self._graph_cache[cache_key] = {
+ "graphs": self.graphs,
+ "pool": self.graph_pool,
+ "vars": self.graph_vars,
+ "bs": self.graph_bs,
+ "sig": self._get_graph_capture_signature(),
+ }
+ if cache_key in self._graph_cache_order:
+ self._graph_cache_order.remove(cache_key)
+ self._graph_cache_order.append(cache_key)
+ while len(self._graph_cache_order) > 5:
+ old_key = self._graph_cache_order.pop(0)
+ self._drop_graph_cache_entry(old_key)
diff --git a/Wan2GP/shared/llm_engines/nanovllm/engine/scheduler.py b/Wan2GP/shared/llm_engines/nanovllm/engine/scheduler.py
new file mode 100644
index 000000000..3178c2216
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/engine/scheduler.py
@@ -0,0 +1,230 @@
+from collections import deque
+
+from nanovllm.config import Config
+from nanovllm.engine.sequence import Sequence, SequenceStatus
+from nanovllm.engine.block_manager import BlockManager
+
+
+class Scheduler:
+
+ def __init__(self, config: Config):
+ self.max_num_seqs = config.max_num_seqs
+ self.max_num_batched_tokens = config.max_num_batched_tokens
+ self.eos = config.eos
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
+ self.waiting: deque[Sequence] = deque()
+ self.running: deque[Sequence] = deque()
+
+ def is_finished(self):
+ return not self.waiting and not self.running
+
+ def add(self, seq: Sequence):
+ self.waiting.append(seq)
+
+ def schedule(self) -> tuple[list[Sequence], bool]:
+ # prefill
+ scheduled_seqs = []
+ num_seqs = 0
+ num_batched_tokens = 0
+ processed_seqs = set() # Track processed sequences to handle CFG pairs
+
+ while self.waiting and num_seqs < self.max_num_seqs:
+ seq = self.waiting[0]
+
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
+ # This is a conditional sequence, need to schedule its paired unconditional sequence too
+ paired_seq = seq.paired_seq
+ if paired_seq.status != SequenceStatus.WAITING:
+ # Paired sequence not in waiting, skip this conditional sequence for now
+ break
+
+ # Calculate tokens for both sequences
+ total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
+
+ # FIX: Check if we have enough blocks for BOTH sequences combined
+ # The old check was wrong: it checked each sequence independently,
+ # but didn't account for the total blocks needed by both
+ total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
+ can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
+
+ if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
+ break
+
+ # Schedule both sequences: conditional first, then unconditional
+ for s in [seq, paired_seq]:
+ num_seqs += 1
+ self.block_manager.allocate(s)
+ num_batched_tokens += len(s) - s.num_cached_tokens
+ s.status = SequenceStatus.RUNNING
+ self.waiting.remove(s)
+ self.running.append(s)
+ scheduled_seqs.append(s)
+ processed_seqs.add(s.seq_id)
+ else:
+ # Normal sequence or unconditional sequence (already processed with its conditional)
+ if seq.seq_id in processed_seqs:
+ # Skip if already processed as part of a CFG pair
+ self.waiting.popleft()
+ continue
+
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
+ break
+ num_seqs += 1
+ self.block_manager.allocate(seq)
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
+ seq.status = SequenceStatus.RUNNING
+ self.waiting.popleft()
+ self.running.append(seq)
+ scheduled_seqs.append(seq)
+
+ if scheduled_seqs:
+ # For CFG batches, ensure conditional sequences come before their unconditional pairs
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
+
+ # Reorder: non-CFG, then CFG conditional, then CFG unconditional
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
+ return scheduled_seqs, True
+
+ # decode
+ processed_seqs = set()
+ temp_running = list(self.running) # Work with a copy
+
+ while temp_running and num_seqs < self.max_num_seqs:
+ seq = temp_running.pop(0)
+
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
+ paired_seq = seq.paired_seq
+ if paired_seq not in temp_running:
+ # Paired sequence not available, skip for now
+ continue
+
+ # Remove paired_seq from temp_running
+ temp_running.remove(paired_seq)
+
+ # FIX: Check if we have enough blocks for BOTH sequences to append
+ # Each sequence needs 1 block when at block boundary (len % block_size == 1)
+ block_size = self.block_manager.block_size
+ blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
+ blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
+ total_blocks_needed = blocks_needed_seq + blocks_needed_paired
+ can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
+
+ if not can_append_both:
+ # Try preempting other sequences
+ preempted = False
+ while not can_append_both and temp_running:
+ other_seq = temp_running.pop(0)
+ if other_seq != seq and other_seq != paired_seq:
+ self.preempt(other_seq)
+ # Recalculate with the same correct logic
+ can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
+ preempted = True
+ else:
+ temp_running.append(other_seq)
+ break
+
+ if not can_append_both:
+ # Can't schedule this pair right now
+ temp_running.append(seq)
+ temp_running.append(paired_seq)
+ continue
+
+ # Schedule both sequences
+ for s in [seq, paired_seq]:
+ num_seqs += 1
+ self.block_manager.may_append(s)
+ scheduled_seqs.append(s)
+ processed_seqs.add(s.seq_id)
+ # Remove from actual running list if scheduled
+ if s in self.running:
+ self.running.remove(s)
+ else:
+ # Normal sequence or unconditional (already processed)
+ if seq.seq_id in processed_seqs:
+ continue
+
+ while not self.block_manager.can_append(seq):
+ if temp_running:
+ other_seq = temp_running.pop(0)
+ if other_seq != seq:
+ self.preempt(other_seq)
+ else:
+ temp_running.append(other_seq)
+ break
+ else:
+ self.preempt(seq)
+ if seq in self.running:
+ self.running.remove(seq)
+ break
+ else:
+ num_seqs += 1
+ self.block_manager.may_append(seq)
+ scheduled_seqs.append(seq)
+ if seq in self.running:
+ self.running.remove(seq)
+
+ assert scheduled_seqs
+
+ # For CFG batches in decode, ensure conditional sequences come before unconditional
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
+
+ self.running.extendleft(reversed(scheduled_seqs))
+ return scheduled_seqs, False
+
+ def preempt(self, seq: Sequence):
+ seq.status = SequenceStatus.WAITING
+ self.block_manager.deallocate(seq)
+ self.waiting.appendleft(seq)
+
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
+ # Check if this is a CFG batch
+ is_cfg_batch = False
+ if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
+ num_cond = len(seqs) // 2
+ is_cfg_batch = (num_cond > 0 and
+ not seqs[0].is_unconditional and
+ seqs[num_cond].is_unconditional)
+
+ if is_cfg_batch:
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
+ # token_ids correspond to conditional sequences only (sampled from CFG logits)
+ num_cond = len(seqs) // 2
+ cond_seqs = seqs[:num_cond]
+ uncond_seqs = seqs[num_cond:]
+
+ # Apply the same sampled token to both conditional and unconditional sequences
+ for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
+ cond_seq.append_token(token_id)
+ uncond_seq.append_token(token_id) # Same token for unconditional
+
+ # Check if either sequence is finished
+ cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
+ cond_seq.num_completion_tokens == cond_seq.max_tokens)
+ uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
+ uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
+
+ if cond_finished or uncond_finished:
+ # Mark both as finished
+ cond_seq.status = SequenceStatus.FINISHED
+ uncond_seq.status = SequenceStatus.FINISHED
+ self.block_manager.deallocate(cond_seq)
+ self.block_manager.deallocate(uncond_seq)
+ if cond_seq in self.running:
+ self.running.remove(cond_seq)
+ if uncond_seq in self.running:
+ self.running.remove(uncond_seq)
+ else:
+ # Normal batch
+ for seq, token_id in zip(seqs, token_ids):
+ seq.append_token(token_id)
+ if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
+ seq.status = SequenceStatus.FINISHED
+ self.block_manager.deallocate(seq)
+ self.running.remove(seq)
diff --git a/Wan2GP/shared/llm_engines/nanovllm/engine/sequence.py b/Wan2GP/shared/llm_engines/nanovllm/engine/sequence.py
new file mode 100644
index 000000000..bda2e53c1
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/engine/sequence.py
@@ -0,0 +1,97 @@
+from copy import copy
+from enum import Enum, auto
+from itertools import count
+from typing import Optional, Callable, Any
+
+from nanovllm.sampling_params import SamplingParams
+
+
+class SequenceStatus(Enum):
+ WAITING = auto()
+ RUNNING = auto()
+ FINISHED = auto()
+
+
+class Sequence:
+ block_size = 256
+ counter = count()
+
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
+ self.seq_id = next(Sequence.counter)
+ self.status = SequenceStatus.WAITING
+ self.token_ids = copy(token_ids)
+ self.last_token = token_ids[-1]
+ self.num_tokens = len(self.token_ids)
+ self.num_prompt_tokens = len(token_ids)
+ self.num_cached_tokens = 0
+ self.block_table = []
+ self.temperature = sampling_params.temperature
+ self.max_tokens = sampling_params.max_tokens
+ self.ignore_eos = sampling_params.ignore_eos
+ self.cfg_scale = sampling_params.cfg_scale
+ self.top_k = sampling_params.top_k
+ self.top_p = sampling_params.top_p
+ self.repetition_penalty = sampling_params.repetition_penalty
+ # For CFG: mark if this is an unconditional sequence
+ self.is_unconditional = is_unconditional
+ # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
+ # For conditional sequences, this points to the unconditional sequence
+ self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
+ # For constrained decoding: logits processor and state update callback
+ self.logits_processor: Optional[Any] = sampling_params.logits_processor
+ self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
+ self.logits_bias: Optional[Any] = sampling_params.logits_bias
+
+ def __len__(self):
+ return self.num_tokens
+
+ def __getitem__(self, key):
+ return self.token_ids[key]
+
+ @property
+ def is_finished(self):
+ return self.status == SequenceStatus.FINISHED
+
+ @property
+ def num_completion_tokens(self):
+ return self.num_tokens - self.num_prompt_tokens
+
+ @property
+ def prompt_token_ids(self):
+ return self.token_ids[:self.num_prompt_tokens]
+
+ @property
+ def completion_token_ids(self):
+ return self.token_ids[self.num_prompt_tokens:]
+
+ @property
+ def num_cached_blocks(self):
+ return self.num_cached_tokens // self.block_size
+
+ @property
+ def num_blocks(self):
+ return (self.num_tokens + self.block_size - 1) // self.block_size
+
+ @property
+ def last_block_num_tokens(self):
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
+
+ def block(self, i):
+ assert 0 <= i < self.num_blocks
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
+
+ def append_token(self, token_id: int):
+ self.token_ids.append(token_id)
+ self.last_token = token_id
+ self.num_tokens += 1
+
+ def __getstate__(self):
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
+
+ def __setstate__(self, state):
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
+ if self.num_completion_tokens == 0:
+ self.token_ids = state[-1]
+ else:
+ self.last_token = state[-1]
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/activation.py b/Wan2GP/shared/llm_engines/nanovllm/layers/activation.py
new file mode 100644
index 000000000..041ee2008
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/activation.py
@@ -0,0 +1,14 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class SiluAndMul(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ @torch.compile
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, y = x.chunk(2, -1)
+ return F.silu(x) * y
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/attention.py b/Wan2GP/shared/llm_engines/nanovllm/layers/attention.py
new file mode 100644
index 000000000..e416139ea
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/attention.py
@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+import triton
+import triton.language as tl
+
+from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
+from nanovllm.utils.context import get_context
+
+
+@triton.jit
+def store_kvcache_kernel(
+ key_ptr,
+ key_stride,
+ value_ptr,
+ value_stride,
+ k_cache_ptr,
+ v_cache_ptr,
+ slot_mapping_ptr,
+ D: tl.constexpr,
+):
+ idx = tl.program_id(0)
+ slot = tl.load(slot_mapping_ptr + idx)
+ if slot == -1: return
+ key_offsets = idx * key_stride + tl.arange(0, D)
+ value_offsets = idx * value_stride + tl.arange(0, D)
+ key = tl.load(key_ptr + key_offsets)
+ value = tl.load(value_ptr + value_offsets)
+ cache_offsets = slot * D + tl.arange(0, D)
+ tl.store(k_cache_ptr + cache_offsets, key)
+ tl.store(v_cache_ptr + cache_offsets, value)
+
+
+def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
+ N, num_heads, head_dim = key.shape
+ D = num_heads * head_dim
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
+ assert slot_mapping.numel() == N
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
+
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ num_heads,
+ head_dim,
+ scale,
+ num_kv_heads,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = scale
+ self.num_kv_heads = num_kv_heads
+ self.k_cache = self.v_cache = torch.tensor([])
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ context = get_context()
+ k_cache, v_cache = self.k_cache, self.v_cache
+ if k_cache.numel() and v_cache.numel():
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
+ if context.is_prefill:
+ if context.block_tables is not None: # prefix cache
+ k, v = k_cache, v_cache
+ o = flash_attn_varlen_func(q, k, v,
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
+ else: # decode
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
+ softmax_scale=self.scale, causal=True)
+ return o
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/embed_head.py b/Wan2GP/shared/llm_engines/nanovllm/layers/embed_head.py
new file mode 100644
index 000000000..4475b41ea
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/embed_head.py
@@ -0,0 +1,71 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from nanovllm.utils.context import get_context
+
+
+def _get_tp_info():
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_rank(), dist.get_world_size()
+ return 0, 1
+
+
+class VocabParallelEmbedding(nn.Module):
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ ):
+ super().__init__()
+ self.tp_rank, self.tp_size = _get_tp_info()
+ assert num_embeddings % self.tp_size == 0
+ self.num_embeddings = num_embeddings
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
+ self.weight.weight_loader = self.weight_loader
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(0)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor):
+ if self.tp_size > 1:
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
+ x = mask * (x - self.vocab_start_idx)
+ y = F.embedding(x, self.weight)
+ if self.tp_size > 1:
+ y = mask.unsqueeze(1) * y
+ dist.all_reduce(y)
+ return y
+
+
+class ParallelLMHead(VocabParallelEmbedding):
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ bias: bool = False,
+ ):
+ assert not bias
+ super().__init__(num_embeddings, embedding_dim)
+
+ def forward(self, x: torch.Tensor):
+ context = get_context()
+ if context.is_prefill:
+ last_indices = context.cu_seqlens_q[1:] - 1
+ x = x[last_indices].contiguous()
+ logits = F.linear(x, self.weight)
+ if self.tp_size > 1:
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
+ dist.gather(logits, all_logits, 0)
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
+ return logits
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/layernorm.py b/Wan2GP/shared/llm_engines/nanovllm/layers/layernorm.py
new file mode 100644
index 000000000..71bf4198f
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/layernorm.py
@@ -0,0 +1,50 @@
+import torch
+from torch import nn
+
+
+class RMSNorm(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+
+ @torch.compile
+ def rms_forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ orig_dtype = x.dtype
+ x = x.float()
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x
+
+ @torch.compile
+ def add_rms_forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_dtype = x.dtype
+ x = x.float().add_(residual.float())
+ residual = x.to(orig_dtype)
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x, residual
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ return self.rms_forward(x)
+ else:
+ return self.add_rms_forward(x, residual)
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/linear.py b/Wan2GP/shared/llm_engines/nanovllm/layers/linear.py
new file mode 100644
index 000000000..3476cbbd4
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/linear.py
@@ -0,0 +1,416 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.distributed as dist
+import math
+
+try:
+ from shared.kernels import quanto_int8_triton as _shared_quanto_int8_triton
+except Exception: # pragma: no cover
+ _shared_quanto_int8_triton = None # type: ignore
+
+
+def _shared_kernel_available() -> bool:
+ if _shared_quanto_int8_triton is None:
+ return False
+ is_available = getattr(_shared_quanto_int8_triton, "is_available", None)
+ if not callable(is_available):
+ return False
+ try:
+ return bool(is_available())
+ except Exception:
+ return False
+
+
+_TRITON_AVAILABLE = _shared_kernel_available()
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+def _get_tp_info():
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_rank(), dist.get_world_size()
+ return 0, 1
+
+
+def _flatten_scale(scale: torch.Tensor) -> torch.Tensor:
+ if scale.ndim == 2 and scale.shape[1] == 1:
+ return scale.view(-1)
+ if scale.ndim == 1:
+ return scale
+ return scale.reshape(-1)
+
+
+def _run_triton_fused_int8_mm(
+ x2d: torch.Tensor,
+ qweight_t: torch.Tensor,
+ qweight_scale_fp32: torch.Tensor,
+ input_scale: float | None = None,
+) -> torch.Tensor:
+ del input_scale
+ if _shared_quanto_int8_triton is None:
+ raise RuntimeError("shared.kernels.quanto_int8_triton is unavailable")
+ run_kernel = getattr(_shared_quanto_int8_triton, "fused_quant_scaled_mm_transposed", None)
+ if run_kernel is None:
+ raise RuntimeError("shared.kernels.quanto_int8_triton.fused_quant_scaled_mm_transposed is missing")
+
+ m, k = x2d.shape
+ k2, n = qweight_t.shape
+ if k != k2:
+ raise RuntimeError(f"Triton int8 GEMM shape mismatch: x={x2d.shape}, w_t={qweight_t.shape}")
+ qweight_scale_fp32 = _flatten_scale(qweight_scale_fp32)
+ if qweight_scale_fp32.numel() != n:
+ raise RuntimeError(
+ f"Triton int8 qweight_scale length mismatch: expected {n}, got {qweight_scale_fp32.numel()}"
+ )
+ return run_kernel(x2d, qweight_t, qweight_scale_fp32, out_dtype=x2d.dtype)
+
+
+class LinearBase(nn.Module):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ):
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank, self.tp_size = _get_tp_info()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
+ self.weight.weight_loader = self.weight_loader
+ self.register_buffer("qweight_data", torch.empty(0, dtype=torch.int8))
+ self.register_buffer("qweight_t", torch.empty(0, dtype=torch.int8))
+ self.register_buffer("qweight_scale", torch.empty(0))
+ self.register_buffer("qweight_scale_fp32", torch.empty(0, dtype=torch.float32))
+ self.register_buffer("input_scale", torch.ones((), dtype=torch.bfloat16))
+ self.register_buffer("output_scale", torch.ones((), dtype=torch.bfloat16))
+ self.use_int8_weight = False
+ self.use_triton_int8 = False
+ self._input_scale_value = 1.0
+ self._quant_expected_shards = 1
+ self._quant_data_loaded = set()
+ self._quant_scale_loaded = set()
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_size))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+ def _shard_weight(self, loaded_weight: torch.Tensor, loaded_shard_id=None) -> torch.Tensor:
+ return loaded_weight
+
+ def _shard_weight_scale(self, loaded_scale: torch.Tensor, loaded_shard_id=None) -> torch.Tensor:
+ return loaded_scale
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id=None):
+ param.data.copy_(self._shard_weight(loaded_weight, loaded_shard_id))
+
+ def quant_weight_data_loader(self, loaded_weight: torch.Tensor, loaded_shard_id=None):
+ device = self.weight.device
+ shard = self._shard_weight(loaded_weight, loaded_shard_id).contiguous()
+ q = shard.to(device=device, dtype=torch.int8, non_blocking=True)
+ self.qweight_data = q
+ shard_key = loaded_shard_id if loaded_shard_id is not None else 0
+ self._quant_data_loaded.add(shard_key)
+
+ def quant_weight_scale_loader(self, loaded_scale: torch.Tensor, loaded_shard_id=None):
+ device = self.weight.device if self.weight.numel() != 0 else self.qweight_data.device
+ shard_scale = self._shard_weight_scale(loaded_scale, loaded_shard_id).contiguous()
+ if shard_scale.dim() == 2 and shard_scale.size(1) == 1:
+ shard_scale = shard_scale.squeeze(1)
+ self.qweight_scale = shard_scale.to(device=device, dtype=torch.bfloat16, non_blocking=True)
+ shard_key = loaded_shard_id if loaded_shard_id is not None else 0
+ self._quant_scale_loaded.add(shard_key)
+
+ def quant_input_scale_loader(self, loaded_scale: torch.Tensor):
+ device = self.weight.device if self.weight.numel() != 0 else self.qweight_data.device
+ self.input_scale = loaded_scale.to(device=device, dtype=torch.bfloat16, non_blocking=True)
+
+ def quant_output_scale_loader(self, loaded_scale: torch.Tensor):
+ device = self.weight.device if self.weight.numel() != 0 else self.qweight_data.device
+ self.output_scale = loaded_scale.to(device=device, dtype=torch.bfloat16, non_blocking=True)
+
+ def finalize_quantized(self):
+ if self.qweight_data.numel() == 0 or self.qweight_scale.numel() == 0:
+ return
+ if len(self._quant_data_loaded) < self._quant_expected_shards:
+ return
+ if len(self._quant_scale_loaded) < self._quant_expected_shards:
+ return
+ # Keep int8 weights in KxN layout for int8 GEMM paths.
+ self.qweight_t = self.qweight_data.transpose(0, 1).contiguous()
+ self.qweight_data = torch.empty(0, dtype=torch.int8, device=self.qweight_t.device)
+ self.qweight_scale_fp32 = self.qweight_scale.to(dtype=torch.float32)
+ if not torch.is_tensor(self.input_scale):
+ raise RuntimeError(f"Invalid input_scale type: {type(self.input_scale)}")
+ input_scale_flat = self.input_scale.reshape(-1)
+ if input_scale_flat.numel() != 1:
+ raise RuntimeError(
+ f"Expected scalar input_scale, got shape={tuple(self.input_scale.shape)}"
+ )
+ input_scale_value = float(input_scale_flat[0].item())
+ if not math.isfinite(input_scale_value) or input_scale_value <= 0:
+ raise RuntimeError(f"Invalid input_scale value: {input_scale_value}")
+ self._input_scale_value = max(input_scale_value, 1e-8)
+ self.use_triton_int8 = bool(_TRITON_AVAILABLE and self.qweight_t.is_cuda)
+ self.use_int8_weight = True
+ device = self.qweight_t.device
+ if self.weight.numel() != 0:
+ self.weight = nn.Parameter(torch.empty(0, device=device, dtype=torch.bfloat16), requires_grad=False)
+ if self.bias is not None:
+ self.bias = nn.Parameter(self.bias.data.to(device=device), requires_grad=False)
+
+ def _quant_int8_mm(self, x: torch.Tensor) -> torch.Tensor:
+ x_shape = x.shape
+ x2d = x.reshape(-1, x_shape[-1])
+ if self.use_triton_int8 and x2d.is_cuda and self.qweight_t.numel() != 0 and self.qweight_scale_fp32.numel() != 0:
+ y = _run_triton_fused_int8_mm(
+ x2d,
+ self.qweight_t,
+ self.qweight_scale_fp32,
+ )
+ return y.view(*x_shape[:-1], y.size(-1))
+
+ # Fallback path: use Quanto qbytes_mm for dynamic activation quantization parity.
+ if not hasattr(torch.ops, "quanto") or not hasattr(torch.ops.quanto, "qbytes_mm"):
+ raise RuntimeError("quanto.qbytes_mm op unavailable for int8 fallback path")
+ if self.qweight_t.numel() == 0 and self.qweight_data.numel() != 0:
+ self.qweight_t = self.qweight_data.transpose(0, 1).contiguous()
+ qweight = self.qweight_t.transpose(0, 1).contiguous()
+ scales = self.qweight_scale_fp32
+ if scales.numel() == 0:
+ scales = self.qweight_scale.to(torch.float32)
+ scales = _flatten_scale(scales).reshape(-1, 1).contiguous()
+ out = torch.ops.quanto.qbytes_mm(x2d, qweight, scales).to(x2d.dtype)
+ return out.view(*x_shape[:-1], out.size(-1))
+
+ def prepare_for_quantized_load(self):
+ device = None
+ if self.qweight_t.numel() != 0:
+ device = self.qweight_t.device
+ elif self.qweight_data.numel() != 0:
+ device = self.qweight_data.device
+ elif self.weight.numel() != 0:
+ device = self.weight.device
+ elif self.qweight_scale.numel() != 0:
+ device = self.qweight_scale.device
+ else:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ if self.weight.numel() != 0:
+ self.weight = nn.Parameter(torch.empty(0, device=device, dtype=torch.bfloat16), requires_grad=False)
+
+ # Reset all quantized buffers/state so each reload is complete and deterministic.
+ self.qweight_data = torch.empty(0, dtype=torch.int8, device=device)
+ self.qweight_t = torch.empty(0, dtype=torch.int8, device=device)
+ self.qweight_scale = torch.empty(0, dtype=torch.bfloat16, device=device)
+ self.qweight_scale_fp32 = torch.empty(0, dtype=torch.float32, device=device)
+ self.input_scale = torch.ones((), dtype=torch.bfloat16, device=device)
+ self.output_scale = torch.ones((), dtype=torch.bfloat16, device=device)
+ self.use_int8_weight = False
+ self.use_triton_int8 = False
+ self._input_scale_value = 1.0
+ self._quant_data_loaded.clear()
+ self._quant_scale_loaded.clear()
+
+
+class ReplicatedLinear(LinearBase):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ super().__init__(input_size, output_size, bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_int8_weight:
+ y = self._quant_int8_mm(x)
+ if self.bias is not None:
+ y = y + self.bias
+ return y
+ return F.linear(x, self.weight, self.bias)
+
+
+class ColumnParallelLinear(LinearBase):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = _get_tp_info()[1]
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
+
+ def _shard_weight(self, loaded_weight: torch.Tensor, loaded_shard_id=None) -> torch.Tensor:
+ shard_size = divide(loaded_weight.size(self.tp_dim), self.tp_size)
+ start_idx = self.tp_rank * shard_size
+ return loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+
+ def _shard_weight_scale(self, loaded_scale: torch.Tensor, loaded_shard_id=None) -> torch.Tensor:
+ if loaded_scale.dim() == 0:
+ return loaded_scale
+ shard_size = divide(loaded_scale.size(0), self.tp_size)
+ start_idx = self.tp_rank * shard_size
+ return loaded_scale.narrow(0, start_idx, shard_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_int8_weight:
+ y = self._quant_int8_mm(x)
+ if self.bias is not None:
+ y = y + self.bias
+ return y
+ return F.linear(x, self.weight, self.bias)
+
+
+class MergedColumnParallelLinear(ColumnParallelLinear):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_sizes: list[int],
+ bias: bool = False,
+ ):
+ self.output_sizes = output_sizes
+ super().__init__(input_size, sum(output_sizes), bias)
+ self._quant_expected_shards = len(output_sizes)
+
+ def _merged_shard_meta(self, loaded_shard_id: int):
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
+ return shard_offset, shard_size
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
+ param_data = param.data
+ shard_offset, shard_size = self._merged_shard_meta(loaded_shard_id)
+ param_slice = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ shard = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_slice.copy_(shard)
+
+ def quant_weight_data_loader(self, loaded_weight: torch.Tensor, loaded_shard_id=None):
+ device = self.weight.device
+ if self.qweight_data.numel() == 0:
+ self.qweight_data = torch.empty((self.output_size, self.input_size), dtype=torch.int8, device=device)
+ shard_offset, shard_size = self._merged_shard_meta(int(loaded_shard_id))
+ param_slice = self.qweight_data.narrow(0, shard_offset, shard_size)
+ shard = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_slice.copy_(shard.to(device=device, dtype=torch.int8, non_blocking=True))
+ self._quant_data_loaded.add(int(loaded_shard_id))
+
+ def quant_weight_scale_loader(self, loaded_scale: torch.Tensor, loaded_shard_id=None):
+ device = self.weight.device if self.weight.numel() != 0 else self.qweight_data.device
+ if self.qweight_scale.numel() == 0:
+ self.qweight_scale = torch.empty((self.output_size,), dtype=torch.bfloat16, device=device)
+ shard_offset, shard_size = self._merged_shard_meta(int(loaded_shard_id))
+ scale_slice = self.qweight_scale.narrow(0, shard_offset, shard_size)
+ shard = loaded_scale.chunk(self.tp_size, 0)[self.tp_rank]
+ if shard.dim() == 2 and shard.size(1) == 1:
+ shard = shard.squeeze(1)
+ scale_slice.copy_(shard.to(device=device, dtype=torch.bfloat16, non_blocking=True))
+ self._quant_scale_loaded.add(int(loaded_shard_id))
+
+
+class QKVParallelLinear(ColumnParallelLinear):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ head_size: int,
+ total_num_heads: int,
+ total_num_kv_heads: int | None = None,
+ bias: bool = False,
+ ):
+ tp_size = _get_tp_info()[1]
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
+ self.head_size = head_size
+ self.num_heads = divide(total_num_heads, tp_size)
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
+ super().__init__(hidden_size, output_size, bias)
+ self._quant_expected_shards = 3
+
+ def _qkv_offset_size(self, loaded_shard_id: str):
+ assert loaded_shard_id in ["q", "k", "v"]
+ if loaded_shard_id == "q":
+ shard_size = self.num_heads * self.head_size
+ shard_offset = 0
+ elif loaded_shard_id == "k":
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = self.num_heads * self.head_size
+ else:
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
+ return shard_offset, shard_size
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
+ param_data = param.data
+ shard_offset, shard_size = self._qkv_offset_size(loaded_shard_id)
+ param_slice = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ shard = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_slice.copy_(shard)
+
+ def quant_weight_data_loader(self, loaded_weight: torch.Tensor, loaded_shard_id=None):
+ device = self.weight.device
+ if self.qweight_data.numel() == 0:
+ self.qweight_data = torch.empty((self.output_size, self.input_size), dtype=torch.int8, device=device)
+ shard_offset, shard_size = self._qkv_offset_size(str(loaded_shard_id))
+ param_slice = self.qweight_data.narrow(0, shard_offset, shard_size)
+ shard = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_slice.copy_(shard.to(device=device, dtype=torch.int8, non_blocking=True))
+ self._quant_data_loaded.add(str(loaded_shard_id))
+
+ def quant_weight_scale_loader(self, loaded_scale: torch.Tensor, loaded_shard_id=None):
+ device = self.weight.device if self.weight.numel() != 0 else self.qweight_data.device
+ if self.qweight_scale.numel() == 0:
+ self.qweight_scale = torch.empty((self.output_size,), dtype=torch.bfloat16, device=device)
+ shard_offset, shard_size = self._qkv_offset_size(str(loaded_shard_id))
+ scale_slice = self.qweight_scale.narrow(0, shard_offset, shard_size)
+ shard = loaded_scale.chunk(self.tp_size, 0)[self.tp_rank]
+ if shard.dim() == 2 and shard.size(1) == 1:
+ shard = shard.squeeze(1)
+ scale_slice.copy_(shard.to(device=device, dtype=torch.bfloat16, non_blocking=True))
+ self._quant_scale_loaded.add(str(loaded_shard_id))
+
+
+class RowParallelLinear(LinearBase):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = _get_tp_info()[1]
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
+
+ def _shard_weight(self, loaded_weight: torch.Tensor, loaded_shard_id=None) -> torch.Tensor:
+ shard_size = divide(loaded_weight.size(self.tp_dim), self.tp_size)
+ start_idx = self.tp_rank * shard_size
+ return loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+
+ def _shard_weight_scale(self, loaded_scale: torch.Tensor, loaded_shard_id=None) -> torch.Tensor:
+ # Output rows are not sharded in RowParallelLinear.
+ return loaded_scale
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_int8_weight:
+ y = self._quant_int8_mm(x)
+ if self.bias is not None and self.tp_rank == 0:
+ y = y + self.bias
+ else:
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
+ if self.tp_size > 1:
+ dist.all_reduce(y)
+ return y
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/rotary_embedding.py b/Wan2GP/shared/llm_engines/nanovllm/layers/rotary_embedding.py
new file mode 100644
index 000000000..998d11646
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/rotary_embedding.py
@@ -0,0 +1,61 @@
+from functools import lru_cache
+import torch
+from torch import nn
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ x1, x2 = torch.chunk(x.float(), 2, dim=-1)
+ y1 = x1 * cos - x2 * sin
+ y2 = x2 * cos + x1 * sin
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
+
+
+class RotaryEmbedding(nn.Module):
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ ) -> None:
+ super().__init__()
+ self.head_size = head_size
+ assert rotary_dim == head_size
+ inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
+
+ @torch.compile
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ query = apply_rotary_emb(query, cos, sin)
+ key = apply_rotary_emb(key, cos, sin)
+ return query, key
+
+
+@lru_cache(1)
+def get_rope(
+ head_size: int,
+ rotary_dim: int,
+ max_position: int,
+ base: float,
+ rope_scaling: dict | None = None,
+):
+ assert rope_scaling is None
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
+ return rotary_emb
diff --git a/Wan2GP/shared/llm_engines/nanovllm/layers/sampler.py b/Wan2GP/shared/llm_engines/nanovllm/layers/sampler.py
new file mode 100644
index 000000000..41fcb2710
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/layers/sampler.py
@@ -0,0 +1,122 @@
+import torch
+from torch import nn
+from typing import Optional
+import os
+
+
+_SAMPLER_NUMERIC_GUARD = os.environ.get("WAN2GP_NANOVLLM_SAMPLER_NUMERIC_GUARD", "0") == "1"
+
+
+def apply_top_k_top_p(
+ logits: torch.Tensor,
+ k: Optional[torch.Tensor],
+ p: Optional[torch.Tensor],
+) -> torch.Tensor:
+ """Apply top-k and top-p masks to the logits (vLLM style).
+
+ The logits tensor is updated in-place.
+ """
+ if p is None:
+ if k is None:
+ return logits
+ # Avoid sorting vocab for top-k only case
+ return apply_top_k_only(logits, k)
+
+ # Need to sort for top-p
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
+
+ if k is not None:
+ # Apply top-k first
+ vocab_size = logits_sort.size(1)
+ # Clamp k to valid range
+ k_clamped = k.clamp(1, vocab_size).long()
+ top_k_mask_idx = vocab_size - k_clamped # shape: [B]
+ # Get the threshold value for each batch
+ top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
+ top_k_mask = logits_sort < top_k_thresh
+ logits_sort.masked_fill_(top_k_mask, float('-inf'))
+
+ # Apply top-p
+ probs_sort = logits_sort.softmax(dim=-1)
+ probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
+ top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
+ # Ensure at least one token is kept
+ top_p_mask[:, -1] = False
+ logits_sort.masked_fill_(top_p_mask, float('-inf'))
+
+ # Re-sort back to original positions
+ logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
+ return logits
+
+
+def apply_top_k_only(
+ logits: torch.Tensor,
+ k: torch.Tensor,
+) -> torch.Tensor:
+ """Apply top-k mask without sorting the entire vocab (vLLM style).
+
+ This is much faster than sorting for top-k only cases.
+ The logits tensor is updated in-place.
+ """
+ vocab_size = logits.shape[1]
+ # Handle cases where k >= vocab_size (no filtering needed)
+ no_top_k_mask = (k <= 0) | (k >= vocab_size)
+ # Set invalid k to 1 so we can still gather
+ k_safe = k.masked_fill(no_top_k_mask, 1).long()
+ # NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
+ max_top_k = int(k_safe.max().clamp(max=vocab_size))
+
+ # Get top-k values for all batches
+ # topk.values has shape [batch_size, max_top_k]
+ topk_values = logits.topk(max_top_k, dim=1).values
+
+ # Convert k to 0-based index: we want the k-th largest value (index k-1)
+ # Clamp to valid range for gather
+ k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
+ # Gather the threshold value (the k-th largest)
+ top_k_thresh = topk_values.gather(1, k_index)
+
+ # For rows with no top-k filtering, set threshold to -inf so nothing gets masked
+ top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
+
+ # Mask all values below the threshold
+ logits.masked_fill_(logits < top_k_thresh, float('-inf'))
+ return logits
+
+
+class Sampler(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(
+ self,
+ logits: torch.Tensor,
+ temperatures: torch.Tensor,
+ top_ks: Optional[torch.Tensor] = None,
+ top_ps: Optional[torch.Tensor] = None,
+ repetition_penalties: Optional[torch.Tensor] = None,
+ input_ids: Optional[torch.Tensor] = None,
+ ):
+ """
+ Sample tokens from logits with optional top-k and top-p filtering.
+
+ Condition checking is done OUTSIDE the compiled function to avoid
+ graph breaks from .any() calls.
+ """
+ # Apply temperature
+ logits = logits.float().div_(temperatures.unsqueeze(dim=1))
+
+ logits = apply_top_k_top_p(
+ logits,
+ top_ks,
+ top_ps,
+ )
+ if _SAMPLER_NUMERIC_GUARD:
+ logits = torch.nan_to_num(logits, nan=float("-inf"))
+ invalid_rows = ~torch.isfinite(logits).any(dim=-1)
+ if invalid_rows.any():
+ logits[invalid_rows, 0] = 0.0
+ probs = torch.softmax(logits, dim=-1)
+ sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
+ return sample_tokens
diff --git a/Wan2GP/shared/llm_engines/nanovllm/llm.py b/Wan2GP/shared/llm_engines/nanovllm/llm.py
new file mode 100644
index 000000000..4f51a44f1
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/llm.py
@@ -0,0 +1,5 @@
+from nanovllm.engine.llm_engine import LLMEngine
+
+
+class LLM(LLMEngine):
+ pass
diff --git a/Wan2GP/shared/llm_engines/nanovllm/models/qwen3.py b/Wan2GP/shared/llm_engines/nanovllm/models/qwen3.py
new file mode 100644
index 000000000..3f1c84500
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/models/qwen3.py
@@ -0,0 +1,235 @@
+import torch
+from torch import nn
+import torch.distributed as dist
+from transformers import Qwen3Config
+
+from nanovllm.layers.activation import SiluAndMul
+from nanovllm.layers.attention import Attention
+from nanovllm.layers.layernorm import RMSNorm
+from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
+from nanovllm.layers.rotary_embedding import get_rope
+from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
+
+
+def _get_tp_size():
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size()
+ return 1
+
+
+class Qwen3Attention(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = _get_tp_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim ** -0.5
+ self.qkv_bias = qkv_bias
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ if not self.qkv_bias:
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = q.view(-1, self.num_heads, self.head_dim)
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ if not self.qkv_bias:
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = self.rotary_emb(positions, q, k)
+ o = self.attn(q, k, v)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3DecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.self_attn = Qwen3Attention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, 'attention_bias', True),
+ head_dim=getattr(config, 'head_dim', None),
+ rope_theta=getattr(config, "rope_theta", 1000000),
+ rope_scaling=getattr(config, "rope_scaling", None),
+ )
+ self.mlp = Qwen3MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class Qwen3Model(nn.Module):
+
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Qwen3ForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(
+ self,
+ config: Qwen3Config
+ ) -> None:
+ super().__init__()
+ self.model = Qwen3Model(config)
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ # Proxy attributes for weight loading compatibility
+ # Some model weights use "embed_tokens" instead of "model.embed_tokens"
+ @property
+ def embed_tokens(self):
+ return self.model.embed_tokens
+
+ @property
+ def layers(self):
+ return self.model.layers
+
+ @property
+ def norm(self):
+ return self.model.norm
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
diff --git a/Wan2GP/shared/llm_engines/nanovllm/sampling_params.py b/Wan2GP/shared/llm_engines/nanovllm/sampling_params.py
new file mode 100644
index 000000000..291a1c293
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/sampling_params.py
@@ -0,0 +1,32 @@
+from dataclasses import dataclass, field
+from typing import Optional, Callable, Any
+
+
+@dataclass
+class SamplingParams:
+ temperature: float = 1.0
+ max_tokens: int = 64
+ ignore_eos: bool = False
+ cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
+ top_k: Optional[int] = None # Top-k sampling: consider only top k tokens
+ top_p: Optional[float] = None # Top-p (nucleus) sampling: consider tokens with cumulative probability <= top_p
+ repetition_penalty: float = 1.0 # Repetition penalty: >1.0 reduces repetition, <1.0 increases it
+ # Optional logits processor for constrained decoding
+ # Should be a callable with signature: (input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor
+ logits_processor: Optional[Any] = field(default=None, repr=False)
+ # Optional callback to update processor state after each token
+ # Should be a callable with signature: (token_id: int) -> None
+ logits_processor_update_state: Optional[Callable[[int], None]] = field(default=None, repr=False)
+ # Optional additive logits bias (shape [vocab_size]) applied each decode step.
+ logits_bias: Optional[Any] = field(default=None, repr=False)
+ # Optional RNG seed for deterministic sampling.
+ seed: Optional[int] = None
+
+ def __post_init__(self):
+ assert self.temperature > 1e-10, "greedy sampling is not permitted"
+ assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
+ if self.top_k is not None:
+ assert self.top_k > 0, "top_k must be > 0"
+ if self.top_p is not None:
+ assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
+ assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"
diff --git a/Wan2GP/shared/llm_engines/nanovllm/utils/context.py b/Wan2GP/shared/llm_engines/nanovllm/utils/context.py
new file mode 100644
index 000000000..2281888f8
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/utils/context.py
@@ -0,0 +1,27 @@
+from dataclasses import dataclass
+import torch
+
+
+@dataclass
+class Context:
+ is_prefill: bool = False
+ cu_seqlens_q: torch.Tensor | None = None
+ cu_seqlens_k: torch.Tensor | None = None
+ max_seqlen_q: int = 0
+ max_seqlen_k: int = 0
+ slot_mapping: torch.Tensor | None = None
+ context_lens: torch.Tensor | None = None
+ block_tables: torch.Tensor | None = None
+
+_CONTEXT = Context()
+
+def get_context():
+ return _CONTEXT
+
+def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
+ global _CONTEXT
+ _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
+
+def reset_context():
+ global _CONTEXT
+ _CONTEXT = Context()
diff --git a/Wan2GP/shared/llm_engines/nanovllm/utils/loader.py b/Wan2GP/shared/llm_engines/nanovllm/utils/loader.py
new file mode 100644
index 000000000..17ac80004
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/utils/loader.py
@@ -0,0 +1,221 @@
+import os
+from glob import glob
+import torch
+from torch import nn
+from safetensors import safe_open
+
+
+def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
+ param.data.copy_(loaded_weight)
+
+
+def _get_parameter_safe(model: nn.Module, weight_name: str):
+ """
+ Try to get parameter from model, handling name mismatches.
+
+ Some models have nested structure (e.g., Qwen3ForCausalLM has model.embed_tokens)
+ but weight files may have flat names (embed_tokens.weight).
+ """
+ # Try direct access first
+ try:
+ return model.get_parameter(weight_name)
+ except AttributeError:
+ pass
+
+ # Try with 'model.' prefix (for nested model structure)
+ try:
+ prefixed_name = f"model.{weight_name}"
+ return model.get_parameter(prefixed_name)
+ except AttributeError:
+ pass
+
+ # Try removing 'model.' prefix
+ if weight_name.startswith("model."):
+ try:
+ unprefixed_name = weight_name[6:] # Remove 'model.' prefix
+ return model.get_parameter(unprefixed_name)
+ except AttributeError:
+ pass
+
+ return None
+
+
+def _get_submodule_safe(model: nn.Module, module_name: str):
+ try:
+ return model.get_submodule(module_name)
+ except Exception:
+ pass
+ try:
+ return model.get_submodule(f"model.{module_name}")
+ except Exception:
+ pass
+ if module_name.startswith("model."):
+ try:
+ return model.get_submodule(module_name[6:])
+ except Exception:
+ pass
+ return None
+
+
+def _list_safetensor_files(path: str) -> list[str]:
+ if os.path.isfile(path) and path.endswith(".safetensors"):
+ return [path]
+ return glob(os.path.join(path, "*.safetensors"))
+
+
+class WeightStore:
+ def __init__(self, path: str, mode: str = "lazy"):
+ self.path = path
+ self.mode = (mode or "lazy").lower()
+ self.files = _list_safetensor_files(path)
+ if not self.files:
+ raise FileNotFoundError(f"No .safetensors files found in {path}")
+ self._file_handles = {}
+ self._weight_to_file = {}
+ self._pinned_weights = {}
+ self.is_quanto_int8 = False
+ for file in self.files:
+ with safe_open(file, "pt", "cpu") as f:
+ for key in f.keys():
+ self._weight_to_file[key] = file
+ if key.endswith(".weight._data"):
+ self.is_quanto_int8 = True
+ if self.mode == "pinned":
+ self._preload_pinned()
+
+ def _get_handle(self, file_path: str):
+ handle = self._file_handles.get(file_path)
+ if handle is None:
+ handle = safe_open(file_path, "pt", "cpu")
+ self._file_handles[file_path] = handle
+ return handle
+
+ def _preload_pinned(self):
+ for key, file in self._weight_to_file.items():
+ handle = self._get_handle(file)
+ tensor = handle.get_tensor(key)
+ if tensor.device.type != "cpu":
+ tensor = tensor.cpu()
+ if not tensor.is_pinned():
+ tensor = tensor.pin_memory()
+ self._pinned_weights[key] = tensor
+
+ def get_tensor(self, key: str) -> torch.Tensor:
+ if self.mode == "pinned":
+ return self._pinned_weights[key]
+ handle = self._get_handle(self._weight_to_file[key])
+ return handle.get_tensor(key)
+
+
+def load_model(model: nn.Module, path: str, weight_store: WeightStore | None = None):
+ packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
+ if weight_store is None:
+ safetensor_files = _list_safetensor_files(path)
+ if not safetensor_files:
+ raise FileNotFoundError(f"No .safetensors files found in {path}")
+ for file in safetensor_files:
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ tensor = f.get_tensor(weight_name)
+ _apply_weight(model, packed_modules_mapping, weight_name, tensor)
+ _finalize_quantized_modules(model)
+ return
+
+ for weight_name in weight_store._weight_to_file.keys():
+ tensor = weight_store.get_tensor(weight_name)
+ _apply_weight(model, packed_modules_mapping, weight_name, tensor)
+ _finalize_quantized_modules(model)
+
+
+def _apply_weight(model: nn.Module, packed_modules_mapping, weight_name: str, tensor: torch.Tensor):
+ quant_suffix = None
+ if weight_name.endswith(".weight._data"):
+ quant_suffix = "qdata"
+ elif weight_name.endswith(".weight._scale"):
+ quant_suffix = "qscale"
+ elif weight_name.endswith(".input_scale"):
+ quant_suffix = "input_scale"
+ elif weight_name.endswith(".output_scale"):
+ quant_suffix = "output_scale"
+
+ for k in packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = packed_modules_mapping[k]
+ mapped_name = weight_name.replace(k, v)
+ if quant_suffix is not None:
+ module_name = mapped_name
+ if quant_suffix == "qdata":
+ module_name = mapped_name[: -len(".weight._data")]
+ elif quant_suffix == "qscale":
+ module_name = mapped_name[: -len(".weight._scale")]
+ elif quant_suffix == "input_scale":
+ module_name = mapped_name[: -len(".input_scale")]
+ elif quant_suffix == "output_scale":
+ module_name = mapped_name[: -len(".output_scale")]
+ module = _get_submodule_safe(model, module_name)
+ if module is None:
+ print(f"[loader] Warning: Module not found: {module_name}")
+ return
+ if quant_suffix == "qdata":
+ loader_fn = getattr(module, "quant_weight_data_loader", None)
+ elif quant_suffix == "qscale":
+ loader_fn = getattr(module, "quant_weight_scale_loader", None)
+ elif quant_suffix == "input_scale":
+ loader_fn = getattr(module, "quant_input_scale_loader", None)
+ else:
+ loader_fn = getattr(module, "quant_output_scale_loader", None)
+ if loader_fn is None:
+ print(f"[loader] Warning: Quant loader not found on module: {module_name}")
+ return
+ if quant_suffix in ("qdata", "qscale"):
+ loader_fn(tensor, shard_id)
+ else:
+ loader_fn(tensor)
+ return
+ param_name = mapped_name
+ param = _get_parameter_safe(model, param_name)
+ if param is None:
+ print(f"[loader] Warning: Parameter not found: {param_name}")
+ return
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, tensor, shard_id)
+ return
+ if quant_suffix is not None:
+ if quant_suffix == "qdata":
+ module_name = weight_name[: -len(".weight._data")]
+ elif quant_suffix == "qscale":
+ module_name = weight_name[: -len(".weight._scale")]
+ elif quant_suffix == "input_scale":
+ module_name = weight_name[: -len(".input_scale")]
+ else:
+ module_name = weight_name[: -len(".output_scale")]
+ module = _get_submodule_safe(model, module_name)
+ if module is None:
+ print(f"[loader] Warning: Module not found: {module_name}")
+ return
+ if quant_suffix == "qdata":
+ loader_fn = getattr(module, "quant_weight_data_loader", None)
+ elif quant_suffix == "qscale":
+ loader_fn = getattr(module, "quant_weight_scale_loader", None)
+ elif quant_suffix == "input_scale":
+ loader_fn = getattr(module, "quant_input_scale_loader", None)
+ else:
+ loader_fn = getattr(module, "quant_output_scale_loader", None)
+ if loader_fn is None:
+ print(f"[loader] Warning: Quant loader not found on module: {module_name}")
+ return
+ loader_fn(tensor)
+ return
+ param = _get_parameter_safe(model, weight_name)
+ if param is None:
+ print(f"[loader] Warning: Parameter not found: {weight_name}")
+ return
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, tensor)
+
+
+def _finalize_quantized_modules(model: nn.Module):
+ for module in model.modules():
+ finalize = getattr(module, "finalize_quantized", None)
+ if callable(finalize):
+ finalize()
diff --git a/Wan2GP/shared/llm_engines/nanovllm/vllm_support.py b/Wan2GP/shared/llm_engines/nanovllm/vllm_support.py
new file mode 100644
index 000000000..f8c195ec9
--- /dev/null
+++ b/Wan2GP/shared/llm_engines/nanovllm/vllm_support.py
@@ -0,0 +1,139 @@
+_PROBE_CACHE = None
+_WARNED_REQUESTED_VLLM_UNAVAILABLE = False
+
+
+def _check_triton():
+ try:
+ import triton # noqa: F401
+ import triton.language as tl # noqa: F401
+ except Exception as exc:
+ return False, f"Triton import failed: {exc}"
+ return True, "ok"
+
+
+def _check_flash_attention_2():
+ try:
+ import flash_attn
+ from flash_attn import flash_attn_varlen_func # noqa: F401
+ from flash_attn import flash_attn_with_kvcache # noqa: F401
+ version = str(getattr(flash_attn, "__version__", ""))
+ except Exception as exc:
+ return False, f"FlashAttention import failed: {exc}"
+
+ major = None
+ if len(version) > 0:
+ try:
+ major = int(version.split(".", 1)[0])
+ except Exception:
+ major = None
+ if major is not None and major < 2:
+ return False, f"FlashAttention major version is {major}, expected >= 2"
+ return True, "ok"
+
+
+def _load_linear_module():
+ import importlib.util
+ import os
+
+ base_dir = os.path.dirname(os.path.abspath(__file__))
+ linear_path = os.path.join(base_dir, "layers", "linear.py")
+ if not os.path.isfile(linear_path):
+ raise RuntimeError(f"Missing nanovllm linear kernel file: {linear_path}")
+
+ spec = importlib.util.spec_from_file_location("nanovllm_linear_probe", linear_path)
+ if spec is None or spec.loader is None:
+ raise RuntimeError("Unable to build import spec for nanovllm linear kernel probe")
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+
+
+def _check_triton_int8_kernel():
+ import inspect
+ import torch
+
+ if not torch.cuda.is_available():
+ return False, "CUDA is not available"
+
+ try:
+ linear_module = _load_linear_module()
+ if not getattr(linear_module, "_TRITON_AVAILABLE", False):
+ return False, "nanovllm Triton path is disabled"
+
+ run_kernel = getattr(linear_module, "_run_triton_fused_int8_mm", None)
+ if run_kernel is None:
+ return False, "nanovllm Triton int8 kernel entrypoint is missing"
+
+ device = torch.device("cuda")
+ x = torch.randn((2, 64), device=device, dtype=torch.bfloat16)
+ qweight_t = torch.randint(-127, 128, (64, 32), device=device, dtype=torch.int8)
+ qweight_scale = torch.ones((32,), device=device, dtype=torch.float32)
+ # Support both legacy (4-arg) and current (3-arg) probe signatures.
+ param_count = len(inspect.signature(run_kernel).parameters)
+ if param_count >= 4:
+ out = run_kernel(x, qweight_t, qweight_scale, 0.01)
+ else:
+ out = run_kernel(x, qweight_t, qweight_scale)
+ torch.cuda.synchronize()
+
+ if tuple(out.shape) != (2, 32):
+ return False, f"Unexpected kernel output shape: {tuple(out.shape)}"
+ if not torch.isfinite(out).all().item():
+ return False, "Kernel output contains non-finite values"
+ except Exception as exc:
+ return False, f"Triton int8 kernel smoke test failed: {exc}"
+
+ return True, "ok"
+
+
+def probe_vllm_runtime(force=False):
+ global _PROBE_CACHE
+ if _PROBE_CACHE is not None and not force:
+ return _PROBE_CACHE.copy()
+
+ checks = {}
+
+ triton_ok, triton_msg = _check_triton()
+ checks["triton"] = {"ok": triton_ok, "message": triton_msg}
+
+ flash_ok, flash_msg = _check_flash_attention_2()
+ checks["flash_attention_2"] = {"ok": flash_ok, "message": flash_msg}
+
+ kernel_ok, kernel_msg = _check_triton_int8_kernel()
+ checks["triton_int8_kernel"] = {"ok": kernel_ok, "message": kernel_msg}
+
+ supported = triton_ok and flash_ok and kernel_ok
+ result = {
+ "supported": supported,
+ "preferred_engine": "vllm" if supported else "legacy",
+ "checks": checks,
+ }
+
+ _PROBE_CACHE = result.copy()
+ return result
+
+
+def resolve_lm_decoder_engine(requested_engine):
+ probe_result = probe_vllm_runtime()
+ supported = bool(probe_result.get("supported", False))
+ if requested_engine == "vllm":
+ if supported:
+ return "vllm"
+ global _WARNED_REQUESTED_VLLM_UNAVAILABLE
+ if not _WARNED_REQUESTED_VLLM_UNAVAILABLE:
+ checks = probe_result.get("checks", {})
+ reasons = []
+ if isinstance(checks, dict):
+ for check_name, check_data in checks.items():
+ if isinstance(check_data, dict) and not check_data.get("ok", False):
+ msg = str(check_data.get("message", "failed")).replace("\n", " ").strip()
+ if len(msg) > 220:
+ msg = msg[:220] + "..."
+ reasons.append(f"{check_name}={msg}")
+ reason_text = "; ".join(reasons) if len(reasons) > 0 else "unknown reason"
+ print(f"[LM] Requested decoder engine 'vllm' is unavailable at startup ({reason_text}).")
+ _WARNED_REQUESTED_VLLM_UNAVAILABLE = True
+ return "legacy"
+ if requested_engine == "":
+ return "vllm" if supported else "legacy"
+ return requested_engine
diff --git a/Wan2GP/shared/prompt_enhancer/__init__.py b/Wan2GP/shared/prompt_enhancer/__init__.py
new file mode 100644
index 000000000..198074eb2
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/__init__.py
@@ -0,0 +1,3 @@
+from .loader import load_florence2
+
+__all__ = ["load_florence2"]
diff --git a/Wan2GP/shared/prompt_enhancer/florence2/__init__.py b/Wan2GP/shared/prompt_enhancer/florence2/__init__.py
new file mode 100644
index 000000000..cff5ddda4
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/florence2/__init__.py
@@ -0,0 +1,11 @@
+from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig
+from .modeling_florence2 import Florence2ForConditionalGeneration
+from .processing_florence2 import Florence2Processor
+
+__all__ = [
+ "Florence2Config",
+ "Florence2LanguageConfig",
+ "Florence2VisionConfig",
+ "Florence2ForConditionalGeneration",
+ "Florence2Processor",
+]
diff --git a/Wan2GP/shared/prompt_enhancer/florence2/configuration_florence2.py b/Wan2GP/shared/prompt_enhancer/florence2/configuration_florence2.py
new file mode 100644
index 000000000..b4ca3f132
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/florence2/configuration_florence2.py
@@ -0,0 +1,339 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+""" Florence-2 configuration"""
+
+from typing import Optional
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+class Florence2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ The dropout rate of the drop path layer.
+ patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
+ The patch size of the image.
+ patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
+ The patch stride of the image.
+ patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
+ The patch padding of the image.
+ patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
+ Whether to apply layer normalization before the patch embedding layer.
+ enable_checkpoint (`bool`, *optional*, defaults to False):
+ Whether to enable checkpointing.
+ dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
+ The dimension of the embedding layer.
+ num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
+ The number of attention heads.
+ num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
+ The number of groups.
+ depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
+ The depth of the model.
+ window_size (`int`, *optional*, defaults to 12):
+ The window size of the model.
+ projection_dim (`int`, *optional*, defaults to 1024):
+ The dimension of the projection layer.
+ visual_temporal_embedding (`dict`, *optional*):
+ The configuration of the visual temporal embedding.
+ image_pos_embed (`dict`, *optional*):
+ The configuration of the image position embedding.
+ image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
+ The source of the image feature.
+ Example:
+
+ ```python
+ >>> from transformers import Florence2VisionConfig, Florence2VisionModel
+
+ >>> # Initializing a Florence2 Vision style configuration
+ >>> configuration = Florence2VisionConfig()
+
+ >>> # Initializing a model (with random weights)
+ >>> model = Florence2VisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "florence2_vision"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ drop_path_rate=0.1,
+ patch_size=[7, 3, 3, 3],
+ patch_stride=[4, 2, 2, 2],
+ patch_padding=[3, 1, 1, 1],
+ patch_prenorm=[False, True, True, True],
+ enable_checkpoint=False,
+ dim_embed=[256, 512, 1024, 2048],
+ num_heads=[8, 16, 32, 64],
+ num_groups=[8, 16, 32, 64],
+ depths=[1, 1, 9, 1],
+ window_size=12,
+ projection_dim=1024,
+ visual_temporal_embedding=None,
+ image_pos_embed=None,
+ image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
+ **kwargs,
+ ):
+ self.drop_path_rate = drop_path_rate
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.patch_prenorm = patch_prenorm
+ self.enable_checkpoint = enable_checkpoint
+ self.dim_embed = dim_embed
+ self.num_heads = num_heads
+ self.num_groups = num_groups
+ self.depths = depths
+ self.window_size = window_size
+ self.projection_dim = projection_dim
+ self.visual_temporal_embedding = visual_temporal_embedding
+ self.image_pos_embed = image_pos_embed
+ self.image_feature_source = image_feature_source
+
+ super().__init__(**kwargs)
+
+
+
+class Florence2LanguageConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the BART
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 51289):
+ Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Florence2LanguageModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for classifier.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ num_labels (`int`, *optional*, defaults to 3):
+ The number of labels to use in [`Florence2LanguageForSequenceClassification`].
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
+
+ >>> # Initializing a Florence2 Language style configuration
+ >>> configuration = Florence2LanguageConfig()
+
+ >>> # Initializing a model (with random weights)
+ >>> model = Florence2LangaugeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "florence2_language"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=51289,
+ max_position_embeddings=1024,
+ encoder_layers=12,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=12,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ num_labels=3,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ is_encoder_decoder=True,
+ decoder_start_token_id=2,
+ forced_eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(
+ num_labels=num_labels,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+ # ensure backward compatibility for BART CNN models
+ forced_bos_token_id = getattr(self, "forced_bos_token_id", None)
+ if forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
+ self.forced_bos_token_id = self.bos_token_id
+ warnings.warn(
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
+ "The config can simply be saved and uploaded again to be fixed."
+ )
+
+class Florence2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
+ Florence-2 model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Florence2VisionConfig`, *optional*):
+ Custom vision config or dict
+ text_config (`Union[AutoConfig, dict]`, *optional*):
+ The config object of the text backbone.
+ ignore_index (`int`, *optional*, defaults to -100):
+ The ignore index for the loss function.
+ vocab_size (`int`, *optional*, defaults to 51289):
+ Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
+ projection_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the multimodal projection space.
+
+ Example:
+
+ ```python
+ >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
+
+ >>> # Initializing a clip-like vision config
+ >>> vision_config = CLIPVisionConfig()
+
+ >>> # Initializing a Bart config
+ >>> text_config = BartConfig()
+
+ >>> # Initializing a Florence-2 configuration
+ >>> configuration = Florence2Config(vision_config, text_config)
+
+ >>> # Initializing a model from the florence-2 configuration
+ >>> model = Florence2ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "florence2"
+ is_composition = False
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ ignore_index=-100,
+ vocab_size=51289,
+ projection_dim=1024,
+ **kwargs,
+ ):
+ self.ignore_index = ignore_index
+ self.vocab_size = vocab_size
+ self.projection_dim = projection_dim
+ if vision_config is not None:
+ vision_config = PretrainedConfig(**vision_config)
+ self.vision_config = vision_config
+ self.vocab_size = self.vocab_size
+
+ self.text_config = text_config
+ if text_config is not None:
+ self.text_config = Florence2LanguageConfig(**text_config)
+
+
+ super().__init__(**kwargs)
diff --git a/Wan2GP/shared/prompt_enhancer/florence2/image_processing_florence2.py b/Wan2GP/shared/prompt_enhancer/florence2/image_processing_florence2.py
new file mode 100644
index 000000000..5046910bf
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/florence2/image_processing_florence2.py
@@ -0,0 +1,222 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image, ImageOps
+import torch
+from transformers.image_processing_base import ImageProcessingMixin
+
+
+def _as_list(val):
+ if isinstance(val, (list, tuple)):
+ return list(val)
+ return [val]
+
+
+def _to_numpy(image: Any) -> np.ndarray:
+ if isinstance(image, np.ndarray):
+ return image
+ if torch.is_tensor(image):
+ return image.detach().cpu().numpy()
+ if isinstance(image, Image.Image):
+ return np.array(image)
+ raise TypeError(f"Unsupported image type: {type(image)}")
+
+
+def _infer_input_format(arr: np.ndarray) -> str:
+ if arr.ndim == 3 and arr.shape[0] in (1, 3) and arr.shape[-1] not in (1, 3):
+ return "channels_first"
+ return "channels_last"
+
+
+def _to_channels_last(arr: np.ndarray, input_format: str) -> np.ndarray:
+ if input_format == "channels_first":
+ return np.transpose(arr, (1, 2, 0))
+ return arr
+
+
+def _to_channels_first(arr: np.ndarray, input_format: str) -> np.ndarray:
+ if input_format == "channels_last":
+ return np.transpose(arr, (2, 0, 1))
+ return arr
+
+
+def _compute_resize_size(image_size: Tuple[int, int], size: Dict[str, int]) -> Tuple[int, int]:
+ height, width = image_size
+ if "height" in size and "width" in size:
+ return int(size["height"]), int(size["width"])
+ if "shortest_edge" in size:
+ target = int(size["shortest_edge"])
+ if height <= width:
+ new_h = target
+ new_w = int(round(width * target / max(height, 1)))
+ else:
+ new_w = target
+ new_h = int(round(height * target / max(width, 1)))
+ return new_h, new_w
+ raise ValueError(f"Unsupported size dict: {size}")
+
+
+def _resolve_resample(resample: Optional[int]) -> int:
+ if resample is None:
+ return Image.BICUBIC
+ try:
+ return Image.Resampling(resample)
+ except Exception:
+ return resample
+
+
+def _center_crop_pil(image: Image.Image, crop_size: Dict[str, int]) -> Image.Image:
+ target_h = int(crop_size["height"])
+ target_w = int(crop_size["width"])
+ width, height = image.size
+ if width < target_w or height < target_h:
+ padded_w = max(width, target_w)
+ padded_h = max(height, target_h)
+ padded = Image.new(image.mode, (padded_w, padded_h), (0, 0, 0))
+ padded.paste(image, ((padded_w - width) // 2, (padded_h - height) // 2))
+ image = padded
+ width, height = image.size
+ left = int(round((width - target_w) / 2.0))
+ top = int(round((height - target_h) / 2.0))
+ return image.crop((left, top, left + target_w, top + target_h))
+
+
+def _normalize_return_tensors(value: Optional[Union[str, Any]]) -> Optional[str]:
+ if value is None:
+ return None
+ if isinstance(value, str):
+ return value.lower()
+ name = getattr(value, "name", None)
+ if name:
+ return name.lower()
+ return str(value).lower()
+
+
+class Florence2ImageProcessorLite(ImageProcessingMixin):
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ image_seq_length: int,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[int] = None,
+ do_center_crop: bool = False,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[List[float]] = None,
+ image_std: Optional[List[float]] = None,
+ do_convert_rgb: Optional[bool] = True,
+ ) -> None:
+ super().__init__()
+ self.image_seq_length = int(image_seq_length)
+ self.do_resize = bool(do_resize)
+ self.size = size or {"height": 224, "width": 224}
+ self.resample = resample
+ self.do_center_crop = bool(do_center_crop)
+ self.crop_size = crop_size or {"height": 224, "width": 224}
+ self.do_rescale = bool(do_rescale)
+ self.rescale_factor = float(rescale_factor)
+ self.do_normalize = bool(do_normalize)
+ self.image_mean = image_mean or [0.485, 0.456, 0.406]
+ self.image_std = image_std or [0.229, 0.224, 0.225]
+ self.do_convert_rgb = do_convert_rgb
+
+ @classmethod
+ def from_preprocessor_config(cls, model_dir: Union[str, Path]) -> "Florence2ImageProcessorLite":
+ config_path = Path(model_dir) / "preprocessor_config.json"
+ if not config_path.exists():
+ raise FileNotFoundError(f"Missing Florence2 preprocessor_config.json in {model_dir}")
+ data = json.loads(config_path.read_text(encoding="utf-8"))
+ return cls(
+ image_seq_length=data.get("image_seq_length", 0),
+ do_resize=data.get("do_resize", True),
+ size=data.get("size") or data.get("crop_size") or {"height": 224, "width": 224},
+ resample=data.get("resample"),
+ do_center_crop=data.get("do_center_crop", False),
+ crop_size=data.get("crop_size") or data.get("size") or {"height": 224, "width": 224},
+ do_rescale=data.get("do_rescale", True),
+ rescale_factor=data.get("rescale_factor", 1 / 255),
+ do_normalize=data.get("do_normalize", True),
+ image_mean=data.get("image_mean"),
+ image_std=data.get("image_std"),
+ do_convert_rgb=data.get("do_convert_rgb"),
+ )
+
+ def __call__(
+ self,
+ images: Union[Image.Image, np.ndarray, torch.Tensor, List[Any]],
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[int] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Iterable[float]] = None,
+ image_std: Optional[Iterable[float]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, Any]] = "pt",
+ data_format: Optional[str] = "channels_first",
+ input_data_format: Optional[str] = None,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ do_resize = self.do_resize if do_resize is None else do_resize
+ size = self.size if size is None else size
+ resample = self.resample if resample is None else resample
+ do_center_crop = self.do_center_crop if do_center_crop is None else do_center_crop
+ crop_size = self.crop_size if crop_size is None else crop_size
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
+ image_mean = list(self.image_mean if image_mean is None else image_mean)
+ image_std = list(self.image_std if image_std is None else image_std)
+ do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb
+
+ resample = _resolve_resample(resample)
+ want_torch = _normalize_return_tensors(return_tensors) in ("pt", "pytorch", "tensortype.pytorch")
+
+ processed: List[np.ndarray] = []
+ for image in _as_list(images):
+ if isinstance(image, Image.Image):
+ img = image
+ if do_convert_rgb:
+ img = ImageOps.exif_transpose(img).convert("RGB")
+ else:
+ arr = _to_numpy(image)
+ input_fmt = input_data_format or _infer_input_format(arr)
+ arr = _to_channels_last(arr, input_fmt)
+ img = Image.fromarray(arr.astype(np.uint8))
+ if do_convert_rgb:
+ img = img.convert("RGB")
+
+ if do_resize:
+ out_h, out_w = _compute_resize_size((img.size[1], img.size[0]), size)
+ img = img.resize((out_w, out_h), resample=resample)
+
+ if do_center_crop:
+ img = _center_crop_pil(img, crop_size)
+
+ arr = np.array(img).astype(np.float32)
+ if do_rescale:
+ arr = arr * float(rescale_factor)
+ if do_normalize:
+ mean = np.array(image_mean, dtype=np.float32)
+ std = np.array(image_std, dtype=np.float32)
+ arr = (arr - mean) / std
+
+ if data_format in ("channels_first", "first"):
+ arr = _to_channels_first(arr, "channels_last")
+ processed.append(arr)
+
+ batch = np.stack(processed, axis=0)
+ if want_torch:
+ batch = torch.from_numpy(batch).float()
+ return {"pixel_values": batch}
diff --git a/Wan2GP/shared/prompt_enhancer/florence2/modeling_florence2.py b/Wan2GP/shared/prompt_enhancer/florence2/modeling_florence2.py
new file mode 100644
index 000000000..631077358
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/florence2/modeling_florence2.py
@@ -0,0 +1,2911 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" PyTorch Florence-2 model."""
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import math
+import torch
+import torch.utils.checkpoint
+from torch import nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from torch.nn import CrossEntropyLoss
+from collections import OrderedDict
+from einops import rearrange
+from timm.layers import DropPath, trunc_normal_
+
+from transformers.modeling_utils import PreTrainedModel
+from transformers.generation import GenerationMixin
+from transformers.utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ logging,
+ replace_return_docstrings,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+)
+from .configuration_florence2 import Florence2Config
+from .configuration_florence2 import Florence2LanguageConfig
+from .configuration_florence2 import Florence2VisionConfig
+
+
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+
+
+if is_flash_attn_2_available():
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Florence2Config"
+
+class LearnedAbsolutePositionEmbedding2D(nn.Module):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, embedding_dim=256, num_pos=50):
+ super().__init__()
+ self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
+ self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2))
+
+ def forward(self, pixel_values):
+ """
+ pixel_values: (batch_size, height, width, num_channels)
+ returns: (batch_size, height, width, embedding_dim * 2)
+ """
+ if len(pixel_values.shape) != 4:
+ raise ValueError('pixel_values must be a 4D tensor')
+ height, width = pixel_values.shape[1:3]
+ width_values = torch.arange(width, device=pixel_values.device)
+ height_values = torch.arange(height, device=pixel_values.device)
+ x_emb = self.column_embeddings(width_values)
+ y_emb = self.row_embeddings(height_values)
+ # (height, width, embedding_dim * 2)
+ pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
+ # (embedding_dim * 2, height, width)
+ pos = pos.permute(2, 0, 1)
+ pos = pos.unsqueeze(0)
+ # (batch_size, embedding_dim * 2, height, width)
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
+ # (batch_size, height, width, embedding_dim * 2)
+ pos = pos.permute(0, 2, 3, 1)
+ return pos
+
+class PositionalEmbeddingCosine1D(nn.Module):
+ """
+ This class implements a very simple positional encoding. It follows closely
+ the encoder from the link below:
+ https://pytorch.org/tutorials/beginner/translation_transformer.html
+
+ Args:
+ embed_dim: The dimension of the embeddings.
+ dropout_prob: The dropout probability.
+ max_seq_len: The maximum length to precompute the positional encodings.
+ """
+ def __init__(
+ self,
+ embed_dim: int = 512,
+ max_seq_len: int = 1024) -> None:
+ super(PositionalEmbeddingCosine1D, self).__init__()
+ self.embed_dim = embed_dim
+ self.max_seq_len = max_seq_len
+ # Generate the sinusoidal arrays.
+ factor = math.log(10000)
+ denominator = torch.exp(
+ -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim)
+ # Matrix where rows correspond to a positional embedding as a function
+ # of the position index (i.e., the row index).
+ frequencies = \
+ torch.arange(0, self.max_seq_len) \
+ .reshape(self.max_seq_len, 1) * denominator
+ pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
+ # Populate uneven entries.
+ pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
+ pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
+ # Save the positional embeddings in a constant buffer.
+ self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
+
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ seq_embeds: The sequence embeddings in order. Allowed size:
+ 1. [T, D], where T is the length of the sequence, and D is the
+ frame embedding dimension.
+ 2. [B, T, D], where B is the batch size and T and D are the
+ same as above.
+
+ Returns a tensor of with the same dimensions as the input: i.e.,
+ [1, T, D] or [T, D].
+ """
+ shape_len = len(seq_embeds.shape)
+ assert 2 <= shape_len <= 3
+ len_seq = seq_embeds.size(-2)
+ assert len_seq <= self.max_seq_len
+ pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
+ # Adapt pre-computed positional embeddings to the input.
+ if shape_len == 3:
+ pos_embeds = pos_embeds.view(
+ (1, pos_embeds.size(0), pos_embeds.size(1)))
+ return pos_embeds
+
+
+class LearnedAbsolutePositionEmbedding1D(nn.Module):
+ """
+ Learnable absolute positional embeddings for 1D sequences.
+
+ Args:
+ embed_dim: The dimension of the embeddings.
+ max_seq_len: The maximum length to precompute the positional encodings.
+ """
+ def __init__(
+ self,
+ embedding_dim: int = 512,
+ num_pos: int = 1024) -> None:
+ super(LearnedAbsolutePositionEmbedding1D, self).__init__()
+ self.embeddings = nn.Embedding(num_pos, embedding_dim)
+ self.num_pos = num_pos
+
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ seq_embeds: The sequence embeddings in order. Allowed size:
+ 1. [T, D], where T is the length of the sequence, and D is the
+ frame embedding dimension.
+ 2. [B, T, D], where B is the batch size and T and D are the
+ same as above.
+
+ Returns a tensor of with the same dimensions as the input: i.e.,
+ [1, T, D] or [T, D].
+ """
+ shape_len = len(seq_embeds.shape)
+ assert 2 <= shape_len <= 3
+ len_seq = seq_embeds.size(-2)
+ assert len_seq <= self.num_pos
+ # [T, D]
+ pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
+ # Adapt pre-computed positional embeddings to the input.
+ if shape_len == 3:
+ pos_embeds = pos_embeds.view(
+ (1, pos_embeds.size(0), pos_embeds.size(1)))
+ return pos_embeds
+
+
+
+class MySequential(nn.Sequential):
+ def forward(self, *inputs):
+ for module in self._modules.values():
+ if type(inputs) == tuple:
+ inputs = module(*inputs)
+ else:
+ inputs = module(inputs)
+ return inputs
+
+
+class PreNorm(nn.Module):
+ def __init__(self, norm, fn, drop_path=None):
+ super().__init__()
+ self.norm = norm
+ self.fn = fn
+ self.drop_path = drop_path
+
+ def forward(self, x, *args, **kwargs):
+ shortcut = x
+ if self.norm != None:
+ x, size = self.fn(self.norm(x), *args, **kwargs)
+ else:
+ x, size = self.fn(x, *args, **kwargs)
+
+ if self.drop_path:
+ x = self.drop_path(x)
+
+ x = shortcut + x
+
+ return x, size
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.net = nn.Sequential(OrderedDict([
+ ("fc1", nn.Linear(in_features, hidden_features)),
+ ("act", act_layer()),
+ ("fc2", nn.Linear(hidden_features, out_features))
+ ]))
+
+ def forward(self, x, size):
+ return self.net(x), size
+
+
+class DepthWiseConv2d(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ kernel_size,
+ padding,
+ stride,
+ bias=True,
+ ):
+ super().__init__()
+ self.dw = nn.Conv2d(
+ dim_in, dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim_in,
+ stride=stride,
+ bias=bias
+ )
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+ H, W = size
+ assert N == H * W
+
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
+ size = (x.size(-2), x.size(-1))
+ x = x.flatten(2).transpose(1, 2)
+ return x, size
+
+
+class ConvEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(
+ self,
+ patch_size=7,
+ in_chans=3,
+ embed_dim=64,
+ stride=4,
+ padding=2,
+ norm_layer=None,
+ pre_norm=True
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=padding
+ )
+
+ dim_norm = in_chans if pre_norm else embed_dim
+ self.norm = norm_layer(dim_norm) if norm_layer else None
+
+ self.pre_norm = pre_norm
+
+ def forward(self, x, size):
+ H, W = size
+ if len(x.size()) == 3:
+ if self.norm and self.pre_norm:
+ x = self.norm(x)
+ x = rearrange(
+ x, 'b (h w) c -> b c h w',
+ h=H, w=W
+ )
+
+ x = self.proj(x)
+
+ _, _, H, W = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ if self.norm and not self.pre_norm:
+ x = self.norm(x)
+
+ return x, (H, W)
+
+
+class ChannelAttention(nn.Module):
+
+ def __init__(self, dim, groups=8, qkv_bias=True):
+ super().__init__()
+
+ self.groups = groups
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * (float(N) ** -0.5)
+ attention = q.transpose(-1, -2) @ k
+ attention = attention.softmax(dim=-1)
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ return x, size
+
+
+class ChannelBlock(nn.Module):
+
+ def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
+ drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ conv_at_attn=True, conv_at_ffn=True):
+ super().__init__()
+
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
+ self.channel_attn = PreNorm(
+ norm_layer(dim),
+ ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
+ drop_path
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
+ drop_path
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.channel_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+
+ return x, size
+
+
+def window_partition(x, window_size: int):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
+ B = batch_size
+ # this will cause onnx conversion failed for dynamic axis, because treated as constant
+ # int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = float(head_dim) ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, size):
+
+ H, W = size
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ x = window_partition(x, self.window_size)
+ x = x.view(-1, self.window_size * self.window_size, C)
+
+ # W-MSA/SW-MSA
+ # attn_windows = self.attn(x_windows)
+
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+ attn = self.softmax(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+
+ # merge windows
+ x = x.view(
+ -1, self.window_size, self.window_size, C
+ )
+ x = window_reverse(x, B, self.window_size, Hp, Wp)
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ return x, size
+
+
+class SpatialBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
+ super().__init__()
+
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
+ self.window_attn = PreNorm(
+ norm_layer(dim),
+ WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
+ drop_path
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
+ drop_path
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.window_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+ return x, size
+
+
+class DaViT(nn.Module):
+ """ DaViT: Dual-Attention Transformer
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3.
+ num_classes (int): Number of classes for classification head. Default: 1000.
+ patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
+ patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
+ patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
+ patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
+ embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
+ num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
+ num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ enable_checkpoint (bool): If True, enable checkpointing. Default: False.
+ conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
+ conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
+ """
+
+ def __init__(
+ self,
+ in_chans=3,
+ num_classes=1000,
+ depths=(1, 1, 3, 1),
+ patch_size=(7, 2, 2, 2),
+ patch_stride=(4, 2, 2, 2),
+ patch_padding=(3, 0, 0, 0),
+ patch_prenorm=(False, False, False, False),
+ embed_dims=(64, 128, 192, 256),
+ num_heads=(3, 6, 12, 24),
+ num_groups=(3, 6, 12, 24),
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ enable_checkpoint=False,
+ conv_at_attn=True,
+ conv_at_ffn=True,
+ ):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.num_groups = num_groups
+ self.num_stages = len(self.embed_dims)
+ self.enable_checkpoint = enable_checkpoint
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
+
+ num_stages = len(embed_dims)
+ # Use pure Python math to stay meta-tensor safe during init.
+ n_drop = sum(depths) * 2
+ if n_drop <= 1:
+ dpr = [0.0]
+ else:
+ step = drop_path_rate / (n_drop - 1)
+ dpr = [i * step for i in range(n_drop)]
+
+ depth_offset = 0
+ convs = []
+ blocks = []
+ for i in range(num_stages):
+ conv_embed = ConvEmbed(
+ patch_size=patch_size[i],
+ stride=patch_stride[i],
+ padding=patch_padding[i],
+ in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
+ embed_dim=self.embed_dims[i],
+ norm_layer=norm_layer,
+ pre_norm=patch_prenorm[i]
+ )
+ convs.append(conv_embed)
+
+ block = MySequential(
+ *[
+ MySequential(OrderedDict([
+ (
+ 'spatial_block', SpatialBlock(
+ embed_dims[i],
+ num_heads[i],
+ window_size,
+ drop_path_rate=dpr[depth_offset+j*2],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ )
+ ),
+ (
+ 'channel_block', ChannelBlock(
+ embed_dims[i],
+ num_groups[i],
+ drop_path_rate=dpr[depth_offset+j*2+1],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ )
+ )
+ ])) for j in range(depths[i])
+ ]
+ )
+ blocks.append(block)
+ depth_offset += depths[i]*2
+
+ self.convs = nn.ModuleList(convs)
+ self.blocks = nn.ModuleList(blocks)
+
+ self.norms = norm_layer(self.embed_dims[-1])
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ @property
+ def dim_out(self):
+ return self.embed_dims[-1]
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, std=0.02)
+ for name, _ in m.named_parameters():
+ if name in ['bias']:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.weight, 1.0)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.0)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features_unpool(self, x):
+ """
+ forward until avg pooling
+ Args:
+ x (_type_): input image tensor
+ """
+ input_size = (x.size(2), x.size(3))
+ for conv, block in zip(self.convs, self.blocks):
+ x, input_size = conv(x, input_size)
+ if self.enable_checkpoint:
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
+ else:
+ x, input_size = block(x, input_size)
+ return x
+
+ def forward_features(self, x):
+ x = self.forward_features_unpool(x)
+
+ # (batch_size, num_tokens, token_dim)
+ x = self.avgpool(x.transpose(1, 2))
+ # (batch_size, 1, num_tokens)
+ x = torch.flatten(x, 1)
+ x = self.norms(x)
+
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(
+ depths=config.depths,
+ embed_dims=config.dim_embed,
+ num_heads=config.num_heads,
+ num_groups=config.num_groups,
+ patch_size=config.patch_size,
+ patch_stride=config.patch_stride,
+ patch_padding=config.patch_padding,
+ patch_prenorm=config.patch_prenorm,
+ drop_path_rate=config.drop_path_rate,
+ window_size=config.window_size,
+ )
+
+
+
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class Florence2LearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # Florence2 is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+ bsz, seq_len = input_ids.shape[:2]
+ positions = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+ ).expand(bsz, -1)
+
+ return super().forward(positions + self.offset)
+
+
+class Florence2ScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale
+
+
+class Florence2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[Florence2LanguageConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class Florence2FlashAttention2(Florence2Attention):
+ """
+ Florence2 flash attention module. This module inherits from `Florence2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # Florence2FlashAttention2 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("Florence2FlashAttention2 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class Florence2SdpaAttention(Florence2Attention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ if output_attentions or layer_head_mask is not None:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Florence2Model is using Florence2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ query_states = self._shape(query_states, tgt_len, bsz)
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+ is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+FLORENCE2_ATTENTION_CLASSES = {
+ "eager": Florence2Attention,
+ "sdpa": Florence2SdpaAttention,
+ "flash_attention_2": Florence2FlashAttention2,
+}
+
+
+class Florence2EncoderLayer(nn.Module):
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Florence2DecoderLayer(nn.Module):
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+
+class Florence2LanguagePreTrainedModel(PreTrainedModel, GenerationMixin ):
+ config_class = Florence2LanguageConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
+ _no_split_modules = [r"Florence2EncoderLayer", r"Florence2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class Florence2Encoder(Florence2LanguagePreTrainedModel, GenerationMixin):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`Florence2EncoderLayer`].
+
+ Args:
+ config: Florence2LanguageConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ self.embed_tokens = Florence2ScaledWordEmbedding(
+ config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
+ )
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = Florence2LearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ )
+ self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif inputs_embeds is not None:
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ embed_pos = self.embed_positions(input)
+ embed_pos = embed_pos.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ if self._use_flash_attention_2:
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class Florence2Decoder(Florence2LanguagePreTrainedModel, GenerationMixin):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Florence2DecoderLayer`]
+
+ Args:
+ config: Florence2LanguageConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ self.embed_tokens = Florence2ScaledWordEmbedding(
+ config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
+ )
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = Florence2LearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ )
+ self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self._use_flash_attention_2:
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+ positions = positions.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class Florence2LanguageModel(Florence2LanguagePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {
+ "encoder.embed_tokens.weight": "shared.weight",
+ "decoder.embed_tokens.weight": "shared.weight",
+ }
+
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = Florence2Encoder(config, self.shared)
+ self.decoder = Florence2Decoder(config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
+ # different to other models, Florence2 automatically creates decoder_input_ids from
+ # input_ids if no decoder_input_ids are provided
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ if input_ids is None:
+ raise ValueError(
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
+ "passed, `input_ids` cannot be `None`. Please pass either "
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
+ )
+
+ decoder_input_ids = shift_tokens_right(
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
+ base_model_prefix = "model"
+ _tied_weights_keys = {
+ "model.encoder.embed_tokens.weight": "model.shared.weight",
+ "model.decoder.embed_tokens.weight": "model.shared.weight",
+ "lm_head.weight": "model.shared.weight",
+ }
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__(config)
+ self.model = Florence2LanguageModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def __repr__(self):
+ # Keep debugger repr fast and side-effect free.
+ return f"{self.__class__.__name__}()"
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ lm_logits = self.lm_head(outputs[0])
+ lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
+
+ masked_lm_loss = None
+ if labels is not None:
+ labels = labels.to(lm_logits.device)
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ if past_key_values is not None and not isinstance(past_key_values, (tuple, list)):
+ if hasattr(past_key_values, "to_legacy_cache"):
+ past_key_values = past_key_values.to_legacy_cache()
+ if past_key_values is not None and len(past_key_values) == 0:
+ past_key_values = None
+ if past_key_values is not None and isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0:
+ first = past_key_values[0]
+ if isinstance(first, (tuple, list)) and len(first) > 0 and first[0] is None:
+ past_key_values = None
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ + layer_past[2:],
+ )
+ return reordered_past
+
+@dataclass
+class Florence2Seq2SeqLMOutput(ModelOutput):
+ """
+ Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size,
+ num_image_tokens, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder
+ """
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+FLORENCE2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Florence2Config`] or [`Florence2VisionConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Florence-2 Model outputting raw hidden-states without any specific head on top.",
+ FLORENCE2_START_DOCSTRING,
+)
+class Florence2PreTrainedModel(PreTrainedModel, GenerationMixin):
+ config_class = Florence2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+
+ @property
+ def _supports_flash_attn_2(self):
+ """
+ Retrieve language_model's attribute to check whether the model supports
+ Flash Attention 2 or not.
+ """
+ language_model = getattr(self, "language_model", None)
+ if language_model is None:
+ return False
+ return getattr(language_model, "_supports_flash_attn_2", False)
+
+ @property
+ def _supports_sdpa(self):
+ """
+ Retrieve language_model's attribute to check whether the model supports
+ SDPA or not.
+ """
+ language_model = getattr(self, "language_model", None)
+ if language_model is None:
+ return True
+ return getattr(language_model, "_supports_sdpa", True)
+
+
+FLORENCE2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Florence2Processor`] uses
+ [`CLIPImageProcessor`] for processing images).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+@add_start_docstrings(
+ """The FLORENCE2 vision model without any head""",
+ FLORENCE2_START_DOCSTRING,
+)
+class Florence2VisionModel(Florence2PreTrainedModel):
+ def __init__(self, config: Florence2VisionConfig):
+ super().__init__(config)
+ assert config.model_type == 'davit', 'only DaViT is supported for now'
+ self.vision_tower = DaViT.from_config(config=config)
+
+ self.post_init()
+
+ def forward(self, pixel_values):
+ if len(pixel_values.shape) == 4:
+ x = self.vision_tower.forward_features_unpool(pixel_values)
+ else:
+ raise ValueError(f'invalid image shape {pixel_values.shape}')
+ return x
+
+
+@add_start_docstrings(
+ """The FLORENCE2 vision model with projection layer""",
+ FLORENCE2_START_DOCSTRING,
+)
+class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
+ def __init__(self, config: Florence2VisionConfig):
+ super().__init__(config)
+ assert config.model_type == 'davit', 'only DaViT is supported for now'
+ self.vision_tower = DaViT.from_config(config=config)
+
+ self._build_image_projection_layers(config)
+
+ self.post_init()
+
+ def _build_image_projection_layers(self, config):
+ image_dim_out = config.dim_embed[-1]
+ dim_projection = config.projection_dim
+ self.image_projection = nn.Parameter(
+ torch.empty(image_dim_out, dim_projection)
+ )
+ self.image_proj_norm = nn.LayerNorm(dim_projection)
+ image_pos_embed_config = config.image_pos_embed
+ if image_pos_embed_config['type'] == 'learned_abs_2d':
+ self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
+ embedding_dim=image_dim_out,
+ num_pos=image_pos_embed_config['max_pos_embeddings']
+ )
+ else:
+ raise NotImplementedError('Not implemented yet')
+
+ self.image_feature_source = config.image_feature_source
+
+ # temporal embedding
+ visual_temporal_embedding_config = config.visual_temporal_embedding
+ if visual_temporal_embedding_config['type'] == 'COSINE':
+ self.visual_temporal_embed = PositionalEmbeddingCosine1D(
+ embed_dim=image_dim_out,
+ max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
+ )
+ else:
+ raise NotImplementedError('Not implemented yet')
+
+ def forward(self, pixel_values):
+ if len(pixel_values.shape) == 4:
+ batch_size, C, H, W = pixel_values.shape
+ T = 1
+ x = self.vision_tower.forward_features_unpool(pixel_values)
+ else:
+ raise ValueError(f'invalid image shape {pixel_values.shape}')
+
+ if self.image_pos_embed is not None:
+ x = x.view(batch_size * T, -1, x.shape[-1])
+ num_tokens = x.shape[-2]
+ h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
+ assert h * w == num_tokens, 'only support square feature maps for now'
+ x = x.view(batch_size * T, h, w, x.shape[-1])
+ pos_embed = self.image_pos_embed(x)
+ x = x + pos_embed
+ x = x.view(batch_size, T * h*w, x.shape[-1])
+
+ if self.visual_temporal_embed is not None:
+ visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
+ x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
+
+ x_feat_dict = {}
+
+ spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
+ x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
+
+ temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
+ x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
+
+ x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
+ x_feat_dict['last_frame'] = x
+
+ new_x = []
+ for _image_feature_source in self.image_feature_source:
+ if _image_feature_source not in x_feat_dict:
+ raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
+ new_x.append(x_feat_dict[_image_feature_source])
+
+ x = torch.cat(new_x, dim=1)
+
+ x = x @ self.image_projection
+ x = self.image_proj_norm(x)
+
+
+ return x
+
+
+
+@add_start_docstrings(
+ """The FLORENCE2 model which consists of a vision backbone and a language model.""",
+ FLORENCE2_START_DOCSTRING,
+)
+class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixin):
+ def __init__(self, config: Florence2Config):
+ super().__init__(config)
+ assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
+ self.vision_tower = DaViT.from_config(config=config.vision_config)
+ # remove unused layers
+ del self.vision_tower.head
+ del self.vision_tower.norms
+
+ self.vocab_size = config.vocab_size
+ self._attn_implementation = config._attn_implementation
+ self._build_image_projection_layers(config)
+
+ language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
+
+ if language_model._tied_weights_keys is not None:
+ if isinstance(language_model._tied_weights_keys, dict):
+ self._tied_weights_keys = {
+ f"language_model.{k}": f"language_model.{v}"
+ for k, v in language_model._tied_weights_keys.items()
+ }
+ else:
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ def _build_image_projection_layers(self, config):
+ image_dim_out = config.vision_config.dim_embed[-1]
+ dim_projection = config.vision_config.projection_dim
+ self.image_projection = nn.Parameter(
+ torch.empty(image_dim_out, dim_projection)
+ )
+ self.image_proj_norm = nn.LayerNorm(dim_projection)
+ image_pos_embed_config = config.vision_config.image_pos_embed
+ if image_pos_embed_config['type'] == 'learned_abs_2d':
+ self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
+ embedding_dim=image_dim_out,
+ num_pos=image_pos_embed_config['max_pos_embeddings']
+ )
+ else:
+ raise NotImplementedError('Not implemented yet')
+
+ self.image_feature_source = config.vision_config.image_feature_source
+
+ # temporal embedding
+ visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
+ if visual_temporal_embedding_config['type'] == 'COSINE':
+ self.visual_temporal_embed = PositionalEmbeddingCosine1D(
+ embed_dim=image_dim_out,
+ max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
+ )
+ else:
+ raise NotImplementedError('Not implemented yet')
+
+ def get_encoder(self):
+ return self.language_model.get_encoder()
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _encode_image(self, pixel_values):
+ if len(pixel_values.shape) == 4:
+ batch_size, C, H, W = pixel_values.shape
+ T = 1
+ x = self.vision_tower.forward_features_unpool(pixel_values)
+ else:
+ raise ValueError(f'invalid image shape {pixel_values.shape}')
+
+ if self.image_pos_embed is not None:
+ x = x.view(batch_size * T, -1, x.shape[-1])
+ num_tokens = x.shape[-2]
+ h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
+ assert h * w == num_tokens, 'only support square feature maps for now'
+ x = x.view(batch_size * T, h, w, x.shape[-1])
+ pos_embed = self.image_pos_embed(x)
+ x = x + pos_embed
+ x = x.view(batch_size, T * h*w, x.shape[-1])
+
+ if self.visual_temporal_embed is not None:
+ visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
+ x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
+
+ x_feat_dict = {}
+
+ spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
+ x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
+
+ temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
+ x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
+
+ x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
+ x_feat_dict['last_frame'] = x
+
+ new_x = []
+ for _image_feature_source in self.image_feature_source:
+ if _image_feature_source not in x_feat_dict:
+ raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
+ new_x.append(x_feat_dict[_image_feature_source])
+
+ x = torch.cat(new_x, dim=1)
+
+ x = x @ self.image_projection
+ x = self.image_proj_norm(x)
+
+ return x
+
+ def _merge_input_ids_with_image_features(
+ self, image_features, inputs_embeds
+ ):
+ batch_size, image_token_length = image_features.size()[:-1]
+ device = image_features.device
+ image_attention_mask = torch.ones(batch_size, image_token_length, device=device)
+
+ # task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
+ # task_prefix_attention_mask: [batch_size, context_length]
+ if inputs_embeds is None:
+ return image_features, image_attention_mask
+
+ task_prefix_embeds = inputs_embeds
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
+
+ if len(task_prefix_attention_mask.shape) == 3:
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
+
+ # concat [image embeds, task prefix embeds]
+ inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
+ attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1)
+
+ return inputs_embeds, attention_mask
+
+
+ @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+ >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
+ >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
+
+ >>> prompt = "
"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=100)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "A green car parked in front of a yellow building."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ image_features = None
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ # 2. Merge text and images
+ if pixel_values is not None:
+ # (batch_size, num_image_tokens, hidden_size)
+ image_features = self._encode_image(pixel_values)
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
+
+ if inputs_embeds is not None:
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ labels=labels,
+ inputs_embeds=inputs_embeds,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs.logits
+ logits = logits.float()
+ loss = outputs.loss
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Florence2Seq2SeqLMOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ image_hidden_states=image_features
+ )
+
+ def generate(
+ self,
+ input_ids,
+ inputs_embeds=None,
+ pixel_values=None,
+ **kwargs
+ ):
+ attention_mask = None
+ if "attention_mask" in kwargs:
+ provided_mask = kwargs.pop("attention_mask")
+ if provided_mask is not None:
+ attention_mask = provided_mask
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ # 2. Merge text and images
+ if pixel_values is not None:
+ image_features = self._encode_image(pixel_values)
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
+ if inputs_embeds is not None:
+ encoder = self.language_model.get_encoder()
+ encoder_outputs = encoder(
+ input_ids=None,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=kwargs.get("output_attentions", None),
+ output_hidden_states=kwargs.get("output_hidden_states", None),
+ return_dict=True,
+ )
+ else:
+ encoder_outputs = None
+
+ return self.language_model.generate(
+ encoder_outputs=encoder_outputs,
+ attention_mask=attention_mask,
+ **kwargs
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ pixel_values=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ if past_key_values is not None and not isinstance(past_key_values, (tuple, list)):
+ if hasattr(past_key_values, "to_legacy_cache"):
+ past_key_values = past_key_values.to_legacy_cache()
+ if past_key_values is not None and len(past_key_values) == 0:
+ past_key_values = None
+ if past_key_values is not None and isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0:
+ first = past_key_values[0]
+ if isinstance(first, (tuple, list)) and len(first) > 0 and first[0] is None:
+ past_key_values = None
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return self.language_model.shift_tokens_right(labels)
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
diff --git a/Wan2GP/shared/prompt_enhancer/florence2/processing_florence2.py b/Wan2GP/shared/prompt_enhancer/florence2/processing_florence2.py
new file mode 100644
index 000000000..2c0befda9
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/florence2/processing_florence2.py
@@ -0,0 +1,1118 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Florence-2.
+"""
+
+import re
+import logging
+from typing import List, Optional, Union
+import numpy as np
+
+import torch
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput, is_valid_image
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import (
+ PaddingStrategy,
+ PreTokenizedInput,
+ TextInput,
+ TruncationStrategy,
+)
+from transformers.utils import TensorType
+
+
+logger = logging.getLogger(__name__)
+
+# Copied from transformers.models.idefics2.processing_idefics2.is_url
+def is_url(val) -> bool:
+ return isinstance(val, str) and val.startswith("http")
+
+# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
+def is_image_or_image_url(elem):
+ return is_url(elem) or is_valid_image(elem)
+
+
+def _is_str_or_image(elem):
+ return isinstance(elem, (str)) or is_image_or_image_url(elem)
+
+
+class Florence2Processor(ProcessorMixin):
+ r"""
+ Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
+
+ [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
+ [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
+
+ Args:
+ image_processor ([`CLIPImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`BartTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "CLIPImageProcessor"
+ tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ ):
+ if image_processor is None:
+ raise ValueError("You need to specify an `image_processor`.")
+ if tokenizer is None:
+ raise ValueError("You need to specify a `tokenizer`.")
+ if not hasattr(image_processor, "image_seq_length"):
+ raise ValueError("Image processor is missing an `image_seq_length` attribute.")
+
+ self.image_seq_length = image_processor.image_seq_length
+
+ extra_special_tokens = getattr(tokenizer, "additional_special_tokens", None)
+ if extra_special_tokens is None:
+ extra_special_tokens = getattr(tokenizer, "extra_special_tokens", None)
+ if extra_special_tokens is None:
+ extra_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens")
+ if extra_special_tokens is None:
+ extra_special_tokens = tokenizer.special_tokens_map.get("extra_special_tokens")
+ if extra_special_tokens is None:
+ extra_special_tokens = []
+
+ tokens_to_add = {
+ "additional_special_tokens": extra_special_tokens
+ + ["", "", "", ""]
+ + [f"" for x in range(1000)]
+ + [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ],
+ }
+ tokenizer.add_special_tokens(tokens_to_add)
+
+ self.tasks_answer_post_processing_type = {
+ '': 'pure_text',
+ '': 'ocr',
+ '
': 'pure_text',
+ '': 'pure_text',
+ '': 'pure_text',
+ '': 'description_with_bboxes',
+ '': 'description_with_bboxes',
+ '': "phrase_grounding",
+ '': 'polygons',
+ '': 'polygons',
+ '': 'description_with_bboxes_or_polygons',
+ '': 'pure_text',
+ '': 'pure_text',
+ '': 'pure_text',
+ '': 'bboxes'
+ }
+
+ self.task_prompts_without_inputs = {
+ '': 'What is the text in the image?',
+ '': 'What is the text in the image, with regions?',
+ '
': 'What does the image describe?',
+ '': 'Describe in detail what is shown in the image.',
+ '': 'Describe with a paragraph what is shown in the image.',
+ '': 'Locate the objects with category name in the image.',
+ '': 'Locate the objects in the image, with their descriptions.',
+ '': 'Locate the region proposals in the image.'
+ }
+
+ self.task_prompts_with_input = {
+ '': "Locate the phrases in the caption: {input}",
+ '': 'Locate {input} in the image with mask',
+ '': 'What is the polygon mask of region {input}',
+ '': 'Locate {input} in the image.',
+ '': 'What is the region {input}?',
+ '': 'What does the region {input} describe?',
+ '': 'What text is in the region {input}?',
+ }
+
+ self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
+
+
+ super().__init__(image_processor, tokenizer)
+
+ def _construct_prompts(self, text):
+ # replace the task tokens with the task prompts if task token is in the text
+ prompts = []
+ for _text in text:
+ # 1. fixed task prompts without additional inputs
+ for task_token, task_prompt in self.task_prompts_without_inputs.items():
+ if task_token in _text:
+ assert _text == task_token, f"Task token {task_token} should be the only token in the text."
+ _text = task_prompt
+ break
+ # 2. task prompts with additional inputs
+ for task_token, task_prompt in self.task_prompts_with_input.items():
+ if task_token in _text:
+ _text = task_prompt.format(input=_text.replace(task_token, ''))
+ break
+ prompts.append(_text)
+ return prompts
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ images: ImageInput = None,
+ tokenize_newline_separately: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length=None,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ do_resize: bool = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
+ input_data_format: Optional[
+ Union[str, "ChannelDimension"] # noqa: F821
+ ] = None,
+ resample: "PILImageResampling" = None, # noqa: F821
+ do_convert_rgb: bool = None,
+ do_thumbnail: bool = None,
+ do_align_long_axis: bool = None,
+ do_rescale: bool = None,
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ tokenize_newline_separately (`bool`, defaults to `True`):
+ Adds a separately tokenized '\n' at the end of the prompt.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
+ is provided, the `input_ids` will also contain the suffix input ids.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **labels** -- Labels compatible with training if `suffix` is not None
+ """
+
+ return_token_type_ids = False
+
+ if images is None:
+ raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
+ if text is None:
+ logger.warning_once(
+ "You are using Florence-2 without a text prompt."
+ )
+ text = ""
+
+ if isinstance(text, List) and isinstance(images, List):
+ if len(images) < len(text):
+ raise ValueError(
+ f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
+ )
+ if _is_str_or_image(text):
+ text = [text]
+ elif isinstance(text, list) and _is_str_or_image(text[0]):
+ pass
+
+ pixel_values = self.image_processor(
+ images,
+ do_resize=do_resize,
+ do_normalize=do_normalize,
+ return_tensors=return_tensors,
+ image_mean=image_mean,
+ image_std=image_std,
+ input_data_format=input_data_format,
+ data_format=data_format,
+ resample=resample,
+ do_convert_rgb=do_convert_rgb,
+ )["pixel_values"]
+
+ if max_length is not None:
+ max_length -= self.image_seq_length # max_length has to account for the image tokens
+
+ text = self._construct_prompts(text)
+
+ inputs = self.tokenizer(
+ text,
+ return_tensors=return_tensors,
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ return_token_type_ids=return_token_type_ids,
+ )
+
+ return_data = {**inputs, "pixel_values": pixel_values}
+
+ if return_token_type_ids:
+ labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
+ return_data.update({"labels": labels})
+ return BatchFeature(data=return_data)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+ def post_process_generation(self, text, task, image_size):
+ """
+ Post-process the output of the model to each of the task outputs.
+
+ Args:
+ text (`str`): The text to post-process.
+ task (`str`): The task to post-process the text for.
+ image_size (`Tuple[int, int]`): The size of the image. height x width.
+ """
+
+ task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
+ task_answer = self.post_processor(
+ text=text,
+ image_size=image_size,
+ parse_tasks=task_answer_post_processing_type,
+ )[task_answer_post_processing_type]
+
+ if task_answer_post_processing_type == 'pure_text':
+ final_answer = task_answer
+ # remove the special tokens
+ final_answer = final_answer.replace('', '').replace('', '')
+ elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
+ od_instances = task_answer
+ bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
+ labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
+ final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
+ elif task_answer_post_processing_type in ['ocr']:
+ bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
+ labels = [str(_od_instance['text']) for _od_instance in task_answer]
+ final_answer = {'quad_boxes': bboxes, 'labels': labels}
+ elif task_answer_post_processing_type in ['phrase_grounding']:
+ bboxes = []
+ labels = []
+ for _grounded_phrase in task_answer:
+ for _bbox in _grounded_phrase['bbox']:
+ bboxes.append(_bbox)
+ labels.append(_grounded_phrase['cat_name'])
+ final_answer = {'bboxes': bboxes, 'labels': labels}
+ elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
+ labels = []
+ polygons = []
+ for result in task_answer:
+ label = result['cat_name']
+ _polygons = result['polygons']
+ labels.append(label)
+ polygons.append(_polygons)
+ final_answer = {'polygons': polygons, 'labels': labels}
+ elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
+ bboxes = []
+ bboxes_labels = []
+ polygons = []
+ polygons_labels = []
+ for result in task_answer:
+ label = result['cat_name']
+ if 'polygons' in result:
+ _polygons = result['polygons']
+ polygons.append(_polygons)
+ polygons_labels.append(label)
+ else:
+ _bbox = result['bbox']
+ bboxes.append(_bbox)
+ bboxes_labels.append(label)
+ final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
+ else:
+ raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
+
+ final_answer = {
+ task: final_answer}
+ return final_answer
+
+class BoxQuantizer(object):
+ def __init__(self, mode, bins):
+ self.mode = mode
+ self.bins = bins
+
+ def quantize(self, boxes: torch.Tensor, size):
+ bins_w, bins_h = self.bins # Quantization bins.
+ size_w, size_h = size # Original image size.
+ size_per_bin_w = size_w / bins_w
+ size_per_bin_h = size_h / bins_h
+ xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
+
+ if self.mode == 'floor':
+ quantized_xmin = (
+ xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
+ quantized_ymin = (
+ ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
+ quantized_xmax = (
+ xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
+ quantized_ymax = (
+ ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
+
+ elif self.mode == 'round':
+ raise NotImplementedError()
+
+ else:
+ raise ValueError('Incorrect quantization type.')
+
+ quantized_boxes = torch.cat(
+ (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
+ ).int()
+
+ return quantized_boxes
+
+ def dequantize(self, boxes: torch.Tensor, size):
+ bins_w, bins_h = self.bins # Quantization bins.
+ size_w, size_h = size # Original image size.
+ size_per_bin_w = size_w / bins_w
+ size_per_bin_h = size_h / bins_h
+ xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
+
+ if self.mode == 'floor':
+ # Add 0.5 to use the center position of the bin as the coordinate.
+ dequantized_xmin = (xmin + 0.5) * size_per_bin_w
+ dequantized_ymin = (ymin + 0.5) * size_per_bin_h
+ dequantized_xmax = (xmax + 0.5) * size_per_bin_w
+ dequantized_ymax = (ymax + 0.5) * size_per_bin_h
+
+ elif self.mode == 'round':
+ raise NotImplementedError()
+
+ else:
+ raise ValueError('Incorrect quantization type.')
+
+ dequantized_boxes = torch.cat(
+ (dequantized_xmin, dequantized_ymin,
+ dequantized_xmax, dequantized_ymax), dim=-1
+ )
+
+ return dequantized_boxes
+
+
+class CoordinatesQuantizer(object):
+ """
+ Quantize coornidates (Nx2)
+ """
+
+ def __init__(self, mode, bins):
+ self.mode = mode
+ self.bins = bins
+
+ def quantize(self, coordinates: torch.Tensor, size):
+ bins_w, bins_h = self.bins # Quantization bins.
+ size_w, size_h = size # Original image size.
+ size_per_bin_w = size_w / bins_w
+ size_per_bin_h = size_h / bins_h
+ assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
+ x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
+
+ if self.mode == 'floor':
+ quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
+ quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
+
+ elif self.mode == 'round':
+ raise NotImplementedError()
+
+ else:
+ raise ValueError('Incorrect quantization type.')
+
+ quantized_coordinates = torch.cat(
+ (quantized_x, quantized_y), dim=-1
+ ).int()
+
+ return quantized_coordinates
+
+ def dequantize(self, coordinates: torch.Tensor, size):
+ bins_w, bins_h = self.bins # Quantization bins.
+ size_w, size_h = size # Original image size.
+ size_per_bin_w = size_w / bins_w
+ size_per_bin_h = size_h / bins_h
+ assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
+ x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
+
+ if self.mode == 'floor':
+ # Add 0.5 to use the center position of the bin as the coordinate.
+ dequantized_x = (x + 0.5) * size_per_bin_w
+ dequantized_y = (y + 0.5) * size_per_bin_h
+
+ elif self.mode == 'round':
+ raise NotImplementedError()
+
+ else:
+ raise ValueError('Incorrect quantization type.')
+
+ dequantized_coordinates = torch.cat(
+ (dequantized_x, dequantized_y), dim=-1
+ )
+
+ return dequantized_coordinates
+
+
+class Florence2PostProcesser(object):
+ r"""
+ Florence-2 post process for converting text prediction to various tasks results.
+
+ Args:
+ config: A dict of configs.
+ tokenizer: A tokenizer for decoding text to spans.
+ sample config:
+ UNIFIED_POST_PROCESS:
+ # commom configs
+ NUM_BBOX_HEIGHT_BINS: 1000
+ NUM_BBOX_WIDTH_BINS: 1000
+ COORDINATES_HEIGHT_BINS: 1000
+ COORDINATES_WIDTH_BINS: 1000
+ # task specific configs, override the common configs
+ PRASE_TASKS:
+ - TASK_NAME: 'video_dense_caption'
+ PATTERN: 'r([a-zA-Z0-9 ]+)'
+ SCORE_MODE: 'avg_cat_name_scores'
+ NUM_BINS: 100
+ - TASK_NAME: 'od'
+ PATTERN: 'r([a-zA-Z0-9 ]+)'
+ SCORE_MODE: 'avg_cat_name_scores'
+
+ Returns:
+ parsed_dict (dict): A dict of parsed results.
+ """
+ def __init__(
+ self,
+ tokenizer=None
+ ):
+ parse_tasks = []
+ parse_task_configs = {}
+ config = self._create_default_config()
+ for task in config['PARSE_TASKS']:
+ parse_tasks.append(task['TASK_NAME'])
+ parse_task_configs[task['TASK_NAME']] = task
+
+ self.config = config
+ self.parse_tasks = parse_tasks
+ self.parse_tasks_configs = parse_task_configs
+
+ self.tokenizer = tokenizer
+ if self.tokenizer is not None:
+ self.all_special_tokens = set(self.tokenizer.all_special_tokens)
+
+ self.init_quantizers()
+ self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
+
+ def _create_black_list_of_phrase_grounding(self):
+ black_list = {}
+
+ if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
+ black_list = set(
+ ['it', 'I', 'me', 'mine',
+ 'you', 'your', 'yours',
+ 'he', 'him', 'his',
+ 'she', 'her', 'hers',
+ 'they', 'them', 'their', 'theirs',
+ 'one', 'oneself',
+ 'we', 'us', 'our', 'ours',
+ 'you', 'your', 'yours',
+ 'they', 'them', 'their', 'theirs',
+ 'mine', 'yours', 'his', 'hers', 'its',
+ 'ours', 'yours', 'theirs',
+ 'myself', 'yourself', 'himself', 'herself', 'itself',
+ 'ourselves', 'yourselves', 'themselves',
+ 'this', 'that',
+ 'these', 'those',
+ 'who', 'whom', 'whose', 'which', 'what',
+ 'who', 'whom', 'whose', 'which', 'that',
+ 'all', 'another', 'any', 'anybody', 'anyone', 'anything',
+ 'each', 'everybody', 'everyone', 'everything',
+ 'few', 'many', 'nobody', 'none', 'one', 'several',
+ 'some', 'somebody', 'someone', 'something',
+ 'each other', 'one another',
+ 'myself', 'yourself', 'himself', 'herself', 'itself',
+ 'ourselves', 'yourselves', 'themselves',
+ 'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
+ 'other objects', 'lots', 'a set',
+ ]
+ )
+
+ return black_list
+
+ def _create_default_config(self):
+ config = {
+ 'NUM_BBOX_HEIGHT_BINS': 1000,
+ 'NUM_BBOX_WIDTH_BINS': 1000,
+ 'BOX_QUANTIZATION_MODE': 'floor',
+ 'COORDINATES_HEIGHT_BINS': 1000,
+ 'COORDINATES_WIDTH_BINS': 1000,
+ 'COORDINATES_QUANTIZATION_MODE': 'floor',
+ 'PARSE_TASKS': [
+ {
+ 'TASK_NAME': 'od',
+ 'PATTERN': r'([a-zA-Z0-9 ]+)'
+ },
+ {
+ 'TASK_NAME': 'ocr',
+ 'PATTERN': r'(.+?)',
+ 'AREA_THRESHOLD': 0.00
+ },
+ {
+ 'TASK_NAME': 'phrase_grounding',
+ 'FILTER_BY_BLACK_LIST': True
+ },
+ {
+ 'TASK_NAME': 'pure_text',
+ },
+ {
+ 'TASK_NAME': 'description_with_bboxes',
+ },
+ {
+ 'TASK_NAME': 'description_with_polygons',
+ },
+ {
+ 'TASK_NAME': 'polygons',
+ },
+ {
+ 'TASK_NAME': 'bboxes',
+ },
+ {
+ 'TASK_NAME': 'description_with_bboxes_or_polygons',
+ }
+ ]
+ }
+
+ return config
+
+ def init_quantizers(self):
+ # we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
+ num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
+ num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
+ box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
+ self.box_quantizer = BoxQuantizer(
+ box_quantization_mode,
+ (num_bbox_width_bins, num_bbox_height_bins),
+ )
+
+ num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
+ num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
+ box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
+ self.coordinates_quantizer = CoordinatesQuantizer(
+ box_quantization_mode,
+ (num_bbox_width_bins, num_bbox_height_bins),
+ )
+
+ def decode_with_spans(self, tokenizer, token_ids):
+ filtered_tokens = tokenizer.convert_ids_to_tokens(
+ token_ids, skip_special_tokens=False)
+ assert len(filtered_tokens) == len(token_ids)
+
+ # To avoid mixing byte-level and unicode for byte-level BPT
+ # we need to build string separately for added tokens and byte-level tokens
+ # cf. https://github.com/huggingface/transformers/issues/1133
+ sub_texts = []
+ for token in filtered_tokens:
+ if token in self.all_special_tokens:
+ sub_texts.append(token)
+ else:
+ if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
+ sub_text = tokenizer.convert_tokens_to_string([token])
+ elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
+ # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
+ # Note: Do not strip sub_text as it may have functional whitespace
+ sub_text = token.replace('▁', ' ')
+ else:
+ raise ValueError(f'type {type(tokenizer)} not supported')
+ sub_texts.append(sub_text)
+
+ text = ''
+ spans = []
+ for sub_text in sub_texts:
+ span = (len(text), len(text) + len(sub_text)) # [start index, end index).
+ text += sub_text
+ spans.append(span)
+
+ # Text format:
+ # 1. T5Tokenizer/T5TokenizerFast:
+ # " transplanting dog cat"
+ # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
+ # 2. BartTokenizer (need to double check):
+ # "transplanting dogcat"
+ # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
+ return text, spans
+
+ def parse_od_from_text_and_spans(
+ self,
+ text,
+ pattern,
+ image_size,
+ phrase_centric=False
+ ):
+ parsed = list(re.finditer(pattern, text))
+
+ instances = []
+ for i in range(len(parsed)):
+ # Prepare instance.
+ instance = {}
+
+ if phrase_centric:
+ bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
+ else:
+ bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
+ instance['bbox'] = self.box_quantizer.dequantize(
+ boxes=torch.tensor(bbox_bins),
+ size=image_size
+ ).tolist()
+
+ if phrase_centric:
+ instance['cat_name'] = parsed[i].group(1).lower().strip()
+ else:
+ instance['cat_name'] = parsed[i].group(5).lower().strip()
+ instances.append(instance)
+
+ return instances
+
+ def parse_ocr_from_text_and_spans(self,
+ text,
+ pattern,
+ image_size,
+ area_threshold=-1.0,
+ ):
+ bboxes = []
+ labels = []
+ text = text.replace('', '')
+ # ocr with regions
+ parsed = re.findall(pattern, text)
+ instances = []
+ image_width, image_height = image_size
+
+ for ocr_line in parsed:
+ ocr_content = ocr_line[0]
+ quad_box = ocr_line[1:]
+ quad_box = [int(i) for i in quad_box]
+ quad_box = self.coordinates_quantizer.dequantize(
+ torch.tensor(np.array(quad_box).reshape(-1, 2)),
+ size=image_size
+ ).reshape(-1).tolist()
+
+ if area_threshold > 0:
+ x_coords = [i for i in quad_box[0::2]]
+ y_coords = [i for i in quad_box[1::2]]
+
+ # apply the Shoelace formula
+ area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
+
+ if area < (image_width * image_height) * area_threshold:
+ continue
+
+ bboxes.append(quad_box)
+ labels.append(ocr_content)
+ instances.append({
+ 'quad_box': quad_box,
+ 'text': ocr_content,
+ })
+ return instances
+
+ def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
+ # ignore and
+ cur_span = 0
+ if text.startswith(''):
+ cur_span += 3
+
+ text = text.replace('', '')
+ text = text.replace('', '')
+ text = text.replace('', '')
+
+ pattern = r"([^<]+(?:){4,})"
+ phrases = re.findall(pattern, text)
+
+ # pattern should be text pattern and od pattern
+ pattern = r'^\s*(.*?)(?=||||||'
+
+ instances = []
+ for pharse_text in phrases:
+ phrase_text_strip = pharse_text.replace('', '', 1)
+ phrase_text_strip = pharse_text.replace('', '', 1)
+
+ if phrase_text_strip == '':
+ cur_span += len(pharse_text)
+ continue
+
+ # Prepare instance.
+ instance = {}
+
+ # parse phrase, get string
+ phrase = re.search(pattern, phrase_text_strip)
+ if phrase is None:
+ cur_span += len(pharse_text)
+ continue
+
+ # parse bboxes by box_pattern
+ bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
+ if len(bboxes_parsed) == 0:
+ cur_span += len(pharse_text)
+ continue
+
+ phrase = phrase.group()
+ # remove leading and trailing spaces
+ phrase = phrase.strip()
+
+ if phrase in self.black_list_of_phrase_grounding:
+ cur_span += len(pharse_text)
+ continue
+
+ # a list of list
+ bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
+ instance['bbox'] = self.box_quantizer.dequantize(
+ boxes=torch.tensor(bbox_bins),
+ size=image_size
+ ).tolist()
+
+ # exclude non-ascii characters
+ phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
+ instance['cat_name'] = phrase
+
+ instances.append(instance)
+
+ return instances
+
+ def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
+ # temporary parse solution, split by '.'
+ # ignore and
+
+ text = text.replace('', '')
+ text = text.replace('', '')
+ text = text.replace('', '')
+
+ if allow_empty_phrase:
+ pattern = rf"(?:(?:){{4,}})"
+ else:
+ pattern = r"([^<]+(?:){4,})"
+ phrases = re.findall(pattern, text)
+
+ # pattern should be text pattern and od pattern
+ pattern = r'^\s*(.*?)(?=||||||'
+
+ instances = []
+ for pharse_text in phrases:
+ phrase_text_strip = pharse_text.replace('', '', 1)
+ phrase_text_strip = pharse_text.replace('', '', 1)
+
+ if phrase_text_strip == '' and not allow_empty_phrase:
+ continue
+
+ # parse phrase, get string
+ phrase = re.search(pattern, phrase_text_strip)
+ if phrase is None:
+ continue
+
+ phrase = phrase.group()
+ # remove leading and trailing spaces
+ phrase = phrase.strip()
+
+ # parse bboxes by box_pattern
+ bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
+ if len(bboxes_parsed) == 0:
+ continue
+
+ # a list of list
+ bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
+
+ bboxes = self.box_quantizer.dequantize(
+ boxes=torch.tensor(bbox_bins),
+ size=image_size
+ ).tolist()
+
+ phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
+ for _bboxes in bboxes:
+ # Prepare instance.
+ instance = {}
+ instance['bbox'] = _bboxes
+ # exclude non-ascii characters
+ instance['cat_name'] = phrase
+ instances.append(instance)
+
+ return instances
+
+ def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
+ allow_empty_phrase=False,
+ polygon_sep_token='',
+ polygon_start_token='',
+ polygon_end_token='',
+ with_box_at_start=False,
+ ):
+
+ # ref_seg format: '<><><><><><>'
+ # ignore and
+
+ text = text.replace('', '')
+ text = text.replace('', '')
+ text = text.replace('', '')
+
+ if allow_empty_phrase:
+ pattern = rf"(?:(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
+ else:
+ # [^<]+: This part matches one or more characters that are not the < symbol.
+ # The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
+ #
+ pattern = rf"([^<]+(?:|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
+ phrases = re.findall(pattern, text)
+
+ phrase_string_pattern = r'^\s*(.*?)(?=||||||)'
+ box_pattern = rf'((?:)+)(?:{re.escape(polygon_sep_token)}|$)'
+
+ # one polygons instance is separated by polygon_start_token and polygon_end_token
+ polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
+
+ instances = []
+ for phrase_text in phrases:
+
+ # exclude loc_\d+>
+ # need to get span if want to include category score
+ phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
+
+ # phrase = phrase.replace('', '')
+ # phrase = phrase.replace('poly>', '')
+
+ if phrase_text_strip == '' and not allow_empty_phrase:
+ continue
+
+
+ # parse phrase, get string
+ phrase = re.search(phrase_string_pattern, phrase_text_strip)
+ if phrase is None:
+ continue
+ phrase = phrase.group()
+ # remove leading and trailing spaces
+ phrase = phrase.strip()
+
+ # parse bboxes by box_pattern
+
+ # split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
+ if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
+ polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
+ else:
+ polygons_instances_parsed = [phrase_text]
+
+ for _polygons_instances_parsed in polygons_instances_parsed:
+ # Prepare instance.
+ instance = {}
+
+ # polygons_parsed= list(re.finditer(box_pattern, phrase_text))
+ if isinstance(_polygons_instances_parsed, str):
+ polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
+ else:
+ polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
+ if len(polygons_parsed) == 0:
+ continue
+
+ # a list of list (polygon)
+ bbox = []
+ polygons = []
+ for _polygon_parsed in polygons_parsed:
+ # group 1: whole ...
+ _polygon = _polygon_parsed.group(1)
+ # parse into list of int
+ _polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'', _polygon)]
+ if with_box_at_start and len(bbox) == 0:
+ if len(_polygon) > 4:
+ # no valid bbox prediction
+ bbox = _polygon[:4]
+ _polygon = _polygon[4:]
+ else:
+ bbox = [0, 0, 0, 0]
+ # abandon last element if is not paired
+ if len(_polygon) % 2 == 1:
+ _polygon = _polygon[:-1]
+
+ # reshape into (n, 2)
+ _polygon = self.coordinates_quantizer.dequantize(
+ torch.tensor(np.array(_polygon).reshape(-1, 2)),
+ size=image_size
+ ).reshape(-1).tolist()
+ # reshape back
+ polygons.append(_polygon)
+
+ instance['cat_name'] = phrase
+ instance['polygons'] = polygons
+ if len(bbox) != 0:
+ instance['bbox'] = self.box_quantizer.dequantize(
+ boxes=torch.tensor([bbox]),
+ size=image_size
+ ).tolist()[0]
+
+ instances.append(instance)
+
+ return instances
+
+ def __call__(
+ self,
+ text=None,
+ image_size=None,
+ parse_tasks=None,
+ ):
+ """
+ Args:
+ text: model outputs
+ image_size: (width, height)
+ parse_tasks: a list of tasks to parse, if None, parse all tasks.
+
+ """
+ if parse_tasks is not None:
+ if isinstance(parse_tasks, str):
+ parse_tasks = [parse_tasks]
+ for _parse_task in parse_tasks:
+ assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
+
+ # sequence or text should be provided
+ assert text is not None, 'text should be provided'
+
+ parsed_dict = {
+ 'text': text
+ }
+
+ for task in self.parse_tasks:
+ if parse_tasks is not None and task not in parse_tasks:
+ continue
+
+ pattern = self.parse_tasks_configs[task].get('PATTERN', None)
+
+ if task == 'ocr':
+ instances = self.parse_ocr_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0),
+ )
+ parsed_dict['ocr'] = instances
+ elif task == 'phrase_grounding':
+ instances = self.parse_phrase_grounding_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ )
+ parsed_dict['phrase_grounding'] = instances
+ elif task == 'pure_text':
+ parsed_dict['pure_text'] = text
+ elif task == 'description_with_bboxes':
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ )
+ parsed_dict['description_with_bboxes'] = instances
+ elif task == 'description_with_polygons':
+ instances = self.parse_description_with_polygons_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ )
+ parsed_dict['description_with_polygons'] = instances
+ elif task == 'polygons':
+ instances = self.parse_description_with_polygons_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ allow_empty_phrase=True,
+ )
+ parsed_dict['polygons'] = instances
+ elif task == 'bboxes':
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ allow_empty_phrase=True,
+ )
+ parsed_dict['bboxes'] = instances
+ elif task == 'description_with_bboxes_or_polygons':
+ if '' in text:
+ # only support either polygons or bboxes, not both at the same time
+ instances = self.parse_description_with_polygons_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ )
+ else:
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
+ text,
+ pattern=pattern,
+ image_size=image_size,
+ )
+ parsed_dict['description_with_bboxes_or_polygons'] = instances
+ else:
+ raise ValueError("task {} is not supported".format(task))
+
+ return parsed_dict
diff --git a/Wan2GP/shared/prompt_enhancer/loader.py b/Wan2GP/shared/prompt_enhancer/loader.py
new file mode 100644
index 000000000..2bd530808
--- /dev/null
+++ b/Wan2GP/shared/prompt_enhancer/loader.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Tuple
+
+import torch
+from safetensors import safe_open
+
+from .florence2 import Florence2Config, Florence2ForConditionalGeneration, Florence2Processor
+from .florence2.image_processing_florence2 import Florence2ImageProcessorLite
+
+from transformers import BartTokenizer, BartTokenizerFast
+
+
+def _load_state_dict(weights_path: Path) -> dict:
+ if weights_path.suffix == ".safetensors":
+ state_dict = {}
+ with safe_open(str(weights_path), framework="pt", device="cpu") as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key)
+ return state_dict
+ return torch.load(str(weights_path), map_location="cpu")
+
+
+def _resolve_weights_path(model_path: Path) -> Path:
+ # Prefer fp32 weights for stability/quality when available.
+ preferred = model_path / "xmodel.safetensors"
+ if preferred.exists():
+ return preferred
+ fallback = model_path / "model.safetensors"
+ if fallback.exists():
+ return fallback
+ fallback = model_path / "pytorch_model.bin"
+ if fallback.exists():
+ return fallback
+ raise FileNotFoundError(
+ f"No Florence2 weights found in {model_path} (expected model.safetensors/xmodel.safetensors/pytorch_model.bin)"
+ )
+
+
+def load_florence2(
+ model_dir: str,
+ attn_implementation: str = "sdpa",
+) -> Tuple[Florence2ForConditionalGeneration, Florence2Processor]:
+ model_path = Path(model_dir)
+ if not model_path.exists():
+ raise FileNotFoundError(f"Florence2 folder not found: {model_path}")
+
+ config = Florence2Config.from_pretrained(str(model_path))
+ if attn_implementation:
+ config._attn_implementation = attn_implementation
+ weights_path = _resolve_weights_path(model_path)
+ state_dict = _load_state_dict(weights_path)
+
+ model = Florence2ForConditionalGeneration(config)
+ load_info = model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ if load_info.missing_keys:
+ allowed_missing = {
+ "language_model.model.encoder.embed_tokens.weight",
+ "language_model.model.decoder.embed_tokens.weight",
+ }
+ extra_missing = [k for k in load_info.missing_keys if k not in allowed_missing]
+ if extra_missing:
+ print(f"Florence2 missing keys: {extra_missing}")
+ if load_info.unexpected_keys:
+ print(f"Florence2 unexpected keys: {len(load_info.unexpected_keys)}")
+ model.eval()
+
+ image_processor = Florence2ImageProcessorLite.from_preprocessor_config(model_path)
+ tokenizer = None
+ tokenizer_errors = []
+ for tok_cls in (BartTokenizerFast, BartTokenizer):
+ try:
+ tokenizer = tok_cls.from_pretrained(str(model_path))
+ break
+ except Exception as exc:
+ tokenizer_errors.append(exc)
+ if tokenizer is None:
+ raise RuntimeError(f"Unable to load Florence2 tokenizer: {tokenizer_errors}")
+ try:
+ processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
+ except TypeError as exc:
+ if "CLIPImageProcessor" not in str(exc):
+ raise
+ try:
+ from transformers import CLIPImageProcessor
+ except Exception:
+ from transformers.models.clip import CLIPImageProcessor
+ image_processor = CLIPImageProcessor.from_pretrained(str(model_path))
+ processor = Florence2Processor(image_processor=image_processor, tokenizer=tokenizer)
+
+ return model, processor
diff --git a/Wan2GP/models/ltx_video/utils/prompt_enhance_utils.py b/Wan2GP/shared/prompt_enhancer/prompt_enhance_utils.py
similarity index 78%
rename from Wan2GP/models/ltx_video/utils/prompt_enhance_utils.py
rename to Wan2GP/shared/prompt_enhancer/prompt_enhance_utils.py
index 2a87c7eca..406cd97c3 100644
--- a/Wan2GP/models/ltx_video/utils/prompt_enhance_utils.py
+++ b/Wan2GP/shared/prompt_enhancer/prompt_enhance_utils.py
@@ -1,5 +1,6 @@
import logging
from typing import Union, List, Optional
+from contextlib import nullcontext
import torch
from PIL import Image
@@ -158,6 +159,11 @@ def generate_cinematic_prompt(
text_prompt = False,
max_new_tokens: int = 256,
prompt_enhancer_instructions = None,
+ do_sample: bool = True,
+ temperature: Optional[float] = None,
+ top_p: Optional[float] = None,
+ top_k: Optional[int] = None,
+ seed: Optional[int] = None,
) -> List[str]:
prompts = [prompt] if isinstance(prompt, str) else prompt
@@ -170,6 +176,11 @@ def generate_cinematic_prompt(
prompts,
max_new_tokens,
prompt_enhancer_instructions,
+ do_sample,
+ temperature,
+ top_p,
+ top_k,
+ seed,
)
else:
if prompt_enhancer_instructions is None:
@@ -184,6 +195,11 @@ def generate_cinematic_prompt(
images,
max_new_tokens,
prompt_enhancer_instructions,
+ do_sample,
+ temperature,
+ top_p,
+ top_k,
+ seed,
)
return prompts
@@ -203,6 +219,11 @@ def _generate_t2v_prompt(
prompts: List[str],
max_new_tokens: int,
system_prompt: str,
+ do_sample: bool,
+ temperature: Optional[float],
+ top_p: Optional[float],
+ top_k: Optional[int],
+ seed: Optional[int],
) -> List[str]:
messages = [
[
@@ -220,11 +241,24 @@ def _generate_t2v_prompt(
]
out_prompts = []
- for text in texts:
+ for idx, text in enumerate(texts):
model_inputs = prompt_enhancer_tokenizer(text, return_tensors="pt").to(
prompt_enhancer_model.device
)
- out_prompts.append(_generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens)[0])
+ prompt_seed = None if seed is None else int(seed) + idx
+ out_prompts.append(
+ _generate_and_decode_prompts(
+ prompt_enhancer_model,
+ prompt_enhancer_tokenizer,
+ model_inputs,
+ max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ seed=prompt_seed,
+ )[0]
+ )
return out_prompts
@@ -237,6 +271,11 @@ def _generate_i2v_prompt(
first_frames: List[Image.Image],
max_new_tokens: int,
system_prompt: str,
+ do_sample: bool,
+ temperature: Optional[float],
+ top_p: Optional[float],
+ top_k: Optional[int],
+ seed: Optional[int],
) -> List[str]:
image_captions = _generate_image_captions(
image_caption_model, image_caption_processor, first_frames
@@ -258,11 +297,24 @@ def _generate_i2v_prompt(
for m in messages
]
out_prompts = []
- for text in texts:
+ for idx, text in enumerate(texts):
model_inputs = prompt_enhancer_tokenizer(text, return_tensors="pt").to(
prompt_enhancer_model.device
)
- out_prompts.append(_generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens)[0])
+ prompt_seed = None if seed is None else int(seed) + idx
+ out_prompts.append(
+ _generate_and_decode_prompts(
+ prompt_enhancer_model,
+ prompt_enhancer_tokenizer,
+ model_inputs,
+ max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ seed=prompt_seed,
+ )[0]
+ )
return out_prompts
@@ -276,7 +328,12 @@ def _generate_image_captions(
image_caption_prompts = [system_prompt] * len(images)
inputs = image_caption_processor(
image_caption_prompts, images, return_tensors="pt"
- ).to("cuda") #.to(image_caption_model.device)
+ ).to(image_caption_model.device)
+
+ bad_words_ids = None
+ bos_id = getattr(image_caption_processor.tokenizer, "bos_token_id", None)
+ if bos_id is not None:
+ bad_words_ids = [[int(bos_id)]]
with torch.inference_mode():
generated_ids = image_caption_model.generate(
@@ -285,17 +342,50 @@ def _generate_image_captions(
max_new_tokens=1024,
do_sample=False,
num_beams=3,
+ bad_words_ids=bad_words_ids,
)
return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
def _generate_and_decode_prompts(
- prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
+ prompt_enhancer_model,
+ prompt_enhancer_tokenizer,
+ model_inputs,
+ max_new_tokens: int,
+ do_sample: bool = True,
+ temperature: Optional[float] = None,
+ top_p: Optional[float] = None,
+ top_k: Optional[int] = None,
+ seed: Optional[int] = None,
) -> List[str]:
- with torch.inference_mode():
+ device = "cuda"
+ if seed is None:
+ rng_context = nullcontext()
+ else:
+ devices = []
+ if isinstance(device, torch.device) and device.type == "cuda":
+ devices = [device.index or 0]
+ rng_context = torch.random.fork_rng(devices=devices) if devices else torch.random.fork_rng()
+ with rng_context, torch.inference_mode():
+ if seed is not None:
+ torch.manual_seed(int(seed))
+ if isinstance(device, torch.device) and device.type == "cuda":
+ with torch.cuda.device(device):
+ torch.cuda.manual_seed(int(seed))
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": do_sample,
+ }
+ if temperature is not None:
+ gen_kwargs["temperature"] = float(temperature)
+ if top_p is not None:
+ gen_kwargs["top_p"] = float(top_p)
+ if top_k is not None:
+ gen_kwargs["top_k"] = int(top_k)
outputs = prompt_enhancer_model.generate(
- **model_inputs, max_new_tokens=max_new_tokens
+ **model_inputs,
+ **gen_kwargs,
)
generated_ids = [
output_ids[len(input_ids) :]
diff --git a/Wan2GP/shared/qtypes/gguf.py b/Wan2GP/shared/qtypes/gguf.py
new file mode 100644
index 000000000..33569b214
--- /dev/null
+++ b/Wan2GP/shared/qtypes/gguf.py
@@ -0,0 +1,1015 @@
+import ast
+import os
+import re
+
+import torch
+from torch.utils import _pytree as pytree
+
+from optimum.quanto import QModuleMixin
+from optimum.quanto.tensor.qtensor import QTensor
+from optimum.quanto.tensor.qtype import qtype as _quanto_qtype, qtypes as _quanto_qtypes
+from collections import OrderedDict
+
+
+HANDLER_NAME = "gguf"
+
+try:
+ import gguf
+except Exception:
+ gguf = None
+
+
+_GGUF_QTYPE_NAME = "gguf"
+if _GGUF_QTYPE_NAME not in _quanto_qtypes:
+ _quanto_qtypes[_GGUF_QTYPE_NAME] = _quanto_qtype(
+ _GGUF_QTYPE_NAME,
+ is_floating_point=False,
+ bits=6,
+ dtype=torch.uint8,
+ qmin=-32.0,
+ qmax=31.0,
+ )
+_GGUF_QTYPE = _quanto_qtypes[_GGUF_QTYPE_NAME]
+
+_GGUF_DEFAULT_DTYPE = None
+_GGUF_LABEL_CACHE = {}
+
+
+def _normalize_gguf_path(file_path):
+ try:
+ return os.path.normcase(os.path.abspath(file_path))
+ except Exception:
+ return str(file_path).lower()
+
+
+def normalize(file_path):
+ return _normalize_gguf_path(file_path)
+
+
+def _set_default_dtype_from_loader(dtype):
+ global _GGUF_DEFAULT_DTYPE
+ if dtype is None:
+ return
+ _GGUF_DEFAULT_DTYPE = dtype
+
+
+def _resolve_default_dtype(dtype, fallback=None):
+ if dtype is None:
+ return _GGUF_DEFAULT_DTYPE or fallback
+ if _GGUF_DEFAULT_DTYPE is not None and fallback is not None and dtype == fallback:
+ return _GGUF_DEFAULT_DTYPE
+ return dtype
+
+def get_file_metadata(file_path):
+ if gguf is None:
+ raise RuntimeError("GGUF support requires the 'gguf' package.")
+ reader = gguf.GGUFReader(file_path)
+ metadata = {}
+ field = reader.get_field("config")
+ if field is not None:
+ try:
+ metadata["config"] = field.contents() if callable(getattr(field, "contents", None)) else field.contents
+ except Exception:
+ pass
+ return OrderedDict(), metadata
+
+
+def _filter_state_dict_basic(state_dict, base_model_prefix, keep_prefix=False):
+ new_state_dict = {}
+ start = -1
+ if keep_prefix:
+ for k, v in state_dict.items():
+ if k.startswith(base_model_prefix):
+ new_state_dict[k] = v
+ else:
+ for k, v in state_dict.items():
+ if k.startswith(base_model_prefix):
+ new_start = len(base_model_prefix)
+ else:
+ pos = k.find("." + base_model_prefix)
+ if pos < 0:
+ continue
+ new_start = pos + len(base_model_prefix) + 1
+ if start != -1 and start != new_start:
+ new_state_dict = state_dict
+ break
+ start = new_start
+ new_state_dict[k[start:]] = v
+ return new_state_dict
+
+
+def _gguf_get_orig_shape(reader, tensor_name):
+ if gguf is None:
+ raise RuntimeError("GGUF support requires the 'gguf' package.")
+ field_key = f"comfy.gguf.orig_shape.{tensor_name}"
+ field = reader.get_field(field_key)
+ if field is None:
+ return None
+ if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32:
+ raise TypeError(f"Bad GGUF shape metadata for {field_key}: {field.types}")
+ return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
+
+
+def _gguf_resolve_prefix(tensor_names, prefixes):
+ for prefix in prefixes:
+ if any(name.startswith(prefix) for name in tensor_names):
+ return prefix
+ return None
+
+
+def load_gguf_state_dict(
+ file_path,
+ filters=None,
+ keep_prefixes=False,
+ writable_tensors=True,
+ verboseLevel=1,
+ default_dtype=None,
+ pin_to_memory=False,
+):
+ if gguf is None:
+ raise RuntimeError("GGUF support requires the 'gguf' package.")
+ if pin_to_memory:
+ raise Exception("Pinning to memory while loading GGUF files is not supported")
+
+ import warnings
+
+ def _cast_plain_tensor(torch_tensor, tensor_type):
+ if tensor_type == gguf.GGMLQuantizationType.F16:
+ if torch_tensor.dtype in (torch.uint8, torch.uint16):
+ torch_tensor = torch_tensor.view(torch.float16)
+ elif torch_tensor.dtype != torch.float16:
+ torch_tensor = torch_tensor.to(torch.float16)
+ elif tensor_type == gguf.GGMLQuantizationType.BF16:
+ if torch_tensor.dtype in (torch.uint8, torch.uint16):
+ torch_tensor = torch_tensor.view(torch.bfloat16)
+ elif torch_tensor.dtype != torch.bfloat16:
+ torch_tensor = torch_tensor.to(torch.bfloat16)
+ elif tensor_type == gguf.GGMLQuantizationType.F32:
+ if torch_tensor.dtype in (torch.uint8, torch.uint16, torch.uint32):
+ torch_tensor = torch_tensor.view(torch.float32)
+ elif torch_tensor.dtype != torch.float32:
+ torch_tensor = torch_tensor.to(torch.float32)
+ return torch_tensor
+
+ def _tensor_type_from_dtype(dtype):
+ if dtype == torch.float16:
+ return gguf.GGMLQuantizationType.F16
+ if dtype == torch.bfloat16:
+ return gguf.GGMLQuantizationType.BF16
+ if dtype == torch.float32:
+ return gguf.GGMLQuantizationType.F32
+ return None
+
+ reader = gguf.GGUFReader(file_path)
+ if verboseLevel >= 2:
+ try:
+ from mmgp import safetensors2
+ safetensors2.verboseLevel = verboseLevel
+ tracker = safetensors2.MmapTracker(file_path)
+ tracker.register(reader.data, 0, 0, int(reader.data.nbytes))
+ except Exception:
+ tracker = None
+ tensor_names = [tensor.name for tensor in reader.tensors]
+ prefix = _gguf_resolve_prefix(tensor_names, ("model.diffusion_model.", "diffusion_model."))
+
+ state_dict = {}
+ qtype_counts = {}
+ for tensor in reader.tensors:
+ name = tensor.name
+ if prefix and name.startswith(prefix):
+ name = name[len(prefix):]
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
+ torch_tensor = torch.from_numpy(tensor.data)
+
+ shape = _gguf_get_orig_shape(reader, tensor.name)
+ if shape is None:
+ shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
+ if tensor.tensor_type in (
+ gguf.GGMLQuantizationType.F32,
+ gguf.GGMLQuantizationType.F16,
+ gguf.GGMLQuantizationType.BF16,
+ ):
+ torch_tensor = _cast_plain_tensor(torch_tensor, tensor.tensor_type)
+ torch_tensor = torch_tensor.view(*shape)
+ wrapped = GGUFSourceTensor.wrap(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
+ if name.endswith(".bias"):
+ wrapped._gguf_bias_orig_dtype = wrapped.dtype
+ wrapped._gguf_bias_orig_tensor_type = tensor.tensor_type
+ state_dict[name] = wrapped
+ type_name = getattr(tensor.tensor_type, "name", str(tensor.tensor_type))
+ qtype_counts[type_name] = qtype_counts.get(type_name, 0) + 1
+
+ if verboseLevel >= 2 and qtype_counts:
+ print("GGUF qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_counts.items()))
+
+ if filters is not None:
+ if not isinstance(filters, list):
+ filters = [filters]
+ new_sd = {}
+ for one_filter in filters:
+ new_sd.update(_filter_state_dict_basic(state_dict, one_filter, keep_prefixes))
+ state_dict = new_sd
+
+ if state_dict:
+ for name, bias in list(state_dict.items()):
+ if not name.endswith(".bias") or not torch.is_tensor(bias):
+ continue
+ weight_name = name[:-5] + ".weight"
+ weight = state_dict.get(weight_name)
+ target_dtype = None
+ weight_type = getattr(weight, "tensor_type", None) if torch.is_tensor(weight) else None
+ if torch.is_tensor(weight) and weight_type is None:
+ target_dtype = weight.dtype if weight.dtype.is_floating_point else default_dtype
+ elif weight_type in (
+ gguf.GGMLQuantizationType.F16,
+ gguf.GGMLQuantizationType.BF16,
+ gguf.GGMLQuantizationType.F32,
+ ):
+ target_dtype = weight.dtype
+ else:
+ target_dtype = default_dtype
+ if target_dtype is None or bias.dtype == target_dtype:
+ continue
+ casted = bias.to(target_dtype)
+ if isinstance(casted, GGUFSourceTensor):
+ casted._gguf_bias_orig_dtype = getattr(bias, "_gguf_bias_orig_dtype", bias.dtype)
+ casted._gguf_bias_orig_tensor_type = getattr(
+ bias, "_gguf_bias_orig_tensor_type", getattr(bias, "tensor_type", None)
+ )
+ new_tensor_type = _tensor_type_from_dtype(target_dtype)
+ if new_tensor_type is not None:
+ casted.tensor_type = new_tensor_type
+ casted.tensor_shape = getattr(bias, "tensor_shape", casted.shape)
+ state_dict[name] = casted
+
+ return state_dict, None, None
+
+
+def load_state_dict(*args, **kwargs):
+ return load_gguf_state_dict(*args, **kwargs)
+
+
+class GGUFSourceTensor(torch.Tensor):
+ @staticmethod
+ def wrap(tensor, *, tensor_type, tensor_shape):
+ wrapped = tensor.as_subclass(GGUFSourceTensor)
+ wrapped.tensor_type = tensor_type
+ wrapped.tensor_shape = tensor_shape
+ return wrapped
+
+ def to(self, *args, **kwargs):
+ new = super().to(*args, **kwargs)
+ new.tensor_type = getattr(self, "tensor_type", None)
+ new.tensor_shape = getattr(self, "tensor_shape", new.shape)
+ return new
+
+ def clone(self, *args, **kwargs):
+ cloned = super().clone(*args, **kwargs).as_subclass(GGUFSourceTensor)
+ cloned.tensor_type = getattr(self, "tensor_type", None)
+ cloned.tensor_shape = getattr(self, "tensor_shape", cloned.shape)
+ return cloned
+
+ def detach(self, *args, **kwargs):
+ detached = super().detach(*args, **kwargs).as_subclass(GGUFSourceTensor)
+ detached.tensor_type = getattr(self, "tensor_type", None)
+ detached.tensor_shape = getattr(self, "tensor_shape", detached.shape)
+ return detached
+
+ def get_quantized_subtensors(self):
+ return [("data", self)]
+
+ def set_quantized_subtensors(self, sub_tensors):
+ if isinstance(sub_tensors, dict):
+ data = sub_tensors.get("data")
+ else:
+ data = dict(sub_tensors).get("data")
+ if data is None or data is self:
+ return
+ torch.utils.swap_tensors(self, data)
+
+
+def _split_gguf_tensor(src, *, dim, split_sizes, context):
+ if not torch.is_tensor(src):
+ return None
+ tensor_type = getattr(src, "tensor_type", None)
+ if tensor_type is None:
+ return None
+ tensor_shape = getattr(src, "tensor_shape", None) or src.shape
+ total = sum(split_sizes)
+ if dim >= len(tensor_shape) or tensor_shape[dim] != total:
+ return None
+ chunks = torch.split(src, split_sizes, dim=dim)
+ out = []
+ for chunk, size in zip(chunks, split_sizes):
+ new_shape = list(tensor_shape)
+ new_shape[dim] = size
+ wrapped = GGUFSourceTensor.wrap(chunk, tensor_type=tensor_type, tensor_shape=tuple(new_shape))
+ if hasattr(src, "_gguf_bias_orig_dtype"):
+ wrapped._gguf_bias_orig_dtype = getattr(src, "_gguf_bias_orig_dtype")
+ if hasattr(src, "_gguf_bias_orig_tensor_type"):
+ wrapped._gguf_bias_orig_tensor_type = getattr(src, "_gguf_bias_orig_tensor_type")
+ out.append(wrapped)
+ return out
+
+
+def split_fused_weights(state_dict, fused_split_map, quantization_map=None, allowed_bases=None, default_dtype=None, verboseLevel=1):
+ from mmgp import offload
+ return offload.sd_split_linear(
+ state_dict,
+ fused_split_map,
+ split_fields={"weight": 0, "bias": 0},
+ split_handlers={"weight": _split_gguf_tensor, "bias": _split_gguf_tensor},
+ verboseLevel=verboseLevel,
+ allowed_bases=allowed_bases,
+ return_split_bases=True,
+ )
+
+
+def _is_gguf_qtype(qtype_obj):
+ if gguf is None:
+ return False
+ if qtype_obj is None:
+ return False
+ return qtype_obj not in (
+ gguf.GGMLQuantizationType.F32,
+ gguf.GGMLQuantizationType.F16,
+ gguf.GGMLQuantizationType.BF16,
+ )
+
+
+def _gguf_qtype_name(qtype_obj):
+ if qtype_obj is None:
+ return None
+ return getattr(qtype_obj, "name", None) or str(qtype_obj)
+
+
+def _guess_variant_from_filename(filename):
+ base = os.path.basename(str(filename))
+ match = re.search(r"(?i)(?:^|[_-])(Q\d+_K|Q\d+_\d|Q\d+|IQ\d+_\w+)(?:$|[_.-])", base)
+ if match:
+ return match.group(1).upper()
+ return None
+
+
+def detect_gguf_quantization_variant(file_path, verboseLevel=1):
+ if gguf is None:
+ return None
+ try:
+ reader = gguf.GGUFReader(file_path)
+ except Exception:
+ return None
+ counts = {}
+ for tensor in reader.tensors:
+ qtype = getattr(tensor, "tensor_type", None)
+ if qtype in (
+ gguf.GGMLQuantizationType.F32,
+ gguf.GGMLQuantizationType.F16,
+ gguf.GGMLQuantizationType.BF16,
+ ):
+ continue
+ name = _gguf_qtype_name(qtype)
+ if not name:
+ continue
+ counts[name] = counts.get(name, 0) + 1
+ if not counts:
+ return None
+ return max(counts, key=counts.get)
+
+
+def detect_quantization_kind_for_file(file_path, verboseLevel=1):
+ if not file_path or str(file_path).lower().endswith(".gguf") is False:
+ return None
+ if gguf is None:
+ return None
+ return "gguf"
+
+
+def detect_quantization_label_from_filename(filename, verboseLevel=1):
+ if not filename or str(filename).lower().endswith(".gguf") is False:
+ return ""
+ key = _normalize_gguf_path(filename)
+ cached = _GGUF_LABEL_CACHE.get(key)
+ if cached:
+ return cached
+ variant = _guess_variant_from_filename(filename)
+ if not variant and os.path.isfile(filename):
+ variant = detect_gguf_quantization_variant(filename, verboseLevel=verboseLevel)
+ if variant:
+ label = f"GGUF-{variant}"
+ else:
+ label = "GGUF"
+ _GGUF_LABEL_CACHE[key] = label
+ return label
+
+
+def _gguf_qfallback(callable, *args, **kwargs):
+ args, kwargs = pytree.tree_map_only(GGUFWeightTensor, lambda x: x.dequantize(), (args, kwargs or {}))
+ return callable(*args, **kwargs)
+
+
+def _reshape_scale(scale, weight):
+ if scale.ndim == 0 or scale.numel() == 1:
+ return scale
+ if scale.ndim == 1 and scale.shape[0] == weight.shape[0]:
+ return scale.view(weight.shape[0], *([1] * (weight.ndim - 1)))
+ return scale
+
+
+def _gguf_dequantize_tensor(raw, qtype_obj, oshape, dtype=None):
+ if gguf is None:
+ raise RuntimeError("gguf package is required to dequantize GGUF weights.")
+ if qtype_obj in (
+ gguf.GGMLQuantizationType.F32,
+ gguf.GGMLQuantizationType.F16,
+ gguf.GGMLQuantizationType.BF16,
+ ):
+ out = raw.view(*oshape)
+ return out.to(dtype) if dtype is not None else out
+ if qtype_obj not in _DEQUANTIZE_FUNCTIONS:
+ out = gguf.quants.dequantize(raw.cpu().numpy(), qtype_obj)
+ out = torch.from_numpy(out)
+ return out.to(dtype) if dtype is not None else out
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype_obj]
+ dequantize_blocks = _DEQUANTIZE_FUNCTIONS[qtype_obj]
+ rows = raw.reshape((-1, raw.shape[-1])).view(torch.uint8)
+ n_blocks = rows.numel() // type_size
+ blocks = rows.reshape((n_blocks, type_size))
+ blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
+ return blocks.reshape(oshape)
+
+
+def _maybe_cast_bias(bias, target_dtype):
+ if bias is None or not torch.is_tensor(bias) or target_dtype is None:
+ return bias
+ if bias.dtype == target_dtype:
+ return bias
+ if isinstance(bias, GGUFSourceTensor):
+ tensor_type = getattr(bias, "tensor_type", None)
+ tensor_shape = getattr(bias, "tensor_shape", bias.shape)
+ if _is_gguf_qtype(tensor_type):
+ return _gguf_dequantize_tensor(bias, tensor_type, tensor_shape, dtype=target_dtype)
+ return bias.to(target_dtype)
+
+
+def _to_uint32(x):
+ x = x.view(torch.uint8).to(torch.int32)
+ return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
+
+
+def _to_uint16(x):
+ x = x.view(torch.uint8).to(torch.int32)
+ return (x[:, 0] | x[:, 1] << 8).unsqueeze(1)
+
+
+def _const_like(ref, values, dtype):
+ device = ref.device if torch.is_tensor(ref) else None
+ count = len(values)
+ if count == 0:
+ return torch.empty((0,), device=device, dtype=dtype)
+ if count == 1:
+ return torch.full((1,), values[0], device=device, dtype=dtype)
+ step = values[1] - values[0]
+ if all(values[idx] - values[idx - 1] == step for idx in range(1, count)):
+ end = values[0] + step * count
+ return torch.arange(values[0], end, step, device=device, dtype=dtype)
+ raise ValueError("Unsupported constant pattern for GGUF dequantization.")
+
+
+def _split_block_dims(blocks, *args):
+ n_max = blocks.shape[1]
+ dims = list(args) + [n_max - sum(args)]
+ return torch.split(blocks, dims, dim=1)
+
+
+def _dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
+ d, x = _split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+ x = x.view(torch.int8)
+ return d * x
+
+
+def _dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ d, m, qh, qs = _split_block_dims(blocks, 2, 2, 4)
+ d = d.view(torch.float16).to(dtype)
+ m = m.view(torch.float16).to(dtype)
+ qh = _to_uint32(qh)
+ qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
+ ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> _const_like(d, [0, 4], torch.uint8).reshape(1, 1, 2, 1)
+ qh = (qh & 1).to(torch.uint8)
+ ql = (ql & 0x0F).reshape((n_blocks, -1))
+ qs = ql | (qh << 4)
+ return d * qs + m
+
+
+def _dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ d, qh, qs = _split_block_dims(blocks, 2, 4)
+ d = d.view(torch.float16).to(dtype)
+ qh = _to_uint32(qh)
+ qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
+ ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> _const_like(d, [0, 4], torch.uint8).reshape(1, 1, 2, 1)
+ qh = (qh & 1).to(torch.uint8)
+ ql = (ql & 0x0F).reshape(n_blocks, -1)
+ qs = (ql | (qh << 4)).to(torch.int8) - 16
+ return d * qs
+
+
+def _dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ d, m, qs = _split_block_dims(blocks, 2, 2)
+ d = d.view(torch.float16).to(dtype)
+ m = m.view(torch.float16).to(dtype)
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> _const_like(d, [0, 4], torch.uint8).reshape(1, 1, 2, 1)
+ qs = (qs & 0x0F).reshape((n_blocks, -1))
+ return d * qs + m
+
+
+def _dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ d, qs = _split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> _const_like(d, [0, 4], torch.uint8).reshape(1, 1, 2, 1)
+ qs = (qs & 0x0F).reshape((n_blocks, -1))
+ qs = qs.to(torch.int8) - 8
+ return d * qs
+
+
+QK_K = 256
+K_SCALE_SIZE = 12
+
+
+def _get_scale_min(scales):
+ n_blocks = scales.shape[0]
+ scales = scales.view(torch.uint8)
+ scales = scales.reshape((n_blocks, 3, 4))
+ d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
+ sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
+ mn = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
+ return sc.reshape((n_blocks, 8)), mn.reshape((n_blocks, 8))
+
+
+def _dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ ql, qh, scales, d = _split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
+ scales = scales.view(torch.int8).to(dtype)
+ d = d.view(torch.float16).to(dtype)
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> _const_like(d, [0, 4], torch.uint8).reshape((1, 1, 2, 1))
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> _const_like(d, [0, 2, 4, 6], torch.uint8).reshape((1, 1, 4, 1))
+ qh = (qh & 0x03).reshape((n_blocks, -1, 32))
+ q = (ql | (qh << 4)).to(torch.int8) - 32
+ q = q.reshape((n_blocks, QK_K // 16, -1))
+ return (d * q).reshape((n_blocks, QK_K))
+
+
+def _dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ d, dmin, scales, qh, qs = _split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
+ d = d.view(torch.float16).to(dtype)
+ dmin = dmin.view(torch.float16).to(dtype)
+ sc, m = _get_scale_min(scales)
+ d = (d * sc).reshape((n_blocks, -1, 1))
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> _const_like(d, [0, 4], torch.uint8).reshape((1, 1, 2, 1))
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> _const_like(d, list(range(8)), torch.uint8).reshape((1, 1, 8, 1))
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
+ qh = (qh & 0x01).reshape((n_blocks, -1, 32))
+ q = ql | (qh << 4)
+ return (d * q - dm).reshape((n_blocks, QK_K))
+
+
+def _dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ d, dmin, scales, qs = _split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
+ d = d.view(torch.float16).to(dtype)
+ dmin = dmin.view(torch.float16).to(dtype)
+ sc, m = _get_scale_min(scales)
+ d = (d * sc).reshape((n_blocks, -1, 1))
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> _const_like(d, [0, 4], torch.uint8).reshape((1, 1, 2, 1))
+ qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
+ return (d * qs - dm).reshape((n_blocks, QK_K))
+
+
+def _dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ hmask, qs, scales, d = _split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
+ d = d.view(torch.float16).to(dtype)
+ lscales, hscales = scales[:, :8], scales[:, 8:]
+ lscales = lscales.reshape((n_blocks, 1, 8)) >> _const_like(d, [0, 4], torch.uint8).reshape((1, 2, 1))
+ lscales = lscales.reshape((n_blocks, 16))
+ hscales = hscales.reshape((n_blocks, 1, 4)) >> _const_like(d, [0, 2, 4, 6], torch.uint8).reshape((1, 4, 1))
+ hscales = hscales.reshape((n_blocks, 16))
+ scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
+ scales = (scales.to(torch.int8) - 32)
+ dl = (d * scales).reshape((n_blocks, 16, 1))
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> _const_like(d, [0, 2, 4, 6], torch.uint8).reshape((1, 1, 4, 1))
+ qh = hmask.reshape(n_blocks, -1, 1, 32) >> _const_like(d, list(range(8)), torch.uint8).reshape((1, 1, 8, 1))
+ ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
+ qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
+ q = (ql.to(torch.int8) - (qh << 2).to(torch.int8))
+ return (dl * q).reshape((n_blocks, QK_K))
+
+
+def _dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+ scales, qs, d, dmin = _split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
+ d = d.view(torch.float16).to(dtype)
+ dmin = dmin.view(torch.float16).to(dtype)
+ dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
+ ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
+ shift = _const_like(d, [0, 2, 4, 6], torch.uint8).reshape((1, 1, 4, 1))
+ qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
+ qs = qs.reshape((n_blocks, QK_K // 16, 16))
+ qs = dl * qs - ml
+ return qs.reshape((n_blocks, -1))
+
+
+if gguf is not None:
+ _DEQUANTIZE_FUNCTIONS = {
+ gguf.GGMLQuantizationType.Q8_0: _dequantize_blocks_Q8_0,
+ gguf.GGMLQuantizationType.Q5_1: _dequantize_blocks_Q5_1,
+ gguf.GGMLQuantizationType.Q5_0: _dequantize_blocks_Q5_0,
+ gguf.GGMLQuantizationType.Q4_1: _dequantize_blocks_Q4_1,
+ gguf.GGMLQuantizationType.Q4_0: _dequantize_blocks_Q4_0,
+ gguf.GGMLQuantizationType.Q6_K: _dequantize_blocks_Q6_K,
+ gguf.GGMLQuantizationType.Q5_K: _dequantize_blocks_Q5_K,
+ gguf.GGMLQuantizationType.Q4_K: _dequantize_blocks_Q4_K,
+ gguf.GGMLQuantizationType.Q3_K: _dequantize_blocks_Q3_K,
+ gguf.GGMLQuantizationType.Q2_K: _dequantize_blocks_Q2_K,
+ }
+else:
+ _DEQUANTIZE_FUNCTIONS = {}
+
+
+class GGUFWeightTensor(QTensor):
+ @staticmethod
+ def create(raw_tensor, size, stride, dtype, device=None, requires_grad=False, tensor_type=None, tensor_shape=None):
+ if tensor_type is None:
+ tensor_type = getattr(raw_tensor, "tensor_type", None)
+ if tensor_shape is None:
+ tensor_shape = getattr(raw_tensor, "tensor_shape", None) or size
+ device = raw_tensor.device if device is None else device
+ if raw_tensor.device != device:
+ raw_tensor = raw_tensor.to(device)
+ return GGUFWeightTensor(
+ qtype=_GGUF_QTYPE,
+ axis=0,
+ size=size,
+ stride=stride,
+ raw=raw_tensor,
+ tensor_type=tensor_type,
+ tensor_shape=tensor_shape,
+ dtype=dtype,
+ requires_grad=requires_grad,
+ )
+
+ @staticmethod
+ def __new__(cls, qtype, axis, size, stride, raw, tensor_type, tensor_shape, dtype, requires_grad=False):
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ size,
+ strides=stride,
+ dtype=dtype,
+ device=raw.device,
+ requires_grad=requires_grad,
+ )
+
+ def __init__(self, qtype, axis, size, stride, raw, tensor_type, tensor_shape, dtype, requires_grad=False):
+ super().__init__(qtype, axis)
+ self._data = raw
+ self._tensor_type = tensor_type
+ self._tensor_shape = torch.Size(tensor_shape)
+ self._gguf_default_dtype = dtype
+
+ def __repr__(self):
+ cls_name = self.__class__.__name__
+ try:
+ shape = tuple(self.shape)
+ except Exception:
+ shape = ">"
+ try:
+ dtype = str(self.dtype).replace("torch.", "")
+ except Exception:
+ dtype = ">"
+ try:
+ device = str(self.device)
+ except Exception:
+ device = ">"
+ qtype = getattr(self, "_qtype", None)
+ qtype_name = getattr(qtype, "name", None) or str(qtype) if qtype is not None else ">"
+ tensor_type = _gguf_qtype_name(getattr(self, "_tensor_type", None)) or ">"
+ return (
+ f"{cls_name}(shape={shape}, dtype={dtype}, device={device}, "
+ f"qtype={qtype_name}, tensor_type={tensor_type})"
+ )
+
+ __str__ = __repr__
+
+ def dequantize(self, dtype=None, device=None):
+ if dtype is None:
+ dtype = self.dtype
+ if device is None:
+ device = self.device
+ raw = self._data if self._data.device == device else self._data.to(device)
+ return _gguf_dequantize_tensor(raw, self._tensor_type, self._tensor_shape, dtype=dtype)
+
+ def linear(self, input, bias=None):
+ if torch.is_tensor(input):
+ target_dtype = _resolve_default_dtype(self._gguf_default_dtype, fallback=input.dtype)
+ target_device = input.device
+ else:
+ target_dtype = _resolve_default_dtype(self._gguf_default_dtype, fallback=self.dtype)
+ target_device = self.device
+ weight = self.dequantize(dtype=target_dtype, device=target_device)
+ if torch.is_tensor(input) and input.dtype != weight.dtype:
+ input = input.to(weight.dtype)
+ bias = _maybe_cast_bias(bias, weight.dtype)
+ return torch.nn.functional.linear(input, weight, bias)
+
+ def get_quantized_subtensors(self):
+ return [("data", self._data)]
+
+ def set_quantized_subtensors(self, sub_tensors):
+ if isinstance(sub_tensors, dict):
+ sub_map = sub_tensors
+ else:
+ sub_map = {name: tensor for name, tensor in sub_tensors}
+ data = sub_map.get("data", None)
+ if data is not None:
+ old_data = self._data
+ if torch.is_tensor(old_data):
+ try:
+ torch.utils.swap_tensors(old_data, data)
+ self._data = old_data
+ except Exception:
+ self._data = data
+ else:
+ self._data = data
+ if hasattr(self, "_ggml_raw_cpu"):
+ self._ggml_raw_cpu = None
+
+ def __tensor_flatten__(self):
+ inner_tensors = ["_data"]
+ meta = {
+ "qtype": self._qtype.name,
+ "axis": str(self._axis),
+ "size": str(list(self.size())),
+ "stride": str(list(self.stride())),
+ "dtype": str(self.dtype),
+ "tensor_type": _gguf_qtype_name(self._tensor_type) or "",
+ "tensor_shape": str(list(self._tensor_shape)),
+ }
+ return inner_tensors, meta
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+ qtype = _quanto_qtypes[meta["qtype"]]
+ axis = ast.literal_eval(meta["axis"])
+ size = ast.literal_eval(meta["size"])
+ stride = ast.literal_eval(meta["stride"])
+ dtype_str = meta.get("dtype", "torch.float16")
+ if dtype_str.startswith("torch."):
+ dtype_name = dtype_str.split(".", 1)[1]
+ dtype = getattr(torch, dtype_name, torch.float16)
+ else:
+ dtype = getattr(torch, dtype_str, torch.float16)
+ tensor_shape = ast.literal_eval(meta.get("tensor_shape", str(list(size))))
+ tensor_type = None
+ if gguf is not None and meta.get("tensor_type"):
+ tensor_type = getattr(gguf.GGMLQuantizationType, meta["tensor_type"], None)
+ return GGUFWeightTensor(
+ qtype=qtype,
+ axis=axis,
+ size=size,
+ stride=stride,
+ raw=inner_tensors["_data"],
+ tensor_type=tensor_type,
+ tensor_shape=tensor_shape,
+ dtype=dtype,
+ )
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if func is torch.nn.functional.linear:
+ input = args[0] if len(args) > 0 else kwargs.get("input", None)
+ weight = args[1] if len(args) > 1 else kwargs.get("weight", None)
+ bias = args[2] if len(args) > 2 else kwargs.get("bias", None)
+ if isinstance(weight, GGUFWeightTensor):
+ return weight.linear(input, bias=bias)
+ with torch._C.DisableTorchFunctionSubclass():
+ return func(*args, **kwargs)
+
+ @classmethod
+ def __torch_dispatch__(cls, op, types, args, kwargs=None):
+ op = op.overloadpacket
+ kwargs = kwargs or {}
+ if op is torch.ops.aten.linear:
+ input = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ if isinstance(weight, GGUFWeightTensor):
+ return weight.linear(input, bias=bias)
+ if op is torch.ops.aten.detach:
+ t = args[0]
+ return GGUFWeightTensor.create(
+ raw_tensor=op(t._data),
+ size=t.size(),
+ stride=t.stride(),
+ dtype=t.dtype,
+ device=t.device,
+ requires_grad=t.requires_grad,
+ tensor_type=getattr(t, "_tensor_type", None),
+ tensor_shape=getattr(t, "_tensor_shape", None),
+ )
+ if op in (torch.ops.aten._to_copy, torch.ops.aten.to):
+ t = args[0]
+ dtype = kwargs.pop("dtype", t.dtype) if kwargs else t.dtype
+ device = kwargs.pop("device", t.device) if kwargs else t.device
+ if dtype != t.dtype:
+ return t.dequantize(dtype=dtype, device=device)
+ out_data = op(t._data, device=device, **(kwargs or {}))
+ return GGUFWeightTensor.create(
+ raw_tensor=out_data,
+ size=t.size(),
+ stride=t.stride(),
+ dtype=t.dtype,
+ device=device,
+ requires_grad=t.requires_grad,
+ tensor_type=getattr(t, "_tensor_type", None),
+ tensor_shape=getattr(t, "_tensor_shape", None),
+ )
+ return _gguf_qfallback(op, *args, **(kwargs or {}))
+
+
+class QLinearGGUF(QModuleMixin, torch.nn.Linear):
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ bias=True,
+ device=None,
+ dtype=None,
+ weights=None,
+ activations=None,
+ optimizer=None,
+ quantize_input=True,
+ ):
+ super().__init__(
+ in_features,
+ out_features,
+ bias=bias,
+ device=device,
+ dtype=dtype,
+ weights=weights,
+ activations=activations,
+ optimizer=optimizer,
+ quantize_input=quantize_input,
+ )
+ self._gguf_default_dtype = dtype
+
+ @classmethod
+ def qcreate(cls, module, weights, activations=None, optimizer=None, device=None):
+ if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point:
+ weight_dtype = module.weight.dtype
+ elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point:
+ weight_dtype = module.bias.dtype
+ else:
+ weight_dtype = torch.float16
+ return cls(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ device=device,
+ dtype=weight_dtype,
+ weights=weights,
+ activations=activations,
+ optimizer=optimizer,
+ quantize_input=True,
+ )
+
+ def set_default_dtype(self, dtype):
+ self._gguf_default_dtype = dtype
+
+ @property
+ def qweight(self):
+ if self.weight_qtype == _GGUF_QTYPE:
+ return self.weight
+ return super().qweight
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ qweight = self.qweight
+ if isinstance(qweight, GGUFWeightTensor):
+ return qweight.linear(input, bias=self.bias)
+ return torch.nn.functional.linear(input, qweight, bias=self.bias)
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ if self.weight_qtype != _GGUF_QTYPE:
+ return super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ weight_key = prefix + "weight"
+ bias_key = prefix + "bias"
+ input_scale_key = prefix + "input_scale"
+ output_scale_key = prefix + "output_scale"
+
+ weight_raw = state_dict.pop(weight_key, None)
+ bias = state_dict.pop(bias_key, None)
+ input_scale = state_dict.pop(input_scale_key, None)
+ output_scale = state_dict.pop(output_scale_key, None)
+
+ if weight_raw is None:
+ missing_keys.append(weight_key)
+
+ target_dtype = _resolve_default_dtype(self._gguf_default_dtype, fallback=self.weight.dtype)
+ if weight_raw is not None:
+ gguf_weight = GGUFWeightTensor.create(
+ raw_tensor=weight_raw,
+ size=self.weight.size(),
+ stride=self.weight.stride(),
+ dtype=target_dtype,
+ device=weight_raw.device,
+ requires_grad=False,
+ )
+ self.weight = torch.nn.Parameter(gguf_weight, requires_grad=False)
+
+ if bias is not None:
+ self.bias = torch.nn.Parameter(bias, requires_grad=False)
+
+ if torch.is_tensor(weight_raw):
+ scale_device = weight_raw.device
+ elif torch.is_tensor(self.weight):
+ scale_device = self.weight.device
+ elif torch.is_tensor(bias):
+ scale_device = bias.device
+ else:
+ scale_device = torch.device("cpu")
+
+ if input_scale is not None:
+ self.input_scale = input_scale.to(scale_device)
+ else:
+ if not hasattr(self, "input_scale") or self.input_scale.is_meta:
+ scale_dtype = self.input_scale.dtype if hasattr(self, "input_scale") else torch.float32
+ self.input_scale = torch.ones((), dtype=scale_dtype, device=scale_device)
+
+ if output_scale is not None:
+ self.output_scale = output_scale.to(scale_device)
+ else:
+ if not hasattr(self, "output_scale") or self.output_scale.is_meta:
+ scale_dtype = self.output_scale.dtype if hasattr(self, "output_scale") else torch.float32
+ self.output_scale = torch.ones((), dtype=scale_dtype, device=scale_device)
+
+ return
+
+
+def _collect_gguf_specs(state_dict):
+ specs = []
+ for key, tensor in state_dict.items():
+ if not key.endswith(".weight"):
+ continue
+ if not _is_gguf_qtype(getattr(tensor, "tensor_type", None)):
+ continue
+ specs.append({"name": key[:-7], "tensor": tensor})
+ return specs
+
+
+def detect(state_dict, verboseLevel=1):
+ if gguf is None:
+ return {"matched": False, "kind": "none", "details": {"error": "gguf not installed"}}
+ specs = _collect_gguf_specs(state_dict)
+ if not specs:
+ return {"matched": False, "kind": "none", "details": {}}
+ names = [spec["name"] for spec in specs][:8]
+ return {"matched": True, "kind": "gguf", "details": {"count": len(specs), "names": names}}
+
+
+def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None):
+ if gguf is None:
+ return {"state_dict": state_dict, "quant_map": {}}
+ if detection is not None and not detection.get("matched", False):
+ return {"state_dict": state_dict, "quant_map": {}}
+ specs = _collect_gguf_specs(state_dict)
+ if not specs:
+ return {"state_dict": state_dict, "quant_map": {}}
+ _set_default_dtype_from_loader(default_dtype)
+ quant_map = {spec["name"]: {"weights": _GGUF_QTYPE_NAME, "activations": "none"} for spec in specs}
+ return {"state_dict": state_dict, "quant_map": quant_map}
+
+
+def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1):
+ if default_dtype is None or model is None or not quantization_map:
+ return quantization_map or {}, []
+ _set_default_dtype_from_loader(default_dtype)
+ quantized = set(quantization_map.keys())
+ for name, module in model.named_modules():
+ if name in quantized and isinstance(module, torch.nn.Linear):
+ module._router_default_dtype = default_dtype
+ return quantization_map or {}, []
diff --git a/Wan2GP/shared/qtypes/nunchaku_fp4.py b/Wan2GP/shared/qtypes/nunchaku_fp4.py
index 268380c3f..0907a4134 100644
--- a/Wan2GP/shared/qtypes/nunchaku_fp4.py
+++ b/Wan2GP/shared/qtypes/nunchaku_fp4.py
@@ -13,6 +13,7 @@
HANDLER_NAME = "nunchaku_fp4"
+HANDLER_PRIORITY = 2
_NUNCHAKU_FP4_QTYPE_NAME = "nunchaku_fp4"
if _NUNCHAKU_FP4_QTYPE_NAME not in _quanto_qtypes:
@@ -28,6 +29,7 @@
_NUNCHAKU_FP4_QTYPE = _quanto_qtypes[_NUNCHAKU_FP4_QTYPE_NAME]
_NUNCHAKU_OPS = None
_NUNCHAKU_FALLBACK_NOTICE = False
+_NUNCHAKU_KERNEL_NOTICE = False
_NUNCHAKU_SPLIT_FIELDS = {
"qweight": 0,
"wscales": 1,
@@ -71,6 +73,19 @@ def _split(state_dict, verboseLevel=1):
return _split
+def split_fused_weights(state_dict, fused_split_map, quantization_map=None, allowed_bases=None, default_dtype=None, verboseLevel=1):
+ from mmgp import offload
+ split_kwargs = get_nunchaku_split_kwargs()
+ return offload.sd_split_linear(
+ state_dict,
+ fused_split_map,
+ verboseLevel=verboseLevel,
+ allowed_bases=allowed_bases,
+ return_split_bases=True,
+ **split_kwargs,
+ )
+
+
def _install_nunchaku_shim(candidate_root):
candidate_pkg = candidate_root / "nunchaku"
if not candidate_pkg.exists():
@@ -152,6 +167,18 @@ def _notify_nunchaku_fallback(reason):
_NUNCHAKU_FALLBACK_NOTICE = True
+def _notify_nunchaku_kernel_status(verboseLevel=1):
+ global _NUNCHAKU_KERNEL_NOTICE
+ if _NUNCHAKU_KERNEL_NOTICE:
+ return
+ ops = _load_nunchaku_ops()
+ if ops:
+ print("[nunchaku_fp4] Using Nunchaku kernels.")
+ elif not _NUNCHAKU_FALLBACK_NOTICE:
+ print("[nunchaku_fp4] Nunchaku kernels unavailable; using Python fallback.")
+ _NUNCHAKU_KERNEL_NOTICE = True
+
+
def _is_float8_dtype(dtype):
return "float8" in str(dtype).lower() or "f8" in str(dtype).lower()
@@ -400,8 +427,10 @@ def _unpack_nunchaku_fp4_weight(qweight, out_features, in_features):
k_tiles = in_features // mem_k
packed_i32 = qweight.view(torch.int32)
packed_i32 = packed_i32.view(n_tiles, k_tiles, 1, 8, 8, 4, 2, 2, 1)
- shifts = torch.tensor([0, 4, 8, 12, 16, 20, 24, 28], device=packed_i32.device, dtype=torch.int32)
- vals = (packed_i32.unsqueeze(-1) >> shifts) & 0xF
+ vals = torch.stack(
+ [(packed_i32 >> shift) & 0xF for shift in (0, 4, 8, 12, 16, 20, 24, 28)],
+ dim=-1,
+ )
vals = vals.permute(0, 3, 6, 4, 8, 1, 2, 7, 5, 9).contiguous()
vals = vals.view(out_features, in_features).to(torch.int16)
return vals
@@ -428,9 +457,7 @@ def _unpack_int4_from_int32(qweight, out_features, in_features):
q = qweight.view(torch.int32).reshape(out_features, in_features // 8)
q = q.to(torch.int64) & 0xFFFFFFFF
- shifts = torch.arange(0, 32, 4, device=q.device, dtype=torch.int64)
- mask = torch.tensor(0x0F, device=q.device, dtype=torch.int64)
- vals = (q.unsqueeze(-1) >> shifts) & mask
+ vals = torch.stack([(q >> shift) & 0xF for shift in range(0, 32, 4)], dim=-1)
return vals.reshape(out_features, in_features)
@@ -477,6 +504,55 @@ class NunchakuBaseWeightTensor(QTensor):
def __init__(self, qtype, axis):
super().__init__(qtype, axis)
+ def __repr__(self):
+ cls_name = self.__class__.__name__
+ try:
+ shape = tuple(self.shape)
+ except Exception:
+ shape = ">"
+ try:
+ dtype = str(self.dtype).replace("torch.", "")
+ except Exception:
+ dtype = ">"
+ try:
+ device = str(self.device)
+ except Exception:
+ device = ">"
+ qtype = getattr(self, "_qtype", None)
+ qtype_name = getattr(qtype, "name", None) or str(qtype) if qtype is not None else ">"
+ parts = [
+ f"shape={shape}",
+ f"dtype={dtype}",
+ f"device={device}",
+ f"qtype={qtype_name}",
+ ]
+ group_size = getattr(self, "_group_size", None)
+ if group_size is not None:
+ parts.append(f"group_size={group_size}")
+ field_parts = []
+ for name in (
+ "_qweight",
+ "_wscales",
+ "_wzeros",
+ "_wtscale",
+ "_wcscales",
+ "_smooth_factor",
+ "_proj_down",
+ "_proj_up",
+ ):
+ if not hasattr(self, name):
+ continue
+ value = getattr(self, name)
+ if torch.is_tensor(value):
+ field_parts.append(f"{name[1:]}={tuple(value.shape)}:{value.dtype}")
+ else:
+ field_parts.append(f"{name[1:]}={value}")
+ if field_parts:
+ parts.append("fields={" + ", ".join(field_parts) + "}")
+ return f"{cls_name}(" + ", ".join(parts) + ")"
+
+ __str__ = __repr__
+
def get_quantized_subtensors(self):
raise NotImplementedError
@@ -697,7 +773,14 @@ def _linear_cuda(self, input, bias=None):
out = out[: x.shape[0]]
return out.reshape(*input.shape[:-1], self.shape[0])
+ @torch.compiler.disable()
def linear(self, input, bias=None):
+ if torch.is_tensor(input):
+ from torch._subclasses.fake_tensor import is_fake
+
+ if is_fake(input):
+ out_shape = (*input.shape[:-1], self.shape[0])
+ return torch.empty(out_shape, device=input.device, dtype=input.dtype)
if torch.is_tensor(input) and input.device.type == "cuda":
return self._linear_cuda(input, bias=bias)
return self._linear_fallback(input, bias=bias)
@@ -994,7 +1077,14 @@ def _linear_cuda(self, input, bias=None):
out.add_(bias.view(view_shape))
return out.reshape(*input.shape[:-1], out.shape[-1])
+ @torch.compiler.disable()
def linear(self, input, bias=None):
+ if torch.is_tensor(input):
+ from torch._subclasses.fake_tensor import is_fake
+
+ if is_fake(input):
+ out_shape = (*input.shape[:-1], self.shape[0])
+ return torch.empty(out_shape, device=input.device, dtype=input.dtype)
if torch.is_tensor(input) and input.device.type == "cuda":
return self._linear_cuda(input, bias=bias)
return self._linear_fallback(input, bias=bias)
@@ -1189,6 +1279,7 @@ def __init__(
def set_default_dtype(self, dtype):
self._nunchaku_default_dtype = dtype
+ @torch.compiler.disable()
def forward(self, input: torch.Tensor) -> torch.Tensor:
qweight = self.qweight
if isinstance(qweight, NunchakuBaseWeightTensor):
@@ -1503,6 +1594,7 @@ def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None)
specs = _collect_nunchaku_specs(state_dict)
if not specs:
return {"state_dict": state_dict, "quant_map": {}}
+ _notify_nunchaku_kernel_status(verboseLevel=verboseLevel)
quant_map = {spec["name"]: {"weights": "nunchaku_fp4", "activations": "none"} for spec in specs}
return {"state_dict": state_dict, "quant_map": quant_map}
diff --git a/Wan2GP/shared/qtypes/nunchaku_int4.py b/Wan2GP/shared/qtypes/nunchaku_int4.py
index 9925e54b2..77d998d39 100644
--- a/Wan2GP/shared/qtypes/nunchaku_int4.py
+++ b/Wan2GP/shared/qtypes/nunchaku_int4.py
@@ -13,6 +13,7 @@
HANDLER_NAME = "nunchaku_int4"
+HANDLER_PRIORITY = 2
_NUNCHAKU_INT4_QTYPE_NAME = "nunchaku_int4"
if _NUNCHAKU_INT4_QTYPE_NAME not in _quanto_qtypes:
@@ -28,6 +29,7 @@
_NUNCHAKU_INT4_QTYPE = _quanto_qtypes[_NUNCHAKU_INT4_QTYPE_NAME]
_NUNCHAKU_OPS = None
_NUNCHAKU_FALLBACK_NOTICE = False
+_NUNCHAKU_KERNEL_NOTICE = False
_NUNCHAKU_SPLIT_FIELDS = {
"qweight": 0,
"wscales": 1,
@@ -71,6 +73,19 @@ def _split(state_dict, verboseLevel=1):
return _split
+def split_fused_weights(state_dict, fused_split_map, quantization_map=None, allowed_bases=None, default_dtype=None, verboseLevel=1):
+ from mmgp import offload
+ split_kwargs = get_nunchaku_split_kwargs()
+ return offload.sd_split_linear(
+ state_dict,
+ fused_split_map,
+ verboseLevel=verboseLevel,
+ allowed_bases=allowed_bases,
+ return_split_bases=True,
+ **split_kwargs,
+ )
+
+
def _install_nunchaku_shim(candidate_root):
candidate_pkg = candidate_root / "nunchaku"
if not candidate_pkg.exists():
@@ -135,6 +150,18 @@ def _notify_nunchaku_fallback(reason):
_NUNCHAKU_FALLBACK_NOTICE = True
+def _notify_nunchaku_kernel_status(verboseLevel=1):
+ global _NUNCHAKU_KERNEL_NOTICE
+ if _NUNCHAKU_KERNEL_NOTICE:
+ return
+ ops = _load_nunchaku_ops()
+ if ops:
+ print("[nunchaku_int4] Using Nunchaku kernels.")
+ elif not _NUNCHAKU_FALLBACK_NOTICE:
+ print("[nunchaku_int4] Nunchaku kernels unavailable; using Python fallback.")
+ _NUNCHAKU_KERNEL_NOTICE = True
+
+
def _is_float8_dtype(dtype):
return "float8" in str(dtype).lower() or "f8" in str(dtype).lower()
@@ -317,8 +344,10 @@ def _unpack_nunchaku_w4a4_weight(qweight, out_features, in_features):
k_tiles = in_features // mem_k
packed_i32 = qweight.view(torch.int32)
packed_i32 = packed_i32.view(n_tiles, k_tiles, 1, 8, 8, 4, 2, 2, 1)
- shifts = torch.tensor([0, 4, 8, 12, 16, 20, 24, 28], device=packed_i32.device, dtype=torch.int32)
- vals = (packed_i32.unsqueeze(-1) >> shifts) & 0xF
+ vals = torch.stack(
+ [(packed_i32 >> shift) & 0xF for shift in (0, 4, 8, 12, 16, 20, 24, 28)],
+ dim=-1,
+ )
vals = vals.permute(0, 3, 6, 4, 8, 1, 2, 7, 5, 9).contiguous()
vals = vals.view(out_features, in_features).to(torch.int16)
vals -= (vals >= 8).to(torch.int16) * 16
@@ -346,9 +375,7 @@ def _unpack_int4_from_int32(qweight, out_features, in_features):
q = qweight.view(torch.int32).reshape(out_features, in_features // 8)
q = q.to(torch.int64) & 0xFFFFFFFF
- shifts = torch.arange(0, 32, 4, device=q.device, dtype=torch.int64)
- mask = torch.tensor(0x0F, device=q.device, dtype=torch.int64)
- vals = (q.unsqueeze(-1) >> shifts) & mask
+ vals = torch.stack([(q >> shift) & 0xF for shift in range(0, 32, 4)], dim=-1)
return vals.reshape(out_features, in_features)
@@ -395,6 +422,53 @@ class NunchakuBaseWeightTensor(QTensor):
def __init__(self, qtype, axis):
super().__init__(qtype, axis)
+ def __repr__(self):
+ cls_name = self.__class__.__name__
+ try:
+ shape = tuple(self.shape)
+ except Exception:
+ shape = ">"
+ try:
+ dtype = str(self.dtype).replace("torch.", "")
+ except Exception:
+ dtype = ">"
+ try:
+ device = str(self.device)
+ except Exception:
+ device = ">"
+ qtype = getattr(self, "_qtype", None)
+ qtype_name = getattr(qtype, "name", None) or str(qtype) if qtype is not None else ">"
+ parts = [
+ f"shape={shape}",
+ f"dtype={dtype}",
+ f"device={device}",
+ f"qtype={qtype_name}",
+ ]
+ group_size = getattr(self, "_group_size", None)
+ if group_size is not None:
+ parts.append(f"group_size={group_size}")
+ field_parts = []
+ for name in (
+ "_qweight",
+ "_wscales",
+ "_wzeros",
+ "_smooth_factor",
+ "_proj_down",
+ "_proj_up",
+ ):
+ if not hasattr(self, name):
+ continue
+ value = getattr(self, name)
+ if torch.is_tensor(value):
+ field_parts.append(f"{name[1:]}={tuple(value.shape)}:{value.dtype}")
+ else:
+ field_parts.append(f"{name[1:]}={value}")
+ if field_parts:
+ parts.append("fields={" + ", ".join(field_parts) + "}")
+ return f"{cls_name}(" + ", ".join(parts) + ")"
+
+ __str__ = __repr__
+
def get_quantized_subtensors(self):
raise NotImplementedError
@@ -587,6 +661,7 @@ def _linear_cuda(self, input, bias=None):
out = out[: x.shape[0]]
return out.reshape(*input.shape[:-1], self.shape[0])
+ @torch.compiler.disable()
def linear(self, input, bias=None):
if torch.is_tensor(input) and input.device.type == "cuda":
return self._linear_cuda(input, bias=bias)
@@ -862,6 +937,7 @@ def _linear_cuda(self, input, bias=None):
out.add_(bias.view(view_shape))
return out.reshape(*input.shape[:-1], out.shape[-1])
+ @torch.compiler.disable()
def linear(self, input, bias=None):
if torch.is_tensor(input) and input.device.type == "cuda":
return self._linear_cuda(input, bias=bias)
@@ -1057,6 +1133,7 @@ def __init__(
def set_default_dtype(self, dtype):
self._nunchaku_default_dtype = dtype
+ @torch.compiler.disable()
def forward(self, input: torch.Tensor) -> torch.Tensor:
qweight = self.qweight
if isinstance(qweight, NunchakuBaseWeightTensor):
@@ -1370,6 +1447,7 @@ def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None)
specs = _collect_nunchaku_specs(state_dict)
if not specs:
return {"state_dict": state_dict, "quant_map": {}}
+ _notify_nunchaku_kernel_status(verboseLevel=verboseLevel)
quant_map = {spec["name"]: {"weights": "nunchaku_int4", "activations": "none"} for spec in specs}
return {"state_dict": state_dict, "quant_map": quant_map}
diff --git a/Wan2GP/shared/qtypes/nvfp4.py b/Wan2GP/shared/qtypes/nvfp4.py
index 62cf33614..59b194c0f 100644
--- a/Wan2GP/shared/qtypes/nvfp4.py
+++ b/Wan2GP/shared/qtypes/nvfp4.py
@@ -1,4 +1,5 @@
import ast
+import os
import torch
from torch.utils import _pytree as pytree
@@ -6,13 +7,31 @@
from optimum.quanto.tensor.qtensor import QTensor
from optimum.quanto.tensor.qtype import qtype as _quanto_qtype, qtypes as _quanto_qtypes
+def _maybe_add_nvfp4_cu13_dll_dir():
+ if os.name != "nt":
+ return
+ try:
+ import nvidia.cu13
+ dll_dir = os.path.join(nvidia.cu13.__path__[0], "bin", "x86_64")
+ if os.path.isdir(dll_dir):
+ os.add_dll_directory(dll_dir)
+ except Exception:
+ pass
+
+try:
+ from comfy_kitchen.backends import cuda as _ck_cuda
+ _ck_cuda_available = getattr(_ck_cuda, "_EXT_AVAILABLE", False)
+except Exception:
+ _ck_cuda = None
+ _ck_cuda_available = False
+
try:
- from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm
- _KERNEL_AVAILABLE = True
+ _maybe_add_nvfp4_cu13_dll_dir()
+ from lightx2v_kernel import gemm as _lx_gemm
+ _lx_gemm_available = True
except Exception:
- scaled_nvfp4_quant = None
- cutlass_scaled_nvfp4_mm = None
- _KERNEL_AVAILABLE = False
+ _lx_gemm = None
+ _lx_gemm_available = False
_NVFP4_QTYPE_NAME = "nvfp4"
if _NVFP4_QTYPE_NAME not in _quanto_qtypes:
@@ -25,14 +44,439 @@
qmax=6.0,
)
_NVFP4_QTYPE = _quanto_qtypes[_NVFP4_QTYPE_NAME]
+HANDLER_PRIORITY = 1
-def _supports_nvfp4_kernel(device):
- if not _KERNEL_AVAILABLE:
+_NVFP4_LAYOUT_LEGACY = "legacy"
+_NVFP4_LAYOUT_TENSORCORE = "tensorcore"
+
+_NVFP4_BACKEND_AUTO = "auto"
+_NVFP4_BACKEND_COMFY = "comfy"
+_NVFP4_BACKEND_LIGHTX2V = "lightx2v"
+
+_NVFP4_KERNEL_LOGGED = False
+_NVFP4_FALLBACK_LOGGED = False
+_NVFP4_LOAD_LOGGED = False
+_NVFP4_KERNEL_AVAILABLE = False
+_NVFP4_KERNEL_CHECKED = False
+_NVFP4_KERNEL_BACKEND = None
+_NVFP4_ACT_SCALE_CACHE = {}
+
+_NVFP4_SPLIT_FIELDS = {
+ "weight": 0,
+ "bias": 0,
+ "weight_scale": 0,
+ "weight_scale_2": 0,
+ "input_scale": 0,
+ "input_global_scale": 0,
+ "alpha": 0,
+ "input_absmax": 0,
+ "weight_global_scale": 0,
+ "output_scale": 0,
+}
+
+_NVFP4_BACKEND = os.environ.get("WGP_NVFP4_BACKEND", _NVFP4_BACKEND_AUTO).strip().lower()
+_NVFP4_BACKEND = _NVFP4_BACKEND_LIGHTX2V
+
+def _normalize_nvfp4_backend(name):
+ if name is None:
+ return _NVFP4_BACKEND_AUTO
+ norm = str(name).strip().lower()
+ if norm in ("", "auto", "default"):
+ return _NVFP4_BACKEND_AUTO
+ if norm in ("comfy", "comfy-kitchen", "comfy_kitchen", "ck"):
+ return _NVFP4_BACKEND_COMFY
+ if norm in ("lightx2v", "lightx2v_kernel", "lightx2v-kernel", "lx"):
+ return _NVFP4_BACKEND_LIGHTX2V
+ if norm in ("off", "none", "fallback", "disable", "disabled"):
+ return "fallback"
+ return norm
+
+
+def _split_or_share_nvfp4_scale(src, *, dim, split_sizes, context):
+ if src is None or not torch.is_tensor(src):
+ return None
+ total = sum(split_sizes)
+ if src.numel() == 1:
+ return [src] * len(split_sizes)
+ if src.dim() > dim and src.size(dim) == total:
+ return torch.split(src, split_sizes, dim=dim)
+ if src.ndim > 1 and src.size(1) == total:
+ return torch.split(src, split_sizes, dim=1)
+ return [src] * len(split_sizes)
+
+
+def split_fused_weights(state_dict, fused_split_map, quantization_map=None, allowed_bases=None, default_dtype=None, verboseLevel=1):
+ from mmgp import offload
+ return offload.sd_split_linear(
+ state_dict,
+ fused_split_map,
+ split_fields=dict(_NVFP4_SPLIT_FIELDS),
+ split_handlers={
+ "weight_scale": _split_or_share_nvfp4_scale,
+ "weight_scale_2": _split_or_share_nvfp4_scale,
+ "input_scale": _split_or_share_nvfp4_scale,
+ "input_global_scale": _split_or_share_nvfp4_scale,
+ "alpha": _split_or_share_nvfp4_scale,
+ "input_absmax": _split_or_share_nvfp4_scale,
+ "weight_global_scale": _split_or_share_nvfp4_scale,
+ "output_scale": _split_or_share_nvfp4_scale,
+ },
+ verboseLevel=verboseLevel,
+ allowed_bases=allowed_bases,
+ return_split_bases=True,
+ )
+
+
+_NVFP4_BACKEND = _normalize_nvfp4_backend(_NVFP4_BACKEND)
+
+
+def _nvfp4_backend_candidates():
+ if _NVFP4_BACKEND == _NVFP4_BACKEND_AUTO:
+ return [_NVFP4_BACKEND_COMFY, _NVFP4_BACKEND_LIGHTX2V]
+ if _NVFP4_BACKEND in (_NVFP4_BACKEND_COMFY, _NVFP4_BACKEND_LIGHTX2V):
+ return [_NVFP4_BACKEND]
+ return []
+
+
+def _nvfp4_backend_label(backend):
+ if backend == _NVFP4_BACKEND_LIGHTX2V:
+ return "lightx2v"
+ if backend == _NVFP4_BACKEND_COMFY:
+ return "comfy-kitchen"
+ return backend
+
+
+def _nvfp4_lightx2v_device_ok(device):
+ force = os.environ.get("WGP_NVFP4_LIGHTX2V_FORCE", "").strip().lower()
+ if force in ("1", "true", "yes", "y"):
+ return True
+ try:
+ props = torch.cuda.get_device_properties(device)
+ except Exception:
return False
+ return props.major >= 12
+
+
+def set_nvfp4_backend(name):
+ global _NVFP4_BACKEND, _NVFP4_KERNEL_CHECKED, _NVFP4_KERNEL_AVAILABLE, _NVFP4_KERNEL_BACKEND
+ global _NVFP4_KERNEL_LOGGED, _NVFP4_LOAD_LOGGED
+ _NVFP4_BACKEND = _normalize_nvfp4_backend(name)
+ _NVFP4_KERNEL_CHECKED = False
+ _NVFP4_KERNEL_AVAILABLE = False
+ _NVFP4_KERNEL_BACKEND = None
+ _NVFP4_KERNEL_LOGGED = False
+ _NVFP4_LOAD_LOGGED = False
+ _init_nvfp4_kernel_support()
+
+
+def _nvfp4_note_kernel():
+ global _NVFP4_KERNEL_LOGGED
+ if not _NVFP4_KERNEL_LOGGED:
+ label = _nvfp4_backend_label(_NVFP4_KERNEL_BACKEND) if _NVFP4_KERNEL_BACKEND else "CUDA"
+ print(f"NVFP4: using {label} kernel")
+ _NVFP4_KERNEL_LOGGED = True
+
+
+def _nvfp4_note_fallback():
+ global _NVFP4_FALLBACK_LOGGED
+ global _NVFP4_KERNEL_LOGGED
+ if not _NVFP4_FALLBACK_LOGGED:
+ if _NVFP4_KERNEL_LOGGED:
+ print("NVFP4: linear fallback needed on some weights")
+ else:
+ print("NVFP4: linear fallback")
+ _NVFP4_FALLBACK_LOGGED = True
+
+def _nvfp4_note_reset():
+ global _NVFP4_FALLBACK_LOGGED
+ global _NVFP4_KERNEL_LOGGED
+ global _NVFP4_LOAD_LOGGED
+ _NVFP4_KERNEL_LOGGED = False
+ _NVFP4_FALLBACK_LOGGED = False
+ _NVFP4_LOAD_LOGGED = False
+
+def _nvfp4_note_load_backend():
+ global _NVFP4_LOAD_LOGGED
+ if _NVFP4_LOAD_LOGGED:
+ return
+ _NVFP4_LOAD_LOGGED = True
+ if _NVFP4_KERNEL_AVAILABLE:
+ label = _nvfp4_backend_label(_NVFP4_KERNEL_BACKEND) if _NVFP4_KERNEL_BACKEND else "unknown"
+ print(f"NVFP4: kernels available ({label}); optimized path will be used when compatible.")
+ else:
+ print("NVFP4: kernels unavailable; using fallback.")
+
+
+def _check_nvfp4_kernel_support(device, backend):
+ if device.type != "cuda":
+ return False
+ if backend == _NVFP4_BACKEND_COMFY:
+ if not _ck_cuda_available:
+ return False
+ if not hasattr(_ck_cuda, "scaled_mm_nvfp4"):
+ return False
+ if not hasattr(_ck_cuda, "quantize_nvfp4"):
+ return False
+ if not (hasattr(torch.ops, "comfy_kitchen") and hasattr(torch.ops.comfy_kitchen, "scaled_mm_nvfp4")):
+ return False
+ major, minor = torch.cuda.get_device_capability(device)
+ return (major, minor) >= (10, 0)
+ if backend == _NVFP4_BACKEND_LIGHTX2V:
+ if not _lx_gemm_available:
+ return False
+ if not _nvfp4_lightx2v_device_ok(device):
+ return False
+ if not (hasattr(torch.ops, "lightx2v_kernel") and hasattr(torch.ops.lightx2v_kernel, "cutlass_scaled_nvfp4_mm_sm120")):
+ return False
+ if not hasattr(torch.ops.lightx2v_kernel, "scaled_nvfp4_quant_sm120"):
+ return False
+ major, minor = torch.cuda.get_device_capability(device)
+ return (major, minor) >= (12, 0)
+ return False
+
+
+def _init_nvfp4_kernel_support():
+ global _NVFP4_KERNEL_AVAILABLE, _NVFP4_KERNEL_CHECKED, _NVFP4_KERNEL_BACKEND
+ if _NVFP4_KERNEL_CHECKED:
+ return
+ _NVFP4_KERNEL_CHECKED = True
+ _NVFP4_KERNEL_AVAILABLE = False
+ _NVFP4_KERNEL_BACKEND = None
+ if not torch.cuda.is_available():
+ return
+ device = torch.device("cuda")
+ for backend in _nvfp4_backend_candidates():
+ try:
+ if _check_nvfp4_kernel_support(device, backend):
+ _NVFP4_KERNEL_AVAILABLE = True
+ _NVFP4_KERNEL_BACKEND = backend
+ break
+ except Exception:
+ continue
+
+
+def _supports_nvfp4_kernel(device):
if device.type != "cuda":
return False
- major, _ = torch.cuda.get_device_capability(device)
- return major >= 12
+ if not _NVFP4_KERNEL_CHECKED:
+ _init_nvfp4_kernel_support()
+ return _NVFP4_KERNEL_AVAILABLE
+
+
+_init_nvfp4_kernel_support()
+
+
+def _nvfp4_layout(weight):
+ return getattr(weight, "_layout", _NVFP4_LAYOUT_LEGACY)
+
+
+def _nvfp4_can_use_kernel(input, weight):
+ if not torch.is_tensor(input):
+ return False
+ if not getattr(weight, "_allow_kernel", True):
+ return False
+ if not _supports_nvfp4_kernel(input.device):
+ return False
+ backend = _NVFP4_KERNEL_BACKEND
+ if backend is None:
+ return False
+ layout = _nvfp4_layout(weight)
+ if backend == _NVFP4_BACKEND_LIGHTX2V:
+ if input.shape[-1] % 32 != 0:
+ return False
+ if weight.size(0) % 32 != 0:
+ return False
+ else:
+ if layout == _NVFP4_LAYOUT_LEGACY:
+ if input.shape[-1] % 64 != 0:
+ return False
+ else:
+ if input.shape[-1] % 16 != 0:
+ return False
+ if weight.size(0) % 8 != 0:
+ return False
+ if weight._data.shape[1] * 2 != input.shape[-1]:
+ return False
+ if weight._block_size != 16:
+ return False
+ if not torch.is_tensor(weight._input_global_scale) or not torch.is_tensor(weight._alpha):
+ return False
+ if getattr(weight._input_global_scale, "is_meta", False):
+ return False
+ try:
+ if not torch.isfinite(weight._input_global_scale).all():
+ return False
+ except Exception:
+ return False
+ return True
+
+
+def _nvfp4_get_act_scale(device):
+ act_scale = _NVFP4_ACT_SCALE_CACHE.get(device)
+ if act_scale is None:
+ act_scale = torch.tensor(1.0, device=device, dtype=torch.float32)
+ _NVFP4_ACT_SCALE_CACHE[device] = act_scale
+ return act_scale
+
+
+def _nvfp4_swap_nibbles(tensor):
+ return ((tensor & 0x0F) << 4) | ((tensor & 0xF0) >> 4)
+
+
+def _nvfp4_linear_cuda_comfy(input, weight, bias=None):
+ _nvfp4_note_kernel()
+ x2d = input.reshape(-1, input.shape[-1])
+ if not x2d.is_floating_point():
+ x2d = x2d.to(torch.float16)
+ orig_dtype = x2d.dtype
+ if orig_dtype not in (torch.float16, torch.bfloat16):
+ x2d = x2d.to(torch.float16)
+ out_dtype = torch.float16
+ else:
+ out_dtype = orig_dtype
+ if not x2d.is_contiguous():
+ x2d = x2d.contiguous()
+ weight_fp4 = weight._data
+ weight_scale = weight._scale
+ input_scale = weight._input_global_scale
+ alpha = weight._alpha
+ layout = _nvfp4_layout(weight)
+ device = x2d.device
+ if weight_fp4.device != device:
+ weight_fp4 = weight_fp4.to(device)
+ if weight_scale.device != device:
+ weight_scale = weight_scale.to(device)
+ if input_scale.device != device:
+ input_scale = input_scale.to(device)
+ if alpha.device != device:
+ alpha = alpha.to(device)
+ if bias is not None and torch.is_tensor(bias) and bias.dtype != out_dtype:
+ bias = bias.to(out_dtype)
+ orig_rows = x2d.shape[0]
+ pad_16x = (orig_rows % 16) != 0
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ input_scale = input_scale.to(torch.float32)
+ tensor_scale = alpha.to(torch.float32)
+ qx, qx_scale = _ck_cuda.quantize_nvfp4(x2d, input_scale, 0.0, pad_16x)
+ out = _ck_cuda.scaled_mm_nvfp4(
+ qx,
+ weight_fp4,
+ tensor_scale_a=input_scale,
+ tensor_scale_b=tensor_scale,
+ block_scale_a=qx_scale,
+ block_scale_b=weight_scale,
+ bias=bias,
+ out_dtype=out_dtype,
+ )
+ else:
+ alpha = alpha * input_scale
+ if alpha.dtype != torch.float32:
+ alpha = alpha.to(torch.float32)
+ act_scale = _nvfp4_get_act_scale(device)
+ qx, qx_scale = _ck_cuda.quantize_nvfp4(x2d, act_scale, 0.0, pad_16x)
+ weight_fp4 = _nvfp4_swap_nibbles(weight_fp4)
+ out = _ck_cuda.scaled_mm_nvfp4(
+ qx,
+ weight_fp4,
+ act_scale,
+ input_scale,
+ qx_scale,
+ weight_scale,
+ bias=bias,
+ out_dtype=out_dtype,
+ alpha=alpha,
+ )
+ if pad_16x:
+ out = out[:orig_rows]
+ if out.dtype != orig_dtype:
+ out = out.to(orig_dtype)
+ return out.reshape(*input.shape[:-1], weight.size(0))
+
+
+def _nvfp4_linear_cuda_lightx2v(input, weight, bias=None):
+ _nvfp4_note_kernel()
+ x2d = input.reshape(-1, input.shape[-1])
+ if not x2d.is_floating_point():
+ x2d = x2d.to(torch.float16)
+ orig_dtype = x2d.dtype
+ if orig_dtype not in (torch.float16, torch.bfloat16):
+ x2d = x2d.to(torch.float16)
+ out_dtype = torch.float16
+ else:
+ out_dtype = orig_dtype
+ if not x2d.is_contiguous():
+ x2d = x2d.contiguous()
+ weight_fp4 = weight._data
+ weight_scale = weight._scale
+ input_scale = weight._input_global_scale
+ alpha = weight._alpha
+ layout = _nvfp4_layout(weight)
+ device = x2d.device
+ if weight_fp4.device != device:
+ weight_fp4 = weight_fp4.to(device)
+ if weight_scale.device != device:
+ weight_scale = weight_scale.to(device)
+ if not weight_fp4.is_contiguous():
+ weight_fp4 = weight_fp4.contiguous()
+ if not weight_scale.is_contiguous():
+ weight_scale = weight_scale.contiguous()
+ if input_scale.device != device:
+ input_scale = input_scale.to(device)
+ if alpha.device != device:
+ alpha = alpha.to(device)
+ if input_scale.dtype != torch.float32:
+ input_scale = input_scale.to(torch.float32)
+ if alpha.dtype != torch.float32:
+ alpha = alpha.to(torch.float32)
+ if bias is not None and torch.is_tensor(bias):
+ if bias.dtype != torch.bfloat16:
+ bias = bias.to(torch.bfloat16)
+ if not bias.is_contiguous():
+ bias = bias.contiguous()
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ quant_scale = torch.reciprocal(torch.clamp(input_scale, min=1e-8))
+ alpha = alpha * input_scale
+ else:
+ quant_scale = input_scale
+ qx, qx_scale = _lx_gemm.scaled_nvfp4_quant(x2d, quant_scale)
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ qx = _nvfp4_swap_nibbles(qx)
+ if not qx.is_contiguous():
+ qx = qx.contiguous()
+ if not qx_scale.is_contiguous():
+ qx_scale = qx_scale.contiguous()
+ out = _lx_gemm.cutlass_scaled_nvfp4_mm(
+ qx,
+ weight_fp4,
+ qx_scale,
+ weight_scale,
+ alpha=alpha,
+ bias=bias,
+ )
+ if out.dtype != orig_dtype:
+ out = out.to(orig_dtype)
+ return out.reshape(*input.shape[:-1], weight.size(0))
+
+
+def _nvfp4_linear_cuda(input, weight, bias=None):
+ if _NVFP4_KERNEL_BACKEND == _NVFP4_BACKEND_LIGHTX2V:
+ return _nvfp4_linear_cuda_lightx2v(input, weight, bias=bias)
+ return _nvfp4_linear_cuda_comfy(input, weight, bias=bias)
+
+
+@torch.compiler.disable()
+def _nvfp4_linear(input, weight, bias=None, op=None):
+ if _nvfp4_can_use_kernel(input, weight):
+ return _nvfp4_linear_cuda(input, weight, bias=bias)
+ _nvfp4_note_fallback()
+ dtype = input.dtype if torch.is_tensor(input) else weight.dtype
+ device = input.device if torch.is_tensor(input) else weight.device
+ w = weight.dequantize(dtype=dtype, device=device)
+ if bias is not None and torch.is_tensor(bias) and bias.dtype != dtype:
+ bias = bias.to(dtype)
+ if op is not None:
+ return op(input, w, bias)
+ return torch.nn.functional.linear(input, w, bias)
def _is_float8_dtype(dtype):
@@ -96,6 +540,7 @@ def _dequantize_nvfp4_weight(
dtype,
device,
block_size=16,
+ layout=_NVFP4_LAYOUT_LEGACY,
):
if weight_u8.device != device:
weight_u8 = weight_u8.to(device)
@@ -104,10 +549,18 @@ def _dequantize_nvfp4_weight(
alpha = alpha.to(device)
if input_global_scale.device != device:
input_global_scale = input_global_scale.to(device)
+ if layout == _NVFP4_LAYOUT_TENSORCORE and device.type == "cuda" and _ck_cuda_available:
+ try:
+ return _ck_cuda.dequantize_nvfp4(weight_u8, alpha.to(torch.float32), scale, output_type=dtype)
+ except Exception:
+ pass
m, k_bytes = weight_u8.shape
byte_lut = _get_fp4_byte_lut(device, dtype)
- idx = weight_u8.to(torch.int32)
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ idx = _nvfp4_swap_nibbles(weight_u8).to(torch.int32)
+ else:
+ idx = weight_u8.to(torch.int32)
out = byte_lut[idx].reshape(m, k_bytes * 2)
scale = _deswizzle_nvfp4_scale(scale, out.shape[1], block_size=block_size, dtype=dtype)
@@ -115,7 +568,10 @@ def _dequantize_nvfp4_weight(
out.mul_(scale.unsqueeze(-1))
out = out.view(out.shape[0], -1)
- scale_factor = alpha.to(dtype) * input_global_scale.to(dtype)
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ scale_factor = alpha.to(dtype)
+ else:
+ scale_factor = alpha.to(dtype) * input_global_scale.to(dtype)
out.mul_(scale_factor)
return out
@@ -129,20 +585,50 @@ def _collect_nvfp4_specs(state_dict):
continue
base = key[:-7]
scale_key = base + ".weight_scale"
- input_global_key = base + ".input_global_scale"
- alpha_key = base + ".alpha"
- if scale_key not in state_dict or input_global_key not in state_dict or alpha_key not in state_dict:
+ if scale_key not in state_dict:
continue
if not _is_float8_dtype(state_dict[scale_key].dtype):
continue
+
+ weight_scale_2_key = base + ".weight_scale_2"
+ input_scale_key = base + ".input_scale"
+ if weight_scale_2_key in state_dict:
+ specs.append(
+ {
+ "name": base,
+ "weight": tensor,
+ "weight_scale": state_dict[scale_key],
+ "weight_scale_2": state_dict[weight_scale_2_key],
+ "input_scale": state_dict.get(input_scale_key, None),
+ "bias": state_dict.get(base + ".bias", None),
+ "layout": _NVFP4_LAYOUT_TENSORCORE,
+ }
+ )
+ continue
+
+ input_global_key = base + ".input_global_scale"
+ alpha_key = base + ".alpha"
+ input_absmax_key = base + ".input_absmax"
+ weight_global_scale_key = base + ".weight_global_scale"
+ if input_global_key not in state_dict or alpha_key not in state_dict:
+ if input_absmax_key not in state_dict or weight_global_scale_key not in state_dict:
+ continue
+ input_absmax = state_dict[input_absmax_key]
+ weight_global_scale = state_dict[weight_global_scale_key]
+ input_global_scale = (2688.0 / input_absmax).to(torch.float32)
+ alpha = 1.0 / (input_global_scale * weight_global_scale.to(torch.float32))
+ else:
+ input_global_scale = state_dict[input_global_key]
+ alpha = state_dict[alpha_key]
specs.append(
{
"name": base,
"weight": tensor,
"weight_scale": state_dict[scale_key],
- "input_global_scale": state_dict[input_global_key],
- "alpha": state_dict[alpha_key],
+ "input_global_scale": input_global_scale,
+ "alpha": alpha,
"bias": state_dict.get(base + ".bias", None),
+ "layout": _NVFP4_LAYOUT_LEGACY,
}
)
return specs
@@ -162,7 +648,12 @@ def convert_nvfp4_to_quanto(state_dict, default_dtype=None, verboseLevel=1):
specs = _collect_nvfp4_specs(state_dict)
if not specs:
return {"state_dict": state_dict, "quant_map": {}}
- quant_map = {spec["name"]: {"weights": "nvfp4", "activations": "none"} for spec in specs}
+ _nvfp4_note_load_backend()
+ quant_map = {}
+ for spec in specs:
+ qcfg = {"weights": "nvfp4", "activations": "none"}
+ quant_map[spec["name"]] = qcfg
+ quant_map[spec["name"] + ".weight"] = qcfg
return {"state_dict": state_dict, "quant_map": quant_map}
@@ -175,6 +666,7 @@ def detect(state_dict, verboseLevel=1):
def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None):
if detection is not None and not detection.get("matched", False):
return {"state_dict": state_dict, "quant_map": {}}
+ _nvfp4_note_reset()
return convert_nvfp4_to_quanto(state_dict, default_dtype=default_dtype, verboseLevel=verboseLevel)
@@ -192,14 +684,32 @@ class NVFP4WeightTensor(QTensor):
def create(
weight_u8,
weight_scale,
- input_global_scale,
- alpha,
size,
stride,
dtype,
+ input_global_scale=None,
+ alpha=None,
+ input_scale=None,
+ weight_scale_2=None,
device=None,
requires_grad=False,
+ layout=_NVFP4_LAYOUT_LEGACY,
+ allow_kernel=True,
):
+ if input_global_scale is None and input_scale is not None:
+ input_global_scale = input_scale
+ if alpha is None and weight_scale_2 is not None:
+ alpha = weight_scale_2
+ if layout == _NVFP4_LAYOUT_LEGACY and (weight_scale_2 is not None or input_scale is not None):
+ layout = _NVFP4_LAYOUT_TENSORCORE
+ if input_global_scale is None or alpha is None:
+ raise ValueError("NVFP4WeightTensor.create requires input_global_scale/alpha or input_scale/weight_scale_2")
+ if torch.is_tensor(input_global_scale):
+ try:
+ if not torch.isfinite(input_global_scale).all():
+ allow_kernel = False
+ except Exception:
+ allow_kernel = False
device = weight_u8.device if device is None else device
if weight_u8.device != device:
weight_u8 = weight_u8.to(device)
@@ -218,8 +728,10 @@ def create(
weight_scale=weight_scale,
input_global_scale=input_global_scale,
alpha=alpha,
+ allow_kernel=allow_kernel,
dtype=dtype,
requires_grad=requires_grad,
+ layout=layout,
)
@staticmethod
@@ -234,7 +746,9 @@ def __new__(
input_global_scale,
alpha,
dtype,
+ allow_kernel=True,
requires_grad=False,
+ layout=_NVFP4_LAYOUT_LEGACY,
):
return torch.Tensor._make_wrapper_subclass(
cls,
@@ -257,6 +771,8 @@ def __init__(
alpha,
dtype,
requires_grad=False,
+ layout=_NVFP4_LAYOUT_LEGACY,
+ allow_kernel=True,
):
super().__init__(qtype, axis)
self._data = weight_u8
@@ -264,6 +780,8 @@ def __init__(
self._input_global_scale = input_global_scale
self._alpha = alpha
self._block_size = 16
+ self._layout = layout
+ self._allow_kernel = allow_kernel
def dequantize(self, dtype=None, device=None):
if dtype is None:
@@ -278,9 +796,17 @@ def dequantize(self, dtype=None, device=None):
dtype=dtype,
device=device,
block_size=self._block_size,
+ layout=self._layout,
)
def get_quantized_subtensors(self):
+ if self._layout == _NVFP4_LAYOUT_TENSORCORE:
+ return [
+ ("weight_u8", self._data),
+ ("weight_scale", self._scale),
+ ("weight_scale_2", self._alpha),
+ ("input_scale", self._input_global_scale),
+ ]
return [
("weight_u8", self._data),
("weight_scale", self._scale),
@@ -298,9 +824,13 @@ def set_quantized_subtensors(self, sub_tensors):
self._data = data
if "weight_scale" in sub_map and sub_map["weight_scale"] is not None:
self._scale = sub_map["weight_scale"]
- if "input_global_scale" in sub_map and sub_map["input_global_scale"] is not None:
+ if "input_scale" in sub_map and sub_map["input_scale"] is not None:
+ self._input_global_scale = sub_map["input_scale"]
+ elif "input_global_scale" in sub_map and sub_map["input_global_scale"] is not None:
self._input_global_scale = sub_map["input_global_scale"]
- if "alpha" in sub_map and sub_map["alpha"] is not None:
+ if "weight_scale_2" in sub_map and sub_map["weight_scale_2"] is not None:
+ self._alpha = sub_map["weight_scale_2"]
+ elif "alpha" in sub_map and sub_map["alpha"] is not None:
self._alpha = sub_map["alpha"]
def __tensor_flatten__(self):
@@ -311,6 +841,8 @@ def __tensor_flatten__(self):
"size": str(list(self.size())),
"stride": str(list(self.stride())),
"dtype": str(self.dtype),
+ "layout": self._layout,
+ "allow_kernel": "1" if self._allow_kernel else "0",
}
return inner_tensors, meta
@@ -326,6 +858,8 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
dtype = getattr(torch, dtype_name, torch.float16)
else:
dtype = getattr(torch, dtype_str, torch.float16)
+ layout = meta.get("layout", _NVFP4_LAYOUT_LEGACY)
+ allow_kernel = str(meta.get("allow_kernel", "1")).strip().lower() not in ("0", "false", "no")
return NVFP4WeightTensor(
qtype=qtype,
axis=axis,
@@ -335,7 +869,9 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
weight_scale=inner_tensors["_scale"],
input_global_scale=inner_tensors["_input_global_scale"],
alpha=inner_tensors["_alpha"],
+ allow_kernel=allow_kernel,
dtype=dtype,
+ layout=layout,
)
@classmethod
@@ -346,31 +882,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
weight = args[1] if len(args) > 1 else kwargs.get("weight", None)
bias = args[2] if len(args) > 2 else kwargs.get("bias", None)
if isinstance(weight, NVFP4WeightTensor):
- if torch.is_tensor(input) and _supports_nvfp4_kernel(input.device):
- x2d = input.reshape(-1, input.shape[-1])
- if not x2d.is_floating_point():
- x2d = x2d.to(torch.float16)
- orig_dtype = x2d.dtype
- input_quant, input_scale = scaled_nvfp4_quant(x2d, weight._input_global_scale)
- if bias is not None and torch.is_tensor(bias) and bias.dtype != orig_dtype:
- bias = bias.to(orig_dtype)
- out = cutlass_scaled_nvfp4_mm(
- input_quant,
- weight._data,
- input_scale,
- weight._scale,
- alpha=weight._alpha,
- bias=bias,
- )
- if out.dtype != orig_dtype:
- out = out.to(orig_dtype)
- return out.reshape(*input.shape[:-1], weight.size(0))
- dtype = input.dtype if torch.is_tensor(input) else weight.dtype
- device = input.device if torch.is_tensor(input) else weight.device
- w = weight.dequantize(dtype=dtype, device=device)
- if bias is not None and torch.is_tensor(bias) and bias.dtype != dtype:
- bias = bias.to(dtype)
- return torch.nn.functional.linear(input, w, bias)
+ return _nvfp4_linear(input, weight, bias=bias)
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
@@ -382,31 +894,7 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(weight, NVFP4WeightTensor):
- if torch.is_tensor(input) and _supports_nvfp4_kernel(input.device):
- x2d = input.reshape(-1, input.shape[-1])
- if not x2d.is_floating_point():
- x2d = x2d.to(torch.float16)
- orig_dtype = x2d.dtype
- input_quant, input_scale = scaled_nvfp4_quant(x2d, weight._input_global_scale)
- if bias is not None and torch.is_tensor(bias) and bias.dtype != orig_dtype:
- bias = bias.to(orig_dtype)
- out = cutlass_scaled_nvfp4_mm(
- input_quant,
- weight._data,
- input_scale,
- weight._scale,
- alpha=weight._alpha,
- bias=bias,
- )
- if out.dtype != orig_dtype:
- out = out.to(orig_dtype)
- return out.reshape(*input.shape[:-1], weight.size(0))
- dtype = input.dtype if torch.is_tensor(input) else weight.dtype
- device = input.device if torch.is_tensor(input) else weight.device
- w = weight.dequantize(dtype=dtype, device=device)
- if bias is not None and torch.is_tensor(bias) and bias.dtype != dtype:
- bias = bias.to(dtype)
- return op(input, w, bias)
+ return _nvfp4_linear(input, weight, bias=bias, op=op)
if op is torch.ops.aten.detach:
t = args[0]
return NVFP4WeightTensor.create(
@@ -414,11 +902,13 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
weight_scale=op(t._scale),
input_global_scale=op(t._input_global_scale),
alpha=op(t._alpha),
+ allow_kernel=getattr(t, "_allow_kernel", True),
size=t.size(),
stride=t.stride(),
dtype=t.dtype,
device=t.device,
requires_grad=t.requires_grad,
+ layout=t._layout,
)
if op in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
@@ -435,11 +925,13 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
weight_scale=out_scale,
input_global_scale=out_igs,
alpha=out_alpha,
+ allow_kernel=getattr(t, "_allow_kernel", True),
size=t.size(),
stride=t.stride(),
dtype=t.dtype,
device=device,
requires_grad=t.requires_grad,
+ layout=t._layout,
)
return _nvfp4_qfallback(op, *args, **(kwargs or {}))
@@ -519,16 +1011,22 @@ def _load_from_state_dict(
weight_key = prefix + "weight"
scale_key = prefix + "weight_scale"
+ scale2_key = prefix + "weight_scale_2"
igs_key = prefix + "input_global_scale"
alpha_key = prefix + "alpha"
+ input_absmax_key = prefix + "input_absmax"
+ weight_global_scale_key = prefix + "weight_global_scale"
bias_key = prefix + "bias"
input_scale_key = prefix + "input_scale"
output_scale_key = prefix + "output_scale"
weight_u8 = state_dict.pop(weight_key, None)
weight_scale = state_dict.pop(scale_key, None)
+ weight_scale_2 = state_dict.pop(scale2_key, None)
input_global_scale = state_dict.pop(igs_key, None)
alpha = state_dict.pop(alpha_key, None)
+ input_absmax = state_dict.pop(input_absmax_key, None)
+ weight_global_scale = state_dict.pop(weight_global_scale_key, None)
bias = state_dict.pop(bias_key, None)
input_scale = state_dict.pop(input_scale_key, None)
output_scale = state_dict.pop(output_scale_key, None)
@@ -537,25 +1035,68 @@ def _load_from_state_dict(
missing_keys.append(weight_key)
if weight_scale is None:
missing_keys.append(scale_key)
- if input_global_scale is None:
- missing_keys.append(igs_key)
- if alpha is None:
- missing_keys.append(alpha_key)
+ layout = _NVFP4_LAYOUT_LEGACY
+ allow_kernel = True
+ if weight_scale_2 is not None or input_scale is not None:
+ layout = _NVFP4_LAYOUT_TENSORCORE
+ if weight_scale_2 is None:
+ missing_keys.append(scale2_key)
+ if input_scale is None:
+ allow_kernel = False
+ if torch.is_tensor(weight_scale_2):
+ input_scale = torch.full(
+ (),
+ float("nan"),
+ dtype=weight_scale_2.dtype,
+ device=weight_scale_2.device,
+ )
+ elif torch.is_tensor(weight_u8):
+ input_scale = torch.full((), float("nan"), dtype=torch.float32, device=weight_u8.device)
+ else:
+ input_scale = torch.tensor(float("nan"), dtype=torch.float32)
+ else:
+ if input_global_scale is None or alpha is None:
+ if input_absmax is not None and weight_global_scale is not None:
+ input_global_scale = (2688.0 / input_absmax).to(torch.float32)
+ alpha = 1.0 / (input_global_scale * weight_global_scale.to(torch.float32))
+ else:
+ if input_global_scale is None:
+ missing_keys.append(igs_key)
+ if alpha is None:
+ missing_keys.append(alpha_key)
target_dtype = self._nvfp4_default_dtype or self.weight.dtype
- if weight_u8 is not None and weight_scale is not None and input_global_scale is not None and alpha is not None:
- nvfp4_weight = NVFP4WeightTensor.create(
- weight_u8=weight_u8,
- weight_scale=weight_scale,
- input_global_scale=input_global_scale,
- alpha=alpha,
- size=self.weight.size(),
- stride=self.weight.stride(),
- dtype=target_dtype,
- device=weight_u8.device,
- requires_grad=False,
- )
- self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False)
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ if weight_u8 is not None and weight_scale is not None and weight_scale_2 is not None and input_scale is not None:
+ nvfp4_weight = NVFP4WeightTensor.create(
+ weight_u8=weight_u8,
+ weight_scale=weight_scale,
+ input_global_scale=input_scale,
+ alpha=weight_scale_2,
+ allow_kernel=allow_kernel,
+ size=self.weight.size(),
+ stride=self.weight.stride(),
+ dtype=target_dtype,
+ device=weight_u8.device,
+ requires_grad=False,
+ layout=layout,
+ )
+ self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False)
+ else:
+ if weight_u8 is not None and weight_scale is not None and input_global_scale is not None and alpha is not None:
+ nvfp4_weight = NVFP4WeightTensor.create(
+ weight_u8=weight_u8,
+ weight_scale=weight_scale,
+ input_global_scale=input_global_scale,
+ alpha=alpha,
+ size=self.weight.size(),
+ stride=self.weight.stride(),
+ dtype=target_dtype,
+ device=weight_u8.device,
+ requires_grad=False,
+ layout=layout,
+ )
+ self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False)
if bias is not None:
if target_dtype is not None and bias.dtype != target_dtype:
@@ -586,3 +1127,122 @@ def _load_from_state_dict(
self.output_scale = torch.ones((), dtype=scale_dtype, device=scale_device)
return
+
+
+def validate_nvfp4_kernel(
+ state_dict=None,
+ checkpoint_path=None,
+ device=None,
+ max_layers=4,
+ seed=0,
+ batch_size=2,
+ dtype=torch.bfloat16,
+ verbose=True,
+):
+ """Compare kernel vs fallback outputs for a few NVFP4 layers."""
+ if state_dict is None:
+ if checkpoint_path is None:
+ raise ValueError("state_dict or checkpoint_path is required")
+ from mmgp import safetensors2
+
+ state_dict = {}
+ with safetensors2.safe_open(checkpoint_path, framework="pt", device="cpu", writable_tensors=False) as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key)
+
+ specs = _collect_nvfp4_specs(state_dict)
+ if not specs:
+ return {"ok": False, "reason": "no nvfp4 weights found"}
+
+ candidates = sorted(specs, key=lambda spec: spec["weight"].numel())
+ if isinstance(max_layers, int) and max_layers > 0:
+ candidates = candidates[:max_layers]
+
+ device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ results = []
+ with torch.no_grad():
+ for spec in candidates:
+ weight = spec["weight"]
+ layout = spec.get("layout", _NVFP4_LAYOUT_LEGACY)
+ in_features = weight.shape[1] * 2
+ bias = spec.get("bias")
+
+ if layout == _NVFP4_LAYOUT_TENSORCORE:
+ input_scale = spec.get("input_scale")
+ tensor_scale = spec.get("weight_scale_2")
+ else:
+ input_scale = spec.get("input_global_scale")
+ tensor_scale = spec.get("alpha")
+
+ if input_scale is None or tensor_scale is None:
+ results.append({"name": spec["name"], "layout": layout, "kernel": False, "reason": "missing scales"})
+ continue
+
+ nvfp4_weight = NVFP4WeightTensor.create(
+ weight_u8=weight,
+ weight_scale=spec["weight_scale"],
+ input_global_scale=input_scale,
+ alpha=tensor_scale,
+ size=(weight.shape[0], in_features),
+ stride=(in_features, 1),
+ dtype=dtype,
+ device=device,
+ requires_grad=False,
+ layout=layout,
+ )
+
+ x = torch.randn(batch_size, in_features, device=device, dtype=dtype)
+ if bias is not None:
+ bias = bias.to(device=device, dtype=dtype)
+
+ kernel_ok = _nvfp4_can_use_kernel(x, nvfp4_weight)
+ y_kernel = _nvfp4_linear_cuda(x, nvfp4_weight, bias=bias) if kernel_ok else None
+ x_ref = x
+ if (
+ layout == _NVFP4_LAYOUT_TENSORCORE
+ and _ck_cuda_available
+ and device.type == "cuda"
+ and torch.is_tensor(input_scale)
+ ):
+ input_scale_fp32 = input_scale.to(device=device, dtype=torch.float32)
+ pad_16x = (x.shape[0] % 16) != 0
+ qx, qx_scale = _ck_cuda.quantize_nvfp4(x, input_scale_fp32, 0.0, pad_16x)
+ x_ref = _ck_cuda.dequantize_nvfp4(qx, input_scale_fp32, qx_scale, output_type=dtype)
+ if pad_16x:
+ x_ref = x_ref[: x.shape[0]]
+ y_ref = torch.nn.functional.linear(
+ x_ref,
+ nvfp4_weight.dequantize(dtype=dtype, device=device),
+ bias,
+ )
+
+ if y_kernel is None:
+ results.append({"name": spec["name"], "layout": layout, "kernel": False})
+ continue
+
+ diff = (y_kernel - y_ref).float()
+ results.append(
+ {
+ "name": spec["name"],
+ "layout": layout,
+ "kernel": True,
+ "max_abs": diff.abs().max().item(),
+ "mean_abs": diff.abs().mean().item(),
+ }
+ )
+
+ if verbose:
+ print("NVFP4 kernel validation:")
+ for entry in results:
+ if not entry.get("kernel"):
+ print(f" {entry['name']}: kernel skipped ({entry.get('reason', 'incompatible')})")
+ continue
+ print(
+ f" {entry['name']}: max_abs={entry['max_abs']:.6f} mean_abs={entry['mean_abs']:.6f}"
+ )
+
+ return {"ok": True, "results": results}
diff --git a/Wan2GP/shared/qtypes/scaled_fp8.py b/Wan2GP/shared/qtypes/scaled_fp8.py
new file mode 100644
index 000000000..f801bef67
--- /dev/null
+++ b/Wan2GP/shared/qtypes/scaled_fp8.py
@@ -0,0 +1,735 @@
+import ast
+import os
+
+import torch
+from torch.utils import _pytree as pytree
+
+from optimum.quanto import QModuleMixin
+from optimum.quanto.tensor.qtensor import QTensor
+from optimum.quanto.tensor.qtype import qtype as _quanto_qtype, qtypes as _quanto_qtypes
+
+
+HANDLER_NAME = "fp8"
+HANDLER_PRIORITY = 10
+
+_SCALED_FP8_E4M3_QTYPE_NAME = "scaled_float8_e4m3fn"
+_SCALED_FP8_E5M2_QTYPE_NAME = "scaled_float8_e5m2"
+
+
+def _register_fp8_qtype(name, dtype):
+ if name not in _quanto_qtypes:
+ _quanto_qtypes[name] = _quanto_qtype(
+ name,
+ is_floating_point=True,
+ bits=8,
+ dtype=dtype,
+ qmin=float(torch.finfo(dtype).min),
+ qmax=float(torch.finfo(dtype).max),
+ )
+ return _quanto_qtypes[name]
+
+
+_SCALED_FP8_QTYPE_E4M3 = _register_fp8_qtype(_SCALED_FP8_E4M3_QTYPE_NAME, torch.float8_e4m3fn)
+_SCALED_FP8_QTYPE_E5M2 = _register_fp8_qtype(_SCALED_FP8_E5M2_QTYPE_NAME, torch.float8_e5m2)
+
+_SCALED_FP8_QTYPE_BY_DTYPE = {
+ torch.float8_e4m3fn: _SCALED_FP8_QTYPE_E4M3,
+ torch.float8_e5m2: _SCALED_FP8_QTYPE_E5M2,
+}
+_SCALED_FP8_QTYPES = set(_SCALED_FP8_QTYPE_BY_DTYPE.values())
+
+_FP8_RANGE = {
+ torch.float8_e4m3fn: (float(torch.finfo(torch.float8_e4m3fn).min), float(torch.finfo(torch.float8_e4m3fn).max)),
+ torch.float8_e5m2: (float(torch.finfo(torch.float8_e5m2).min), float(torch.finfo(torch.float8_e5m2).max)),
+}
+
+_FP8_MM_SUPPORT = {
+ torch.float8_e4m3fn: False,
+ torch.float8_e5m2: False,
+}
+_FP8_MM_PROBED = False
+_SCALED_FP8_DEFAULT_DTYPE = None
+
+_SCALED_FP8_SPLIT_FIELDS = {
+ "weight": 0,
+ "bias": 0,
+ "scale_weight": 0,
+ "weight_scale": 0,
+}
+_SCALED_FP8_SHARE_FIELDS = ("input_scale", "output_scale")
+
+def _is_float8_dtype(dtype):
+ return dtype in _SCALED_FP8_QTYPE_BY_DTYPE
+
+
+def _get_fp8_qtype(dtype):
+ return _SCALED_FP8_QTYPE_BY_DTYPE.get(dtype, None)
+
+
+def _set_default_dtype_from_loader(dtype):
+ global _SCALED_FP8_DEFAULT_DTYPE
+ if dtype is None:
+ return
+ if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
+ return
+ _SCALED_FP8_DEFAULT_DTYPE = dtype
+
+
+def _normalize_default_dtype(dtype):
+ if dtype is None:
+ return _SCALED_FP8_DEFAULT_DTYPE or torch.bfloat16
+ if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
+ return _SCALED_FP8_DEFAULT_DTYPE or torch.bfloat16
+ return dtype
+
+
+def _split_fp8_scale(src, *, dim, split_sizes, context):
+ if src is None or not torch.is_tensor(src):
+ return None
+ total = sum(split_sizes)
+ if src.numel() == 1:
+ return [src] * len(split_sizes)
+ if src.dim() > dim and src.size(dim) == total:
+ return torch.split(src, split_sizes, dim=dim)
+ if src.ndim > 1 and src.size(1) == total:
+ return torch.split(src, split_sizes, dim=1)
+ return [src] * len(split_sizes)
+
+
+def split_fused_weights(state_dict, fused_split_map, quantization_map=None, allowed_bases=None, default_dtype=None, verboseLevel=1):
+ from mmgp import offload
+ return offload.sd_split_linear(
+ state_dict,
+ fused_split_map,
+ split_fields=dict(_SCALED_FP8_SPLIT_FIELDS),
+ share_fields=_SCALED_FP8_SHARE_FIELDS,
+ split_handlers={
+ "scale_weight": _split_fp8_scale,
+ "weight_scale": _split_fp8_scale,
+ },
+ verboseLevel=verboseLevel,
+ allowed_bases=allowed_bases,
+ return_split_bases=True,
+ )
+
+
+
+def _scaled_mm_available(dtype):
+ if os.environ.get("WAN2GP_FORCE_FP8_FALLBACK", "").strip().lower() in ("1", "true", "yes", "on"):
+ return False
+ return bool(_FP8_MM_SUPPORT.get(dtype, False))
+
+
+def _reshape_scale(scale, weight):
+ if scale.ndim == 0 or scale.numel() == 1:
+ return scale
+ if scale.ndim == 1 and scale.shape[0] == weight.shape[0]:
+ return scale.view(weight.shape[0], *([1] * (weight.ndim - 1)))
+ if scale.ndim == 2 and scale.shape[0] == weight.shape[0] and scale.shape[1] == 1:
+ return scale.view(weight.shape[0], *([1] * (weight.ndim - 1)))
+ return scale
+
+
+def _normalize_scaled_mm_scale(scale):
+ if not torch.is_tensor(scale):
+ return None
+ if scale.numel() != 1:
+ return None
+ if scale.ndim == 0:
+ return scale
+ return scale.reshape(())
+
+
+def _quantize_activation(x, fp8_dtype):
+ minv, maxv = _FP8_RANGE[fp8_dtype]
+ absmax = x.abs().max().float()
+ scale = absmax / maxv
+ scale = torch.where(absmax > 0, scale, torch.ones_like(scale))
+ scale_f16 = scale.to(dtype=x.dtype)
+ q = (x / scale_f16).clamp(minv, maxv).to(fp8_dtype)
+ return q, scale.reshape(()).to(torch.float32)
+
+
+def _scaled_mm_static_ok(weight, scale):
+ if weight is None or scale is None:
+ return False
+ if not _is_float8_dtype(weight.dtype):
+ return False
+ if weight.ndim != 2:
+ return False
+ if not weight.is_contiguous():
+ return False
+ if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
+ return False
+ if scale.numel() != 1:
+ return False
+ return True
+
+
+
+
+def _init_scaled_mm_support():
+ global _FP8_MM_PROBED
+ if _FP8_MM_PROBED:
+ return
+ _FP8_MM_PROBED = True
+ if not torch.cuda.is_available():
+ return
+ support = {torch.float8_e4m3fn: False, torch.float8_e5m2: False}
+ try:
+ device = torch.device("cuda", torch.cuda.current_device())
+ with torch.cuda.device(device):
+ a = torch.randn(1, 16, device=device, dtype=torch.float16)
+ b = torch.randn(16, 16, device=device, dtype=torch.float16)
+ scale = torch.ones((), device=device, dtype=torch.float32)
+ for fp8_dtype in support:
+ try:
+ a_fp8 = a.to(fp8_dtype)
+ b_fp8 = b.to(fp8_dtype)
+ torch._scaled_mm(a_fp8, b_fp8.t(), scale, scale, out_dtype=torch.float16)
+ support[fp8_dtype] = True
+ except Exception:
+ support[fp8_dtype] = False
+ except Exception:
+ support = {torch.float8_e4m3fn: False, torch.float8_e5m2: False}
+ _FP8_MM_SUPPORT.update(support)
+
+
+def _scaled_fp8_qfallback(callable, *args, **kwargs):
+ args, kwargs = pytree.tree_map_only(ScaledFP8WeightTensor, lambda x: x.dequantize(), (args, kwargs or {}))
+ return callable(*args, **kwargs)
+
+
+class ScaledFP8WeightTensor(QTensor):
+ @staticmethod
+ def create(weight, scale, size, stride, dtype, device=None, requires_grad=False):
+ if scale is None:
+ scale = torch.ones((), device=weight.device, dtype=torch.float32)
+ qtype = _get_fp8_qtype(weight.dtype)
+ if qtype is None:
+ raise TypeError(f"Scaled FP8 weight requires float8 dtype, got {weight.dtype}.")
+ dtype = _normalize_default_dtype(dtype)
+ return ScaledFP8WeightTensor(
+ qtype=qtype,
+ axis=0,
+ size=size,
+ stride=stride,
+ weight=weight,
+ scale=scale,
+ dtype=dtype,
+ requires_grad=requires_grad,
+ )
+
+ @staticmethod
+ def __new__(cls, qtype, axis, size, stride, weight, scale, dtype, requires_grad=False):
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ size,
+ strides=stride,
+ dtype=dtype,
+ device=weight.device,
+ requires_grad=requires_grad,
+ )
+
+ def __init__(self, qtype, axis, size, stride, weight, scale, dtype, requires_grad=False):
+ super().__init__(qtype, axis)
+ self._data = weight
+ self._scale = scale
+ self._scaled_mm_static_ok = _scaled_mm_static_ok(self._data, self._scale)
+ self._set_linear_impl()
+
+ def __repr__(self):
+ cls_name = self.__class__.__name__
+ try:
+ shape = tuple(self.shape)
+ except Exception:
+ shape = ">"
+ try:
+ dtype = str(self.dtype).replace("torch.", "")
+ except Exception:
+ dtype = ">"
+ try:
+ device = str(self.device)
+ except Exception:
+ device = ">"
+ qtype = getattr(self, "_qtype", None)
+ qtype_name = getattr(qtype, "name", None) or str(qtype) if qtype is not None else ">"
+ return f"{cls_name}(shape={shape}, dtype={dtype}, device={device}, qtype={qtype_name})"
+
+ __str__ = __repr__
+
+ def _set_linear_impl(self):
+ if self._scaled_mm_static_ok and _scaled_mm_available(self._data.dtype):
+ self._linear_impl = ScaledFP8WeightTensor._linear_scaled
+ else:
+ self._linear_impl = ScaledFP8WeightTensor._linear_fallback
+
+ def linear(self, input, bias=None):
+ impl = getattr(self, "_linear_impl", None)
+ if impl is None:
+ self._set_linear_impl()
+ impl = self._linear_impl
+ return impl(self, input, bias)
+
+ def dequantize(self, dtype=None, device=None):
+ if dtype is None:
+ dtype = self.dtype
+ if device is None:
+ device = self.device
+ data = self._data if self._data.device == device else self._data.to(device)
+ scale = self._scale if self._scale.device == device else self._scale.to(device)
+ out = data.to(dtype)
+ if scale.numel() == 1:
+ return out * scale.to(dtype)
+ return out * _reshape_scale(scale.to(dtype), out)
+
+ def _linear_fallback(self, input, bias=None):
+ qweight= self
+ target_type = _normalize_default_dtype(qweight.dtype)
+ weights, output_scales = qweight._data, qweight._scale
+ input = input.to(target_type)
+ output_scales = output_scales.to(target_type)
+ in_features = input.shape[-1]
+ out_features = weights.shape[0]
+ output_shape = input.shape[:-1] + (out_features,)
+ weights = weights.to(target_type)
+ weights *= output_scales
+ out = torch.matmul(input.reshape(-1, in_features), weights.t())
+ out = out.reshape(output_shape)
+ if bias is not None:
+ out += bias
+ return out
+
+ def _linear_scaled(self, input, bias=None):
+ if not torch.is_tensor(input):
+ return torch.nn.functional.linear(input, self.dequantize(), bias)
+ if (
+ not input.is_floating_point()
+ or input.dtype not in (torch.float16, torch.bfloat16, torch.float32)
+ or input.ndim < 2
+ or input.device.type != "cuda"
+ or input.device != self._data.device
+ or input.shape[-1] != self._data.shape[1]
+ ):
+ return torch.nn.functional.linear(input, self.dequantize(dtype=input.dtype, device=input.device), bias)
+
+ scale_b = _normalize_scaled_mm_scale(self._scale)
+ if scale_b is None:
+ return torch.nn.functional.linear(input, self.dequantize(dtype=input.dtype, device=input.device), bias)
+
+ x2d = input.reshape(-1, input.shape[-1])
+ if not x2d.is_contiguous():
+ x2d = x2d.contiguous()
+
+ if x2d.shape[1] % 16 != 0 or self._data.shape[0] % 16 != 0:
+ return torch.nn.functional.linear(input, self.dequantize(dtype=input.dtype, device=input.device), bias)
+
+ fp8_dtype = self._data.dtype
+ x_fp8, scale_a = _quantize_activation(x2d, fp8_dtype)
+ scale_a = _normalize_scaled_mm_scale(scale_a)
+ if scale_a is None:
+ return torch.nn.functional.linear(input, self.dequantize(dtype=input.dtype, device=input.device), bias)
+ scale_b = scale_b.to(device=x_fp8.device, dtype=torch.float32)
+
+ bias_arg = bias
+ if bias_arg is not None:
+ if bias_arg.device != x_fp8.device:
+ bias_arg = bias_arg.to(x_fp8.device)
+ if bias_arg.dtype != input.dtype:
+ bias_arg = bias_arg.to(dtype=input.dtype)
+
+ out = torch._scaled_mm(
+ x_fp8,
+ self._data.t(),
+ scale_a,
+ scale_b,
+ bias=bias_arg,
+ out_dtype=input.dtype,
+ )
+
+ return out.reshape(*input.shape[:-1], self._data.shape[0])
+
+ def get_quantized_subtensors(self):
+ return [("scale", self._scale), ("data", self._data)]
+
+ def set_quantized_subtensors(self, sub_tensors):
+ if isinstance(sub_tensors, dict):
+ sub_map = sub_tensors
+ else:
+ sub_map = {name: tensor for name, tensor in sub_tensors}
+ data = sub_map.get("data", None)
+ if data is not None:
+ self._data = data
+ scale = sub_map.get("scale", None)
+ if scale is not None:
+ self._scale = scale
+ self._scaled_mm_static_ok = _scaled_mm_static_ok(self._data, self._scale)
+ self._set_linear_impl()
+
+ def __tensor_flatten__(self):
+ inner_tensors = ["_data", "_scale"]
+ meta = {
+ "qtype": self._qtype.name,
+ "axis": str(self._axis),
+ "size": str(list(self.size())),
+ "stride": str(list(self.stride())),
+ "dtype": str(self.dtype),
+ }
+ return inner_tensors, meta
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+ qtype = _quanto_qtypes[meta["qtype"]]
+ axis = ast.literal_eval(meta["axis"])
+ size = ast.literal_eval(meta["size"])
+ stride = ast.literal_eval(meta["stride"])
+ dtype_str = meta.get("dtype", "torch.float16")
+ if dtype_str.startswith("torch."):
+ dtype_name = dtype_str.split(".", 1)[1]
+ dtype = getattr(torch, dtype_name, torch.float16)
+ else:
+ dtype = getattr(torch, dtype_str, torch.float16)
+ dtype = _normalize_default_dtype(dtype)
+ return ScaledFP8WeightTensor(
+ qtype=qtype,
+ axis=axis,
+ size=size,
+ stride=stride,
+ weight=inner_tensors["_data"],
+ scale=inner_tensors["_scale"],
+ dtype=dtype,
+ )
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if func is torch.nn.functional.linear:
+ input = args[0] if len(args) > 0 else kwargs.get("input", None)
+ weight = args[1] if len(args) > 1 else kwargs.get("weight", None)
+ bias = args[2] if len(args) > 2 else kwargs.get("bias", None)
+ if isinstance(weight, ScaledFP8WeightTensor):
+ return weight.linear(input, bias=bias)
+ with torch._C.DisableTorchFunctionSubclass():
+ return func(*args, **kwargs)
+
+ @classmethod
+ def __torch_dispatch__(cls, op, types, args, kwargs=None):
+ op = op.overloadpacket
+ kwargs = kwargs or {}
+ if op is torch.ops.aten.linear:
+ input = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ if isinstance(weight, ScaledFP8WeightTensor):
+ return weight.linear(input, bias=bias)
+ if op is torch.ops.aten.detach:
+ t = args[0]
+ return ScaledFP8WeightTensor.create(
+ weight=op(t._data),
+ scale=op(t._scale),
+ size=t.size(),
+ stride=t.stride(),
+ dtype=t.dtype,
+ device=t.device,
+ requires_grad=t.requires_grad,
+ )
+ if op in (torch.ops.aten._to_copy, torch.ops.aten.to):
+ t = args[0]
+ dtype = kwargs.pop("dtype", t.dtype) if kwargs else t.dtype
+ device = kwargs.pop("device", t.device) if kwargs else t.device
+ if dtype != t.dtype:
+ return t.dequantize(dtype=dtype, device=device)
+ out_data = op(t._data, device=device, **(kwargs or {}))
+ out_scale = op(t._scale, device=device, **(kwargs or {}))
+ return ScaledFP8WeightTensor.create(
+ weight=out_data,
+ scale=out_scale,
+ size=t.size(),
+ stride=t.stride(),
+ dtype=t.dtype,
+ device=device,
+ requires_grad=t.requires_grad,
+ )
+ return _scaled_fp8_qfallback(op, *args, **(kwargs or {}))
+
+
+class QLinearScaledFP8(QModuleMixin, torch.nn.Linear):
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ bias=True,
+ device=None,
+ dtype=None,
+ weights=None,
+ activations=None,
+ optimizer=None,
+ quantize_input=True,
+ ):
+ super().__init__(
+ in_features,
+ out_features,
+ bias=bias,
+ device=device,
+ dtype=dtype,
+ weights=weights,
+ activations=activations,
+ optimizer=optimizer,
+ quantize_input=quantize_input,
+ )
+ self._scaled_fp8_default_dtype = _normalize_default_dtype(dtype)
+
+ @classmethod
+ def qcreate(cls, module, weights, activations=None, optimizer=None, device=None):
+ if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point:
+ weight_dtype = module.weight.dtype
+ elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point:
+ weight_dtype = module.bias.dtype
+ else:
+ weight_dtype = torch.float16
+ weight_dtype = _normalize_default_dtype(weight_dtype)
+ return cls(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ device=device,
+ dtype=weight_dtype,
+ weights=weights,
+ activations=activations,
+ optimizer=optimizer,
+ quantize_input=True,
+ )
+
+ def set_default_dtype(self, dtype):
+ self._scaled_fp8_default_dtype = _normalize_default_dtype(dtype)
+
+ @property
+ def qweight(self):
+ if self.weight_qtype in _SCALED_FP8_QTYPES:
+ return self.weight
+ return super().qweight
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ qweight = self.qweight
+ if (
+ getattr(qweight, "_scaled_mm_static_ok", False)
+ and _scaled_mm_available(qweight._data.dtype)
+ and torch.is_tensor(input)
+ and input.is_floating_point()
+ and input.dtype in (torch.float16, torch.bfloat16, torch.float32)
+ and input.ndim >= 2
+ and input.device.type == "cuda"
+ and input.device == qweight._data.device
+ and input.shape[-1] == qweight._data.shape[1]
+ ):
+ return ScaledFP8WeightTensor._linear_scaled(qweight, input, bias=self.bias)
+ return ScaledFP8WeightTensor._linear_fallback(qweight, input, bias=self.bias)
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ if self.weight_qtype not in _SCALED_FP8_QTYPES:
+ return super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ weight_key = prefix + "weight"
+ scale_key = prefix + "scale_weight"
+ alt_scale_key = prefix + "weight_scale"
+ bias_key = prefix + "bias"
+ input_scale_key = prefix + "input_scale"
+ output_scale_key = prefix + "output_scale"
+
+ weight = state_dict.pop(weight_key, None)
+ scale = state_dict.pop(scale_key, None)
+ alt_scale = state_dict.pop(alt_scale_key, None)
+ if scale is None:
+ scale = alt_scale
+ bias = state_dict.pop(bias_key, None)
+ input_scale = state_dict.pop(input_scale_key, None)
+ output_scale = state_dict.pop(output_scale_key, None)
+ # input_scale isn't used in FP8 inference; drop to avoid persisting it.
+ input_scale = None
+
+ if weight is None:
+ missing_keys.append(weight_key)
+
+ target_dtype = _normalize_default_dtype(self._scaled_fp8_default_dtype or self.weight.dtype)
+ if weight is not None:
+ qweight = ScaledFP8WeightTensor.create(
+ weight=weight,
+ scale=scale,
+ size=self.weight.size(),
+ stride=self.weight.stride(),
+ dtype=target_dtype,
+ device=weight.device,
+ requires_grad=False,
+ )
+ self.weight = torch.nn.Parameter(qweight, requires_grad=False)
+
+ if bias is not None:
+ if target_dtype is not None and bias.dtype != target_dtype:
+ bias = bias.to(target_dtype)
+ self.bias = torch.nn.Parameter(bias)
+
+ if torch.is_tensor(weight):
+ scale_device = weight.device
+ elif torch.is_tensor(self.weight):
+ scale_device = self.weight.device
+ elif torch.is_tensor(bias):
+ scale_device = bias.device
+ else:
+ scale_device = torch.device("cpu")
+
+ if input_scale is not None:
+ self.input_scale = input_scale.to(scale_device)
+ else:
+ if not hasattr(self, "input_scale") or self.input_scale.is_meta:
+ scale_dtype = self.input_scale.dtype if hasattr(self, "input_scale") else torch.float32
+ self.input_scale = torch.ones((), dtype=scale_dtype, device=scale_device)
+
+ if output_scale is not None:
+ self.output_scale = output_scale.to(scale_device)
+ else:
+ if not hasattr(self, "output_scale") or self.output_scale.is_meta:
+ scale_dtype = self.output_scale.dtype if hasattr(self, "output_scale") else torch.float32
+ self.output_scale = torch.ones((), dtype=scale_dtype, device=scale_device)
+
+ return
+
+
+def _collect_fp8_specs(state_dict):
+ specs = []
+ for key, tensor in state_dict.items():
+ if not key.endswith(".weight"):
+ continue
+ if not _is_float8_dtype(tensor.dtype):
+ continue
+ base = key[:-7]
+ specs.append(
+ {
+ "name": base,
+ "weight": tensor,
+ }
+ )
+ return specs
+
+
+def detect(state_dict, verboseLevel=1):
+ specs = _collect_fp8_specs(state_dict)
+ if not specs:
+ return {"matched": False, "kind": "none", "details": {}}
+ names = [spec["name"] for spec in specs][:8]
+ return {"matched": True, "kind": "fp8", "details": {"count": len(specs), "names": names}}
+
+
+def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None):
+ if detection is not None and not detection.get("matched", False):
+ return {"state_dict": state_dict, "quant_map": {}}
+ _set_default_dtype_from_loader(default_dtype)
+ if "scaled_fp8" in state_dict:
+ state_dict.pop("scaled_fp8", None)
+ specs = _collect_fp8_specs(state_dict)
+ if not specs:
+ return {"state_dict": state_dict, "quant_map": {}}
+ quant_map = {}
+ for spec in specs:
+ qtype = _get_fp8_qtype(spec["weight"].dtype)
+ if qtype is None:
+ continue
+ quant_map[spec["name"]] = {"weights": qtype.name, "activations": "none"}
+ return {"state_dict": state_dict, "quant_map": quant_map}
+
+
+def _resolve_default_dtype(model, default_dtype):
+ if default_dtype is not None:
+ return default_dtype
+ if model is not None:
+ model_dtype = getattr(model, "_dtype", None) or getattr(model, "dtype", None)
+ if isinstance(model_dtype, torch.dtype):
+ return model_dtype
+ for _, param in model.named_parameters():
+ if torch.is_tensor(param) and param.dtype.is_floating_point:
+ return param.dtype
+ return torch.bfloat16
+
+
+def _collect_linear_param_keys(model):
+ if model is None:
+ return set()
+ keys = set()
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ keys.add(f"{name}.weight")
+ if module.bias is not None:
+ keys.add(f"{name}.bias")
+ return keys
+
+
+def _cast_non_linear_float8_params(model, target_dtype):
+ if model is None:
+ return
+ for module in model.modules():
+ if isinstance(module, torch.nn.Linear) or isinstance(module, QModuleMixin):
+ continue
+ for name, param in list(module.named_parameters(recurse=False)):
+ if torch.is_tensor(param) and _is_float8_dtype(param.dtype):
+ module._parameters[name] = torch.nn.Parameter(
+ param.to(dtype=target_dtype),
+ requires_grad=False,
+ )
+ for name, buf in list(module.named_buffers(recurse=False)):
+ if torch.is_tensor(buf) and _is_float8_dtype(buf.dtype):
+ module._buffers[name] = buf.to(dtype=target_dtype)
+
+
+def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1):
+ _set_default_dtype_from_loader(default_dtype)
+ if not quantization_map:
+ quantization_map = {}
+
+ has_float8 = False
+ for tensor in state_dict.values():
+ if torch.is_tensor(tensor) and _is_float8_dtype(tensor.dtype):
+ has_float8 = True
+ break
+ if not quantization_map and not has_float8:
+ return quantization_map or {}, []
+
+ target_dtype = _resolve_default_dtype(model, default_dtype)
+ linear_param_keys = _collect_linear_param_keys(model)
+
+ to_cast = []
+ for key, tensor in state_dict.items():
+ if not torch.is_tensor(tensor):
+ continue
+ if not _is_float8_dtype(tensor.dtype):
+ continue
+ if key.endswith(".weight") or key.endswith(".bias"):
+ module_name = key.rsplit(".", 1)[0]
+ else:
+ continue
+ if key in linear_param_keys:
+ continue
+ to_cast.append((key, module_name, tensor))
+
+ for key, module_name, tensor in to_cast:
+ state_dict[key] = tensor.to(dtype=target_dtype)
+ state_dict.pop(module_name + ".scale_weight", None)
+ state_dict.pop(module_name + ".weight_scale", None)
+ state_dict.pop(module_name + ".input_scale", None)
+ state_dict.pop(module_name + ".output_scale", None)
+ if module_name in quantization_map:
+ del quantization_map[module_name]
+
+ post_load = []
+ def _post_cast(model):
+ cast_dtype = _resolve_default_dtype(model, default_dtype)
+ _cast_non_linear_float8_params(model, cast_dtype)
+
+ post_load.append(_post_cast)
+
+ return quantization_map or {}, post_load
+
+
+_init_scaled_mm_support()
diff --git a/Wan2GP/shared/utils/audio_metadata.py b/Wan2GP/shared/utils/audio_metadata.py
index e3560e6a3..300818fcd 100644
--- a/Wan2GP/shared/utils/audio_metadata.py
+++ b/Wan2GP/shared/utils/audio_metadata.py
@@ -1,6 +1,7 @@
import struct
from typing import Optional
import json
+import os
def write_wav_text_chunk(in_path: str, out_path: str, text: str,
fourcc: bytes = b'json', encoding: str = 'utf-8') -> None:
@@ -96,8 +97,61 @@ def read_wav_text_chunk(path: str, fourcc: bytes = b'json', encoding: str = 'utf
return None
+def _write_mp3_text_tag(path: str, text: str, tag_key: str = "WanGP") -> None:
+ try:
+ from mutagen.id3 import ID3, ID3NoHeaderError, TXXX
+ except Exception as exc:
+ raise RuntimeError("mutagen is required for mp3 metadata") from exc
+ try:
+ tag = ID3(path)
+ except ID3NoHeaderError:
+ tag = ID3()
+ for key in list(tag.keys()):
+ frame = tag.get(key)
+ if isinstance(frame, TXXX) and frame.desc == tag_key:
+ del tag[key]
+ tag.add(TXXX(encoding=3, desc=tag_key, text=[text]))
+ tag.save(path)
+
+
+def _read_mp3_text_tag(path: str, tag_key: str = "WanGP") -> Optional[str]:
+ try:
+ from mutagen.id3 import ID3, ID3NoHeaderError, TXXX, COMM
+ except Exception:
+ return None
+ try:
+ tag = ID3(path)
+ except ID3NoHeaderError:
+ return None
+ for frame in tag.getall("TXXX"):
+ if isinstance(frame, TXXX) and frame.desc == tag_key:
+ if frame.text:
+ return frame.text[0]
+ for frame in tag.getall("COMM"):
+ if isinstance(frame, COMM) and frame.desc == tag_key:
+ return frame.text[0] if frame.text else None
+ return None
+
+
def save_audio_metadata(path, configs):
- write_wav_text_chunk(path, path, json.dumps(configs))
+ ext = os.path.splitext(path)[1].lower()
+ payload = json.dumps(configs)
+ if ext == ".mp3":
+ _write_mp3_text_tag(path, payload)
+ elif ext == ".wav":
+ write_wav_text_chunk(path, path, payload)
+ else:
+ raise ValueError(f"Unsupported audio metadata format: {ext}")
+
def read_audio_metadata(path):
- return json.loads(read_wav_text_chunk(path))
\ No newline at end of file
+ ext = os.path.splitext(path)[1].lower()
+ if ext == ".mp3":
+ raw = _read_mp3_text_tag(path)
+ elif ext == ".wav":
+ raw = read_wav_text_chunk(path)
+ else:
+ return None
+ if not raw:
+ return None
+ return json.loads(raw)
diff --git a/Wan2GP/shared/utils/audio_video.py b/Wan2GP/shared/utils/audio_video.py
index 74eb9a1a7..d581143b1 100644
--- a/Wan2GP/shared/utils/audio_video.py
+++ b/Wan2GP/shared/utils/audio_video.py
@@ -12,6 +12,8 @@
from PIL import Image
import os.path as osp
import json
+import numpy as np
+import soundfile as sf
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
@@ -22,6 +24,83 @@ def rand_name(length=8, suffix=''):
return name
+def _prepare_audio_array(audio_data):
+ if torch.is_tensor(audio_data):
+ audio_data = audio_data.detach().cpu().float().numpy()
+ else:
+ audio_data = np.asarray(audio_data, dtype=np.float32)
+ if audio_data.ndim == 2 and audio_data.shape[0] <= 8 and audio_data.shape[1] > audio_data.shape[0]:
+ audio_data = audio_data.T
+ return audio_data
+
+
+def write_wav_file(path, audio_data, sample_rate):
+ audio_array = _prepare_audio_array(audio_data)
+ sf.write(path, audio_array, int(sample_rate))
+ return path
+
+
+def _get_audio_codec_settings(codec_key):
+ if not codec_key:
+ codec_key = "wav"
+ codec_key = str(codec_key).lower()
+ if codec_key == "mp3":
+ codec_key = "mp3_192"
+ settings = {
+ "wav": {"ext": "wav", "format": "wav"},
+ "mp3_128": {"ext": "mp3", "format": "mp3", "bitrate": "128k"},
+ "mp3_192": {"ext": "mp3", "format": "mp3", "bitrate": "192k"},
+ "mp3_320": {"ext": "mp3", "format": "mp3", "bitrate": "320k"},
+ }
+ return settings.get(codec_key, settings["wav"])
+
+
+def get_audio_codec_extension(codec_key):
+ return _get_audio_codec_settings(codec_key)["ext"]
+
+
+def _run_ffmpeg_encode(input_path, output_path, codec, bitrate=None, sample_rate=None, drop_video=False):
+ cmd = ["ffmpeg", "-y", "-v", "error", "-i", input_path]
+ if drop_video:
+ cmd.append("-vn")
+ cmd += ["-c:a", codec]
+ if bitrate:
+ cmd += ["-b:a", bitrate]
+ if sample_rate:
+ cmd += ["-ar", str(int(sample_rate))]
+ cmd.append(output_path)
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
+
+
+def save_audio_file(path, audio_data, sample_rate, codec_key="wav"):
+ settings = _get_audio_codec_settings(codec_key)
+ ext = settings["ext"]
+ if not path.lower().endswith(f".{ext}"):
+ path = osp.splitext(path)[0] + f".{ext}"
+ if settings["format"] == "wav":
+ return write_wav_file(path, audio_data, sample_rate)
+ fd, tmp_path = tempfile.mkstemp(suffix=".wav", prefix="audio_")
+ os.close(fd)
+ try:
+ write_wav_file(tmp_path, audio_data, sample_rate)
+ _run_ffmpeg_encode(tmp_path, path, "libmp3lame", bitrate=settings.get("bitrate"), sample_rate=sample_rate)
+ finally:
+ try:
+ os.remove(tmp_path)
+ except OSError:
+ pass
+ return path
+
+
+def extract_audio_track_to_wav(video_path, output_path):
+ if not video_path:
+ return None
+ video_path = os.fspath(video_path)
+ import ffmpeg
+ ffmpeg.input(video_path).output(output_path, **{"map": "0:a:0", "acodec": "pcm_s16le"}).overwrite_output().run(quiet=True)
+ return output_path
+
+
def extract_audio_tracks(source_video, verbose=False, query_only=False):
"""
@@ -263,27 +342,46 @@ def save_video(tensor,
error = None
for _ in range(retry):
try:
- if torch.is_tensor(tensor):
- # Preprocess tensor
- tensor = tensor.clamp(min(value_range), max(value_range))
- tensor = torch.stack([
- torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range)
- for u in tensor.unbind(2)
- ], dim=1).permute(1, 2, 3, 0)
- tensor = (tensor * 255).type(torch.uint8).cpu()
- arrays = tensor.numpy()
- else:
- arrays = tensor
-
# Write video (silence ffmpeg logs)
writer = imageio.get_writer(cache_file, fps=fps, ffmpeg_log_level='error', **codec_params)
- for frame in arrays:
- writer.append_data(frame)
-
- writer.close()
+ try:
+ if torch.is_tensor(tensor):
+ # Stream frames to avoid materializing the full video on CPU.
+ if tensor.dtype == torch.uint8 and tensor.ndim == 5 and tensor.shape[0] == 1 and nrow == 1:
+ frames = tensor[0].permute(1, 2, 3, 0)
+ for frame in frames:
+ writer.append_data(frame.cpu().numpy())
+ else:
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float().div_(127.5).sub_(1.0)
+ for u in tensor.unbind(2):
+ u = u.clamp(min(value_range), max(value_range))
+ grid = torchvision.utils.make_grid(
+ u, nrow=nrow, normalize=normalize, value_range=value_range
+ )
+ frame = grid.mul(255).type(torch.uint8).permute(1, 2, 0).cpu().numpy()
+ writer.append_data(frame)
+ elif isinstance(tensor, (list, tuple)) and tensor and torch.is_tensor(tensor[0]):
+ for chunk in tensor:
+ if chunk is None:
+ continue
+ if chunk.ndim == 4:
+ if chunk.shape[-1] in (1, 3, 4):
+ frames = chunk
+ else:
+ frames = chunk.permute(1, 2, 3, 0)
+ for frame in frames:
+ writer.append_data(frame.cpu().numpy())
+ else:
+ writer.append_data(chunk)
+ else:
+ for frame in tensor:
+ writer.append_data(frame)
+ finally:
+ writer.close()
return cache_file
-
+
except Exception as e:
error = e
print(f"error saving {save_file}: {e}")
@@ -445,3 +543,4 @@ def read_image_metadata(image_path):
return None
except Exception as e:
print(f"Error reading metadata: {e}"); return None
+
diff --git a/Wan2GP/shared/utils/basic_flowmatch.py b/Wan2GP/shared/utils/basic_flowmatch.py
index ceb4657b0..d9e6b823d 100644
--- a/Wan2GP/shared/utils/basic_flowmatch.py
+++ b/Wan2GP/shared/utils/basic_flowmatch.py
@@ -5,6 +5,7 @@
class FlowMatchScheduler():
+ is_stateful = False
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
self.num_train_timesteps = num_train_timesteps
diff --git a/Wan2GP/shared/utils/euler_scheduler.py b/Wan2GP/shared/utils/euler_scheduler.py
new file mode 100644
index 000000000..c2f4a1dae
--- /dev/null
+++ b/Wan2GP/shared/utils/euler_scheduler.py
@@ -0,0 +1,86 @@
+import numpy as np
+import torch
+
+
+def _timestep_transform(t, shift=5.0, num_timesteps=1000):
+ t = t / num_timesteps
+ new_t = shift * t / (1 + (shift - 1) * t)
+ return new_t * num_timesteps
+
+
+class EulerSchedulerOutput:
+ def __init__(self, prev_sample, pred_original_sample=None):
+ self.prev_sample = prev_sample
+ if pred_original_sample is not None:
+ self.pred_original_sample = pred_original_sample
+
+ def __getitem__(self, index):
+ if index == 0:
+ return self.prev_sample
+ raise IndexError("EulerSchedulerOutput only supports index 0.")
+
+ def __iter__(self):
+ yield self.prev_sample
+
+
+class EulerScheduler:
+ is_stateful = False
+
+ def __init__(self, num_train_timesteps=1000, use_timestep_transform=True):
+ self.num_train_timesteps = num_train_timesteps
+ self.use_timestep_transform = use_timestep_transform
+ self.timesteps = None
+ self.num_inference_steps = None
+
+ def set_timesteps(self, num_inference_steps, device=None, shift=5.0):
+ self.num_inference_steps = num_inference_steps
+ timesteps = list(
+ np.linspace(self.num_train_timesteps, 1, num_inference_steps, dtype=np.float32)
+ )
+ timesteps.append(0.0)
+ if device is None:
+ timesteps = [torch.tensor([t]) for t in timesteps]
+ else:
+ timesteps = [torch.tensor([t], device=device) for t in timesteps]
+ if self.use_timestep_transform:
+ timesteps = [
+ _timestep_transform(t, shift=shift, num_timesteps=self.num_train_timesteps)
+ for t in timesteps
+ ][:-1]
+ self.timesteps = torch.tensor(timesteps)
+ return self.timesteps
+
+ def _timestep_to_index(self, timestep):
+ if self.timesteps is None:
+ raise ValueError("Timesteps are not set. Call set_timesteps first.")
+ if torch.is_tensor(timestep):
+ if timestep.numel() != 1:
+ t_val = timestep.flatten()[0].item()
+ else:
+ t_val = timestep.item()
+ else:
+ t_val = float(timestep)
+ diff = (self.timesteps - t_val).abs()
+ idx = int(torch.argmin(diff).item())
+ return idx, t_val
+
+ def step(self, model_output, timestep, sample, return_dict=True, **kwargs):
+ if self.timesteps is None:
+ raise ValueError("Timesteps are not set. Call set_timesteps first.")
+ idx, t_val = self._timestep_to_index(timestep)
+ if idx + 1 < len(self.timesteps):
+ dt_raw = self.timesteps[idx] - self.timesteps[idx + 1]
+ else:
+ dt_raw = self.timesteps[idx]
+ dt = dt_raw.item() / self.num_train_timesteps
+ prev_sample = sample - model_output * dt
+ pred_original_sample = sample - (t_val / self.num_train_timesteps) * model_output
+ if not return_dict:
+ return (prev_sample,)
+ return EulerSchedulerOutput(
+ prev_sample=prev_sample,
+ pred_original_sample=pred_original_sample,
+ )
+
+ def scale_model_input(self, sample, *args, **kwargs):
+ return sample
diff --git a/Wan2GP/shared/utils/files_locator.py b/Wan2GP/shared/utils/files_locator.py
index 0c455ce1e..8d4bec47a 100644
--- a/Wan2GP/shared/utils/files_locator.py
+++ b/Wan2GP/shared/utils/files_locator.py
@@ -12,12 +12,19 @@ def set_checkpoints_paths(checkpoints_paths):
_checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ]
if len(checkpoints_paths) == 0:
_checkpoints_paths = default_checkpoints_paths
-def get_download_location(file_name = None):
+def get_download_location(file_name = None, force_path= None):
if file_name is not None and os.path.isabs(file_name): return file_name
+ if force_path is not None and isinstance(force_path, list) and len(force_path): force_path = force_path[0]
if file_name is not None:
- return os.path.join(_checkpoints_paths[0], file_name)
+ if force_path is None:
+ return os.path.join(_checkpoints_paths[0], file_name)
+ else:
+ return os.path.join(_checkpoints_paths[0], force_path, file_name)
else:
- return _checkpoints_paths[0]
+ if force_path is None:
+ return _checkpoints_paths[0]
+ else:
+ return os.path.join(_checkpoints_paths[0])
def locate_folder(folder_name, error_if_none = True):
searched_locations = []
@@ -34,13 +41,15 @@ def locate_folder(folder_name, error_if_none = True):
return None
-def locate_file(file_name, create_path_if_none = False, error_if_none = True):
+def locate_file(file_name, create_path_if_none = False, error_if_none = True, extra_paths = None):
+ if file_name.startswith("http"):
+ file_name = os.path.basename(file_name)
searched_locations = []
if os.path.isabs(file_name):
if os.path.isfile(file_name): return file_name
searched_locations.append(file_name)
else:
- for folder in _checkpoints_paths:
+ for folder in _checkpoints_paths + ([] if extra_paths is None else extra_paths):
path = os.path.join(folder, file_name)
if os.path.isfile(path):
return path
diff --git a/Wan2GP/shared/utils/hf.py b/Wan2GP/shared/utils/hf.py
new file mode 100644
index 000000000..30607e0fd
--- /dev/null
+++ b/Wan2GP/shared/utils/hf.py
@@ -0,0 +1,10 @@
+import posixpath
+
+
+def build_hf_url(repo_id, *path_parts):
+ repo = (repo_id or "").strip("/")
+ parts = [part.strip("/\\") for part in path_parts if part]
+ path = posixpath.join(*parts) if parts else ""
+ if not path:
+ return f"https://huggingface.co/{repo}/resolve/main"
+ return f"https://huggingface.co/{repo}/resolve/main/{path}"
diff --git a/Wan2GP/shared/utils/loras_mutipliers.py b/Wan2GP/shared/utils/loras_mutipliers.py
index 73d6e54bd..42a6cf923 100644
--- a/Wan2GP/shared/utils/loras_mutipliers.py
+++ b/Wan2GP/shared/utils/loras_mutipliers.py
@@ -113,20 +113,6 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st
sz = len(slists_dict["phase1"])
slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ]
nos = [str(l) for l in range(sz)]
-
- # [LORA_STEP_MULTIPLIERS] Log the actual per-step multipliers being set for each LoRA
- print(f"[LORA_STEP_MULTIPLIERS] ═══════════════════════════════════════════════════════")
- print(f"[LORA_STEP_MULTIPLIERS] Setting per-step LoRA multipliers for {sz} LoRAs, {num_inference_steps} steps")
- print(f"[LORA_STEP_MULTIPLIERS] phase_switch_step={phase_switch_step}, phase_switch_step2={phase_switch_step2}")
- for lora_idx, slist in enumerate(slists):
- if isinstance(slist, list):
- # Per-step multipliers
- print(f"[LORA_STEP_MULTIPLIERS] LoRA[{lora_idx}]: {slist}")
- else:
- # Constant multiplier across all steps
- print(f"[LORA_STEP_MULTIPLIERS] LoRA[{lora_idx}]: constant={slist} for all steps")
- print(f"[LORA_STEP_MULTIPLIERS] ═══════════════════════════════════════════════════════")
-
offload.activate_loras(trans, nos, slists )
diff --git a/Wan2GP/shared/utils/plugins.py b/Wan2GP/shared/utils/plugins.py
index 4d13eedf4..fd454311e 100644
--- a/Wan2GP/shared/utils/plugins.py
+++ b/Wan2GP/shared/utils/plugins.py
@@ -3,6 +3,8 @@
import importlib
import importlib.util
import inspect
+import re
+import datetime
from typing import Dict, Any, Optional, List, Union, Set
from dataclasses import dataclass
import gradio as gr
@@ -12,7 +14,145 @@
import shutil
import stat
import json
+import requests
video_gen_label = "Video Generator"
+
+COMMUNITY_PLUGINS_URL = "https://github.com/deepbeepmeep/Wan2GP/raw/refs/heads/main/plugins.json"
+PLUGIN_CATALOG_FILENAME = "plugins.json"
+PLUGIN_LOCAL_CATALOG_FILENAME = "plugins_local.json"
+PLUGIN_METADATA_FILENAME = "plugin_info.json"
+PENDING_DELETIONS_KEY = "pending_plugin_deletions"
+
+def _has_value(value: Any) -> bool:
+ if value is None:
+ return False
+ if isinstance(value, str):
+ return value.strip() != ""
+ return True
+
+def _split_github_repo(url: str) -> Optional[tuple]:
+ if not isinstance(url, str):
+ return None
+ cleaned = url.strip()
+ if not cleaned:
+ return None
+ cleaned = cleaned.split("?", 1)[0].split("#", 1)[0]
+ if cleaned.startswith("git@github.com:"):
+ cleaned = "https://github.com/" + cleaned[len("git@github.com:"):]
+ cleaned = cleaned.rstrip("/")
+ marker = "github.com/"
+ idx = cleaned.lower().find(marker)
+ if idx < 0:
+ return None
+ tail = cleaned[idx + len(marker):]
+ parts = [part for part in tail.split("/") if part]
+ if len(parts) < 2:
+ return None
+ owner, repo = parts[0], parts[1]
+ if repo.endswith(".git"):
+ repo = repo[:-4]
+ if not owner or not repo:
+ return None
+ return owner, repo
+
+def normalize_plugin_url(url: str) -> str:
+ if not isinstance(url, str):
+ return ""
+ cleaned = url.strip()
+ if not cleaned:
+ return ""
+ cleaned = cleaned.split("?", 1)[0].split("#", 1)[0]
+ if cleaned.startswith("git@github.com:"):
+ cleaned = "https://github.com/" + cleaned[len("git@github.com:"):]
+ cleaned = cleaned.rstrip("/")
+ repo_info = _split_github_repo(cleaned)
+ if repo_info:
+ owner, repo = repo_info
+ return f"https://github.com/{owner}/{repo}"
+ if cleaned.endswith(".git"):
+ cleaned = cleaned[:-4]
+ return cleaned.rstrip("/")
+
+def _parse_version_parts(version: str) -> List[Any]:
+ if not isinstance(version, str):
+ return []
+ version = version.strip()
+ if not version:
+ return []
+ parts = re.split(r"[^0-9A-Za-z]+", version)
+ tokens = []
+ for part in parts:
+ if not part:
+ continue
+ for token in re.findall(r"\d+|[A-Za-z]+", part):
+ if token.isdigit():
+ tokens.append((0, int(token)))
+ else:
+ tokens.append((1, token.lower()))
+ return tokens
+
+def compare_versions(left: str, right: str) -> int:
+ left_text = left if isinstance(left, str) else ""
+ right_text = right if isinstance(right, str) else ""
+ left_has_digits = bool(re.search(r"\d", left_text))
+ right_has_digits = bool(re.search(r"\d", right_text))
+ if left_has_digits != right_has_digits:
+ return 1 if left_has_digits else -1
+ left_parts = _parse_version_parts(left_text)
+ right_parts = _parse_version_parts(right_text)
+ max_len = max(len(left_parts), len(right_parts))
+ if max_len == 0:
+ return 0
+ filler = (0, 0)
+ for idx in range(max_len):
+ left_part = left_parts[idx] if idx < len(left_parts) else filler
+ right_part = right_parts[idx] if idx < len(right_parts) else filler
+ if left_part == right_part:
+ continue
+ return 1 if left_part > right_part else -1
+ return 0
+
+def _parse_date(value: Any) -> Optional[datetime.datetime]:
+ if not isinstance(value, str):
+ return None
+ text = value.strip()
+ if not text:
+ return None
+ if text.endswith("Z"):
+ text = text[:-1] + "+00:00"
+ try:
+ return datetime.datetime.fromisoformat(text)
+ except ValueError:
+ return None
+
+def compare_release_metadata(left: Dict[str, Any], right: Dict[str, Any]) -> int:
+ left_date = _parse_date(left.get("date"))
+ right_date = _parse_date(right.get("date"))
+ if left_date or right_date:
+ if left_date and right_date:
+ if left_date != right_date:
+ return 1 if left_date > right_date else -1
+ elif left_date:
+ return 1
+ else:
+ return -1
+ return compare_versions(left.get("version", ""), right.get("version", ""))
+
+def is_wangp_compatible(required_version: str, current_version: str) -> bool:
+ if not _has_value(required_version):
+ return True
+ return compare_versions(current_version or "", required_version) >= 0
+
+def plugin_id_from_url(url: str) -> str:
+ if not isinstance(url, str):
+ return ""
+ repo_info = _split_github_repo(url)
+ if repo_info:
+ return repo_info[1]
+ clean = normalize_plugin_url(url)
+ if not clean:
+ return ""
+ return clean.split("/")[-1]
def auto_install_and_enable_default_plugins(manager: 'PluginManager', wgp_globals: dict):
server_config = wgp_globals.get("server_config")
server_config_filename = wgp_globals.get("server_config_filename")
@@ -21,13 +161,11 @@ def auto_install_and_enable_default_plugins(manager: 'PluginManager', wgp_global
print("[Plugins] WARNING: Cannot auto-install/enable default plugins. Server config not found.")
return
- default_plugins = {
- "wan2gp-gallery": "https://github.com/Tophness/wan2gp-gallery.git",
- "wan2gp-lora-multipliers-ui": "https://github.com/Tophness/wan2gp-lora-multipliers-ui.git"
- }
+ default_plugins = {}
config_modified = False
enabled_plugins = server_config.get("enabled_plugins", [])
+ newly_installed = []
for repo_name, url in default_plugins.items():
target_dir = os.path.join(manager.plugins_dir, repo_name)
@@ -37,13 +175,16 @@ def auto_install_and_enable_default_plugins(manager: 'PluginManager', wgp_global
print(f"[Plugins] Install result for {repo_name}: {result}")
if "[Success]" in result:
- if repo_name not in enabled_plugins:
- enabled_plugins.append(repo_name)
- config_modified = True
+ newly_installed.append(repo_name)
+ for repo_name in newly_installed:
+ if repo_name in enabled_plugins:
+ enabled_plugins.remove(repo_name)
+ config_modified = True
+
if config_modified:
print("[Plugins] Disabling newly installed default plugins...")
- server_config["enabled_plugins"] = []
+ server_config["enabled_plugins"] = enabled_plugins
try:
with open(server_config_filename, 'w', encoding='utf-8') as f:
json.dump(server_config, f, indent=4)
@@ -55,11 +196,13 @@ def auto_install_and_enable_default_plugins(manager: 'PluginManager', wgp_global
"wan2gp-video-mask-creator",
"wan2gp-motion-designer",
"wan2gp-guides",
- "wan2gp-downloads",
"wan2gp-configuration",
"wan2gp-plugin-manager",
"wan2gp-about",
]
+BUNDLED_PLUGINS = {
+ "wan2gp-sample",
+}
USER_PLUGIN_INSERT_POSITION = 3
@@ -81,6 +224,7 @@ def __init__(self):
self.name = self.__class__.__name__
self.version = "1.0.0"
self.description = "No description provided."
+ self.uninstallable = True
self._component_requests: List[str] = []
self._global_requests: List[str] = []
self._insert_after_requests: List[InsertAfterRequest] = []
@@ -158,24 +302,563 @@ def __init__(self, plugins_dir="plugins"):
self.data_hooks: Dict[str, List[callable]] = {}
self.restricted_globals: Set[str] = set()
self.custom_js_snippets: List[str] = []
+ self.repo_root = os.path.abspath(os.getcwd())
+ self.catalog_path = os.path.join(self.repo_root, PLUGIN_CATALOG_FILENAME)
+ self.local_catalog_path = os.path.join(self.repo_root, PLUGIN_LOCAL_CATALOG_FILENAME)
+ self.server_config: Optional[Dict[str, Any]] = None
+ self.server_config_filename: str = ""
+
+ def set_server_config(self, server_config: Optional[Dict[str, Any]], server_config_filename: str = "") -> None:
+ self.server_config = server_config if isinstance(server_config, dict) else None
+ self.server_config_filename = server_config_filename or ""
+
+ def _save_server_config(self) -> None:
+ if not self.server_config or not self.server_config_filename:
+ return
+ try:
+ with open(self.server_config_filename, "w", encoding="utf-8") as writer:
+ writer.write(json.dumps(self.server_config, indent=4))
+ except Exception as e:
+ print(f"[PluginManager] Failed to write config file '{self.server_config_filename}': {e}")
+
+ def _get_pending_deletions(self) -> List[str]:
+ if not self.server_config:
+ return []
+ pending = self.server_config.get(PENDING_DELETIONS_KEY, [])
+ if not isinstance(pending, list):
+ return []
+ cleaned = []
+ for item in pending:
+ if isinstance(item, str) and item.strip():
+ cleaned.append(item.strip())
+ return cleaned
+
+ def _set_pending_deletions(self, pending: List[str]) -> None:
+ if not self.server_config:
+ return
+ unique = []
+ seen = set()
+ for item in pending:
+ if not isinstance(item, str):
+ continue
+ key = item.strip()
+ if not key or key in seen:
+ continue
+ seen.add(key)
+ unique.append(key)
+ self.server_config[PENDING_DELETIONS_KEY] = unique
+ self._save_server_config()
+
+ def _add_pending_deletion(self, plugin_id: str) -> None:
+ if not plugin_id:
+ return
+ pending = self._get_pending_deletions()
+ if plugin_id not in pending:
+ pending.append(plugin_id)
+ self._set_pending_deletions(pending)
+
+ def _clear_pending_deletion(self, plugin_id: str) -> None:
+ if not plugin_id:
+ return
+ pending = self._get_pending_deletions()
+ if plugin_id in pending:
+ pending = [item for item in pending if item != plugin_id]
+ self._set_pending_deletions(pending)
+
+ def _is_cleanup_candidate(self, path: str) -> bool:
+ try:
+ for entry in os.scandir(path):
+ name = entry.name
+ if name.startswith("."):
+ continue
+ if entry.is_file(follow_symlinks=False):
+ if name.endswith(".pyc"):
+ continue
+ return False
+ if entry.is_dir(follow_symlinks=False):
+ if name == "__pycache__":
+ continue
+ if not self._is_cleanup_candidate(entry.path):
+ return False
+ return True
+ except Exception:
+ return False
+
+ def cleanup_pending_deletions(self) -> None:
+ pending = self._get_pending_deletions()
+ if not pending:
+ return
+ remaining = []
+ for plugin_id in pending:
+ plugin_dir = os.path.join(self.plugins_dir, plugin_id)
+ if not os.path.isdir(plugin_dir):
+ continue
+ if self._is_cleanup_candidate(plugin_dir):
+ try:
+ shutil.rmtree(plugin_dir, onerror=self._remove_readonly)
+ continue
+ except Exception:
+ remaining.append(plugin_id)
+ continue
+ remaining.append(plugin_id)
+ self._set_pending_deletions(remaining)
+
+ def _coerce_bool(self, value: Any, default: bool = True) -> bool:
+ if isinstance(value, bool):
+ return value
+ if isinstance(value, str):
+ lowered = value.strip().lower()
+ if lowered in ("true", "1", "yes"):
+ return True
+ if lowered in ("false", "0", "no"):
+ return False
+ return default
+
+ def _load_json_file(self, path: str) -> Optional[Any]:
+ if not path or not os.path.isfile(path):
+ return None
+ last_error = None
+ for encoding in ("utf-8", "utf-8-sig"):
+ try:
+ with open(path, "r", encoding=encoding) as reader:
+ return json.load(reader)
+ except UnicodeDecodeError as e:
+ last_error = e
+ continue
+ except json.JSONDecodeError as e:
+ last_error = e
+ if encoding == "utf-8" and "UTF-8 BOM" in str(e):
+ continue
+ break
+ except Exception as e:
+ last_error = e
+ break
+ if last_error is not None:
+ print(f"[PluginManager] Failed to read JSON from {path}: {last_error}")
+ return None
+
+ def _write_json_file(self, path: str, payload: Any) -> None:
+ if not path:
+ return
+ try:
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
+ with open(path, "w", encoding="utf-8") as writer:
+ json.dump(payload, writer, indent=2)
+ except Exception as e:
+ print(f"[PluginManager] Failed to write JSON to {path}: {e}")
+
+ def _load_plugin_metadata(self, plugin_path: str) -> Optional[Dict[str, Any]]:
+ metadata_path = os.path.join(plugin_path, PLUGIN_METADATA_FILENAME)
+ payload = self._load_json_file(metadata_path)
+ if not isinstance(payload, dict):
+ return None
+ return self._normalize_plugin_metadata(payload)
+
+ def _normalize_plugin_metadata(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ metadata = dict(payload)
+ for key in ("name", "version", "description", "author", "date", "wan2gp_version"):
+ value = metadata.get(key)
+ if isinstance(value, str):
+ metadata[key] = value.strip()
+ elif value is None:
+ metadata[key] = ""
+ else:
+ metadata[key] = str(value)
+ legacy_version = metadata.get("wangp_version")
+ if not _has_value(metadata.get("wan2gp_version")) and _has_value(legacy_version):
+ metadata["wan2gp_version"] = str(legacy_version).strip()
+ metadata.pop("wangp_version", None)
+ metadata["url"] = ""
+ metadata["uninstallable"] = self._coerce_bool(metadata.get("uninstallable"), default=True)
+ return metadata
+
+ def _apply_metadata_to_plugin(self, plugin: WAN2GPPlugin, metadata: Optional[Dict[str, Any]], is_system: bool) -> None:
+ if not metadata:
+ if is_system:
+ plugin.uninstallable = False
+ return
+ for key in ("name", "version", "description", "author", "date", "wan2gp_version"):
+ if key in metadata:
+ setattr(plugin, key, metadata.get(key))
+ if not _has_value(getattr(plugin, "wan2gp_version", "")) and _has_value(metadata.get("wangp_version")):
+ setattr(plugin, "wan2gp_version", str(metadata.get("wangp_version")).strip())
+ if "uninstallable" in metadata:
+ plugin.uninstallable = self._coerce_bool(metadata.get("uninstallable"), default=True)
+ if is_system:
+ plugin.uninstallable = False
+
+ def _normalize_catalog_entry(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ entry = dict(payload)
+ for key in ("name", "author", "version", "description", "url", "date", "wan2gp_version", "last_check"):
+ value = entry.get(key)
+ if key == "url":
+ if isinstance(value, str):
+ entry[key] = normalize_plugin_url(value)
+ elif value is None:
+ entry[key] = ""
+ else:
+ entry[key] = normalize_plugin_url(str(value))
+ continue
+ if isinstance(value, str):
+ entry[key] = value.strip()
+ elif value is None:
+ entry[key] = ""
+ else:
+ entry[key] = str(value)
+ legacy_version = entry.get("wangp_version")
+ if not _has_value(entry.get("wan2gp_version")) and _has_value(legacy_version):
+ entry["wan2gp_version"] = str(legacy_version).strip()
+ entry.pop("wangp_version", None)
+ return entry
+
+ def _merge_entry_fields(self, primary: Dict[str, Any], secondary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+ result = dict(primary) if primary else {}
+ if not secondary:
+ return result
+ for key, value in secondary.items():
+ if _has_value(result.get(key)):
+ continue
+ if _has_value(value):
+ result[key] = value
+ return result
+
+ def _merge_catalog_entries(
+ self,
+ base_entries: List[Dict[str, Any]],
+ local_entries: List[Dict[str, Any]],
+ ) -> Dict[str, Dict[str, Any]]:
+ base_map: Dict[str, Dict[str, Any]] = {}
+ local_map: Dict[str, Dict[str, Any]] = {}
+ for entry in base_entries:
+ plugin_id = entry.get("id") or plugin_id_from_url(entry.get("url", ""))
+ if not plugin_id:
+ continue
+ base_map[plugin_id] = self._normalize_catalog_entry(entry)
+ for entry in local_entries:
+ plugin_id = entry.get("id") or plugin_id_from_url(entry.get("url", ""))
+ if not plugin_id:
+ continue
+ local_map[plugin_id] = self._normalize_catalog_entry(entry)
+
+ merged: Dict[str, Dict[str, Any]] = {}
+ all_ids = set(base_map.keys()) | set(local_map.keys())
+ for plugin_id in all_ids:
+ base_entry = base_map.get(plugin_id)
+ local_entry = local_map.get(plugin_id)
+ if base_entry and local_entry:
+ comparison = compare_release_metadata(local_entry, base_entry)
+ if comparison > 0:
+ merged_entry = dict(local_entry)
+ else:
+ merged_entry = dict(base_entry)
+ if _has_value(local_entry.get("last_check")):
+ merged_entry["last_check"] = local_entry.get("last_check")
+ merged[plugin_id] = merged_entry
+ else:
+ merged[plugin_id] = base_entry or local_entry
+ return merged
+
+ def _fetch_remote_catalog_entries(self) -> List[Dict[str, Any]]:
+ try:
+ response = requests.get(COMMUNITY_PLUGINS_URL, timeout=10)
+ response.raise_for_status()
+ payload = response.json()
+ if isinstance(payload, list):
+ return [self._normalize_catalog_entry(entry) for entry in payload if isinstance(entry, dict)]
+ except Exception as e:
+ print(f"[PluginManager] Could not fetch community plugins info: {e}")
+ return []
+
+ def load_catalog_entries(self, use_remote: bool = True) -> List[Dict[str, Any]]:
+ entries = None # self._fetch_remote_catalog_entries() if use_remote else []
+ if not entries:
+ payload = self._load_json_file(self.catalog_path)
+ if isinstance(payload, list):
+ entries = [self._normalize_catalog_entry(entry) for entry in payload if isinstance(entry, dict)]
+ return entries
+
+ def load_local_catalog_entries(self) -> List[Dict[str, Any]]:
+ payload = self._load_json_file(self.local_catalog_path)
+ if isinstance(payload, list):
+ return [self._normalize_catalog_entry(entry) for entry in payload if isinstance(entry, dict)]
+ return []
+
+ def get_merged_catalog_entries(self, use_remote: bool = True) -> Dict[str, Dict[str, Any]]:
+ base_entries = self.load_catalog_entries(use_remote=use_remote)
+ local_entries = self.load_local_catalog_entries()
+ return self._merge_catalog_entries(base_entries, local_entries)
+
+ def _build_plugin_json_url(self, url: str) -> str:
+ repo_info = _split_github_repo(url)
+ if not repo_info:
+ return ""
+ owner, repo = repo_info
+ return f"https://github.com/{owner}/{repo}/raw/HEAD/{PLUGIN_METADATA_FILENAME}"
+
+ def _fetch_plugin_json(self, url: str, quiet: bool = False) -> Optional[Dict[str, Any]]:
+ plugin_json_url = self._build_plugin_json_url(url)
+ if not plugin_json_url:
+ return None
+ try:
+ response = requests.get(plugin_json_url, timeout=10)
+ if response.status_code != 200:
+ return None
+ payload = response.json()
+ if not isinstance(payload, dict):
+ return None
+ return self._normalize_plugin_metadata(payload)
+ except Exception as e:
+ if not quiet:
+ print(f"[PluginManager] Could not fetch {PLUGIN_METADATA_FILENAME} for {url}: {e}")
+ return None
+
+ def _metadata_to_catalog_entry(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
+ entry = {}
+ for key in ("name", "author", "version", "description", "date", "wan2gp_version"):
+ entry[key] = metadata.get(key, "")
+ return self._normalize_catalog_entry(entry)
+
+ def _extract_catalog_metadata(self, entry: Dict[str, Any]) -> Dict[str, Any]:
+ payload = {}
+ for key in ("name", "author", "version", "description", "url", "date", "wan2gp_version"):
+ payload[key] = entry.get(key, "")
+ return self._normalize_catalog_entry(payload)
+
+ def _find_catalog_entry(self, plugin_id: str, url: str = "", use_remote: bool = False) -> Optional[Dict[str, Any]]:
+ entries = self.load_catalog_entries(use_remote=use_remote)
+ target_id = plugin_id.strip().lower() if isinstance(plugin_id, str) else ""
+ target_url = normalize_plugin_url(url).lower() if isinstance(url, str) else ""
+ for entry in entries:
+ entry_url = entry.get("url", "")
+ entry_id = entry.get("id") or plugin_id_from_url(entry_url)
+ entry_id = entry_id.strip().lower() if isinstance(entry_id, str) else ""
+ entry_url_norm = normalize_plugin_url(entry_url).lower() if isinstance(entry_url, str) else ""
+ if target_id and entry_id and entry_id == target_id:
+ return self._normalize_catalog_entry(entry)
+ if target_url and entry_url_norm and entry_url_norm == target_url:
+ return self._normalize_catalog_entry(entry)
+ return None
+
+ def _get_git_remote_url(self, plugin_path: str) -> str:
+ if not plugin_path:
+ return ""
+ git_dir = os.path.join(plugin_path, ".git")
+ if not os.path.isdir(git_dir):
+ return ""
+ try:
+ repo = git.Repo(plugin_path)
+ if repo.remotes:
+ return repo.remotes.origin.url
+ except Exception:
+ return ""
+ return ""
+
+ def refresh_catalog(self, installed_only: bool = True, use_remote: bool = True) -> Dict[str, int]:
+ base_entries = self.load_catalog_entries(use_remote=False)
+ local_entries = self.load_local_catalog_entries()
+ merged_catalog = self._merge_catalog_entries(base_entries, local_entries)
+ local_map: Dict[str, Dict[str, Any]] = {}
+ for entry in local_entries:
+ plugin_id = entry.get("id") or plugin_id_from_url(entry.get("url", ""))
+ if plugin_id:
+ local_map[plugin_id] = self._normalize_catalog_entry(entry)
+
+ def _info_to_entry(info: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "name": info.get("name", ""),
+ "author": info.get("author", ""),
+ "version": info.get("version", ""),
+ "description": info.get("description", ""),
+ "url": info.get("url", ""),
+ "date": info.get("date", ""),
+ "wan2gp_version": info.get("wan2gp_version", ""),
+ }
+
+ targets: List[Dict[str, Any]] = []
+ if installed_only:
+ installed_plugins = [info for info in self.get_plugins_info() if not info.get("system")]
+ for info in installed_plugins:
+ plugin_id = info.get("id", "")
+ if not plugin_id:
+ continue
+ base_entry = merged_catalog.get(plugin_id, {})
+ git_url = self._get_git_remote_url(info.get("path", ""))
+ url = normalize_plugin_url(base_entry.get("url") or info.get("url") or git_url or "")
+ if git_url and plugin_id_from_url(git_url) == plugin_id:
+ url = normalize_plugin_url(git_url)
+ if not _has_value(url):
+ continue
+ targets.append(
+ {
+ "id": plugin_id,
+ "url": url,
+ "base_entry": base_entry,
+ "info_entry": _info_to_entry(info),
+ }
+ )
+ else:
+ for plugin_id, entry in merged_catalog.items():
+ url = normalize_plugin_url(entry.get("url") or "")
+ if not _has_value(url):
+ continue
+ targets.append({"id": plugin_id, "url": url, "base_entry": entry, "info_entry": {}})
+
+ checked = 0
+ updated = 0
+ updates_available = 0
+ for target in targets:
+ plugin_id = target["id"]
+ url = target["url"]
+ base_entry = target.get("base_entry", {})
+ info_entry = target.get("info_entry", {})
+ checked += 1
+ metadata = self._fetch_plugin_json(url, quiet=True)
+ if not metadata:
+ continue
+ catalog_entry = self._metadata_to_catalog_entry(metadata)
+ if _has_value(url):
+ catalog_entry_url = catalog_entry.get("url", "")
+ if not _has_value(catalog_entry_url) or plugin_id_from_url(catalog_entry_url) != plugin_id:
+ catalog_entry["url"] = url
+ catalog_entry = self._merge_entry_fields(catalog_entry, base_entry)
+ if info_entry:
+ catalog_entry = self._merge_entry_fields(catalog_entry, info_entry)
+ catalog_entry["last_check"] = datetime.datetime.now().isoformat(timespec="seconds")
+ catalog_entry = self._normalize_catalog_entry(catalog_entry)
+ existing_entry = local_map.get(plugin_id, {})
+ if info_entry:
+ if compare_release_metadata(catalog_entry, info_entry) > 0:
+ updates_available += 1
+ elif existing_entry:
+ if compare_release_metadata(catalog_entry, existing_entry) > 0:
+ updates_available += 1
+ local_map[plugin_id] = catalog_entry
+ updated += 1
+
+ if updated > 0:
+ self._write_json_file(self.local_catalog_path, list(local_map.values()))
+ return {"checked": checked, "updated": updated, "updates_available": updates_available}
+
+ def record_plugin_metadata(self, plugin_id: str, url: str = "") -> bool:
+ if not plugin_id:
+ return False
+ url = normalize_plugin_url(url)
+ plugin_path = os.path.join(self.plugins_dir, plugin_id)
+ metadata = self._load_plugin_metadata(plugin_path) if os.path.isdir(plugin_path) else None
+ plugin_json_found = metadata is not None
+ if not metadata and _has_value(url):
+ metadata = self._fetch_plugin_json(url)
+ plugin_json_found = metadata is not None
+ if not metadata:
+ plugin_info = next((info for info in self.get_plugins_info() if info.get("id") == plugin_id), None)
+ if plugin_info:
+ metadata = {
+ "name": plugin_info.get("name", ""),
+ "version": plugin_info.get("version", ""),
+ "description": plugin_info.get("description", ""),
+ "author": plugin_info.get("author", ""),
+ "url": plugin_info.get("url", ""),
+ "date": plugin_info.get("date", ""),
+ "wan2gp_version": plugin_info.get("wan2gp_version", ""),
+ }
+ if not metadata:
+ return False
+
+ catalog_entry = self._metadata_to_catalog_entry(metadata)
+ if _has_value(url):
+ catalog_entry_url = catalog_entry.get("url", "")
+ if not _has_value(catalog_entry_url) or plugin_id_from_url(catalog_entry_url) != plugin_id:
+ catalog_entry["url"] = url
+
+ merged_catalog = self.get_merged_catalog_entries(use_remote=False)
+ catalog_entry = self._merge_entry_fields(catalog_entry, merged_catalog.get(plugin_id, {}))
+ base_entry = self._find_catalog_entry(plugin_id, url=url, use_remote=False)
+ if base_entry:
+ catalog_entry = self._merge_entry_fields(catalog_entry, base_entry)
+ if plugin_json_found:
+ catalog_entry["last_check"] = datetime.datetime.now().isoformat(timespec="seconds")
+ catalog_entry = self._normalize_catalog_entry(catalog_entry)
+
+ local_entries = self.load_local_catalog_entries()
+ local_map: Dict[str, Dict[str, Any]] = {}
+ for entry in local_entries:
+ existing_id = entry.get("id") or plugin_id_from_url(entry.get("url", ""))
+ if existing_id:
+ local_map[existing_id] = self._normalize_catalog_entry(entry)
+ local_map[plugin_id] = catalog_entry
+ self._write_json_file(self.local_catalog_path, list(local_map.values()))
+ return True
+
+ def merge_local_catalog(self) -> str:
+ local_entries = self.load_local_catalog_entries()
+ if not local_entries:
+ return "[Info] No local catalog entries to merge."
+ base_entries = self.load_catalog_entries(use_remote=False)
+ merged = self._merge_catalog_entries(base_entries, local_entries)
+ merged_list = []
+ for entry in merged.values():
+ entry.pop("last_check", None)
+ merged_list.append(entry)
+ merged_list.sort(key=lambda item: item.get("name", ""))
+ self._write_json_file(self.catalog_path, merged_list)
+ try:
+ os.remove(self.local_catalog_path)
+ except FileNotFoundError:
+ pass
+ except Exception as e:
+ return f"[Warning] Catalog merged, but failed to remove {self.local_catalog_path}: {e}"
+ return "[Success] Catalog merged and local overrides removed."
def get_plugins_info(self) -> List[Dict[str, str]]:
plugins_info = []
for dir_name in self.discover_plugins():
plugin_path = os.path.join(self.plugins_dir, dir_name)
is_system = dir_name in SYSTEM_PLUGINS
- info = {'id': dir_name, 'name': dir_name, 'version': 'N/A', 'description': 'No description provided.', 'path': plugin_path, 'system': is_system}
- try:
- module = importlib.import_module(f"{dir_name}.plugin")
- for name, obj in inspect.getmembers(module, inspect.isclass):
- if issubclass(obj, WAN2GPPlugin) and obj != WAN2GPPlugin:
- instance = obj()
- info['name'] = instance.name
- info['version'] = instance.version
- info['description'] = instance.description
- break
- except Exception as e:
- print(f"Could not load metadata for plugin {dir_name}: {e}")
+ info = {
+ 'id': dir_name,
+ 'name': dir_name,
+ 'version': 'N/A',
+ 'description': 'No description provided.',
+ 'author': '',
+ 'url': '',
+ 'date': '',
+ 'wan2gp_version': '',
+ 'path': plugin_path,
+ 'system': is_system,
+ 'uninstallable': True,
+ }
+ metadata = self._load_plugin_metadata(plugin_path)
+ if metadata:
+ info['name'] = metadata.get('name', info['name'])
+ info['version'] = metadata.get('version', info['version'])
+ info['description'] = metadata.get('description', info['description'])
+ info['author'] = metadata.get('author', info['author'])
+ info['url'] = metadata.get('url', info['url'])
+ info['date'] = metadata.get('date', info['date'])
+ info['wan2gp_version'] = metadata.get('wan2gp_version', info['wan2gp_version'])
+ info['uninstallable'] = bool(metadata.get('uninstallable', info['uninstallable']))
+ else:
+ try:
+ module = importlib.import_module(f"{dir_name}.plugin")
+ for name, obj in inspect.getmembers(module, inspect.isclass):
+ if issubclass(obj, WAN2GPPlugin) and obj != WAN2GPPlugin:
+ instance = obj()
+ info['name'] = instance.name
+ info['version'] = instance.version
+ info['description'] = instance.description
+ info['uninstallable'] = bool(getattr(instance, 'uninstallable', True))
+ break
+ except Exception as e:
+ print(f"Could not load metadata for plugin {dir_name}: {e}")
+ if is_system:
+ info['uninstallable'] = False
+ if info['id'] in BUNDLED_PLUGINS:
+ info['uninstallable'] = False
+ if not is_system:
+ merged_catalog = self.get_merged_catalog_entries(use_remote=False)
+ catalog_entry = merged_catalog.get(info.get("id", ""))
+ if catalog_entry:
+ info = self._merge_entry_fields(info, self._extract_catalog_metadata(catalog_entry))
plugins_info.append(info)
plugins_info.sort(key=lambda p: (not p['system'], p['name']))
@@ -188,16 +871,32 @@ def _remove_readonly(self, func, path, exc_info):
else:
raise
+ def _is_plugin_uninstallable(self, plugin_id: str) -> bool:
+ if plugin_id in SYSTEM_PLUGINS:
+ return False
+ if plugin_id in BUNDLED_PLUGINS:
+ return False
+ try:
+ for info in self.get_plugins_info():
+ if info.get('id') == plugin_id:
+ return bool(info.get('uninstallable', True))
+ except Exception:
+ pass
+ return True
+
def uninstall_plugin(self, plugin_id: str):
if not plugin_id:
return "[Error] No plugin selected for uninstallation."
if plugin_id in SYSTEM_PLUGINS:
return f"[Error] Cannot uninstall system plugin '{plugin_id}'."
+ if not self._is_plugin_uninstallable(plugin_id):
+ return f"[Error] Cannot uninstall protected plugin '{plugin_id}'."
target_dir = os.path.join(self.plugins_dir, plugin_id)
if not os.path.isdir(target_dir):
return f"[Error] Plugin '{plugin_id}' directory not found."
+ self._add_pending_deletion(plugin_id)
try:
shutil.rmtree(target_dir, onerror=self._remove_readonly)
@@ -216,13 +915,20 @@ def update_plugin(self, plugin_id: str, progress=None):
try:
if progress is not None: progress(0, desc=f"Updating '{plugin_id}'...")
repo = git.Repo(target_dir)
+ if not repo.remotes:
+ return f"[Error] Update failed: no git remote configured for '{plugin_id}'."
origin = repo.remotes.origin
if progress is not None: progress(0.2, desc=f"Fetching updates for '{plugin_id}'...")
origin.fetch()
-
+
local_commit = repo.head.commit
- remote_commit = origin.refs[repo.active_branch.name].commit
+ try:
+ branch_name = repo.active_branch.name
+ remote_ref = origin.refs[branch_name]
+ remote_commit = remote_ref.commit
+ except Exception:
+ return f"[Error] Update failed: could not resolve remote branch for '{plugin_id}'."
if local_commit == remote_commit:
return f"[Info] Plugin '{plugin_id}' is already up to date."
@@ -239,7 +945,13 @@ def update_plugin(self, plugin_id: str, progress=None):
return f"[Success] Plugin '{plugin_id}' updated. Please restart WanGP for changes to take effect."
except git.exc.GitCommandError as e:
traceback.print_exc()
- return f"[Error] Git update failed for '{plugin_id}': {e.stderr}"
+ stderr = (e.stderr or str(e)).strip()
+ lowered = stderr.lower()
+ if any(token in lowered for token in ("not found", "repository", "could not read from remote")):
+ return f"[Error] Update failed: remote repository not found or unreachable for '{plugin_id}'."
+ if any(token in lowered for token in ("authentication", "access denied", "permission denied")):
+ return f"[Error] Update failed: access denied to remote for '{plugin_id}'."
+ return f"[Error] Git update failed for '{plugin_id}': {stderr}"
except Exception as e:
traceback.print_exc()
return f"[Error] An unexpected error occurred during update of '{plugin_id}': {str(e)}"
@@ -296,18 +1008,26 @@ def reinstall_plugin(self, plugin_id: str, progress=None):
return f"[CRITICAL ERROR] Reinstallation failed AND could not restore backup. Plugin '{plugin_id}' is now in a broken state. Please manually rename '{backup_dir}' back to '{target_dir}'. Original error: {install_msg}. Restore error: {restore_e}"
def install_plugin_from_url(self, git_url: str, progress=None):
- if not git_url or not git_url.startswith("https://github.com/"):
- return "[Error] Invalid GitHub URL."
+ cleaned_url = normalize_plugin_url(git_url)
+ if not cleaned_url or not cleaned_url.startswith("https://github.com/"):
+ return "[Error] Invalid URL."
try:
- repo_name = git_url.split('/')[-1].replace('.git', '')
+ repo_name = plugin_id_from_url(cleaned_url)
+ if not repo_name:
+ return "[Error] Invalid URL."
target_dir = os.path.join(self.plugins_dir, repo_name)
if os.path.exists(target_dir):
return f"[Warning] Plugin '{repo_name}' already exists. Please remove it manually to reinstall."
if progress is not None: progress(0.1, desc=f"Cloning '{repo_name}'...")
- git.Repo.clone_from(git_url, target_dir)
+ git.Repo.clone_from(cleaned_url, target_dir)
+
+ plugin_entry = os.path.join(target_dir, "plugin.py")
+ if not os.path.isfile(plugin_entry):
+ shutil.rmtree(target_dir, onerror=self._remove_readonly)
+ return "[Error] Invalid Plugin."
requirements_path = os.path.join(target_dir, 'requirements.txt')
if os.path.exists(requirements_path):
@@ -331,17 +1051,41 @@ def install_plugin_from_url(self, git_url: str, progress=None):
if not os.path.exists(init_path):
with open(init_path, 'w') as f:
pass
+
+ self._strip_uninstallable_flag(target_dir)
if progress is not None: progress(1.0, desc="Installation complete.")
+ self._clear_pending_deletion(repo_name)
return f"[Success] Plugin '{repo_name}' installed. Please enable it in the list and restart WanGP."
except git.exc.GitCommandError as e:
traceback.print_exc()
- return f"[Error] Git clone failed: {e.stderr}"
+ stderr = (e.stderr or str(e)).strip()
+ lowered = stderr.lower()
+ if any(token in lowered for token in ("not found", "repository", "fatal", "could not read from remote")):
+ return "[Error] Invalid URL."
+ return f"[Error] Git clone failed: {stderr}"
except Exception as e:
traceback.print_exc()
return f"[Error] An unexpected error occurred: {str(e)}"
+ def _strip_uninstallable_flag(self, plugin_dir: str) -> None:
+ if not plugin_dir or not os.path.isdir(plugin_dir):
+ return
+ metadata_path = os.path.join(plugin_dir, PLUGIN_METADATA_FILENAME)
+ payload = self._load_json_file(metadata_path)
+ if not isinstance(payload, dict):
+ return
+ if "uninstallable" not in payload:
+ return
+ if not self._coerce_bool(payload.get("uninstallable"), default=True):
+ payload.pop("uninstallable", None)
+ try:
+ with open(metadata_path, "w", encoding="utf-8") as writer:
+ json.dump(payload, writer, indent=2, ensure_ascii=True)
+ except Exception as e:
+ print(f"[PluginManager] Failed to update {metadata_path}: {e}")
+
def discover_plugins(self) -> List[str]:
discovered = []
for item in os.listdir(self.plugins_dir):
@@ -350,19 +1094,29 @@ def discover_plugins(self) -> List[str]:
discovered.append(item)
return sorted(discovered)
- def load_plugins_from_directory(self, enabled_user_plugins: List[str]) -> None:
+ def load_plugins_from_directory(self, enabled_user_plugins: List[str], safe_mode: bool = False) -> None:
self.custom_js_snippets = []
- plugins_to_load = SYSTEM_PLUGINS + [p for p in enabled_user_plugins if p not in SYSTEM_PLUGINS]
+ if safe_mode:
+ print("[Safe Mode] User plugins are disabled. Only system plugins will be loaded.")
+ plugins_to_load = SYSTEM_PLUGINS
+ else:
+ plugins_to_load = SYSTEM_PLUGINS + [p for p in enabled_user_plugins if p not in SYSTEM_PLUGINS]
for plugin_dir_name in self.discover_plugins():
if plugin_dir_name not in plugins_to_load:
continue
try:
module = importlib.import_module(f"{plugin_dir_name}.plugin")
+ plugin_path = os.path.join(self.plugins_dir, plugin_dir_name)
+ metadata = self._load_plugin_metadata(plugin_path)
+ is_bundled = plugin_dir_name in BUNDLED_PLUGINS
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, WAN2GPPlugin) and obj != WAN2GPPlugin:
plugin = obj()
+ self._apply_metadata_to_plugin(plugin, metadata, plugin_dir_name in SYSTEM_PLUGINS)
+ if is_bundled:
+ plugin.uninstallable = False
plugin.setup_ui()
self.plugins[plugin_dir_name] = plugin
if plugin.custom_js_snippets:
@@ -496,16 +1250,23 @@ def __init__(self):
def initialize_plugins(self, wgp_globals: dict):
if not hasattr(self, 'plugin_manager'):
return
-
- auto_install_and_enable_default_plugins(self.plugin_manager, wgp_globals)
+
+ safe_mode = wgp_globals.get("SAFE_MODE", False)
+
+ if not safe_mode:
+ auto_install_and_enable_default_plugins(self.plugin_manager, wgp_globals)
server_config = wgp_globals.get("server_config")
+ server_config_filename = wgp_globals.get("server_config_filename", "")
if not server_config:
print("[PluginManager] ERROR: server_config not found in globals.")
return
+ self.plugin_manager.set_server_config(server_config, server_config_filename)
+ self.plugin_manager.cleanup_pending_deletions()
self.enabled_plugins = server_config.get("enabled_plugins", [])
- self.plugin_manager.load_plugins_from_directory(self.enabled_plugins)
+
+ self.plugin_manager.load_plugins_from_directory(self.enabled_plugins, safe_mode=safe_mode)
self.plugin_manager.inject_globals(wgp_globals)
def setup_ui_tabs(self, main_tabs_component: gr.Tabs, state_component: gr.State, set_save_form_event):
diff --git a/Wan2GP/shared/utils/prompt_extend.py b/Wan2GP/shared/utils/prompt_extend.py
index 50d8a056f..e7a21b536 100644
--- a/Wan2GP/shared/utils/prompt_extend.py
+++ b/Wan2GP/shared/utils/prompt_extend.py
@@ -345,32 +345,22 @@ def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
min_pixels=min_pixels,
max_pixels=max_pixels,
use_fast=True)
- # Force bfloat16 to reduce RAM usage (~14GB vs ~28GB for float32)
- # Also use low_cpu_mem_usage to reduce peak memory during loading
- model_dtype = torch.bfloat16 if FLASH_VER == 2 else torch.float16
- if "AWQ" in self.model_name:
- model_dtype = torch.float16
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_name,
- torch_dtype=model_dtype,
+ torch_dtype=torch.bfloat16 if FLASH_VER == 2 else
+ torch.float16 if "AWQ" in self.model_name else "auto",
attn_implementation="flash_attention_2"
if FLASH_VER == 2 else None,
- device_map="cpu",
- low_cpu_mem_usage=True)
+ device_map="cpu")
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
- # Force bfloat16/float16 to reduce RAM usage
- # Also use low_cpu_mem_usage to reduce peak memory during loading
- model_dtype = torch.bfloat16 if FLASH_VER == 2 else torch.float16
- if "AWQ" in self.model_name:
- model_dtype = torch.float16
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
- torch_dtype=model_dtype,
+ torch_dtype=torch.float16
+ if "AWQ" in self.model_name else "auto",
attn_implementation="flash_attention_2"
if FLASH_VER == 2 else None,
- device_map="cpu",
- low_cpu_mem_usage=True)
+ device_map="cpu")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
diff --git a/Wan2GP/shared/utils/prompt_parser.py b/Wan2GP/shared/utils/prompt_parser.py
index 46edec405..116a061c8 100644
--- a/Wan2GP/shared/utils/prompt_parser.py
+++ b/Wan2GP/shared/utils/prompt_parser.py
@@ -1,6 +1,6 @@
import re
-def process_template(input_text, keep_comments = False):
+def process_template(input_text, keep_comments=False, keep_empty_lines=False):
"""
Process a text template with macro instructions and variable substitution.
Supports multiple values for variables to generate multiple output versions.
@@ -14,7 +14,8 @@ def process_template(input_text, keep_comments = False):
- output_text: Processed output with variables substituted, or empty string if error
- error_message: Error description and problematic line, or empty string if no error
"""
- lines = input_text.strip().split('\n')
+ normalized_input = str(input_text or "").replace("\r\n", "\n").replace("\r", "\n")
+ lines = normalized_input.split("\n") if keep_empty_lines else normalized_input.strip().split("\n")
current_variables = {}
current_template_lines = []
all_output_lines = []
@@ -29,6 +30,8 @@ def process_template(input_text, keep_comments = False):
# Skip empty lines or comments
if not line:
+ if keep_empty_lines:
+ current_template_lines.append("")
continue
if line.startswith('#') and not keep_comments:
@@ -292,4 +295,4 @@ def generate_macro_line(variables_dict):
sections.append(section)
# Join sections with a colon and space for readability
- return "! " + " : ".join(sections)
\ No newline at end of file
+ return "! " + " : ".join(sections)
diff --git a/Wan2GP/shared/utils/self_refiner.py b/Wan2GP/shared/utils/self_refiner.py
new file mode 100644
index 000000000..730e00f6b
--- /dev/null
+++ b/Wan2GP/shared/utils/self_refiner.py
@@ -0,0 +1,418 @@
+import torch
+import copy
+import uuid
+from diffusers.utils.torch_utils import randn_tensor
+
+def is_int_string(s: str) -> bool:
+ try:
+ int(s)
+ return True
+ except ValueError:
+ return False
+
+def normalize_self_refiner_plan(plan_input, max_plans: int = 1):
+ default_plan = [
+ {"start": 1, "end": 5, "steps": 3},
+ {"start": 6, "end": 13, "steps": 1},
+ ]
+ if len(plan_input) > max_plans:
+ return [], f"Self-refiner supports up to {max_plans} plan(s); found {len(plan_input)}."
+ if not plan_input or not isinstance(plan_input, list):
+ return [default_plan], ""
+
+ return [plan_input], ""
+
+def ensure_refiner_list(plan_data):
+ if not isinstance(plan_data, list):
+ return []
+ for rule in plan_data:
+ if "id" not in rule:
+ rule["id"] = str(uuid.uuid4())
+ return plan_data
+
+def add_refiner_rule(current_rules, range_val, steps_val):
+ new_start, new_end = int(range_val[0]), int(range_val[1])
+
+ if new_start >= new_end:
+ from gradio import Info
+ Info(f"Start step ({new_start}) must be smaller than End step ({new_end}).")
+ return current_rules
+
+ for rule in current_rules:
+ if new_start <= rule['end'] and new_end >= rule['start']:
+ from gradio import Info
+ Info(f"Overlap detected! Steps {new_start}-{new_end} conflict with existing rule {rule['start']}-{rule['end']}.")
+ return current_rules
+
+ new_rule = {
+ "id": str(uuid.uuid4()),
+ "start": new_start,
+ "end": new_end,
+ "steps": int(steps_val)
+ }
+ updated_list = current_rules + [new_rule]
+ return sorted(updated_list, key=lambda x: x['start'])
+
+def remove_refiner_rule(current_rules, rule_id):
+ return [r for r in current_rules if r["id"] != rule_id]
+
+class PnPHandler:
+ def __init__(self, stochastic_plan, ths_uncertainty=0.0, p_norm=1, certain_percentage=0.999, channel_dim: int = 1):
+ self.stochastic_step_map = self._build_stochastic_step_map(stochastic_plan)
+ self.ths_uncertainty = ths_uncertainty
+ self.p_norm = p_norm
+ self.certain_percentage = certain_percentage
+ self.channel_dim = channel_dim
+ self.buffer = [None] # [certain_mask, pred_original_sample, latents_next]
+ self.certain_flag = False
+
+ def _build_stochastic_step_map(self, plan):
+ step_map = {}
+ if not plan:
+ return step_map
+
+ for entry in plan:
+ if isinstance(entry, dict):
+ start = entry.get("start", entry.get("begin"))
+ end = entry.get("end", entry.get("stop"))
+ steps = entry.get("steps", entry.get("anneal", entry.get("num_anneal_steps", 1)))
+ else:
+ start, end, steps = entry
+
+ start_i = int(start)
+ end_i = int(end)
+ steps_i = int(steps)
+
+ if steps_i > 0:
+ for idx in range(start_i, end_i + 1):
+ step_map[idx] = steps_i
+ return step_map
+
+ def get_anneal_steps(self, step_index):
+ return self.stochastic_step_map.get(step_index, 0)
+
+ def reset_buffer(self):
+ self.buffer = [None]
+ self.certain_flag = False
+
+ def process_step(self, latents, noise_pred, sigma, sigma_next, generator=None, device=None, latents_next=None, pred_original_sample=None):
+ """
+ Returns (latents_next, buffer_updated)
+ """
+ # Predict original sample (x0) and next latent
+ # x_t = t * x_1 + (1-t) * x_0 (Flow Matching)
+ # v_t = x_1 - x_0
+ # dx/dt = v_t
+ # Here sigma is time t?? In Wan code usually t goes 1000->0.
+ # Ref code: pred_original_sample = latents - sigma * noise_pred
+ # latents_next = latents + (sigma_next - sigma) * noise_pred
+ # This matches Flow Matching if sigma is time t.
+
+ if pred_original_sample is None:
+ pred_original_sample = latents - sigma * noise_pred
+
+ if latents_next is None:
+ latents_next = latents + (sigma_next - sigma) * noise_pred
+
+ if self.buffer[-1] is not None:
+ # Calculate uncertainty
+ # buffer[-1][1] is previous pred_original_sample
+ diff = pred_original_sample - self.buffer[-1][1]
+ channel_dim = self.channel_dim
+ if channel_dim < 0:
+ channel_dim += latents.ndim
+ # dim=channel_dim is channels/features
+ uncertainty = torch.norm(diff, p=self.p_norm, dim=channel_dim) / latents.shape[channel_dim]
+
+ certain_mask = uncertainty < self.ths_uncertainty
+ if self.buffer[-1][0] is not None:
+ certain_mask = certain_mask | self.buffer[-1][0]
+
+ if certain_mask.sum() / certain_mask.numel() > self.certain_percentage:
+ self.certain_flag = True
+
+ certain_mask_float = certain_mask.to(latents.dtype).unsqueeze(channel_dim) # Broadcast channels
+
+ # Blend
+ latents_next = certain_mask_float * self.buffer[-1][2] + (1.0 - certain_mask_float) * latents_next
+ pred_original_sample = certain_mask_float * self.buffer[-1][1] + (1.0 - certain_mask_float) * pred_original_sample
+
+ certain_mask_stored = certain_mask # keep bool
+ else:
+ certain_mask_stored = None
+ self.buffer.append([certain_mask_stored, pred_original_sample, latents_next])
+ return latents_next
+
+ def perturb_latents(self, latents, buffer_latent, sigma, generator=None, device=None, noise_mask=None):
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+
+ if noise_mask is None:
+ return (1.0 - sigma) * buffer_latent + sigma * noise
+
+ sigma_t = (noise_mask.to(latents.dtype) * sigma)
+ return (1.0 - sigma_t) * buffer_latent + sigma_t * noise
+
+ def run_refinement_loop(self,
+ latents,
+ noise_pred,
+ current_sigma,
+ next_sigma,
+ m_steps,
+ denoise_func,
+ step_func,
+ clone_func=None,
+ restore_func=None,
+ generator=None,
+ device=None,
+ noise_mask=None):
+
+ if noise_pred is None:
+ return None
+ # Save initial state if needed
+ scheduler_state = None
+ if clone_func:
+ scheduler_state = clone_func()
+
+ # Step 0 (Initial)
+ latents_next_0, pred_original_sample_0 = step_func(noise_pred, latents)
+ if latents_next_0 is None or pred_original_sample_0 is None:
+ return None
+
+ latents_next = self.process_step(
+ latents, noise_pred, current_sigma, next_sigma,
+ latents_next=latents_next_0, pred_original_sample=pred_original_sample_0
+ )
+
+ if self.certain_flag:
+ return latents_next
+
+ # Refinement Loop
+ for ii in range(1, m_steps):
+ if restore_func and scheduler_state is not None:
+ restore_func(scheduler_state)
+
+ # Perturb
+ latents_perturbed = self.perturb_latents(
+ latents,
+ self.buffer[-1][1],
+ current_sigma,
+ generator=generator,
+ device=device,
+ noise_mask=noise_mask,
+ )
+
+ # Denoise
+ n_pred = denoise_func(latents_perturbed)
+ if n_pred is None:
+ return None
+
+ # Step
+ latents_next_loop, pred_original_sample_loop = step_func(n_pred, latents_perturbed)
+ if latents_next_loop is None or pred_original_sample_loop is None:
+ return None
+
+ # Refine
+ latents_next = self.process_step(
+ latents_perturbed, n_pred, current_sigma, next_sigma,
+ latents_next=latents_next_loop, pred_original_sample=pred_original_sample_loop
+ )
+
+ if self.certain_flag:
+ break
+
+ return latents_next
+
+ def step(self, step_index, latents, noise_pred, t, timesteps, target_shape, seed_g, sample_scheduler, scheduler_kwargs, denoise_func):
+ if noise_pred is None:
+ return None, sample_scheduler
+ # Reset per denoising step to avoid blending with stale buffers from prior timesteps.
+ self.reset_buffer()
+ # Calculate sigma for PnP
+ current_sigma = t.item() / 1000.0
+ next_sigma = (0. if step_index == len(timesteps)-1 else timesteps[step_index+1].item()) / 1000.0
+
+ m_steps = self.get_anneal_steps(step_index)
+
+ if m_steps > 1 and not self.certain_flag:
+
+ def _get_prev_sample(step_out):
+ if hasattr(step_out, "prev_sample"):
+ return step_out.prev_sample
+ if isinstance(step_out, (tuple, list)):
+ return step_out[0]
+ return step_out
+
+ def _get_pred_original_sample(step_out, latents_in, n_pred_sliced):
+ if hasattr(step_out, "pred_original_sample"):
+ return step_out.pred_original_sample
+ t_val = t.item() if torch.is_tensor(t) else float(t)
+ return latents_in - (t_val / 1000.0) * n_pred_sliced
+
+ def step_func(n_pred_in, latents_in):
+ # Correct slicing:
+ # [:, :channels] slices Dimension 1
+ # [:, :, :frames] slices Dimension 2
+ n_pred_sliced = n_pred_in[:, :latents_in.shape[1], :target_shape[1]]
+
+ nonlocal sample_scheduler
+ step_out = sample_scheduler.step(n_pred_sliced, t, latents_in, **scheduler_kwargs)
+ latents_next_out = _get_prev_sample(step_out)
+ pred_original_sample_out = _get_pred_original_sample(step_out, latents_in, n_pred_sliced)
+ return latents_next_out, pred_original_sample_out
+
+ def clone_func():
+ if sample_scheduler is None:
+ return None
+ if getattr(sample_scheduler, "is_stateful", True):
+ return copy.deepcopy(sample_scheduler)
+ return None
+
+ def restore_func(saved_state):
+ nonlocal sample_scheduler
+ if saved_state:
+ sample_scheduler = copy.deepcopy(saved_state)
+
+ latents = self.run_refinement_loop(
+ latents=latents,
+ noise_pred=noise_pred,
+ current_sigma=current_sigma,
+ next_sigma=next_sigma,
+ m_steps=m_steps,
+ denoise_func=denoise_func,
+ step_func=step_func,
+ clone_func=clone_func,
+ restore_func=restore_func,
+ generator=seed_g,
+ device=latents.device
+ )
+ if latents is None:
+ return None, sample_scheduler
+ else:
+ # Standard logic
+ # Correct slicing: [:, :channels, :frames]
+ n_pred_sliced = noise_pred[:, :latents.shape[1], :target_shape[1]]
+ step_out = sample_scheduler.step( n_pred_sliced, t, latents, **scheduler_kwargs)
+ if hasattr(step_out, "prev_sample"):
+ latents = step_out.prev_sample
+ elif isinstance(step_out, (tuple, list)):
+ latents = step_out[0]
+ else:
+ latents = step_out
+
+ return latents, sample_scheduler
+
+def create_self_refiner_handler(pnp_plan, pnp_f_uncertainty, pnp_p_norm, pnp_certain_percentage, channel_dim: int = 1):
+ plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=1)
+ stochastic_plan = plans[0]
+
+ return PnPHandler(
+ stochastic_plan,
+ ths_uncertainty=pnp_f_uncertainty,
+ p_norm=pnp_p_norm,
+ certain_percentage=pnp_certain_percentage,
+ channel_dim=channel_dim,
+ )
+
+
+def run_refinement_loop_multi(
+ handlers,
+ latents_list,
+ noise_pred_list,
+ current_sigma,
+ next_sigma,
+ m_steps,
+ denoise_func,
+ step_func,
+ generators=None,
+ devices=None,
+ noise_masks=None,
+ stop_when: str = "all",
+):
+ if m_steps <= 1:
+ return latents_list
+ if noise_pred_list is None:
+ return None
+ if not isinstance(noise_pred_list, (list, tuple)) or any(pred is None for pred in noise_pred_list):
+ return None
+
+ def _should_stop():
+ if stop_when == "any":
+ return any(handler.certain_flag for handler in handlers)
+ return all(handler.certain_flag for handler in handlers)
+
+ latents_next_list, pred_original_list = step_func(noise_pred_list, latents_list)
+ if latents_next_list is None or pred_original_list is None:
+ return None
+ if not isinstance(latents_next_list, (list, tuple)) or not isinstance(pred_original_list, (list, tuple)):
+ return None
+ if len(latents_next_list) != len(handlers) or len(pred_original_list) != len(handlers):
+ return None
+ if any(latent is None for latent in latents_next_list) or any(pred is None for pred in pred_original_list):
+ return None
+
+ refined_latents_list = []
+ for handler, latents, latents_next, pred_original in zip(
+ handlers, latents_list, latents_next_list, pred_original_list
+ ):
+ refined_latents_list.append(
+ handler.process_step(
+ latents,
+ None,
+ current_sigma,
+ next_sigma,
+ latents_next=latents_next,
+ pred_original_sample=pred_original,
+ )
+ )
+ if _should_stop():
+ return refined_latents_list
+
+ for _ in range(1, m_steps):
+ perturbed_list = []
+ for idx, (handler, latents) in enumerate(zip(handlers, latents_list)):
+ generator = generators[idx] if generators is not None else None
+ device = devices[idx] if devices is not None else latents.device
+ noise_mask = noise_masks[idx] if noise_masks is not None else None
+ perturbed_list.append(
+ handler.perturb_latents(
+ latents,
+ handler.buffer[-1][1],
+ current_sigma,
+ generator=generator,
+ device=device,
+ noise_mask=noise_mask,
+ )
+ )
+
+ noise_pred_list = denoise_func(perturbed_list)
+ if noise_pred_list is None:
+ return None
+ if not isinstance(noise_pred_list, (list, tuple)) or any(pred is None for pred in noise_pred_list):
+ return None
+ latents_next_list, pred_original_list = step_func(noise_pred_list, perturbed_list)
+ if latents_next_list is None or pred_original_list is None:
+ return None
+ if not isinstance(latents_next_list, (list, tuple)) or not isinstance(pred_original_list, (list, tuple)):
+ return None
+ if len(latents_next_list) != len(handlers) or len(pred_original_list) != len(handlers):
+ return None
+ if any(latent is None for latent in latents_next_list) or any(pred is None for pred in pred_original_list):
+ return None
+ refined_latents_list = []
+ for handler, latents, latents_next, pred_original in zip(
+ handlers, perturbed_list, latents_next_list, pred_original_list
+ ):
+ refined_latents_list.append(
+ handler.process_step(
+ latents,
+ None,
+ current_sigma,
+ next_sigma,
+ latents_next=latents_next,
+ pred_original_sample=pred_original,
+ )
+ )
+ if _should_stop():
+ break
+
+ return refined_latents_list
diff --git a/Wan2GP/shared/utils/text_encoder_cache.py b/Wan2GP/shared/utils/text_encoder_cache.py
new file mode 100644
index 000000000..9bbb60db9
--- /dev/null
+++ b/Wan2GP/shared/utils/text_encoder_cache.py
@@ -0,0 +1,151 @@
+from __future__ import annotations
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Callable, Iterable, Hashable
+
+import torch
+
+
+@dataclass
+class _CacheEntry:
+ value: Any
+ size_bytes: int
+
+
+class TextEncoderCache:
+ def __init__(self, max_size_mb: float = 100) -> None:
+ self.max_size_bytes = int(max_size_mb * 1024 * 1024)
+ self._entries: "OrderedDict[Hashable, _CacheEntry]" = OrderedDict()
+ self._size_bytes = 0
+
+ def encode(
+ self,
+ encode_fn: Callable[[list[str]], list[Any]],
+ prompts: Iterable[str] | str,
+ device: torch.device | str | None = None,
+ parallel: bool = False,
+ cache_keys: Iterable[Hashable] | Hashable | None = None,
+ ) -> list[Any]:
+ if isinstance(prompts, str):
+ prompts_list = [prompts]
+ else:
+ prompts_list = list(prompts)
+ if not prompts_list:
+ return []
+ if cache_keys is None:
+ keys_list = prompts_list
+ else:
+ if len(prompts_list) == 1 and not isinstance(cache_keys, list):
+ keys_list = [cache_keys]
+ else:
+ keys_list = list(cache_keys)
+ if len(keys_list) != len(prompts_list):
+ raise ValueError("cache_keys must match the number of prompts.")
+
+ if not parallel:
+ results: list[Any] = []
+ for prompt, cache_key in zip(prompts_list, keys_list):
+ cached = self._entries.get(cache_key)
+ if cached is not None:
+ self._entries.move_to_end(cache_key)
+ results.append(self._to_device(cached.value, device))
+ continue
+ encoded = encode_fn([prompt])
+ if isinstance(encoded, (list, tuple)):
+ if not encoded:
+ raise ValueError("encode_fn returned empty embeddings.")
+ encoded_item = encoded[0]
+ else:
+ encoded_item = encoded
+ results.append(self._store(cache_key, encoded_item, device))
+ return results
+
+ results = [None] * len(prompts_list)
+ missing_prompts: list[str] = []
+ missing_indices: list[int] = []
+ missing_keys: list[Hashable] = []
+
+ for idx, (prompt, cache_key) in enumerate(zip(prompts_list, keys_list)):
+ cached = self._entries.get(cache_key)
+ if cached is None:
+ missing_prompts.append(prompt)
+ missing_indices.append(idx)
+ missing_keys.append(cache_key)
+ continue
+ self._entries.move_to_end(cache_key)
+ results[idx] = self._to_device(cached.value, device)
+
+ if missing_prompts:
+ encoded_batch = encode_fn(missing_prompts)
+ if not isinstance(encoded_batch, list):
+ encoded_batch = list(encoded_batch)
+ if len(encoded_batch) != len(missing_prompts):
+ raise ValueError("encode_fn returned unexpected number of embeddings.")
+ for cache_key, idx, encoded in zip(missing_keys, missing_indices, encoded_batch):
+ results[idx] = self._store(cache_key, encoded, device)
+
+ return results
+
+ def _store(self, cache_key: Hashable, encoded: Any, device: torch.device | str | None) -> Any:
+ cached_value = self._detach_to_cpu(encoded)
+ size_bytes = self._estimate_size_bytes(cached_value)
+ if size_bytes <= self.max_size_bytes:
+ existing = self._entries.pop(cache_key, None)
+ if existing is not None:
+ self._size_bytes -= existing.size_bytes
+ self._entries[cache_key] = _CacheEntry(cached_value, size_bytes)
+ self._size_bytes += size_bytes
+ self._purge_if_needed()
+ else:
+ if cache_key in self._entries:
+ self._entries.move_to_end(cache_key)
+ return self._to_device(encoded, device)
+
+ def _purge_if_needed(self) -> None:
+ if self._size_bytes <= self.max_size_bytes:
+ return
+ while self._entries and self._size_bytes > self.max_size_bytes:
+ _, entry = self._entries.popitem(last=False)
+ self._size_bytes -= entry.size_bytes
+
+ def _estimate_size_bytes(self, value: Any) -> int:
+ if torch.is_tensor(value):
+ return int(value.numel() * value.element_size())
+ if isinstance(value, dict):
+ return sum(self._estimate_size_bytes(v) for v in value.values())
+ if isinstance(value, (list, tuple)):
+ return sum(self._estimate_size_bytes(v) for v in value)
+ return 0
+
+ def _detach_to_cpu(self, value: Any) -> Any:
+ if torch.is_tensor(value):
+ if value.device.type == "cpu":
+ return value.detach()
+ return value.detach().to("cpu")
+ if isinstance(value, dict):
+ return {k: self._detach_to_cpu(v) for k, v in value.items()}
+ if isinstance(value, tuple):
+ items = [self._detach_to_cpu(v) for v in value]
+ if hasattr(value, "_fields"):
+ return value.__class__(*items)
+ return tuple(items)
+ if isinstance(value, list):
+ return [self._detach_to_cpu(v) for v in value]
+ return value
+
+ def _to_device(self, value: Any, device: torch.device | str | None) -> Any:
+ if device is None:
+ return value
+ if torch.is_tensor(value):
+ return value.to(device)
+ if isinstance(value, dict):
+ return {k: self._to_device(v, device) for k, v in value.items()}
+ if isinstance(value, tuple):
+ items = [self._to_device(v, device) for v in value]
+ if hasattr(value, "_fields"):
+ return value.__class__(*items)
+ return tuple(items)
+ if isinstance(value, list):
+ return [self._to_device(v, device) for v in value]
+ return value
diff --git a/Wan2GP/shared/utils/transformers_fast_tokenizer_patch.py b/Wan2GP/shared/utils/transformers_fast_tokenizer_patch.py
new file mode 100644
index 000000000..c1f0b8a97
--- /dev/null
+++ b/Wan2GP/shared/utils/transformers_fast_tokenizer_patch.py
@@ -0,0 +1,318 @@
+import json
+import os
+import pickle
+import sys
+
+
+_PATCH_ALLOWED_PATHS = None
+_ORIG_FAST_INIT = None
+
+
+def _normalize_path(path):
+ if not path:
+ return None
+ try:
+ return os.path.normcase(os.path.abspath(path))
+ except Exception:
+ return None
+
+
+def _path_allowed(path):
+ if not _PATCH_ALLOWED_PATHS:
+ return False
+ norm = _normalize_path(path)
+ if norm is None:
+ return False
+ for allowed in _PATCH_ALLOWED_PATHS:
+ if allowed is None:
+ continue
+ try:
+ if os.path.commonpath([norm, allowed]) == allowed:
+ return True
+ except Exception:
+ if norm.startswith(allowed):
+ return True
+ return False
+
+
+def _load_cached_tokenizer(tokenizer_file, TokenizerFast):
+ if not tokenizer_file:
+ return None
+ return TokenizerFast.from_file(tokenizer_file)
+
+
+def patch_pretrained_tokenizer_fast(allow_paths=None):
+ global _PATCH_ALLOWED_PATHS
+ global _ORIG_FAST_INIT
+ if allow_paths is not None:
+ _PATCH_ALLOWED_PATHS = [_normalize_path(p) for p in allow_paths if p]
+
+ try:
+ import transformers.tokenization_utils_fast as tuf
+ except Exception:
+ return
+
+ cls = tuf.PreTrainedTokenizerFast
+ if getattr(cls, "_wan2gp_fast_init_patched", False):
+ return
+
+ if _ORIG_FAST_INIT is None:
+ _ORIG_FAST_INIT = cls.__init__
+
+ def _patched_init(self, *args, **kwargs):
+ fast_tokenizer_file = kwargs.get("tokenizer_file")
+ from_slow = kwargs.get("from_slow", False)
+ if not fast_tokenizer_file or from_slow or not _path_allowed(fast_tokenizer_file):
+ return _ORIG_FAST_INIT(self, *args, **kwargs)
+
+ try:
+ fast_tokenizer = _load_cached_tokenizer(fast_tokenizer_file, tuf.TokenizerFast)
+ if fast_tokenizer is None:
+ return _ORIG_FAST_INIT(self, *args, **kwargs)
+ kwargs["tokenizer_object"] = fast_tokenizer
+ except Exception:
+ return _ORIG_FAST_INIT(self, *args, **kwargs)
+
+ tokenizer_object = kwargs.pop("tokenizer_object", None)
+ slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
+ fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
+ from_slow = kwargs.pop("from_slow", False)
+ added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
+ self.add_prefix_space = kwargs.get("add_prefix_space", False)
+
+ if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
+ raise ValueError(
+ "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
+ "have sentencepiece installed."
+ )
+
+ if tokenizer_object is not None:
+ fast_tokenizer = tokenizer_object
+ else:
+ fast_tokenizer = tuf.TokenizerFast.from_file(fast_tokenizer_file)
+
+ self._tokenizer = fast_tokenizer
+
+ if slow_tokenizer is not None:
+ kwargs.update(slow_tokenizer.init_kwargs)
+
+ self._decode_use_source_tokenizer = False
+
+ _truncation = self._tokenizer.truncation
+
+ if _truncation is not None:
+ self._tokenizer.enable_truncation(**_truncation)
+ kwargs.setdefault("max_length", _truncation["max_length"])
+ kwargs.setdefault("truncation_side", _truncation["direction"])
+ kwargs.setdefault("stride", _truncation["stride"])
+ kwargs.setdefault("truncation_strategy", _truncation["strategy"])
+ else:
+ self._tokenizer.no_truncation()
+
+ _padding = self._tokenizer.padding
+ if _padding is not None:
+ self._tokenizer.enable_padding(**_padding)
+ kwargs.setdefault("pad_token", _padding["pad_token"])
+ kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
+ kwargs.setdefault("padding_side", _padding["direction"])
+ kwargs.setdefault("max_length", _padding["length"])
+ kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
+
+ tuf.PreTrainedTokenizerBase.__init__(self, **kwargs)
+ self._tokenizer.encode_special_tokens = self.split_special_tokens
+
+ added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
+ tokens_to_add = [
+ token
+ for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
+ if hash(repr(token)) not in added_tokens_decoder_hash
+ ]
+ encoder_set = set(self.added_tokens_encoder.keys())
+ for token in tokens_to_add:
+ if isinstance(token, tuf.AddedToken):
+ encoder_set.add(token.content)
+ else:
+ encoder_set.add(str(token))
+ tokens_to_add_set = set(tokens_to_add)
+ tokens_to_add += [
+ token
+ for token in self.all_special_tokens_extended
+ if token not in encoder_set and token not in tokens_to_add_set
+ ]
+
+ if len(tokens_to_add) > 0:
+ special_tokens = set(self.all_special_tokens)
+ tokens = []
+ append = tokens.append
+ for token in tokens_to_add:
+ if isinstance(token, tuf.AddedToken):
+ content = token.content
+ if (not token.special) and (content in special_tokens):
+ token.special = True
+ append(token)
+ else:
+ append(tuf.AddedToken(token, special=(token in special_tokens)))
+ if tokens:
+ self.add_tokens(tokens)
+
+ try:
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
+ pre_tok_class = getattr(tuf.pre_tokenizers_fast, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = self.add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+ except Exception:
+ pass
+
+ cls.__init__ = _patched_init
+ cls._wan2gp_fast_init_patched = True
+
+
+def unpatch_pretrained_tokenizer_fast():
+ global _ORIG_FAST_INIT
+ if _ORIG_FAST_INIT is None:
+ return
+ try:
+ import transformers.tokenization_utils_fast as tuf
+ except Exception:
+ return
+ cls = tuf.PreTrainedTokenizerFast
+ if not getattr(cls, "_wan2gp_fast_init_patched", False):
+ return
+ cls.__init__ = _ORIG_FAST_INIT
+ cls._wan2gp_fast_init_patched = False
+
+
+def _get_transformers_version():
+ try:
+ import transformers as _transformers
+ return getattr(_transformers, "__version__", None)
+ except Exception:
+ return None
+
+
+def _get_tokenizers_version():
+ try:
+ import tokenizers as _tokenizers
+ return getattr(_tokenizers, "__version__", None)
+ except Exception:
+ return None
+
+
+def _collect_tokenizer_files(tokenizer_dir):
+ candidates = [
+ "tokenizer.json",
+ "tokenizer_config.json",
+ "special_tokens_map.json",
+ "added_tokens.json",
+ "vocab.json",
+ "merges.txt",
+ "config.json",
+ "sentencepiece.bpe.model",
+ "tokenizer.model",
+ ]
+ files = []
+ for name in candidates:
+ path = os.path.join(tokenizer_dir, name)
+ if os.path.isfile(path):
+ try:
+ stat = os.stat(path)
+ files.append({"path": name, "mtime": stat.st_mtime, "size": stat.st_size})
+ except OSError:
+ files.append({"path": name, "mtime": None, "size": None})
+ return files
+
+
+def _sanitize_cache_tag(tag):
+ if not tag:
+ return ""
+ safe = "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in str(tag))
+ return safe.strip("._-")
+
+
+def _cache_paths(tokenizer_dir, cache_tag=None):
+ suffix = _sanitize_cache_tag(cache_tag)
+ if suffix:
+ cache_file = os.path.join(tokenizer_dir, f"tokenizer.wgp.full.{suffix}.pkl")
+ meta_file = os.path.join(tokenizer_dir, f"tokenizer.wgp.full.{suffix}.meta.json")
+ else:
+ cache_file = os.path.join(tokenizer_dir, "tokenizer.wgp.full.pkl")
+ meta_file = os.path.join(tokenizer_dir, "tokenizer.wgp.full.meta.json")
+ return cache_file, meta_file
+
+
+def _read_cache_meta(meta_file):
+ try:
+ with open(meta_file, "r", encoding="utf-8") as handle:
+ return json.load(handle)
+ except Exception:
+ return None
+
+
+def _meta_matches(meta, tokenizer_dir):
+ if not meta:
+ return False
+ if tuple(meta.get("py_version", [])) != tuple(sys.version_info[:3]):
+ return False
+ if meta.get("transformers_version") != _get_transformers_version():
+ return False
+ if meta.get("tokenizers_version") != _get_tokenizers_version():
+ return False
+ expected_files = meta.get("files", [])
+ current_files = _collect_tokenizer_files(tokenizer_dir)
+ if len(expected_files) != len(current_files):
+ return False
+ current_map = {f.get("path"): f for f in current_files}
+ for entry in expected_files:
+ cur = current_map.get(entry.get("path"))
+ if cur is None:
+ return False
+ if entry.get("mtime") != cur.get("mtime") or entry.get("size") != cur.get("size"):
+ return False
+ return True
+
+
+def _load_full_tokenizer_cache(tokenizer_dir, cache_tag=None):
+ cache_file, meta_file = _cache_paths(tokenizer_dir, cache_tag=cache_tag)
+ if not os.path.isfile(cache_file) or not os.path.isfile(meta_file):
+ return None
+ meta = _read_cache_meta(meta_file)
+ if not _meta_matches(meta, tokenizer_dir):
+ return None
+ try:
+ with open(cache_file, "rb") as handle:
+ return pickle.load(handle)
+ except Exception:
+ return None
+
+
+def _save_full_tokenizer_cache(tokenizer_dir, tokenizer, cache_tag=None):
+ cache_file, meta_file = _cache_paths(tokenizer_dir, cache_tag=cache_tag)
+ meta = {
+ "py_version": list(sys.version_info[:3]),
+ "transformers_version": _get_transformers_version(),
+ "tokenizers_version": _get_tokenizers_version(),
+ "files": _collect_tokenizer_files(tokenizer_dir),
+ }
+ try:
+ with open(cache_file, "wb") as handle:
+ pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
+ with open(meta_file, "w", encoding="utf-8") as handle:
+ json.dump(meta, handle)
+ except Exception:
+ pass
+
+
+def load_cached_lm_tokenizer(tokenizer_dir, loader_fn, cache_tag=None):
+ if not tokenizer_dir:
+ return loader_fn()
+ cached = _load_full_tokenizer_cache(tokenizer_dir, cache_tag=cache_tag)
+ if cached is not None:
+ return cached
+ patch_pretrained_tokenizer_fast(allow_paths=[tokenizer_dir])
+ try:
+ tokenizer = loader_fn()
+ finally:
+ unpatch_pretrained_tokenizer_fast()
+ _save_full_tokenizer_cache(tokenizer_dir, tokenizer, cache_tag=cache_tag)
+ return tokenizer
diff --git a/Wan2GP/shared/utils/utils.py b/Wan2GP/shared/utils/utils.py
index e6658f955..c30b4cc81 100644
--- a/Wan2GP/shared/utils/utils.py
+++ b/Wan2GP/shared/utils/utils.py
@@ -99,6 +99,25 @@ def truncate_for_filesystem(s, max_bytes=None):
def get_default_workers():
return os.cpu_count()/ 2
+def to_rgb_tensor(value, device="cpu", dtype=torch.float):
+ if isinstance(value, torch.Tensor):
+ tensor = value.to(device=device, dtype=dtype)
+ else:
+ if isinstance(value, (list, tuple, np.ndarray)):
+ vals = value
+ else:
+ vals = [value, value, value]
+ tensor = torch.tensor(vals, device=device, dtype=dtype)
+ if tensor.numel() == 1:
+ tensor = tensor.repeat(3)
+ elif tensor.numel() != 3:
+ tensor = tensor.flatten()
+ if tensor.numel() < 3:
+ tensor = tensor.repeat(3)[:3]
+ else:
+ tensor = tensor[:3]
+ return tensor.view(3, 1, 1)
+
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2, in_place = False) :
if not items:
return []
@@ -244,6 +263,8 @@ def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
t = t[:, frame_no]
if t.shape[0]== 1:
t = t.expand(3,-1,-1)
+ if t.dtype == torch.uint8:
+ return Image.fromarray(t.permute(1, 2, 0).cpu().numpy())
if mask_levels:
return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
else:
@@ -372,7 +393,8 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
return output_list, output_mask_list
def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
- inpaint_color = canvas_tf_bg / 127.5 - 1
+ inpaint_color = to_rgb_tensor(canvas_tf_bg, device=device, dtype=torch.float) / 127.5 - 1
+ inpaint_color = inpaint_color.unsqueeze(1)
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
@@ -401,10 +423,10 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
- canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
+ canvas = inpaint_color.expand(3, 1, final_height, final_width).clone()
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
- canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
+ canvas = inpaint_color.expand(3, 1, canvas_height, canvas_width).clone()
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
@@ -423,7 +445,8 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"):
src_videos, src_masks = [], []
- inpaint_color_compressed = guide_inpaint_color/127.5 - 1
+ inpaint_color_compressed = to_rgb_tensor(guide_inpaint_color, device=device, dtype=torch.float) / 127.5 - 1
+ inpaint_color_compressed = inpaint_color_compressed.unsqueeze(1)
prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
src_video, src_mask = cur_video_guide, cur_video_mask
@@ -434,11 +457,14 @@ def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, im
if any_guide_padding:
if src_video is None:
- src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
+ src_video = inpaint_color_compressed.expand(3, current_video_length, *image_size).clone()
elif src_video.shape[1] < current_video_length:
- src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
+ pad = inpaint_color_compressed.to(src_video.device).expand(3, current_video_length - src_video.shape[1], *src_video.shape[-2:]).clone()
+ src_video = torch.cat([src_video, pad], dim=1)
elif src_video is not None:
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
+ if new_num_frames < src_video.shape[1]:
+ print(f"invalid number of control frames {src_video.shape[1]}, potentially {src_video.shape[1]-new_num_frames} frames will be lost")
src_video = src_video[:, :new_num_frames]
if any_mask and src_video is not None:
@@ -453,7 +479,7 @@ def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, im
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
- src_video[:, pos:pos+1] = inpaint_color_compressed
+ src_video[:, pos:pos+1] = inpaint_color_compressed.to(src_video.device)
if any_mask: src_mask[:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
diff --git a/Wan2GP/wgp.py b/Wan2GP/wgp.py
index 9d7c47066..af789e153 100644
--- a/Wan2GP/wgp.py
+++ b/Wan2GP/wgp.py
@@ -7,6 +7,9 @@
p = os.path.dirname(os.path.abspath(__file__))
if p not in sys.path:
sys.path.insert(0, p)
+# Ensure plugin-side `import wgp` resolves to this live module instance.
+if sys.modules.get("wgp") is not sys.modules.get(__name__):
+ sys.modules["wgp"] = sys.modules[__name__]
import asyncio
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@@ -19,8 +22,7 @@
import argparse
import warnings
warnings.filterwarnings('ignore', message='Failed to find.*', module='triton')
-from mmgp import offload, safetensors2, profile_type , fp8_quanto_bridge, quant_router
-if not os.name == "nt": fp8_quanto_bridge.enable_fp8_marlin_fallback()
+from mmgp import offload, safetensors2, profile_type , quant_router
try:
import triton
except ImportError:
@@ -34,11 +36,11 @@
import importlib
from shared.utils import notification_sound
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
-from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
+from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask, to_rgb_tensor
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
from shared.utils.utils import has_video_file_extension, has_image_file_extension, has_audio_file_extension
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
-from shared.utils.audio_video import save_image_metadata, read_image_metadata
+from shared.utils.audio_video import save_image_metadata, read_image_metadata, extract_audio_track_to_wav, write_wav_file, save_audio_file, get_audio_codec_extension
from shared.utils.audio_metadata import save_audio_metadata, read_audio_metadata
from shared.utils.video_metadata import save_video_metadata
from shared.match_archi import match_nvidia_architecture
@@ -49,6 +51,7 @@
from huggingface_hub import hf_hub_download, snapshot_download
from shared.utils import files_locator as fl
from shared.gradio.audio_gallery import AudioGallery
+from shared.utils.self_refiner import normalize_self_refiner_plan, ensure_refiner_list, add_refiner_rule, remove_refiner_rule
import torch
import gc
import traceback
@@ -66,6 +69,8 @@
import glob
import cv2
import html
+from gradio_rangeslider import RangeSlider
+import re
from transformers.utils import logging
logging.set_verbosity_error
from tqdm import tqdm
@@ -73,24 +78,33 @@
from shared.gradio.gallery import AdvancedMediaGallery
from shared.ffmpeg_setup import download_ffmpeg
from shared.utils.plugins import PluginManager, WAN2GPApplication, SYSTEM_PLUGINS
+from shared.llm_engines.nanovllm.vllm_support import resolve_lm_decoder_engine
from collections import defaultdict
# import torch._dynamo as dynamo
# dynamo.config.recompile_limit = 2000 # default is 256
# dynamo.config.accumulated_recompile_limit = 2000 # or whatever limit you want
+STARTUP_LOCK_FILE = "startup.lock"
global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip"
AUTOSAVE_PATH = AUTOSAVE_FILENAME
+AUTOSAVE_ERROR_FILENAME = "error_queue.zip"
AUTOSAVE_TEMPLATE_PATH = AUTOSAVE_FILENAME
CONFIG_FILENAME = "wgp_config.json"
PROMPT_VARS_MAX = 10
-target_mmgp_version = "3.6.10"
-WanGP_version = "10.01"
-settings_version = 2.42
+target_mmgp_version = "3.7.4"
+WanGP_version = "10.83"
+settings_version = 2.50
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
image_names_list = ["image_start", "image_end", "image_refs"]
+CUSTOM_SETTINGS_MAX = 6
+CUSTOM_SETTINGS_PER_ROW = 2
+CUSTOM_SETTING_TYPES = {"int", "float", "text"}
+lm_decoder_engine = ""
+lm_decoder_engine_obtained = "legacy"
+enable_int8_kernels = 0
# All media attachment keys for queue save/load
ATTACHMENT_KEYS = ["image_start", "image_end", "image_refs", "image_guide", "image_mask",
"video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source", "custom_guide"]
@@ -108,12 +122,31 @@
offloadobj = enhancer_offloadobj = wan_model = None
reload_needed = True
_HANDLER_MODULES = [
+ "shared.qtypes.scaled_fp8",
"shared.qtypes.nvfp4",
"shared.qtypes.nunchaku_int4",
"shared.qtypes.nunchaku_fp4",
+ "shared.qtypes.gguf",
]
+quant_router.unregister_handler(".fp8_quanto_bridge")
for handler in _HANDLER_MODULES:
quant_router.register_handler(handler)
+from shared.qtypes import gguf as gguf_handler
+quant_router.register_file_extension("gguf", gguf_handler)
+from shared.kernels.quanto_int8_inject import maybe_enable_quanto_int8_kernel, disable_quanto_int8_kernel
+
+
+def apply_int8_kernel_setting(enabled: int, notify_disabled = False) -> bool:
+ global enable_int8_kernels, verbose_level
+ try:
+ enable_int8_kernels = 1 if int(enabled) == 1 else 0
+ except Exception:
+ enable_int8_kernels = 0
+ os.environ["WAN2GP_QUANTO_INT8_KERNEL"] = "1" if enable_int8_kernels == 1 else "0"
+ if enable_int8_kernels == 1:
+ return bool(maybe_enable_quanto_int8_kernel(verbose_level=verbose_level))
+ disable_quanto_int8_kernel(notify_disabled)
+ return False
def set_wgp_global(variable_name: str, new_value: any) -> str:
if variable_name not in globals():
@@ -133,38 +166,17 @@ def clear_gen_cache():
if "_cache" in offload.shared_state:
del offload.shared_state["_cache"]
-def _flush_torch_memory():
- gc.collect()
- if torch.cuda.is_available():
- try:
- torch.cuda.synchronize()
- except torch.cuda.CudaError:
- pass
- for idx in range(torch.cuda.device_count()):
- with torch.cuda.device(idx):
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- torch.cuda.reset_peak_memory_stats()
- try:
- torch._C._host_emptyCache()
- except AttributeError:
- pass
- if os.name == "nt":
- try:
- import ctypes, ctypes.wintypes as wintypes, os as _os
- PROCESS_SET_QUOTA = 0x0100
- PROCESS_QUERY_INFORMATION = 0x0400
- kernel32 = ctypes.windll.kernel32
- psapi = ctypes.windll.psapi
- handle = kernel32.OpenProcess(PROCESS_SET_QUOTA | PROCESS_QUERY_INFORMATION, False, _os.getpid())
- if handle:
- psapi.EmptyWorkingSet(handle)
- kernel32.CloseHandle(handle)
- except Exception:
- pass
+
def release_model():
global wan_model, offloadobj, reload_needed
+ if wan_model is not None:
+ close_fn = getattr(wan_model, "close", None)
+ if callable(close_fn):
+ try:
+ close_fn()
+ except Exception:
+ pass
wan_model = None
clear_gen_cache()
if "_cache" in offload.shared_state:
@@ -172,12 +184,7 @@ def release_model():
if offloadobj is not None:
offloadobj.release()
offloadobj = None
- _flush_torch_memory()
- from accelerate import init_empty_weights
- with init_empty_weights():
- for _ in range(3):
- dummy_tensor = torch.nn.Embedding(256384, 1024)
- dummy_tensor = None
+ offload.flush_torch_caches()
reload_needed = True
def get_unique_id():
global unique_id
@@ -387,7 +394,7 @@ def ret():
inputs["state"] = state
inputs["model_type"] = model_type
- inputs.pop("lset_name")
+ inputs.pop("lset_name", None)
if inputs == None:
gr.Warning("Internal state error: Could not retrieve inputs for the model.")
queue = gen.get("queue", [])
@@ -419,7 +426,7 @@ def ret():
if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0:
gr.Info("Temporal Upsampling can not be used with an Image")
return ret()
- film_grain_intensity = inputs.get("film_grain_intensity",0)
+ film_grain_intensity = inputs.get("film_grain<<_intensity",0)
film_grain_saturation = inputs.get("film_grain_saturation",0.5)
# if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"]
if film_grain_intensity >0: prompt += ["Film Grain"]
@@ -452,9 +459,9 @@ def ret():
queue= gen.get("queue", [])
return update_queue_data(queue), gr.update(open=True) if new_prompts_count > 1 else gr.update()
- override_inputs, prompts, image_start, image_end = validate_settings(state, model_type, False, inputs)
+ inputs, prompts, image_start, image_end = validate_settings(state, model_type, False, inputs)
- if override_inputs is None:
+ if inputs is None:
return ret()
multi_prompts_gen_type = inputs["multi_prompts_gen_type"]
@@ -502,23 +509,20 @@ def ret():
image_end = [None] * len(prompts)
for single_prompt, start, end in zip(prompts, image_start, image_end) :
- override_inputs.update({
+ inputs.update({
"prompt" : single_prompt,
"image_start": start,
"image_end" : end,
})
- inputs.update(override_inputs)
add_video_task(**inputs)
else:
for single_prompt in prompts :
- override_inputs["prompt"] = single_prompt
- inputs.update(override_inputs)
+ inputs["prompt"] = single_prompt
add_video_task(**inputs)
new_prompts_count = len(prompts)
else:
new_prompts_count = 1
- override_inputs["prompt"] = "\n".join(prompts)
- inputs.update(override_inputs)
+ inputs["prompt"] = "\n".join(prompts)
add_video_task(**inputs)
new_prompts_count += gen.get("prompts_max",0)
gen["prompts_max"] = new_prompts_count
@@ -526,6 +530,123 @@ def ret():
queue= gen.get("queue", [])
return update_queue_data(queue), gr.update(open=True) if new_prompts_count > 1 else gr.update()
+def get_custom_setting_key(index):
+ return f"custom_setting_{index + 1}"
+
+def _normalize_custom_setting_type(setting_type):
+ parsed_type = str(setting_type or "text").strip().lower()
+ return parsed_type if parsed_type in CUSTOM_SETTING_TYPES else "text"
+
+def _normalize_custom_setting_name(name):
+ normalized = re.sub(r"[^a-z0-9_]+", "_", str(name or "").strip().lower()).strip("_")
+ return normalized
+
+def get_custom_setting_id(setting_def, setting_index):
+ explicit_id = setting_def.get("id", None)
+ if explicit_id is not None and len(str(explicit_id).strip()) > 0:
+ normalized_id = _normalize_custom_setting_name(explicit_id)
+ if len(normalized_id) > 0:
+ return normalized_id
+ for field_name in ("name", "param"):
+ normalized_name = _normalize_custom_setting_name(setting_def.get(field_name, ""))
+ if len(normalized_name) > 0:
+ return normalized_name
+ return get_custom_setting_key(setting_index)
+
+def get_model_custom_settings(model_def):
+ if not isinstance(model_def, dict):
+ return []
+ custom_settings = model_def.get("custom_settings", [])
+ if not isinstance(custom_settings, list):
+ return []
+ normalized = []
+ used_ids = set()
+ for idx, setting in enumerate(custom_settings[:CUSTOM_SETTINGS_MAX]):
+ if not isinstance(setting, dict):
+ continue
+ one = setting.copy()
+ one["label"] = str(one.get("label", f"Custom Setting {idx + 1}"))
+ one["name"] = str(one.get("name", f"Custom Setting {idx + 1}"))
+ one["type"] = _normalize_custom_setting_type(one.get("type", "text"))
+ setting_id = get_custom_setting_id(one, idx)
+ if setting_id in used_ids:
+ setting_id = get_custom_setting_key(idx)
+ used_ids.add(setting_id)
+ one["id"] = setting_id
+ normalized.append(one)
+ return normalized
+
+def parse_custom_setting_typed_value(raw_value, setting_type):
+ if raw_value is None:
+ return None, None
+ if isinstance(raw_value, str):
+ raw_value = raw_value.strip()
+ if len(raw_value) == 0:
+ return None, None
+ setting_type = _normalize_custom_setting_type(setting_type)
+ if setting_type == "int":
+ if isinstance(raw_value, bool):
+ return None, "Expected an integer value."
+ if isinstance(raw_value, int):
+ return raw_value, None
+ if isinstance(raw_value, float):
+ if raw_value.is_integer():
+ return int(raw_value), None
+ return None, "Expected an integer value."
+ try:
+ return int(str(raw_value).strip()), None
+ except Exception:
+ try:
+ float_value = float(str(raw_value).strip())
+ if float_value.is_integer():
+ return int(float_value), None
+ except Exception:
+ pass
+ return None, "Expected an integer value."
+ if setting_type == "float":
+ if isinstance(raw_value, bool):
+ return None, "Expected a float value."
+ try:
+ return float(raw_value), None
+ except Exception:
+ return None, "Expected a float value."
+ return str(raw_value).strip(), None
+
+def get_custom_setting_value_from_dict(custom_settings_values, setting_def, setting_index):
+ setting_id = setting_def.get("id", get_custom_setting_id(setting_def, setting_index))
+ if isinstance(custom_settings_values, dict) and setting_id in custom_settings_values:
+ return custom_settings_values.get(setting_id, None)
+ return setting_def.get("default", "")
+
+def collect_custom_settings_from_inputs(model_def, inputs, strict=False):
+ custom_settings_dict = {}
+ existing_custom_settings = inputs.get("custom_settings", None)
+ if not isinstance(existing_custom_settings, dict):
+ existing_custom_settings = {}
+ custom_settings = get_model_custom_settings(model_def)
+ for idx, setting_def in enumerate(custom_settings):
+ slot_key = get_custom_setting_key(idx)
+ setting_id = setting_def["id"]
+ raw_value = inputs.get(slot_key, None)
+ if raw_value is None and setting_id in existing_custom_settings:
+ raw_value = existing_custom_settings.get(setting_id, None)
+ parsed_value, parse_error = parse_custom_setting_typed_value(raw_value, setting_def.get("type", "text"))
+ if parse_error is not None:
+ if strict:
+ return None, f"{setting_def.get('label', slot_key)} {parse_error}"
+ if raw_value is not None:
+ raw_text = str(raw_value).strip() if isinstance(raw_value, str) else raw_value
+ if not (isinstance(raw_text, str) and len(raw_text) == 0):
+ custom_settings_dict[setting_id] = raw_text
+ continue
+ if parsed_value is not None:
+ custom_settings_dict[setting_id] = parsed_value
+ return custom_settings_dict if len(custom_settings_dict) > 0 else None, None
+
+def clear_custom_setting_slots(inputs):
+ for idx in range(CUSTOM_SETTINGS_MAX):
+ inputs.pop(get_custom_setting_key(idx), None)
+
def validate_settings(state, model_type, single_prompt, inputs):
def ret():
return None, None, None, None
@@ -538,39 +659,35 @@ def ret():
model_filename = get_model_filename(model_type)
- if hasattr(model_handler, "validate_generative_settings"):
- error = model_handler.validate_generative_settings(model_type, model_def, inputs)
- if error is not None and len(error) > 0:
- gr.Info(error)
- return ret()
if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0:
gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time")
return ret()
prompt = inputs["prompt"]
- if len(prompt) ==0:
+ prompt, errors = prompt_parser.process_template(prompt, keep_empty_lines=model_def.get("preserve_empty_prompt_lines", False))
+ if len(errors) > 0:
+ gr.Info("Error processing prompt template: " + errors)
+ return ret()
+ prompt = prompt.strip("\n").strip()
+
+ if len(prompt) == 0:
gr.Info("Prompt cannot be empty.")
gen = get_gen_info(state)
queue = gen.get("queue", [])
return ret()
- prompt, errors = prompt_parser.process_template(prompt)
- if len(errors) > 0:
- gr.Info("Error processing prompt template: " + errors)
- return ret()
multi_prompts_gen_type = inputs["multi_prompts_gen_type"]
-
- prompts = prompt.replace("\r", "").split("\n")
- prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
-
if single_prompt or multi_prompts_gen_type == 2:
- prompts = ["\n".join(prompts)]
+ prompts = [prompt]
+ else:
+ prompts = [one_line.strip() for one_line in prompt.split("\n") if len(one_line.strip()) > 0]
- if len(prompts) == 0:
- gr.Info("Prompt cannot be empty.")
- gen = get_gen_info(state)
- queue = gen.get("queue", [])
+ parsed_custom_settings, custom_settings_error = collect_custom_settings_from_inputs(model_def, inputs, strict=True)
+ if custom_settings_error is not None:
+ gr.Info(custom_settings_error)
return ret()
+ inputs["custom_settings"] = parsed_custom_settings
+ clear_custom_setting_slots(inputs)
if hasattr(model_handler, "validate_generative_prompt"):
for one_prompt in prompts:
@@ -606,6 +723,7 @@ def ret():
keep_frames_video_source = inputs["keep_frames_video_source"]
denoising_strength= inputs["denoising_strength"]
masking_strength= inputs["masking_strength"]
+ input_video_strength = inputs.get("input_video_strength", 1.0)
sliding_window_size = inputs["sliding_window_size"]
sliding_window_overlap = inputs["sliding_window_overlap"]
sliding_window_discard_last_frames = inputs["sliding_window_discard_last_frames"]
@@ -624,12 +742,34 @@ def ret():
video_guide_outpainting = inputs["video_guide_outpainting"]
spatial_upsampling = inputs["spatial_upsampling"]
motion_amplitude = inputs["motion_amplitude"]
+ self_refiner_setting = inputs["self_refiner_setting"]
+ self_refiner_plan = inputs["self_refiner_plan"]
+ model_mode = inputs["model_mode"]
medium = "Videos" if image_mode == 0 else "Images"
+ if image_start is not None and not isinstance(image_start, list): image_start = [image_start]
+ outpainting_modes = model_def.get("video_guide_outpainting", [])
+ if image_mode not in outpainting_modes:
+ video_guide_outpainting = ""
+
outpainting_dims = get_outpainting_dims(video_guide_outpainting)
+ model_modes_visibility = [0,1,2]
+ model_mode_choices = model_def.get("model_modes", None)
+ if model_mode_choices is not None: model_modes_visibility= model_mode_choices.get("image_modes", model_modes_visibility)
+ if model_mode is not None and image_mode not in model_modes_visibility:
+ model_mode = None
if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"):
gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting")
+ if self_refiner_setting != 0:
+ if isinstance(self_refiner_plan, list):
+ max_plans = model_def.get("self_refiner_max_plans", 1)
+ _, error = normalize_self_refiner_plan(self_refiner_plan, max_plans=max_plans)
+ if len(error):
+ gr.Info(error)
+ return ret()
+ else:
+ self_refiner_plan = []
if not model_def.get("motion_amplitude", False): motion_amplitude = 1.
if "vae" in spatial_upsampling:
@@ -642,7 +782,11 @@ def ret():
if len(error) > 0:
gr.Info(error)
return ret()
-
+ if model_def.get("lock_guidance_phases", False):
+ guidance_phases = model_def.get("guidance_max_phases", 0)
+ else:
+ guidance_phases = min(guidance_phases, model_def.get("guidance_max_phases", 0))
+
if len(loras_multipliers) > 0:
_, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases)
if len(errors) > 0:
@@ -667,7 +811,11 @@ def ret():
if image_mode > 0:
audio_prompt_type = ""
- if "B" in audio_prompt_type or "X" in audio_prompt_type:
+ if "K" in audio_prompt_type and "V" not in video_prompt_type:
+ gr.Info("You must enable a Control Video to use the Control Video Audio Track as an audio prompt")
+ return ret()
+
+ if ("B" in audio_prompt_type or "X" in audio_prompt_type) and not model_def.get("one_speaker_only", False):
from models.wan.multitalk.multitalk import parse_speakers_locations
speakers_bboxes, error = parse_speakers_locations(speakers_locations)
if len(error) > 0:
@@ -718,20 +866,24 @@ def ret():
else:
video_source = None
+ if len(model_def.get("input_video_strength", ""))==0 or not any_letters(image_prompt_type, "SVL"):
+ input_video_strength = 1.0
+
if "A" in audio_prompt_type:
if audio_guide == None:
gr.Info("You must provide an Audio Source")
return ret()
- if "B" in audio_prompt_type:
- if audio_guide2 == None:
- gr.Info("You must provide a second Audio Source")
- return ret()
- else:
- audio_guide2 = None
else:
audio_guide = None
+
+
+ if "B" in audio_prompt_type:
+ if audio_guide2 == None:
+ gr.Info("You must provide a second Audio Source")
+ return ret()
+ else:
audio_guide2 = None
-
+
if model_type in ["vace_multitalk_14B"] and ("B" in audio_prompt_type or "X" in audio_prompt_type):
if not "I" in video_prompt_type and not not "V" in video_prompt_type:
gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame or a Control Video (potentially truncated) that contains the two people one on each side")
@@ -782,7 +934,7 @@ def ret():
image_mask = None
if "G" in video_prompt_type:
- if denoising_strength < 1.:
+ if denoising_strength < 1. and not model_def.get("custom_denoising_strength", False):
gr.Info(f"With Denoising Strength {denoising_strength:.1f}, Denoising will start at Step no {int(round(num_inference_steps * (1. - denoising_strength),4))} ")
else:
denoising_strength = 1.0
@@ -790,7 +942,8 @@ def ret():
if "G" in video_prompt_type or model_def.get("mask_strength_always_enabled", False):
if "A" in video_prompt_type and "U" not in video_prompt_type and masking_strength < 1.:
masking_duration = math.ceil(num_inference_steps * masking_strength)
- gr.Info(f"With Masking Strength {masking_strength:.1f}, Masking will last {masking_duration}{' Step' if masking_duration==1 else ' Steps'}")
+ if masking_strength:
+ gr.Info(f"With Masking Strength {masking_strength:.1f}, Masking will last {masking_duration}{' Step' if masking_duration==1 else ' Steps'}")
else:
masking_strength = 1.0
if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
@@ -818,7 +971,6 @@ def ret():
-
if "S" in image_prompt_type:
if model_def.get("black_frame", False) and len(image_start or [])==0:
if "E" in image_prompt_type and len(image_end or []):
@@ -913,6 +1065,7 @@ def ret():
"video_source": video_source,
"frames_positions": frames_positions,
"keep_frames_video_source": keep_frames_video_source,
+ "input_video_strength": input_video_strength,
"keep_frames_video_guide": keep_frames_video_guide,
"denoising_strength": denoising_strength,
"masking_strength": masking_strength,
@@ -922,8 +1075,17 @@ def ret():
"skip_steps_cache_type": skip_steps_cache_type,
"model_switch_phase": model_switch_phase,
"motion_amplitude": motion_amplitude,
+ "model_mode": model_mode,
+ "video_guide_outpainting": video_guide_outpainting,
+ "custom_settings": inputs.get("custom_settings", None),
}
- return override_inputs, prompts, image_start, image_end
+ inputs.update(override_inputs)
+ if hasattr(model_handler, "validate_generative_settings"):
+ error = model_handler.validate_generative_settings(model_type, model_def, inputs)
+ if error is not None and len(error) > 0:
+ gr.Info(error)
+ return ret()
+ return inputs, prompts, image_start, image_end
def get_preview_images(inputs):
@@ -1195,8 +1357,8 @@ def _load_task_attachments(params, media_base_path, cache_dir=None, log_prefix="
# Update params, preserving list/single structure
if loaded_items:
- has_pil_item = any(isinstance(item, Image.Image) for item in loaded_items)
- if is_originally_list or has_pil_item:
+ # has_pil_item = any(isinstance(item, Image.Image) for item in loaded_items)
+ if is_originally_list: # or has_pil_item
params[key] = loaded_items
else:
params[key] = loaded_items[0]
@@ -1515,6 +1677,7 @@ def clear_queue_action(state):
def quit_application():
print("Save and Quit requested...")
+ clear_startup_lock()
autosave_queue()
import signal
os.kill(os.getpid(), signal.SIGINT)
@@ -1551,32 +1714,60 @@ def autosave_queue():
def finalize_generation_with_state(current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
- return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state
-
- gallery_update, audio_files_paths_update, audio_file_selected_update, audio_gallery_refresh_trigger_update, gallery_tabs_update, current_gallery_tab_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
+ return (
+ gr.update(),
+ gr.update(),
+ gr.update(),
+ gr.update(),
+ gr.update(),
+ gr.update(),
+ gr.update(interactive=True),
+ gr.update(interactive=True),
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False, value=""),
+ gr.update(),
+ current_state,
+ )
+
+ gallery_tabs_update, current_gallery_tab_update, gallery_update, audio_files_paths_update, audio_file_selected_update, audio_gallery_refresh_trigger_update, abort_btn_update, earlystop_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update()
- return gallery_update, audio_files_paths_update, audio_file_selected_update, audio_gallery_refresh_trigger_update, gallery_tabs_update, current_gallery_tab_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state
+ return gallery_tabs_update, current_gallery_tab_update, gallery_update, audio_files_paths_update, audio_file_selected_update, audio_gallery_refresh_trigger_update, abort_btn_update, earlystop_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state
def generate_queue_html(queue):
if len(queue) <= 1:
return "