From 0b86a6b21750ad7c765100b961bf8206ba49b69a Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Fri, 31 Oct 2025 11:33:28 +0800 Subject: [PATCH 1/6] add Video-As-Prompt-Wan2.1-14B inference --- diffsynth/configs/model_config.py | 2 + diffsynth/models/wan_video_dit.py | 29 ++++- diffsynth/pipelines/wan_video_new.py | 108 +++++++++++++++++- .../model_inference/Wan2.1-VAP-14B.py | 66 +++++++++++ 4 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.1-VAP-14B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 47e26e0c..743ffba5 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -64,6 +64,7 @@ from ..models.wan_video_vace import VaceWanModel from ..models.wav2vec import WanS2VAudioEncoder from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel from ..models.step1x_connector import Qwen2Connector @@ -157,6 +158,7 @@ (None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"), (None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"), (None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"), + (None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index cdebad43..b4af5fa2 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -437,6 +437,11 @@ def from_diffusers(self, state_dict): "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", @@ -454,6 +459,14 @@ def from_diffusers(self, state_dict): "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", "condition_embedder.time_proj.bias": "time_projection.1.bias", "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", "patch_embedding.bias": "patch_embedding.bias", "patch_embedding.weight": "patch_embedding.weight", "scale_shift_table": "head.modulation", @@ -470,7 +483,7 @@ def from_diffusers(self, state_dict): name_ = rename_dict[name_] name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) state_dict_[name_] = param - if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": + if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b": config = { "model_type": "t2v", "patch_size": (1, 2, 2), @@ -488,6 +501,20 @@ def from_diffusers(self, state_dict): "cross_attn_norm": True, "eps": 1e-6, } + elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } else: config = {} return state_dict_, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index d374afd8..efefb661 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -22,6 +22,7 @@ from ..models.wan_video_vace import VaceWanModel from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel from ..models.longcat_video_dit import LongCatVideoTransformer3DModel from ..schedulers.flow_match import FlowMatchScheduler from ..prompters import WanPrompter @@ -47,8 +48,9 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None self.vace2: VaceWanModel = None + self.vap: MotWanModel = None self.animate_adapter: WanAnimateAdapter = None - self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter") + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter","vap") self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter") self.unit_runner = PipelineUnitRunner() self.units = [ @@ -69,6 +71,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non WanVideoPostUnit_AnimatePoseLatents(), WanVideoPostUnit_AnimateFacePixelValues(), WanVideoPostUnit_AnimateInpaint(), + WanVideoUnit_VAP(), WanVideoUnit_UnifiedSequenceParallel(), WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), @@ -392,6 +395,7 @@ def from_pretrained( pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") vace = model_manager.fetch_model("wan_video_vace", index=2) + pipe.vap = model_manager.fetch_model("wan_video_vap") if isinstance(vace, list): pipe.vace, pipe.vace2 = vace else: @@ -455,6 +459,10 @@ def __call__( animate_face_video: Optional[list[Image.Image]] = None, animate_inpaint_video: Optional[list[Image.Image]] = None, animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", # Randomness seed: Optional[int] = None, rand_device: Optional[str] = "cpu", @@ -493,10 +501,12 @@ def __call__( # Inputs inputs_posi = { "prompt": prompt, + "vap_prompt":vap_prompt, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, } inputs_nega = { "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, } inputs_shared = { @@ -516,6 +526,7 @@ def __call__( "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -927,6 +938,73 @@ def process( else: return {"vace_context": None, "vace_scale": vace_scale} +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder","vae","image_encoder") + ) + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.vap is None or inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt"), inputs_nega.get("negative_vap_prompt") + vap_prompt_emb = pipe.prompter.encode_prompt(vap_prompt, positive=inputs_posi.get('positive',None), device=pipe.device) + negative_vap_prompt_emb = pipe.prompter.encode_prompt(negative_vap_prompt, positive=inputs_nega.get('positive',None), device=pipe.device) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae"]) + pipe.load_models_to_device(["image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1) + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + # latent_mot_ref = (latent_mot_ref - latents_mean) * latents_std + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + pipe.load_models_to_device([]) + + return inputs_shared, inputs_posi, inputs_nega + class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): @@ -1285,6 +1363,7 @@ def model_fn_wan_video( dit: WanModel, motion_controller: WanMotionControllerModel = None, vace: VaceWanModel = None, + vap: MotWanModel = None, animate_adapter: WanAnimateAdapter = None, latents: torch.Tensor = None, timestep: torch.Tensor = None, @@ -1297,6 +1376,9 @@ def model_fn_wan_video( audio_embeds: Optional[torch.Tensor] = None, motion_latents: Optional[torch.Tensor] = None, s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, drop_motion_frames: bool = True, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, @@ -1406,7 +1488,6 @@ def model_fn_wan_video( if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - # Camera control x = dit.patchify(x, control_camera_latents_input) @@ -1431,6 +1512,25 @@ def model_fn_wan_video( dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # VAP + if vap is not None: + # hidden state + x_vap = vap_hidden_state + x_vap = vap.patchify(x_vap) + x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() + # Timestep + clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) + t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) + t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) + + # rope + freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) + + # context + vap_clip_embedding = vap.img_emb(vap_clip_feature) + context_vap = vap.text_embedding(context_vap) + context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) # TeaCache if tea_cache is not None: @@ -1462,7 +1562,9 @@ def custom_forward(*inputs): for block_id, block in enumerate(dit.blocks): # Block - if use_gradient_checkpointing_offload: + if vap is not None and block_id in vap.mot_layers_mapping: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + elif use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), diff --git a/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py b/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py new file mode 100644 index 00000000..9931a9a1 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py @@ -0,0 +1,66 @@ +import torch +import PIL +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download +from typing import List + +def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]: + if len(video_frames) == 0: + return [] + if mode == "first": + return video_frames[:num] + if mode == "evenly": + import torch as _torch + idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist() + return [video_frames[i] for i in idx] + if mode == "random": + if len(video_frames) <= num: + return video_frames + import random as _random + start = _random.randint(0, len(video_frames) - num) + return video_frames[start:start+num] + return video_frames + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors",), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) + +ref_video_path = 'data/examples/wanvap/vap_ref.mp4' +target_image_path = 'data/examples/wanvap/input_image.jpg' + + +image = Image.open(target_image_path).convert("RGB") +ref_video = VideoData(ref_video_path, height=480, width=832) +ref_frames = select_frames(ref_video, num=49, mode= "evenly") + +vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery." +prompt="A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent." + +video = pipe( + prompt=prompt, + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + input_image=image, + seed=42, tiled=True, + height=480, width=832, + num_frames=49, + vap_video=ref_frames, + vap_prompt=vap_prompt, + negative_vap_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +) + +save_video(video, "video.mp4", fps=15, quality=5) \ No newline at end of file From 30bea528df8a979135dcd480e1ce59b5c6fd7d17 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Fri, 31 Oct 2025 11:41:08 +0800 Subject: [PATCH 2/6] add wan2.1-vap-14 inference --- diffsynth/pipelines/wan_video_new.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index efefb661..f03ee032 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -23,7 +23,10 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_animate_adapter import WanAnimateAdapter from ..models.wan_video_mot import MotWanModel +<<<<<<< Updated upstream from ..models.longcat_video_dit import LongCatVideoTransformer3DModel +======= +>>>>>>> Stashed changes from ..schedulers.flow_match import FlowMatchScheduler from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm From 9b8c9c3ee0f5ac563b9a0ab203c8afcf8bc03950 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Fri, 31 Oct 2025 11:46:32 +0800 Subject: [PATCH 3/6] add wan2.1-vap-14B-inference --- diffsynth/pipelines/wan_video_new.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index f03ee032..efefb661 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -23,10 +23,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_animate_adapter import WanAnimateAdapter from ..models.wan_video_mot import MotWanModel -<<<<<<< Updated upstream from ..models.longcat_video_dit import LongCatVideoTransformer3DModel -======= ->>>>>>> Stashed changes from ..schedulers.flow_match import FlowMatchScheduler from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm From ec872d9b2268e460faf69d66d41ba9394712abc5 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Fri, 31 Oct 2025 13:11:26 +0800 Subject: [PATCH 4/6] add wan2.1-vap-14B-inference --- Wan2.1-VAP-14B.py | 69 ++++++++++++++++++++++++++++ diffsynth/pipelines/wan_video_new.py | 4 +- 2 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 Wan2.1-VAP-14B.py diff --git a/Wan2.1-VAP-14B.py b/Wan2.1-VAP-14B.py new file mode 100644 index 00000000..dda49686 --- /dev/null +++ b/Wan2.1-VAP-14B.py @@ -0,0 +1,69 @@ +import torch +import PIL +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download +from typing import List + +def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]: + if len(video_frames) == 0: + return [] + if mode == "first": + return video_frames[:num] + if mode == "evenly": + import torch as _torch + idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist() + return [video_frames[i] for i in idx] + if mode == "random": + if len(video_frames) <= num: + return video_frames + import random as _random + start = _random.randint(0, len(video_frames) - num) + return video_frames[start:start+num] + return video_frames + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors",), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], +) + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) + +ref_video_path = 'data/examples/wanvap/vap_ref.mp4' +target_image_path = 'data/examples/wanvap/input_image.jpg' + +ref_video_path = '/mnt/nas2/zhiwen/wan2.2-fun/Video-As-Prompt/assets/videos/demo/man-534.mp4' +target_image_path = '/mnt/nas2/zhiwen/wan2.2-fun/Video-As-Prompt/assets/images/demo/woman-7.jpg' + + +image = Image.open(target_image_path).convert("RGB") +ref_video = VideoData(ref_video_path, height=480, width=832) +ref_frames = select_frames(ref_video, num=49, mode= "evenly") + +vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery." +prompt="A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent." + +video = pipe( + prompt=prompt, + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + input_image=image, + seed=42, tiled=True, + height=480, width=832, + num_frames=49, + vap_video=ref_frames, + vap_prompt=vap_prompt, + negative_vap_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +) + +save_video(video, "video.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index efefb661..9cd63b61 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -997,8 +997,6 @@ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_neg vap_video = pipe.preprocess_video(vap_video) vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - # latent_mot_ref = (latent_mot_ref - latents_mean) * latents_std - vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) inputs_shared.update({"vap_hidden_state":vap_latent}) pipe.load_models_to_device([]) @@ -1488,6 +1486,7 @@ def model_fn_wan_video( if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + # Camera control x = dit.patchify(x, control_camera_latents_input) @@ -1606,7 +1605,6 @@ def custom_forward(*inputs): x = dit.unpatchify(x, (f, h, w)) return x - def model_fn_longcat_video( dit: LongCatVideoTransformer3DModel, latents: torch.Tensor = None, From 870b46f24de760616570a69365a6877a19d8fc6a Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Fri, 31 Oct 2025 13:17:02 +0800 Subject: [PATCH 5/6] add wan2.1-vap-14B inference --- Wan2.1-VAP-14B.py | 69 ------- diffsynth/models/wan_video_mot.py | 307 ++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 69 deletions(-) delete mode 100644 Wan2.1-VAP-14B.py create mode 100644 diffsynth/models/wan_video_mot.py diff --git a/Wan2.1-VAP-14B.py b/Wan2.1-VAP-14B.py deleted file mode 100644 index dda49686..00000000 --- a/Wan2.1-VAP-14B.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import PIL -from PIL import Image -from diffsynth import save_video, VideoData -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download -from typing import List - -def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]: - if len(video_frames) == 0: - return [] - if mode == "first": - return video_frames[:num] - if mode == "evenly": - import torch as _torch - idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist() - return [video_frames[i] for i in idx] - if mode == "random": - if len(video_frames) <= num: - return video_frames - import random as _random - start = _random.randint(0, len(video_frames) - num) - return video_frames[start:start+num] - return video_frames - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors",), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"), - ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), - ], -) - -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) - -ref_video_path = 'data/examples/wanvap/vap_ref.mp4' -target_image_path = 'data/examples/wanvap/input_image.jpg' - -ref_video_path = '/mnt/nas2/zhiwen/wan2.2-fun/Video-As-Prompt/assets/videos/demo/man-534.mp4' -target_image_path = '/mnt/nas2/zhiwen/wan2.2-fun/Video-As-Prompt/assets/images/demo/woman-7.jpg' - - -image = Image.open(target_image_path).convert("RGB") -ref_video = VideoData(ref_video_path, height=480, width=832) -ref_frames = select_frames(ref_video, num=49, mode= "evenly") - -vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery." -prompt="A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent." - -video = pipe( - prompt=prompt, - negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", - input_image=image, - seed=42, tiled=True, - height=480, width=832, - num_frames=49, - vap_video=ref_frames, - vap_prompt=vap_prompt, - negative_vap_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" -) - -save_video(video, "video.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/diffsynth/models/wan_video_mot.py b/diffsynth/models/wan_video_mot.py new file mode 100644 index 00000000..c1d8ed53 --- /dev/null +++ b/diffsynth/models/wan_video_mot.py @@ -0,0 +1,307 @@ +import torch +from .wan_video_dit import DiTBlock, SelfAttention, CrossAttention, rope_apply,flash_attention,modulate,MLP +from .utils import hash_state_dict_keys +import einops +import torch.nn as nn + + +class MotSelfAttention(SelfAttention): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__(dim, num_heads, eps) + def forward(self, x, freqs, is_before_attn=False): + if is_before_attn: + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + return q, k, v + else: + return self.o(x) + + +class MotWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + + self.self_attn = MotSelfAttention(dim, num_heads, eps) + + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot): + + # 1. prepare scale parameter + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + + scale_params_mot_ref = self.modulation + t_mod_mot.float() + scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1) + shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2) + + # 2. Self-attention + input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa) + # original block self-attn + attn1 = wan_block.self_attn + q = attn1.norm_q(attn1.q(input_x)) + k = attn1.norm_k(attn1.k(input_x)) + v = attn1.v(input_x) + q = rope_apply(q, freqs, attn1.num_heads) + k = rope_apply(k, freqs, attn1.num_heads) + + # mot block self-attn + norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot) + norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1) + q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True) + + tmp_hidden_states = flash_attention( + torch.cat([q, q_mot], dim=-2), + torch.cat([k, k_mot], dim=-2), + torch.cat([v, v_mot], dim=-2), + num_heads=attn1.num_heads) + + attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2) + + attn_output = attn1.o(attn_output) + x = wan_block.gate(x, gate_msa, attn_output) + + attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False) + # gate + attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1) + attn_output_mot = attn_output_mot * gate_msa_mot_ref + attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot) + + # 3. cross-attention and feed-forward + x = x + wan_block.cross_attn(wan_block.norm3(x), context) + input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp) + x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x)) + + x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot) + # modulate + norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot) + norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1) + input_x_mot = self.ffn(norm_x_mot_ref) + # gate + input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1) + input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref + input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + input_x_mot).type_as(x_mot) + + return x, x_mot + + +class MotWanModel(torch.nn.Module): + def __init__( + self, + mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + patch_size=(1, 2, 2), + has_image_input=True, + has_image_pos_emb=False, + dim=5120, + num_heads=40, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + in_dim=36, + eps=1e-6, + ): + super().__init__() + self.mot_layers = mot_layers + self.freq_dim = freq_dim + self.dim = dim + + self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)} + self.head_dim = dim // num_heads + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) + + # mot blocks + self.blocks = torch.nn.ModuleList([ + MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.mot_layers + ]) + + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + return x + + def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0): + def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta) + h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + + freqs = torch.cat([ + f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1), + h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1), + w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1) + return freqs + + def forward( + self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + + ): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + block = self.blocks[self.mot_layers_mapping[block_id]] + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x,x_mot = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x,x_mot = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, + use_reentrant=False, + ) + else: + x,x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot) + + return x,x_mot + + @staticmethod + def state_dict_converter(): + return MotWanModelDictConverter() + + +class MotWanModelDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name} + if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7': + mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) + else: + mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) + mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)} + + state_dict_ = {} + + for name, param in state_dict.items(): + name = name.replace("_mot_ref", "") + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + if name.split(".")[1].isdigit(): + block_id = int(name.split(".")[1]) + name = name.replace(str(block_id), str(mot_layers_mapping[block_id])) + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + + if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B + config = { + "mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "num_heads": 40, + "eps": 1e-6 + } + else: + config = {} + return state_dict_, config + + + \ No newline at end of file From 4000b59582d5abe92910c270f3802ff09d762cab Mon Sep 17 00:00:00 2001 From: lzws <63908509+lzws@users.noreply.github.com> Date: Fri, 31 Oct 2025 13:20:42 +0800 Subject: [PATCH 6/6] wan2.1-vap-14B inference Removed dataset snapshot download function call. --- examples/wanvideo/model_inference/Wan2.1-VAP-14B.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py b/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py index 9931a9a1..d7f22d47 100644 --- a/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py +++ b/examples/wanvideo/model_inference/Wan2.1-VAP-14B.py @@ -34,11 +34,6 @@ def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> L ], ) -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) ref_video_path = 'data/examples/wanvap/vap_ref.mp4' target_image_path = 'data/examples/wanvap/input_image.jpg' @@ -63,4 +58,4 @@ def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> L negative_vap_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" ) -save_video(video, "video.mp4", fps=15, quality=5) \ No newline at end of file +save_video(video, "video.mp4", fps=15, quality=5)