diff --git a/README.md b/README.md index f8fdf3b1..0702bbbd 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,70 @@ +# AnimateHYBRID + +Mini-Guide: AnimateDiff <4 GB VRAM with N3R +1️⃣ Pick the Right Model + +N3RModelOptimized → ~3.6 GB VRAM, full features, stable. + +Mini GPU Mode + generate_latents_mini_gpu_320 → ~2.1 GB VRAM, ultra-light for quick tests. + +2️⃣ VRAM-Friendly Settings + +final_latent_scale → reduces the final latent resolution to save memory. + +num_fraps_per_image → limit the number of frames per input image. + +block_size & overlap → tweak for streaming decoding efficiency. + +3️⃣ Adaptive N3R Fusion + +Channel-wise normalization to prevent artifacts. + +Strict clamping of latents: torch.clamp(latents, -1.0, 1.0). + +Controlled latent injection: + +fused_latents = latent_injection * latents_frame + (1 - latent_injection) * n3r_latents +4️⃣ Motion / LoRA / VAE + +Motion modules and LoRA can be enabled, but watch VRAM usage → use attention slicing if needed. + +Light VAE + blockwise decoding ensures GPU stability. + +5️⃣ Pro Tips + +Free VRAM after each frame: + +del latents +torch.cuda.empty_cache() + +Adaptive embeddings for UNet → avoids dimension mismatch errors. + +decode_latents_ultrasafe_blockwise → stable decoding with high-quality output. + +💡 Bottom Line: With these settings, AnimateDiff can run even on GPUs with 3–4 GB VRAM without sacrificing output quality. + +And yes… N3R did it for you! 🚀 +n3rProtoBoost: +``` +python -m scripts.n3rProtoBoost \ + --pretrained-model-path "/huggingface/miniSD" \ + --config "configs/prompts/2_animate/1080.yaml" \ + --device "cuda" \ + --vae-offload \ + --fp16* +``` +n3rProBoostNet: +``` +python -m scripts.n3rProBoostNet \ + --pretrained-model-path "/mnt/62G/huggingface/miniSD" \ + --config "configs/prompts/0_n3r/960.yaml" \ + --device "cuda" \ + --vae-offload \ + --fp16 +``` + + + # AnimateDiff This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725) [ICLR2024 Spotlight]. diff --git a/configs/prompts/0_animate/128.yaml b/configs/prompts/0_animate/128.yaml new file mode 100644 index 00000000..ca90f2fd --- /dev/null +++ b/configs/prompts/0_animate/128.yaml @@ -0,0 +1,61 @@ +#n3r_tiny.yaml +# ------------------------- +# Tiny-SD config +# ------------------------- +W: 128 +H: 128 +L: 4 +steps: 15 # un peu plus que le test pour meilleure qualité +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +text_encoder: null +tokenizer: null +dtype: float16 + +fps: 12 +guidance_scale: 4.5 +init_image_scale: 0.85 +use_real_esrgan: true # activé, mais batch plus petit + +scheduler: + type: DDIMScheduler + steps: 15 + beta_start: 0.00085 + beta_end: 0.012 + +dreambooth_path: null +lora_model_path: null + +key_frames: 4 +inter_frames: 4 + +block_size: 128 +overlap: 16 +num_frames_per_image: 4 # générer par petits batches (tu peux répéter pour 60 frames) + +motion_module: scripts/modules/motion_module_tiny.py +controlnet_adapter: null + +inference_config: configs/inference/inference-v4.yaml +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true + +seed: 1234 + +enable_xformers_memory_efficient_attention: true +vae_slicing: true +vae_tiling: true +vae_path: null + +upscale_factor: 4 # ESRGAN actif + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting" + - "natural walk cycle, same character, forward facing, slight tilt for realism, smooth transitions, crisp lines, photorealistic shading, coherent body movement, stable proportions, simple background" + +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background" + +input_images: + - "input/image_128x0.png" diff --git a/configs/prompts/0_animate/256.yaml b/configs/prompts/0_animate/256.yaml new file mode 100644 index 00000000..bb9e028f --- /dev/null +++ b/configs/prompts/0_animate/256.yaml @@ -0,0 +1,77 @@ +# ------------------------- +# n3r_tiny_256_vram_opt.yaml +# ------------------------- +# Tiny-SD 256x256 VRAM-safe optimisé pour 5D AnimateDiff +# - Frames fluides et cohérentes +# - FP16 UNET & TextEncoder, VAE FP32 pour stabilité couleurs +# - Motion module Tiny intégré +# ------------------------- + +# Image size +W: 256 +H: 256 +L: 4 +steps: 15 + +# Model paths +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" + +# FP settings +dtype: float32 # VAE en FP32 pour stabilité couleurs +fp16_unet: true # UNET et text_encoder en FP16 + +# Video / motion +fps: 15 +num_frames_per_image: 12 # permet séquences fluides +init_image_scale: 0.85 +use_real_esrgan: true + +# Scheduler +scheduler: + type: DDIMScheduler + steps: 15 + beta_start: 0.00085 + beta_end: 0.012 + +# Motion module +motion_module: scripts/modules/motion_module_tiny.py +key_frames: 1 +inter_frames: 11 + +# Tiling VAE +vae_slicing: true +vae_tiling: true +block_size: 256 # décode l'image entière → plus de mosaïque +overlap: 32 + +# Inference / memory +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 8 +enable_xformers_memory_efficient_attention: true + +# Random seed +seed: 1234 + +# Diffusion guidance +guidance_scale: 4.5 +creative_noise: 0.0 + +# Upscale factor +upscale_factor: 4 + +# Prompts +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting" + - "natural walk cycle, same character, forward facing, slight tilt for realism, smooth transitions, crisp lines, photorealistic shading, coherent body movement, stable proportions, simple background" + +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background" + +# Input images +input_images: + - "input/256/image_256x0.png" + - "input/256/image_256x1.png" + - "input/256/image_256x1.png" diff --git a/configs/prompts/0_animate/Readme.txt b/configs/prompts/0_animate/Readme.txt new file mode 100644 index 00000000..da9e08f6 --- /dev/null +++ b/configs/prompts/0_animate/Readme.txt @@ -0,0 +1 @@ +Sample prompt for low Vram MAX 4Go diff --git a/configs/prompts/0_n3r/512-a.yaml b/configs/prompts/0_n3r/512-a.yaml new file mode 100644 index 00000000..ad1230a6 --- /dev/null +++ b/configs/prompts/0_n3r/512-a.yaml @@ -0,0 +1,122 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +latent_injection: 0.70 # 85% proche de l'image init - 0.6 on s'éloigne de l'init - 0.5 creatif - Valeur Stable (0.70) +creative_noise: 0.08 # moins de liberté, plus de cohérence 0.08 ou 0.15 rendu creatif; 0.00 rendu cinéma (0.10) -- Valeur test à conservé (0.00) +# Exploration maximale du latent 0. (0.25) - Augmenter init_image_scale → donne plus de signal à partir de ton image d’initiation. (1.0 image original) +init_image_scale: 0.5 # - on s'éloigne de l'init 0.5 - presque tout le signal de l'image d'origine -- Valeur test 0.85 +init_image_scale_end: 0.75 # Plus proche de init_image_scale pour limiter le flou +# guidance_scale = 12 → les prompts positifs et négatifs auront un impact plus marqué. +guidance_scale: 8.5 # un peu plus strict pour que le chat ressorte 8.0 - 1.5 rendu proche de l'input - 4.0 creatif - Valeur test (2.7) +guidance_scale_end: 6.5 # légère décroissance pour variation naturelle +steps: 50 # diffusion steps faible pour VRAM - Augmenter steps ou fps → le motion module a plus de temps pour transformer les latents. (20) +# ------------------------- +use_n3r_model: true +use_mini_gpu: true +n3r_L: 6 +n3r_N_samples: 24 +n3r_L_low: 4 +n3r_L_high: 8 +verbose: False + +final_latent_scale: 0.25 +tile_size: 64 +#--------------------latent_scale_boost 5.71 ------------ +latent_scale_boost: 1.0 # Permet d'augmenter le rendu à utiliser doucement :) +# +W: 536 +H: 960 +L: 4 +fps: 12 # moins de frames simultanées + +# ------------------------- +# Modèles n3oray +# ------------------------- +# Liste des modèles à utiliser (ordre = interpolation possible) LoRA est FP16 ou compatible avec ton UNet SDXL. + +n3oray_models: + #cyber_skin: "/mnt/62G/huggingface/cyber-fp16/Cyber skin_fp16.safetensors" #OK + +motion_module: "v3_sd15_mm.ckpt" + +upscale_factor: 2 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +unet_dtype: float32 +dtype: float16 # VRAM safe FP16 pour 4Go- +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + #type: PNDMScheduler + steps: 40 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 12 # plus de frames par image pour plus de cohérence +transition_frames: 1 # transitions plus douces +key_frames: 1 + + +#80*0.6+80 = 128 → ok. +block_size: 160 #80 ou 100 ou 110 ou 120 ou (128) ou 156 ou (160 MAX soit 160x0.6+160=256) 160 +overlap: 96 #40 ou 60 ou 66 ou 72 ou 76 ou (80) ou 100 ou (96) 96 +vae_slicing: true +vae_tiling: true +vae_device: cpu # force VAE sur CPU pour VRAM safe +motion_module_device: cpu +#motion_module_device: cuda + +# Vérifier le motion module → certains modules “lite” sont trop conservateurs avec des images petites ou des frames peu nombreuses. +#motion_module: scripts/modules/motion_module_tiny.py +#motion_module: scripts/modules/motion_module_cam.py +#motion_module: scripts/modules/motion_module_cam2.py +motion_module: scripts/modules/motion_module_show_safe.py +#motion_module: scripts/modules/Motion_module_Masked.py +#motion_module: scripts/modules/Motion_module_Enhanced.py +#motion_module: scripts/modules/Motion_module_Wind.py +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_ultralite_debug.py +#motion_module: scripts/modules/motion_ulta_lite_fix.py +#motion_module: scripts/modules/motion_module_show.py (test ok) + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +#vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +#vae_path: "/mnt/62G/huggingface/vae/vaeKlF8Anime2_klF8Anime2VAE.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model_fp16.safetensors" +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae" + + +prompt: + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, comic book style Marvel, cel shading, bold outlines, dramatic action pose, neon lighting, vibrant colors, glowing effects, dynamic perspective, cinematic composition, energy lines, stylized shadows, inked lines, comic panel framing, halftone shading, high contrast lighting, happy expression" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, Marvel comic style, dynamic pose, neon city background, glowing circuits, vibrant colors, cinematic lighting, dramatic angles, stylized shadows, action scene, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, katana on back, Marvel comic style, synthwave neon city, pink and purple lighting, reflective wet streets, dynamic perspective, cinematic composition, glowing atmosphere, motion energy, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, katana on back, Marvel comic style, angelic transformation, glowing neon wings, radiant energy, floating pose, dramatic lighting, vibrant colors, stylized aura, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, katana on back, Marvel comic style, cyberpunk city night, dense skyline, neon signs, holograms, rainy streets, cinematic composition, dramatic shadows, dynamic perspective, inked lines, comic panel framing, halftone shading, high contrast lighting" + + +n_prompt: + - "blurry, deformed, extra limbs, bad anatomy, low quality, messy background, dark colors, scary, realistic, human features" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + + +input_images: + - "input/540-960/1.png" + - "input/540-960/2.png" + - "input/540-960/3.png" + - "input/540-960/4.png" + - "input/540-960/5.png" diff --git a/configs/prompts/0_n3r/512-c.yaml b/configs/prompts/0_n3r/512-c.yaml new file mode 100644 index 00000000..b0101435 --- /dev/null +++ b/configs/prompts/0_n3r/512-c.yaml @@ -0,0 +1,84 @@ +# --------------------------------------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Optimisé pour créativité + mouvement dynamique +# --------------------------------------------------------- +fps: 12 +num_fraps_per_image: 3 # plus de frames par image = interpolation plus fluide +upscale_factor: 2 +steps: 50 +guidance_scale: 7.5 +guidance_scale_end: 6.5 # légère décroissance pour variation naturelle +init_image_scale: 0.85 +init_image_scale_end: 0.75 # Plus proche de init_image_scale pour limiter le flou +creative_noise: 0.08 +creative_noise_end: 0.10 # Progression plus douce +latent_scale_boost: 1.0 +final_latent_scale: 0.25 +transition_frames: 4 # Plus de frames pour une interpolation uniforme +latent_injection: 0.85 + +block_size: 256 # Assurez-vous que divisible par la résolution +use_n3r_model: true +use_n3r_pro_net: true +use_mini_gpu: true +n3r_L: 6 +n3r_N_samples: 12 +n3r_L_low: 4 +n3r_L_high: 8 +verbose: False + +W: 536 +H: 960 +overlap: 96 +tile_size: 64 + +# ------------------------- +# VAE et upscale +# ------------------------- +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +vae_device: cpu +vae_slicing: true +vae_tiling: true +use_real_esrgan: true + +# ------------------------- +# Motion module → dynamique et créatif +# ------------------------- +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_show.py +motion_module: scripts/modules/motion_module_show_safe.py # Cool effect +#motion_module: scripts/modules/Motion_module_Enhanced.py +motion_module_device: cpu + +# ------------------------- +# LoRA / UNet +# ------------------------- +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +unet_dtype: float32 +dtype: float16 +enable_xformers_memory_efficient_attention: true +batch_size: 1 +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true + +# ------------------------- +# Seeds & contrôle +# ------------------------- +seed: 1234 + +# ------------------------- +# Prompts +# ------------------------- +prompt: + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, comic book style Marvel, cel shading, bold outlines, dramatic action pose, neon lighting, vibrant colors, glowing effects, dynamic perspective, cinematic composition, energy lines, stylized shadows, inked lines, comic panel framing, halftone shading, high contrast lighting, happy expression" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, comic book style Marvel, cel shading, bold outlines, dramatic action pose, neon lighting, vibrant colors, glowing effects, dynamic perspective, cinematic composition, energy lines, stylized shadows, inked lines, comic panel framing, halftone shading, high contrast lighting, happy expression" + +n_prompt: + - "blurry, deformed, extra limbs, bad anatomy, low quality, messy background, dark colors, scary, realistic, human features" + - "blurry, deformed, extra limbs, bad anatomy, low quality, messy background, dark colors, scary, realistic, human features" + + +input_images: + - "input/540-960/5.png" + - "input/540-960/6.png" diff --git a/configs/prompts/0_n3r/960.yaml b/configs/prompts/0_n3r/960.yaml new file mode 100644 index 00000000..b41daa3f --- /dev/null +++ b/configs/prompts/0_n3r/960.yaml @@ -0,0 +1,99 @@ +# --------------------------------------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Optimisé pour créativité + mouvement dynamique +# --------------------------------------------------------- +fps: 12 +num_fraps_per_image: 3 # plus de frames par image = interpolation plus fluide +upscale_factor: 2 +steps: 50 +guidance_scale: 7.5 +guidance_scale_end: 6.5 # légère décroissance pour variation naturelle +init_image_scale: 0.85 +init_image_scale_end: 0.75 # Plus proche de init_image_scale pour limiter le flou +creative_noise: 0.08 +creative_noise_end: 0.10 # Progression plus douce +latent_scale_boost: 1.0 +final_latent_scale: 0.25 +transition_frames: 4 # Plus de frames pour une interpolation uniforme +latent_injection: 0.85 + +block_size: 256 # Assurez-vous que divisible par la résolution +use_n3r_model: true +use_n3r_pro_net: true +use_mini_gpu: true +n3r_L: 6 +n3r_N_samples: 12 +n3r_L_low: 4 +n3r_L_high: 8 +verbose: False + +W: 536 +H: 960 +overlap: 96 +tile_size: 64 + +# ------------------------- +# VAE et upscale +# ------------------------- +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +vae_device: cpu +vae_slicing: true +vae_tiling: true +use_real_esrgan: true + +# ------------------------- +# Motion module → dynamique et créatif +# ------------------------- + +# Vérifier le motion module → certains modules “lite” sont trop conservateurs avec des images petites ou des frames peu nombreuses. +#motion_module: scripts/modules/motion_module_tiny.py +#motion_module: scripts/modules/motion_module_cam.py +motion_module: scripts/modules/motion_module_cam2.py +#motion_module: scripts/modules/Motion_module_Masked.py +#motion_module: scripts/modules/Motion_module_Enhanced.py +#motion_module: scripts/modules/Motion_module_Wind.py +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_ultralite_debug.py +#motion_module: scripts/modules/motion_ulta_lite_fix.py +#motion_module: scripts/modules/motion_module_show.py (test ok) + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +#vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +#vae_path: "/mnt/62G/huggingface/vae/vaeKlF8Anime2_klF8Anime2VAE.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model_fp16.safetensors" +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae" + + +prompt: + - "cyber armor, mask, comic book style Marvel, cel shading, bold outlines, dramatic, neon lighting, vibrant colors, glowing effects, dynamic perspective, cinematic composition, energy lines, stylized shadows, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "superhero, consistent character, blue light, cyber armor, mask, Marvel comic style, dynamic, neon city background, glowing circuits, vibrant colors, cinematic lighting, dramatic angles, stylized shadows, action scene, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "cyber armor, mask, Marvel comic style, synthwave neon city, pink and purple lighting, reflective wet streets, dynamic perspective, cinematic composition, glowing atmosphere, motion energy, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "cyber armor, mask, Marvel comic style, angelic transformation, glowing neon wings, radiant energy, dramatic lighting, vibrant colors, stylized aura, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "Marvel comic style, cyberpunk city night, dense skyline, neon signs, holograms, rainy streets, cinematic composition, dramatic shadows, dynamic perspective, inked lines, comic panel framing, halftone shading, high contrast lighting" + + +n_prompt: + - "blurry, deformed, low quality, messy background, dark colors, scary, realistic" + - "blurry, deformed, stiff pose, unnatural movement, distorted, missing parts, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + - "blurry, deformed, stiff pose, unnatural movement, distorted, missing parts, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + - "blurry, deformed, stiff pose, unnatural movement, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + - "blurry, deformed, unnatural movement, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + + +input_images: + - "input/536x960/1.png" + - "input/536x960/2.png" + - "input/536x960/3.png" + - "input/536x960/4.png" + - "input/536x960/5.png" diff --git a/configs/prompts/0_n3r/Readme.txt b/configs/prompts/0_n3r/Readme.txt new file mode 100644 index 00000000..c1a5e6c1 --- /dev/null +++ b/configs/prompts/0_n3r/Readme.txt @@ -0,0 +1,23 @@ +⚡ High Quality (default) +motion: + enabled: true + num_fraps_per_image: 12 + transition_frames: 6 + init_image_scale: 0.85 + creative_noise_end: 0.08 + latent_clamp: true + smoothing_alpha: 0.3 +⚡ Fast Preview +motion: + enabled: true + num_fraps_per_image: 8 + transition_frames: 4 + init_image_scale: 0.9 + creative_noise_end: 0.06 +🎨 Experimental / Creative +motion: + enabled: true + num_fraps_per_image: 16 + transition_frames: 8 + init_image_scale: 0.75 + creative_noise_end: 0.1 diff --git a/configs/prompts/2_animate/1080.yaml b/configs/prompts/2_animate/1080.yaml new file mode 100644 index 00000000..56de564b --- /dev/null +++ b/configs/prompts/2_animate/1080.yaml @@ -0,0 +1,116 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- +use_n3r_model: true +use_mini_gpu: false +n3r_L: 6 +n3r_N_samples: 12 +n3r_L_low: 4 +n3r_L_high: 8 +verbose: true +#L_low, L_high = 4, 8 # plus de couches = plus de structure +#final_latent_scale: 0.18215 +final_latent_scale: 0.25 +tile_size: 64 +#--------------------latent_scale_boost 5.71 ------------ +latent_injection: 1.00 # 85% +latent_scale_boost: 5.71 #0.18215 # 5.71 +# +W: 512 +H: 512 +#W: 512 +#H: 512 +L: 4 +steps: 50 # diffusion steps faible pour VRAM - Augmenter steps ou fps → le motion module a plus de temps pour transformer les latents. (20) + +fps: 12 # moins de frames simultanées +#init_image_scale: 0.25 # Exploration maximale du latent 0. (0.25) - Augmenter init_image_scale → donne plus de signal à partir de ton image d’initiation. (1.0 image original) +init_image_scale: 0.85 # presque tout le signal de l'image d'origine +creative_noise: 0.0 # moins de liberté, plus de cohérence +guidance_scale: 2.0 # un peu plus strict pour que le chat ressorte 8.0 + + + +# ------------------------- +# Modèles n3oray +# ------------------------- +# Liste des modèles à utiliser (ordre = interpolation possible) LoRA est FP16 ou compatible avec ton UNet SDXL. + +n3oray_models: + #cyberpunk_style_v3: "/mnt/62G/huggingface/cyber-fp16/cyberpunk style v3_fp16.safetensors" #OK + #cybersamurai_v2: "/mnt/62G/huggingface/cyber-fp16/cybersamuraiV2E12_fp16.safetensors" #OK + #cyber_skin: "/mnt/62G/huggingface/cyber-fp16/Cyber skin_fp16.safetensors" #OK + #cyber_skin_girl: "/mnt/62G/huggingface/cyber/cyber_style_girl_v1.safetensors" #OK + +motion_module: "v3_sd15_mm.ckpt" + +upscale_factor: 2 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +unet_dtype: float32 +dtype: float16 # VRAM safe FP16 pour 4Go- +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + #type: PNDMScheduler + steps: 40 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 6 # plus de frames par image pour plus de cohérence +transition_frames: 0 # transitions plus douces +key_frames: 1 + + +#80*0.6+80 = 128 → ok. +block_size: 160 #80 ou 100 ou 110 ou 120 ou (128) ou 156 ou (160 MAX soit 160x0.6+160=256) 160 +overlap: 96 #40 ou 60 ou 66 ou 72 ou 76 ou (80) ou 100 ou (96) 96 +vae_slicing: true +vae_tiling: true +vae_device: cpu # force VAE sur CPU pour VRAM safe +motion_module_device: cpu +#motion_module_device: cuda + +# Vérifier le motion module → certains modules “lite” sont trop conservateurs avec des images petites ou des frames peu nombreuses. +#motion_module: scripts/modules/motion_module_tiny.py +#motion_module: scripts/modules/Motion_module_Masked.py +#motion_module: scripts/modules/Motion_module_Enhanced.py +motion_module: scripts/modules/Motion_module_Wind.py +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_ultralite_debug.py +#motion_module: scripts/modules/motion_ulta_lite_fix.py +#motion_module: scripts/modules/motion_module_show.py (test ok) + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +#vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +#vae_path: "/mnt/62G/huggingface/vae/vaeKlF8Anime2_klF8Anime2VAE.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model_fp16.safetensors" +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae" + + + +prompt: + - "Cute little kawaii cat, big sparkling eyes, small nose, fluffy fur, sitting, happy expression, pastel colors, soft lighting, anime style, chibi, simple background, fluffy, velvety fur, soft texture, gentle sparkles around, cozy atmosphere" + - "Cute little kawaii cat, big sparkling eyes, big nose, fluffy fur, sitting, happy expression, pastel colors, high lighting, anime style, chibi, simple background, soft morning light, warm ambient, subtle rim light, glowing eyes" + +n_prompt: + - "blurry, deformed, extra limbs, bad anatomy, low quality, messy background, dark colors, scary, realistic, human features" + - "blurry, deformed, extra limbs, bad anatomy, low quality, messy background, dark colors, scary, realistic, human features, accurate anatomy, correct proportions, four paws visible, cute rounded ears" + +input_images: + - "input/divers/image_512x.png" + - "input/divers/1.png" + diff --git a/configs/prompts/2_animate/128.yaml b/configs/prompts/2_animate/128.yaml new file mode 100644 index 00000000..8f116b75 --- /dev/null +++ b/configs/prompts/2_animate/128.yaml @@ -0,0 +1,45 @@ +# n3r_tiny.yaml +W: 128 +H: 128 +L: 4 +steps: 10 + +init_image_scale: 0.1 # 0.75 +#creative_noise: 0.8 # 0.1 +creative_noise: 0.03 # subtil et sûr max 0.8 + + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 + +fps: 12 +num_frames_per_image: 6 +guidance_scale: 7.5 +use_real_esrgan: true + +#scheduler: +# type: DDIMScheduler +# steps: 22 +# beta_start: 0.00085 +# beta_end: 0.012 + +#motion_module: scripts/modules/motion_module_tiny.py +motion_module: scripts/modules/motion_module_lite.py + +seed: 1234 + +input_images: + - input/image_128x0.png + - input/image_128x1.png + - input/image_128x2.png + - input/image_128x3.png + - input/image_128x4.png + +#enable_xformers_memory_efficient_attention: true (Actif par défaut) +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, film grain, masterpiece" + +n_prompt: + - "ugly,low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, cluttered, broken motion, low resolution, inconsistent colors, messy background, poorly drawn face, disfigured, deformed, deformed text, easynegative, poorly drawn eyes" diff --git a/configs/prompts/2_animate/128p.yaml b/configs/prompts/2_animate/128p.yaml new file mode 100644 index 00000000..0ecb58bc --- /dev/null +++ b/configs/prompts/2_animate/128p.yaml @@ -0,0 +1,44 @@ +# n3r_tiny.yaml +W: 128 +H: 128 +L: 4 +steps: 8 + +init_image_scale: 0.75 # 0.75 +#creative_noise: 0.8 # 0.1 +creative_noise: 0.03 # subtil et sûr max 0.8 + + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 + +fps: 12 +num_frames_per_image: 6 +guidance_scale: 2.5 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 8 + beta_start: 0.00085 + beta_end: 0.012 + +#motion_module: scripts/modules/motion_module_tiny.py +motion_module: scripts/modules/motion_module_lite.py + +seed: 1234 + +input_images: + - input/image_128x0.png + + +#enable_xformers_memory_efficient_attention: true (Actif par défaut) +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, film grain, masterpiece" + + - "best quality, ultra detailed, sharp focus, clean typography, readable text, clear letters, well-formed characters, correct spelling, high contrast text, professional graphic design, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, film grain, masterpiece" + +n_prompt: + - "ugly, low quality, blurry, deformed, distorted text, warped letters, broken typography, unreadable text, malformed characters, extra strokes, fused letters, overlapping letters, glitch text, noisy text, deformed anatomy, stiff pose, unnatural movement, missing parts, extra limbs, cluttered, broken motion, low resolution, inconsistent colors" diff --git a/configs/prompts/2_animate/256.yaml b/configs/prompts/2_animate/256.yaml new file mode 100644 index 00000000..71ae7329 --- /dev/null +++ b/configs/prompts/2_animate/256.yaml @@ -0,0 +1,59 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 256 +H: 256 +L: 4 +steps: 4 + +fps: 12 +init_image_scale: 0.5 # 0.75 +creative_noise: 0.03 # subtil et sûr max 0.8 +guidance_scale: 7.5 # 4.5 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 4 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 4 +key_frames: 1 +inter_frames: 3 + +block_size: 64 +overlap: 32 # overlap plus grand pour éviter mosaïques + +motion_module: scripts/modules/motion_module_tiny.py + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +vae_slicing: true +vae_tiling: true +dtype: float16 # VRAM safe FP16 pour 4Go + + + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, high light" +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background" + +input_images: + - "input/256/image_256x0.png" diff --git a/configs/prompts/2_animate/256_quality.yaml b/configs/prompts/2_animate/256_quality.yaml new file mode 100644 index 00000000..3785702e --- /dev/null +++ b/configs/prompts/2_animate/256_quality.yaml @@ -0,0 +1,35 @@ +W: 128 +H: 128 +L: 4 + +num_frames_per_image: 10 +key_frames: 1 +inter_frames: 1 + +init_image_scale: 0.5 +creative_noise: 0.1 +steps: 12 +guidance_scale: 7.5 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +fps: 12 + +use_real_esrgan: false + +motion_module: scripts/modules/motion_module_tiny.py +dtype: float16 + +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +vae_slicing: false +vae_tiling: false + + +seed: 42 + +prompt: + - "best quality, 1 girl walking, dynamic pose" +n_prompt: + - "low quality, blurry, deformed" + +input_images: + - "input/256/image_256x0.png" diff --git a/configs/prompts/2_animate/256_speed.yaml b/configs/prompts/2_animate/256_speed.yaml new file mode 100644 index 00000000..eda13290 --- /dev/null +++ b/configs/prompts/2_animate/256_speed.yaml @@ -0,0 +1,28 @@ +W: 128 +H: 128 +L: 4 +steps: 5 +num_frames_per_image: 1 +key_frames: 1 +inter_frames: 1 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +fps: 12 +init_image_scale: 0.85 +use_real_esrgan: false + +motion_module: null +dtype: float16 + +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +vae_slicing: false +vae_tiling: false +guidance_scale: 3.5 + +prompt: + - "best quality, 1 girl walking, dynamic pose" +n_prompt: + - "low quality, blurry, deformed" + +input_images: + - "input/256/image_256x0.png" diff --git a/configs/prompts/2_animate/256p.yaml b/configs/prompts/2_animate/256p.yaml new file mode 100644 index 00000000..9118e65c --- /dev/null +++ b/configs/prompts/2_animate/256p.yaml @@ -0,0 +1,59 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 256 +H: 256 +L: 4 +steps: 1 + +fps: 12 +init_image_scale: 0.75 # 0.75 +creative_noise: 0.03 # subtil et sûr max 0.8 +guidance_scale: 1.5 # 4.5 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 1 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 12 +key_frames: 1 +inter_frames: 3 + +block_size: 64 +overlap: 32 # overlap plus grand pour éviter mosaïques + +motion_module: scripts/modules/motion_module_tiny.py + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +vae_slicing: true +vae_tiling: true +dtype: float16 # VRAM safe FP16 pour 4Go + + + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, high light" +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background" + +input_images: + - "input/256/image_256x0.png" diff --git a/configs/prompts/2_animate/512-c.yaml b/configs/prompts/2_animate/512-c.yaml new file mode 100644 index 00000000..8198f3f4 --- /dev/null +++ b/configs/prompts/2_animate/512-c.yaml @@ -0,0 +1,122 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +latent_injection: 0.90 # 85% proche de l'image init - 0.6 on s'éloigne de l'init - 0.5 creatif - Valeur Stable (0.70) +creative_noise: 0.00 # moins de liberté, plus de cohérence 0.08 ou 0.15 rendu creatif; 0.00 rendu cinéma (0.10) -- Valeur test à conservé (0.00) +# Exploration maximale du latent 0. (0.25) - Augmenter init_image_scale → donne plus de signal à partir de ton image d’initiation. (1.0 image original) +init_image_scale: 0.5 # - on s'éloigne de l'init 0.5 - presque tout le signal de l'image d'origine -- Valeur test 0.85 +# guidance_scale = 12 → les prompts positifs et négatifs auront un impact plus marqué. +guidance_scale: 3.5 # un peu plus strict pour que le chat ressorte 8.0 - 1.5 rendu proche de l'input - 4.0 creatif - Valeur test (2.7) +steps: 50 # diffusion steps faible pour VRAM - Augmenter steps ou fps → le motion module a plus de temps pour transformer les latents. (20) +# ------------------------- +use_n3r_model: true +use_mini_gpu: true +n3r_L: 6 +n3r_N_samples: 12 +n3r_L_low: 4 +n3r_L_high: 8 +verbose: False + +final_latent_scale: 0.25 +tile_size: 64 +#--------------------latent_scale_boost 5.71 ------------ +latent_scale_boost: 1.0 # Permet d'augmenter le rendu à utiliser doucement :) +# +W: 540 +H: 960 + +L: 4 + + +fps: 12 # moins de frames simultanées + +# ------------------------- +# Modèles n3oray +# ------------------------- +# Liste des modèles à utiliser (ordre = interpolation possible) LoRA est FP16 ou compatible avec ton UNet SDXL. + +n3oray_models: + #cyber_skin: "/mnt/62G/huggingface/cyber-fp16/Cyber skin_fp16.safetensors" #OK + +motion_module: "v3_sd15_mm.ckpt" + +upscale_factor: 2 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +unet_dtype: float32 +dtype: float16 # VRAM safe FP16 pour 4Go- +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + #type: PNDMScheduler + steps: 40 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 10 # plus de frames par image pour plus de cohérence +transition_frames: 2 # transitions plus douces +key_frames: 1 + + +#80*0.6+80 = 128 → ok. +block_size: 160 #80 ou 100 ou 110 ou 120 ou (128) ou 156 ou (160 MAX soit 160x0.6+160=256) 160 +overlap: 96 #40 ou 60 ou 66 ou 72 ou 76 ou (80) ou 100 ou (96) 96 +vae_slicing: true +vae_tiling: true +vae_device: cpu # force VAE sur CPU pour VRAM safe +motion_module_device: cpu +#motion_module_device: cuda + +# Vérifier le motion module → certains modules “lite” sont trop conservateurs avec des images petites ou des frames peu nombreuses. +#motion_module: scripts/modules/motion_module_tiny.py +#motion_module: scripts/modules/motion_module_cam.py +motion_module: scripts/modules/motion_module_cam2.py +#motion_module: scripts/modules/Motion_module_Masked.py +#motion_module: scripts/modules/Motion_module_Enhanced.py +#motion_module: scripts/modules/Motion_module_Wind.py +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_ultralite_debug.py +#motion_module: scripts/modules/motion_ulta_lite_fix.py +#motion_module: scripts/modules/motion_module_show.py (test ok) + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +#vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +#vae_path: "/mnt/62G/huggingface/vae/vaeKlF8Anime2_klF8Anime2VAE.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model_fp16.safetensors" +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae" + + +prompt: + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, comic book style Marvel, cel shading, bold outlines, dramatic action pose, neon lighting, vibrant colors, glowing effects, dynamic perspective, cinematic composition, energy lines, stylized shadows, inked lines, comic panel framing, halftone shading, high contrast lighting, happy expression" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, Marvel comic style, dynamic pose, neon city background, glowing circuits, vibrant colors, cinematic lighting, dramatic angles, stylized shadows, action scene, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, katana on back, Marvel comic style, synthwave neon city, pink and purple lighting, reflective wet streets, dynamic perspective, cinematic composition, glowing atmosphere, motion energy, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, katana on back, Marvel comic style, angelic transformation, glowing neon wings, radiant energy, floating pose, dramatic lighting, vibrant colors, stylized aura, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "female superhero, consistent character, same face, long flowing black hair, blue eyes, upper body, cyber armor, mask, katana on back, Marvel comic style, cyberpunk city night, dense skyline, neon signs, holograms, rainy streets, cinematic composition, dramatic shadows, dynamic perspective, inked lines, comic panel framing, halftone shading, high contrast lighting" + + +n_prompt: + - "blurry, deformed, extra limbs, bad anatomy, low quality, messy background, dark colors, scary, realistic, human features" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + - "blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + + +input_images: + - "input/540-960/1.png" + - "input/540-960/2.png" + - "input/540-960/3.png" + - "input/540-960/4.png" + - "input/540-960/5.png" diff --git a/configs/prompts/2_animate/512.yaml b/configs/prompts/2_animate/512.yaml new file mode 100644 index 00000000..dbe29a0c --- /dev/null +++ b/configs/prompts/2_animate/512.yaml @@ -0,0 +1,59 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 512 +H: 512 +L: 4 +steps: 1 + +fps: 12 +init_image_scale: 0.5 # 0.75 +creative_noise: 0.03 # subtil et sûr max 0.8 +guidance_scale: 7.5 # 4.5 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 1 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 12 +key_frames: 1 +inter_frames: 3 + +block_size: 64 +overlap: 32 # overlap plus grand pour éviter mosaïques + +motion_module: scripts/modules/motion_module_tiny.py + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +vae_slicing: true +vae_tiling: true +dtype: float16 # VRAM safe FP16 pour 4Go + + + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, high light" +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background" + +input_images: + - "input/512/image_512x.png" diff --git a/configs/prompts/2_animate/640.yaml b/configs/prompts/2_animate/640.yaml new file mode 100644 index 00000000..af31a7a2 --- /dev/null +++ b/configs/prompts/2_animate/640.yaml @@ -0,0 +1,68 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 640 +H: 640 +L: 4 +steps: 20 + +fps: 5 +#init_image_scale: 0.25 # Moins de poids sur l'image d'entrée +init_image_scale: 0.15 # Exploration maximale du latent +#creative_noise: 0.1 # Plus de bruit pour une exploration plus libre +creative_noise: 0.15 +#guidance_scale: 4.5 # Valeur plus faible pour favoriser la créativité +guidance_scale: 3.8 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 20 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 5 +key_frames: 1 +inter_frames: 3 + +block_size: 64 +overlap: 45 # Augmentation pour réduire les traces visibles # overlap plus grand pour éviter mosaïques 32 Max 48 +#block_size: 128 +#overlap: 80 + +block_size: 76 +overlap: 48 +vae_slicing: true +vae_tiling: true +dtype: float16 # VRAM safe FP16 pour 4Go + +motion_module: scripts/modules/motion_module_tiny.py + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" + + + + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, high light,vivid, surrealistic, futuristic, abstract, colorful, alien landscape" +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + +input_images: + - "input/640/image_640.png" diff --git a/configs/prompts/2_animate/640x512.yaml b/configs/prompts/2_animate/640x512.yaml new file mode 100644 index 00000000..60787ef6 --- /dev/null +++ b/configs/prompts/2_animate/640x512.yaml @@ -0,0 +1,69 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 512 +H: 640 +L: 4 +steps: 20 + +fps: 5 +#init_image_scale: 0.25 # Moins de poids sur l'image d'entrée +init_image_scale: 0.15 # Exploration maximale du latent +#creative_noise: 0.1 # Plus de bruit pour une exploration plus libre +creative_noise: 0.15 +#guidance_scale: 4.5 # Valeur plus faible pour favoriser la créativité +guidance_scale: 3.8 + +upscale_factor: 2 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 25 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 5 +key_frames: 1 +inter_frames: 3 + + +block_size: 100 +overlap: 60 +vae_slicing: true +vae_tiling: true +dtype: float16 # VRAM safe FP16 pour 4Go + +#motion_module: scripts/modules/motion_module_tiny.py +motion_module: scripts/modules/Motion_module_Masked.py +motion_module: scripts/modules/Motion_module_Enhanced.py + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" + + + + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, high light,vivid, surrealistic, futuristic, abstract, colorful, alien landscape, floating islands, magical lighting, cosmic clouds, bioluminescent plants, vibrant abstract patterns" +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + +input_images: + - "input/640/mon_image_640-1.png" + - "input/640/mon_image_640-2.png" diff --git a/configs/prompts/2_animate/768.yaml b/configs/prompts/2_animate/768.yaml new file mode 100644 index 00000000..66f3b219 --- /dev/null +++ b/configs/prompts/2_animate/768.yaml @@ -0,0 +1,59 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 768 +H: 768 +L: 4 +steps: 1 + +fps: 5 +init_image_scale: 0.5 # 0.75 +creative_noise: 0.03 # subtil et sûr max 0.8 +guidance_scale: 7.5 # 4.5 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +dtype: float16 +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + steps: 1 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 5 +key_frames: 1 +inter_frames: 3 + +block_size: 64 +overlap: 32 # overlap plus grand pour éviter mosaïques + +motion_module: scripts/modules/motion_module_tiny.py + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +vae_slicing: true +vae_tiling: true +dtype: float16 # VRAM safe FP16 pour 4Go + + + +prompt: + - "best quality, 1 girl walking, natural dynamic pose, arms swinging, legs mid-step, balanced torso and head, flowing hair, consistent outfit and hairstyle, outdoor spring, cherry blossoms, petals, smooth motion across frames, coherent animation, vibrant colors, cinematic lighting, high light" +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background" + +input_images: + - "input/768/image_768.png" diff --git a/configs/prompts/2_animate/960.yaml b/configs/prompts/2_animate/960.yaml new file mode 100644 index 00000000..b251800f --- /dev/null +++ b/configs/prompts/2_animate/960.yaml @@ -0,0 +1,122 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +latent_injection: 0.90 # 85% proche de l'image init - 0.6 on s'éloigne de l'init - 0.5 creatif - Valeur Stable (0.70) +creative_noise: 0.00 # moins de liberté, plus de cohérence 0.08 ou 0.15 rendu creatif; 0.00 rendu cinéma (0.10) -- Valeur test à conservé (0.00) +# Exploration maximale du latent 0. (0.25) - Augmenter init_image_scale → donne plus de signal à partir de ton image d’initiation. (1.0 image original) +init_image_scale: 0.5 # - on s'éloigne de l'init 0.5 - presque tout le signal de l'image d'origine -- Valeur test 0.85 +# guidance_scale = 12 → les prompts positifs et négatifs auront un impact plus marqué. +guidance_scale: 3.5 # un peu plus strict pour que le chat ressorte 8.0 - 1.5 rendu proche de l'input - 4.0 creatif - Valeur test (2.7) +steps: 50 # diffusion steps faible pour VRAM - Augmenter steps ou fps → le motion module a plus de temps pour transformer les latents. (20) +# ------------------------- +use_n3r_model: true +use_mini_gpu: true +n3r_L: 6 +n3r_N_samples: 12 +n3r_L_low: 4 +n3r_L_high: 8 +verbose: False + +final_latent_scale: 0.25 +tile_size: 64 +#--------------------latent_scale_boost 5.71 ------------ +latent_scale_boost: 1.0 # Permet d'augmenter le rendu à utiliser doucement :) +# +W: 540 +H: 960 + +L: 4 + + +fps: 12 # moins de frames simultanées + +# ------------------------- +# Modèles n3oray +# ------------------------- +# Liste des modèles à utiliser (ordre = interpolation possible) LoRA est FP16 ou compatible avec ton UNet SDXL. + +n3oray_models: + #cyber_skin: "/mnt/62G/huggingface/cyber-fp16/Cyber skin_fp16.safetensors" #OK + +motion_module: "v3_sd15_mm.ckpt" + +upscale_factor: 2 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +unet_dtype: float32 +dtype: float16 # VRAM safe FP16 pour 4Go- +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + #type: PNDMScheduler + steps: 40 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 10 # plus de frames par image pour plus de cohérence +transition_frames: 2 # transitions plus douces +key_frames: 1 + + +#80*0.6+80 = 128 → ok. +block_size: 160 #80 ou 100 ou 110 ou 120 ou (128) ou 156 ou (160 MAX soit 160x0.6+160=256) 160 +overlap: 96 #40 ou 60 ou 66 ou 72 ou 76 ou (80) ou 100 ou (96) 96 +vae_slicing: true +vae_tiling: true +vae_device: cpu # force VAE sur CPU pour VRAM safe +motion_module_device: cpu +#motion_module_device: cuda + +# Vérifier le motion module → certains modules “lite” sont trop conservateurs avec des images petites ou des frames peu nombreuses. +#motion_module: scripts/modules/motion_module_tiny.py +#motion_module: scripts/modules/motion_module_cam.py +motion_module: scripts/modules/motion_module_cam2.py +#motion_module: scripts/modules/Motion_module_Masked.py +#motion_module: scripts/modules/Motion_module_Enhanced.py +#motion_module: scripts/modules/Motion_module_Wind.py +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_ultralite_debug.py +#motion_module: scripts/modules/motion_ulta_lite_fix.py +#motion_module: scripts/modules/motion_module_show.py (test ok) + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +#vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +#vae_path: "/mnt/62G/huggingface/vae/vaeKlF8Anime2_klF8Anime2VAE.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model_fp16.safetensors" +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae" + + +prompt: + - "cyber armor, mask, comic book style Marvel, cel shading, bold outlines, dramatic, neon lighting, vibrant colors, glowing effects, dynamic perspective, cinematic composition, energy lines, stylized shadows, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "superhero, consistent character, blue light, cyber armor, mask, Marvel comic style, dynamic, neon city background, glowing circuits, vibrant colors, cinematic lighting, dramatic angles, stylized shadows, action scene, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "cyber armor, mask, Marvel comic style, synthwave neon city, pink and purple lighting, reflective wet streets, dynamic perspective, cinematic composition, glowing atmosphere, motion energy, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "cyber armor, mask, Marvel comic style, angelic transformation, glowing neon wings, radiant energy, dramatic lighting, vibrant colors, stylized aura, inked lines, comic panel framing, halftone shading, high contrast lighting" + - "Marvel comic style, cyberpunk city night, dense skyline, neon signs, holograms, rainy streets, cinematic composition, dramatic shadows, dynamic perspective, inked lines, comic panel framing, halftone shading, high contrast lighting" + + +n_prompt: + - "blurry, deformed, low quality, messy background, dark colors, scary, realistic" + - "blurry, deformed, stiff pose, unnatural movement, distorted, missing parts, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + - "blurry, deformed, stiff pose, unnatural movement, distorted, missing parts, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + - "blurry, deformed, stiff pose, unnatural movement, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + - "blurry, deformed, unnatural movement, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic" + + +input_images: + - "input/540x960/1.png" + - "input/540x960/2.png" + - "input/540x960/3.png" + - "input/540x960/4.png" + - "input/540x960/5.png" diff --git a/configs/prompts/2_animate/Readme.txt b/configs/prompts/2_animate/Readme.txt new file mode 100644 index 00000000..e5eff9ca --- /dev/null +++ b/configs/prompts/2_animate/Readme.txt @@ -0,0 +1,6 @@ +Sample prompt for low Vram MAX 4Go + +128x128 -> run with n3rHYBRID14 or n3rHYBRID21 or n3rHYBRID22 or n3rHYBRID26 or n3rfast.py or n3rspeed.py +256x256 -> run with n3rHYBRID14 or n3rHYBRID21 or n3rHYBRID22 or n3rfast.py or n3rspeed.py +512x512 -> run with n3rfast.py or n3rspeed.py +640x640 -> run with n3rfast.py or n3rspeed.py diff --git a/configs/prompts/2_animate/cyber.yaml b/configs/prompts/2_animate/cyber.yaml new file mode 100644 index 00000000..9f4e4982 --- /dev/null +++ b/configs/prompts/2_animate/cyber.yaml @@ -0,0 +1,123 @@ +# ------------------------- +# Tiny-SD 256x256 VRAM-safe FP16 pour 5D AnimateDiff +# Avec Motion Module Tiny et VAE offload +# ------------------------- + +W: 256 +H: 256 +#W: 512 +#H: 512 +L: 4 +steps: 20 # diffusion steps faible pour VRAM - Augmenter steps ou fps → le motion module a plus de temps pour transformer les latents. (20) + +fps: 12 # moins de frames simultanées +#init_image_scale: 0.25 # Exploration maximale du latent 0. (0.25) - Augmenter init_image_scale → donne plus de signal à partir de ton image d’initiation. (1.0 image original) +init_image_scale: 0.9 # signal complet depuis l'image +creative_noise: 0.07 # Plus de bruit pour une exploration plus libre (0.1) - Augmenter creative_noise → permet au motion module de générer du mouvement. +guidance_scale: 5.0 # Valeur plus faible pour favoriser la créativité (4.0) + + + +# ------------------------- +# Modèles n3oray +# ------------------------- +# Liste des modèles à utiliser (ordre = interpolation possible) LoRA est FP16 ou compatible avec ton UNet SDXL. + +n3oray_models: + #cyberpunk_style_v3: "/mnt/62G/huggingface/cyber-fp16/cyberpunk style v3_fp16.safetensors" #OK + #cybersamurai_v2: "/mnt/62G/huggingface/cyber-fp16/cybersamuraiV2E12_fp16.safetensors" #OK + #cyber_skin: "/mnt/62G/huggingface/cyber-fp16/Cyber skin_fp16.safetensors" #OK + #cyber_skin_girl: "/mnt/62G/huggingface/cyber/cyber_style_girl_v1.safetensors" #OK + #magic_skin: "/mnt/62G/huggingface/cyber-fp16/magic-fantasy-mech-v1.safetensors" # KO + #civchan_sd15: "/mnt/62G/huggingface/cyber-fp16/new/CivChan_SD1.safetensors" # KO + #night_city_sd15: "/mnt/62G/huggingface/cyber-fp16/new/NightCity.safetensors" # KO + #style2077_sd15: "/mnt/62G/huggingface/cyber-fp16/new/2077_Style-10.safetensors" + #Kara_Vex_sd15: "/mnt/62G/huggingface/cyber-fp16/new/Kara_Vex_Elite_SOLO_Night_City_Cyberpunk_Genre-000008.safetensors" + #CivBotFlux_sd15: "/mnt/62G/huggingface/cyber-fp16/new/CivBotFlux.safetensors" + #Sakiko_sd15: "/mnt/62G/huggingface/cyber-fp16/new/Sakiko Ichinose3216PDXL.safetensors" + +motion_module: "v3_sd15_mm.ckpt" + +upscale_factor: 2 + +pretrained_model_path: "/mnt/62G/huggingface/miniSD" +unet_dtype: float32 +dtype: float16 # VRAM safe FP16 pour 4Go +use_real_esrgan: true + +scheduler: + type: DDIMScheduler + #type: PNDMScheduler + steps: 20 + beta_start: 0.00085 + beta_end: 0.012 + +num_fraps_per_image: 2 +key_frames: 1 +transition_frames: 2 + +#80*0.6+80 = 128 → ok. +block_size: 80 #80 ou 100 ou 110 ou 120 ou (128) ou 156 ou (160 MAX soit 160x0.6+160=256) 160 +overlap: 40 #40 ou 60 ou 66 ou 72 ou 76 ou (80) ou 100 ou (96) 96 +vae_slicing: true +vae_tiling: true +vae_device: cpu # force VAE sur CPU pour VRAM safe +motion_module_device: cpu + +# Vérifier le motion module → certains modules “lite” sont trop conservateurs avec des images petites ou des frames peu nombreuses. +motion_module: scripts/modules/motion_module_tiny.py +#motion_module: scripts/modules/Motion_module_Masked.py +#motion_module: scripts/modules/Motion_module_Enhanced.py +#motion_module: scripts/modules/motion_module_lite.py +#motion_module: scripts/modules/motion_module_ultralite_debug.py +#motion_module: scripts/modules/motion_ulta_lite_fix.py +#motion_module: scripts/modules/motion_module_show.py (test ok) + +device: cuda +offload_folder: /tmp/offload +accelerate: true +low_cpu_mem_usage: true +batch_size: 1 # chaque image traitée individuellement + +seed: 1234 + +enable_xformers_memory_efficient_attention: true + +# VAE complet pour couleurs correctes +#vae_path: "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" +#vae_path: "/mnt/62G/huggingface/vae/vaeKlF8Anime2_klF8Anime2VAE.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model_fp16.safetensors" +vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae/diffusion_pytorch_model.safetensors" +#vae_path: "/mnt/62G/huggingface/miniSD-fp16/vae" + + + +prompt: + - "Super realistic, hyper realistic, hyperdetailed, cybersamurai, long flowing hair, blue eyes, black hair, upper body, cyber armor, mask, science fiction, cyborg, cyberpunk, katana on back, bright glowing background, neon lights, soft colorful lighting, masterpiece, best quality, wide-angle, 8k, dynamic lighting, dramatic shadows, motion blur, sparks, holograms, futuristic calligraphy, vibrant colors, multi-layered texture, Japanese decor, ethereal atmosphere,Hyper realistic, hyper realistic, hyperdetailed, cybersamurai, long hair, green eyes, black hair, upper body, cyber armor, cyber mask, science fiction, cyborg, cyberpunk, katana on back, bright glowing background, neon lights, soft colorful lighting, masterpiece, best quality, wide-angle, 8k, natural sunlight, studio lighting, HDR, vibrant colors, maximum clarity, multi-layered texture, multi-layered texture, Japanese cyber decor, ethereal atmosphere, purple hair" + - "Super realistic, hyper realistic, hyperdetailed, cybersamurai, long flowing hair, green eyes, blue eyes, purple hair, black hair, upper body, armor, cyber armor, mask, cyber mask, science fiction, cyborg, cyberpunk, katana on back, bright glowing background, neon lights, soft colorful lighting, masterpiece, best quality, wide-angle, 8k, dynamic lighting, dramatic shadows, motion blur, sparks, holograms, futuristic calligraphy, natural sunlight, studio lighting, HDR, vibrant colors, maximum clarity, multi-layered texture, Japanese decor, Japanese cyber decor, ethereal atmosphere" + - "Synthwave neon lights, glowing grid streets, reflective wet asphalt, futuristic 80s aesthetic, pink and purple neon glows, dynamic lighting" + - "Digital angel, luminous wings made of neon light, holographic textures, glowing circuits, ethereal floating presence, radiant aura, cyber-futuristic divine being" + - "Cyberpunk city night, dense futuristic skyline, neon signs and holograms, rainy streets, flying cars, misty atmosphere, dramatic lighting" + +n_prompt: + - "low quality, blurry, deformed, stiff pose, unnatural movement, distorted anatomy, missing parts, extra limbs, broken motion, low resolution, inconsistent colors, messy background, realistic, clear, detailed, structured, photo-realistic, bad anatomy, extra fingers" + + +#input_images: +# - "input/512/image_512x.png" + +input_images: + - "input/256/image_256x0.png" + - "input/256/image_256x1.png" + +# prompt : Super realistic, hyper realistic, hyperdetailed, cybersamurai, long flowing hair, green eyes, blue eyes, purple hair, black hair, upper body, armor, cyber armor, mask, cyber mask, science fiction, cyborg, cyberpunk, katana on back, bright glowing background, neon lights, soft colorful lighting, masterpiece, best quality, wide-angle, 8k, dynamic lighting, dramatic shadows, motion blur, sparks, holograms, futuristic calligraphy, natural sunlight, studio lighting, HDR, vibrant colors, maximum clarity, multi-layered texture, Japanese decor, Japanese cyber decor, ethereal atmosphere +# prompt2: Super realistic, hyper realistic, hyperdetailed, cybersamurai, long flowing hair, green eyes, blue eyes, purple hair, black hair, upper body, armor, cyber armor, mask, cyber mask, science fiction, cyborg, cyberpunk, katana on back, bright glowing background, neon lights, soft colorful lighting, masterpiece, best quality, wide-angle, 8k, dynamic lighting, dramatic shadows, motion blur, sparks, holograms, futuristic calligraphy, natural sunlight, studio lighting, HDR, vibrant colors, maximum clarity, multi-layered texture, Japanese decor, Japanese cyber decor, ethereal atmosphere. she in shifting smoke. Fluid, lethal, eternal. +# Animate: Elle sort son katana fait des movements rapidement avec des effects de lumière sur la lame, animation de propultion de la camera vers l'arrière comme si elle reçois les coups. Mouvement réaliste, fluidité du mouvement, pas de déformation de la lame, pas de dégradation du visage, zoom rapide sur le visage. +# Animate2: Animate: Elle sort son katana fait des movements rapidement avec des effects de lumière sur la lame, animation de propultion de la camera vers l'arrière comme si elle reçois les coups. Mouvement réaliste, fluidité du mouvement, pas de déformation de la lame, pas de dégradation du visage, zoom rapide sur le visage. Elle esquive et bloque des projectils shuriken dans sa direction + + + +# Animate: She draws her katana and performs rapid movements with lighting effects on the blade. The camera pulls back as if she's being hit. Realistic movement, fluid motion, no blade distortion, no facial degradation, and a quick zoom on her face. + +# Animate2: She draws her katana and performs rapid movements with lighting effects on the blade. The camera pulls back as if she's being hit. Realistic movement, fluid motion, no blade distortion, no facial degradation, and a quick zoom on her face. She dodges and blocks shuriken projectiles coming her way. + diff --git a/input/128/Readme.txt b/input/128/Readme.txt new file mode 100644 index 00000000..78e017ae --- /dev/null +++ b/input/128/Readme.txt @@ -0,0 +1 @@ +add your file 128x128 here diff --git a/input/256/4.png b/input/256/4.png new file mode 100644 index 00000000..15d89103 Binary files /dev/null and b/input/256/4.png differ diff --git a/input/256/Readme.txt b/input/256/Readme.txt new file mode 100644 index 00000000..046de24d --- /dev/null +++ b/input/256/Readme.txt @@ -0,0 +1 @@ +Add your pics in 256x256 here diff --git a/input/536x960/0.png b/input/536x960/0.png new file mode 100644 index 00000000..d2c603bf Binary files /dev/null and b/input/536x960/0.png differ diff --git a/input/536x960/1.png b/input/536x960/1.png new file mode 100644 index 00000000..05696c46 Binary files /dev/null and b/input/536x960/1.png differ diff --git a/input/536x960/2.png b/input/536x960/2.png new file mode 100644 index 00000000..66e12e2e Binary files /dev/null and b/input/536x960/2.png differ diff --git a/input/536x960/3.png b/input/536x960/3.png new file mode 100644 index 00000000..0b70db8d Binary files /dev/null and b/input/536x960/3.png differ diff --git a/input/536x960/4.png b/input/536x960/4.png new file mode 100644 index 00000000..11625a50 Binary files /dev/null and b/input/536x960/4.png differ diff --git a/input/536x960/5.png b/input/536x960/5.png new file mode 100644 index 00000000..7d0070f2 Binary files /dev/null and b/input/536x960/5.png differ diff --git a/input/536x960/6.png b/input/536x960/6.png new file mode 100644 index 00000000..88ccb169 Binary files /dev/null and b/input/536x960/6.png differ diff --git a/input/536x960/7.png b/input/536x960/7.png new file mode 100644 index 00000000..c91c6454 Binary files /dev/null and b/input/536x960/7.png differ diff --git a/input/536x960/8.png b/input/536x960/8.png new file mode 100644 index 00000000..dcb6c23d Binary files /dev/null and b/input/536x960/8.png differ diff --git a/input/536x960/Readme.txt b/input/536x960/Readme.txt new file mode 100644 index 00000000..b06f6315 --- /dev/null +++ b/input/536x960/Readme.txt @@ -0,0 +1 @@ +Cyber Test diff --git a/input/Readm.txt b/input/Readm.txt new file mode 100644 index 00000000..c19dd496 --- /dev/null +++ b/input/Readm.txt @@ -0,0 +1 @@ +Here sample image for script diff --git a/input/image_128.png b/input/image_128.png new file mode 100644 index 00000000..1222f7e4 Binary files /dev/null and b/input/image_128.png differ diff --git a/input/image_128x0.jpg b/input/image_128x0.jpg new file mode 100644 index 00000000..a40653ce Binary files /dev/null and b/input/image_128x0.jpg differ diff --git a/input/image_128x0.png b/input/image_128x0.png new file mode 100644 index 00000000..f718be9a Binary files /dev/null and b/input/image_128x0.png differ diff --git a/input/image_128x1.png b/input/image_128x1.png new file mode 100644 index 00000000..ff861218 Binary files /dev/null and b/input/image_128x1.png differ diff --git a/input/image_128x2.png b/input/image_128x2.png new file mode 100644 index 00000000..5760a7c8 Binary files /dev/null and b/input/image_128x2.png differ diff --git a/input/image_128x3.png b/input/image_128x3.png new file mode 100644 index 00000000..4cf30393 Binary files /dev/null and b/input/image_128x3.png differ diff --git a/outputs/Readme.txt b/outputs/Readme.txt new file mode 100644 index 00000000..401f387a --- /dev/null +++ b/outputs/Readme.txt @@ -0,0 +1 @@ +Visualisation et test diff --git a/outputs/frame_00000.png b/outputs/frame_00000.png new file mode 100644 index 00000000..a7f572ba Binary files /dev/null and b/outputs/frame_00000.png differ diff --git a/requirements-full.txt b/requirements-full.txt new file mode 100644 index 00000000..6a5286d7 --- /dev/null +++ b/requirements-full.txt @@ -0,0 +1,140 @@ +absl-py==2.4.0 +accelerate==0.22.0 +aiofiles==25.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.3 +aiosignal==1.4.0 +altair==6.0.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.13.0 +attrs==26.1.0 +certifi==2026.2.25 +cffi==2.0.0 +charset-normalizer==3.4.6 +click==8.3.1 +contourpy==1.3.3 +cycler==0.12.1 +decord==0.6.0 +diffusers==0.37.0 +einops==0.8.2 +fastapi==0.135.1 +ffmpeg-python==0.2.0 +ffmpy==1.0.0 +filelock==3.25.2 +flatbuffers==25.12.19 +fonttools==4.62.1 +frozenlist==1.8.0 +fsspec==2026.3.0 +future==1.0.0 +gradio==3.36.1 +gradio_client==2.4.0 +h11==0.16.0 +hf-xet==1.4.2 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==1.6.0 +idna==3.11 +imageio==2.27.0 +imageio-ffmpeg==0.4.9 +importlib_metadata==9.0.0 +jax==0.7.1 +jaxlib==0.7.1 +Jinja2==3.1.6 +jsonschema==4.26.0 +jsonschema-specifications==2025.9.1 +kiwisolver==1.5.0 +linkify-it-py==2.1.0 +markdown-it-py==2.2.0 +MarkupSafe==3.0.3 +matplotlib==3.10.8 +mdit-py-plugins==0.3.3 +mdurl==0.1.2 +mediapipe==0.10.14 +ml_dtypes==0.5.4 +mpmath==1.3.0 +multidict==6.7.1 +narwhals==2.18.1 +networkx==3.6.1 +numpy==1.26.4 +nvidia-cublas==13.0.0.19 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti==13.0.48 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc==13.0.48 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime==13.0.48 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cudnn-cu13==9.13.0.50 +nvidia-cufft==12.0.0.15 +nvidia-cufft-cu12==11.0.2.54 +nvidia-cufile==1.15.0.42 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand==10.4.0.35 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver==12.0.3.29 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse==12.6.2.49 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-cusparselt-cu12==0.7.1 +nvidia-cusparselt-cu13==0.8.0 +nvidia-nccl-cu12==2.21.5 +nvidia-nccl-cu13==2.27.7 +nvidia-nvjitlink==13.0.39 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.3.20 +nvidia-nvshmem-cu13==3.3.24 +nvidia-nvtx==13.0.39 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +opencv-contrib-python==4.11.0.86 +opt_einsum==3.4.0 +orjson==3.11.7 +packaging==26.0 +pandas==2.3.3 +pillow==12.1.1 +propcache==0.4.1 +protobuf==4.25.9 +psutil==7.2.2 +pycparser==3.0 +pydantic==2.12.5 +pydantic_core==2.41.5 +pydub==0.25.1 +Pygments==2.19.2 +pyparsing==3.3.2 +python-dateutil==2.9.0.post0 +python-multipart==0.0.22 +pytz==2026.1.post1 +PyYAML==6.0.3 +referencing==0.37.0 +regex==2026.2.28 +requests==2.32.5 +rich==14.3.3 +rpds-py==0.30.0 +safetensors==0.7.0 +scipy==1.17.1 +semantic-version==2.10.0 +shellingham==1.5.4 +six==1.17.0 +sounddevice==0.5.5 +starlette==0.52.1 +sympy==1.14.0 +tokenizers==0.22.2 +torch==2.9.0+cu130 +torchaudio==2.9.0+cu130 +torchvision==0.24.0+cu130 +tqdm==4.67.3 +transformers==5.3.0 +triton==3.5.0 +typer==0.24.1 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.3 +uc-micro-py==2.0.0 +urllib3==2.6.3 +uvicorn==0.41.0 +websockets==16.0 +yarl==1.23.0 +zipp==3.23.0 diff --git a/requirements-new.txt b/requirements-new.txt new file mode 100644 index 00000000..7e626321 --- /dev/null +++ b/requirements-new.txt @@ -0,0 +1,41 @@ +# PyTorch (stable avec ton GPU) +torch==2.5.1 +torchvision==0.20.1 +torchaudio==2.5.1 + +# Core SD +diffusers==0.37.0 +transformers==5.3.0 +accelerate==0.22.0 +safetensors==0.7.0 +huggingface_hub==1.6.0 + +# Utils +numpy<2 +pillow +einops +omegaconf +scipy +pandas + +# Control / vision +mediapipe==0.10.14 +opencv-contrib-python + +# API / UI +gradio==3.36.1 +fastapi +uvicorn +httpx + +# Video / audio +decord +ffmpeg-python +imageio +imageio-ffmpeg +pydub + +# Misc +tqdm +rich +protobuf>=4.25.3,<5 diff --git a/requirements_min.txt b/requirements_min.txt new file mode 100644 index 00000000..45805a1c --- /dev/null +++ b/requirements_min.txt @@ -0,0 +1,58 @@ +# PyTorch + CUDA 12 +torch==2.3.1 +torchvision==0.16.1 +triton==2.3.1 +xformers==0.0.27 + +# NVIDIA CUDA / cuBLAS / cuDNN / NCCL / cuFFT / cuSolver / cuRand / cuSparse +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-nccl-cu12==2.27.5 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.8.90 +nvidia-nvshmem-cu12==3.4.5 + +# Core ML / SD +diffusers==0.25.0 +transformers==4.25.1 +safetensors==0.7.0 +accelerate==0.22.0 +huggingface_hub==0.15.1 + +# Utilities +numpy<2 +pillow==12.1.1 +matplotlib==3.10.8 +einops==0.8.2 +omegaconf==2.3.0 +scipy==1.12.2 +pandas==2.3.3 +typing-extensions==4.15.0 +setuptools==82.0.0 +packaging==26.0 +python-dateutil==2.9.0.post0 +requests==2.32.5 +tqdm==4.67.3 +rich==14.3.3 +protobuf==6.33.5 + +# Web / API / GUI +gradio==3.36.1 +fastapi==0.135.1 +uvicorn==0.41.0 +httpx==0.28.1 +starlette==0.52.1 + +# Audio / Video +decord==0.6.0 +ffmpeg-python==0.2.0 +pydub==0.25.1 +imageio==2.27.0 +imageio-ffmpeg==0.4.9 diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..1c6f211b --- /dev/null +++ b/run.sh @@ -0,0 +1,151 @@ +#!/bin/bash +set -e + +MODEL_PATH="/mnt/62G/huggingface/miniSD" +DEVICE="cuda" + +clear +echo "========================================" +echo " N3R LAUNCHER INTERACTIF " +echo "========================================" +echo "" + +# ------------------------- +# Température initiale +# ------------------------- +TEMP_BEFORE=$(nvidia-smi --query-gpu=temperature.gpu --format=csv,noheader,nounits) +echo "🌡️ Température GPU actuelle : ${TEMP_BEFORE}°C" +echo "" + +# ------------------------- +# MENU CONFIG +# ------------------------- +echo "Choisir la configuration :" +select CONFIG_CHOICE in \ + "128" \ + "128p" \ + "256" \ + "256p" \ + "512" \ + "512x640" \ + "640" \ + "Quitter" +do + case $CONFIG_CHOICE in + 128) CONFIG="configs/prompts/0_animate/128.yaml"; break ;; + 128p) CONFIG="configs/prompts/2_animate/128p.yaml"; break ;; + 256) CONFIG="configs/prompts/2_animate/256.yaml"; break ;; + 256p) CONFIG="configs/prompts/1_animate/256p.yaml"; break ;; + 512) CONFIG="configs/prompts/2_animate/512.yaml"; break ;; + 512x640) CONFIG="configs/prompts/2_animate/640x512.yaml"; break ;; + 640) CONFIG="configs/prompts/2_animate/640.yaml"; break ;; + Quitter) exit 0 ;; + *) echo "Choix invalide." ;; + esac +done + +echo "" +echo "Choisir le script :" +select SCRIPT_CHOICE in \ + "n3rfast" \ + "n3rspeed" \ + "n3rcreative" \ + "n3rHYBRID21" \ + "n3rHYBRID22" \ + "n3rHYBRID26" \ + "n3rHYBRID14" \ + "Quitter" +do + case $SCRIPT_CHOICE in + Quitter) exit 0 ;; + *) + SCRIPT="$SCRIPT_CHOICE" + break + ;; + esac +done + +# ------------------------- +# Vérification compatibilité +# ------------------------- +if [[ "$SCRIPT" == "n3rHYBRID26" && "$CONFIG_CHOICE" == "256" ]]; then + echo "❌ n3rHYBRID26 ne supporte pas 256." + exit 1 +fi + +if [[ "$SCRIPT" == "n3rHYBRID14" && "$CONFIG_CHOICE" == "512" ]]; then + echo "❌ n3rHYBRID14 ne supporte pas 512." + exit 1 +fi + +# ------------------------- +# Résumé +# ------------------------- +echo "" +echo "========================================" +echo "Config : $CONFIG" +echo "Script : $SCRIPT" +echo "========================================" +echo "" + +read -p "Confirmer l'exécution ? (y/n) : " CONFIRM +if [[ "$CONFIRM" != "y" ]]; then + echo "Annulé." + exit 0 +fi + +# ------------------------- +# Exécution avec timing +# ------------------------- +echo "" +echo "🚀 Lancement..." +echo "" + +START_TIME=$(date +%s) + +python -m scripts.$SCRIPT \ + --pretrained-model-path "$MODEL_PATH" \ + --config "$CONFIG" \ + --device "$DEVICE" \ + --vae-offload \ + --fp16 + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) + +MIN=$((DURATION / 60)) +SEC=$((DURATION % 60)) + +echo "" +echo "========================================" +echo "✅ Exécution terminée." +echo "⏱️ Temps total : ${MIN}m ${SEC}s" +echo "========================================" + +# ------------------------- +# Nettoyage CUDA +# ------------------------- +python - < gray.mean(dim=[2,3], keepdim=True) * 0.9).float() # sujet plus clair que fond moyen + return mask # 1 = personnage, 0 = fond + + def forward(self, latents, input_image_latent=None): + """ + latents: [B, C, F, H, W] + input_image_latent: [B, C, H, W] image d'entrée pour créer le masque personnage + """ + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + device = latents.device + + if input_image_latent is not None: + person_mask = self.create_person_mask(input_image_latent).to(device) # [B,1,H,W] + person_mask = person_mask.unsqueeze(2).expand(-1, -1, F, -1, -1) # [B,1,F,H,W] + bg_mask = 1.0 - person_mask + bg_mask = bg_mask.repeat(1,C,1,1,1) # canaux + person_mask = person_mask.repeat(1,C,1,1,1) + else: + # Pas de mask fourni -> mouvement uniforme + person_mask = torch.ones_like(latents) + bg_mask = torch.ones_like(latents) + + # ------------------------- + # Mouvement personnage + # ------------------------- + y = torch.linspace(1.0, 0.0, H, device=device).view(1,1,1,H,1) + hair_mask = y ** self.hair_bias + t = torch.linspace(0, 2*math.pi, F, device=device) + wave_person = torch.sin(t * self.wave_speed).view(1,1,F,1,1) * self.wave_amplitude + + noise = torch.randn(B, C, F, H, W, device=device) * 0.1 + person_motion = noise * wave_person * hair_mask * self.strength_person * person_mask + + # ------------------------- + # Mouvement décor + # ------------------------- + decor_wave = torch.sin(t * self.decor_wave_speed).view(1,1,F,1,1) * self.decor_wave_amplitude + decor_motion = noise * decor_wave * self.strength_bg * bg_mask + + return latents + person_motion + decor_motion + +# Alias pour compatibilité +MotionModule = MotionModuleMasked diff --git a/scripts/modules/Motion_module_Wind.py b/scripts/modules/Motion_module_Wind.py new file mode 100644 index 00000000..b5558627 --- /dev/null +++ b/scripts/modules/Motion_module_Wind.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import math + +class MotionModuleEnhanced(nn.Module): + """ + Motion module simple et cohérent, prêt pour animation VRAM-light (~2Go) + - Propagation depuis la frame précédente + - Oscillation sinus + bruit pour mouvement naturel + """ + def __init__( + self, + strength: float = 0.03, # mouvement global du personnage/décor + wave_speed: float = 1.5, + wave_amplitude: float = 0.7, + decor_strength: float = 0.05, + decor_wave_speed: float = 0.8, + decor_wave_amplitude: float = 0.3, + camera_shift: float = 0.05 + ): + super().__init__() + self.strength = strength + self.wave_speed = wave_speed + self.wave_amplitude = wave_amplitude + self.decor_strength = decor_strength + self.decor_wave_speed = decor_wave_speed + self.decor_wave_amplitude = decor_wave_amplitude + self.camera_shift = camera_shift + + def forward(self, latents, previous_latent=None): + """ + latents: [B, C, F, H, W] (F=1 si une frame) + previous_latent: [B, C, F, H, W] ou None + """ + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + device = latents.device + + # Temps pour oscillation + t = torch.linspace(0, 2 * math.pi, F, device=device).view(1, 1, F, 1, 1) + + # Oscillation personnage / décor + wave_person = torch.sin(t * self.wave_speed) * self.wave_amplitude + wave_decor = torch.sin(t * self.decor_wave_speed) * self.decor_wave_amplitude + + # Bruit léger + noise = torch.randn_like(latents) * 0.05 + + motion_person = noise * wave_person * self.strength + motion_decor = noise * wave_decor * self.decor_strength + + # Déplacement global (caméra) + shift_x = torch.sin(t * 0.3) * self.camera_shift + shift_y = torch.cos(t * 0.3) * self.camera_shift + motion_camera = torch.zeros_like(latents) + motion_camera = motion_camera.roll(int(H * shift_y.mean().item()), dims=3) + motion_camera = motion_camera.roll(int(W * shift_x.mean().item()), dims=4) + + # Propagation depuis frame précédente pour continuité + if previous_latent is not None: + latents = previous_latent + (latents - previous_latent) * self.strength + + # Combinaison finale + latents = latents + motion_person + motion_decor + motion_camera + latents = torch.clamp(latents, -1.0, 1.0) + + return latents + +# Alias pour compatibilité avec ton script existant +MotionModule = MotionModuleEnhanced diff --git a/scripts/modules/Readme.txt b/scripts/modules/Readme.txt new file mode 100644 index 00000000..60bd76cd --- /dev/null +++ b/scripts/modules/Readme.txt @@ -0,0 +1 @@ +All modules here diff --git a/scripts/modules/motion_module_cam.py b/scripts/modules/motion_module_cam.py new file mode 100644 index 00000000..0e96517d --- /dev/null +++ b/scripts/modules/motion_module_cam.py @@ -0,0 +1,79 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MotionModuleCam(nn.Module): + """ + Motion module caméra pour N3R : + - Applique une rotation progressive (face -> profil) + - Translation et zoom + - Compatible latents [B, C, F, H, W] + """ + def __init__(self, rot_strength: float = 30.0, tx_strength: float = 20.0, + ty_strength: float = 0.0, zoom_strength: float = 0.9): + """ + Args: + rot_strength: rotation max en degrés + tx_strength: translation max X en pixels + ty_strength: translation max Y en pixels + zoom_strength: zoom final (1.0 = pas de zoom) + """ + super().__init__() + self.rot_strength = rot_strength + self.tx_strength = tx_strength + self.ty_strength = ty_strength + self.zoom_strength = zoom_strength + + def forward(self, latents): + """ + latents: [B, C, F, H, W] + """ + if latents.dim() != 5: + return latents # fallback + + B, C, F, H, W = latents.shape + + for f in range(F): + # progression normalisée frame / total frames + t = f / max(F-1, 1) + + # rotation progressive (0° -> rot_strength) + angle = t * self.rot_strength + + # translation progressive + tx = t * self.tx_strength + ty = t * self.ty_strength + + # zoom progressif (1.0 -> zoom_strength) + zoom = 1.0 - t * (1.0 - self.zoom_strength) + + # transformation de la frame + latents[:,:,f] = self.transform_frame(latents[:,:,f], angle, tx, ty, zoom) + + return latents + + @staticmethod + def transform_frame(frame, angle, tx, ty, zoom): + """ + Applique rotation + translation + zoom sur une frame [B,C,H,W] + """ + B, C, H, W = frame.shape + device = frame.device + dtype = frame.dtype + + angle_rad = math.radians(angle) + cos = math.cos(angle_rad) / zoom + sin = math.sin(angle_rad) / zoom + + # matrice affine [B,2,3] + theta = torch.tensor([ + [cos, -sin, 2*tx/W], + [sin, cos, 2*ty/H] + ], device=device, dtype=dtype).unsqueeze(0).repeat(B,1,1) + + # grid et sampling + grid = F.affine_grid(theta, frame.size(), align_corners=False) + transformed = F.grid_sample(frame, grid, align_corners=False, padding_mode='border') + + return transformed diff --git a/scripts/modules/motion_module_cam2.py b/scripts/modules/motion_module_cam2.py new file mode 100644 index 00000000..11f97688 --- /dev/null +++ b/scripts/modules/motion_module_cam2.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + +class MotionModuleCam(nn.Module): + def __init__(self, strength: float = 0.05, + prompt_injection_alpha: float = 0.3, + prompt_injection_every_n_frames: int = 5): + super().__init__() + self.strength = strength + self.prompt_injection_alpha = prompt_injection_alpha + self.prompt_injection_every_n_frames = prompt_injection_every_n_frames + + def forward(self, latents, prompt_latents=None): + """ + latents: [B, C, F, H, W] - sortie N3R + prompt_latents: [B, C, F, H, W] - latents générés depuis le prompt seul + """ + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + + # --------------------------- + # 1️⃣ Appliquer une dérive temporelle progressive + # --------------------------- + motion = torch.linspace(0, 1, F, device=latents.device).view(1, 1, F, 1, 1) + latents = latents + motion * self.strength + + # --------------------------- + # 2️⃣ Ignorer les clés mémoire à 0 (si applicable) + # --------------------------- + # On suppose que la 4ème dimension des canaux contient des clés à 0.0 + if C >= 4: + mask = torch.ones(C, device=latents.device) + mask[latents[0,:,0,0,0] == 0.0] = 0.0 + latents = latents * mask.view(1, -1, 1, 1, 1) + + # --------------------------- + # 3️⃣ Injection du prompt toutes les N frames + # --------------------------- + if prompt_latents is not None: + for f in range(F): + if f % self.prompt_injection_every_n_frames == 0: + latents[:, :, f, :, :] = ( + latents[:, :, f, :, :] * (1 - self.prompt_injection_alpha) + + prompt_latents[:, :, f, :, :] * self.prompt_injection_alpha + ) + + return latents diff --git a/scripts/modules/motion_module_cam3.py b/scripts/modules/motion_module_cam3.py new file mode 100644 index 00000000..930df93e --- /dev/null +++ b/scripts/modules/motion_module_cam3.py @@ -0,0 +1,33 @@ +# scripts/modules/motion_module_cam3.py +import torch +import torch.nn as nn + +class MotionModuleCam(nn.Module): # ⚠ même nom que l’ancien module + def __init__(self, strength: float = 0.15, + prompt_injection_alpha: float = 0.2, + prompt_injection_every_n_frames: int = 5): + super().__init__() + self.strength = strength + self.prompt_injection_alpha = prompt_injection_alpha + self.prompt_injection_every_n_frames = prompt_injection_every_n_frames + + def forward(self, latents, prompt_latents=None): + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + + t = torch.linspace(0, 2*torch.pi, F, device=latents.device).view(1, 1, F, 1, 1) + motion = torch.sin(t) * self.strength + latents = latents + motion + + if prompt_latents is not None: + for f in range(F): + if f % self.prompt_injection_every_n_frames == 0: + latents[:, :, f, :, :] = ( + latents[:, :, f, :, :] * (1 - self.prompt_injection_alpha) + + prompt_latents[:, :, f, :, :] * self.prompt_injection_alpha + ) + + latents = latents + 0.02 * torch.randn_like(latents) + return latents diff --git a/scripts/modules/motion_module_lite.py b/scripts/modules/motion_module_lite.py new file mode 100644 index 00000000..a8e4cacf --- /dev/null +++ b/scripts/modules/motion_module_lite.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + +class MotionModuleTiny(nn.Module): + def __init__(self, strength: float = 0.05): + super().__init__() + self.strength = strength + + def forward(self, latents): + """ + latents: [B, C, F, H, W] + Applique une légère dérive temporelle pour simuler du mouvement. + """ + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + + # Création d'un léger décalage temporel progressif + motion = torch.linspace(0, 1, F, device=latents.device).view(1, 1, F, 1, 1) + + # Applique une petite variation + latents = latents + motion * self.strength + + return latents + + +class MotionModule(nn.Module): + def __init__(self, strength: float = 0.05): + super().__init__() + self.strength = strength + + def forward(self, latents): + """ + latents: [B, C, F, H, W] + Applique une légère dérive temporelle pour simuler du mouvement. + """ + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + + # Création d'un léger décalage temporel progressif + motion = torch.linspace(0, 1, F, device=latents.device).view(1, 1, F, 1, 1) + + # Applique une petite variation + latents = latents + motion * self.strength + + return latents diff --git a/scripts/modules/motion_module_show.py b/scripts/modules/motion_module_show.py new file mode 100644 index 00000000..4cda93a5 --- /dev/null +++ b/scripts/modules/motion_module_show.py @@ -0,0 +1,72 @@ +# motion_module_show.py +# motion_module_safe_patch.py +# motion_module_safe_patch.py +import torch +import matplotlib.pyplot as plt +from scripts.modules.motion_ulta_lite_fix import MotionModuleUltraLiteFixComplete + +class MotionModuleSafePatch: + """ + Version patchée de MotionModuleSafe pour éviter les erreurs de dimension. + - Convertit la concaténation de latents en addition pour compatibilité UNet. + - Injecte un bruit minimal si frames trop faibles. + - Affichage debug optionnel. + """ + def __init__(self, verbose=False, min_threshold=1e-3, noise_scale=1e-2): + self.motion_module = MotionModuleUltraLiteFixComplete(verbose=False) + self.verbose = verbose + self.min_threshold = min_threshold + self.noise_scale = noise_scale + + def ensure_valid_latents(self, latents): + B, C, F, H, W = latents.shape + for f in range(F): + frame_abs_max = latents[:, :, f, :, :].abs().max() + if frame_abs_max < self.min_threshold: + latents[:, :, f, :, :] += torch.randn_like(latents[:, :, f, :, :]) * self.noise_scale + if self.verbose: + print(f"[SAFE DEBUG] Frame {f} trop faible ({frame_abs_max:.6f}) → bruit injecté") + return latents + + def show_latents(self, latents, title="Latents"): + if latents.abs().max() < self.min_threshold: + if self.verbose: + print("[SAFE DEBUG] Frames trop faibles pour affichage") + return + F = latents.shape[2] + fig, axes = plt.subplots(1, F, figsize=(3*F,3)) + for f in range(F): + img = latents[0, :3, f, :, :].permute(1,2,0).clamp(-1,1) + img = (img + 1)/2.0 + axes[f].imshow(img.detach().cpu()) + axes[f].axis('off') + axes[f].set_title(f"Frame {f}") + fig.suptitle(title) + plt.show() + + def __call__(self, latents, init_image_scale_override=None): + latents = self.ensure_valid_latents(latents) + + # Override temporaire init_image_scale si demandé + if init_image_scale_override is not None: + original_scale = getattr(self.motion_module, "init_image_scale", 1.0) + self.motion_module.init_image_scale = init_image_scale_override + + # Patch principal : addition au lieu de concat pour éviter "2048 vs 4096" + latents_after = self.motion_module(latents) + if latents_after.shape != latents.shape: + if latents_after.shape[1] == 2 * latents.shape[1]: + # remplace concat par addition (ou moyenne) + latents_after = latents + latents_after[:, latents.shape[1]:, :, :, :] + if self.verbose: + print(f"[SAFE PATCH] Fusion add : {latents.shape} ← {latents_after.shape}") + + if init_image_scale_override is not None: + self.motion_module.init_image_scale = original_scale + + if latents_after.abs().max() > self.min_threshold: + self.show_latents(latents_after, title="Après Motion Module") + elif self.verbose: + print("[SAFE DEBUG] Motion module appliqué mais frames trop faibles pour affichage") + + return latents_after diff --git a/scripts/modules/motion_module_show_safe.py b/scripts/modules/motion_module_show_safe.py new file mode 100644 index 00000000..9ece0c16 --- /dev/null +++ b/scripts/modules/motion_module_show_safe.py @@ -0,0 +1,93 @@ +# motion_module_show.py +import torch +import matplotlib.pyplot as plt +from scripts.modules.motion_ulta_lite_fix import MotionModuleUltraLiteFixComplete + +class MotionModuleSafePatch: + """ + Version patchée de MotionModuleSafe pour : + - Éviter les erreurs de dimension avec UNet (addition au lieu de concat) + - Injecter un bruit minimal si frames trop faibles + - Affichage debug optionnel + - Lissage optionnel pour réduire le flou/flickering + """ + def __init__(self, verbose=False, min_threshold=1e-3, noise_scale=1e-2, smoothing_alpha=0.3): + self.motion_module = MotionModuleUltraLiteFixComplete(verbose=False) + self.verbose = verbose + self.min_threshold = min_threshold + self.noise_scale = noise_scale + self.smoothing_alpha = smoothing_alpha + self.previous_latents = None + + def ensure_valid_latents(self, latents): + """Injecte du bruit si frames trop faibles""" + B, C, F, H, W = latents.shape + for f in range(F): + frame_abs_max = latents[:, :, f, :, :].abs().max() + if frame_abs_max < self.min_threshold: + latents[:, :, f, :, :] += torch.randn_like(latents[:, :, f, :, :]) * self.noise_scale + if self.verbose: + print(f"[SAFE DEBUG] Frame {f} trop faible ({frame_abs_max:.6f}) → bruit injecté") + return latents + + def smooth_latents(self, latents): + """Lissage simple entre frames pour réduire le flou et flickering""" + if self.previous_latents is None: + self.previous_latents = latents.clone() + return latents + latents = (1 - self.smoothing_alpha) * self.previous_latents + self.smoothing_alpha * latents + self.previous_latents = latents.clone() + return latents + + def show_latents(self, latents, title="Latents"): + """Affiche les latents de manière propre""" + if latents.abs().max() < self.min_threshold: + if self.verbose: + print("[SAFE DEBUG] Frames trop faibles pour affichage") + return + F = latents.shape[2] + fig, axes = plt.subplots(1, F, figsize=(3*F, 3)) + if F == 1: + axes = [axes] + for f in range(F): + img = latents[0, :3, f, :, :].permute(1,2,0).clamp(-1,1) + img = (img + 1) / 2.0 + axes[f].imshow(img.detach().cpu()) + axes[f].axis('off') + axes[f].set_title(f"Frame {f}") + fig.suptitle(title) + plt.show() + + def __call__(self, latents, init_image_scale_override=None, apply_smoothing=True): + latents = self.ensure_valid_latents(latents) + + # Override temporaire init_image_scale si demandé + if init_image_scale_override is not None: + original_scale = getattr(self.motion_module, "init_image_scale", 1.0) + self.motion_module.init_image_scale = init_image_scale_override + + # Application du motion module + latents_after = self.motion_module(latents) + + # Patch dimensionnel : addition au lieu de concat + if latents_after.shape != latents.shape: + if latents_after.shape[1] == 2 * latents.shape[1]: + latents_after = latents + latents_after[:, latents.shape[1]:, :, :, :] + if self.verbose: + print(f"[SAFE PATCH] Fusion add : {latents.shape} ← {latents_after.shape}") + + # Restauration scale original + if init_image_scale_override is not None: + self.motion_module.init_image_scale = original_scale + + # Lissage pour réduire flou / flickering + if apply_smoothing: + latents_after = self.smooth_latents(latents_after) + + # Affichage debug + if latents_after.abs().max() > self.min_threshold: + self.show_latents(latents_after, title="Après Motion Module") + elif self.verbose: + print("[SAFE DEBUG] Motion module appliqué mais frames trop faibles pour affichage") + + return latents_after diff --git a/scripts/modules/motion_module_tiny.py b/scripts/modules/motion_module_tiny.py new file mode 100644 index 00000000..d9128a6e --- /dev/null +++ b/scripts/modules/motion_module_tiny.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import math + + +class MotionModuleTiny(nn.Module): + def __init__( + self, + strength: float = 0.03, + hair_bias: float = 0.7, + wave_speed: float = 1.5, + wave_amplitude: float = 1.0, + ): + super().__init__() + self.strength = strength + self.hair_bias = hair_bias + self.wave_speed = wave_speed + self.wave_amplitude = wave_amplitude + + def forward(self, latents): + """ + latents: [B, C, F, H, W] + Simule un léger mouvement de cheveux naturel. + """ + + if latents.dim() != 5: + return latents + + B, C, F, H, W = latents.shape + device = latents.device + + # ------------------------- + # 1️⃣ Masque vertical (plus fort en haut) + # ------------------------- + y = torch.linspace(1.0, 0.0, H, device=device).view(1, 1, 1, H, 1) + vertical_mask = y ** self.hair_bias # accentuation haut image + + # ------------------------- + # 2️⃣ Oscillation temporelle sinusoïdale + # ------------------------- + t = torch.linspace(0, 2 * math.pi, F, device=device) + wave = torch.sin(t * self.wave_speed) * self.wave_amplitude + wave = wave.view(1, 1, F, 1, 1) + + # ------------------------- + # 3️⃣ Bruit cohérent léger + # ------------------------- + noise = torch.randn_like(latents) * 0.5 + + # ------------------------- + # 4️⃣ Combinaison + # ------------------------- + motion = noise * wave * vertical_mask * self.strength + + return latents + motion + + +# Version générique (alias) +class MotionModule(MotionModuleTiny): + pass diff --git a/scripts/modules/motion_module_ultralite_debug.py b/scripts/modules/motion_module_ultralite_debug.py new file mode 100644 index 00000000..ba1bc268 --- /dev/null +++ b/scripts/modules/motion_module_ultralite_debug.py @@ -0,0 +1,34 @@ +# motion_module_ultralite_debug.py +import torch +import torch.nn as nn + +class MotionModuleUltraLiteDebug(nn.Module): + def __init__(self, strength: float = 0.01): + super().__init__() + self.strength = strength + + def forward(self, latents): + # Forcer 5D : [B, C, F, H, W] + if latents.dim() == 4: + latents = latents.unsqueeze(2) # ajoute F=1 + + B, C, F, H, W = latents.shape + device = latents.device + + # Très léger mouvement + motion = torch.linspace(0, 1, F, device=device).view(1, 1, F, 1, 1) + latents = latents + motion * self.strength + + # Debug frame par frame + for f in range(F): + frame_latents = latents[:, :, f, :, :] + print(f"[DEBUG TRACE] Frame {f}: latents min={frame_latents.min().item():.6f}, max={frame_latents.max().item():.6f}") + + # Repasser à 4D si nécessaire (optionnel, selon pipeline) + if F == 1: + latents = latents.squeeze(2) + + return latents + +class MotionModule(MotionModuleUltraLiteDebug): + pass diff --git a/scripts/modules/motion_ulta_lite_fix.py b/scripts/modules/motion_ulta_lite_fix.py new file mode 100644 index 00000000..b9fc2ff2 --- /dev/null +++ b/scripts/modules/motion_ulta_lite_fix.py @@ -0,0 +1,61 @@ +# scripts/modules/motion_ulta_lite_fix.py +import torch +import torch.nn as nn + +class MotionModuleUltraLiteFixComplete(nn.Module): + """ + Module motion ultra-léger pour corriger les frames mortes et ajouter un petit bruit créatif. + """ + def __init__(self, creative_noise=0.07, verbose=True): + super().__init__() + self.creative_noise = creative_noise + self.verbose = verbose + + def forward(self, latents): + """ + Applique le motion module aux latents. + latents: torch.Tensor [B, C, H, W] ou [B, C, F, H, W] + """ + original_dim = latents.dim() + if original_dim == 4: + latents = latents.unsqueeze(2) # [B, C, 1, H, W] + + B, C, F, H, W = latents.shape + + if self.verbose: + print(f"[SAFE DEBUG] Latents avant motion: shape={latents.shape}, " + f"min={latents.min():.6f}, max={latents.max():.6f}, " + f"mean={latents.mean():.6f}, std={latents.std():.6f}") + + # Correction des frames nulles + for f in range(F): + frame_latents = latents[:, :, f, :, :] + if (frame_latents.abs() < 1e-6).all(): + if f == 0: + # Première frame → petit bruit + latents[:, :, f, :, :] = frame_latents + torch.randn_like(frame_latents) * (self.creative_noise * 0.1) + else: + # Interpolation depuis la frame précédente + petit bruit + latents[:, :, f, :, :] = latents[:, :, f-1, :, :] + torch.randn_like(frame_latents) * (self.creative_noise * 0.05) + if self.verbose: + print(f"[SAFE DEBUG] Frame {f} morte remplacée par bruit (première frame)" if f==0 else f"[SAFE DEBUG] Frame {f} corrigée depuis frame précédente") + + # Ajouter un petit bruit créatif sur toutes les frames + latents = latents + torch.randn_like(latents) * (self.creative_noise * 0.5) + + if self.verbose: + print(f"[SAFE DEBUG] Latents après motion: min={latents.min():.6f}, " + f"max={latents.max():.6f}, mean={latents.mean():.6f}, " + f"std={latents.std():.6f}") + + # Stats par frame + for f in range(F): + frame_latents = latents[:, :, f, :, :] + print(f"[SAFE TRACE] Frame {f}: min={frame_latents.min():.6f}, " + f"max={frame_latents.max():.6f}, mean={frame_latents.mean():.6f}, " + f"std={frame_latents.std():.6f}") + + if original_dim == 4: + latents = latents.squeeze(2) + + return latents diff --git a/scripts/n3r2good.py b/scripts/n3r2good.py new file mode 100644 index 00000000..e545583f --- /dev/null +++ b/scripts/n3r2good.py @@ -0,0 +1,251 @@ +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +from transformers import CLIPTokenizerFast, CLIPTextModel +import math + + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae, safe_load_unet, safe_load_scheduler +from scripts.utils.vae_utils import encode_images_to_latents, decode_latents_to_image_tiled +from scripts.utils.motion_utils import load_motion_module, apply_motion_module +from scripts.utils.safe_latent import ensure_valid +from scripts.utils.video_utils import save_frames_as_video +from scripts.utils.n3r_utils import load_image_file, generate_latents + +LATENT_SCALE = 0.18215 # Tiny-SD 128x128 + +# ------------------------- +# Main pipeline +# ------------------------- + +# ------------------------- +# Image utilities +# ------------------------- +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + if p.lower().endswith(".gif"): + img = Image.open(p) + frames = [torch.tensor(np.array(f)).permute(2,0,1).to(device=device, dtype=dtype)/127.5 - 1.0 + for f in ImageSequence.Iterator(img)] + print(f"✅ GIF chargé : {p} avec {len(frames)} frames") + all_tensors.extend(frames) + else: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +# ------------------------- +# Encode / Decode +# ------------------------- +def encode_images_to_latents(images, vae): + device = vae.device + images = images.to(device=device, dtype=torch.float32) + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +def decode_latents_to_image(latents, vae): + latents = latents.to(vae.device).float() / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) + +# ------------------------- +# Main ultra safe VRAM +# ------------------------- +def main(args): + + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 35) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f"⏱️ Durée totale estimée de la vidéo : {estimated_seconds:.1f}s") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + if not unet or not vae or not scheduler: + print("❌ UNet, VAE ou Scheduler manquant.") + return + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else default_motion_module + if not callable(motion_module): + motion_module = default_motion_module + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/run_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + from torchvision.transforms import ToPILImage + to_pil = ToPILImage() + + frames_for_video = [] + frame_counter = 0 + + total_frames = len(input_paths) * num_fraps_per_image * max(len(embeddings), 1) + pbar = tqdm(total=total_frames, ncols=120) + + # ------------------------- + # Generation loop ultra-safe + # ------------------------- + for img_path in input_paths: + + input_image = load_images([img_path], + W=cfg["W"], + H=cfg["H"], + device=device, + dtype=dtype) + + input_latents = encode_images_to_latents(input_image, vae) + input_latents = input_latents.expand(-1, -1, num_fraps_per_image, -1, -1).clone() + input_latents += torch.randn_like(input_latents) * 0.01 # bruit initial + + for pos_embeds, neg_embeds in embeddings: + + B, C, F, H, W = input_latents.shape + + for f in range(F): + + scheduler.set_timesteps(steps, device=device) + + if f == 0: + latents_frame = input_latents[:, :, f:f+1, :, :] + else: + latents_frame = generate_latents( + latents=input_latents[:, :, f:f+1, :, :], + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + scheduler=scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale * 0.3, + init_image_scale=init_image_scale * (1 - f / F) + ) + + # ------------------------- + # Vérification et correction des latents + # ------------------------- + mean_latent = latents_frame.abs().mean().item() + if mean_latent < 0.05 or math.isnan(mean_latent): + # Relance frame avec bruit contrôlé + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + latents_frame += torch.randn_like(latents_frame) * 0.05 + print(f"⚠ Frame {frame_counter:05d} relancée, mean_latent={latents_frame.abs().mean().item():.6f}") + + # Clamp pour éviter valeurs extrêmes + latents_frame = latents_frame.squeeze(2).to(torch.float32).clamp(-3.0, 3.0) + + # Decode VAE tuilé + frame_tensor = decode_latents_to_image_tiled( + latents_frame, + vae, + tile_size=32, #16 24 32 + overlap=16 # valeur mini 4 , 8 conseillé , 16 max + ).clamp(0, 1) + + if frame_tensor.ndim == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + + frames_for_video.append(frame_pil) + frame_counter += 1 + pbar.update(1) + + # --- log latent moyen --- + print(f"Frame {frame_counter:05d} | mean abs(latent) = {latents_frame.abs().mean().item():.6f}") + + pbar.close() + + # --- sauvegarde vidéo --- + save_frames_as_video(frames_for_video, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# Entrée +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3r2good_v1.py b/scripts/n3r2good_v1.py new file mode 100644 index 00000000..369ca08d --- /dev/null +++ b/scripts/n3r2good_v1.py @@ -0,0 +1,307 @@ +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +from transformers import CLIPTokenizerFast, CLIPTextModel +import os +import math +import ffmpeg + + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae, safe_load_unet, safe_load_scheduler +from scripts.utils.vae_utils import encode_images_to_latents, decode_latents_to_image_tiled +from scripts.utils.motion_utils import load_motion_module, apply_motion_module +from scripts.utils.safe_latent import ensure_valid +from scripts.utils.video_utils import save_frames_as_video, upscale_video +from scripts.utils.n3r_utils import load_image_file, generate_latents + +LATENT_SCALE = 0.18215 # Tiny-SD 128x128 + +# ------------------------- +# Main pipeline +# ------------------------- + +# ------------------------- +# Image utilities +# ------------------------- +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + if p.lower().endswith(".gif"): + img = Image.open(p) + frames = [torch.tensor(np.array(f)).permute(2,0,1).to(device=device, dtype=dtype)/127.5 - 1.0 + for f in ImageSequence.Iterator(img)] + print(f"✅ GIF chargé : {p} avec {len(frames)} frames") + all_tensors.extend(frames) + else: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +# ------------------------- +# Encode / Decode +# ------------------------- +def encode_images_to_latents(images, vae): + device = vae.device + images = images.to(device=device, dtype=torch.float32) + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +def decode_latents_to_image(latents, vae): + latents = latents.to(vae.device).float() / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) + +# ------------------------- +# Main ultra safe VRAM +# ------------------------- +def main(args): + + cfg = load_config(args.config) + print("DEBUG: num_fraps_per_image =", cfg.get("num_fraps_per_image")) + print("DEBUG: full cfg keys =", list(cfg.keys())) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + # New option creatif + creative_mode = cfg.get("creative_mode", False) + creative_scale_min = cfg.get("creative_scale_min", 0.2) + creative_scale_max = cfg.get("creative_scale_max", 0.8) + creative_noise = cfg.get("creative_noise", 0.0) + # PATCH ------------------------------------------------------ + #num_fraps_per_image = int(cfg.get("num_fraps_per_image", 20)) + #print("DEBUG: num_fraps_per_image final =", num_fraps_per_image) + # ---------------------------------------------------------------- + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 35) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f"⏱️ Durée totale estimée de la vidéo : {estimated_seconds:.1f}s") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + if not unet or not vae or not scheduler: + print("❌ UNet, VAE ou Scheduler manquant.") + return + + # ---------------------- Motion Module param --------------------------------------------------------- + + #motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else default_motion_module + + + motion_path = cfg.get("motion_module") + + if motion_path: + motion_module = load_motion_module(motion_path, device=device) + + # Injection paramètres dynamiques + if hasattr(motion_module, "strength"): + motion_module.strength = cfg.get("motion_strength", 0.03) + + if hasattr(motion_module, "hair_bias"): + motion_module.hair_bias = cfg.get("motion_hair_bias", 0.7) + + if hasattr(motion_module, "wave_speed"): + motion_module.wave_speed = cfg.get("motion_wave_speed", 1.5) + + if hasattr(motion_module, "wave_amplitude"): + motion_module.wave_amplitude = cfg.get("motion_wave_amplitude", 1.0) + else: + motion_module = None + + # ----------------------------------------------------------------------------------------------------- + + + if not callable(motion_module): + motion_module = default_motion_module + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/run_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + from torchvision.transforms import ToPILImage + to_pil = ToPILImage() + + frames_for_video = [] + frame_counter = 0 + + total_frames = len(input_paths) * num_fraps_per_image * max(len(embeddings), 1) + pbar = tqdm(total=total_frames, ncols=120) + + # ------------------------- + # Generation loop ultra-safe + # ------------------------- + for img_path in input_paths: + + input_image = load_images([img_path], + W=cfg["W"], + H=cfg["H"], + device=device, + dtype=dtype) + + input_latents = encode_images_to_latents(input_image, vae) + input_latents = input_latents.expand(-1, -1, num_fraps_per_image, -1, -1).clone() + input_latents += torch.randn_like(input_latents) * 0.01 # bruit initial + + for pos_embeds, neg_embeds in embeddings: + + B, C, F, H, W = input_latents.shape + + for f in range(F): + + scheduler.set_timesteps(steps, device=device) + + if f == 0: + latents_frame = input_latents[:, :, f:f+1, :, :] + else: + # ------------------ GENERATION ------------------------- + if creative_mode: + # guidance dynamique par frame + dynamic_scale = guidance_scale * (creative_scale_min + (creative_scale_max - creative_scale_min) * (f / F)) + else: + dynamic_scale = guidance_scale + + latents_frame = generate_latents( + latents=input_latents[:, :, f:f+1, :, :], + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + scheduler=scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=dynamic_scale, + init_image_scale=init_image_scale * (1 - f / F) + ) + + # ajout de bruit créatif léger + if creative_mode and creative_noise > 0: + latents_frame += torch.randn_like(latents_frame) * creative_noise + # ---------------------------------------------------------------------------------------------------------------------- + # ------------------------- + # Vérification et correction des latents + # ------------------------- + mean_latent = latents_frame.abs().mean().item() + if mean_latent < 0.05 or math.isnan(mean_latent): + # Relance frame avec bruit contrôlé + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + latents_frame += torch.randn_like(latents_frame) * 0.05 + print(f"⚠ Frame {frame_counter:05d} relancée, mean_latent={latents_frame.abs().mean().item():.6f}") + + # Clamp pour éviter valeurs extrêmes + latents_frame = latents_frame.squeeze(2).to(torch.float32).clamp(-3.0, 3.0) + + # Decode VAE tuilé + frame_tensor = decode_latents_to_image_tiled( + latents_frame, + vae, + tile_size=32, #16 24 32 + overlap=16 # valeur mini 4 , 8 conseillé , 16 max + ).clamp(0, 1) + + if frame_tensor.ndim == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + + frames_for_video.append(frame_pil) + frame_counter += 1 + pbar.update(1) + + # --- log latent moyen --- + print(f"Frame {frame_counter:05d} | mean abs(latent) = {latents_frame.abs().mean().item():.6f}") + + pbar.close() + + # --- sauvegarde vidéo --- + save_frames_as_video(frames_for_video, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + + # --- upscale final --- + upscaled_video = output_dir / f"output_{timestamp}_x2.mp4" + upscale_video(out_video, upscaled_video, scale_factor=2) + print(f"🎬 Vidéo générée X2 : {out_video}") + + print("✅ Pipeline terminé proprement.") + +# Entrée +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rHYBRID14.py b/scripts/n3rHYBRID14.py new file mode 100644 index 00000000..2d01d996 --- /dev/null +++ b/scripts/n3rHYBRID14.py @@ -0,0 +1,227 @@ +# ------------------------- +# n3rHYBRID14_ULTRA_STABLE.py +# ------------------------- + +import os, time, csv +from pathlib import Path +from datetime import datetime +import torch, numpy as np +from PIL import Image +import cv2 + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae_safetensors, safe_load_unet, safe_load_scheduler, safe_load_vae_stable +from scripts.utils.vae_utils import decode_latents_to_image_tiled +from scripts.utils.vae_utils import test_vae_256 +from scripts.utils.model_utils import load_pretrained_unet, get_text_embeddings, load_DDIMScheduler +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images, decode_latents_correct, generate_latents_ai_5D_optimized +from transformers import CLIPTextModel, CLIPTokenizer + +LATENT_SCALE = 0.18215 +CLAMP_MAX = 1.0 +torch.backends.cuda.matmul.allow_tf32 = True + +def save_frame(img_array, filename): + img_array = np.clip(img_array, 0.0, 1.0) + img_uint8 = (img_array * 255).astype(np.uint8) + os.makedirs(os.path.dirname(filename), exist_ok=True) + Image.fromarray(img_uint8).save(filename) + +def encode_image_latents(image_tensor, vae, scale=LATENT_SCALE): + device = next(vae.parameters()).device + img = image_tensor.to(device=device, dtype=next(vae.parameters()).dtype) + with torch.no_grad(): + latents = vae.encode(img).latent_dist.sample() * scale + return latents.unsqueeze(2) + +def main(args): + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 35) + seed = cfg.get("seed",42) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" seed : {seed}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f"⏱ Durée totale estimée de la vidéo : {estimated_seconds:.1f}s") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + if not vae : + print("❌ VAE manquant.") + return + + if not scheduler: + print("❌ Scheduler manquant.") + return + + if not unet : + print("❌ UNet manquant.") + return + + test_vae_256(vae, Image.open("scripts/utils/logo.png").convert("RGB")) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ------------------------- + # UNet + Scheduler + # ------------------------- + #unet = load_pretrained_unet(args.pretrained_model_path, device=device, dtype=dtype) + unet.eval() + try: unet.enable_xformers_memory_efficient_attention() + except: pass + #scheduler = load_DDIMScheduler(args.pretrained_model_path) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) \ + if cfg.get("motion_module") else None + + # ------------------------- + # OUTPUT + # ------------------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/hybrid14_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + debug_dir = output_dir / "debug_frames"; debug_dir.mkdir(exist_ok=True) + csv_file = output_dir / "generation_log.csv" + + video = None + frame_counter = 0 + + with open(csv_file,"w",newline="") as f_csv: + writer = csv.writer(f_csv) + writer.writerow(["frame","latent_min","latent_max","gen_time","decode_time"]) + + for img_path in input_paths: + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents = encode_image_latents(input_image, vae) + input_latents = input_latents.expand(-1,-1,num_fraps_per_image,-1,-1).clone() + + # --- Découpage en segments pour éviter OOM --- + segment_size = 4 + num_segments = (num_fraps_per_image + segment_size - 1) // segment_size + + for seg_idx in range(num_segments): + start_idx = seg_idx * segment_size + end_idx = min(start_idx + segment_size, num_fraps_per_image) + + for f_idx in range(start_idx, end_idx): + torch.cuda.empty_cache() + latent_frame = input_latents[:,:,f_idx:f_idx+1,:,:].squeeze(2) + frame_seed = seed + f_idx + + # --- Génération latent --- + gen_start = time.time() + latent_frame = generate_latents_ai_5D_optimized( + latent_frame=latent_frame, + scheduler=scheduler, + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_seed, + steps=steps + ) + gen_time = time.time() - gen_start + + # --- Décodage VAE --- + decode_start = time.time() + if args.vae_offload: + vae.to(device) + print("Latent stats:", + float(latent_frame.min()), + float(latent_frame.max()), + float(latent_frame.mean())) + # important + #latent_frame = latent_frame / LATENT_SCALE + frame_tensor = decode_latents_correct(latent_frame, vae) + decode_time = time.time() - decode_start + if args.vae_offload: + vae.cpu(); torch.cuda.empty_cache() + + # --- Conversion finale --- + + # conversion correcte SD + #frame_tensor = (frame_tensor / 2 + 0.5).clamp(0, 1) + frame_tensor = frame_tensor.clamp(0.0,1.0) + print("After decode:", + float(frame_tensor.min()), + float(frame_tensor.max())) + frame_array = frame_tensor[0].permute(1,2,0).cpu().numpy() # (H,W,C) + save_frame(frame_array, debug_dir/f"frame_{frame_counter:05d}.png") + + if video is None: + h,w = frame_array.shape[:2] + video_path = output_dir/"animation.mp4" + video = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w,h)) + + video.write(cv2.cvtColor((frame_array*255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + writer.writerow([frame_counter,float(latent_frame.min()),float(latent_frame.max()),round(gen_time,4),round(decode_time,4)]) + + del latent_frame, frame_tensor + torch.cuda.empty_cache() + frame_counter += 1 + + if video: video.release() + print("✅ Génération ultra-stable terminée.") + +# ------------------------- +# Entrée +# ------------------------- +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rHYBRID22.py b/scripts/n3rHYBRID22.py new file mode 100644 index 00000000..8f065d7f --- /dev/null +++ b/scripts/n3rHYBRID22.py @@ -0,0 +1,273 @@ +import os, time, csv +from pathlib import Path +from datetime import datetime +import torch, numpy as np +from PIL import Image +import cv2 + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae_stable, safe_load_unet, safe_load_scheduler, clamp_and_warn_tile, tile_image_vae, merge_tiles_vae, log_rgb_stats +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images, generate_latents_ai_5D_stable + +LATENT_SCALE = 0.18215 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("high") + +# ------------------------- +# Fonctions utilitaires +# ------------------------- +def save_frame(img_array, filename): + img_array = np.clip(img_array, 0.0, 1.0) + img_uint8 = (img_array * 255).astype(np.uint8) + os.makedirs(os.path.dirname(filename), exist_ok=True) + Image.fromarray(img_uint8).save(filename) + + +def encode_images_to_latents(images, vae): + """Encode images -> latents [B,4,1,H,W] float32 pour stabilité couleur""" + vae_device = next(vae.parameters()).device + images = images.to(device=vae_device, dtype=torch.float32) + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) # [B,C,1,H,W] + return latents + +def decode_latents_to_image(latents, vae): + """Decode latents 4D ou 5D -> RGB [0,1]""" + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + if latents.ndim == 5: + B,C,T,H,W = latents.shape + latents = latents.permute(0,2,1,3,4).reshape(B*T,C,H,W) + latents = latents.to(device=vae_device, dtype=vae_dtype) + latents = latents / LATENT_SCALE + with torch.no_grad(): + images = vae.decode(latents).sample + return images.clamp(0,1) + + + +# ------------------------- +# MAIN +# ------------------------- +def main(args): + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + + fps = cfg.get("fps",12) + num_fraps_per_image = cfg.get("num_fraps_per_image",12) + steps = cfg.get("steps",10) + seed = cfg.get("seed",42) + guidance_scale = cfg.get("guidance_scale",7.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise",0.03) + tile_size = cfg.get("tile_size",128) + tile_overlap = cfg.get("tile_overlap",32) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" seed : {seed}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f"⏱ Durée totale estimée de la vidéo : {estimated_seconds:.1f}s") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + unet_dtype = next(unet.parameters()).dtype + dtype = unet_dtype # dtype pipeline + + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=False, offload=args.vae_offload) + vae = vae.float() + scheduler = safe_load_scheduler(args.pretrained_model_path) + + if not vae or not unet or not scheduler: + print("❌ Un ou plusieurs modèles manquent.") + return + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # ------------------------- + # Encode prompts + # ------------------------- + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item,list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts,list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + # Adapter dtype UNet + pos_embeds = pos_embeds.to(device=device, dtype=dtype) + neg_embeds = neg_embeds.to(device=device, dtype=dtype) + + embeddings.append((pos_embeds, neg_embeds)) + + unet.eval() + try: + unet.enable_xformers_memory_efficient_attention() + except: + pass + + # ------------------------- + # OUTPUT + # ------------------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/hybrid22_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + debug_dir = output_dir / "debug_frames" + debug_dir.mkdir(exist_ok=True) + csv_file = output_dir / "generation_log.csv" + + video = None + frame_counter = 0 + + with open(csv_file,"w",newline="") as f_csv: + writer = csv.writer(f_csv) + writer.writerow(["frame","latent_min","latent_max","gen_time","decode_time","warnings"]) + + for img_path in input_paths: + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + tiles, positions = tile_image_vae(input_image, tile_size, tile_overlap) + + for f_idx in range(num_fraps_per_image): + decoded_tiles = [] + all_warnings = [] + + for tile_idx, tile_rgb in enumerate(tiles): + tile_rgb = clamp_and_warn_tile(tile_rgb, frame_counter, tile_idx, all_warnings) + + # ---- Encode tile -> latents float32 + tile_latent = encode_images_to_latents(tile_rgb, vae) + + # ---- Convert pour UNet (fp16 si --fp16) + tile_latent = tile_latent.to(device=device, dtype=dtype) + + # ---- Squeeze 5D -> 4D pour UNet + if tile_latent.ndim == 5 and tile_latent.shape[2] == 1: + tile_latent = tile_latent.squeeze(2) # -> [B,C,H,W] + + # ---- Génération latents + gen_start = time.time() + pos_embeds, neg_embeds = embeddings[0] + + batch_latents = generate_latents_ai_5D_stable( + latent_frame=tile_latent, + scheduler=scheduler, + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + creative_noise=creative_noise, + seed=seed + f_idx, + steps=steps + ) + gen_time = time.time() - gen_start + + # ---- Décodage VAE + decode_start = time.time() + if args.vae_offload: + vae.to(device) + + print("Latent stats:", + float(batch_latents.min()), + float(batch_latents.max()), + float(batch_latents.mean())) + decoded_tile = decode_latents_to_image(batch_latents, vae) + decode_time = time.time() - decode_start + if args.vae_offload: + vae.cpu() + torch.cuda.empty_cache() + + decoded_tile = clamp_and_warn_tile(decoded_tile, frame_counter, tile_idx, all_warnings) + + # ---- Log RGB stats par tile + tile_warnings = log_rgb_stats(decoded_tile, step=f"frame{frame_counter}_tile{tile_idx}") + all_warnings.extend(tile_warnings) + + decoded_tiles.append(decoded_tile) + + # ---- Fusion tiles + final_frame = merge_tiles_vae(decoded_tiles, positions, H=input_image.shape[2], W=input_image.shape[3]) + final_frame = clamp_and_warn_tile(final_frame, frame_counter, "final", all_warnings) + + # ---- Log RGB stats frame finale + frame_warnings = log_rgb_stats(final_frame, step=f"frame{frame_counter}_final") + all_warnings.extend(frame_warnings) + + #frame_array = final_frame[0].clamp(0,1).permute(1,2,0).cpu().numpy() + + final_frame = final_frame.clamp(0.0,1.0) + print("After decode:", + float(final_frame.min()), + float(final_frame.max())) + frame_array = final_frame[0].permute(1,2,0).cpu().numpy() # (H,W,C) + + + save_frame(frame_array, debug_dir / f"frame_{frame_counter:05d}.png") + + if video is None: + h,w = frame_array.shape[:2] + video_path = output_dir / "animation.mp4" + video = cv2.VideoWriter(str(video_path), + cv2.VideoWriter_fourcc(*'mp4v'), + fps, + (w,h)) + video.write(cv2.cvtColor((frame_array*255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + writer.writerow([ + frame_counter, + float(batch_latents.min()), + float(batch_latents.max()), + round(gen_time,4), + round(decode_time,4), + "; ".join(all_warnings) + ]) + frame_counter += 1 + + if video: + video.release() + print("✅ Génération 128x128 VRAM-safe avec correction couleur et logging RGB terminée.") + +# ------------------------- +# Entrée +# ------------------------- +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path",type=str,required=True) + parser.add_argument("--config",type=str,required=True) + parser.add_argument("--device",type=str,default="cuda") + parser.add_argument("--fp16",action="store_true") + parser.add_argument("--vae-offload",action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rHYBRID25.py b/scripts/n3rHYBRID25.py new file mode 100644 index 00000000..62e1d7bc --- /dev/null +++ b/scripts/n3rHYBRID25.py @@ -0,0 +1,147 @@ +# ------------------------- +# n3rHYBRID25.py (tiling 128x128) +# ------------------------- + +import os, csv, torch, numpy as np, cv2 +from pathlib import Path +from datetime import datetime +from PIL import Image + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae_stable, safe_load_unet, safe_load_scheduler +from scripts.utils.n3r_utils import load_images, generate_frame_with_tiling +from scripts.utils.motion_utils import load_motion_module + +LATENT_SCALE = 0.18215 +torch.backends.cuda.matmul.allow_tf32 = True + +# ------------------------- +# Fonctions utilitaires +# ------------------------- +def save_frame(img_array, filename): + img_array = np.clip(img_array, 0.0, 1.0) + img_uint8 = (img_array*255).astype(np.uint8) + os.makedirs(os.path.dirname(filename), exist_ok=True) + Image.fromarray(img_uint8).save(filename) + +# ------------------------- +# Fonction principale +# ------------------------- +def main(args): + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 12) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.75) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f"⏱ Durée totale estimée : {estimated_seconds:.1f}s") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=False, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # --- Préparer embeddings + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + unet.eval() + try: unet.enable_xformers_memory_efficient_attention() + except: pass + + # ------------------------- + # Output + # ------------------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/hybrid25_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + debug_dir = output_dir / "debug_frames"; debug_dir.mkdir(exist_ok=True) + csv_file = output_dir / "generation_log.csv" + + video = None + frame_counter = 0 + + with open(csv_file, "w", newline="") as f_csv: + writer = csv.writer(f_csv) + writer.writerow(["frame", "gen_time", "decode_time", "warnings"]) + + for img_path in input_paths: + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + + for f_idx in range(num_fraps_per_image): + torch.cuda.empty_cache() + frame_tensor = generate_frame_with_tiling( + input_image, vae, unet, scheduler, embeddings, motion_module, + tile_size=128, overlap=32, + fp16=args.fp16, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + steps=steps + ) + + # --- Conversion pour vidéo --- + frame_array = frame_tensor[0].permute(1, 2, 0).cpu().numpy() + save_frame(frame_array, debug_dir / f"frame_{frame_counter:05d}.png") + if video is None: + h, w = frame_array.shape[:2] + video_path = output_dir / "animation.mp4" + video = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + video.write(cv2.cvtColor((frame_array*255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + writer.writerow([frame_counter, "", "", ""]) + frame_counter += 1 + + if video: video.release() + print("✅ Génération hybride patch-based terminée.") +# ------------------------- +# Entrée +# ------------------------- +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rHYBRID26.py b/scripts/n3rHYBRID26.py new file mode 100644 index 00000000..fdcac0e8 --- /dev/null +++ b/scripts/n3rHYBRID26.py @@ -0,0 +1,158 @@ +# ------------------------- +# n3rHYBRID26_cpu_safe_v3.py +# ------------------------- + +import os, csv, torch, numpy as np, cv2 +from pathlib import Path +from datetime import datetime +from PIL import Image + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae_stable, safe_load_unet, safe_load_scheduler +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import patchify_latents, unpatchify_latents, generate_frame_patched_v3, load_images + +LATENT_SCALE = 0.18215 +torch.backends.cuda.matmul.allow_tf32 = True + +# ------------------------- +# Utilitaires +# ------------------------- +def save_frame(img_array, filename): + img_array = np.nan_to_num(img_array, nan=0.0, posinf=1.0, neginf=0.0) + img_array = np.clip(img_array, 0.0, 1.0) + img_uint8 = (img_array*255).astype(np.uint8) + os.makedirs(os.path.dirname(filename), exist_ok=True) + Image.fromarray(img_uint8).save(filename) + +# ------------------------- +# Main +# ------------------------- +def main(args): + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 12) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.75) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device="cpu", fp16=False, offload=True) + vae.to(device) + scheduler = safe_load_scheduler(args.pretrained_model_path) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + unet.eval() + try: unet.enable_xformers_memory_efficient_attention() + except: pass + + # ------------------------- + # Output + # ------------------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/hybrid26_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + debug_dir = output_dir / "debug_frames"; debug_dir.mkdir(exist_ok=True) + csv_file = output_dir / "generation_log.csv" + + video = None + frame_counter = 0 + + with open(csv_file, "w", newline="") as f_csv: + writer = csv.writer(f_csv) + writer.writerow(["frame", "warnings"]) + + for img_path in input_paths: + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + + for f_idx in range(num_fraps_per_image): + torch.cuda.empty_cache() + + # ---- batching des patchs ---- + pos_emb, neg_emb = embeddings[0] + frame_tensor = generate_frame_patched_v3( + input_image=input_image, + vae=vae, + unet=unet, + scheduler=scheduler, + pos_emb=pos_emb, + neg_emb=neg_emb, + tile_size=64, # plus petit pour réduire mémoire + overlap=16, + steps=steps, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + device=device, + dtype=dtype + ) + + frame_array = frame_tensor[0].permute(1, 2, 0).cpu().numpy() + save_frame(frame_array, debug_dir / f"frame_{frame_counter:05d}.png") + + if video is None: + h, w = frame_array.shape[:2] + video_path = output_dir / "animation.mp4" + video = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + frame_array = np.nan_to_num(frame_array, nan=0.0, posinf=1.0, neginf=0.0) + frame_array = np.clip(frame_array, 0.0, 1.0) + video.write(cv2.cvtColor((frame_array*255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + writer.writerow([frame_counter, ""]) + frame_counter += 1 + + if video: video.release() + print("✅ Génération patch-based optimisée terminée.") + +# ------------------------- +# Entrée +# ------------------------- +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rOpenPose_utils.py b/scripts/n3rOpenPose_utils.py new file mode 100644 index 00000000..ad1522b3 --- /dev/null +++ b/scripts/n3rOpenPose_utils.py @@ -0,0 +1,812 @@ +#******************************************** +# n3rOpenPose_utils.py +#******************************************** +import torch +from diffusers import ControlNetModel +import math +import torch.nn.functional as F +from .n3rControlNet import create_canny_control, control_to_latent, match_latent_size +import numpy as np +import cv2 + +from PIL import Image +import matplotlib.pyplot as plt + +def save_debug_pose_image(pose_tensor, frame_counter, output_dir, cfg=None, prefix="openpose"): + """ + Sauvegarde la pose détectée pour vérification visuelle. + + Args: + pose_tensor (torch.Tensor): Tensor BCHW ou CHW (1,3,H,W ou 3,H,W) + frame_counter (int): numéro de la frame + output_dir (Path): dossier de sortie pour sauvegarde + cfg (dict, optional): configuration, active si cfg.get("debug_pose_visual", False) est True + prefix (str): préfixe du fichier image (default: 'openpose') + """ + if cfg is None or not cfg.get("debug_pose_visual", False): + return + + # S'assurer que le tensor est BCHW + if pose_tensor.ndim == 3: # CHW -> BCHW + pose_tensor = pose_tensor.unsqueeze(0) + + pose_tensor = pose_tensor[0] # retirer batch + + # Limiter à 3 canaux + if pose_tensor.shape[0] > 3: + pose_tensor = pose_tensor[:3, :, :] + + # CHW -> HWC + pose_np = pose_tensor.permute(1, 2, 0).cpu().numpy() + # Normalisation 0-255 + pose_np = (pose_np - pose_np.min()) / (pose_np.max() - pose_np.min() + 1e-8) * 255.0 + pose_np = pose_np.astype("uint8") + img = Image.fromarray(pose_np) + + # Nom de fichier : openpose_0001.png + output_dir.mkdir(parents=True, exist_ok=True) + filename = output_dir / f"{prefix}_{frame_counter:04d}.png" + img.save(filename) + +def debug_pose_visual(pose_tensor, frame_counter, cfg=None, title="Pose Debug"): + """ + Affiche la pose détectée pour vérification visuelle. + + Args: + pose_tensor (torch.Tensor): Tensor BCHW ou CHW (1,3,H,W ou 3,H,W) + frame_counter (int): numéro de la frame + cfg (dict, optional): configuration, active si cfg.get("debug_pose_visual", False) est True + title (str): titre pour l'affichage + """ + if cfg is None or not cfg.get("debug_pose_visual", False): + return + + # S'assurer que le tensor est BCHW + if pose_tensor.ndim == 3: # CHW -> BCHW + pose_tensor = pose_tensor.unsqueeze(0) + + pose_tensor = pose_tensor[0] # retirer batch + + # Limiter à 3 canaux + if pose_tensor.shape[0] > 3: + pose_tensor = pose_tensor[:3, :, :] + + # CHW -> HWC pour PIL + pose_np = pose_tensor.permute(1, 2, 0).cpu().numpy() + pose_np = (pose_np - pose_np.min()) / (pose_np.max() - pose_np.min() + 1e-8) * 255.0 + pose_np = pose_np.astype("uint8") + img = Image.fromarray(pose_np) + + # Affichage rapide avec matplotlib + plt.figure(figsize=(4, 4)) + plt.imshow(img) + plt.axis("off") + plt.title(f"{title} - Frame {frame_counter}") + plt.show(block=False) + plt.pause(0.1) # court délai pour rafraîchir + plt.close() + +def convert_json_to_pose_sequence(anim_data, H=512, W=512, device="cuda", dtype=torch.float16, debug=False): + """ + Convertit un JSON d'animation OpenPose simplifié en tensor utilisable par ControlNet. + + Output: + pose_sequence: tensor [num_frames, 3, H, W] (RGB image type) + """ + + frames = anim_data.get("animation", []) + pose_images = [] + + for idx, frame in enumerate(frames): + keypoints = frame.get("keypoints", []) + + # Image noire + canvas = np.zeros((H, W, 3), dtype=np.uint8) + + # --- Dessin des points --- + for kp in keypoints: + x = int(kp["x"]) + y = int(kp["y"]) + conf = kp.get("confidence", 1.0) + + if conf > 0.3: + cv2.circle(canvas, (x, y), 4, (255, 255, 255), -1) + + # --- Dessin des connexions (squelette simple) --- + skeleton = [ + (0, 1), # tête → torse + (1, 2), # torse → bras gauche + (1, 3), # torse → bras droit + (1, 4), # torse → jambe gauche + (1, 5), # torse → jambe droite + ] + + for a, b in skeleton: + if a < len(keypoints) and b < len(keypoints): + x1, y1 = int(keypoints[a]["x"]), int(keypoints[a]["y"]) + x2, y2 = int(keypoints[b]["x"]), int(keypoints[b]["y"]) + cv2.line(canvas, (x1, y1), (x2, y2), (255, 255, 255), 2) + + # --- Conversion en tensor --- + img = torch.from_numpy(canvas).float() / 255.0 # [H, W, C] + img = img.permute(2, 0, 1) # → [C, H, W] + + pose_images.append(img) + + pose_sequence = torch.stack(pose_images).to(device=device, dtype=dtype) + + if debug: + print(f"[JSON->POSE] shape: {pose_sequence.shape}") + print(f"[JSON->POSE] min/max: {pose_sequence.min().item()} / {pose_sequence.max().item()}") + + return pose_sequence + + +def apply_controlnet_openpose_step_safe( + latents, + timestep, + unet, + controlnet, + scheduler, + pose_image, + pos_embeds, + neg_embeds, + guidance_scale, + controlnet_scale=0.25, + device="cuda", + dtype=torch.float16, + debug=False +): + """ + Wrapper sécurisé pour apply_controlnet_openpose_step + - gère CPU/GPU + - corrige dtype + - convertit timestep en long pour scheduler + """ + # --- CPU float32 pour ControlNet --- + latents_cpu = latents.to("cpu", dtype=torch.float32) + unet_cpu = unet.to("cpu", dtype=torch.float32) + controlnet_cpu = controlnet.to("cpu", dtype=torch.float32) + pose_cpu = pose_image.to("cpu", dtype=torch.float32) + pos_embeds_cpu = pos_embeds.to("cpu", dtype=torch.float32) + neg_embeds_cpu = neg_embeds.to("cpu", dtype=torch.float32) + + # --- Préparer timestep --- + if timestep.ndim == 0: + timestep = timestep.unsqueeze(0) + batch_size = latents_cpu.shape[0] + timestep = timestep.repeat(batch_size).to(torch.long).to("cpu") + + # --- Appel ControlNet OpenPose --- + latents_cpu = apply_controlnet_openpose_step( + latents=latents_cpu, + t=timestep, + unet=unet_cpu, + controlnet=controlnet_cpu, + scheduler=scheduler, + pose_image=pose_cpu, + pos_embeds=pos_embeds_cpu, + neg_embeds=neg_embeds_cpu, + guidance_scale=guidance_scale, + controlnet_scale=controlnet_scale, + device="cpu", + dtype=torch.float32, + debug=debug + ) + + # --- Retour sur GPU et dtype final --- + latents_out = latents_cpu.to(device, dtype=dtype) + unet.to(device, dtype=dtype) + + return latents_out + +def build_control_latent_debug(input_pil, vae, device="cuda", latent_scale=0.18215): + import torch + + print("\n================ CONTROL LATENT DEBUG ================") + + # 1. Canny + control = create_canny_control(input_pil) + + print("[STEP 1] RAW CONTROL") + print(" shape:", control.shape) + print(" dtype:", control.dtype) + print(" min/max:", control.min().item(), control.max().item()) + + # 2. 1 → 3 channels + if control.shape[1] == 1: + control = control.repeat(1, 3, 1, 1) + + # 3. Normalize PROPERLY (CRUCIAL) + control = control.clamp(0, 1) # sécurité + control = control * 2.0 - 1.0 # [-1,1] + + print("[STEP 2] NORMALIZED") + print(" min/max:", control.min().item(), control.max().item()) + + # 4. Move to device FP32 + control = control.to(device=device, dtype=torch.float32) + + print("[STEP 3] DEVICE") + print(" device:", control.device) + print(" dtype:", control.dtype) + + # 5. Sync VAE + print("[STEP 4] VAE STATE") + print(" vae dtype:", next(vae.parameters()).dtype) + print(" vae device:", next(vae.parameters()).device) + + # 🔥 FORCER cohérence VAE + vae = vae.to(device=device, dtype=torch.float32) + + # 6. Encode SAFE (no autocast) + with torch.no_grad(): + try: + latent_dist = vae.encode(control).latent_dist + latent = latent_dist.sample() + except Exception as e: + print("❌ VAE ENCODE CRASH:", e) + raise + + print("[STEP 5] LATENT RAW") + print(" min/max:", latent.min().item(), latent.max().item()) + print(" NaN:", torch.isnan(latent).sum().item()) + + # 🚨 CHECK NaN + if torch.isnan(latent).any(): + print("⚠️ NaN DETECTED → applying fallback") + + # fallback 1: zero latent + latent = torch.zeros_like(latent) + + # fallback 2 (optionnel): + # latent = torch.randn_like(latent) * 0.1 + + # 7. Scale (SD standard) + latent = latent * latent_scale + + print("[STEP 6] SCALED LATENT") + print(" min/max:", latent.min().item(), latent.max().item()) + + # 8. Back to FP16 + latent = latent.to(dtype=torch.float16) + + print("[FINAL]") + print(" dtype:", latent.dtype) + print(" device:", latent.device) + print("=====================================================\n") + + return latent + +# ---------------- Control -> Latent sécurisé ---------------- +def control_to_latent_safe(control_tensor, vae, device="cuda", LATENT_SCALE=1.0): + # 🔥 FORCE VAE EN FP32 + vae = vae.to(device=device, dtype=torch.float32) + + control_tensor = control_tensor.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latent = vae.encode(control_tensor).latent_dist.sample() + + return latent * LATENT_SCALE + +def process_latents_streamed(control_latent, mini_latents=None, mini_weight=0.5, device="cuda"): + """ + Fusionne ControlNet / mini-latents frame par frame, patch par patch + pour réduire l'empreinte VRAM. + """ + # On garde tout en float16 tant que possible + control_latent = control_latent.to(device=device, dtype=torch.float16) + + if mini_latents is not None: + mini_latents = mini_latents.to(device=device, dtype=torch.float16) + + # Initialisation finale du tensor latents en float16 + latents = control_latent.clone() + + # Si mini_latents existe, on fait un mix patch par patch + if mini_latents is not None: + B, C, H, W = latents.shape + patch_size = 16 # petit patch pour limiter la VRAM + for y in range(0, H, patch_size): + y1 = min(y + patch_size, H) + for x in range(0, W, patch_size): + x1 = min(x + patch_size, W) + + # Sélection patch + patch_main = latents[:, :, y:y1, x:x1] + patch_mini = mini_latents[:, :, y:y1, x:x1] + + # Mix float16 → float16 pour VRAM + patch_main = (1 - mini_weight) * patch_main + mini_weight * patch_mini + + # Écriture patch back + latents[:, :, y:y1, x:x1] = patch_main + + # Nettoyage immédiat pour libérer VRAM + del patch_main, patch_mini + torch.cuda.empty_cache() + + return latents + + +def match_latent_size(latents_main, *tensors): + """ + Interpole tous les tensors pour correspondre à la taille HxW de latents_main. + """ + matched = [] + for t in tensors: + if t.shape[2:] != latents_main.shape[2:]: + t = F.interpolate(t, size=latents_main.shape[2:], mode='bilinear', align_corners=False) + matched.append(t) + return matched if len(matched) > 1 else matched[0] + + +def match_latent_size_v1(latents_main, latents_mini): + """ + Assure que latents_mini a la même taille HxW que latents_main. + """ + if latents_mini.shape[2:] != latents_main.shape[2:]: + latents_mini = F.interpolate( + latents_mini, + size=latents_main.shape[2:], # H, W + mode='bilinear', + align_corners=False + ) + return latents_mini + + +def apply_controlnet_openpose_step( + latents, + t, + unet, + controlnet, + scheduler, + pose_image, + pos_embeds, + neg_embeds=None, + guidance_scale=5.0, + controlnet_scale=0.7, + device="cuda", + dtype=torch.float16, + debug=False +): + import torch + + latents = latents.to(device=device, dtype=dtype) + pose_image = pose_image.to(device=device, dtype=dtype) + + # 🔁 classifier-free guidance + if neg_embeds is not None: + latent_model_input = torch.cat([latents] * 2) + encoder_states = torch.cat([neg_embeds, pos_embeds]) + pose_input = torch.cat([pose_image] * 2) + else: + latent_model_input = latents + encoder_states = pos_embeds + pose_input = pose_image + + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # 🔥 ControlNet + down_samples, mid_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + controlnet_cond=pose_input, + return_dict=False + ) + + # 🔥 UNet avec ControlNet + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + down_block_additional_residuals=[d * controlnet_scale for d in down_samples], + mid_block_additional_residual=mid_sample * controlnet_scale + ).sample + + # 🔁 CFG + if neg_embeds is not None: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔥 Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + if debug: + print(f"[ControlNet] t={t}, latents min/max: {latents.min().item():.3f}/{latents.max().item():.3f}") + + return latents + + +def generate_pose_sequence( + base_pose, + num_frames=16, + motion_type="idle", # "idle", "sway", "zoom", "breath" + amplitude=5.0, + device="cuda", + dtype=None, + debug=False +): + """ + Génère une séquence de poses animées (OpenPose-like). + + Args: + base_pose: tensor [1,3,H,W] ou [1,1,H,W] (image pose) + num_frames: nombre de frames + motion_type: type animation + amplitude: intensité mouvement + device: device + debug: print infos + + Returns: + List[Tensor]: liste de control_tensor + """ + + + if dtype is None: + dtype = base_pose.dtype + + base_pose = base_pose.to(device=device, dtype=dtype) + B, C, H, W = base_pose.shape + + poses = [] + + for t in range(num_frames): + + alpha = t / max(1, num_frames - 1) + phase = 2 * math.pi * alpha + + pose = base_pose.clone() + + # -------------------------------------------------- + # 🎯 MOTION TYPES + # -------------------------------------------------- + + # 🔹 1. IDLE (micro mouvements naturels) + if motion_type == "idle": + dx = math.sin(phase) * amplitude * 0.3 + dy = math.cos(phase) * amplitude * 0.2 + + # 🔹 2. SWAY (balancement) + elif motion_type == "sway": + dx = math.sin(phase) * amplitude + dy = 0 + + # 🔹 3. BREATH (zoom subtil) + elif motion_type == "breath": + scale = 1.0 + math.sin(phase) * 0.03 + pose = F.interpolate( + pose, + scale_factor=scale, + mode="bilinear", + align_corners=False + ) + pose = F.interpolate(pose, size=(H, W)) + dx, dy = 0, 0 + + # 🔹 4. ZOOM léger + drift + elif motion_type == "zoom": + scale = 1.0 + math.sin(phase) * 0.05 + pose = F.interpolate( + pose, + scale_factor=scale, + mode="bilinear", + align_corners=False + ) + pose = F.interpolate(pose, size=(H, W)) + dx = math.sin(phase) * amplitude * 0.2 + dy = math.cos(phase) * amplitude * 0.2 + + else: + dx, dy = 0, 0 + + # -------------------------------------------------- + # 🔹 Translation affine (ultra stable) + # -------------------------------------------------- + if motion_type != "breath": + theta = torch.tensor([ + [1, 0, dx / (W/2)], + [0, 1, dy / (H/2)] + ], device=device, dtype=dtype).unsqueeze(0) + + grid = F.affine_grid(theta, pose.size(), align_corners=False) + pose = F.grid_sample(pose, grid, align_corners=False) + + # -------------------------------------------------- + # 🔒 Sécurité + # -------------------------------------------------- + pose = torch.nan_to_num(pose, 0.0) + pose = pose.clamp(0, 1) + + poses.append(pose) + + # -------------------------------------------------- + # 🔍 Debug + # -------------------------------------------------- + if debug: + print(f"[PoseSeq] frames: {num_frames}") + print(f"[PoseSeq] type: {motion_type}") + print(f"[PoseSeq] shape: {poses[0].shape}") + + return poses + + +def apply_controlnet_openpose_step_v1( + latents, + t, + unet, + controlnet, + control_tensor, + pos_embeds=None, + neg_embeds=None, + guidance_scale=1.0, + controlnet_strength=1.0, + device="cuda", + debug=False +): + """ + Applique ControlNet OpenPose sur un step UNet avec CFG. + + Args: + latents: [B,C,H,W] + t: timestep + unet: modèle UNet + controlnet: modèle ControlNet + control_tensor: [B,1,H,W] ou [B,3,H,W] (pose image) + pos_embeds: embeddings positifs + neg_embeds: embeddings négatifs + guidance_scale: CFG strength + controlnet_strength: influence pose + device: device + debug: print infos + + Returns: + latents mis à jour + """ + + # -------------------------------------------------- + # 🔒 Sécurisation inputs + # -------------------------------------------------- + latents = torch.nan_to_num(latents, 0.0) + control_tensor = torch.nan_to_num(control_tensor, 0.0) + + control_tensor = control_tensor.clamp(0, 1) + + if control_tensor.shape[1] == 1: + control_tensor = control_tensor.repeat(1, 3, 1, 1) + + control_tensor = control_tensor.to(device=device, dtype=latents.dtype) + + if debug: + print(f"[ControlNet] latents: {latents.shape}") + print(f"[ControlNet] control: {control_tensor.shape}") + print(f"[ControlNet] timestep: {t}") + + # -------------------------------------------------- + # 🔹 POS PASS + # -------------------------------------------------- + down_pos, mid_pos = controlnet( + latents, + t, + encoder_hidden_states=pos_embeds, + controlnet_cond=control_tensor, + return_dict=False + ) + + down_pos = [d * controlnet_strength for d in down_pos] + mid_pos = mid_pos * controlnet_strength + + noise_pos = unet( + latents, + t, + encoder_hidden_states=pos_embeds, + down_block_additional_residuals=down_pos, + mid_block_additional_residual=mid_pos + ).sample + + # -------------------------------------------------- + # 🔹 NEG PASS (si CFG) + # -------------------------------------------------- + if neg_embeds is not None: + + down_neg, mid_neg = controlnet( + latents, + t, + encoder_hidden_states=neg_embeds, + controlnet_cond=control_tensor, + return_dict=False + ) + + down_neg = [d * controlnet_strength for d in down_neg] + mid_neg = mid_neg * controlnet_strength + + noise_neg = unet( + latents, + t, + encoder_hidden_states=neg_embeds, + down_block_additional_residuals=down_neg, + mid_block_additional_residual=mid_neg + ).sample + + # 🔥 CFG + noise_pred = noise_neg + guidance_scale * (noise_pos - noise_neg) + + else: + noise_pred = noise_pos + + # -------------------------------------------------- + # 🔹 Update latents (diffusion step simplifié) + # -------------------------------------------------- + latents = latents + noise_pred * 0.1 # facteur stable (évite explosion) + + # 🔒 Clamp sécurité + latents = torch.clamp(latents, -1.5, 1.5) + + # -------------------------------------------------- + # 🔍 Debug + # -------------------------------------------------- + if debug: + print(f"[ControlNet] noise min/max: {noise_pred.min():.3f}/{noise_pred.max():.3f}") + print(f"[ControlNet] latents min/max: {latents.min():.3f}/{latents.max():.3f}") + + return latents + +# Chargement par defaut: +# /mnt/62G/AnimateDiff main* 54s +# animatediff ❯ ls -l /mnt/62G/huggingface/sd-controlnet-openpose/ +# .rw-r--r--@ 1,4G n3oray 25 mars 22:50  diffusion_pytorch_model.safetensors +# .rw-r--r--@ 67 n3oray 25 mars 22:51  Note.txt + +def load_controlnet_openpose_local( + local_model_path="/mnt/62G/huggingface/sd-controlnet-openpose", + device="cuda", + dtype=torch.float16, + use_fp16=True, + debug=True +): + """ + Charge ControlNet OpenPose depuis un dossier local contenant : + - diffusion_pytorch_model.safetensors + - config.json + + Args: + local_model_path (str): chemin vers le dossier local du modèle + device (str): "cuda" ou "cpu" + dtype (torch.dtype): dtype cible + use_fp16 (bool): force fp16 si possible + debug (bool): logs détaillés + + Returns: + controlnet (ControlNetModel) + """ + print(f"Chargement ControlNet OpenPose depuis dossier local : {local_model_path}") + print(f"device : {device}") + print(f"dtype : {dtype}") + + try: + # 🔹 Choix dtype intelligent + load_dtype = torch.float16 if (use_fp16 and device.startswith("cuda")) else torch.float32 + + controlnet = ControlNetModel.from_pretrained( + local_model_path, + torch_dtype=load_dtype, + local_files_only=True + ) + + # 🔹 Move device + controlnet = controlnet.to(device) + + # 🔹 Vérification paramètres + total_params = sum(p.numel() for p in controlnet.parameters()) / 1e6 + if debug: + print(f"🧠 ControlNet prêt") + print(f" params : {total_params:.1f}M") + print(f" dtype : {next(controlnet.parameters()).dtype}") + print(f" device : {next(controlnet.parameters()).device}") + + # 🔹 Mode eval + controlnet.eval() + + # 🔹 Nettoyage mémoire GPU + torch.cuda.empty_cache() + + return controlnet + + except Exception as e: + print("❌ ERREUR chargement ControlNet depuis dossier local") + print(str(e)) + + # 🔥 fallback CPU (évite crash) + if device.startswith("cuda"): + print("⚠ Fallback CPU...") + return load_controlnet_openpose_local( + local_model_path=local_model_path, + device="cpu", + dtype=torch.float32, + use_fp16=False, + debug=debug + ) + + raise e + + +def load_controlnet_openpose( + device="cuda", + dtype=torch.float16, + model_id="lllyasviel/sd-controlnet-openpose", + use_fp16=True, + debug=True +): + """ + Charge ControlNet OpenPose avec gestion propre GPU / CPU / dtype. + + Args: + device (str): "cuda" ou "cpu" + dtype (torch.dtype): dtype cible (fp16 recommandé) + model_id (str): repo HF + use_fp16 (bool): force fp16 si possible + debug (bool): logs détaillés + + Returns: + controlnet (ControlNetModel) + """ + print(f"Chargement ControlNet OpenPose - Parametres recommander:") + print(f"guidance_scale = 5.0 → 6.0") + print(f"controlnet_strength = 0.7 → 0.9") + print(f"latents update factor = 0.1 ✅ (ne pas monter)") + + if debug: + print("🔄 Chargement ControlNet OpenPose...") + print(f" model_id : {model_id}") + print(f" device : {device}") + print(f" dtype : {dtype}") + + try: + # 🔹 Choix dtype intelligent + load_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 + + controlnet = ControlNetModel.from_pretrained( + model_id, + torch_dtype=load_dtype + ) + + if debug: + print("✅ Modèle chargé depuis HuggingFace") + + # 🔹 Move device + controlnet = controlnet.to(device) + + # 🔹 Vérification paramètres + total_params = sum(p.numel() for p in controlnet.parameters()) / 1e6 + + if debug: + print(f"🧠 ControlNet prêt") + print(f" params : {total_params:.1f}M") + print(f" dtype : {next(controlnet.parameters()).dtype}") + print(f" device : {next(controlnet.parameters()).device}") + + # 🔹 Mode eval + controlnet.eval() + + # 🔹 Sécurité mémoire (important pour 4GB) + torch.cuda.empty_cache() + + return controlnet + + except Exception as e: + print("❌ ERREUR chargement ControlNet") + print(str(e)) + + # 🔥 fallback CPU (évite crash) + if device == "cuda": + print("⚠ Fallback CPU...") + return load_controlnet_openpose( + device="cpu", + dtype=torch.float32, + use_fp16=False, + debug=debug + ) + + raise e diff --git a/scripts/n3rProBoost.py b/scripts/n3rProBoost.py new file mode 100644 index 00000000..2b905660 --- /dev/null +++ b/scripts/n3rProBoost.py @@ -0,0 +1,465 @@ +# -------------------------------------------------------------- +# n3rProBoost.py - AnimateDiff ultra-light ~2Go VRAM +# Prompt / Input → N3RModelOptimized → MotionModule → UNet → LoRA → VAE → Image / Vidéo +#Avec use_mini_gpu et generate_latents_mini_gpu_320 → ~2,1 Go VRAM, ultra léger ✅ Avec use_n3r_model et N3RModelOptimized → ~3,6 Go VRAM +# -------------------------------------------------------------- +import os, math, threading, random +import traceback +import hashlib +import torch +import pickle +from pathlib import Path +from datetime import datetime +from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image +from PIL import Image, ImageFilter +import argparse +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.n3rModelUtils import generate_n3r_coords, process_n3r_latents, fuse_with_memory, inject_external, fuse_n3r_latents_adaptive_new +from scripts.utils.tools_utils import ensure_4_channels, print_generation_params, sanitize_latents, stabilize_latents_advanced, log_debug, compute_overlap, get_interpolated_embeddings, save_memory, load_memory, load_external_embedding_as_latent, inject_external_embeddings, update_n3r_memory, compute_weighted_params, adapt_embeddings_to_unet, get_dynamic_latent_injection, save_input_frame +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline, generate_latents_robuste_4D +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, decode_latents_ultrasafe_blockwise, adaptive_post_process, save_frames_as_video_from_folder, encode_images_to_latents_safe, apply_post_processing_adaptive, encode_images_to_latents_hybrid, interpolate_param_fast, fuse_n3r_latents_adaptive, adaptive_post_process, remove_white_noise, apply_post_processing + +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3rModelFast4Go import N3RModelFast4GB, N3RModelLazyCPU, N3RModelOptimized +from scripts.utils.n3rProNet import N3RProNet +from scripts.utils.n3rProNet_utils import apply_n3r_pro_net, soft_tone_map + +LATENT_SCALE = 0.18215 +stop_generation = False + +# Variation de l'interpolation' Valeurs de départ (fidèles à l'image)-----------------------interpolate_param_fast --- +#init_image_scale_start = 0.95 #guidance_scale_start = 1.5 #creative_noise_start = 0.0 +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Utilitaires ---------------- +def apply_motion_safe(latents, motion_module, threshold=1e-3): + if latents.abs().max() < threshold: + return latents, False + return motion_module(latents), True + +# ---------------- MAIN FIABLE ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + + use_mini_gpu = cfg.get("use_mini_gpu", True) + verbose = cfg.get("verbose", False) + latent_injection = float(cfg.get("latent_injection", 0.75)) + latent_injection = min(max(latent_injection, 0.5), 0.9) # plage sûre + final_latent_scale = cfg.get("final_latent_scale", 1/8) # 1/8 speed, 1/4 moyen, 1/2 low + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 1) + transition_frames = cfg.get("transition_frames", 4) + num_fraps_per_image = cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 6.5) # 0.15 peut de créativité 4.5 moderé + guidance_scale_end = cfg.get("guidance_scale_end", 7.0) # 0.15 peut de créativité 4.5 moderé + init_image_scale = cfg.get("init_image_scale", 0.75) # 0.85 ou 0.95 proche de l'init' (0.75) + init_image_scale_end = cfg.get("init_image_scale_end", 0.9) # 0.85 ou 0.95 proche de l'init' + creative_noise = cfg.get("creative_noise", 0.0) + creative_noise_end = cfg.get("creative_noise_end", 0.08) + latent_scale_boost = cfg.get("latent_scale_boost", 1.0) + frames_per_prompt = cfg.get("frames_per_prompt", 10) # nombre de frames par prompt + contrast = cfg.get("contrast", 1.15) # Post Traitement constrat + saturation = cfg.get("saturation", 1.00) # Post Traitement saturation + blur_radius = cfg.get("blur_radius", 0.03) # Post Traitement blur + sharpen_percent = cfg.get("sharpen_percent", 90) #Post Traitement sharpen + H, W = cfg.get("H", 512), cfg.get("W", 512) + block_size = min(256, H//2, W//2) # block_size auto selon résolution + use_n3r_model = cfg.get("use_n3r_model", False) + + # Seed aléatoire + seed = torch.randint(0, 100000, (1,)).item() + params = { 'use_mini_gpu': use_mini_gpu, 'fps': fps, 'upscale_factor': upscale_factor, 'num_fraps_per_image': num_fraps_per_image, 'steps': steps, 'guidance_scale': guidance_scale, 'guidance_scale_end': guidance_scale_end, 'init_image_scale': init_image_scale, 'init_image_scale_end': init_image_scale_end, 'creative_noise': creative_noise, 'creative_noise_end': creative_noise_end, 'latent_scale_boost': latent_scale_boost, 'final_latent_scale': final_latent_scale, 'seed': seed, 'latent_injection': latent_injection, 'transition_frames': transition_frames, 'block_size': block_size, 'use_n3r_model': use_n3r_model } + print_generation_params(params) + + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA configuré, étape ignorée.") + #iniy external_latent + external_latent = None + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + if motion_module and verbose: + print(f"[INFO] motion_module type: {type(motion_module)}") + + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device).to(dtype) + + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + # --- Projection adaptative + text_inputs_sample = tokenizer("test", padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + sample_embeds = text_encoder(text_inputs_sample.input_ids.to(device)).last_hidden_state + current_dim = sample_embeds.shape[-1] + projection = None + if current_dim != unet_cross_attention_dim: + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + + # --- Pré-calcul des embeddings pour interpolation + pos_embeds_list = [] + neg_embeds_list = [] + + # Si prompts et n_prompts sont des listes de listes ou chaînes + for i, prompt_item in enumerate(prompts): + # Texte positif + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + # Texte négatif correspondant + neg_text_item = negative_prompts[i] if i < len(negative_prompts) else negative_prompts[0] + neg_text = " ".join(neg_text_item) if isinstance(neg_text_item, list) else str(neg_text_item) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + if projection is not None: + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + + # Ajouter à la liste complète + pos_embeds_list.append(pos_embeds) + neg_embeds_list.append(neg_embeds) + + # ---------------- N3RModelOptimized ---------------- + n3r_model = None + if use_n3r_model: + n3r_model = N3RModelOptimized( + L_low=cfg.get("n3r_L_low",3), # 3 ou 4 # plutôt que 3, un peu plus de finesse + L_high=cfg.get("n3r_L_high",6), # garde structure globale + N_samples=cfg.get("n3r_N_samples",32), # plus de samples pour un rendu détaillé 48 + tile_size=cfg.get("n3r_tile_size",64), # inchangé pour VRAM raisonnable + cpu_offload=cfg.get("n3r_cpu_offload",True) + ).to(device) + n3r_model.eval() + print(f"✅ N3RModelOptimized initialisé sur {device}") + + # ------------------- Initialisation mémoire ------------------- + output_dir_m = Path("./outputs") + memory_file = output_dir_m / "n3r_memory" + memory_dict = load_memory(memory_file) + + # Configurable depuis ton fichier cfg + use_n3r_pro_net = cfg.get("use_n3r_pro_net", True) + n3r_pro_strength = cfg.get("n3r_pro_strength", 0.3) + + n3r_pro_net = None + if use_n3r_pro_net: + n3r_pro_net = N3RProNet(channels=4).to(device).to(dtype) + n3r_pro_net.eval() + print("✅ N3RProNet activé") + + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/ProBoost{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales avec interpolation prompts ---------------- + external_embeddings = None + + # Charger latent externe avant la génération + external_path = "/mnt/62G/huggingface/cyber-fp16/pt/KnxCOmiXNeg.safetensors" + external_latent = load_external_embedding_as_latent( + external_path, (1, 4, cfg["H"]//8, cfg["W"]//8) + ).to(device) + #------------------------------------------------------------------------------ + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + try: + # Paramètres interpolés + current_init_image_scale, current_creative_noise, current_guidance_scale = compute_weighted_params( frame_counter, total_frames, init_start=0.85, init_end=0.5,noise_start=0.0, noise_end=0.08, guidance_start=3.5, guidance_end=4.5, mode="cosine" ) + print(f"[Frame {frame_counter:03d}] " f"init_image_scale={current_init_image_scale:.3f}, " f"guidance_scale={current_guidance_scale:.3f}, " f"creative_noise={current_creative_noise:.3f}") + + # Charger et encoder l'image sur GPU + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_image = ensure_4_channels(input_image) + frame_counter = save_input_frame( input_image, output_dir, frame_counter, pbar=pbar, blur_radius=blur_radius, contrast=contrast, saturation=saturation, apply_post=False ) + + current_latent_single = encode_images_to_latents_hybrid(input_image, vae, device=device, latent_scale=LATENT_SCALE) + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(cfg["H"]//8, cfg["W"]//8), + #current_latent_single, size=(cfg["H"]//6, cfg["W"]//6), + mode='bilinear', align_corners=False + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + + # Génération initiale robuste : + #42 Classique, beaucoup de tests communautaires utilisent ce seed. #1234 Fidèle, stable, souvent utilisé pour des tests de cohérence. + #5555 Fidélité à l’image initiale (ton choix actuel) #2026 Léger changement dans la texture ou la posture, subtil mais prévisible + #9876 Variation un peu plus visible, garde la structure globale + pos_embeds, neg_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device ) + try: + current_latent_single = generate_latents_robuste_4D( + latents=current_latent_single.to(device), + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + scheduler=scheduler, + motion_module=None, + device=device, + dtype=dtype, + guidance_scale=current_guidance_scale, #guidance_scale: 1.5 # un peu plus strict pour que le chat ressorte + init_image_scale=current_init_image_scale, #init_image_scale: 0.85 # presque tout le signal de l'image d'origine + creative_noise=current_creative_noise, # creative_noise: 0.08 # moins de liberté, plus de cohérence + seed=seed # 42, 1234, 2026, 5555 + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + except Exception as e: + print(f"[Robuste INIT ERROR] {e}") + + current_latent_single = ensure_4_channels(current_latent_single) + current_latent_single = current_latent_single.to('cpu') + del input_image + torch.cuda.empty_cache() + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + with torch.no_grad(): + # --- Fusion adaptative avec diminution progressive de l'influence de la frame précédente + injection_start = 0.8 # influence initiale de l'ancienne frame + injection_end = 0.1 # influence finale + denom = max(transition_frames-1, 1) + injection_alpha = injection_start * (1 - t/denom) + injection_end * (t/denom) + + latent_interp = injection_alpha * previous_latent_single.to(device) + (1 - injection_alpha) * current_latent_single.to(device) + # 🔥 FIX NaN / stabilité + latent_interp = sanitize_latents(latent_interp) + + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + + # Application de n3r_pro_net + latent_interp = apply_n3r_pro_net( latent_interp, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_latents ) + # Décodage streaming + latent_interp = latent_interp / LATENT_SCALE # “rescale” avant décodage + # contrast=1.5, saturation=1.3, latent_scale_boost # Recommmander 1.0 + frame_pil = decode_latents_ultrasafe_blockwise( latent_interp, vae, block_size=block_size, overlap=overlap, gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + #frame_pil = apply_post_processing_adaptive(frame_pil, blur_radius=blur_radius, contrast=contrast, brightness=1.05, saturation=saturation, vibrance_base=1.0, vibrance_max=1.1, sharpen=True, sharpen_radius=1, sharpen_percent=sharpen_percent, sharpen_threshold=2) + + + frame_pil = soft_tone_map(frame_pil) + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=0.02, # ↓ plus léger (évite wash) + contrast=1.03, # 🔥 très important → quasi neutre + brightness=1.0, # ne jamais toucher sauf besoin + saturation=0.93, # 🔥 clé → évite saturation globale + vibrance_base=0.98, # 🔥 baisse globale + vibrance_max=1.02, # limite haute très faible + sharpen=True, + sharpen_radius=0.6, # 🔥 plus fin + sharpen_percent=60, # 🔥 beaucoup moins agressif + sharpen_threshold=3 # évite bruit + ) + + # save + print(f"[ init SAVE Frame {frame_counter:03d}]") + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + del latent_interp + torch.cuda.empty_cache() + + # ---------------- Frames principales ---------------- + + for f in range(num_fraps_per_image): + if stop_generation: break + with torch.no_grad(): + latents_frame = current_latent_single.to(device) + + # --- Interpolation des embeddings prompts --- + cf_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device ) + + # --- N3R ou mini GPU diffusion --- + n3r_latents = None + latents = latents_frame.clone() + + # 🔥 FIX NaN / stabilité + latents = sanitize_latents(latents) + + # ---------------- N3R avec mémoire latente conditionnée ---------------- + use_n3r_this_frame = use_n3r_model and (frame_counter % random.choice([4,5,6]) == 0) + # ------------------- Bloc N3R par frame ------------------- + if use_n3r_this_frame: + try: + H, W = latents.shape[-2], latents.shape[-1] + N_samples = n3r_model.N_samples + coords = generate_n3r_coords(H, W, N_samples, seed, frame_counter, device) + n3r_latents = process_n3r_latents(n3r_model, coords, H, W, H, W) + fused_latents = fuse_with_memory(n3r_latents, memory_dict, cf_embeds, frame_counter) + fused_latents = inject_external(fused_latents, external_latent, frame_counter, device) + #latent_injection = get_dynamic_latent_injection(frame_counter, total_frames, start=0.90, end=0.55) + #latents = fuse_n3r_latents_adaptive(latents, fused_latents, latent_injection=latent_injection) + latents = fuse_n3r_latents_adaptive_new(latents, fused_latents, frame_counter=frame_counter, total_frames=total_frames, latent_injection_start=0.90, latent_injection_end=0.55) + latents = sanitize_latents(latents) + except Exception as e: + print(f"[N3R ERROR] {e}") + + # Sauvegarde mémoire périodique + if frame_counter % 30 == 0: + save_memory(memory_dict, memory_file) + + elif use_mini_gpu: + latents = generate_latents_mini_gpu_320( + unet=unet, scheduler=scheduler, + input_latents=latents_frame, embeddings=cf_embeds, + motion_module=motion_module, guidance_scale=current_guidance_scale, + device=device, fp16=True, steps=steps, + debug=verbose, init_image_scale=current_init_image_scale, + creative_noise=current_creative_noise + ) + if latent_injection > 0: + if latents.shape[-2:] != latents_frame.shape[-2:]: + latents = torch.nn.functional.interpolate( + latents, + size=latents_frame.shape[-2:], + mode='bilinear', align_corners=False + ).contiguous() + latents = latent_injection*latents_frame + (1-latent_injection)*latents + + # --- Motion module propre et safe --- + if motion_module is not None: + if previous_latent_single is not None: + latents_seq = torch.stack([ + previous_latent_single.to(device), + latents, + latents + 0.01 * torch.randn_like(latents) + ], dim=2) # [B,C,F,H,W] + else: + latents_seq = latents.unsqueeze(2).repeat(1, 1, 3, 1, 1) + + # 🔥 sécurisation AVANT motion + latents_seq = sanitize_latents(latents_seq) + + # 🔥 motion safe (une seule fois) + latents_seq, applied = apply_motion_safe(latents_seq, motion_module) + + if applied: + latents = latents_seq[:, :, 1, :, :] + else: + latents = latents + + latents = sanitize_latents(latents) + + # 🔥 AUCUN blending → juste update mémoire + previous_latent_single = latents.detach().cpu() + # Application de n3r_pro_net + latents = apply_n3r_pro_net( latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_latents ) + # Clamp et resize final 🔥 FIX NaN / stabilité 🔥 nettoyage final intelligent (LE point clé) + latents = latents / LATENT_SCALE + # Decode + frame_pil = decode_latents_ultrasafe_blockwise( latents, vae, block_size=block_size, overlap=overlap, gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + #frame_pil = apply_post_processing_adaptive(frame_pil, blur_radius=blur_radius, contrast=contrast, brightness=1.00, saturation=saturation, vibrance_base=1.1, vibrance_max=1.2, sharpen=True, sharpen_radius=1, sharpen_percent=sharpen_percent, sharpen_threshold=2) + + frame_pil = soft_tone_map(frame_pil) + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=0.02, # ↓ plus léger (évite wash) + contrast=1.03, # 🔥 très important → quasi neutre + brightness=1.0, # ne jamais toucher sauf besoin + saturation=0.93, # 🔥 clé → évite saturation globale + vibrance_base=0.98, # 🔥 baisse globale + vibrance_max=1.02, # limite haute très faible + sharpen=True, + sharpen_radius=0.6, # 🔥 plus fin + sharpen_percent=60, # 🔥 beaucoup moins agressif + sharpen_threshold=3 # évite bruit + ) + + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + # Nettoyage VRAM + del latents, latents_frame, cf_embeds, n3r_latents + torch.cuda.empty_cache() + + previous_latent_single = current_latent_single + + except Exception as e: + print(f"\n[FRAME ERROR] {img_path}") + print(f"Type d'erreur : {type(e).__name__}") + print(f"Message d'erreur : {e}") + print("Traceback complet :") + traceback.print_exc() + continue + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=upscale_factor) + print(f"🎬 Vidéo générée : {out_video}") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rProBoostNet.py b/scripts/n3rProBoostNet.py new file mode 100644 index 00000000..0ddeb67c --- /dev/null +++ b/scripts/n3rProBoostNet.py @@ -0,0 +1,434 @@ +# ---------------------------------------------------------------------------------------- +# n3rProBoostNet.py - AnimateDiff stables, ProNet + HDR ultra-light ~2Go VRAM +# Prompt / Input → N3RModelOptimized → MotionModule → UNet → LoRA → VAE → Image / Vidéo +#Avec use_mini_gpu et generate_latents_mini_gpu_320 → ~2,1 Go VRAM, ultra léger ✅ Avec use_n3r_model et N3RModelOptimized → ~3,6 Go VRAM +# ---------------------------------------------------------------------------------------- +import os, math, threading, random +import traceback +import hashlib +import torch +import pickle +from pathlib import Path +from datetime import datetime +from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image +from PIL import Image, ImageFilter +import argparse +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.n3rModelUtils import generate_n3r_coords, process_n3r_latents, fuse_with_memory, inject_external, fuse_n3r_latents_adaptive_new +from scripts.utils.tools_utils import ensure_4_channels, print_generation_params, sanitize_latents, stabilize_latents_advanced, log_debug, compute_overlap, get_interpolated_embeddings, save_memory, load_memory, load_external_embedding_as_latent, inject_external_embeddings, update_n3r_memory, compute_weighted_params, adapt_embeddings_to_unet, get_dynamic_latent_injection, save_input_frame +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline, generate_latents_robuste_4D +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, adaptive_post_process, save_frames_as_video_from_folder, encode_images_to_latents_safe, encode_images_to_latents_hybrid, interpolate_param_fast, fuse_n3r_latents_adaptive, adaptive_post_process, remove_white_noise + +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3rModelFast4Go import N3RModelFast4GB, N3RModelLazyCPU, N3RModelOptimized +from scripts.utils.n3rProNet import N3RProNet +from scripts.utils.n3rProNet_utils import apply_n3r_pro_net, save_frame_verbose, full_frame_postprocess, decode_latents_ultrasafe_blockwise, get_eye_coords_safe, create_volumetrique_mask, create_eye_mask, tensor_to_pil, apply_pro_net_volumetrique, apply_pro_net_with_eyes, get_eye_coords_safe, scale_eye_coords_to_latents, get_coords, get_coords_safe + +LATENT_SCALE = 0.18215 +stop_generation = False + +# ---------------- Thread stop ---------------- + +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Utilitaires ---------------- +def apply_motion_safe(latents, motion_module, threshold=1e-3): + if latents.abs().max() < threshold: + return latents, False + return motion_module(latents), True + + +# ---------------- MAIN FIABLE ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + + use_mini_gpu = cfg.get("use_mini_gpu", True) + verbose = cfg.get("verbose", False) + psave = cfg.get("psave", False) + latent_injection = float(cfg.get("latent_injection", 0.75)) + latent_injection = min(max(latent_injection, 0.5), 0.9) # plage sûre + final_latent_scale = cfg.get("final_latent_scale", 1/8) # 1/8 speed, 1/4 moyen, 1/2 low + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 1) + transition_frames = cfg.get("transition_frames", 4) + num_fraps_per_image = cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 6.5) # 0.15 peut de créativité 4.5 moderé + guidance_scale_end = cfg.get("guidance_scale_end", 7.0) # 0.15 peut de créativité 4.5 moderé + init_image_scale = cfg.get("init_image_scale", 0.75) # 0.85 ou 0.95 proche de l'init' (0.75) + init_image_scale_end = cfg.get("init_image_scale_end", 0.9) # 0.85 ou 0.95 proche de l'init' + creative_noise = cfg.get("creative_noise", 0.0) + creative_noise_end = cfg.get("creative_noise_end", 0.08) + latent_scale_boost = cfg.get("latent_scale_boost", 1.0) + frames_per_prompt = cfg.get("frames_per_prompt", 10) # nombre de frames par prompt + contrast = cfg.get("contrast", 1.15) # Post Traitement constrat + blur_radius = cfg.get("blur_radius", 0.03) # Post Traitement blur + sharpen_percent = cfg.get("sharpen_percent", 90) #Post Traitement sharpen + H, W = cfg.get("H", 512), cfg.get("W", 512) + block_size = min(256, H//2, W//2) # block_size auto selon résolution + use_n3r_model = cfg.get("use_n3r_model", False) + # Configurable depuis ton fichier cfg + use_n3r_pro_net = cfg.get("use_n3r_pro_net", True) + n3r_pro_strength = cfg.get("n3r_pro_strength", 0.2) # 0.1, 0.2, 0.3 + + + #target_temp = 8000 reference_temp = 6000 (Froid) + target_temp = 7800 + reference_temp = 6500 + + # Seed aléatoire + seed = torch.randint(0, 100000, (1,)).item() + params = { 'use_mini_gpu': use_mini_gpu, 'fps': fps, 'upscale_factor': upscale_factor, 'num_fraps_per_image': num_fraps_per_image, 'steps': steps, 'guidance_scale': guidance_scale, 'guidance_scale_end': guidance_scale_end, 'init_image_scale': init_image_scale, 'init_image_scale_end': init_image_scale_end, 'creative_noise': creative_noise, 'creative_noise_end': creative_noise_end, 'latent_scale_boost': latent_scale_boost, 'final_latent_scale': final_latent_scale, 'seed': seed, 'latent_injection': latent_injection, 'transition_frames': transition_frames, 'block_size': block_size, 'use_n3r_model': use_n3r_model } + print_generation_params(params) + + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA configuré, étape ignorée.") + #iniy external_latent + external_latent = None + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + if motion_module and verbose: + print(f"[INFO] motion_module type: {type(motion_module)}") + + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device).to(dtype) + + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + # --- Projection adaptative + text_inputs_sample = tokenizer("test", padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + sample_embeds = text_encoder(text_inputs_sample.input_ids.to(device)).last_hidden_state + current_dim = sample_embeds.shape[-1] + projection = None + if current_dim != unet_cross_attention_dim: + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + + # --- Pré-calcul des embeddings pour interpolation + pos_embeds_list = [] + neg_embeds_list = [] + + # Si prompts et n_prompts sont des listes de listes ou chaînes + for i, prompt_item in enumerate(prompts): + # Texte positif + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + # Texte négatif correspondant + neg_text_item = negative_prompts[i] if i < len(negative_prompts) else negative_prompts[0] + neg_text = " ".join(neg_text_item) if isinstance(neg_text_item, list) else str(neg_text_item) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + if projection is not None: + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + + # Ajouter à la liste complète + pos_embeds_list.append(pos_embeds) + neg_embeds_list.append(neg_embeds) + + # ---------------- N3RModelOptimized ---------------- + n3r_model = None + if use_n3r_model: + n3r_model = N3RModelOptimized( + L_low=cfg.get("n3r_L_low",3), # 3 ou 4 # plutôt que 3, un peu plus de finesse + L_high=cfg.get("n3r_L_high",6), # garde structure globale + N_samples=cfg.get("n3r_N_samples",32), # plus de samples pour un rendu détaillé 48 + tile_size=cfg.get("n3r_tile_size",64), # inchangé pour VRAM raisonnable + cpu_offload=cfg.get("n3r_cpu_offload",True) + ).to(device) + n3r_model.eval() + print(f"✅ N3RModelOptimized initialisé sur {device}") + + # ------------------- Initialisation mémoire ------------------- + output_dir_m = Path("./outputs") + memory_file = output_dir_m / "n3r_memory" + memory_dict = load_memory(memory_file) + # ---------------- n3r_pro_net ---------------- + n3r_pro_net = None + if use_n3r_pro_net: + n3r_pro_net = N3RProNet(channels=4).to(device).to(dtype) + n3r_pro_net.eval() + print("✅ N3RProNet activé") + + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/ProBoostNet{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales avec interpolation prompts ---------------- + external_embeddings = None + + # Charger latent externe avant la génération + external_path = "/mnt/62G/huggingface/cyber-fp16/pt/KnxCOmiXNeg.safetensors" + external_latent = load_external_embedding_as_latent( + external_path, (1, 4, cfg["H"]//8, cfg["W"]//8) + ).to(device) + #------------------------------------------------------------------------------ + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + try: + # Paramètres interpolés + current_init_image_scale, current_creative_noise, current_guidance_scale = compute_weighted_params( frame_counter, total_frames, init_start=0.85, init_end=0.5,noise_start=0.0, noise_end=0.08, guidance_start=3.5, guidance_end=4.5, mode="cosine" ) + print(f"[Frame {frame_counter:03d}] " f"init_image_scale={current_init_image_scale:.3f}, " f"guidance_scale={current_guidance_scale:.3f}, " f"creative_noise={current_creative_noise:.3f}") + + # Charger et encoder l'image sur GPU + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + # 🔥 Détection yeux (une seule fois par image) + input_pil = tensor_to_pil(input_image) # à créer si tu ne l'as pas + eye_coords = get_eye_coords_safe(input_pil) + coords_v = get_coords_safe( input_pil, H=cfg["H"], W=cfg["W"] ) + input_image = ensure_4_channels(input_image) + if frame_counter > 0: + initframe = frame_counter+transition_frames + else: + initframe = frame_counter + save_input_frame( input_image, output_dir, initframe, pbar=pbar, blur_radius=blur_radius, contrast=contrast, saturation=1.0, apply_post=False ) + + current_latent_single = encode_images_to_latents_hybrid(input_image, vae, device=device, latent_scale=LATENT_SCALE) + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(cfg["H"]//8, cfg["W"]//8), + #current_latent_single, size=(cfg["H"]//6, cfg["W"]//6), + mode='bilinear', align_corners=False + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + # Génération initiale robuste : + #42 Classique, beaucoup de tests communautaires utilisent ce seed. #1234 Fidèle, stable, souvent utilisé pour des tests de cohérence. + #5555 Fidélité à l’image initiale (ton choix actuel) #2026 Léger changement dans la texture ou la posture, subtil mais prévisible + #9876 Variation un peu plus visible, garde la structure globale + pos_embeds, neg_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device ) + try: + current_latent_single = generate_latents_robuste_4D( + latents=current_latent_single.to(device), + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + scheduler=scheduler, + motion_module=None, + device=device, + dtype=dtype, + guidance_scale=current_guidance_scale, #guidance_scale: 1.5 # un peu plus strict pour que le chat ressorte + init_image_scale=current_init_image_scale, #init_image_scale: 0.85 # presque tout le signal de l'image d'origine + creative_noise=current_creative_noise, # creative_noise: 0.08 # moins de liberté, plus de cohérence + seed=seed # 42, 1234, 2026, 5555 + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + except Exception as e: + print(f"[Robuste INIT ERROR] {e}") + + current_latent_single = ensure_4_channels(current_latent_single) + current_latent_single = current_latent_single.to('cpu') + del input_image + torch.cuda.empty_cache() + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + with torch.no_grad(): + # --- Fusion adaptative avec diminution progressive de l'influence de la frame précédente + injection_start = 0.8 # influence initiale de l'ancienne frame + injection_end = 0.1 # influence finale + denom = max(transition_frames-1, 1) + injection_alpha = injection_start * (1 - t/denom) + injection_end * (t/denom) + + latent_interp = injection_alpha * previous_latent_single.to(device) + (1 - injection_alpha) * current_latent_single.to(device) + # 🔥 FIX NaN / stabilité + latent_interp = sanitize_latents(latent_interp) + + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + # Application de n3r_pro_net - réutilisé pour toutes les frames - creation des masques + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latent_interp.shape[-2], lat_W=latent_interp.shape[-1] ) + if eye_coords_latent: + eye_mask = create_eye_mask(latent_interp, eye_coords_latent) + volume_mask = create_volumetrique_mask(latent_interp, coords_v, debug=False) + # Application du ProNet tout en protégeant les yeux + if use_n3r_pro_net: + latents = apply_pro_net_volumetrique(latent_interp, coords_v, n3r_pro_net, n3r_pro_strength, sanitize_latents, debug=False) + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latents.shape[-2], lat_W=latents.shape[-1] ) + if eye_coords_latent: + latents = apply_pro_net_with_eyes(latents, eye_coords_latent, n3r_pro_net, n3r_pro_strength, sanitize_fn=sanitize_latents) + + # Décodage streaming + latent_interp = latent_interp / LATENT_SCALE # “rescale” avant décodage + frame_pil = decode_latents_ultrasafe_blockwise( latent_interp, vae, block_size=block_size, overlap=overlap, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + + #Post Traitement + frame_pil = full_frame_postprocess( frame_pil, output_dir, frame_counter, target_temp=target_temp, reference_temp=reference_temp, blur_radius=blur_radius, contrast=contrast, sharpen_percent=sharpen_percent, psave=psave ) + save_frame_verbose(frame_pil, output_dir, frame_counter-1, suffix="0i", psave=True) + frame_counter += 1 + pbar.update(1) + + del latent_interp + torch.cuda.empty_cache() + + # ---------------- Frames principales ---------------- + + for f in range(num_fraps_per_image): + if stop_generation: break + with torch.no_grad(): + latents_frame = current_latent_single.to(device) + + # --- Interpolation des embeddings prompts --- + cf_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device ) + + # --- N3R ou mini GPU diffusion --- + n3r_latents = None + latents = latents_frame.clone() + + # 🔥 FIX NaN / stabilité + latents = sanitize_latents(latents) + + # ---------------- N3R avec mémoire latente conditionnée ---------------- + use_n3r_this_frame = use_n3r_model and (frame_counter % random.choice([4,5,6]) == 0) + # ------------------- Bloc N3R par frame ------------------- + if use_n3r_this_frame: + try: + H, W = latents.shape[-2], latents.shape[-1] + N_samples = n3r_model.N_samples + coords = generate_n3r_coords(H, W, N_samples, seed, frame_counter, device) + n3r_latents = process_n3r_latents(n3r_model, coords, H, W, H, W) + fused_latents = fuse_with_memory(n3r_latents, memory_dict, cf_embeds, frame_counter) + fused_latents = inject_external(fused_latents, external_latent, frame_counter, device) + latents = fuse_n3r_latents_adaptive_new(latents, fused_latents, frame_counter=frame_counter, total_frames=total_frames, latent_injection_start=0.90, latent_injection_end=0.55) + latents = sanitize_latents(latents) + except Exception as e: + print(f"[N3R ERROR] {e}") + + # Sauvegarde mémoire périodique + if frame_counter % 30 == 0: + save_memory(memory_dict, memory_file) + + elif use_mini_gpu: + latents = generate_latents_mini_gpu_320( + unet=unet, scheduler=scheduler, + input_latents=latents_frame, embeddings=cf_embeds, + motion_module=motion_module, guidance_scale=current_guidance_scale, + device=device, fp16=True, steps=steps, + debug=verbose, init_image_scale=current_init_image_scale, + creative_noise=current_creative_noise + ) + if latent_injection > 0: + if latents.shape[-2:] != latents_frame.shape[-2:]: + latents = torch.nn.functional.interpolate( latents, size=latents_frame.shape[-2:], mode='bilinear', align_corners=False ).contiguous() + latents = latent_injection*latents_frame + (1-latent_injection)*latents + + # --- Motion module propre et safe --- + # Motion safe + if motion_module is not None: + latents_seq = latents.unsqueeze(2).repeat(1,1,3,1,1) if previous_latent_single is None else torch.stack([previous_latent_single.to(device), latents, latents+0.01*torch.randn_like(latents)], dim=2) + latents_seq = sanitize_latents(latents_seq) + latents_seq, applied = apply_motion_safe(latents_seq, motion_module) + latents = latents_seq[:, :, 1, :, :] if applied else latents + latents = sanitize_latents(latents) + + # ProNet avec yeux + if use_n3r_pro_net: + latents = apply_pro_net_volumetrique(latents, coords_v, n3r_pro_net, n3r_pro_strength, sanitize_latents, debug=False) + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latents.shape[-2], lat_W=latents.shape[-1] ) + if eye_coords_latent: + latents = apply_pro_net_with_eyes(latents, eye_coords_latent, n3r_pro_net, n3r_pro_strength, sanitize_fn=sanitize_latents) + + # Décodage final + latents = latents / LATENT_SCALE + frame_pil = decode_latents_ultrasafe_blockwise(latents, vae, block_size=block_size, overlap=overlap, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost) + frame_pil = full_frame_postprocess(frame_pil, output_dir, frame_counter, target_temp=target_temp, reference_temp=reference_temp, blur_radius=blur_radius, contrast=contrast, sharpen_percent=sharpen_percent, psave=psave) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="0f", psave=True) + + previous_latent_single = latents.detach().cpu() + frame_counter += 1 + pbar.update(1) + + # Nettoyage VRAM + del latents, latents_frame, cf_embeds, n3r_latents + torch.cuda.empty_cache() + + previous_latent_single = current_latent_single + + except Exception as e: + print(f"\n[FRAME ERROR] {img_path}") + print(f"Type d'erreur : {type(e).__name__}") + print(f"Message d'erreur : {e}") + print("Traceback complet :") + traceback.print_exc() + continue + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=upscale_factor) + print(f"🎬 Vidéo générée : {out_video}") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rProNet_utils.py b/scripts/n3rProNet_utils.py new file mode 100644 index 00000000..3fc65d93 --- /dev/null +++ b/scripts/n3rProNet_utils.py @@ -0,0 +1,2766 @@ +# n3rProNet_utils.py +#------------------------------------------------------------------------------- +from .tools_utils import ensure_4_channels, sanitize_latents, log_debug +import torch +import math +import numpy as np +from PIL import Image, ImageFilter +import torch.nn.functional as F +from pathlib import Path + +from torchvision.transforms.functional import to_pil_image + + +def scale_eye_coords_to_latents(eye_coords, img_H, img_W, lat_H, lat_W): + """ + Convertit coords image -> latent space + """ + + # 🔥 FIX : gérer None ou liste vide + if not eye_coords: + return None + + scale_x = lat_W / img_W + scale_y = lat_H / img_H + + return [(int(x * scale_x), int(y * scale_y)) for x, y in eye_coords] + + +def get_eye_coords_safe(image_pil, H=None, W=None): + try: + coords = get_eye_coords(image_pil) + if coords is None: + print("⚠️ Aucun visage détecté") + return None + print(f"👁 Eyes detected: {coords}") + return coords + except Exception as e: + print(f"[Eye detection ERROR] {e}") + return None + + +def get_eye_coords(image_pil): + """ + Détecte les coordonnées des yeux avec MediaPipe. + + Args: + image_pil (PIL.Image): image d'entrée + + Returns: + list[(x, y)]: centres des yeux en coordonnées image + """ + import numpy as np + import mediapipe as mp + + mp_face_mesh = mp.solutions.face_mesh + + image = np.array(image_pil.convert("RGB")) + h, w, _ = image.shape + + with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=1, + refine_landmarks=True + ) as face_mesh: + + results = face_mesh.process(image) + + if not results.multi_face_landmarks: + return None + + face_landmarks = results.multi_face_landmarks[0] + + # 🔹 Indices iris MediaPipe (refine_landmarks=True requis) + LEFT_IRIS = [474, 475, 476, 477] + RIGHT_IRIS = [469, 470, 471, 472] + + def get_center(indices): + xs, ys = [], [] + for idx in indices: + lm = face_landmarks.landmark[idx] + xs.append(lm.x * w) + ys.append(lm.y * h) + return int(sum(xs) / len(xs)), int(sum(ys) / len(ys)) + + left_eye = get_center(LEFT_IRIS) + right_eye = get_center(RIGHT_IRIS) + + return [left_eye, right_eye] + +def apply_glow_froid_iris(latents, eye_coords, iris_radius_ratio=0.08, strength=0.25, blur_kernel=5): + """ + Applique un glow froid ciblé sur l'iris des yeux dans les latents [B,C,H,W]. + + Args: + latents (torch.Tensor): Latents [B,C,H,W]. + eye_coords (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)]. + iris_radius_ratio (float): Ratio de rayon de l'iris par rapport à la plus petite dimension H/W. + strength (float): Intensité du glow (0.0 à 1.0). + blur_kernel (int): Taille du noyau pour un léger flou gaussien. + + Returns: + torch.Tensor: Latents avec glow appliqué sur les iris. + """ + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Créer un masque radial pour chaque œil + mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + min_dim = min(H, W) + iris_radius = iris_radius_ratio * min_dim + + yy, xx = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + for x_eye, y_eye in eye_coords: + dist = torch.sqrt((xx - x_eye)**2 + (yy - y_eye)**2) + eye_mask = torch.exp(-(dist**2) / (2 * iris_radius**2)) + mask += eye_mask.unsqueeze(0) # broadcast batch dimension + + # Clamp à 1 pour éviter dépassement si 2 yeux se chevauchent + mask = mask.clamp(0.0, 1.0) + + # 2️⃣ Appliquer léger blur pour adoucir les bords + if blur_kernel > 1: + kernel = torch.ones((C, 1, blur_kernel, blur_kernel), device=device, dtype=dtype) + kernel = kernel / kernel.sum() + mask = F.conv2d(mask.repeat(1, C, 1, 1), kernel, padding=blur_kernel//2, groups=C) + + # 3️⃣ Créer glow gaussien via convolution légère + sigma = blur_kernel / 3.0 + glow_kernel = torch.exp(-((torch.arange(-blur_kernel//2+1, blur_kernel//2+2, device=device).view(-1,1))**2)/ (2*sigma**2)) + glow_kernel = glow_kernel / glow_kernel.sum() + glow_kernel = glow_kernel.view(1,1,blur_kernel,1).repeat(C,1,1,1) + glow = F.conv2d(latents, glow_kernel, padding=(blur_kernel//2,0), groups=C) + glow = F.conv2d(glow, glow_kernel.transpose(2,3), padding=(0,blur_kernel//2), groups=C) # convolution 2D approximative + + # 4️⃣ Fusion glow sur iris seulement + latents_out = latents * (1 - mask) + glow * mask * strength + latents_out = latents_out.clamp(-1.0, 1.0) + + return latents_out + +def apply_intelligent_glow_froid_latents(latents, strength=0.2, blur_kernel=7): + """ + Applique un effet "glow froid" directement sur des latents [B, C, H, W]. + + Args: + latents (torch.Tensor): Latents [B,C,H,W]. + strength (float): Intensité du glow (0.0 à 1.0). + blur_kernel (int): Taille du noyau pour le flou gaussien (doit être impair). + + Returns: + torch.Tensor: Latents avec glow appliqué. + """ + if latents.ndim != 4: + raise ValueError("Latents doivent être de shape [B, C, H, W]") + + B, C, H, W = latents.shape + + # 🔹 Création noyau gaussien 2D + def gaussian_kernel(kernel_size, sigma, channels): + ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=latents.device) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel = torch.exp(-(xx**2 + yy**2) / (2.0 * sigma**2)) + kernel = kernel / kernel.sum() + kernel = kernel.view(1, 1, kernel_size, kernel_size).repeat(channels, 1, 1, 1) + return kernel + + sigma = blur_kernel / 3.0 + kernel = gaussian_kernel(blur_kernel, sigma, C).to(latents.device, latents.dtype) + + padding = blur_kernel // 2 + # 🔹 Appliquer convolution pour obtenir le glow + glow = F.conv2d(latents, kernel, padding=padding, groups=C) + + # 🔹 Fusion latents original + glow + latents_out = latents * (1 - strength) + glow * strength + + # 🔹 Clamp pour stabilité + latents_out = latents_out.clamp(-1.0, 1.0) + + return latents_out + + +# Appplication effect sur les iris yeux: +def apply_glow_froid_iris(latents, eye_coords, iris_radius_ratio=0.08, strength=0.2, blur_kernel=7): + """ + Applique un glow froid uniquement sur l'iris des yeux dans les latents [B,C,H,W]. + + Args: + latents (torch.Tensor): Latents SD [B,C,H,W] + eye_coords (list of tuples): Coordonnées des yeux [(x1,y1),(x2,y2)] + iris_radius_ratio (float): proportion de H/W pour rayon iris + strength (float): intensité du glow + blur_kernel (int): taille du kernel gaussien (impair) + + Returns: + torch.Tensor: latents avec glow sur iris + """ + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Créer masque radial pour l’iris + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + for i, (x, y) in enumerate(eye_coords): + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + # coordonnées grille + Y, X = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + dist2 = ((X - x)**2) / (rx**2) + ((Y - y)**2) / (ry**2) + iris_mask[0, 0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0, 1) # éviter >1 si deux yeux se chevauchent + + # 2️⃣ Créer kernel gaussien 2D + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) # [C,1,kH,kW] + + # 3️⃣ Appliquer convolution channel-wise + glow = F.conv2d(latents * iris_mask, kernel, padding=blur_kernel // 2, groups=C) + + # 4️⃣ Fusion glow sur iris uniquement + latents_out = latents * (1 - iris_mask) + glow * iris_mask * strength + latents_out = latents_out.clamp(-1.0, 1.0) + + return latents_out + + +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +#----------- Rendu HD ------------------------------ +def apply_pro_net_volumetrique( + latents, + coords_v, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + glow_strength=0.2, + blur_kernel=3, # plus petit = détails plus nets + iris_radius_ratio=0.08, + mask_blur_kernel=1, # très léger flou du masque + debug=False +): + """ + ProNet volumétrique HD + glow iris avec contours adoucis mais plus net + """ + + import torch + import torch.nn.functional as F + + if not coords_v: + return apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn).to(dtype) + + # 2️⃣ Masque iris + iris_mask = torch.zeros((B,1,H,W), device=device, dtype=dtype) + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in coords_v: + rx = max(1, int(W * iris_radius_ratio)) + ry = max(1, int(H * iris_radius_ratio)) + dist2 = ((X - x)**2)/(rx**2 + 1e-6) + ((Y - y)**2)/(ry**2 + 1e-6) + iris_mask[0,0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0,1) + + # 3️⃣ Léger flou du masque seulement + if mask_blur_kernel > 1: + sigma = mask_blur_kernel / 3 + ax = torch.arange(-mask_blur_kernel//2 + 1., mask_blur_kernel//2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + mask_kernel = torch.exp(-(xx**2 + yy**2)/(2*sigma**2)) + mask_kernel = mask_kernel / mask_kernel.sum() + mask_kernel = mask_kernel.view(1,1,mask_blur_kernel,mask_blur_kernel) + iris_mask = F.conv2d(iris_mask, mask_kernel, padding=mask_blur_kernel//2) + iris_mask = iris_mask.clamp(0,1) + + # 4️⃣ Détails HD (high-frequency) + if blur_kernel > 1: + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel//2 + 1., blur_kernel//2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2)/(2*sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1,1,blur_kernel,blur_kernel).repeat(C,1,1,1).to(dtype) + blurred = F.conv2d(latents_prot, kernel, padding=blur_kernel//2, groups=C) + high_freq = latents_prot - blurred + else: + high_freq = latents_prot - latents_prot # pas de flou → pas de high_freq + + # 5️⃣ Glow adaptatif seulement sur iris + latents_out = latents_prot + glow_strength * high_freq * iris_mask + latents_out = latents_out.clamp(-1.0,1.0) + + # 6️⃣ Debug + if debug: + import matplotlib.pyplot as plt + plt.figure(figsize=(12,4)) + plt.subplot(1,3,1); plt.imshow(latents_prot[0,0].detach().cpu(), cmap='gray'); plt.title("ProNet") + plt.subplot(1,3,2); plt.imshow(high_freq[0,0].detach().cpu(), cmap='gray'); plt.title("High-Freq") + plt.subplot(1,3,3); plt.imshow(iris_mask[0,0].detach().cpu(), cmap='Reds', alpha=0.5); plt.title("Mask Iris") + plt.tight_layout(); plt.show() + print("👁 DEBUG HD sharp appliqué") + + return latents_out + +def apply_pro_net_volumetrique_good( + latents, + coords_v, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + glow_strength=0.2, + blur_kernel=7, + iris_radius_ratio=0.08, + debug=False +): + """ + Applique ProNet et un effet "HDR / détail" sur les iris des yeux, + compatible FP16 / latents interpolés. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents à traiter. + coords_v (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)]. + n3r_pro_net: modèle ProNet + n3r_pro_strength (float): force ProNet + sanitize_fn: fonction de nettoyage latents + glow_strength (float): intensité du glow / amplification + blur_kernel (int): taille du kernel pour flou + iris_radius_ratio (float): proportion de H/W pour rayon iris + debug (bool): visualisation mask + latents + + Returns: + torch.Tensor: latents avec effet HDR sur iris uniquement + """ + if not coords_v: + # Aucun yeux détectés → ProNet seul + return apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Appliquer ProNet + latents_prot = apply_n3r_pro_net( + latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn + ) + + # 2️⃣ Créer masque iris + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in coords_v: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + dist2 = ((X - x)**2)/(rx**2) + ((Y - y)**2)/(ry**2) + iris_mask[0, 0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0, 1) + + # 3️⃣ Kernel gaussien, même dtype que latents (FP16 ok) + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) + + # 4️⃣ Convolution channel-wise → amplification détails iris + glow = F.conv2d(latents_prot * iris_mask, kernel, padding=blur_kernel // 2, groups=C) + + # 5️⃣ Fusion ProNet + iris glow + latents_out = latents_prot * (1 - iris_mask) + glow * iris_mask * glow_strength + latents_out = latents_out.clamp(-1.0, 1.0) + + # ---------------- DEBUG ---------------- + if debug: + lat_vis = latents[0, 0].detach().cpu() + prot_vis = latents_prot[0, 0].detach().cpu() + glow_vis = glow[0, 0].detach().cpu() + mask_vis = iris_mask[0, 0].detach().cpu() + + plt.figure(figsize=(12, 4)) + plt.subplot(1, 4, 1) + plt.imshow(lat_vis, cmap='gray') + plt.title("Latent original") + plt.subplot(1, 4, 2) + plt.imshow(prot_vis, cmap='gray') + plt.title("ProNet") + plt.subplot(1, 4, 3) + plt.imshow(glow_vis, cmap='gray') + plt.title("HDR / Glow Iris") + plt.subplot(1, 4, 4) + plt.imshow(lat_vis, cmap='gray', alpha=0.7) + plt.imshow(mask_vis, cmap='Reds', alpha=0.4) + plt.title("Mask Iris") + plt.tight_layout() + plt.show() + print("👁 DEBUG activé → vérifie position / taille iris") + + return latents_out + +#----- Amplification des détails des yeux + +def apply_pro_net_with_eyes( + latents, + eye_coords, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + detail_strength=0.35, # intensité HDR + blur_kernel=5, # kernel pour détails + iris_radius_ratio=0.06, # plus petit = cible mieux iris + mask_blur_kernel=3 # flou du masque pour adoucir les contours +): + """ + ProNet optimisé + amplification HDR des détails sur l’iris (pas glow) + avec fusion douce pour éviter halo sur les contours. + """ + + import torch + import torch.nn.functional as F + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Appliquer ProNet standard + latents_prot = apply_n3r_pro_net( + latents, + model=n3r_pro_net, + strength=n3r_pro_strength, + sanitize_fn=sanitize_fn + ).to(dtype) + + # 2️⃣ Si pas d’yeux → fallback + if not eye_coords: + return latents_prot + + # 3️⃣ Création masque IRIS + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in eye_coords: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + dist = ((X - x)**2) / (rx**2 + 1e-6) + ((Y - y)**2) / (ry**2 + 1e-6) + iris_mask[0, 0] += (dist <= 1).float() + + iris_mask = iris_mask.clamp(0, 1) + + # 4️⃣ Flouter le masque pour adoucir les contours + if mask_blur_kernel > 1: + sigma = mask_blur_kernel / 3 + ax = torch.arange(-mask_blur_kernel // 2 + 1., mask_blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + mask_kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + mask_kernel_2d = mask_kernel_2d / mask_kernel_2d.sum() + mask_kernel = mask_kernel_2d.view(1, 1, mask_blur_kernel, mask_blur_kernel) + iris_mask = F.conv2d(iris_mask, mask_kernel, padding=mask_blur_kernel // 2) + iris_mask = iris_mask.clamp(0, 1) + + # 5️⃣ Blur pour récupérer les détails (high-frequency) + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1).to(dtype) + blurred = F.conv2d(latents_prot, kernel, padding=blur_kernel // 2, groups=C) + details = latents_prot - blurred + + # 6️⃣ Amplification HDR adaptative selon le masque flou + detail_strength_map = detail_strength * iris_mask + enhanced = latents_prot + details * detail_strength_map + + # 7️⃣ Fusion douce + latents_out = latents_prot * (1 - iris_mask) + enhanced * iris_mask + + # 8️⃣ Clamp final pour sécurité + latents_out = torch.clamp(latents_out, -1.0, 1.0) + + print("👁 HDR détails appliqué sur iris avec contours adoucis") + + return latents_out + +#------------ Stable version mais un peu fort ---- +def apply_pro_net_with_eyes_boost( + latents, + eye_coords, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + detail_strength=0.35, # intensité HDR + blur_kernel=5, # plus petit = plus précis + iris_radius_ratio=0.06 # plus petit = cible mieux iris +): + """ + ProNet + amplification HDR des détails sur l’iris (pas glow). + """ + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ ProNet + latents_prot = apply_n3r_pro_net( + latents, + model=n3r_pro_net, + strength=n3r_pro_strength, + sanitize_fn=sanitize_fn + ) + + # 🔒 sécurité dtype (évite ton erreur Half/Float) + latents_prot = latents_prot.to(dtype) + + # 2️⃣ Si pas d’yeux → fallback + if not eye_coords: + return latents_prot + + # 3️⃣ Création masque IRIS (ellipse fine) + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in eye_coords: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + + dist = ((X - x)**2) / (rx**2 + 1e-6) + ((Y - y)**2) / (ry**2 + 1e-6) + iris_mask[0, 0] += (dist <= 1).float() + + iris_mask = iris_mask.clamp(0, 1) + + # 4️⃣ Kernel GAUSSIEN (corrigé) + sigma = blur_kernel / 3 + + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) + + # 🔒 même dtype que latents + kernel = kernel.to(dtype) + + # 5️⃣ Blur = base low-frequency + blurred = F.conv2d( + latents_prot, + kernel, + padding=blur_kernel // 2, + groups=C + ) + + # 6️⃣ Détails (high-frequency) + details = latents_prot - blurred + + # 7️⃣ Amplification HDR + enhanced = latents_prot + detail_strength * details + + # 8️⃣ Fusion UNIQUEMENT sur iris + latents_out = latents_prot * (1 - iris_mask) + enhanced * iris_mask + + # 9️⃣ Clamp sécurité + latents_out = latents_out.clamp(-1.0, 1.0) + + print("👁 HDR détails appliqué sur iris") + + return latents_out + +def apply_pro_net_with_eyes_test(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn, + glow_strength=0.2, blur_kernel=7, iris_radius_ratio=0.08): + """ + Applique ProNet et un glow froid uniquement sur l’iris des yeux. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents à traiter. + eye_coords (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)] + n3r_pro_net: modèle ProNet + n3r_pro_strength (float): force ProNet + sanitize_fn: fonction de nettoyage latents + glow_strength (float): intensité du glow + blur_kernel (int): kernel pour flou gaussien + iris_radius_ratio (float): proportion de H/W pour rayon iris + + Returns: + torch.Tensor: latents avec glow sur iris uniquement + """ + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Application ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + # 2️⃣ Glow froid uniquement sur l’iris + if eye_coords: + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + for x, y in eye_coords: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + Y, X = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + dist2 = ((X - x)**2) / (rx**2) + ((Y - y)**2) / (ry**2) + iris_mask[0, 0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0, 1) + + # Kernel gaussien 2D + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) + + # Convolution channel-wise pour glow + glow = F.conv2d(latents_prot * iris_mask, kernel, padding=blur_kernel // 2, groups=C) + + # Fusion uniquement sur l’iris + latents_out = latents_prot * (1 - iris_mask) + glow * iris_mask * glow_strength + latents_out = latents_out.clamp(-1.0, 1.0) + print("👁 Glow froid appliqué sur iris uniquement") + else: + # fallback si pas d’yeux détectés + latents_out = latents_prot + + return latents_out + + +def apply_pro_net_with_eye_glow(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn, glow_strength=0.2, blur_kernel=7): + """ + Applique ProNet et un glow froid uniquement sur les yeux. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents à traiter. + eye_coords (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)] + n3r_pro_net: modèle ProNet + n3r_pro_strength (float): force ProNet + sanitize_fn: fonction de nettoyage latents + glow_strength (float): intensité du glow + blur_kernel (int): kernel pour le flou + + Returns: + torch.Tensor: latents avec glow sur yeux uniquement + """ + # 1️⃣ Appliquer ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + # 2️⃣ Glow froid sur latents ProNet + glow_latents = apply_intelligent_glow_froid_latents(latents_prot, strength=glow_strength, blur_kernel=blur_kernel) + + + # 3️⃣ Fusion glow uniquement sur les yeux + if eye_coords: + eye_radius = int(min(latents.shape[-2:]) * 0.15) + eye_mask = create_eye_mask(latents, eye_coords, eye_radius) + if eye_mask is not None: + eye_mask = eye_mask.to(latents.device).float() + if eye_mask.ndim == 3: # [B,H,W] -> [B,1,H,W] + eye_mask = eye_mask.unsqueeze(1) + latents = latents * (1 - eye_mask) + glow_latents * eye_mask + print("👁 Glow froid appliqué uniquement sur yeux") + else: + latents = glow_latents # fallback + else: + latents = glow_latents # pas d’yeux détectés → glow global + + return latents + +# Application effect en dehors de yeux: +def apply_pro_net_with_out_eyes(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn): + # 1️⃣ Application du ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + # 2️⃣ Application du glow froid intelligent en dehors des yeux sur le ProNet + latents_prot = apply_intelligent_glow_froid_out(latents_prot) + + # 3️⃣ Fusion avec le masque yeux si détecté + if eye_coords: + print("Eye coords:", eye_coords) + eye_radius = int(min(latents.shape[-2:]) * 0.15) # augmenter légèrement pour protection + eye_mask = create_eye_mask(latents, eye_coords, eye_radius) + + if eye_mask is not None: + eye_mask = eye_mask.to(latents.device) + # protection yeux + ProNet + glow + latents = latents * eye_mask + latents_prot * (1 - eye_mask) + print("👁 protection yeux appliquée avec glow froid") + else: + # si le masque échoue, on applique ProNet + glow sur tout + latents = latents_prot + else: + # pas d’yeux détectés → ProNet + glow global + latents = latents_prot + + return latents + + +def apply_pro_net_with_eye_simple(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn): + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + if eye_coords: + print("Eye coords:", eye_coords) + eye_radius = int(min(latents.shape[-2:]) * 0.15) # un peu plus large valeur 0.12 ou 0.15 + eye_mask = create_eye_mask(latents, eye_coords, eye_radius) + if eye_mask is not None: + eye_mask = eye_mask.to(latents.device) + latents = latents * eye_mask + latents_prot * (1 - eye_mask) + print("👁 protection yeux appliquée (main frames)") + else: + latents = latents_prot + else: + latents = latents_prot + return latents + +def tensor_to_pil(tensor): + """ + tensor: [1,3,H,W] ou [3,H,W] dans [-1,1] + """ + if tensor.dim() == 4: + tensor = tensor[0] + tensor = (tensor.clamp(-1,1) + 1) / 2 + return to_pil_image(tensor.cpu()) + +try: + import mediapipe as mp + from mediapipe.python.solutions import face_mesh as mp_face_mesh + MP_FACE_MESH = mp_face_mesh +except Exception: + MP_FACE_MESH = None + print("⚠ mediapipe non disponible → fallback sans yeux") + + +def get_coords_safe(image, H, W): + coords = get_coords(image) + + if coords: + print(f"👁 Eyes detected: {coords}") + return coords + + print("⚠ fallback eye coords used") + + # 🔥 adapté portrait vertical (ton cas 536x960) + return [ + (int(H * 0.32), int(W * 0.38)), + (int(H * 0.32), int(W * 0.62)) + ] + +# -------------------------------------------------- +# 🔥 Détection yeux (version clean sans cv2) +# -------------------------------------------------- +def get_coords(image): + """ + Retourne [(y_left, x_left), (y_right, x_right)] + Compatible PIL ou numpy + """ + if MP_FACE_MESH is None: + return [] + + # Conversion propre + if isinstance(image, Image.Image): + img = np.array(image) + else: + img = image + + if img is None or img.ndim != 3: + return [] + + h, w, _ = img.shape + + with MP_FACE_MESH.FaceMesh(static_image_mode=True, max_num_faces=1) as face_mesh: + results = face_mesh.process(img) # ✅ déjà RGB → pas besoin de cv2 + + if not results.multi_face_landmarks: + return [] + + lm = results.multi_face_landmarks[0].landmark + + # 🔥 Points clés yeux (stables) + left_eye_pts = [33, 133] + right_eye_pts = [362, 263] + + left_eye = np.mean([(lm[i].y * h, lm[i].x * w) for i in left_eye_pts], axis=0) + right_eye = np.mean([(lm[i].y * h, lm[i].x * w) for i in right_eye_pts], axis=0) + + return [ + (int(left_eye[0]), int(left_eye[1])), + (int(right_eye[0]), int(right_eye[1])) + ] + +# -------------------------------------------------- +# 🔥 Création mask yeux (latents) +# -------------------------------------------------- +import torch +import matplotlib.pyplot as plt + + +def create_volumetrique_mask(latents, coords, radius_ratio=0.15, only=False, in_radius_ratio=0.08, debug=False): + """ + Crée un masque pour les yeux ou uniquement pour l’iris. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents + coords (list of tuples): [(x1,y1),(x2,y2)] coordonnées yeux + radius_ratio (float): proportion H/W pour rayon + only (bool): True → masque uniquement iris, False → masque œil entier + in_radius_ratio (float): proportion H/W pour rayon iris si only=True + debug (bool): Si True, affiche le masque + + Returns: + torch.Tensor: [B,1,H,W] masque float (0=hors masque, 1=masque) + """ + #if not coords or latents.ndim != 4: + if coords is None or len(coords) == 0 or latents.ndim != 4: + return None + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + + for x, y in coords: + r = int(min(H, W) * (radius_ratio if only else in_radius_ratio)) + Y, X = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + dist2 = (X - x)**2 + (Y - y)**2 + mask[0, 0] += (dist2 <= r**2).float() + + mask = mask.clamp(0, 1) + + if debug: + # Affiche le masque superposé à un latents converti en image pour vérification + lat_vis = latents[0, 0].detach().cpu() # canal 0 + plt.figure(figsize=(6,6)) + plt.imshow(lat_vis, cmap='gray', alpha=0.7) + plt.imshow(mask[0,0].cpu(), cmap='Reds', alpha=0.3) + plt.title("Debug Eye/Iris Mask") + plt.show() + + return mask + +def create_eye_mask(latents, eye_coords, eye_radius=8, falloff=4): + """ + Soft mask gaussien → transitions naturelles + """ + if eye_coords is None or len(eye_coords) == 0: + return None + + B, C, H, W = latents.shape + mask = torch.zeros((1, 1, H, W), device=latents.device) + + for y_c, x_c in eye_coords: + y_lat = int(y_c / 8) + x_lat = int(x_c / 8) + + for y in range(H): + for x in range(W): + dist = ((y - y_lat)**2 + (x - x_lat)**2)**0.5 + value = max(0, 1 - dist / (eye_radius + falloff)) + mask[0, 0, y, x] = torch.maximum(mask[0, 0, y, x], torch.tensor(value, device=latents.device)) + + return mask.repeat(B, C, 1, 1) + +def detect_eyes_auto(frame_pil): + """Retourne les coordonnées (y,x) approximatives des yeux""" + img = np.array(frame_pil) + h, w, _ = img.shape + with MP_FACE_MESH.FaceMesh(static_image_mode=True, max_num_faces=1) as face_mesh: + results = face_mesh.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + if not results.multi_face_landmarks: + return [] + lm = results.multi_face_landmarks[0].landmark + left_eye = np.mean([(lm[i].y*h, lm[i].x*w) for i in [33, 133]], axis=0) + right_eye = np.mean([(lm[i].y*h, lm[i].x*w) for i in [362, 263]], axis=0) + return [(int(left_eye[0]), int(left_eye[1])), (int(right_eye[0]), int(right_eye[1]))] + +# Decode avec blending optimise : +# +# --------------------------------------------------------------------------------------------- +def decode_latents_ultrasafe_blockwise_ultranatural( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.015, + sharpen_edges_strength=0.02, + gamma_boost=1.03 # légèrement plus de punch naturel +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN ADAPTATIF + # ========================================================= + if sharpen_mode is not None: + + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + local_std = detail.std(dim=[2,3], keepdim=True) + 1e-6 + adapt_strength = sharpen_strength / (1 + 5*(1-local_std)) + output_rgb = output_rgb + adapt_strength * torch.tanh(detail) + + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + kernel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], device=device, dtype=output_rgb.dtype) + kernel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], device=device, dtype=output_rgb.dtype) + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- Gamma adaptatif ---------------- + output_rgb_gamma = ((output_rgb + 1) / 2.0).clamp(0,1) + luminance = output_rgb_gamma.mean(dim=1, keepdim=True) + adapt_gamma = gamma_boost * (1.0 + 0.1*(0.5-luminance)) # boost plus fort pour zones un peu ternes + output_rgb_gamma = output_rgb_gamma ** adapt_gamma + output_rgb = output_rgb_gamma * 2 - 1 + + # ---------------- Micro-boost couleur (zones un peu plates) ---------------- + mean_c = output_rgb.mean(dim=[2,3], keepdim=True) + color_boost = torch.sigmoid(5.0*(output_rgb - mean_c)) * 0.03 + output_rgb = (output_rgb + color_boost).clamp(-1.0, 1.0) + + # ---------------- To PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + +def decode_latents_ultrasafe_blockwise_natural( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.02, + sharpen_edges_strength=0.02, + gamma_boost=1.10 # 12% plus de punch naturel +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN SECTION ADAPTATIVE + # ========================================================= + if sharpen_mode is not None: + + # ---- 1. Tanh sharpen (détails globaux adaptatifs) + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + # facteur adaptatif selon contraste local + local_std = detail.std(dim=[2,3], keepdim=True) + 1e-6 + adapt_strength = sharpen_strength / (1 + 5*(1-local_std)) + output_rgb = output_rgb + adapt_strength * torch.tanh(detail) + + # ---- 2. Edge sharpen adaptatif + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + + kernel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], device=device, dtype=output_rgb.dtype) + kernel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], device=device, dtype=output_rgb.dtype) + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- Gamma adaptatif ---------------- + output_rgb_gamma = ((output_rgb + 1) / 2.0).clamp(0,1) # [0,1] + output_rgb_gamma = output_rgb_gamma ** gamma_boost + output_rgb = output_rgb_gamma * 2 - 1 + + # ---------------- To PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_sharp( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.04, + sharpen_edges_strength=0.05 +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN SECTION + # ========================================================= + + if sharpen_mode is not None: + + # ---- 1. Tanh sharpen (détails globaux) + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + output_rgb = output_rgb + sharpen_strength * torch.tanh(detail) + + # ---- Edge sharpen PRO (anti plastique) + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + + kernel_x = torch.tensor( + [[-1,0,1],[-2,0,2],[-1,0,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_y = torch.tensor( + [[-1,-2,-1],[0,0,0],[1,2,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + + # 🔥 NORMALISATION douce (pas globale) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + + # 🔥 MASQUE BEAUCOUP plus sélectif (clé) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + + # 🔥 DIRECTION du contraste (pas ajout brut) + sign = torch.sign(output_rgb) + + output_rgb = output_rgb + sharpen_edges_strength * edge_mask * sign * edges * 0.5 + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- To PIL ---------------- + # Ajouter gamma boost ici + gamma = 1.10 + output_rgb_gamma = ((output_rgb + 1.0) / 2.0).clamp(0,1) + output_rgb_gamma = output_rgb_gamma ** gamma + output_rgb_gamma = output_rgb_gamma * 2.0 - 1.0 + output_rgb = output_rgb_gamma + + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_plastique( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.04, + sharpen_edges_strength=0.05 +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN SECTION + # ========================================================= + + if sharpen_mode is not None: + + # ---- 1. Tanh sharpen (détails globaux) + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + output_rgb = output_rgb + sharpen_strength * torch.tanh(detail) + + # ---- 2. Edge sharpen (version PRO stable) + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + + kernel_x = torch.tensor( + [[-1,0,1],[-2,0,2],[-1,0,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_y = torch.tensor( + [[-1,-2,-1],[0,0,0],[1,2,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + + # 🔥 NORMALISATION LOCALE (clé stabilité) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + + # 🔥 MASQUE edges (évite bruit dans zones plates) + edge_mask = torch.sigmoid(4.0 * (edges - 0.5)) + + # 🔥 Sharpen intelligent + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- To PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_pro( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True +): + import torch + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # 🔥 Fenêtre de blending PRO (Hann = ultra stable) + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + # 🔥 feather dynamique (corrige bord image) + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # 🔥 sécurité critique (évite artefacts) + weight = torch.clamp(weight, min=1e-3) + + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +# Decode latents par blockwise - ultrasafe : +def decode_latents_ultrasafe_blockwise(latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0): + """ + Décodage ultra-safe par blocs des latents en image PIL. + Paramètres conservés uniquement : block_size, overlap, device, frame_counter, latent_scale_boost + """ + import torch + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + decoded.shape[2], ix0 + decoded.shape[3] + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + output_rgb = (output_rgb / weight.clamp(min=1e-6)).clamp(-1.0, 1.0) + + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def apply_intelligent_glow_pro( + frame_pil, + strength=0.18, + edge_weight=0.6, + luminance_weight=0.8, + blur_radius=1.2 +): + from PIL import Image, ImageFilter + import numpy as np + + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ---------------- Luminance ---------------- + lum = 0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2] + lum_mask = np.clip((lum - 0.6) / 0.4, 0, 1) + lum_mask = np.power(lum_mask, 1.5) + + # ---------------- Edge ---------------- + gray = (lum * 255).astype(np.uint8) + edge = Image.fromarray(gray).filter(ImageFilter.FIND_EDGES) + edge = np.array(edge).astype(np.float32) / 255.0 + edge = np.clip(edge * 1.2, 0, 1) + edge = np.power(edge, 1.3) + + # ---------------- Mask combiné ---------------- + combined_mask = np.clip(luminance_weight * lum_mask + edge_weight * edge, 0, 1) + + # ---------------- Glow ---------------- + glow_img = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + glow_arr = np.array(glow_img).astype(np.float32) / 255.0 + + # ---------------- Appliquer glow seulement sur la luminance ---------------- + glow_lum = 0.299 * glow_arr[:, :, 0] + 0.587 * glow_arr[:, :, 1] + 0.114 * glow_arr[:, :, 2] + + # mixer luminance glow + couleur originale + result = arr.copy() + for c in range(3): + # conserver la teinte originale mais injecter glow sur la luminosité + result[:, :, c] = arr[:, :, c] + (glow_lum - lum) * combined_mask * strength + + result = np.clip(result, 0, 1) + return Image.fromarray((result * 255).astype(np.uint8)) + + +def apply_intelligent_glow_froid( + frame_pil, + strength=0.18, + edge_weight=0.6, + luminance_weight=0.8, + blur_radius=1.2 +): + from PIL import Image, ImageFilter, ImageEnhance + import numpy as np + + # ---------------- Base ---------------- + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ---------------- Luminance mask ---------------- + # luminance perceptuelle + lum = 0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2] + + # masque doux (favorise les zones claires) + lum_mask = np.clip((lum - 0.6) / 0.4, 0, 1) + lum_mask = np.power(lum_mask, 1.5) # douceur + + # ---------------- Edge mask ---------------- + gray = (lum * 255).astype(np.uint8) + edge = Image.fromarray(gray).filter(ImageFilter.FIND_EDGES) + edge = np.array(edge).astype(np.float32) / 255.0 + + # adoucir les edges (évite bruit) + edge = np.clip(edge * 1.2, 0, 1) + edge = np.power(edge, 1.3) + + # ---------------- Fusion intelligente ---------------- + combined_mask = ( + luminance_weight * lum_mask + + edge_weight * edge + ) + + combined_mask = np.clip(combined_mask, 0, 1) + + # ---------------- Glow ---------------- + # blur image pour glow + glow_img = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + glow_arr = np.array(glow_img).astype(np.float32) / 255.0 + + # appliquer glow uniquement où mask actif + result = arr + (glow_arr - arr) * combined_mask[..., None] * strength + + result = np.clip(result, 0, 1) + + return Image.fromarray((result * 255).astype(np.uint8)) + + +def apply_post_processing_adaptive( + frame_pil, + blur_radius=0.03, + contrast=1.10, + vibrance_strength=0.25, # 🔥 contrôle simple (0 → off, 0.3 = doux) + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True +): + from PIL import ImageEnhance, ImageFilter + import numpy as np + + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + # ---------------- 1️⃣ Micro blur (anti pixel) ---------------- + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ---------------- 2️⃣ Contrast (seul vrai levier global) ---------------- + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + + # ---------------- 3️⃣ Vibrance douce (version stable) ---------------- + if vibrance_strength > 0: + try: + arr = np.array(frame_pil).astype(np.float32) + + # saturation simple + max_rgb = arr.max(axis=2) + min_rgb = arr.min(axis=2) + sat = (max_rgb - min_rgb) / 255.0 + + # 🔥 boost UNIQUEMENT zones peu saturées + boost = 1.0 + vibrance_strength * (1.0 - sat) + + arr = arr * boost[..., None] + + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] vibrance skipped: {e}") + + # ---------------- 4️⃣ Clamp rouge (anti rose / peau cramée) ---------------- + if clamp_r: + try: + arr = np.array(frame_pil).astype(np.float32) + + r = arr[:, :, 0] + r_mean = r.mean() + + if r_mean > 160: # 🔥 seuil plus bas = plus stable + factor = 160 / (r_mean + 1e-6) + arr[:, :, 0] *= factor + + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] clamp rouge skipped: {e}") + + # ---------------- 5️⃣ Sharpen léger ---------------- + if sharpen: + try: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + except Exception as e: + print(f"[WARNING] sharpening skipped: {e}") + + return frame_pil + + + + +def smooth_edges(frame_pil, strength=0.4, blur_radius=1.2): + from PIL import ImageFilter, ImageChops + import numpy as np + + # 1️⃣ edges + edges = frame_pil.convert("L").filter(ImageFilter.FIND_EDGES) + + # 2️⃣ normalisation du masque + edges_np = np.array(edges).astype(np.float32) / 255.0 + edges_np = np.clip(edges_np * 2.0, 0, 1) # renforce zones edges + + # 3️⃣ blur global (source) + blurred = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # 4️⃣ blend intelligent + orig = np.array(frame_pil).astype(np.float32) + blur = np.array(blurred).astype(np.float32) + + mask = edges_np[..., None] * strength + + result = orig * (1 - mask) + blur * mask + + return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) + + +def apply_post_processing_unreal_cinematic( + frame_pil, + exposure=1.0, + vibrance=1.02, + edge_strength=0.25, + sharpen=True, + brightness_adj=0.90, # 🔻 -5% + contrast_adj=1.65 # 🔺 +65% +): + from PIL import Image, ImageEnhance, ImageFilter, ImageChops + import numpy as np + + # 🔥 1. Base (sans toucher contraste global) + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr *= exposure + + # Vibrance douce + mean_c = arr.mean(axis=2, keepdims=True) + arr = mean_c + (arr - mean_c) * vibrance + arr = np.clip(arr, 0, 1) + + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ========================= + # ✏️ EDGE CRAYON BLANC + # ========================= + gray = img.convert("L") + edges = gray.filter(ImageFilter.FIND_EDGES) + + edges = edges.filter(ImageFilter.GaussianBlur(radius=0.8)) + edges = ImageChops.invert(edges) + edges = ImageEnhance.Contrast(edges).enhance(1.2) + + edge_rgb = Image.merge("RGB", (edges, edges, edges)) + + # Screen = effet lumineux propre + img_edges = ImageChops.screen(img, edge_rgb) + + # Blend final contrôlé + img = Image.blend(frame_pil, img_edges, edge_strength) + + # ========================= + # 🔥 AJUSTEMENTS DEMANDÉS + # ========================= + img = ImageEnhance.Brightness(img).enhance(brightness_adj) + img = ImageEnhance.Contrast(img).enhance(contrast_adj) + + # ========================= + # 🔧 Sharpen doux + # ========================= + if sharpen: + img = img.filter(ImageFilter.UnsharpMask( + radius=0.5, + percent=30, + threshold=3 + )) + + # 🔥 micro lissage final + img = img.filter(ImageFilter.GaussianBlur(radius=0.25)) + + return img + +def apply_post_processing_minimal( + frame_pil, + blur_radius=0.05, + contrast=1.15, + vibrance_base=1.0, + vibrance_max=1.25, + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True +): + from PIL import Image, ImageFilter, ImageEnhance + import numpy as np + + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + # ---------------- 1. Blur léger ---------------- + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ---------------- 2. Contraste ---------------- + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + + # ---------------- 3. Vibrance adaptative ---------------- + try: + frame_np = np.array(frame_pil).astype(np.float32) + + max_rgb = np.max(frame_np, axis=2) + min_rgb = np.min(frame_np, axis=2) + sat = max_rgb - min_rgb + + factor_map = vibrance_base + (vibrance_max - vibrance_base) * (1 - sat / 255.0) + factor_map = np.clip(factor_map, vibrance_base, vibrance_max) + + frame_np *= factor_map[..., None] + frame_np = np.clip(frame_np, 0, 255) + + frame_pil = Image.fromarray(frame_np.astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] vibrance skipped: {e}") + + # ---------------- 4. Clamp rouge ---------------- + if clamp_r: + try: + arr = np.array(frame_pil).astype(np.float32) + r_mean = arr[..., 0].mean() + + if r_mean > 180: + factor = 180 / r_mean + arr[..., 0] *= factor + + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] clamp rouge skipped: {e}") + + # ---------------- 5. Sharpen ---------------- + if sharpen: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + + return frame_pil + +def apply_intelligent_glow(frame_pil, + glow_strength=0.22, + blur_radius=1.2, + luminance_threshold=0.7, + edge_strength=1.2, + detail_preservation=0.85): + """ + Glow intelligent : + - basé sur luminance + edges + - évite effet flou global + - boost détails lumineux uniquement + """ + from PIL import Image, ImageFilter, ImageEnhance, ImageChops + import numpy as np + + # ----------------------- + # 1️⃣ Base numpy + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ----------------------- + # 2️⃣ Luminance mask + # ----------------------- + gray = frame_pil.convert("L") + lum = np.array(gray).astype(np.float32) / 255.0 + + lum_mask = np.clip((lum - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + + # ----------------------- + # 3️⃣ Edge mask (important 🔥) + # ----------------------- + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = ImageEnhance.Contrast(edges).enhance(edge_strength) + + edge_arr = np.array(edges).astype(np.float32) / 255.0 + + # 🔥 combinaison intelligente + combined_mask = lum_mask * edge_arr + + # ----------------------- + # 4️⃣ Glow blur + # ----------------------- + blurred = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + blurred_arr = np.array(blurred).astype(np.float32) / 255.0 + + # ----------------------- + # 5️⃣ Application du glow + # ----------------------- + for c in range(3): + arr[..., c] = arr[..., c] + glow_strength * combined_mask * blurred_arr[..., c] + + arr = np.clip(arr, 0, 1) + + # ----------------------- + # 6️⃣ Reconstruction + # ----------------------- + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 7️⃣ Préservation détails + # ----------------------- + img = Image.blend(frame_pil, img, 1 - detail_preservation) + + # ----------------------- + # 8️⃣ Micro sharpen + # ----------------------- + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=25, threshold=2)) + + return img + + +def apply_chromatic_soft_glow(frame_pil, + glow_strength=0.25, + exposure=1.05, + blur_radius=2.0, + luminance_threshold=0.8, + color_saturation=1.05, + sharpen=True): + """ + Soft Glow chromatique localisé : + - Glow appliqué sur pixels clairs selon leur canal (R/G/B) + - Zones sombres préservées + - Détails conservés + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr = np.clip(arr * exposure, 0, 1) + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # Masque par canal + # ----------------------- + r, g, b = arr[...,0], arr[...,1], arr[...,2] + mask_r = np.clip((r - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + mask_g = np.clip((g - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + mask_b = np.clip((b - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + + # ----------------------- + # Glow par canal + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + bright_arr = np.array(bright).astype(np.float32) / 255.0 + + # Mélange selon masque couleur + arr[...,0] = np.clip(arr[...,0] + glow_strength * mask_r * bright_arr[...,0], 0, 1) + arr[...,1] = np.clip(arr[...,1] + glow_strength * mask_g * bright_arr[...,1], 0, 1) + arr[...,2] = np.clip(arr[...,2] + glow_strength * mask_b * bright_arr[...,2], 0, 1) + + img = Image.fromarray((arr*255).astype(np.uint8)) + + # ----------------------- + # Saturation douce + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + + # ----------------------- + # Micro sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=30, threshold=2)) + + return img + + +def apply_localized_soft_glow(frame_pil, + glow_strength=0.25, + exposure=1.05, + blur_radius=2.0, + luminance_threshold=0.6, + color_saturation=1.05, + sharpen=True): + """ + Filtre 'Soft Glow Localisé': + - Glow appliqué seulement sur les zones lumineuses + - Effet subtil, préserve les zones sombres + - Maintien des détails + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Convertir en float + exposure + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr = np.clip(arr * exposure, 0, 1) + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 2️⃣ Masque de luminosité + # ----------------------- + gray = img.convert("L") + lum_arr = np.array(gray).astype(np.float32) / 255.0 + mask = np.clip((lum_arr - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + mask_img = Image.fromarray((mask * 255).astype(np.uint8)) + + # ----------------------- + # 3️⃣ Glow léger + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + glow_img = ImageChops.screen(img, bright) + # Appliquer glow uniquement là où mask > 0 + glow_img = Image.composite(glow_img, img, mask_img) + img = Image.blend(img, glow_img, glow_strength) + + # ----------------------- + # 4️⃣ Saturation douce + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + + # ----------------------- + # 5️⃣ Micro sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=30, threshold=2)) + + return img + + +def apply_soft_glow(frame_pil, + glow_strength=0.25, + exposure=1.05, + blur_radius=2.0, + color_saturation=1.05, + sharpen=True): + """ + Filtre 'Soft Glow' : + - Surexposition douce sur les zones claires + - Glow léger et subtil + - Maintien des détails et textures + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Convertir en float + exposure léger + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr = np.clip(arr * exposure, 0, 1) + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 2️⃣ Glow subtil (Light Bloom) + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + img = ImageChops.screen(img, bright) + img = Image.blend(img, bright, glow_strength) + + # ----------------------- + # 3️⃣ Saturation douce + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + + # ----------------------- + # 4️⃣ Micro sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=30, threshold=2)) + + return img + + +def apply_cinematic_neon_glow(frame_pil, + glow_strength=0.25, + edge_strength=0.15, + color_saturation=1.15, + exposure=1.05, + contrast=1.25, + blur_radius=0.4, + sharpen=True): + """ + Filtre original 'Cinematic Neon Glow': + - Glow subtil autour des zones claires + - Couleurs saturées style néon / cinématographique + - Bords légèrement lumineux type sketch + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Convertir en float + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ----------------------- + # 2️⃣ Exposure léger + # ----------------------- + arr *= exposure + arr = np.clip(arr, 0, 1) + + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 3️⃣ Glow subtil (Light Bloom) + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=5)) + img = ImageChops.screen(img, bright) # effet lumineux + img = Image.blend(img, bright, glow_strength) + + # ----------------------- + # 4️⃣ Edge sketch léger + # ----------------------- + gray = img.convert("L").filter(ImageFilter.GaussianBlur(radius=1.0)) + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = ImageChops.invert(edges) + edges_rgb = Image.merge("RGB", (edges, edges, edges)) + img = ImageChops.blend(img, edges_rgb, edge_strength) + + # ----------------------- + # 5️⃣ Saturation & Contraste + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + img = ImageEnhance.Contrast(img).enhance(contrast) + + # ----------------------- + # 6️⃣ Micro blur anti-pixel + # ----------------------- + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ----------------------- + # 7️⃣ Sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=40, threshold=2)) + + return img + + +def apply_post_processing_sketch(frame_pil, edge_strength=0.2, blur_radius=0.3, sharpen=True, + contrast_boost=1.6, # +60% contraste + exposure=0.80): # -20% brillance + """ + Effet dessin subtil / croquis clair ajusté : + - Contours légèrement visibles (blancs doux) + - +40% contraste, -10% brillance + - Lisse les pixels isolés + - Ne dénature pas les couleurs de base + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Edge detection doux + # ----------------------- + gray = frame_pil.convert("L").filter(ImageFilter.GaussianBlur(radius=0.5)) + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = edges.filter(ImageFilter.MedianFilter(size=3)) # supprime points isolés + edges = edges.filter(ImageFilter.GaussianBlur(radius=0.6)) # lissage + edges = ImageEnhance.Contrast(edges).enhance(1.2) + edges = ImageChops.invert(edges) + edge_rgb = Image.merge("RGB", (edges, edges, edges)) + + # ----------------------- + # 2️⃣ Fusion douce des edges + # ----------------------- + img = ImageChops.blend(frame_pil, edge_rgb, edge_strength) + + # ----------------------- + # 3️⃣ Exposure / Brillance + # ----------------------- + img = ImageEnhance.Brightness(img).enhance(exposure) + + # ----------------------- + # 4️⃣ Contraste + # ----------------------- + img = ImageEnhance.Contrast(img).enhance(contrast_boost) + + # ----------------------- + # 5️⃣ Blur léger anti-pixel + # ----------------------- + if blur_radius > 0: + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ----------------------- + # 6️⃣ Sharp subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=40, threshold=2)) + + return img + + + +def apply_post_processing_drawing(frame_pil, + edge_strength=0.7, + color_levels=48, + saturation=0.95, + contrast=1.10, + sharpen=True): + """ + Post-processing dessin type line-art. + Simplifie les couleurs, ajoute des contours au crayon blanc, + supprime les points noirs et garde un rendu net. + """ + + from PIL import Image, ImageFilter, ImageEnhance, ImageChops + import numpy as np + + # ----------------------- + # 1️⃣ Color simplification douce + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) + levels = color_levels + arr = np.round(arr / (256 / levels)) * (256 / levels) + img = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + # ----------------------- + # 2️⃣ Edge detection propre + # ----------------------- + gray = frame_pil.convert("L").filter(ImageFilter.GaussianBlur(radius=0.6)) + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = edges.filter(ImageFilter.GaussianBlur(radius=0.8)) + edges = edges.filter(ImageFilter.MedianFilter(size=3)) # supprime points isolés + edges = ImageEnhance.Contrast(edges).enhance(1.4) + edges = edges.point(lambda x: 0 if x < 15 else int(x * 1.2)) + edges = ImageChops.invert(edges) + edge_rgb = Image.merge("RGB", (edges, edges, edges)) + + # ----------------------- + # 3️⃣ Fusion douce contours + # ----------------------- + img_edges = ImageChops.multiply(img, edge_rgb) + img = Image.blend(img, img_edges, edge_strength * 0.85) + + # ----------------------- + # 4️⃣ Color / Contrast / Sharpen + # ----------------------- + img = ImageEnhance.Color(img).enhance(saturation) + img = ImageEnhance.Contrast(img).enhance(contrast) + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.6, percent=60, threshold=3)) + + return img + + + + +def save_frame_verbose(frame: Image.Image, output_dir: Path, frame_counter: int, suffix: str = "00", psave: bool = True): + """ + Sauvegarde une frame avec suffixe et affiche un message si verbose=True + + Args: + frame (Image.Image): Image PIL à sauvegarder + output_dir (Path): Dossier de sortie + frame_counter (int): Numéro de frame + suffix (str): Suffixe pour différencier les étapes + verbose (bool): Affiche le message si True + """ + file_path = output_dir / f"frame_{frame_counter:05d}_{suffix}.png" + + if psave: + print(f"[SAVE Frame {frame_counter:03d}_{suffix}] -> {file_path}") + frame.save(file_path) + return file_path + +def neutralize_color_cast(img, strength=0.45, warm_bias=0.015, green_bias=-0.07): + """ + Neutralise la dominante de couleur tout en corrigeant un excès de vert. + + Args: + img (PIL.Image): image à corriger + strength (float): intensité de neutralisation (0.0 = off, 1.0 = full) + warm_bias (float): réchauffe légèrement (rouge+/bleu-) + green_bias (float): ajuste le vert (-0.07 = moins 7%) + """ + import numpy as np + from PIL import Image + + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + gray = mean.mean() + + gain = gray / (mean + 1e-6) + gain = 1.0 + (gain - 1.0) * strength + + arr[..., 0] *= gain[0] * (1 + warm_bias) # rouge + + arr[..., 1] *= gain[1] * (1 + green_bias) # vert corrigé + arr[..., 2] *= gain[2] * (1 - warm_bias) # bleu - + + arr = np.clip(arr, 0, 255) + + return Image.fromarray(arr.astype(np.uint8)) + + +def neutralize_color_cast_clean(img, strength=0.6, warm_bias=0.02): + import numpy as np + from PIL import Image + + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + gray = mean.mean() + + gain = gray / (mean + 1e-6) + gain = 1.0 + (gain - 1.0) * strength + + arr[..., 0] *= gain[0] * (1 + warm_bias) # 🔥 léger rouge + + arr[..., 1] *= gain[1] + arr[..., 2] *= gain[2] * (1 - warm_bias) # 🔥 léger bleu - + + return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + +def neutralize_color_cast_str(img, strength=0.6): + import numpy as np + from PIL import Image + + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + gray = mean.mean() + + gain = gray / (mean + 1e-6) + + # 🔥 interpolation (clé) + gain = 1.0 + (gain - 1.0) * strength + + arr[..., 0] *= gain[0] + arr[..., 1] *= gain[1] + arr[..., 2] *= gain[2] + + return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + +def neutralize_color_cast_simple(img): + import numpy as np + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + + # cible gris neutre + gray = mean.mean() + + gain = gray / (mean + 1e-6) + + arr[..., 0] *= gain[0] + arr[..., 1] *= gain[1] + arr[..., 2] *= gain[2] + + return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + +def kelvin_to_rgb(temp): + """ + Approximation réaliste Kelvin → RGB (inspiré photographie) + """ + temp = temp / 100.0 + + # Rouge + if temp <= 66: + r = 255 + else: + r = temp - 60 + r = 329.698727446 * (r ** -0.1332047592) + + # Vert + if temp <= 66: + g = temp + g = 99.4708025861 * math.log(g) - 161.1195681661 + else: + g = temp - 60 + g = 288.1221695283 * (g ** -0.0755148492) + + # Bleu + if temp >= 66: + b = 255 + elif temp <= 19: + b = 0 + else: + b = temp - 10 + b = 138.5177312231 * math.log(b) - 305.0447927307 + + return ( + max(0, min(255, r)) / 255.0, + max(0, min(255, g)) / 255.0, + max(0, min(255, b)) / 255.0 + ) + +def adjust_color_temperature( + image, + target_temp=7800, + reference_temp=6500, + strength=0.5, + adaptive=True, + max_gain=2.0, + debug=False +): + import numpy as np + + img = np.array(image).astype(np.float32) / 255.0 + + # --- 1. Gains température (comme ton code) + r1, g1, b1 = kelvin_to_rgb(reference_temp) + r2, g2, b2 = kelvin_to_rgb(target_temp) + + base_gain = np.array([ + r2 / r1, + g2 / g1, + b2 / b1 + ]) + + # --- 2. Estimation rapide du WB actuel (gray-world simplifié) + if adaptive: + mean_rgb = img.reshape(-1, 3).mean(axis=0) + mean_rgb = np.maximum(mean_rgb, 1e-6) + + # normalisation sur G + wb_ratio = mean_rgb / mean_rgb[1] + + # mesure du déséquilibre + imbalance = np.std(wb_ratio) + + # facteur adaptatif doux (évite overcorrection) + adaptive_factor = 1.0 + min(1.0, imbalance * 2.0) + else: + adaptive_factor = 1.0 + + # --- 3. Interpolation (ta logique conservée 💡) + final_gain = (1 - strength) + strength * base_gain * adaptive_factor + + # --- 4. Clamp sécurité (très important en pratique) + final_gain = np.clip(final_gain, 1 / max_gain, max_gain) + + # --- 5. Application + img *= final_gain + + img = np.clip(img, 0, 1) + + if debug: + print("=== DEBUG TEMP ===") + print(f"mean_rgb: {mean_rgb if adaptive else 'disabled'}") + print(f"base_gain: {base_gain}") + print(f"adaptive_factor: {adaptive_factor}") + print(f"final_gain: {final_gain}") + print("==================") + + return Image.fromarray((img * 255).astype(np.uint8)) + + +def adjust_color_temperature_basic(image, target_temp=10000, reference_temp=6500, strength=0.5): + import numpy as np + + img = np.array(image).astype(np.float32) / 255.0 + + r1, g1, b1 = kelvin_to_rgb(reference_temp) + r2, g2, b2 = kelvin_to_rgb(target_temp) + + # 🔥 interpolation (clé) + r_gain = (1 - strength) + strength * (r2 / r1) + g_gain = (1 - strength) + strength * (g2 / g1) + b_gain = (1 - strength) + strength * (b2 / b1) + + img[..., 0] *= r_gain + img[..., 1] *= g_gain + img[..., 2] *= b_gain + + img = np.clip(img, 0, 1) + return Image.fromarray((img * 255).astype(np.uint8)) + +def adjust_color_temperature_simple(image, target_temp=7800, reference_temp=6500): + import numpy as np + + img = np.array(image).astype(np.float32) / 255.0 + + # Gains relatifs (IMPORTANT → comme GIMP) + r1, g1, b1 = kelvin_to_rgb(reference_temp) + r2, g2, b2 = kelvin_to_rgb(target_temp) + + r_gain = r2 / r1 + g_gain = g2 / g1 + b_gain = b2 / b1 + + img[..., 0] *= r_gain + img[..., 1] *= g_gain + img[..., 2] *= b_gain + + img = np.clip(img, 0, 1) + return Image.fromarray((img * 255).astype(np.uint8)) + + +def soft_tone_map(img): + import numpy as np + + arr = np.array(img).astype(np.float32) / 255.0 + + # 🔥 contraste léger (au lieu de compression) + mean = arr.mean(axis=(0,1), keepdims=True) + arr = (arr - mean) * 1.1 + mean + + return Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8)) + +def soft_tone_map_unreal(img, exposure=1.0): + import numpy as np + + arr = np.array(img).astype(np.float32) / 255.0 + + # 🔥 exposure + arr = arr * exposure + + # 🔥 tone mapping type Reinhard (plus naturel) + mapped = arr / (1.0 + arr) + + # 🔥 léger contraste local (clé !) + mapped = np.power(mapped, 0.9) + + return Image.fromarray((np.clip(mapped, 0, 1) * 255).astype(np.uint8)) + + +def soft_tone_map_v1(img): + arr = np.array(img).astype(np.float32) / 255.0 + + # 🔥 compression plus douce (log-like) + arr = np.log1p(arr * 1.5) / np.log1p(1.5) + + # 🔥 léger adoucissement des contrastes + arr = np.power(arr, 0.95) + + return Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8)) + +def soft_tone_map1(img): + arr = np.array(img).astype(np.float32) / 255.0 + arr = arr / (arr + 0.2) + arr = np.power(arr, 0.95) + arr = np.clip(arr, 0, 1) + return Image.fromarray((arr * 255).astype(np.uint8)) + +def apply_n3r_pro_net(latents, model=None, strength=0.3, sanitize_fn=None): + if model is None or strength <= 0: + return latents + + try: + latents = latents.to(next(model.parameters()).dtype) + refined = model(latents) + + # 🔥 différence (detail map) + detail = refined - latents + + # 🔥 SMOOTH du détail (clé !!!) + detail = F.avg_pool2d(detail, kernel_size=3, stride=1, padding=1) + + # 🔥 injection contrôlée + latents = latents + strength * detail + + if sanitize_fn: + latents = sanitize_fn(latents) + + return latents + + except Exception as e: + print(f"[N3RProNet ERROR] {e}") + return latents + + +def apply_n3r_pro_net1(latents, model=None, strength=0.3, sanitize_fn=None): + if model is None or strength <= 0: + return latents + + try: + dtype = next(model.parameters()).dtype + latents = latents.to(dtype) + + refined = model(latents) + + # 🔥 CLAMP SAFE (évite explosion) + refined = torch.clamp(refined, -2.5, 2.5) + + # 🔥 BLEND DOUX (beaucoup plus stable) + latents = (1 - strength) * latents + strength * refined + + # 🔥 NORMALISATION LÉGÈRE + latents = latents / (latents.std(dim=[1,2,3], keepdim=True) + 1e-6) + + if sanitize_fn: + latents = sanitize_fn(latents) + + return latents + + except Exception as e: + print(f"[N3RProNet ERROR] {e}") + return latents + + +def apply_n3r_pro_net_v1(latents, model=None, strength=0.3, sanitize_fn=None, frame_idx=None, total_frames=None): + if model is None or strength <= 0: + return latents + + try: + model_dtype = next(model.parameters()).dtype + model_device = next(model.parameters()).device + latents = latents.to(dtype=model_dtype, device=model_device) + latents = ensure_4_channels(latents) + + if frame_idx is not None and total_frames is not None: + adaptive_strength = strength * (0.3 + 0.7 * 0.5 * (1 - math.cos(math.pi * frame_idx / total_frames))) + else: + adaptive_strength = strength + + refined = model(latents) + + # 🔹 Normalisation du delta pour éviter saturation + delta = refined - latents + max_delta = delta.abs().amax(dim=(1,2,3), keepdim=True).clamp(min=1e-5) + delta = delta / max_delta + latents = latents + adaptive_strength * delta + + # 🔹 Clamp léger pour stabilité + latents = latents / latents.abs().amax(dim=(1,2,3), keepdim=True).clamp(min=1.0) + + if sanitize_fn: + latents = sanitize_fn(latents) + + return latents + + except Exception as e: + print(f"[N3RProNet ERROR] {e}") + return latents + + + +def full_frame_postprocess_add( frame_pil: Image.Image, output_dir: Path, frame_counter: int, target_temp: int = 7800, reference_temp: int = 6500, temp_strength: float = 0.22, blur_radius: float = 0.03, contrast: float = 1.10, saturation: float = 1.0, sharpen_percent: int = 90, psave: bool = True, unreal: bool = False, cartoon: bool = False , glow: bool = False) -> Image.Image: + """ + Returns: + frame_pil final traité + """ + removewhite = False + minimal = False + + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="01", psave=psave) + # 🔥 1. Température + frame_pil = adjust_color_temperature( + frame_pil, + target_temp=target_temp, + reference_temp=reference_temp, + strength=temp_strength + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="02", psave=psave) + + # 🔥 2. Neutralisation de la dominante + frame_pil = neutralize_color_cast(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="03", psave=psave) + + # 🔥 3. Tone mapping + frame_pil = soft_tone_map(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="04", psave=psave) + + # 🔥 4. Post-traitement adaptatif + if minimal: + frame_pil = apply_post_processing_minimal( + frame_pil, + blur_radius=blur_radius, + contrast=contrast, + vibrance_base=1.0, + vibrance_max=1.1, + sharpen=True, + sharpen_radius=1, + sharpen_percent=sharpen_percent, + sharpen_threshold=2 + ) + else: + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=0.03, + contrast=1.10, + vibrance_strength=0.05, # 🔥 contrôle simple (0 → off, 0.3 = doux) + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="05", psave=psave) + + + # 🔥 5. clean white Style + if removewhite: + frame_pil = remove_white_noise(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="06", psave=psave) + + # 🔥 6. Unreal Style + if unreal: + frame_pil = apply_post_processing_unreal_cinematic(frame_pil) + frame_pil = smooth_edges(frame_pil, strength=0.35, blur_radius=1.0) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="07", psave=psave) + + elif cartoon: + # 🔥 6. Cartoon Style + frame_pil = apply_post_processing_sketch(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="08", psave=psave) + + # 🔥 7. Glow Style + if glow: + # Glow forcé pour le style + frame_pil = apply_chromatic_soft_glow(frame_pil) + frame_pil = apply_localized_soft_glow(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="09", psave=psave) + else: + # Glow intelligent + frame_pil = apply_intelligent_glow( frame_pil ) + from PIL import ImageEnhance + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(1.04) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="09", psave=psave) + + return frame_pil + + + +def full_frame_postprocess( + frame_pil: Image.Image, + output_dir: Path, + frame_counter: int, + target_temp: int = 7800, + reference_temp: int = 6500, + temp_strength: float = 0.20, # 🔥 légèrement réduit (moins bleu) + blur_radius: float = 0.025, # 🔥 un peu moins de blur global + contrast: float = 1.08, # 🔥 évite sur-contraste cumulé + sharpen_percent: int = 90, + psave: bool = True, + unreal: bool = False, + cartoon: bool = False +) -> Image.Image: + + # ---------------- 1️⃣ Input ---------------- + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="01", psave=psave) + + # ---------------- 2️⃣ Température ---------------- + frame_pil = adjust_color_temperature( + frame_pil, + target_temp=target_temp, + reference_temp=reference_temp, + strength=temp_strength + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="02", psave=psave) + + # ---------------- 3️⃣ Neutralisation (adoucie) ---------------- + frame_pil = neutralize_color_cast(frame_pil, strength=0.6) # 🔥 clé + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="03", psave=psave) + + # ---------------- 4️⃣ Tone mapping (plus doux) ---------------- + frame_pil = soft_tone_map(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="04", psave=psave) + + # ---------------- 5️⃣ Adaptive (nettoyage + micro boost) ---------------- + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=blur_radius, + contrast=contrast, + vibrance_strength=0.22, # 🔥 légèrement réduit + sharpen=False, + clamp_r=True + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="05", psave=psave) + + # ---------------- 6️⃣ Stylisation ---------------- + if unreal: + frame_pil = apply_post_processing_unreal_cinematic(frame_pil) + frame_pil = smooth_edges(frame_pil, strength=0.30, blur_radius=0.8) # 🔥 moins destructif + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="06", psave=psave) + + elif cartoon: + frame_pil = apply_post_processing_sketch(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="07", psave=psave) + + # ---------------- 7️⃣ Glow intelligent (rééquilibré) ---------------- + # strength=0.15 edge_weight=0.5 luminance_weight=0.8 + + frame_pil = apply_intelligent_glow_pro( + frame_pil, + strength=0.18, # 🔥 moins agressif + edge_weight=0.6, # 🔥 priorise edges + luminance_weight=0.8 # 🔥 glow sur zones lumineuses + ) + + # 🔥 micro contraste FINAL (après glow → très important) + from PIL import ImageEnhance + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(1.04) + + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="09", psave=psave) + + return frame_pil diff --git a/scripts/n3rProtoBoost.py b/scripts/n3rProtoBoost.py new file mode 100644 index 00000000..b5d76193 --- /dev/null +++ b/scripts/n3rProtoBoost.py @@ -0,0 +1,452 @@ +# -------------------------------------------------------------- +# n3rProtoBoost.py - AnimateDiff ultra-light ~2Go VRAM +# Prompt / Input → N3RModelOptimized → MotionModule → UNet → LoRA → VAE → Image / Vidéo +#Avec use_mini_gpu et generate_latents_mini_gpu_320 → ~2,1 Go VRAM, ultra léger ✅ Avec use_n3r_model et N3RModelOptimized → ~3,6 Go VRAM, un peu plus gourmand mais toujours raisonnable ✅ +# -------------------------------------------------------------- +import os, math, threading +from pathlib import Path +from datetime import datetime +import torch +from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image +from PIL import Image, ImageFilter +import argparse + +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.tools_utils import ensure_4_channels, print_generation_params, sanitize_latents, stabilize_latents_advanced, log_debug, compute_overlap, get_interpolated_embeddings +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline, generate_latents_robuste_4D +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, decode_latents_ultrasafe_blockwise, adaptive_post_process, save_frames_as_video_from_folder, encode_images_to_latents_safe, apply_post_processing_adaptive, encode_images_to_latents_hybrid, interpolate_param_fast, fuse_n3r_latents_adaptive, adaptive_post_process, apply_post_processing_unreal_smooth_pro, apply_post_processing_cinematic_ultra_refined_pro, remove_white_noise, apply_post_processing + +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3rModelFast4Go import N3RModelFast4GB, N3RModelLazyCPU, N3RModelOptimized + +LATENT_SCALE = 0.18215 +stop_generation = False + +# Variation de l'interpolation' Valeurs de départ (fidèles à l'image)-----------------------interpolate_param_fast --- +#init_image_scale_start = 0.95 #guidance_scale_start = 1.5 #creative_noise_start = 0.0 + +# Valeurs finales (plus de créativité, moins d'input) +init_image_scale_end = 0.9 +guidance_scale_end = 4.0 +creative_noise_end = 0.0 + +# ------------------------------------------------------------------------------------------- +# --- Sélection simple des embeddings prompts par frame --- +def get_embeddings_for_frame(frame_idx, frames_per_prompt, pos_list, neg_list, device="cuda"): + #Retourne les embeddings du prompt correspondant à la frame_idx. Chaque prompt produit `frames_per_prompt` frames consécutives. + num_prompts = len(pos_list) + prompt_idx = min(frame_idx // frames_per_prompt, num_prompts - 1) + return pos_list[prompt_idx].to(device), neg_list[prompt_idx].to(device) + + +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Utilitaires ---------------- + +def apply_motion_safe(latents, motion_module, threshold=1e-3): + if latents.abs().max() < threshold: + return latents, False + return motion_module(latents), True + +def adapt_embeddings_to_unet(pos_embeds, neg_embeds, target_dim): + """Adapte automatiquement les embeddings texte pour correspondre au cross_attention_dim du UNet.""" + current_dim = pos_embeds.shape[-1] + if current_dim == target_dim: + return pos_embeds, neg_embeds + # Troncature + if current_dim > target_dim: + pos_embeds = pos_embeds[..., :target_dim] + neg_embeds = neg_embeds[..., :target_dim] + # Padding + elif current_dim < target_dim: + pad = target_dim - current_dim + pos_embeds = torch.nn.functional.pad(pos_embeds, (0, pad)) + neg_embeds = torch.nn.functional.pad(neg_embeds, (0, pad)) + return pos_embeds, neg_embeds + +# ---------------- MAIN FIABLE ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + + use_mini_gpu = cfg.get("use_mini_gpu", True) + verbose = cfg.get("verbose", False) + latent_injection = float(cfg.get("latent_injection", 0.5)) + latent_injection = min(max(latent_injection, 0.5), 0.9) # plage sûre + final_latent_scale = cfg.get("final_latent_scale", 1/8) # 1/8 speed, 1/4 moyen, 1/2 low + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 1) + transition_frames = cfg.get("transition_frames", 4) + num_fraps_per_image = cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 2.5) # 0.15 peut de créativité 4.5 moderé + init_image_scale = cfg.get("init_image_scale", 0.5) # 0.85 ou 0.95 proche de l'init' + creative_noise = cfg.get("creative_noise", 0.0) + latent_scale_boost = cfg.get("latent_scale_boost", 1.0) + frames_per_prompt = cfg.get("frames_per_prompt", 10) # nombre de frames par prompt + # Seed aléatoire + seed = torch.randint(0, 100000, (1,)).item() + + + params = { 'fps': fps, 'upscale_factor': upscale_factor, 'num_fraps_per_image': num_fraps_per_image, 'steps': steps, 'guidance_scale': guidance_scale, 'init_image_scale': init_image_scale, 'creative_noise': creative_noise, 'latent_scale_boost': latent_scale_boost, 'final_latent_scale': final_latent_scale, 'seed': seed } + print_generation_params(params) + + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA configuré, étape ignorée.") + + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + if motion_module and verbose: + print(f"[INFO] motion_module type: {type(motion_module)}") + + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device).to(dtype) + + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + # --- Projection adaptative + text_inputs_sample = tokenizer("test", padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + sample_embeds = text_encoder(text_inputs_sample.input_ids.to(device)).last_hidden_state + current_dim = sample_embeds.shape[-1] + projection = None + if current_dim != unet_cross_attention_dim: + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + + # --- Pré-calcul des embeddings pour interpolation + pos_embeds_list = [] + neg_embeds_list = [] + + # Si prompts et n_prompts sont des listes de listes ou chaînes + for i, prompt_item in enumerate(prompts): + # Texte positif + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + # Texte négatif correspondant + neg_text_item = negative_prompts[i] if i < len(negative_prompts) else negative_prompts[0] + neg_text = " ".join(neg_text_item) if isinstance(neg_text_item, list) else str(neg_text_item) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + if projection is not None: + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + + # Ajouter à la liste complète + pos_embeds_list.append(pos_embeds) + neg_embeds_list.append(neg_embeds) + + # ---------------- N3RModelOptimized ---------------- + use_n3r_model = cfg.get("use_n3r_model", False) + n3r_model = None + if use_n3r_model: + n3r_model = N3RModelOptimized( + L_low=cfg.get("n3r_L_low",3), + L_high=cfg.get("n3r_L_high",6), + N_samples=cfg.get("n3r_N_samples",32), + tile_size=cfg.get("n3r_tile_size",64), + cpu_offload=cfg.get("n3r_cpu_offload",True) + ).to(device) + n3r_model.eval() + print(f"✅ N3RModelOptimized initialisé sur {device}") + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/ProtoBoost{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + block_size = cfg.get("block_size", 160) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales VRAM-safe ---------------- + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales avec interpolation prompts ---------------- + + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + try: + # Paramètres interpolés + current_init_image_scale = interpolate_param_fast(init_image_scale, init_image_scale_end, frame_counter, total_frames, mode="cosine") + current_guidance_scale = interpolate_param_fast(guidance_scale, guidance_scale_end, frame_counter, total_frames, mode="cosine") + current_creative_noise = interpolate_param_fast(creative_noise, creative_noise_end, frame_counter, total_frames, mode="cosine") + print(f"[Frame {frame_counter:03d}] " f"init_image_scale={current_init_image_scale:.3f}, " f"guidance_scale={current_guidance_scale:.3f}, " f"creative_noise={current_creative_noise:.3f}") + + # Charger et encoder l'image sur GPU + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_image = ensure_4_channels(input_image) + + current_latent_single = encode_images_to_latents_hybrid(input_image, vae, device=device, latent_scale=LATENT_SCALE) + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(cfg["H"]//8, cfg["W"]//8), + mode='bilinear', align_corners=False + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + + # Génération initiale robuste : + #42 Classique, beaucoup de tests communautaires utilisent ce seed. #1234 Fidèle, stable, souvent utilisé pour des tests de cohérence. + #5555 Fidélité à l’image initiale (ton choix actuel) #2026 Léger changement dans la texture ou la posture, subtil mais prévisible + #9876 Variation un peu plus visible, garde la structure globale + pos_embeds, neg_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device ) + try: + current_latent_single = generate_latents_robuste_4D( + latents=current_latent_single.to(device), + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + scheduler=scheduler, + motion_module=None, + device=device, + dtype=dtype, + guidance_scale=current_guidance_scale, #guidance_scale: 1.5 # un peu plus strict pour que le chat ressorte + init_image_scale=current_init_image_scale, #init_image_scale: 0.85 # presque tout le signal de l'image d'origine + creative_noise=current_creative_noise, # creative_noise: 0.08 # moins de liberté, plus de cohérence + seed=seed # 42, 1234, 2026, 5555 + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + except Exception as e: + print(f"[Robuste INIT ERROR] {e}") + + current_latent_single = ensure_4_channels(current_latent_single) + current_latent_single = current_latent_single.to('cpu') + del input_image + torch.cuda.empty_cache() + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + with torch.no_grad(): + # --- Fusion adaptative avec diminution progressive de l'influence de la frame précédente + injection_start = 0.8 # influence initiale de l'ancienne frame + injection_end = 0.1 # influence finale + injection_alpha = injection_start * (1 - t/(transition_frames-1)) + injection_end * (t/(transition_frames-1)) + + latent_interp = injection_alpha * previous_latent_single.to(device) + (1 - injection_alpha) * current_latent_single.to(device) + # 🔥 FIX NaN / stabilité + latent_interp = sanitize_latents(latent_interp) + + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + # Décodage streaming + latent_interp = latent_interp / LATENT_SCALE # “rescale” avant décodage + # contrast=1.5, saturation=1.3, latent_scale_boost # Recommmander 1.0 + frame_pil = decode_latents_ultrasafe_blockwise( latent_interp, vae, block_size=block_size, overlap=overlap, gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + # contrast=1.5, saturation=1.3, latent_scale_boost # Recommmander 1.0 + frame_pil = apply_post_processing_adaptive(frame_pil, blur_radius=0.05, contrast=1.05, brightness=1.05, saturation=0.80, vibrance_base=1.0, vibrance_max=1.1, sharpen=True, sharpen_radius=1, sharpen_percent=60, sharpen_threshold=2) + # save + print(f"[ init SAVE Frame {frame_counter:03d}]") + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + del latent_interp + torch.cuda.empty_cache() + + # ---------------- Frames principales ---------------- + for f in range(num_fraps_per_image): + if stop_generation: break + with torch.no_grad(): + latents_frame = current_latent_single.to(device) + + # --- Interpolation des embeddings prompts --- + #cf_embeds = get_interpolated_embeddings(frame_counter, total_frames, pos_embeds_list, neg_embeds_list) + #cf_embeds = get_embeddings_for_frame(frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device) + cf_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device ) + + # --- N3R ou mini GPU diffusion --- + n3r_latents = None + latents = latents_frame.clone() + + # 🔥 FIX NaN / stabilité + latents = sanitize_latents(latents) + + #------------------------------------------------- use_n3r_model: + use_n3r_this_frame = use_n3r_model and (frame_counter % 3 == 0) + + if use_n3r_this_frame: + try: + H, W = cfg["H"], cfg["W"] + N_samples = n3r_model.N_samples + + # ------------------- Coordonnées normalisées -1..1 ------------------- + ys = torch.linspace(-1.0, 1.0, H, device=device) + xs = torch.linspace(-1.0, 1.0, W, device=device) + ss = torch.arange(N_samples, device=device, dtype=torch.float32) + + ys, xs, ss = torch.meshgrid(ys, xs, ss, indexing='ij') + coords = torch.stack([xs, ys, ss], dim=-1).reshape(-1, 3) + + # ------------------- Variation temporelle et jitter ------------------- + noise_scale = 0.01 + 0.02 * math.sin(frame_counter * 0.1) + torch.manual_seed(seed) # reproductibilité + jitter = (torch.rand_like(coords) - 0.5) * 0.02 + coords = coords + jitter + torch.randn_like(coords) * noise_scale + coords = torch.nan_to_num(coords) + + # ------------------- Forward N3R ------------------- + n3r_latents_raw = n3r_model(coords, H, W)[:, :3] + n3r_latents = n3r_latents_raw.view(H, W, N_samples, 3).mean(dim=2) + n3r_latents = n3r_latents.permute(2, 0, 1).unsqueeze(0) + + # Ajouter canal alpha si nécessaire + if n3r_latents.shape[1] == 3: + n3r_latents = torch.cat([n3r_latents, torch.zeros_like(n3r_latents[:, :1, :, :])], dim=1) + + # ------------------- Redimensionner si besoin ------------------- + target_H, target_W = latents.shape[-2], latents.shape[-1] + if n3r_latents.shape[-2:] != (target_H, target_W): + n3r_latents = torch.nn.functional.interpolate( + n3r_latents, size=(target_H, target_W), + mode='bilinear', align_corners=False + ).contiguous() + + # ------------------- Fusion adaptative ------------------- + n3r_latents = torch.clamp(n3r_latents, -1.0, 1.0) + n3r_latents = torch.nan_to_num(n3r_latents) + latents = fuse_n3r_latents_adaptive( + latents, + n3r_latents, + latent_injection=latent_injection, + clamp_val=1.0, + creative_noise=0.0 + ) + + # ------------------- Nettoyage final ------------------- + latents = sanitize_latents(latents) + + except Exception as e: + print(f"[N3R ERROR] {e}") + + elif use_mini_gpu: + latents = generate_latents_mini_gpu_320( + unet=unet, scheduler=scheduler, + input_latents=latents_frame, embeddings=cf_embeds, + motion_module=motion_module, guidance_scale=current_guidance_scale, + device=device, fp16=True, steps=steps, + debug=verbose, init_image_scale=current_init_image_scale, + creative_noise=current_creative_noise + ) + if latent_injection > 0: + if latents.shape[-2:] != latents_frame.shape[-2:]: + latents = torch.nn.functional.interpolate( + latents, + size=latents_frame.shape[-2:], + mode='bilinear', align_corners=False + ).contiguous() + latents = latent_injection*latents_frame + (1-latent_injection)*latents + + # --- Motion module --- + #if motion_module: + # latents, _ = apply_motion_safe(latents, motion_module) + + if motion_module is not None: + # 🔥 FIX NaN / stabilité + latents = sanitize_latents(latents) + latents = latents.unsqueeze(2) # [B,C,F,H,W], F=1 + latents = motion_module(latents) # juste les latents + latents = latents.squeeze(2) # revenir à [B,C,H,W] + latents = sanitize_latents(latents) + + # 🔥 stabilisation temporelle KO + #if previous_latent_single is not None: + # latents = 0.85 * latents + 0.15 * previous_latent_single.to(device) + + # 🔥 AUCUN blending → juste update mémoire + previous_latent_single = latents.detach().cpu() + # Clamp et resize final 🔥 FIX NaN / stabilité 🔥 nettoyage final intelligent (LE point clé) + latents = latents / LATENT_SCALE + # contrast=1.5, saturation=1.3, latent_scale_boost # Recommmander 1.0 + frame_pil = decode_latents_ultrasafe_blockwise( latents, vae, block_size=block_size, overlap=overlap, gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + frame_pil = apply_post_processing_adaptive(frame_pil, blur_radius=0.05, contrast=1.15, brightness=1.05, saturation=0.85, vibrance_base=1.1, vibrance_max=1.2, sharpen=True, sharpen_radius=1, sharpen_percent=60, sharpen_threshold=2) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + # Nettoyage VRAM + del latents, latents_frame, cf_embeds, n3r_latents + torch.cuda.empty_cache() + + previous_latent_single = current_latent_single + + except Exception as e: + print(f"[FRAME ERROR] {img_path} : {e}") + continue + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=upscale_factor) + print(f"🎬 Vidéo générée : {out_video}") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rProtoHybrid.py b/scripts/n3rProtoHybrid.py new file mode 100644 index 00000000..bf4663f7 --- /dev/null +++ b/scripts/n3rProtoHybrid.py @@ -0,0 +1,436 @@ +# -------------------------------------------------------------- +# n3rmodelSD_final.py - AnimateDiff ultra-light ~2Go VRAM +# Prompt / Input → N3RModelOptimized → MotionModule → UNet → LoRA → VAE → Image / Vidéo +#Avec use_mini_gpu et generate_latents_mini_gpu_320 → ~2,1 Go VRAM, ultra léger ✅ + +#Avec use_n3r_model et N3RModelOptimized → ~3,6 Go VRAM, un peu plus gourmand mais toujours raisonnable ✅ +# -------------------------------------------------------------- +import os, math, threading +from pathlib import Path +from datetime import datetime +import torch +from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image +from PIL import Image +from PIL import ImageFilter +import argparse + +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.tools_utils import ensure_4_channels +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_safe_miniGPU, generate_latents_mini_gpu, load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline, generate_latents_robuste_4D +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, decode_latents_ultrasafe_blockwise, save_frames_as_video_from_folder, encode_images_to_latents_safe, apply_post_processing +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3rModelFast4Go import N3RModelFast4GB, N3RModelLazyCPU, N3RModelOptimized + +LATENT_SCALE = 0.18215 +stop_generation = False + + +# ---------------- Fusion N3R/VAE adaptative ---------------- + +def fuse_n3r_latents_adaptive(latents_frame, n3r_latents, latent_injection=0.7, clamp_val=1.0, creative_noise=0.0): + n3r_latents = n3r_latents.clone() + + # Normalisation **canal par canal** + for c in range(4): + n3r_c = n3r_latents[:,c:c+1,:,:] + frame_c = latents_frame[:,c:c+1,:,:] + mean, std = n3r_c.mean(), n3r_c.std() + n3r_c = (n3r_c - mean) / (std + 1e-6) # centre / std + n3r_c = n3r_c * frame_c.std() + frame_c.mean() + n3r_latents[:,c:c+1,:,:] = n3r_c + + # Ajouter un bruit créatif léger si nécessaire + if creative_noise > 0.0: + noise = torch.randn_like(n3r_latents) * creative_noise + n3r_latents += noise + + # Clamp stricte pour éviter débordement + n3r_latents = torch.clamp(n3r_latents, -clamp_val, clamp_val) + latents_frame = torch.clamp(latents_frame, -clamp_val, clamp_val) + + # Fusion finale + fused_latents = latent_injection * latents_frame + (1 - latent_injection) * n3r_latents + fused_latents = torch.clamp(fused_latents, -clamp_val, clamp_val) + + print(f"[N3R fusion frame] mean/std par canal: {fused_latents.mean(dim=(2,3))}, injection={latent_injection:.2f}") + return fused_latents + + + +# ---------------- DEBUG UTILS ---------------- +def log_debug(message, level="INFO", verbose=True): + """ + Affiche le message si verbose=True. + level: "INFO", "DEBUG", "WARNING" + """ + if verbose: + print(f"[{level}] {message}") + + +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Utilitaires ---------------- + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + return min(overlap, min(W,H)//4) + +def apply_motion_safe(latents, motion_module, threshold=1e-3): + if latents.abs().max() < threshold: + return latents, False + return motion_module(latents), True + +def adapt_embeddings_to_unet(pos_embeds, neg_embeds, target_dim): + """Adapte automatiquement les embeddings texte pour correspondre au cross_attention_dim du UNet.""" + current_dim = pos_embeds.shape[-1] + if current_dim == target_dim: + return pos_embeds, neg_embeds + # Troncature + if current_dim > target_dim: + pos_embeds = pos_embeds[..., :target_dim] + neg_embeds = neg_embeds[..., :target_dim] + # Padding + elif current_dim < target_dim: + pad = target_dim - current_dim + pos_embeds = torch.nn.functional.pad(pos_embeds, (0, pad)) + neg_embeds = torch.nn.functional.pad(neg_embeds, (0, pad)) + return pos_embeds, neg_embeds + +# ---------------- MAIN ---------------- +# ---------------- MAIN FIABLE ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + + use_mini_gpu = cfg.get("use_mini_gpu", True) + verbose = cfg.get("verbose", False) + latent_injection = max(0.0, min(1.0, cfg.get("latent_injection", 0.7))) + final_latent_scale = cfg.get("final_latent_scale", 1/8) # 1/8 speed, 1/4 moyen, 1/2 low + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 1) + transition_frames = cfg.get("transition_frames", 4) + num_fraps_per_image = cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + latent_scale_boost = cfg.get("latent_scale_boost", 5.71) + + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" upscale_factor : {upscale_factor}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f" latent_scale_boost : {latent_scale_boost}") + print(f" final_latent_scale : {final_latent_scale}") + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA configuré, étape ignorée.") + + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + if motion_module and verbose: + print(f"[INFO] motion_module type: {type(motion_module)}") + + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device).to(dtype) + + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + # --- Projection adaptative + text_inputs_sample = tokenizer("test", padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + sample_embeds = text_encoder(text_inputs_sample.input_ids.to(device)).last_hidden_state + current_dim = sample_embeds.shape[-1] + projection = None + if current_dim != unet_cross_attention_dim: + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + + # --- Pré-calcul des embeddings + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + if projection is not None: + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + + embeddings.append((pos_embeds, neg_embeds)) + + # ---------------- N3RModelOptimized ---------------- + use_n3r_model = cfg.get("use_n3r_model", False) + n3r_model = None + if use_n3r_model: + n3r_model = N3RModelOptimized( + L_low=cfg.get("n3r_L_low",3), + L_high=cfg.get("n3r_L_high",6), + N_samples=cfg.get("n3r_N_samples",32), + tile_size=cfg.get("n3r_tile_size",64), + cpu_offload=cfg.get("n3r_cpu_offload",True) + ).to(device) + n3r_model.eval() + print(f"✅ N3RModelOptimized initialisé sur {device}") + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/ProtoHybrid_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + block_size = cfg.get("block_size", 160) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales VRAM-safe ---------------- + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + try: + # Charger et encoder l'image sur GPU, puis déplacer sur CPU pour stockage + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_image = ensure_4_channels(input_image) + current_latent_single = encode_images_to_latents_safe(input_image, vae, device=device, latent_scale=LATENT_SCALE) + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(cfg["H"]//8, cfg["W"]//8), + mode='bilinear', align_corners=False + ) + + # ------------------- NOUVEAU ------------------- + # Génération robuste des latents initiaux (optionnel, steps réduits pour VRAM) + try: + current_latent_single = generate_latents_robuste_4D( + latents=current_latent_single.to(device), # <-- nom correct + pos_embeds=None, # pas encore de prompt, ou embeddings neutres + neg_embeds=None, + unet=unet, + scheduler=scheduler, + motion_module=None, + device=device, + dtype=dtype, + guidance_scale=1.5, # plus safe pour init + init_image_scale=0.85, # pas 1.0 pour ne pas amplifier NaN + creative_noise=0.0, # pour init, on peut désactiver + seed=42 + ) + except Exception as e: + print(f"[Robuste INIT ERROR] {e}") + # ----------------------------------------------- + current_latent_single = ensure_4_channels(current_latent_single) + # Déplacer sur CPU dès que possible + current_latent_single = current_latent_single.to('cpu') + del input_image + torch.cuda.empty_cache() + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + + # Interpolation sur GPU pour calcul + with torch.no_grad(): + latent_interp = (1-alpha)*previous_latent_single.to(device) + alpha*current_latent_single.to(device) + latent_interp = torch.clamp(latent_interp, -1.0, 1.0).contiguous() + + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + # Resize pour VAE + final_H, final_W = int(cfg["H"]*final_latent_scale), int(cfg["W"]*final_latent_scale) + if latent_interp.shape[-2:] != (final_H, final_W): + latent_interp = torch.nn.functional.interpolate( + latent_interp, size=(final_H, final_W), + mode='bilinear', align_corners=False + ).contiguous() + + # Décodage en streaming vers CPU + frame_pil = decode_latents_ultrasafe_blockwise( + latent_interp, vae, + block_size=block_size, overlap=overlap, + gamma=1.0, brightness=1.0, + contrast=1.5, saturation=1.3, + device=device, # GPU uniquement pour blocs, frame final vers CPU + frame_counter=frame_counter, + latent_scale_boost=latent_scale_boost + ) + + frame_pil = apply_post_processing(frame_pil, blur_radius=0.2) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + del latent_interp + torch.cuda.empty_cache() + + # ---------------- Frames principales ---------------- + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + if stop_generation: break + + with torch.no_grad(): + # Latent de base sur GPU pour calcul + latents_frame = current_latent_single.to(device) + cf_embeds = (pos_embeds.to(device), neg_embeds.to(device)) + latents = latents_frame.clone() + n3r_latents = None + + # --- N3R --- + if use_n3r_model: + try: + H, W = cfg["H"], cfg["W"] + ys, xs, ss = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + torch.arange(n3r_model.N_samples, device=device), + indexing='ij' + ) + coords = torch.stack([xs, ys, ss.float()], dim=-1).reshape(-1,3).float() + n3r_latents_raw = n3r_model(coords, H, W)[:, :3] + n3r_latents = n3r_latents_raw.view(H, W, n3r_model.N_samples, 3).mean(dim=2) + n3r_latents = n3r_latents.permute(2,0,1).unsqueeze(0) + if n3r_latents.shape[1] == 3: + n3r_latents = torch.cat([n3r_latents, torch.zeros_like(n3r_latents[:, :1, :, :])], dim=1) + # Resize et clamp + target_H, target_W = latents.shape[-2], latents.shape[-1] + if n3r_latents.shape[-2:] != (target_H, target_W): + n3r_latents = torch.nn.functional.interpolate( + n3r_latents, size=(target_H, target_W), + mode='bilinear', align_corners=False + ).contiguous() + n3r_latents = torch.clamp(n3r_latents, -1.0, 1.0) + latents = fuse_n3r_latents_adaptive(latents, n3r_latents, latent_injection=latent_injection) + except Exception as e: + print(f"[N3R ERROR] {e}") + + # --- Mini GPU diffusion --- + elif use_mini_gpu: + latents = generate_latents_mini_gpu_320( + unet=unet, scheduler=scheduler, + input_latents=latents_frame, embeddings=cf_embeds, + motion_module=motion_module, guidance_scale=guidance_scale, + device=device, fp16=True, steps=steps, + debug=verbose, init_image_scale=init_image_scale, + creative_noise=creative_noise + ) + if latent_injection > 0: + # ⚡ Adapter la taille avant fusion + if latents.shape[-2:] != latents_frame.shape[-2:]: + latents = torch.nn.functional.interpolate( + latents, + size=latents_frame.shape[-2:], + mode='bilinear', + align_corners=False + ).contiguous() + latents = latent_injection*latents_frame + (1-latent_injection)*latents + + # Motion module + if motion_module: + latents, _ = apply_motion_safe(latents, motion_module) + + # Clamp et resize final + latents = torch.clamp(latents, -1.0, 1.0) + final_H, final_W = int(cfg["H"]*final_latent_scale), int(cfg["W"]*final_latent_scale) + if latents.shape[-2:] != (final_H, final_W): + latents = torch.nn.functional.interpolate(latents, size=(final_H, final_W), + mode='bilinear', align_corners=False).contiguous() + + # Décodage streaming vers CPU + frame_pil = decode_latents_ultrasafe_blockwise( + latents, vae, + block_size=block_size, overlap=overlap, + gamma=1.0, brightness=1.0, + contrast=1.5, saturation=1.3, + device=device, + frame_counter=frame_counter, + latent_scale_boost=latent_scale_boost + ) + frame_pil = apply_post_processing(frame_pil, blur_radius=0.05, contrast=1.15, brightness=1.05, saturation=0.85, sharpen=True, sharpen_radius=1, sharpen_percent=90, sharpen_threshold=2) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + # Nettoyage VRAM + del latents, latents_frame, cf_embeds, n3r_latents + torch.cuda.empty_cache() + + previous_latent_single = current_latent_single # reste sur CPU + + except Exception as e: + print(f"[FRAME ERROR] {img_path} : {e}") + continue + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=upscale_factor) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé avec motion module safe et N3RModelOptimized.") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rRealControl.py b/scripts/n3rRealControl.py new file mode 100644 index 00000000..23febdfd --- /dev/null +++ b/scripts/n3rRealControl.py @@ -0,0 +1,555 @@ +# ---------------------------------------------------------------------------------------- +# n3rRealControl.py - AnimateDiff stables, ProNet + HDR ultra-light ~2Go VRAM - pipeline 4D +# Prompt / Input → N3RModelOptimized → MotionModule → UNet → LoRA → VAE → Image / Vidéo +#Avec use_mini_gpu et generate_latents_mini_gpu_320 → ~2,1 Go VRAM, ultra léger ✅ Avec use_n3r_model et N3RModelOptimized → ~3,6 Go VRAM +# Image input ↓ OpenPose → skeleton (frame t) ↓ ControlNet (condition pose) ↓ UNet (avec pos/neg embeds) ↓ Latents 4D (animés) ↓ N3RProNet (détails + iris + sharpen) ↓ Decode blockwise ↓ Frames animées +# ---------------------------------------------------------------------------------------- +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64" + +import math, threading, random, json, traceback, hashlib, pickle, argparse +from pathlib import Path +from datetime import datetime + +import torch +import torch.nn.functional as F +import torchvision.transforms as T +from torchvision.transforms.functional import to_pil_image +from tqdm import tqdm +from PIL import Image, ImageFilter +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.n3rModelUtils import generate_n3r_coords, process_n3r_latents, fuse_with_memory, inject_external, fuse_n3r_latents_adaptive_new +from scripts.utils.tools_utils import ensure_4_channels, print_generation_params, sanitize_latents, stabilize_latents_advanced, log_debug, compute_overlap, get_interpolated_embeddings, save_memory, load_memory, load_external_embedding_as_latent, inject_external_embeddings, update_n3r_memory, compute_weighted_params, adapt_embeddings_to_unet, get_dynamic_latent_injection, save_input_frame, apply_motion_safe, encode_prompts_batch +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline, generate_latents_robuste_4D +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, adaptive_post_process, save_frames_as_video_from_folder, encode_images_to_latents_safe, encode_images_to_latents_hybrid, interpolate_param_fast, fuse_n3r_latents_adaptive, adaptive_post_process, remove_white_noise + +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3rModelFast4Go import N3RModelFast4GB, N3RModelLazyCPU, N3RModelOptimized +from scripts.utils.n3rProNet import N3RProNet +from scripts.utils.n3rProNet_utils import apply_n3r_pro_net, save_frame_verbose, full_frame_postprocess, decode_latents_ultrasafe_blockwise, get_eye_coords_safe, create_volumetrique_mask, create_eye_mask, tensor_to_pil, apply_pro_net_volumetrique, apply_pro_net_with_eyes, get_eye_coords_safe, scale_eye_coords_to_latents, get_coords, get_coords_safe, decode_latents_ultrasafe_blockwise_pro, decode_latents_ultrasafe_blockwise_sharp, decode_latents_ultrasafe_blockwise_natural, decode_latents_ultrasafe_blockwise_ultranatural +from scripts.utils.n3rControlNet import create_canny_control, control_to_latent, match_latent_size +# OpenPose : +from scripts.utils.n3rOpenPose_utils import generate_pose_sequence, apply_controlnet_openpose_step, load_controlnet_openpose, load_controlnet_openpose_local, match_latent_size, control_to_latent_safe, build_control_latent_debug, convert_json_to_pose_sequence, debug_pose_visual, save_debug_pose_image, fix_pose_sequence, prepare_controlnet, log_frame_error, apply_controlnet_openpose_step_ultrasafe, apply_openpose_tilewise, controlnet_tile_fn + +LATENT_SCALE = 0.18215 +stop_generation = False +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- MAIN FIABLE ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + # Configurable depuis ton fichier cfg + use_mini_gpu = cfg.get("use_mini_gpu", True) + verbose, psave = cfg.get("verbose", False), cfg.get("psave", False) + latent_injection = float(cfg.get("latent_injection", 0.75)) + latent_injection = min(max(latent_injection, 0.5), 0.9) # plage sûre + final_latent_scale = cfg.get("final_latent_scale", 1/8) # 1/8 speed, 1/4 moyen, 1/2 low + fps, upscale_factor = cfg.get("fps", 12), cfg.get("upscale_factor", 1) + transition_frames, num_fraps_per_image = cfg.get("transition_frames", 4), cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 6.5) # 0.15 peut de créativité 4.5 moderé + guidance_scale_end = cfg.get("guidance_scale_end", 7.0) # 0.15 peut de créativité 4.5 moderé + init_image_scale = cfg.get("init_image_scale", 0.75) # 0.85 ou 0.95 proche de l'init' (0.75) + init_image_scale_end = cfg.get("init_image_scale_end", 0.9) # 0.85 ou 0.95 proche de l'init' + creative_noise, creative_noise_end = cfg.get("creative_noise", 0.0), cfg.get("creative_noise_end", 0.08) + latent_scale_boost = cfg.get("latent_scale_boost", 1.0) + frames_per_prompt = cfg.get("frames_per_prompt", 20) # nombre de frames par prompt + contrast, blur_radius, sharpen_percent = cfg.get("contrast", 1.15), cfg.get("blur_radius", 0.03), cfg.get("sharpen_percent", 90) # Post Traitement + H, W = cfg.get("H", 512), cfg.get("W", 512) + block_size = min(32, H//2, W//2) # block_size auto selon résolution + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) # overlap = 64 + + + use_n3r_model, use_n3r_pro_net = cfg.get("use_n3r_model", False), cfg.get("use_n3r_pro_net", True) + use_openpose = cfg.get("use_openpose", True) + controlnet_scale = cfg.get("controlnet_scale", 1.0) # typiquement 0.5 → 1.0 + control_strength = cfg.get("control_strength", 1.5) + + n3r_pro_strength = cfg.get("n3r_pro_strength", 0.2) # 0.1, 0.2, 0.3 + target_temp, reference_temp = 7800, 6500 #target_temp = 8000 reference_temp = 6000 (Froid) + facteur = cfg.get("facteur", 4) # 8, 6, 4 + + # Seed aléatoire + seed = torch.randint(0, 100000, (1,)).item() + params = { 'use_mini_gpu': use_mini_gpu, 'fps': fps, 'upscale_factor': upscale_factor, 'num_fraps_per_image': num_fraps_per_image, 'steps': steps, 'guidance_scale': guidance_scale, 'guidance_scale_end': guidance_scale_end, 'init_image_scale': init_image_scale, 'init_image_scale_end': init_image_scale_end, 'creative_noise': creative_noise, 'creative_noise_end': creative_noise_end, 'latent_scale_boost': latent_scale_boost, 'final_latent_scale': final_latent_scale, 'seed': seed, 'latent_injection': latent_injection, 'transition_frames': transition_frames, 'block_size': block_size, 'use_n3r_model': use_n3r_model } + print_generation_params(params) + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # 🔥 FIX CRITIQUE GLOBAL + if hasattr(scheduler, "alphas_cumprod"): + scheduler.alphas_cumprod = scheduler.alphas_cumprod.cpu() + + if hasattr(scheduler, "final_alpha_cumprod"): + scheduler.final_alpha_cumprod = scheduler.final_alpha_cumprod.cpu() + + if hasattr(scheduler, "timesteps") and isinstance(scheduler.timesteps, torch.Tensor): + scheduler.timesteps = scheduler.timesteps.cpu() + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA configuré, étape ignorée.") + #iniy external_latent + external_latent = None + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + if motion_module and verbose: + print(f"[INFO] motion_module type: {type(motion_module)}") + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to("cpu").to(dtype) + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + # --- Projection adaptative + text_inputs_sample = tokenizer("test", padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + sample_embeds = text_encoder(text_inputs_sample.input_ids.to("cpu")).last_hidden_state + current_dim = sample_embeds.shape[-1] + projection = None + if current_dim != unet_cross_attention_dim: + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + + # --- Pré-calcul des embeddings pour interpolation + # Appel de la fonction - encode_prompts_batch + pos_embeds_list, neg_embeds_list = encode_prompts_batch( prompts=prompts, negative_prompts=negative_prompts, tokenizer=tokenizer, text_encoder=text_encoder, device="cpu", projection=None) + print(f"Pos embeds shape: {pos_embeds_list[0].shape} | Neg embeds shape: {neg_embeds_list[0].shape}") + + # ---------- Input image ----------------------------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + + # ---------------- load_controlnet_openpose ---------------- + if use_openpose: + controlnet = load_controlnet_openpose_local( device=device, dtype=torch.float16, use_fp16=True, debug=True ) + controlnet, pose_sequence = prepare_controlnet( controlnet, device=device, dtype=dtype ) + + try: + base_dir = Path(__file__).resolve().parent + json_file = base_dir / "json" / "anim2.json" + + with open(json_file, "r") as f: + anim_data = json.load(f) + + print(f"✅ JSON chargé : {json_file}") + pose_sequence = convert_json_to_pose_sequence( anim_data, H=cfg["H"], W=cfg["W"], device=device, dtype=dtype, debug=True) + + if pose_sequence is None: + print("❌ Aucun pose_sequence → OpenPose désactivé") + use_openpose = False + else: + # 🔥 Fix interpolation + pose_sequence = fix_pose_sequence( pose_sequence, total_frames=total_frames, device=device, dtype=dtype ) + + except Exception as e: + print(f"[Load Json animation INIT ERROR] {e}") + + # ---------------- N3RModelOptimized ---------------- + n3r_model = None + if use_n3r_model: + n3r_model = N3RModelOptimized( + L_low=cfg.get("n3r_L_low",3), L_high=cfg.get("n3r_L_high",6), + N_samples=cfg.get("n3r_N_samples",32), tile_size=cfg.get("n3r_tile_size",64), + cpu_offload=cfg.get("n3r_cpu_offload",True) + ).to(device) + n3r_model.eval() + print(f"✅ N3RModelOptimized initialisé sur {device}") + + # ------------------- Initialisation mémoire ------------------- + output_dir_m = Path("./outputs") + memory_file = output_dir_m / "n3r_memory" + memory_dict = load_memory(memory_file) + # ---------------- n3r_pro_net ---------------- + n3r_pro_net = None + if use_n3r_pro_net: + n3r_pro_net = N3RProNet(channels=4).to(device).to(dtype) + n3r_pro_net.eval() + print("✅ N3RProNet activé") + + # ---------------- Input ---------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/RealControl{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales avec interpolation prompts ---------------- + external_embeddings = None + + # Charger latent externe avant la génération + external_path = "/mnt/62G/huggingface/cyber-fp16/pt/KnxCOmiXNeg.safetensors" + external_latent = load_external_embedding_as_latent( + external_path, (1, 4, cfg["H"]//facteur, cfg["W"]//facteur) + ).to(device) + + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + try: + # Paramètres interpolés + current_init_image_scale, current_creative_noise, current_guidance_scale = compute_weighted_params( frame_counter, total_frames, init_start=init_image_scale, init_end=init_image_scale_end,noise_start=creative_noise, noise_end=creative_noise_end, guidance_start=guidance_scale, guidance_end=guidance_scale_end, mode="cosine" ) + print(f"[Frame {frame_counter:03d}] " f"init_image_scale={current_init_image_scale:.3f}, " f"guidance_scale={current_guidance_scale:.3f}, " f"creative_noise={current_creative_noise:.3f}") + + # Charger et encoder l'image sur GPU + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + # ---------------- Pose sequence --------------------------------------------- + # start_pose = tensor 4D BCHW directement + start_pose = input_image.to(device=device, dtype=dtype) + # Pose sequence + if use_openpose and pose_sequence is None: + pose_sequence = generate_pose_sequence(base_pose=start_pose, num_frames=total_frames, device=device, dtype=dtype, debug=True) + # 🔥 Détection yeux (une seule fois par image) + input_pil = tensor_to_pil(input_image) # à créer si tu ne l'as pas + + # 🔥 n3rControl - encode Canny en sécurité------------------------------------------------------------------------------------ + base_control_latent = build_control_latent_debug( + input_pil, + vae, + device="cuda", + latent_scale=LATENT_SCALE + ) + base_control_latent = sanitize_latents(base_control_latent) + base_control_latent = torch.clamp(torch.nan_to_num(base_control_latent), -1.0, 1.0) + + control_latent = base_control_latent + 0.01 * torch.randn_like(base_control_latent, dtype=torch.float16, device="cuda") + control_latent = sanitize_latents(control_latent) + # ----------------------------------------------------------------------------------------- + # coordonner masque eye et masque volumetrique + eye_coords = get_eye_coords_safe(input_pil) + coords_v = get_coords_safe( input_pil, H=cfg["H"], W=cfg["W"] ) + input_image = ensure_4_channels(input_image) + if frame_counter > 0: + initframe = frame_counter+transition_frames + else: + initframe = frame_counter + save_input_frame( input_image, output_dir, initframe, pbar=pbar, blur_radius=blur_radius, contrast=contrast, saturation=1.0, apply_post=False ) + + current_latent_single = encode_images_to_latents_hybrid(input_image, vae, device=device, latent_scale=LATENT_SCALE) + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(cfg["H"]//facteur, cfg["W"]//facteur), + #current_latent_single, size=(cfg["H"]//6, cfg["W"]//6), + mode='bilinear', align_corners=False + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + # Génération initiale robuste : + pos_embeds, neg_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device, debug=False) + try: + current_latent_single = generate_latents_robuste_4D( + latents=current_latent_single.to(device), + pos_embeds=pos_embeds, neg_embeds=neg_embeds, unet=unet, scheduler=scheduler, + motion_module=None, device=device, dtype=dtype, + guidance_scale=current_guidance_scale, #guidance_scale: 1.5 # un peu plus strict pour que le chat ressorte + init_image_scale=current_init_image_scale, #init_image_scale: 0.85 # presque tout le signal de l'image d'origine + creative_noise=current_creative_noise, seed=seed # 42, 1234, 2026, 5555 + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + except Exception as e: + print(f"[Robuste INIT ERROR] {e}") + + current_latent_single = ensure_4_channels(current_latent_single) + current_latent_single = current_latent_single.to('cpu') + del input_image + torch.cuda.empty_cache() + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + with torch.no_grad(): + # --- Fusion adaptative avec diminution progressive de l'influence de la frame précédente + injection_start = 0.01 # influence initiale de l'ancienne frame + injection_end = 0.1 # influence finale + denom = max(transition_frames-1, 1) + injection_alpha = injection_start * (1 - t/denom) + injection_end * (t/denom) + + latent_interp = injection_alpha * previous_latent_single.to(device) + (1 - injection_alpha) * current_latent_single.to(device) + # 🔥 FIX NaN / stabilité + latent_interp = sanitize_latents(latent_interp) + + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + # Application de n3r_pro_net - réutilisé pour toutes les frames - creation des masques + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latent_interp.shape[-2], lat_W=latent_interp.shape[-1] ) + if eye_coords_latent: + eye_mask = create_eye_mask(latent_interp, eye_coords_latent) + volume_mask = create_volumetrique_mask(latent_interp, coords_v, debug=False) + # Application du ProNet tout en protégeant les yeux + if use_n3r_pro_net: + latents = apply_pro_net_volumetrique(latent_interp, coords_v, n3r_pro_net, n3r_pro_strength, sanitize_latents, debug=False) + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latents.shape[-2], lat_W=latents.shape[-1] ) + if eye_coords_latent: + latents = apply_pro_net_with_eyes(latents, eye_coords_latent, n3r_pro_net, n3r_pro_strength, sanitize_fn=sanitize_latents) + + # Décodage streaming + latent_interp = latent_interp / LATENT_SCALE # “rescale” avant décodage + print(f"Dimention frame inter: Shape de latent_interp :", latent_interp.shape) + frame_pil = decode_latents_ultrasafe_blockwise_ultranatural( latent_interp, vae, block_size=block_size, overlap=overlap, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + + #Post Traitement + frame_pil = full_frame_postprocess( frame_pil, output_dir, frame_counter, target_temp=target_temp, reference_temp=reference_temp, blur_radius=blur_radius, contrast=contrast, sharpen_percent=sharpen_percent, psave=psave ) + save_frame_verbose(frame_pil, output_dir, frame_counter-1, suffix="0i", psave=True) + frame_counter += 1 + pbar.update(1) + + del latent_interp + torch.cuda.empty_cache() + + # ---------------- Frames principales ---------------- + for f in range(num_fraps_per_image): + if stop_generation: + break + with torch.no_grad(): + latents_frame = current_latent_single.to(device) + print(f"Dimention inital: Shape de latents_frame :", latents_frame.shape) + + # --- Interpolation des embeddings prompts --- + cf_embeds = get_interpolated_embeddings(frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device, debug=False) + latents = sanitize_latents(latents_frame.clone()) # 🔥 FIX NaN / stabilité + # --- volume mask --- + volume_mask = create_volumetrique_mask(latents, coords_v, debug=False) + control_weight_map = 0.05 + 0.25 * volume_mask**1.5 + control_latent = sanitize_latents(base_control_latent + 0.005 * torch.randn_like(base_control_latent)) + control_latent, control_weight_map = match_latent_size(latents, control_latent, control_weight_map) + + # ---------------- N3R avec mémoire latente conditionnée ---------------- + use_n3r_this_frame = math.sin(frame_counter * 0.2) > 0.7 + #control_strength = 0.05 * (1 - frame_counter / total_frames) + 0.02 + print(f"[DEBUG] 🧠 Pose control_strength ={control_strength:.4f}") + print(f"[DEBUG] 🧠 Pose controlnet_scale ={controlnet_scale:.4f}") + + + if use_n3r_this_frame and use_n3r_model: + print(f"[DEBUG] N3R active: {use_n3r_this_frame}") + try: + H, W = latents.shape[-2], latents.shape[-1] + coords = generate_n3r_coords(H, W, n3r_model.N_samples, seed, frame_counter, device) + n3r_latents = process_n3r_latents(n3r_model, coords, H, W, H, W) + fused_latents = fuse_with_memory(n3r_latents, memory_dict, cf_embeds, frame_counter) + external_weight = 0.2 * (1 - frame_counter / total_frames) + fused_latents = (1 - external_weight) * fused_latents + external_weight * external_latent + latents = fuse_n3r_latents_adaptive_new(latents, fused_latents, frame_counter, + total_frames=total_frames, + latent_injection_start=0.90, latent_injection_end=0.55) + latents = sanitize_latents(latents) + except Exception as e: + print(f"[N3R ERROR] {e}") + + if frame_counter % 30 == 0: + save_memory(memory_dict, memory_file) + + # ---------------- Mini-GPU diffusion ---------------- + if use_mini_gpu: + mini_latents = generate_latents_mini_gpu_320( + unet=unet, scheduler=scheduler, input_latents=latents, embeddings=cf_embeds, motion_module=motion_module, guidance_scale=current_guidance_scale, + device=device, fp16=True, steps=steps, debug=verbose, init_image_scale=current_init_image_scale, creative_noise=current_creative_noise + ) + mini_weight = (1 - frame_counter / total_frames) * (1 - latent_injection) + # S'assurer que les dimensions correspondent + mini_latents = match_latent_size(latents, mini_latents) + # Fusion pondérée + latents = (1 - mini_weight) * latents + mini_weight * mini_latents + latents = sanitize_latents(latents) + + # ---------------- ControlNet OpenPose ------------------------ + if use_openpose: + try: + import torch.nn.functional as F + import traceback + + # 🔥 dtype cible réel (UNet) + target_dtype = next(unet.parameters()).dtype + + # 🔹 ===== 1. PREPARE POSE FULL ===== + pose_full = pose_sequence[frame_counter % pose_sequence.shape[0]] + + # → BCHW + if pose_full.ndim == 3: + if pose_full.shape[0] in [1, 3]: # C,H,W + pose_full = pose_full.unsqueeze(0) + else: # H,W,C + pose_full = pose_full.permute(2, 0, 1).unsqueeze(0) + + # → channels fix + if pose_full.shape[1] > 3: + pose_full = pose_full[:, :3] + elif pose_full.shape[1] == 1: + pose_full = pose_full.repeat(1, 3, 1, 1) + + # → normalize [-1,1] + pose_full = (pose_full - 0.5) * 2.0 + pose_full = torch.clamp(pose_full, -1.0, 1.0) + + # → device + dtype + pose_full = pose_full.to(device=device, dtype=target_dtype) + + print(f"[DEBUG] Pose full {pose_full.shape} dtype={pose_full.dtype}") + + # 🔹 ===== 2. BUILD LATENT-SPACE POSE (CRUCIAL) ===== + latent_h, latent_w = latents.shape[2], latents.shape[3] + + pose_latent_full = F.interpolate( + pose_full, + size=(latent_h, latent_w), + mode='bilinear', + align_corners=False + ).to(device=device, dtype=target_dtype) + + print(f"[DEBUG] Pose latent {pose_latent_full.shape} dtype={pose_latent_full.dtype}") + + # 🔹 backup latents + latents_before_openpose = latents.clone() + + print(f"[DEBUG] Latents avant OpenPose min={latents.min().item():.4f}, max={latents.max().item():.4f}") + + from functools import partial + + # Préparer la tile function avec tous les arguments sauf latent_tile et tile_coords + tile_fn_partial = partial( controlnet_tile_fn, frame_counter=frame_counter, pose_full=pose_full, unet=unet, controlnet=controlnet, scheduler=scheduler, cf_embeds=cf_embeds, current_guidance_scale=current_guidance_scale, controlnet_scale=controlnet_scale, device=device, target_dtype=target_dtype ) + + # 🔹 Appel correct de apply_openpose_tilewise + latents = apply_openpose_tilewise( latents, pose_latent_full, tile_fn_partial, block_size=block_size, overlap=overlap, device=device, debug=True, debug_dir=output_dir,frame_idx=frame_counter) + + # 🔹 ===== 5. CLEANUP ===== + latents = torch.nan_to_num(latents) + latents = sanitize_latents(latents) + latents = torch.clamp(latents, -0.85, 0.85) + + diff = (latents - latents_before_openpose).abs().mean() + + print(f"[DEBUG] OpenPose impact: {diff.item():.6f}") + print(f"[DEBUG] Latents après OpenPose min={latents.min().item():.4f}, max={latents.max().item():.4f}") + + except Exception: + print("[ERROR] ControlNet OpenPose failed:") + traceback.print_exc() + latents = latents_before_openpose.clone() + + # 🔹 debug image + save_debug_pose_image(pose_full, frame_counter, output_dir, cfg, prefix="openpose") + + # ---------------- Injection finale ControlNet ---------------- + control_latent, control_weight_map = match_latent_size(latents, control_latent, control_weight_map) + latents = latents + control_strength * control_weight_map * control_latent + latents = sanitize_latents(latents) + latents = torch.clamp(latents, -0.85, 0.85) + print(f"[DEBUG] Après Injection finale ControlNet min={latents.min().item():.4f}, max={latents.max().item():.4f}, NaN={torch.isnan(latents).any().item()}") + + # ---------------- Fusion frame + latent injection ---------------- + if latent_injection > 0: + if latents.shape[-2:] != latents_frame.shape[-2:]: + latents = torch.nn.functional.interpolate(latents, size=latents_frame.shape[-2:], + mode='bilinear', align_corners=False).contiguous() + latents = latent_injection * latents_frame + (1 - latent_injection) * latents + latents = sanitize_latents(latents) + print(f"[DEBUG] Après Fusion frame min={latents.min().item():.4f}, max={latents.max().item():.4f}, NaN={torch.isnan(latents).any().item()}") + + # ---------------- Motion module ---------------- + if motion_module is not None: + latents_seq = latents.unsqueeze(2).repeat(1, 1, 3, 1, 1) if previous_latent_single is None \ + else torch.stack([previous_latent_single.to(device), latents, latents + 0.01 * torch.randn_like(latents)], dim=2) + latents_seq = sanitize_latents(latents_seq) + latents_seq, applied = apply_motion_safe(latents_seq, motion_module) + latents = latents_seq[:, :, 1, :, :] if applied else latents + latents = sanitize_latents(latents) + print(f"[DEBUG] Après Motion module min={latents.min().item():.4f}, max={latents.max().item():.4f}, NaN={torch.isnan(latents).any().item()}") + + # ---------------- ProNet yeux ---------------- + if use_n3r_pro_net: + latents = apply_pro_net_volumetrique(latents, coords_v, n3r_pro_net, n3r_pro_strength, sanitize_latents, debug=False) + print(f"[DEBUG] Après ProNet volumetrique min={latents.min().item():.4f}, max={latents.max().item():.4f}, NaN={torch.isnan(latents).any().item()}") + + eye_coords_latent = scale_eye_coords_to_latents(eye_coords, img_H=cfg["H"], img_W=cfg["W"], + lat_H=latents.shape[-2], lat_W=latents.shape[-1]) + if eye_coords_latent: + latents = apply_pro_net_with_eyes(latents, eye_coords_latent, n3r_pro_net, n3r_pro_strength, + sanitize_fn=sanitize_latents) + print(f"[DEBUG] Après ProNet yeux min={latents.min().item():.4f}, max={latents.max().item():.4f}, NaN={torch.isnan(latents).any().item()}") + + # ---------------- Clamp latents ---------------- + latents = torch.clamp(latents, -1.5, 1.5) + # ---------------- Décodage final ---------------- + # 🔥 SANITY AVANT DECODE + latents = sanitize_latents(latents) + latents = torch.clamp(latents, -1.0, 1.0) + print("FINAL LATENTS SAFE:", latents.min().item(), latents.max().item()) + latents = latents / LATENT_SCALE + print(f"Dimention : Shape de latents :", latents.shape) + frame_pil = decode_latents_ultrasafe_blockwise_ultranatural(latents, vae, block_size=block_size, overlap=overlap, device=device, + frame_counter=frame_counter, latent_scale_boost=latent_scale_boost + ) + frame_pil = full_frame_postprocess(frame_pil, output_dir, frame_counter, target_temp=target_temp, reference_temp=reference_temp, + blur_radius=blur_radius, contrast=contrast, sharpen_percent=sharpen_percent, psave=psave) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="0f", psave=True) + + previous_latent_single = latents.detach().cpu() + frame_counter += 1 + pbar.update(1) + for var in ["latents", "latents_frame", "cf_embeds", "n3r_latents"]: + if var in locals(): + del locals()[var] + torch.cuda.empty_cache() + + previous_latent_single = current_latent_single + + except Exception as e: + log_frame_error(img_path, e) + continue + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=upscale_factor) + print(f"🎬 Vidéo générée : {out_video}") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rRealProNet.py b/scripts/n3rRealProNet.py new file mode 100644 index 00000000..e19e6d96 --- /dev/null +++ b/scripts/n3rRealProNet.py @@ -0,0 +1,415 @@ +# ---------------------------------------------------------------------------------------- +# n3rProBoostNet.py - AnimateDiff stables, ProNet + HDR ultra-light ~2Go VRAM - pipeline 4D +# Prompt / Input → N3RModelOptimized → MotionModule → UNet → LoRA → VAE → Image / Vidéo +#Avec use_mini_gpu et generate_latents_mini_gpu_320 → ~2,1 Go VRAM, ultra léger ✅ Avec use_n3r_model et N3RModelOptimized → ~3,6 Go VRAM +# ---------------------------------------------------------------------------------------- +import os, math, threading, random +import traceback +import hashlib +import torch +import pickle +from pathlib import Path +from datetime import datetime +from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image +from PIL import Image, ImageFilter +import argparse +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.n3rModelUtils import generate_n3r_coords, process_n3r_latents, fuse_with_memory, inject_external, fuse_n3r_latents_adaptive_new +from scripts.utils.tools_utils import ensure_4_channels, print_generation_params, sanitize_latents, stabilize_latents_advanced, log_debug, compute_overlap, get_interpolated_embeddings, save_memory, load_memory, load_external_embedding_as_latent, inject_external_embeddings, update_n3r_memory, compute_weighted_params, adapt_embeddings_to_unet, get_dynamic_latent_injection, save_input_frame, apply_motion_safe, encode_prompts_batch +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline, generate_latents_robuste_4D +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, adaptive_post_process, save_frames_as_video_from_folder, encode_images_to_latents_safe, encode_images_to_latents_hybrid, interpolate_param_fast, fuse_n3r_latents_adaptive, adaptive_post_process, remove_white_noise + +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3rModelFast4Go import N3RModelFast4GB, N3RModelLazyCPU, N3RModelOptimized +from scripts.utils.n3rProNet import N3RProNet +from scripts.utils.n3rProNet_utils import apply_n3r_pro_net, save_frame_verbose, full_frame_postprocess, decode_latents_ultrasafe_blockwise, get_eye_coords_safe, create_volumetrique_mask, create_eye_mask, tensor_to_pil, apply_pro_net_volumetrique, apply_pro_net_with_eyes, get_eye_coords_safe, scale_eye_coords_to_latents, get_coords, get_coords_safe, decode_latents_ultrasafe_blockwise_pro, decode_latents_ultrasafe_blockwise_sharp, decode_latents_ultrasafe_blockwise_natural, decode_latents_ultrasafe_blockwise_ultranatural +from scripts.utils.n3rControlNet import create_canny_control, control_to_latent, match_latent_size + +LATENT_SCALE = 0.18215 +stop_generation = False +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- MAIN FIABLE ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + # Configurable depuis ton fichier cfg + use_mini_gpu = cfg.get("use_mini_gpu", True) + verbose, psave = cfg.get("verbose", False), cfg.get("psave", False) + latent_injection = float(cfg.get("latent_injection", 0.75)) + latent_injection = min(max(latent_injection, 0.5), 0.9) # plage sûre + final_latent_scale = cfg.get("final_latent_scale", 1/8) # 1/8 speed, 1/4 moyen, 1/2 low + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 1) + transition_frames = cfg.get("transition_frames", 4) + num_fraps_per_image = cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 6.5) # 0.15 peut de créativité 4.5 moderé + guidance_scale_end = cfg.get("guidance_scale_end", 7.0) # 0.15 peut de créativité 4.5 moderé + init_image_scale = cfg.get("init_image_scale", 0.75) # 0.85 ou 0.95 proche de l'init' (0.75) + init_image_scale_end = cfg.get("init_image_scale_end", 0.9) # 0.85 ou 0.95 proche de l'init' + creative_noise = cfg.get("creative_noise", 0.0) + creative_noise_end = cfg.get("creative_noise_end", 0.08) + latent_scale_boost = cfg.get("latent_scale_boost", 1.0) + frames_per_prompt = cfg.get("frames_per_prompt", 10) # nombre de frames par prompt + contrast = cfg.get("contrast", 1.15) # Post Traitement constrat + blur_radius = cfg.get("blur_radius", 0.03) # Post Traitement blur + sharpen_percent = cfg.get("sharpen_percent", 90) #Post Traitement sharpen + H, W = cfg.get("H", 512), cfg.get("W", 512) + block_size = min(256, H//2, W//2) # block_size auto selon résolution + use_n3r_model = cfg.get("use_n3r_model", False) + use_n3r_pro_net = cfg.get("use_n3r_pro_net", True) + n3r_pro_strength = cfg.get("n3r_pro_strength", 0.2) # 0.1, 0.2, 0.3 + #target_temp = 8000 reference_temp = 6000 (Froid) + target_temp, reference_temp = 7800, 6500 + + # Seed aléatoire + seed = torch.randint(0, 100000, (1,)).item() + params = { 'use_mini_gpu': use_mini_gpu, 'fps': fps, 'upscale_factor': upscale_factor, 'num_fraps_per_image': num_fraps_per_image, 'steps': steps, 'guidance_scale': guidance_scale, 'guidance_scale_end': guidance_scale_end, 'init_image_scale': init_image_scale, 'init_image_scale_end': init_image_scale_end, 'creative_noise': creative_noise, 'creative_noise_end': creative_noise_end, 'latent_scale_boost': latent_scale_boost, 'final_latent_scale': final_latent_scale, 'seed': seed, 'latent_injection': latent_injection, 'transition_frames': transition_frames, 'block_size': block_size, 'use_n3r_model': use_n3r_model } + print_generation_params(params) + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA configuré, étape ignorée.") + #iniy external_latent + external_latent = None + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + if motion_module and verbose: + print(f"[INFO] motion_module type: {type(motion_module)}") + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device).to(dtype) + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + # --- Projection adaptative + text_inputs_sample = tokenizer("test", padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + sample_embeds = text_encoder(text_inputs_sample.input_ids.to(device)).last_hidden_state + current_dim = sample_embeds.shape[-1] + projection = None + if current_dim != unet_cross_attention_dim: + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + + # --- Pré-calcul des embeddings pour interpolation + # Appel de la fonction - encode_prompts_batch + pos_embeds_list, neg_embeds_list = encode_prompts_batch( prompts=prompts, negative_prompts=negative_prompts, tokenizer=tokenizer, text_encoder=text_encoder, device="cuda", projection=None) + # pos_embeds_list et neg_embeds_list sont des listes de tenseurs [1, seq_len, dim] + print(f"Pos embeds shape: {pos_embeds_list[0].shape}") + print(f"Neg embeds shape: {neg_embeds_list[0].shape}") + + # ---------------- N3RModelOptimized ---------------- + n3r_model = None + if use_n3r_model: + n3r_model = N3RModelOptimized( + L_low=cfg.get("n3r_L_low",3), # 3 ou 4 # plutôt que 3, un peu plus de finesse + L_high=cfg.get("n3r_L_high",6), # garde structure globale + N_samples=cfg.get("n3r_N_samples",32), # plus de samples pour un rendu détaillé 48 + tile_size=cfg.get("n3r_tile_size",64), # inchangé pour VRAM raisonnable + cpu_offload=cfg.get("n3r_cpu_offload",True) + ).to(device) + n3r_model.eval() + print(f"✅ N3RModelOptimized initialisé sur {device}") + + # ------------------- Initialisation mémoire ------------------- + output_dir_m = Path("./outputs") + memory_file = output_dir_m / "n3r_memory" + memory_dict = load_memory(memory_file) + # ---------------- n3r_pro_net ---------------- + n3r_pro_net = None + if use_n3r_pro_net: + n3r_pro_net = N3RProNet(channels=4).to(device).to(dtype) + n3r_pro_net.eval() + print("✅ N3RProNet activé") + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/RealProNet{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------- Frames principales avec interpolation prompts ---------------- + external_embeddings = None + + # Charger latent externe avant la génération + external_path = "/mnt/62G/huggingface/cyber-fp16/pt/KnxCOmiXNeg.safetensors" + external_latent = load_external_embedding_as_latent( + external_path, (1, 4, cfg["H"]//8, cfg["W"]//8) + ).to(device) + #------------------------------------------------------------------------------ + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + try: + # Paramètres interpolés + current_init_image_scale, current_creative_noise, current_guidance_scale = compute_weighted_params( frame_counter, total_frames, init_start=0.85, init_end=0.5,noise_start=0.0, noise_end=0.08, guidance_start=3.5, guidance_end=4.5, mode="cosine" ) + print(f"[Frame {frame_counter:03d}] " f"init_image_scale={current_init_image_scale:.3f}, " f"guidance_scale={current_guidance_scale:.3f}, " f"creative_noise={current_creative_noise:.3f}") + + # Charger et encoder l'image sur GPU + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + # 🔥 Détection yeux (une seule fois par image) + input_pil = tensor_to_pil(input_image) # à créer si tu ne l'as pas + + # 🔥 n3rControlNet + # Convertir en 3 canaux pour VAE + base_control = create_canny_control(input_pil) + if base_control.shape[1] == 1: # si 1 canal + base_control = base_control.repeat(1,3,1,1) # dupliquer pour RGB + base_control = base_control.to(dtype=torch.float16, device=device) + base_control_latent = control_to_latent(base_control, vae, device, LATENT_SCALE) + control_latent = base_control_latent + 0.01 * torch.randn_like(base_control_latent, dtype=torch.float16, device=device) + + # coordonner masque eye et masque volumetrique + eye_coords = get_eye_coords_safe(input_pil) + coords_v = get_coords_safe( input_pil, H=cfg["H"], W=cfg["W"] ) + input_image = ensure_4_channels(input_image) + if frame_counter > 0: + initframe = frame_counter+transition_frames + else: + initframe = frame_counter + save_input_frame( input_image, output_dir, initframe, pbar=pbar, blur_radius=blur_radius, contrast=contrast, saturation=1.0, apply_post=False ) + + current_latent_single = encode_images_to_latents_hybrid(input_image, vae, device=device, latent_scale=LATENT_SCALE) + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(cfg["H"]//8, cfg["W"]//8), + #current_latent_single, size=(cfg["H"]//6, cfg["W"]//6), + mode='bilinear', align_corners=False + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + # Génération initiale robuste : + pos_embeds, neg_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device, debug=False) + try: + current_latent_single = generate_latents_robuste_4D( + latents=current_latent_single.to(device), + pos_embeds=pos_embeds, neg_embeds=neg_embeds, unet=unet, scheduler=scheduler, + motion_module=None, device=device, dtype=dtype, + guidance_scale=current_guidance_scale, #guidance_scale: 1.5 # un peu plus strict pour que le chat ressorte + init_image_scale=current_init_image_scale, #init_image_scale: 0.85 # presque tout le signal de l'image d'origine + creative_noise=current_creative_noise, seed=seed # 42, 1234, 2026, 5555 + ) + + # 🔥 FIX NaN / stabilité + current_latent_single = sanitize_latents(current_latent_single) + except Exception as e: + print(f"[Robuste INIT ERROR] {e}") + + current_latent_single = ensure_4_channels(current_latent_single) + current_latent_single = current_latent_single.to('cpu') + del input_image + torch.cuda.empty_cache() + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + with torch.no_grad(): + # --- Fusion adaptative avec diminution progressive de l'influence de la frame précédente + injection_start = 0.8 # influence initiale de l'ancienne frame + injection_end = 0.1 # influence finale + denom = max(transition_frames-1, 1) + injection_alpha = injection_start * (1 - t/denom) + injection_end * (t/denom) + + latent_interp = injection_alpha * previous_latent_single.to(device) + (1 - injection_alpha) * current_latent_single.to(device) + # 🔥 FIX NaN / stabilité + latent_interp = sanitize_latents(latent_interp) + + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + # Application de n3r_pro_net - réutilisé pour toutes les frames - creation des masques + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latent_interp.shape[-2], lat_W=latent_interp.shape[-1] ) + if eye_coords_latent: + eye_mask = create_eye_mask(latent_interp, eye_coords_latent) + volume_mask = create_volumetrique_mask(latent_interp, coords_v, debug=False) + # Application du ProNet tout en protégeant les yeux + if use_n3r_pro_net: + latents = apply_pro_net_volumetrique(latent_interp, coords_v, n3r_pro_net, n3r_pro_strength, sanitize_latents, debug=False) + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latents.shape[-2], lat_W=latents.shape[-1] ) + if eye_coords_latent: + latents = apply_pro_net_with_eyes(latents, eye_coords_latent, n3r_pro_net, n3r_pro_strength, sanitize_fn=sanitize_latents) + + # Décodage streaming + latent_interp = latent_interp / LATENT_SCALE # “rescale” avant décodage + frame_pil = decode_latents_ultrasafe_blockwise_ultranatural( latent_interp, vae, block_size=block_size, overlap=overlap, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost ) + + #Post Traitement + frame_pil = full_frame_postprocess( frame_pil, output_dir, frame_counter, target_temp=target_temp, reference_temp=reference_temp, blur_radius=blur_radius, contrast=contrast, sharpen_percent=sharpen_percent, psave=psave ) + save_frame_verbose(frame_pil, output_dir, frame_counter-1, suffix="0i", psave=True) + frame_counter += 1 + pbar.update(1) + + del latent_interp + torch.cuda.empty_cache() + + # ---------------- Frames principales ---------------- + + for f in range(num_fraps_per_image): + if stop_generation: break + with torch.no_grad(): + latents_frame = current_latent_single.to(device) + + # --- Interpolation des embeddings prompts --- + cf_embeds = get_interpolated_embeddings( frame_counter, frames_per_prompt, pos_embeds_list, neg_embeds_list, device, debug=False) + + # --- N3R ou mini GPU diffusion --- + n3r_latents = None + latents = latents_frame.clone() + + # 🔥 FIX NaN / stabilité + latents = sanitize_latents(latents) + + # --- volume mask --- + volume_mask = create_volumetrique_mask(latents, coords_v, debug=False) + # --- ControlNet Lite ---------------------------------------------------------------- + control_weight_map = 0.05 + 0.25 * volume_mask**1.5 + control_latent = base_control_latent + 0.005 * torch.randn_like(base_control_latent) + control_latent, control_weight_map = match_latent_size(latents, control_latent, control_weight_map) + control_latent = sanitize_latents(control_latent) # Ne pas oublier ! + # ---------------- N3R avec mémoire latente conditionnée ---------------- + use_n3r_this_frame = use_n3r_model and (frame_counter % random.choice([4,5,6]) == 0) + # ControlNet + control_strength = 0.08 * (1 - frame_counter / total_frames) + 0.03 + # ------------------- Bloc N3R par frame ------------------- + if use_n3r_this_frame: + try: + H, W = latents.shape[-2], latents.shape[-1] + N_samples = n3r_model.N_samples + coords = generate_n3r_coords(H, W, N_samples, seed, frame_counter, device) + n3r_latents = process_n3r_latents(n3r_model, coords, H, W, H, W) + fused_latents = fuse_with_memory(n3r_latents, memory_dict, cf_embeds, frame_counter) + fused_latents = inject_external(fused_latents, external_latent, frame_counter, device) + latents = fuse_n3r_latents_adaptive_new(latents, fused_latents, frame_counter=frame_counter, total_frames=total_frames, latent_injection_start=0.90, latent_injection_end=0.55) + latents = sanitize_latents(latents) + + # ControlNet injection - Avant d’appliquer ControlNet + control_latent, control_weight_map = match_latent_size(latents, control_latent, control_weight_map) + print(f"[DEBUG] latents: {latents.shape}, control_latent: {control_latent.shape}, control_weight_map: {control_weight_map.shape}") + latents = latents + control_strength * control_weight_map * control_latent + latents = sanitize_latents(latents) + except Exception as e: + print(f"[N3R ERROR] {e}") + + # Sauvegarde mémoire périodique + if frame_counter % 30 == 0: + save_memory(memory_dict, memory_file) + + elif use_mini_gpu: + latents = generate_latents_mini_gpu_320( + unet=unet, scheduler=scheduler, + input_latents=latents_frame, embeddings=cf_embeds, + motion_module=motion_module, guidance_scale=current_guidance_scale, + device=device, fp16=True, steps=steps, + debug=verbose, init_image_scale=current_init_image_scale, + creative_noise=current_creative_noise + ) + # ControlNet injection : + control_latent, control_weight_map = match_latent_size(latents, control_latent, control_weight_map) + print(f"[DEBUG] latents: {latents.shape}, control_latent: {control_latent.shape}, control_weight_map: {control_weight_map.shape}") + latents = latents + control_strength * control_weight_map * control_latent + latents = sanitize_latents(latents) + + if latent_injection > 0: + if latents.shape[-2:] != latents_frame.shape[-2:]: + latents = torch.nn.functional.interpolate( latents, size=latents_frame.shape[-2:], mode='bilinear', align_corners=False ).contiguous() + latents = latent_injection*latents_frame + (1-latent_injection)*latents + + # --- Motion module propre et safe --- + if motion_module is not None: + latents_seq = latents.unsqueeze(2).repeat(1,1,3,1,1) if previous_latent_single is None else torch.stack([previous_latent_single.to(device), latents, latents+0.01*torch.randn_like(latents)], dim=2) + latents_seq = sanitize_latents(latents_seq) + latents_seq, applied = apply_motion_safe(latents_seq, motion_module) + latents = latents_seq[:, :, 1, :, :] if applied else latents + latents = sanitize_latents(latents) + + # ProNet avec yeux + if use_n3r_pro_net: + latents = apply_pro_net_volumetrique(latents, coords_v, n3r_pro_net, n3r_pro_strength, sanitize_latents, debug=False) + eye_coords_latent = scale_eye_coords_to_latents( eye_coords, img_H=cfg["H"], img_W=cfg["W"], lat_H=latents.shape[-2], lat_W=latents.shape[-1] ) + if eye_coords_latent: + latents = apply_pro_net_with_eyes(latents, eye_coords_latent, n3r_pro_net, n3r_pro_strength, sanitize_fn=sanitize_latents) + + # Décodage final + latents = latents / LATENT_SCALE + print(f"[DEBUG] LATENT_SCALE: {LATENT_SCALE}, latents min/max: {latents.min().item():.3f}/{latents.max().item():.3f}") + frame_pil = decode_latents_ultrasafe_blockwise_ultranatural(latents, vae, block_size=block_size, overlap=overlap, device=device, frame_counter=frame_counter, latent_scale_boost=latent_scale_boost) + frame_pil = full_frame_postprocess(frame_pil, output_dir, frame_counter, target_temp=target_temp, reference_temp=reference_temp, blur_radius=blur_radius, contrast=contrast, sharpen_percent=sharpen_percent, psave=psave) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="0f", psave=True) + + previous_latent_single = latents.detach().cpu() + frame_counter += 1 + pbar.update(1) + # Nettoyage VRAM + del latents, latents_frame, cf_embeds, n3r_latents + torch.cuda.empty_cache() + + previous_latent_single = current_latent_single + + except Exception as e: + print(f"\n[FRAME ERROR] {img_path}") + print(f"Type d'erreur : {type(e).__name__}") + print(f"Message d'erreur : {e}") + print("Traceback complet :") + traceback.print_exc() + continue + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=upscale_factor) + print(f"🎬 Vidéo générée : {out_video}") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rauto.py b/scripts/n3rauto.py new file mode 100644 index 00000000..2c2178e5 --- /dev/null +++ b/scripts/n3rauto.py @@ -0,0 +1,32 @@ +# n3rUnet5D_auto_tile128.py + +import torch +from scripts.utils import log_gpu_memory +from scripts.modules.motion_module_tiny import MotionModuleTiny +from scripts.vae import decode_latents_to_image_tiled +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae, safe_load_unet, safe_load_scheduler +from scripts.utils.vae_utils import decode_latents_to_image_tiled, decode_latents_frame_auto, generate_5D_video_auto +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_ai_5D, load_image_file, generate_5D_video_auto +from scripts.utils.n3r_utils import decode_latents_frame_auto, generate_5D_video_auto, log_gpu_memory + +tile_size = 128 +overlap = 64 + + + + + +# --- UTILISATION --- +if __name__ == "__main__": + import yaml + config_path = "configs/prompts/1_animate/256.yaml" + with open(config_path) as f: + config = yaml.safe_load(f) + + generate_5D_video_auto( + pretrained_model_path="/mnt/62G/huggingface/miniSD", + config=config, + device='cuda' + ) diff --git a/scripts/n3rcreative.py b/scripts/n3rcreative.py new file mode 100644 index 00000000..36e1b4a2 --- /dev/null +++ b/scripts/n3rcreative.py @@ -0,0 +1,225 @@ +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +import shutil +from PIL import Image, ImageSequence +import numpy as np +import ffmpeg +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_vae_stable, safe_load_unet, safe_load_scheduler +from scripts.utils.vae_utils import encode_images_to_latents, decode_latents_to_image_tiled +from scripts.utils.motion_utils import load_motion_module, apply_motion_module +from scripts.utils.safe_latent import ensure_valid +from scripts.utils.video_utils import save_frames_as_video +from scripts.utils.n3r_utils import load_image_file, generate_latents + +LATENT_SCALE = 0.18215 # Tiny-SD 128x128 + +# ------------------------- +# Image utilities +# ------------------------- +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + if p.lower().endswith(".gif"): + img = Image.open(p) + frames = [torch.tensor(np.array(f)).permute(2,0,1).to(device=device, dtype=dtype)/127.5 - 1.0 + for f in ImageSequence.Iterator(img)] + print(f"✅ GIF chargé : {p} avec {len(frames)} frames") + all_tensors.extend(frames) + else: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +# ------------------------- +# Main +# ------------------------- +def main(args): + + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + # Paramètres principaux + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 35) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + + # Paramètres créatifs + creative_mode = cfg.get("creative_mode", False) + creative_scale_min = cfg.get("creative_scale_min", 0.2) + creative_scale_max = cfg.get("creative_scale_max", 0.8) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f"⏱ Durée totale estimée de la vidéo : {estimated_seconds:.1f}s") + + # ------------------------- + # Load models + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + if not unet or not vae or not scheduler: + print("❌ UNet, VAE ou Scheduler manquant.") + return + + # Motion module + motion_path = cfg.get("motion_module") + motion_module = load_motion_module(motion_path, device=device) if motion_path else None + if not callable(motion_module): + from scripts.utils.motion_utils import default_motion_module + motion_module = default_motion_module + + # Tokenizer / Text Encoder + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # Préparer embeddings + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ------------------------- + # Output + # ------------------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/creative_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + from torchvision.transforms import ToPILImage + to_pil = ToPILImage() + + frames_for_video = [] + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ------------------------- + # Generation loop + # ------------------------- + for img_path in input_paths: + + input_image = load_images([img_path], + W=cfg["W"], + H=cfg["H"], + device=device, + dtype=dtype) + + input_latents = encode_images_to_latents(input_image, vae) + # On s'assure que les latents ont bien une dimension F=1 si image unique + if input_latents.dim() == 4: + input_latents = input_latents.unsqueeze(2) + + B, C, F, H_lat, W_lat = input_latents.shape + + # On génère chaque frame + for pos_embeds, neg_embeds in embeddings: + for f_idx in range(F): + scheduler.set_timesteps(steps, device=device) + + latents_frame = input_latents[:, :, f_idx:f_idx+1, :, :].clone() + # Appliquer bruit créatif si activé + if creative_mode and creative_noise > 0.0: + latents_frame += torch.randn_like(latents_frame) * creative_noise + + # guidance dynamique + dynamic_scale = guidance_scale + if creative_mode: + dynamic_scale *= creative_scale_min + (creative_scale_max - creative_scale_min) * (f_idx / F) + + # init_image_scale progressif + progressive_init_scale = init_image_scale * (1 - f_idx / max(F-1,1)) + + # Génération du latent + latents_frame = generate_latents( + latents=latents_frame, + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + scheduler=scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=dynamic_scale, + init_image_scale=progressive_init_scale + ) + + # Clamp et check NaN + if torch.isnan(latents_frame).any(): + print(f"⚠ Frame {frame_counter:05d} contient NaN, réinitialisation avec petit bruit") + latents_frame = input_latents[:, :, f_idx:f_idx+1, :, :].clone() + latents_frame += torch.randn_like(latents_frame) * max(creative_noise, 0.01) + latents_frame = latents_frame.clamp(-3.0, 3.0).squeeze(2).to(torch.float32) + + mean_lat = latents_frame.abs().mean().item() + if mean_lat < 0.01: + print(f"⚠ Frame {frame_counter:05d} a latent moyen trop faible ({mean_lat:.6f}), ajout de bruit minimal") + latents_frame += torch.randn_like(latents_frame) * max(creative_noise, 0.01) + + # Décodage tuilé + frame_tensor = decode_latents_to_image_tiled(latents_frame, vae, tile_size=32, overlap=16).clamp(0,1) + if frame_tensor.ndim == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frames_for_video.append(frame_pil) + + frame_counter += 1 + pbar.update(1) + print(f"Frame {frame_counter:05d} | mean abs(latent) = {mean_lat:.6f}") + + pbar.close() + + # ------------------------- + # Sauvegarde vidéo + # ------------------------- + save_frames_as_video(frames_for_video, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# ------------------------- +# Entrée +# ------------------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rcyber.py b/scripts/n3rcyber.py new file mode 100644 index 00000000..ad9f2736 --- /dev/null +++ b/scripts/n3rcyber.py @@ -0,0 +1,265 @@ +# -------------------------------------------------------------- +# nr3perfect - INTERPOLATION fast movie - Multi-Model n3oray (VAE séparé) +# -------------------------------------------------------------- +import os +import torch +import argparse +from pathlib import Path +from datetime import datetime +import math +from tqdm import tqdm +from PIL import Image, ImageFilter +from torchvision.transforms import ToPILImage +import threading + +from diffusers import AutoencoderKL +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_unet, safe_load_scheduler +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_robuste, load_image_file, decode_latents_to_image_auto +from safetensors.torch import load_file + +LATENT_SCALE = 0.18215 +stop_generation = False + +# ---------------- Thread pour stopper la génération ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True + +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Fonctions utilitaires ---------------- +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + return overlap + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +def save_frames_as_video_from_folder(folder_path, output_path, fps=12): + import ffmpeg + folder_path = Path(folder_path) + frame_files = sorted(folder_path.glob("frame_*.png")) + if not frame_files: + print("❌ Aucun frame trouvé dans le dossier") + return + pattern = str(folder_path / "frame_*.png") + ( + ffmpeg.input(pattern, framerate=fps, pattern_type='glob') + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) # [B, C, 1, H/8, W/8] + return latents + +# ---------------- MAIN ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 2) + transition_frames = cfg.get("transition_frames", 8) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + vae_path = cfg.get("vae_path", "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors") + # Nom du modèle choisi + n3_model_name = args.n3_model # exemple: "cyber_skin" + + # Chemin du modèle depuis le YAML + n3_model_path = cfg["n3oray_models"].get(n3_model_name) + if n3_model_path is None: + raise ValueError(f"Le modèle N3 '{n3_model_name}' n'est pas défini dans le YAML") + print(f"✅ Chargement du modèle N3 '{n3_model_name}' depuis : {n3_model_path}") + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = ( + len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + + max(len(input_paths) - 1, 0) * transition_frames + ) + print(f"🎞 Frames totales estimées : {total_frames}") + print("⏹ Touche '²' pour arrêter la génération et création de la vidéo directement...") + + # ---------------- LOAD MODELS ---------------- + # Créer UNET vide avec la config standard + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=args.fp16) + + # Charger le safetensors correspondant au N3 choisi + state_dict = load_file(n3_model_path, device=device) + unet.load_state_dict(state_dict, strict=False) + print(f"✅ UNET N3 '{n3_model_name}' chargé correctement") + + scheduler = safe_load_scheduler(args.pretrained_model_path) + scheduler.set_timesteps(steps, device=device) + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # ---------------- LOAD SEPARATE VAE ---------------- + device = "cuda" + + # 1️⃣ Crée le modèle VAE vide correspondant + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D"]*4, + up_block_types=["UpDecoderBlock2D"]*4, + block_out_channels=[128, 256, 512, 512], + latent_channels=4, + layers_per_block=2, + sample_size=256 + ) + + # 2️⃣ Charge les poids safetensors via safetensors + state_dict = load_file(vae_path, device=device) + vae.load_state_dict(state_dict, strict=False) + + #vae = vae.to(device) + offload=True + vae = vae.to("cpu" if offload else device).float() + #vae.enable_tiling() + vae.enable_slicing() + + print(f"✅ VAE safetensors chargé depuis : {vae_path}") + + # ---------------- PROMPT EMBEDDINGS ---------------- + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------- OUTPUT ---------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fastperfect_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + previous_latent_single = None + stop_generation = False + + # ================= MAIN LOOP ================= + for img_idx, img_path in enumerate(input_paths): + if stop_generation: + break + + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents_single = encode_images_to_latents(input_image, vae) + input_latents = input_latents_single.repeat(1, 1, num_fraps_per_image, 1, 1) + current_latent_single = input_latents_single.clone() + block_size = cfg.get("block_size", 64) + overlap = cfg.get("overlap", compute_overlap(cfg["W"], cfg["H"], block_size)) + + # --- Transition latente --- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: + print("⏹ Arrêt demandé, création de la vidéo...") + break + alpha = 0.5 - 0.5 * math.cos(math.pi * t / (transition_frames - 1)) + latent_interp = ((1 - alpha) * previous_latent_single + alpha * current_latent_single).squeeze(2).clamp(-3.0, 3.0) + frame_tensor = decode_latents_to_image_auto(latent_interp, vae) + frame_tensor = normalize_frame(frame_tensor) + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(radius=0.2)) + if upscale_factor > 1: + frame_pil = frame_pil.resize((frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), Image.BICUBIC) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + # --- Boucle principale fraps/prompts (sans mélange dynamique de styles) --- + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + if f == 0: + frame_tensor = (input_image.squeeze(0) + 1.0) / 2.0 + frame_tensor = frame_tensor.clamp(0, 1) + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + latents_frame = generate_latents_robuste( + latents_frame, pos_embeds, neg_embeds, + unet, scheduler, motion_module, device, dtype, + guidance_scale, init_image_scale, creative_noise, seed=frame_counter + ) + frame_tensor = decode_latents_to_image_auto(latents_frame, vae) + frame_tensor = normalize_frame(frame_tensor) + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(radius=0.2)) + if upscale_factor > 1: + frame_pil = frame_pil.resize((frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), Image.BICUBIC) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + + del frame_tensor, frame_pil + if f != 0: + del latents_frame + frame_counter += 1 + if frame_counter % 10 == 0: + torch.cuda.empty_cache() + pbar.update(1) + + previous_latent_single = current_latent_single.clone() + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# ---------------- ENTRY POINT ---------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + parser.add_argument("--n3_model", type=str, default="cyberpunk_style_v3") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rfast.py b/scripts/n3rfast.py new file mode 100644 index 00000000..cbaa25f9 --- /dev/null +++ b/scripts/n3rfast.py @@ -0,0 +1,268 @@ +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +import shutil +from PIL import Image +from PIL import ImageFilter +from torchvision.transforms import ToPILImage + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import ( + safe_load_unet, + safe_load_scheduler, + safe_load_vae_stable, + decode_latents_to_image_tiled +) +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_robuste, load_image_file + +LATENT_SCALE = 0.18215 + + +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0,1) + +# ------------------------- +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W,H)//4) + return overlap + +# ------------------------- +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +# ------------------------- +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.benchmark = True + +# ------------------------- +def save_frames_as_video(frames, output_path, fps=12): + import ffmpeg + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) + +# ------------------------- +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +# ------------------------- +def main(args): + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + estimated_seconds = total_frames / fps + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f"⏱ Durée totale estimée de la vidéo : {estimated_seconds:.1f}s") + + # ------------------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + try: + unet.enable_xformers_memory_efficient_attention() + unet.set_attention_slice("max") + vae.enable_slicing() + vae.enable_tiling() + print("✅ xFormers memory efficient attention activé") + except Exception: + print("⚠ xFormers non disponible") + + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + + if not vae: + print("❌ VAE manquant.") + return + if not scheduler: + print("❌ Scheduler manquant.") + return + scheduler.set_timesteps(steps, device=device) + if not unet: + print("❌ UNet manquant.") + return + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fast_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frames_for_video = [] + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ------------------------- + for img_idx, img_path in enumerate(input_paths): + + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents = encode_images_to_latents(input_image, vae) + input_latents = input_latents.expand(-1, -1, num_fraps_per_image, -1, -1).clone() + + block_size = cfg.get("block_size", 64) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size, max_overlap_ratio=0.6) + + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + # ------------------------- Première frame + #if frame_counter == 0: + if f == 0: + # Première frame = image input originale + frame_tensor = input_image.squeeze(0) + # Si le tenseur est normalisé [-1,1], le ramener à [0,1] + frame_tensor = (frame_tensor + 1.0) / 2.0 + frame_tensor = frame_tensor.clamp(0,1) + print(f"ℹ️ Frame {frame_counter:05d} = image input originale") + latents_frame = None + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + + if latents_frame is not None: # new code + try: + latents_frame = generate_latents_robuste( + latents_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_counter + ) + except Exception as e: + print(f"⚠ Erreur génération frame {frame_counter:05d}, reset avec petit bruit: {e}") + latents_frame = input_latents[:, :, f:f+1, :, :] + torch.randn_like(input_latents[:, :, f:f+1, :, :]) * 0.05 + latents_frame = latents_frame.to(dtype=dtype) + + # Clamp et decode tuilé + if latents_frame is not None: + latents_frame = latents_frame.squeeze(2).clamp(-3.0, 3.0) + #frame_tensor = decode_latents_to_image_tiled(latents_frame, vae, tile_size=32, overlap=16).clamp(0,1) + + #------------------- NEW CODE ------------------------------------------- + block_size = cfg.get("block_size", 64) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size, max_overlap_ratio=0.6) # 0.6 ou 0.65 Max + + if latents_frame is not None: + frame_tensor = decode_latents_to_image_tiled( + latents_frame, vae, + tile_size=block_size, + overlap=overlap + ).clamp(0,1) + + frame_tensor = normalize_frame(frame_tensor) # A test ok + #------------------------------------------------------------------------ + + + if frame_tensor.ndim == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=0.2)) # Ajout flou sur le bord pour lisser l'overlap' + # ---------------- UPSCALE ---------------- + upscale_factor = cfg.get("upscale_factor", 2) # ajouter dans config si besoin + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + # ---------------------------------------- + + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + + frames_for_video.append(frame_pil) + frame_counter += 1 + pbar.update(1) + + if latents_frame is not None: + mean_lat = latents_frame.abs().mean().item() + if math.isnan(mean_lat) or mean_lat < 1e-5: + print(f"⚠ Frame {frame_counter:05d} contient NaN ou latent trop petit, réinitialisation") + + del latents_frame, frame_tensor + torch.cuda.empty_cache() + + pbar.close() + save_frames_as_video(frames_for_video, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# ------------------------- +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rfastinterpol.py b/scripts/n3rfastinterpol.py new file mode 100644 index 00000000..8cee96b0 --- /dev/null +++ b/scripts/n3rfastinterpol.py @@ -0,0 +1,290 @@ +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +import shutil +from PIL import Image +from PIL import ImageFilter +from torchvision.transforms import ToPILImage + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import ( + safe_load_unet, + safe_load_scheduler, + safe_load_vae_stable, + decode_latents_to_image_tiled +) +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_robuste, load_image_file + +LATENT_SCALE = 0.18215 + + +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + return overlap + + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + + +def save_frames_as_video(frames, output_path, fps=12): + import ffmpeg + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + + shutil.rmtree(temp_dir) + + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + + +def main(args): + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + transition_frames = cfg.get("transition_frames", 8) + + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = ( + len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + + max(len(input_paths) - 1, 0) * transition_frames + ) + + print(f"🎞 Frames totales estimées : {total_frames}") + + # ---------------- LOAD MODELS ---------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + + scheduler.set_timesteps(steps, device=device) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained( + os.path.join(args.pretrained_model_path, "text_encoder") + ).to(device) + + if args.fp16: + text_encoder = text_encoder.half() + + # ---------------- PROMPT EMBEDDINGS ---------------- + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------- OUTPUT ---------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fastinterpol_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frames_for_video = [] + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + previous_latent_single = None + + # ================= MAIN LOOP ================= + for img_idx, img_path in enumerate(input_paths): + + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents = encode_images_to_latents(input_image, vae) + current_latent_single = input_latents.clone() + + block_size = cfg.get("block_size", 64) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + # ----------- LATENT INTERPOLATION ----------- + if previous_latent_single is not None and transition_frames > 0: + print("🎬 Transition latente...") + + for t in range(transition_frames): + alpha = 0.5 - 0.5 * math.cos(math.pi * t / (transition_frames - 1)) + + latent_interp = ( + (1 - alpha) * previous_latent_single + + alpha * current_latent_single + ) + + latent_interp = latent_interp.squeeze(2).clamp(-3.0, 3.0) + + frame_tensor = decode_latents_to_image_tiled( + latent_interp, + vae, + tile_size=block_size, + overlap=overlap + ).clamp(0, 1) + + frame_tensor = normalize_frame(frame_tensor) + + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=0.2)) + + # -------- UPSCALE -------- + upscale_factor = cfg.get("upscale_factor", 2) + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frames_for_video.append(frame_pil) + + frame_counter += 1 + pbar.update(1) + + # ------------------------------------------- + + input_latents = input_latents.expand(-1, -1, num_fraps_per_image, -1, -1).clone() + + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + + if f == 0: + frame_tensor = input_image.squeeze(0) + frame_tensor = (frame_tensor + 1.0) / 2.0 + frame_tensor = frame_tensor.clamp(0, 1) + latents_frame = None + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + + try: + latents_frame = generate_latents_robuste( + latents_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_counter + ) + except Exception: + latents_frame = input_latents[:, :, f:f+1, :, :] + + latents_frame = latents_frame.squeeze(2).clamp(-3.0, 3.0) + + frame_tensor = decode_latents_to_image_tiled( + latents_frame, + vae, + tile_size=block_size, + overlap=overlap + ).clamp(0, 1) + + frame_tensor = normalize_frame(frame_tensor) + + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=0.2)) + + # -------- UPSCALE -------- + upscale_factor = cfg.get("upscale_factor", 2) + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frames_for_video.append(frame_pil) + + if latents_frame is not None: + mean_lat = latents_frame.abs().mean().item() + if math.isnan(mean_lat) or mean_lat < 1e-5: + print(f"⚠ Frame {frame_counter:05d} latent suspect") + + frame_counter += 1 + pbar.update(1) + + previous_latent_single = current_latent_single.clone() + + pbar.close() + save_frames_as_video(frames_for_video, out_video, fps=fps) + + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rfastmovie.py b/scripts/n3rfastmovie.py new file mode 100644 index 00000000..3c649104 --- /dev/null +++ b/scripts/n3rfastmovie.py @@ -0,0 +1,339 @@ +#-------------------------------------------------------------- +# nr3fastmovie - INTERPOLATION cinema movie +#-------------------------------------------------------------- + +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +import shutil +from PIL import Image +from PIL import ImageFilter +from torchvision.transforms import ToPILImage + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import ( + safe_load_unet, + safe_load_scheduler, + safe_load_vae_stable, + decode_latents_to_image_tiled +) +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_robuste, load_image_file + +LATENT_SCALE = 0.18215 + + +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + return overlap + + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + + +def save_frames_as_video(frames, output_path, fps=12): + import ffmpeg + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + + shutil.rmtree(temp_dir) + + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + + +def main(args): + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + # new param + transition_frames = cfg.get("transition_frames", 8) + transition_zoom = cfg.get("transition_zoom", 0.0) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + transition_frames = cfg.get("transition_frames", 8) + + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = ( + len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + + max(len(input_paths) - 1, 0) * transition_frames + ) + + print(f"🎞 Frames totales estimées : {total_frames}") + + # ---------------- LOAD MODELS ---------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + + scheduler.set_timesteps(steps, device=device) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained( + os.path.join(args.pretrained_model_path, "text_encoder") + ).to(device) + + if args.fp16: + text_encoder = text_encoder.half() + + # ---------------- PROMPT EMBEDDINGS ---------------- + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------- OUTPUT ---------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fastmovie_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frames_for_video = [] + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + previous_latent_single = None + + # ================= MAIN LOOP ================= + for img_idx, img_path in enumerate(input_paths): + + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents = encode_images_to_latents(input_image, vae) + current_latent_single = input_latents.clone() + + block_size = cfg.get("block_size", 64) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + # ----------- LATENT INTERPOLATION ----------- + # ================= CINEMA++ TRANSITION ================= + if previous_latent_single is not None and transition_frames > 0: + print("🎬 Transition CINEMA++ ...") + + for t in range(transition_frames): + + alpha = 0.5 - 0.5 * math.cos(math.pi * t / (transition_frames - 1)) + + latent_interp = ( + (1 - alpha) * previous_latent_single + + alpha * current_latent_single + ) + + # On garde dimension temporelle pour UNet + latent_interp = latent_interp.clone() + + # --------- Stylisation via UNet ---------- + for pos_embeds, neg_embeds in embeddings[:1]: # 1er prompt seulement + latent_interp = generate_latents_robuste( + latent_interp, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=0.6, # plus libre pendant morph + creative_noise=0.02, # léger bruit artistique + seed=frame_counter + ) + # ----------------------------------------- + + latent_interp = latent_interp.squeeze(2).clamp(-3.0, 3.0) + + frame_tensor = decode_latents_to_image_tiled( + latent_interp, + vae, + tile_size=block_size, + overlap=overlap + ).clamp(0, 1) + + frame_tensor = normalize_frame(frame_tensor) + + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=0.2)) + + # -------- CINEMATIC ZOOM -------- + if transition_zoom > 0: + zoom_factor = 1.0 + transition_zoom * alpha + w, h = frame_pil.size + new_w = int(w * zoom_factor) + new_h = int(h * zoom_factor) + + frame_zoom = frame_pil.resize((new_w, new_h), Image.BICUBIC) + + left = (new_w - w) // 2 + top = (new_h - h) // 2 + frame_pil = frame_zoom.crop((left, top, left + w, top + h)) + # -------------------------------- + + # -------- UPSCALE -------- + upscale_factor = cfg.get("upscale_factor", 2) + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + # ------------------------- + + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frames_for_video.append(frame_pil) + + frame_counter += 1 + pbar.update(1) + + del latent_interp, frame_tensor + torch.cuda.empty_cache() + # ======================================================== + + # ------------------------------------------- + + input_latents = input_latents.expand(-1, -1, num_fraps_per_image, -1, -1).clone() + + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + + if f == 0: + frame_tensor = input_image.squeeze(0) + frame_tensor = (frame_tensor + 1.0) / 2.0 + frame_tensor = frame_tensor.clamp(0, 1) + latents_frame = None + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + + try: + latents_frame = generate_latents_robuste( + latents_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_counter + ) + except Exception: + latents_frame = input_latents[:, :, f:f+1, :, :] + + latents_frame = latents_frame.squeeze(2).clamp(-3.0, 3.0) + + frame_tensor = decode_latents_to_image_tiled( + latents_frame, + vae, + tile_size=block_size, + overlap=overlap + ).clamp(0, 1) + + frame_tensor = normalize_frame(frame_tensor) + + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()) + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=0.2)) + + # -------- UPSCALE -------- + upscale_factor = cfg.get("upscale_factor", 2) + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frames_for_video.append(frame_pil) + + if latents_frame is not None: + mean_lat = latents_frame.abs().mean().item() + if math.isnan(mean_lat) or mean_lat < 1e-5: + print(f"⚠ Frame {frame_counter:05d} latent suspect") + + frame_counter += 1 + pbar.update(1) + + previous_latent_single = current_latent_single.clone() + + pbar.close() + save_frames_as_video(frames_for_video, out_video, fps=fps) + + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rmodel.py b/scripts/n3rmodel.py new file mode 100644 index 00000000..ca4940e7 --- /dev/null +++ b/scripts/n3rmodel.py @@ -0,0 +1,344 @@ +# -------------------------------------------------------------- +# n3rmodel.py - AnimateDiff pipeline optimisé VRAM +# -------------------------------------------------------------- +import os, math, threading +from pathlib import Path +from datetime import datetime +import torch +from tqdm import tqdm +from PIL import Image, ImageFilter +from torchvision.transforms import ToPILImage +import argparse + +from diffusers import AutoencoderKL +from transformers import CLIPTokenizerFast, CLIPTextModel +from safetensors.torch import load_file + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import safe_load_unet, safe_load_scheduler, generate_latents_robuste_model +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_robuste, load_image_file, decode_latents_to_image_auto + +LATENT_SCALE = 0.18215 +stop_generation = False + + +def prepare_frame_tensor(frame_tensor): + """Assure que frame_tensor est [C,H,W] pour ToPIL""" + if frame_tensor.ndim == 5: # [B,C,T,H,W] + frame_tensor = frame_tensor.squeeze(2) + if frame_tensor.ndim == 4: # [B,C,H,W] + frame_tensor = frame_tensor.squeeze(0) + if frame_tensor.ndim == 3 and frame_tensor.shape[0] != 3: # [H,W,C] -> [C,H,W] + frame_tensor = frame_tensor.permute(2,0,1) + return frame_tensor + +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Utils ---------------- +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + return min(overlap, min(W,H)//4) + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +def save_frames_as_video_from_folder(folder_path, output_path, fps=12): + import ffmpeg + folder_path = Path(folder_path) + frame_files = sorted(folder_path.glob("frame_*.png")) + if not frame_files: + print("❌ Aucun frame trouvé") + return + pattern = str(folder_path / "frame_*.png") + ( + ffmpeg.input(pattern, framerate=fps, pattern_type='glob') + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) # [B,C,1,H/8,W/8] + return latents + +def decode_latents_to_image_auto_test(latents, vae): + """Décodage latents → images, cast automatique au dtype/device du VAE""" + if latents.ndim == 5: # [B,C,T,H,W] + latents = latents.squeeze(2) + elif latents.ndim == 3: # [C,H,W] + latents = latents.unsqueeze(0) + latents = latents.to(dtype=next(vae.parameters()).dtype, device=vae.device) + with torch.no_grad(): + images = vae.decode(latents / LATENT_SCALE).sample + images = ((images + 1.0)/2.0).clamp(0,1) + return images + + +def decode_latents_to_image_auto_new(latents, vae): + import torch + import torch.nn.functional as F + + # Nettoyage latents + latents = torch.nan_to_num(latents, nan=0.0, posinf=4.0, neginf=-4.0) + + # Soft clamp (beaucoup plus stable que clamp) + latents = torch.tanh(latents) * 3.0 + + # Gestion dimensions AnimateDiff + if latents.ndim == 5: # [B,C,T,H,W] + latents = latents[:,:,0,:,:] + + # scale VAE SD 1.5 + latents = latents / 0.18215 + + with torch.no_grad(): + image = vae.decode(latents).sample + + # normalisation image + image = (image + 1) / 2 + image = image.clamp(0,1) + + # filtre anti bruit léger + image = F.avg_pool2d(image, kernel_size=2, stride=1, padding=1) + + return image + + +def tensor_to_pil(frame_tensor): + import torchvision.transforms as T + + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor[0] + + frame_tensor = frame_tensor.clamp(0,1) + + return T.ToPILImage()(frame_tensor.cpu()) + + +# ---------------- MAIN ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 2) + transition_frames = cfg.get("transition_frames", 8) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + # ---------------- LOAD UNET ---------------- + n3_model_name = args.n3_model + n3_model_path = cfg["n3oray_models"].get(n3_model_name) + if n3_model_path is None: + raise ValueError(f"N3 model '{n3_model_name}' non défini dans le YAML") + print(f"✅ Chargement du modèle N3 '{n3_model_name}' depuis : {n3_model_path}") + + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): + unet.enable_attention_slicing() + state_dict = load_file(n3_model_path, device=device) + unet.load_state_dict(state_dict, strict=False) + print(f"✅ UNET N3 '{n3_model_name}' chargé correctement") + + # ---------------- Scheduler ---------------- + scheduler = safe_load_scheduler(args.pretrained_model_path) + scheduler.set_timesteps(steps, device=device) + + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + print(f"✅ Motion module chargé") + + # ---------------- Tokenizer & Text Encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device) + text_encoder = text_encoder.half() + + # ---------------- VAE sur CPU ---------------- + vae_path = cfg.get("vae_path") + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D"]*4, + up_block_types=["UpDecoderBlock2D"]*4, + block_out_channels=[128,256,512,512], + latent_channels=4, + layers_per_block=2, + sample_size=256 + ) + state_dict = load_file(vae_path, device="cpu") + vae.load_state_dict(state_dict, strict=False) + vae.to("cpu").float() + vae.enable_slicing() + print(f"✅ VAE safetensors chargé depuis : {vae_path}") + + # ---------------- Embeddings ---------------- + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item,list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts,list) else str(negative_prompts) + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths)*num_fraps_per_image + max(len(input_paths)-1,0)*transition_frames + print(f"🎞 Frames totales estimées : {total_frames}") + + # ---------------- OUTPUT ---------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/model_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + previous_latent_single = None + stop_generation = False + + # ---------------- Main loop ---------------- + for img_idx, img_path in enumerate(input_paths): + if stop_generation: + break + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents_single = encode_images_to_latents(input_image, vae) + input_latents = input_latents_single.repeat(1,1,num_fraps_per_image,1,1) + current_latent_single = input_latents_single.clone() + block_size = cfg.get("block_size",64) + overlap = cfg.get("overlap", compute_overlap(cfg["W"], cfg["H"], block_size)) + + # Transition latente + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/(transition_frames-1)) + latent_interp = ((1 - alpha) * previous_latent_single + alpha * current_latent_single) + latent_interp = latent_interp.squeeze(2).clamp(-3.0, 3.0) + + frame_tensor = decode_latents_to_image_auto_new(latent_interp, vae) # original + frame_tensor = normalize_frame(frame_tensor) + + # ***** Correctif dimension: + # print(frame_tensor.min().item(), frame_tensor.max().item()) # test + # Assurer que frame_tensor est [C,H,W] ou [H,W] avant ToPIL + # if frame_tensor.ndim == 4 and frame_tensor.shape[0] == 1: # [1,C,H,W] + # frame_tensor = frame_tensor.squeeze(0) + # elif frame_tensor.ndim == 5 and frame_tensor.shape[0] == 1 and frame_tensor.shape[2] == 1: # [1,C,1,H,W] + # frame_tensor = frame_tensor.squeeze(0).squeeze(2) + # print(frame_tensor.min().item(), frame_tensor.max().item()) # test + + frame_tensor = prepare_frame_tensor(frame_tensor) + if t == 0: + frame_pil = to_pil(frame_tensor.cpu()) + else: + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(0.2)) + if upscale_factor>1: + frame_pil = frame_pil.resize((frame_pil.width*upscale_factor, frame_pil.height*upscale_factor), Image.BICUBIC) + #frame_pil.save(Path(f"./outputs/frame_{frame_counter:05d}.png")) + frame_path = output_dir / f"frame_{frame_counter:05d}.png" + frame_pil.save(frame_path) + frame_counter += 1 + pbar.update(1) + + # Génération frames + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + if f==0: + frame_tensor = (input_image.squeeze(0)+1.0)/2.0 + frame_tensor = frame_tensor.clamp(0,1) + else: + latents_frame = input_latents[:,:,f:f+1,:,:].clone() + latents_frame = generate_latents_robuste( + latents_frame, pos_embeds, neg_embeds, unet, scheduler, + motion_module, device, dtype, + guidance_scale, init_image_scale, creative_noise, seed=frame_counter + ) + # 🔹 Nettoyage latents + latents_frame = torch.nan_to_num(latents_frame, nan=0.0, posinf=5.0, neginf=-5.0) + latents_frame = latents_frame.clamp(-5.0,5.0) + + frame_tensor = decode_latents_to_image_auto_new(latents_frame, vae) # original + + #frame_tensor = decode_latents_to_image_auto(latents_frame, vae=None) # test + + frame_tensor = normalize_frame(frame_tensor) + #if frame_tensor.ndim==4: frame_tensor = frame_tensor.squeeze(0) + # ***** Correctif dimension: + # print(frame_tensor.min().item(), frame_tensor.max().item()) # test + # Assurer que frame_tensor est [C,H,W] ou [H,W] avant ToPIL + # if frame_tensor.ndim == 4 and frame_tensor.shape[0] == 1: # [1,C,H,W] + # frame_tensor = frame_tensor.squeeze(0) + # elif frame_tensor.ndim == 5 and frame_tensor.shape[0] == 1 and frame_tensor.shape[2] == 1: # [1,C,1,H,W] + # frame_tensor = frame_tensor.squeeze(0).squeeze(2) + # print(frame_tensor.min().item(), frame_tensor.max().item()) # test + + frame_tensor = prepare_frame_tensor(frame_tensor) + + if f == 0: + frame_pil = to_pil(frame_tensor.cpu()) + else: + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(0.2)) + if upscale_factor>1: + frame_pil = frame_pil.resize((frame_pil.width*upscale_factor, frame_pil.height*upscale_factor), Image.BICUBIC) + + #frame_pil.save(Path(f"./outputs/frame_{frame_counter:05d}.png")) + frame_path = output_dir / f"frame_{frame_counter:05d}.png" + frame_pil.save(frame_path) + + frame_counter += 1 + if f!=0: del latents_frame + if frame_counter%10==0: torch.cuda.empty_cache() + pbar.update(1) + previous_latent_single = current_latent_single.clone() + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# ---------------- ENTRY ---------------- --n3_model "cybersamurai_v2" \ +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + parser.add_argument("--n3_model", type=str, default="cyberpunk_style_v3") # cyber_skin ou cyberpunk_style_v3 ou cybersamurai_v2 + args = parser.parse_args() + main(args) diff --git a/scripts/n3rmodelFast.py b/scripts/n3rmodelFast.py new file mode 100644 index 00000000..d837393c --- /dev/null +++ b/scripts/n3rmodelFast.py @@ -0,0 +1,381 @@ +# -------------------------------------------------------------- +# n3rmodelFast.py - AnimateDiff ultra-light ~2Go VRAM +# -------------------------------------------------------------- +import os, math, threading +from pathlib import Path +from datetime import datetime +import torch +from tqdm import tqdm +from torchvision.transforms.functional import to_pil_image +from PIL import Image +from PIL import ImageFilter +from torchvision.transforms import ToPILImage +import argparse + +from diffusers import PNDMScheduler +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.lora_utils import apply_lora_smart +from scripts.utils.vae_config import load_vae +from scripts.utils.tools_utils import ensure_4_channels +from scripts.utils.config_loader import load_config +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_safe_miniGPU, generate_latents_mini_gpu, load_images_test, generate_latents_mini_gpu_320, run_diffusion_pipeline +from scripts.utils.fx_utils import encode_images_to_latents_nuanced, decode_latents_ultrasafe_blockwise, save_frames_as_video_from_folder, encode_images_to_latents_safe +from scripts.utils.vae_utils import safe_load_unet + +LATENT_SCALE = 0.18215 +stop_generation = False + + +# ---------------- DEBUG UTILS ---------------- +def log_debug(message, level="INFO", verbose=True): + """ + Affiche le message si verbose=True. + level: "INFO", "DEBUG", "WARNING" + """ + if verbose: + print(f"[{level}] {message}") + +# ---------------- Thread stop ---------------- +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True +threading.Thread(target=wait_for_stop, daemon=True).start() + +# ---------------- Utilitaires ---------------- + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + return min(overlap, min(W,H)//4) + +def apply_motion_safe(latents, motion_module, threshold=1e-2): + if latents.abs().max() < threshold: + return latents, False + return motion_module(latents), True + +# ---------------- MAIN ---------------- +def main(args): + global stop_generation + cfg = load_config(args.config) + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 + + use_mini_gpu = cfg.get("use_mini_gpu", True) # True for <2 Go VRAM - False for <4 Go VRAM + verbose = cfg.get("verbose", False) # True or False + # ---------------- Injection contrôlée du latent original ---------------- + # latent_injection = 0.0 -> pas d'influence de l'image originale + # latent_injection = 1.0 -> on garde totalement le latent d'origine + latent_injection = max(0.0, min(1.0, cfg.get("latent_injection", 0.7))) # implication de l'image original dans le resultat final' + final_latent_scale = cfg.get("final_latent_scale", 1/8) #1/2 de l'image, 1/4 de l'image, 0.125 pour 1/8, etc. + + + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 1) + transition_frames = cfg.get("transition_frames", 4) + num_fraps_per_image = cfg.get("num_fraps_per_image", 2) + steps = max(cfg.get("steps", 16), 4) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + latent_scale_boost = cfg.get("latent_scale_boost", 5.71) + + + print("📌 Paramètres de génération :") + print(f" fps : {fps}") + print(f" upscale_factor : {upscale_factor}") + print(f" num_fraps_per_image : {num_fraps_per_image}") + print(f" steps : {steps}") + print(f" guidance_scale : {guidance_scale}") + print(f" init_image_scale : {init_image_scale}") + print(f" creative_noise : {creative_noise}") + print(f" latent_scale_boost : {latent_scale_boost}") + + scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + scheduler.set_timesteps(steps, device=device) + + # ---------------- UNET ---------------- + unet = safe_load_unet(args.pretrained_model_path, device=device, fp16=True) + if hasattr(unet, "enable_attention_slicing"): unet.enable_attention_slicing() + if hasattr(unet, "enable_xformers_memory_efficient_attention"): + try: unet.enable_xformers_memory_efficient_attention(True) + except: pass + + # ---------------- LoRA ---------------- + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 768) + n3oray_models = cfg.get("n3oray_models") + if n3oray_models: + for model_name, lora_path in n3oray_models.items(): + applied = apply_lora_smart(unet, lora_path, alpha=0.5, device=device, verbose=verbose) + if not applied: + print(f"⚠ LoRA '{model_name}' ignorée (incompatible UNet)") + else: + print("⚠ Aucun modèle LoRA n'est configuré, étape ignorée.") + # ---------------- Motion module ---------------- + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + #motion_module = None + if motion_module is not None: + log_debug(f"motion_module type: {type(motion_module)}", level="INFO", verbose=cfg.get("verbose", True)) + + + # ---------------- Tokenizer / Text encoder ---------------- + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path,"tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path,"text_encoder")).to(device).to(dtype) + + # ---------------- VAE ---------------- + vae_path = cfg.get("vae_path") + vae, vae_type, latent_channels, LATENT_SCALE = load_vae(vae_path, device=device, dtype=dtype) + + # ---------------- Embeddings ---------------- + # ---------------- Embeddings ---------------- + embeddings = [] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + # Récupération du cross_attention_dim attendu par le UNet + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 1024) + + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + # ---------------- CORRECTION DES DIMENSIONS ---------------- + current_dim = pos_embeds.shape[-1] + if current_dim != unet_cross_attention_dim: + # Projection linéaire 768 -> 1024 + projection = torch.nn.Linear(current_dim, unet_cross_attention_dim).to(device).to(dtype) + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + + embeddings.append((pos_embeds, neg_embeds)) + + print(f"✅ Embeddings adaptées à UNet cross_attention_dim={unet_cross_attention_dim}") + + # ---------------- DEBUG DIMENSIONS ---------------- + print("\n🔍 Vérification des dimensions avant génération") + for i, (pos, neg) in enumerate(embeddings): + print(f"Embedding {i}: pos {pos.shape}, neg {neg.shape}") + if pos.shape[-1] != unet_cross_attention_dim: + log_debug(f"Attention : pos_embedding dim {pos.shape[-1]} != UNet {unet_cross_attention_dim}", level="WARNING", verbose=verbose) + if neg.shape[-1] != unet_cross_attention_dim: + print(f"⚠ Attention : neg_embedding dim {neg.shape[-1]} != UNet {unet_cross_attention_dim}") + + print(f"UNet cross_attention_dim attendu : {unet_cross_attention_dim}") + print("✅ Toutes les dimensions semblent correctes\n") + + + # ---------------- Input images ---------------- + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + total_frames = len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/modelFast_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + print(f"📌 fps: {fps}, frames/image: {num_fraps_per_image}, steps: {steps}, guidance_scale: {guidance_scale}") + + block_size = cfg.get("block_size", 64) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + previous_latent_single = None + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + #to_pil = ToPILImage() + + for img_idx, img_path in enumerate(input_paths): + if stop_generation: break + + # Charger et normaliser l'image + input_image = load_images_test([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_image = ensure_4_channels(input_image) + + # Encoder l'image en latents + current_latent_single = encode_images_to_latents_safe(input_image, vae, device=device, latent_scale=LATENT_SCALE) + + log_debug(f"Latents shape after encoding: {current_latent_single.shape}", level="DEBUG", verbose=verbose) + + + # ---------------- Ajuster la taille des latents pour UNet ---------------- + # UNet.sample_size correspond à la taille d'entrée attendue par le modèle (ex: 320 ou 512) + target_H = getattr(unet.config, "sample_size", cfg["H"]) // 8 + target_W = getattr(unet.config, "sample_size", cfg["W"]) // 8 + + # Interpolation bilinéaire pour correspondre à UNet + current_latent_single = torch.nn.functional.interpolate( + current_latent_single, size=(target_H, target_W), mode='bilinear', align_corners=False + ) + + # Assurer 4 channels (sécurité) + current_latent_single = ensure_4_channels(current_latent_single) + + log_debug(f"DEBUG latents shape after interpolation: {current_latent_single.shape}", level="DEBUG", verbose=verbose) + + # ---------------- Transition frames ---------------- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: break + alpha = 0.5 - 0.5*math.cos(math.pi*t/max(transition_frames-1,1)) + latent_interp = (1-alpha)*previous_latent_single + alpha*current_latent_single + if motion_module: + latent_interp, _ = apply_motion_safe(latent_interp, motion_module) + + # ⚡ Forcer la taille finale (VRAM-safe), exactement comme les autres latents + final_latent_H = int(cfg["H"] * final_latent_scale) + final_latent_W = int(cfg["W"] * final_latent_scale) + if latent_interp.shape[-2:] != (final_latent_H, final_latent_W): + latent_interp = torch.nn.functional.interpolate( + latent_interp, + size=(final_latent_H, final_latent_W), + mode='bilinear', + align_corners=False + ) + frame_pil = decode_latents_ultrasafe_blockwise(latent_interp, vae, + block_size=block_size, overlap=overlap, + gamma=1.0, brightness=1.0, + contrast=1.5, saturation=1.3, + device=device, frame_counter=frame_counter, + latent_scale_boost=latent_scale_boost) + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + # ---------------- Frames principales ---------------- + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + if stop_generation: break + if f == 0: + # Frame initiale = image d'entrée + frame_tensor = torch.clamp((input_image.squeeze(0)+1)/2, 0, 1) + + # Upscale proportionnel à final_latent_scale + # On multiplie par final_latent_scale > 0 pour correspondre à la taille des latents + # Si final_latent_scale <1 → on agrandit légèrement, si >1 → on réduit légèrement + upscale_H = int(frame_tensor.shape[-2] * final_latent_scale * 8) # 8 = facteur latent->image + upscale_W = int(frame_tensor.shape[-1] * final_latent_scale * 8) + frame_tensor = torch.nn.functional.interpolate( + frame_tensor.unsqueeze(0), + size=(upscale_H, upscale_W), + mode='bilinear', + align_corners=False + ).squeeze(0) + + frame_pil = to_pil_image(frame_tensor) + else: + latents_frame = current_latent_single.clone() + + # ---------------- Redimension latents selon final_latent_scale avant UNet ---------------- + target_H = int(latents_frame.shape[-2] * final_latent_scale) + target_W = int(latents_frame.shape[-1] * final_latent_scale) + if latents_frame.shape[-2:] != (target_H, target_W): + latents_frame = torch.nn.functional.interpolate( + latents_frame, + size=(target_H, target_W), + mode='bilinear', + align_corners=False + ) + + # Assurer 4 canaux + latents_frame = ensure_4_channels(latents_frame) + + # ---------------- Génération latents ---------------- + cf_embeds = (pos_embeds.to(device), neg_embeds.to(device)) + if use_mini_gpu: + latents = generate_latents_mini_gpu_320( + unet=unet, + scheduler=scheduler, + input_latents=latents_frame, + embeddings=cf_embeds, + motion_module=motion_module, + guidance_scale=guidance_scale, + device=device, + fp16=True, + steps=steps, + debug=verbose, + init_image_scale=init_image_scale, # <-- ajouté + creative_noise=creative_noise # <-- ajouté + ) + # Injection contrôlée du latent d'entrée depuis le cfg + if latent_injection > 0.0: + latents = latent_injection * current_latent_single + (1 - latent_injection) * latents + else: + latents_input = latents_frame[:, :3, :, :] + latents = run_diffusion_pipeline( + unet=unet, + vae=vae, + scheduler=scheduler, + images=latents_input, + embeddings=cf_embeds, + timesteps=scheduler.timesteps, + device=device + ) + + # ---------------- Motion Module ---------------- + if motion_module: + latents, _ = apply_motion_safe(latents, motion_module) + + # ---------------- Interpolation vers latents finaux (VRAM-safe) ---------------- + final_latent_H = int(cfg["H"] * final_latent_scale) + final_latent_W = int(cfg["W"] * final_latent_scale) + if latents.shape[-2:] != (final_latent_H, final_latent_W): + latents = torch.nn.functional.interpolate( + latents, + size=(final_latent_H, final_latent_W), + mode='bilinear', + align_corners=False + ) + + # ---------------- Décodage bloc par bloc ---------------- + frame_pil = decode_latents_ultrasafe_blockwise( + latents, vae, + block_size=block_size, + overlap=overlap, + gamma=1.0, + brightness=1.0, + contrast=1.5, + saturation=1.3, + device=device, + frame_counter=frame_counter, + latent_scale_boost=latent_scale_boost + ) + + del latents + torch.cuda.empty_cache() + + # Appliquer un flou léger sur toute l'image + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=0.2)) + + # ---------------- Sauvegarde ---------------- + frame_pil.save(output_dir / f"frame_{frame_counter:05d}.png") + frame_counter += 1 + pbar.update(1) + + + previous_latent_single = current_latent_single + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps, upscale_factor=2) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé avec motion module safe.") + +# ---------------- ENTRY ---------------- +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rperfect.py b/scripts/n3rperfect.py new file mode 100644 index 00000000..40fde403 --- /dev/null +++ b/scripts/n3rperfect.py @@ -0,0 +1,266 @@ +#-------------------------------------------------------------- +# nr3perfect - INTERPOLATION fast movie - Optimal (version finale) +#-------------------------------------------------------------- + +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +from PIL import Image, ImageFilter +from torchvision.transforms import ToPILImage + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import ( + safe_load_unet, + safe_load_scheduler, + safe_load_vae_stable +) +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import generate_latents_robuste, load_image_file, decode_latents_to_image_auto +#import keyboard # pip install keyboard + +LATENT_SCALE = 0.18215 + +import threading + +stop_generation = False + +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True + +# Lance le thread +threading.Thread(target=wait_for_stop, daemon=True).start() + +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + return overlap + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +def save_frames_as_video_from_folder(folder_path, output_path, fps=12): + import ffmpeg + folder_path = Path(folder_path) + # Tri alphabétique de tous les fichiers commençant par "frame_" + frame_files = sorted(folder_path.glob("frame_*.png")) + if not frame_files: + print("❌ Aucun frame trouvé dans le dossier") + return + + # ffmpeg peut utiliser un pattern, mais attention à l'ordre + first_frame = frame_files[0] + pattern = str(folder_path / "frame_*.png") + + ( + ffmpeg.input(pattern, framerate=fps, pattern_type='glob') + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) # [B, C, 1, H/8, W/8] + return latents + +# ---------------- MAIN ---------------- +def main(args): + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 2) + transition_frames = cfg.get("transition_frames", 8) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = ( + len(input_paths) * num_fraps_per_image * max(len(prompts), 1) + + max(len(input_paths) - 1, 0) * transition_frames + ) + print(f"🎞 Frames totales estimées : {total_frames}") + print("⏹ Touche '²' pour arrêter la génération et création de la vidéo directement...") + + # ---------------- LOAD MODELS ---------------- + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + scheduler.set_timesteps(steps, device=device) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # ---------------- PROMPT EMBEDDINGS ---------------- + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------- OUTPUT ---------------- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fastperfect_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + previous_latent_single = None + + + stop_generation = False # option d'arret du script on passe la saugarde de la video' + # ================= MAIN LOOP ================= + for img_idx, img_path in enumerate(input_paths): + if stop_generation: + break + + # Charge et encode l'image d'entrée + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + input_latents_single = encode_images_to_latents(input_image, vae) # [B, C, 1, H/8, W/8] + + # Réplication des latents pour toutes les fraps + input_latents = input_latents_single.repeat(1, 1, num_fraps_per_image, 1, 1) # [B, C, num_fraps, H/8, W/8] + + current_latent_single = input_latents_single.clone() + block_size = cfg.get("block_size", 64) + overlap = compute_overlap(cfg["W"], cfg["H"], block_size) + + # --- Transition latente avec interpolation --- + if previous_latent_single is not None and transition_frames > 0: + for t in range(transition_frames): + if stop_generation: + print("⏹ Arrêt demandé, création de la vidéo...") + break + + # ... génération normale des frames ... + alpha = 0.5 - 0.5 * math.cos(math.pi * t / (transition_frames - 1)) + latent_interp = ((1 - alpha) * previous_latent_single + alpha * current_latent_single) + latent_interp = latent_interp.squeeze(2).clamp(-3.0, 3.0) + + frame_tensor = decode_latents_to_image_auto(latent_interp, vae) + frame_tensor = normalize_frame(frame_tensor) + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(radius=0.2)) + if upscale_factor > 1: + frame_pil = frame_pil.resize((frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), resample=Image.BICUBIC) + + frame_path = output_dir / f"frame_{frame_counter:05d}.png" + frame_pil.save(frame_path) + + del latent_interp, frame_tensor, frame_pil + torch.cuda.empty_cache() + frame_counter += 1 + pbar.update(1) + + # --- Boucle principale sur fraps et prompts --- + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + if f == 0: + frame_tensor = (input_image.squeeze(0) + 1.0) / 2.0 + frame_tensor = frame_tensor.clamp(0, 1) + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + try: + latents_frame = generate_latents_robuste( + latents_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_counter + ) + except Exception: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + + frame_tensor = decode_latents_to_image_auto(latents_frame, vae) + frame_tensor = normalize_frame(frame_tensor) + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(radius=0.2)) + if upscale_factor > 1: + frame_pil = frame_pil.resize((frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), resample=Image.BICUBIC) + + frame_path = output_dir / f"frame_{frame_counter:05d}.png" + frame_pil.save(frame_path) + + if f != 0: + del latents_frame + del frame_tensor, frame_pil + torch.cuda.empty_cache() + frame_counter += 1 + pbar.update(1) + + previous_latent_single = current_latent_single.clone() + + pbar.close() + save_frames_as_video_from_folder(output_dir, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + args = parser.parse_args() + main(args) diff --git a/scripts/n3rvideo2video.py b/scripts/n3rvideo2video.py new file mode 100644 index 00000000..8d6cb39e --- /dev/null +++ b/scripts/n3rvideo2video.py @@ -0,0 +1,476 @@ +# -------------------------------------------------------------- +# nr3perfect - INTERPOLATION fast movie - Optimal (video support) +# -------------------------------------------------------------- + +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +import cv2 +from PIL import Image, ImageFilter +from torchvision.transforms import ToPILImage +from transformers import CLIPTokenizerFast, CLIPTextModel +from PIL import ImageDraw, ImageFont +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import ( + safe_load_unet, + safe_load_scheduler, + safe_load_vae_stable +) +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import ( + generate_latents_robuste, + load_image_file, + decode_latents_to_image_auto +) + +LATENT_SCALE = 0.18215 + + +# -------------------------------------------------------------- +# STOP THREAD +# -------------------------------------------------------------- + +import threading +import numpy as np +stop_generation = False + +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True + +threading.Thread(target=wait_for_stop, daemon=True).start() + +# -------------------------------------------------------------- +# UTILS +# -------------------------------------------------------------- +# ---------------- TRACKING GLOBALS ---------------- +tracker = None +tracking_initialized = False + +# ---------------- WATERMARK ---------------- +import matplotlib.pyplot as plt + + +# ---------------- WATERMARK OPTIMISÉ ---------------- +def remove_watermark_with_cached_mask( + frame_pil, + target_hex_list, + candidate_zones, + tolerance=26, + threshold=0.4, + blur_radius=10, + feather_radius=8, + overlay_text=None, + text_opacity=0.6, + text_scale=0.7, + text_color=(255, 255, 255), + cached_mask=None, + force_recalc=False +): + """ + Retire le watermark avec masque mis en cache pour accélérer le traitement. + - cached_mask: liste de tuples (x, y, w, h, mask_soft) pour réutilisation + - force_recalc: recalculer le masque même si cached_mask existe + """ + img_np = np.array(frame_pil).astype(np.int16) + H, W, _ = img_np.shape + + if cached_mask is None or force_recalc: + mask_store = [] + + target_colors = np.array( + [[int(h[i:i+2], 16) for i in (1, 3, 5)] for h in target_hex_list], + dtype=np.int16 + ) + + if candidate_zones is None: + candidate_zones = [(0, 0, W, H)] + + for (x, y, w, h) in candidate_zones: + patch = img_np[y:y+h, x:x+w] + + mask_total = np.zeros((h, w), dtype=np.uint8) + for color in target_colors: + dist = np.linalg.norm(patch - color, axis=2) + mask_total += (dist <= tolerance).astype(np.uint8) + + ratio = mask_total.sum() / (w * h) + if ratio >= threshold: + mask_binary = mask_total * 255 + mask_binary = cv2.GaussianBlur(mask_binary, (0, 0), feather_radius) + mask_binary = mask_binary.astype(np.float32) / 255.0 + else: + mask_binary = np.zeros((h, w), dtype=np.float32) + + mask_store.append((x, y, w, h, mask_binary)) + + cached_mask = mask_store + + # ---- APPLICATION DU MASK ---- + for (x, y, w, h, mask_soft) in cached_mask: + mask_soft_exp = np.expand_dims(mask_soft, axis=2).astype(np.float32) + + region = img_np[y:y+h, x:x+w].astype(np.float32) + region_blur = cv2.GaussianBlur(region, (0, 0), blur_radius) + + blended = region * (1 - mask_soft_exp) + region_blur * mask_soft_exp + img_np[y:y+h, x:x+w] = blended.astype(np.uint8) + + blended_img = Image.fromarray(img_np.astype(np.uint8)) + + # ---- TEXT OVERLAY ---- + if overlay_text is not None: + draw_layer = Image.new("RGBA", blended_img.size, (0,0,0,0)) + draw = ImageDraw.Draw(draw_layer) + + for (x, y, w, h, _) in cached_mask: + try: + font_size = int(h * text_scale) + font = ImageFont.truetype("arial.ttf", font_size) + except: + font = ImageFont.load_default() + + bbox = draw.textbbox((0,0), overlay_text, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + + tx = x + (w - text_w)//2 + ty = y + (h - text_h)//2 + + draw.text( + (tx, ty), + overlay_text, + font=font, + fill=(*text_color, int(255 * text_opacity)) + ) + + blended_img = Image.alpha_composite(blended_img.convert("RGBA"), draw_layer).convert("RGB") + + return blended_img, cached_mask + +#--------------------------------------------------------------------------------------- +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + return overlap + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + + +def load_images_from_pil(pil_images, W, H, device, dtype, preproc_fn=None): + """ + Charge une liste de PIL.Images et retourne un tensor 4D [B, C, H, W]. + - preproc_fn : fonction optionnelle à appliquer après resize et avant conversion tensor + """ + all_tensors = [] + + for idx, img_pil in enumerate(pil_images): + # Redimensionner + img_resized = img_pil.resize((W, H), Image.LANCZOS) + + # Appliquer un pré-traitement optionnel (ex: flou watermark) + if preproc_fn is not None: + img_resized = preproc_fn(img_resized) + + # Convertir en numpy 0..1 + img_np = np.array(img_resized).astype(np.float32) / 255.0 + # Convertir en tensor torch [C, H, W] + img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).to(device=device, dtype=dtype) + # Normaliser [-1, 1] + img_tensor = img_tensor * 2 - 1 + all_tensors.append(img_tensor) + print(f"✅ Image chargée et préparée : {idx}") + + return torch.stack(all_tensors, dim=0) + + +def extract_frames_from_video(video_path, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + cap = cv2.VideoCapture(video_path) + frame_paths = [] + idx = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + frame_path = output_dir / f"video_frame_{idx:05d}.png" + cv2.imwrite(str(frame_path), frame) + frame_paths.append(str(frame_path)) + idx += 1 + + cap.release() + print(f"🎬 {idx} frames extraites depuis la vidéo.") + return frame_paths + +def save_frames_as_video_from_folder(folder_path, output_path, fps=12): + import ffmpeg + folder_path = Path(folder_path) + + frame_files = sorted(folder_path.glob("frame_*.png")) + if not frame_files: + print("❌ Aucun frame trouvé.") + return + + pattern = str(folder_path / "frame_*.png") + + ( + ffmpeg.input(pattern, framerate=fps, pattern_type='glob') + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +# -------------------------------------------------------------- +# MAIN +# -------------------------------------------------------------- + +def main(args): + + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 2) + transition_frames = cfg.get("transition_frames", 8) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + rm_watermark = cfg.get("rm_watermark", True) + + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + # ---------------------------------------------------------- + # INPUT IMAGE / VIDEO + # ---------------------------------------------------------- + + if args.input_video: + print("🎥 Mode vidéo activé") + temp_dir = Path("./temp_video_frames") + input_paths = extract_frames_from_video(args.input_video, temp_dir) + else: + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * max(len(prompts), 1) * num_fraps_per_image + if rm_watermark: + print(f"🎞 Remove Water Active") + + print(f"🎞 Frames estimées : {total_frames}") + print("⏹ Appuyez sur '²' + Entrée pour stopper.") + + # ---------------------------------------------------------- + # LOAD MODELS + # ---------------------------------------------------------- + + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + scheduler.set_timesteps(steps, device=device) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # ---------------------------------------------------------- + # PROMPT EMBEDDINGS + # ---------------------------------------------------------- + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------------------------------------------------- + # OUTPUT + # ---------------------------------------------------------- + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fastperfect_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------------------------------------------------- + # MAIN LOOP + # ---------------------------------------------------------- + # Variables pour watermark mis en cache + watermark_recalc_every = 10 + last_mask = None + frame_index = 0 + #----------------------------------------------------------- + for img_path in input_paths: + + if stop_generation: + break + + #----------- fonction remove watermark auto + if rm_watermark: + colors = ["#EADBDE", "#FDF8F4", "#FFFAF9", "#F9EDEB", "#C5BABA", + "#FBF8FA", "#FAF8F9", "#FBFAF5", "#ECE3E2", "#6D6C6B", + "#F9F3EF", "#DDCAC6", "#DDCAC6", "#F3EBEB", "#FDF8F6", + "#FFFAFF", "#FDFCFB", "#FAF8F8", "#F9F4F7", "#222021"] + candidate_zones = [ + (21, 366, 123, 28), + (421, 572, 125, 28) + ] + + # Chargement PIL + pil_img = Image.open(img_path).convert("RGB") + + # Recalcul du watermark tous les watermark_recalc_every frames + force_recalc = (frame_index % watermark_recalc_every == 0) + + pil_img, last_mask = remove_watermark_with_cached_mask( + pil_img, + target_hex_list=colors, + candidate_zones=candidate_zones, + overlay_text="N3ORAY", + text_opacity=0.9, + text_scale=4.0, + cached_mask=last_mask, + force_recalc=force_recalc + ) + + # Convertir en tensor + input_image = load_images_from_pil( + [pil_img], + W=cfg["W"], + H=cfg["H"], + device=device, + dtype=dtype + ) + else: + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + + input_latents_single = encode_images_to_latents(input_image, vae) + input_latents = input_latents_single.repeat(1, 1, num_fraps_per_image, 1, 1) + + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + + if stop_generation: + break + + if f == 0: + frame_tensor = (input_image.squeeze(0) + 1.0) / 2.0 + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + try: + latents_frame = generate_latents_robuste( + latents_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_counter + ) + except: + pass + + frame_tensor = decode_latents_to_image_auto(latents_frame, vae) + del latents_frame + + frame_tensor = normalize_frame(frame_tensor) + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(radius=0.2)) + + # ----------------------------------------------------------------------------- + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + + frame_path = output_dir / f"frame_{frame_counter:05d}.png" + frame_pil.save(frame_path) + + del frame_tensor, frame_pil + #torch.cuda.empty_cache() + + frame_counter += 1 + pbar.update(1) + + # Nettoyage VRAM toutes les 20 frames + if frame_counter % 20 == 0: + torch.cuda.empty_cache() + frame_index += 1 + pbar.close() + + save_frames_as_video_from_folder(output_dir, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# -------------------------------------------------------------- +# ARGPARSE +# -------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + parser.add_argument("--input-video", type=str, default=None) + args = parser.parse_args() + + main(args) diff --git a/scripts/n3rvideotovideo.py b/scripts/n3rvideotovideo.py new file mode 100644 index 00000000..aa084f6a --- /dev/null +++ b/scripts/n3rvideotovideo.py @@ -0,0 +1,473 @@ +# -------------------------------------------------------------- +# nr3perfect - INTERPOLATION fast movie - Optimal (video support) +# -------------------------------------------------------------- + +import argparse +from pathlib import Path +from tqdm import tqdm +import torch +from datetime import datetime +import os +import math +import cv2 +from PIL import Image, ImageFilter +from torchvision.transforms import ToPILImage + +from transformers import CLIPTokenizerFast, CLIPTextModel + +from scripts.utils.config_loader import load_config +from scripts.utils.vae_utils import ( + safe_load_unet, + safe_load_scheduler, + safe_load_vae_stable +) +from scripts.utils.motion_utils import load_motion_module +from scripts.utils.n3r_utils import ( + generate_latents_robuste, + load_image_file, + decode_latents_to_image_auto +) + +LATENT_SCALE = 0.18215 + + +# -------------------------------------------------------------- +# STOP THREAD +# -------------------------------------------------------------- + +import threading +import numpy as np +stop_generation = False + +def wait_for_stop(): + global stop_generation + inp = input("Appuyez sur '²' + Entrée pour arrêter : ") + if inp.lower() == "²": + stop_generation = True + +threading.Thread(target=wait_for_stop, daemon=True).start() + +# -------------------------------------------------------------- +# UTILS +# -------------------------------------------------------------- +# ---------------- TRACKING GLOBALS ---------------- +tracker = None +tracking_initialized = False + +# ---------------- WATERMARK ---------------- +from PIL import Image +import numpy as np +import cv2 +import matplotlib.pyplot as plt + + +def remove_watermark_auto_blur( + frame_pil, + target_hex_list, + tolerance=26, + threshold=0.4, + candidate_zones=None, + blur_radius=10, + feather_radius=8, + overlay_text=None, + text_opacity=0.6, + text_scale=0.7, + text_color=(255, 255, 255) +): + import numpy as np + import cv2 + from PIL import Image, ImageFilter, ImageDraw, ImageFont + + img_np = np.array(frame_pil).astype(np.int16) + H, W, _ = img_np.shape + + target_colors = np.array( + [[int(h[i:i+2], 16) for i in (1, 3, 5)] for h in target_hex_list], + dtype=np.int16 + ) + + if candidate_zones is None: + candidate_zones = [(0, 0, W, H)] + + for (x, y, w, h) in candidate_zones: + + patch = img_np[y:y+h, x:x+w] + mask_total = np.zeros((h, w), dtype=np.uint8) + + for color in target_colors: + dist = np.linalg.norm(patch - color, axis=2) + mask_total += (dist <= tolerance).astype(np.uint8) + + ratio = mask_total.sum() / (w * h) + + if ratio >= threshold: + + mask_binary = (mask_total > 0).astype(np.uint8) * 255 + + kernel = np.ones((3, 3), np.uint8) + mask_binary = cv2.dilate(mask_binary, kernel, iterations=1) + + mask_soft = cv2.GaussianBlur(mask_binary, (0, 0), feather_radius) + mask_soft = mask_soft.astype(np.float32) / 255.0 + mask_soft = np.expand_dims(mask_soft, axis=2) + + region = frame_pil.crop((x, y, x+w, y+h)) + region_np = np.array(region).astype(np.float32) + + region_blur = region.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + region_blur_np = np.array(region_blur).astype(np.float32) + + blended = region_np * (1 - mask_soft) + region_blur_np * mask_soft + blended = blended.astype(np.uint8) + + blended_img = Image.fromarray(blended) + + # ------------------ TEXT OVERLAY ------------------ + if overlay_text is not None: + + draw = ImageDraw.Draw(blended_img) + + try: + font_size = int(h * text_scale) + font = ImageFont.truetype("arial.ttf", font_size) + except: + font = ImageFont.load_default() + + #text_w, text_h = draw.textsize(overlay_text, font=font) + bbox = draw.textbbox((0, 0), overlay_text, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + + tx = (w - text_w) // 2 + ty = (h - text_h) // 2 + + text_layer = Image.new("RGBA", blended_img.size, (0,0,0,0)) + text_draw = ImageDraw.Draw(text_layer) + + text_draw.text( + (tx, ty), + overlay_text, + font=font, + fill=(*text_color, int(255 * text_opacity)) + ) + + blended_img = Image.alpha_composite( + blended_img.convert("RGBA"), + text_layer + ).convert("RGB") + + frame_pil.paste(blended_img, (x, y)) + + return frame_pil + + +#--------------------------------------------------------------------------------------- +def normalize_frame(frame_tensor): + if frame_tensor.min() < 0: + frame_tensor = (frame_tensor + 1.0) / 2.0 + return frame_tensor.clamp(0, 1) + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.6): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + return overlap + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + + +def load_images_from_pil(pil_images, W, H, device, dtype, preproc_fn=None): + """ + Charge une liste de PIL.Images et retourne un tensor 4D [B, C, H, W]. + - preproc_fn : fonction optionnelle à appliquer après resize et avant conversion tensor + """ + all_tensors = [] + + for idx, img_pil in enumerate(pil_images): + # Redimensionner + img_resized = img_pil.resize((W, H), Image.LANCZOS) + + # Appliquer un pré-traitement optionnel (ex: flou watermark) + if preproc_fn is not None: + img_resized = preproc_fn(img_resized) + + # Convertir en numpy 0..1 + img_np = np.array(img_resized).astype(np.float32) / 255.0 + # Convertir en tensor torch [C, H, W] + img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).to(device=device, dtype=dtype) + # Normaliser [-1, 1] + img_tensor = img_tensor * 2 - 1 + all_tensors.append(img_tensor) + print(f"✅ Image chargée et préparée : {idx}") + + return torch.stack(all_tensors, dim=0) + + +def extract_frames_from_video(video_path, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + cap = cv2.VideoCapture(video_path) + frame_paths = [] + idx = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + frame_path = output_dir / f"video_frame_{idx:05d}.png" + cv2.imwrite(str(frame_path), frame) + frame_paths.append(str(frame_path)) + idx += 1 + + cap.release() + print(f"🎬 {idx} frames extraites depuis la vidéo.") + return frame_paths + +def save_frames_as_video_from_folder(folder_path, output_path, fps=12): + import ffmpeg + folder_path = Path(folder_path) + + frame_files = sorted(folder_path.glob("frame_*.png")) + if not frame_files: + print("❌ Aucun frame trouvé.") + return + + pattern = str(folder_path / "frame_*.png") + + ( + ffmpeg.input(pattern, framerate=fps, pattern_type='glob') + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + +def encode_images_to_latents(images, vae): + images = images.to(device=vae.device, dtype=torch.float32) + with torch.inference_mode(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +# -------------------------------------------------------------- +# MAIN +# -------------------------------------------------------------- + +def main(args): + + cfg = load_config(args.config) + + device = args.device if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if args.fp16 else torch.float32 + + fps = cfg.get("fps", 12) + upscale_factor = cfg.get("upscale_factor", 2) + transition_frames = cfg.get("transition_frames", 8) + num_fraps_per_image = cfg.get("num_fraps_per_image", 12) + rm_watermark = cfg.get("rm_watermark", True) + + steps = cfg.get("steps", 50) + guidance_scale = cfg.get("guidance_scale", 4.5) + init_image_scale = cfg.get("init_image_scale", 0.85) + creative_noise = cfg.get("creative_noise", 0.0) + + # ---------------------------------------------------------- + # INPUT IMAGE / VIDEO + # ---------------------------------------------------------- + + if args.input_video: + print("🎥 Mode vidéo activé") + temp_dir = Path("./temp_video_frames") + input_paths = extract_frames_from_video(args.input_video, temp_dir) + else: + input_paths = cfg.get("input_images") or [cfg.get("input_image")] + + prompts = cfg.get("prompt", []) + negative_prompts = cfg.get("n_prompt", []) + + total_frames = len(input_paths) * max(len(prompts), 1) * num_fraps_per_image + if rm_watermark: + print(f"🎞 Remove Water Active") + + print(f"🎞 Frames estimées : {total_frames}") + print("⏹ Appuyez sur '²' + Entrée pour stopper.") + + # ---------------------------------------------------------- + # LOAD MODELS + # ---------------------------------------------------------- + + unet = safe_load_unet(args.pretrained_model_path, device, fp16=args.fp16) + vae = safe_load_vae_stable(args.pretrained_model_path, device, fp16=args.fp16, offload=args.vae_offload) + scheduler = safe_load_scheduler(args.pretrained_model_path) + scheduler.set_timesteps(steps, device=device) + + motion_module = load_motion_module(cfg.get("motion_module"), device=device) if cfg.get("motion_module") else None + + tokenizer = CLIPTokenizerFast.from_pretrained(os.path.join(args.pretrained_model_path, "tokenizer")) + text_encoder = CLIPTextModel.from_pretrained(os.path.join(args.pretrained_model_path, "text_encoder")).to(device) + if args.fp16: + text_encoder = text_encoder.half() + + # ---------------------------------------------------------- + # PROMPT EMBEDDINGS + # ---------------------------------------------------------- + + embeddings = [] + for prompt_item in prompts: + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + neg_text = " ".join(negative_prompts) if isinstance(negative_prompts, list) else str(negative_prompts) + + text_inputs = tokenizer(prompt_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(neg_text, padding="max_length", truncation=True, + max_length=tokenizer.model_max_length, return_tensors="pt") + + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + embeddings.append((pos_embeds.to(dtype), neg_embeds.to(dtype))) + + # ---------------------------------------------------------- + # OUTPUT + # ---------------------------------------------------------- + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"./outputs/fastperfect_{timestamp}") + output_dir.mkdir(parents=True, exist_ok=True) + out_video = output_dir / f"output_{timestamp}.mp4" + + to_pil = ToPILImage() + frame_counter = 0 + pbar = tqdm(total=total_frames, ncols=120) + + # ---------------------------------------------------------- + # MAIN LOOP + # ---------------------------------------------------------- + + for img_path in input_paths: + + if stop_generation: + break + + #----------- fonction remove watermark auto + if rm_watermark: + colors = ["#EADBDE", "#FDF8F4", "#FFFAF9", "#F9EDEB", "#C5BABA", + "#FBF8FA", "#FAF8F9", "#FBFAF5", "#ECE3E2", "#6D6C6B", + "#F9F3EF", "#DDCAC6", "#DDCAC6", "#F3EBEB", "#FDF8F6", + "#FFFAFF", "#FDFCFB", "#FAF8F8", "#F9F4F7", "#222021"] + candidate_zones = [ + (21, 366, 123, 28), + (421, 572, 125, 28) + ] + + def watermark_blur_preproc(img): + return remove_watermark_auto_blur( + img, + target_hex_list=colors, + candidate_zones=candidate_zones, + overlay_text="N3ORAY", + text_opacity=0.9, + text_scale=4.0 + ) + + # Chargement avec flou watermark appliqué **après resize** + input_image = load_images_from_pil( + [Image.open(img_path).convert("RGB")], + W=cfg["W"], + H=cfg["H"], + device=device, + dtype=dtype, + preproc_fn=watermark_blur_preproc + ) + else: + input_image = load_images([img_path], W=cfg["W"], H=cfg["H"], device=device, dtype=dtype) + + input_latents_single = encode_images_to_latents(input_image, vae) + input_latents = input_latents_single.repeat(1, 1, num_fraps_per_image, 1, 1) + + for pos_embeds, neg_embeds in embeddings: + for f in range(num_fraps_per_image): + + if stop_generation: + break + + if f == 0: + frame_tensor = (input_image.squeeze(0) + 1.0) / 2.0 + else: + latents_frame = input_latents[:, :, f:f+1, :, :].clone() + try: + latents_frame = generate_latents_robuste( + latents_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=frame_counter + ) + except: + pass + + frame_tensor = decode_latents_to_image_auto(latents_frame, vae) + del latents_frame + + frame_tensor = normalize_frame(frame_tensor) + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor.squeeze(0) + + frame_pil = to_pil(frame_tensor.cpu()).filter(ImageFilter.GaussianBlur(radius=0.2)) + + # ----------------------------------------------------------------------------- + if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + resample=Image.BICUBIC + ) + + frame_path = output_dir / f"frame_{frame_counter:05d}.png" + frame_pil.save(frame_path) + + del frame_tensor, frame_pil + #torch.cuda.empty_cache() + + frame_counter += 1 + pbar.update(1) + + # Nettoyage VRAM toutes les 20 frames + if frame_counter % 20 == 0: + torch.cuda.empty_cache() + + pbar.close() + + save_frames_as_video_from_folder(output_dir, out_video, fps=fps) + print(f"🎬 Vidéo générée : {out_video}") + print("✅ Pipeline terminé proprement.") + +# -------------------------------------------------------------- +# ARGPARSE +# -------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained-model-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--fp16", action="store_true", default=True) + parser.add_argument("--vae-offload", action="store_true") + parser.add_argument("--input-video", type=str, default=None) + args = parser.parse_args() + + main(args) diff --git a/scripts/test-post.py b/scripts/test-post.py new file mode 100644 index 00000000..6b1fb77f --- /dev/null +++ b/scripts/test-post.py @@ -0,0 +1,58 @@ +# test_post_processing.py +from PIL import Image, ImageEnhance, ImageFilter +import argparse + +def apply_post_processing(frame_pil, + blur_radius=0.1, + contrast=1.3, + brightness=1.1, + saturation=0.6, + sharpen=True, + sharpen_radius=1, + sharpen_percent=80, + sharpen_threshold=3): + """Appliquer des effets post-decode sur une frame PIL avec blur, contraste, luminosité, saturation et sharpen.""" + # GaussianBlur + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # Ajustements + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + # UnsharpMask + if sharpen: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + + return frame_pil + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Tester le post-processing d'une image PIL") + parser.add_argument("--input", type=str, required=True, help="Chemin vers l'image d'entrée") + parser.add_argument("--output", type=str, default="output.png", help="Chemin de l'image sortie") + parser.add_argument("--blur", type=float, default=0.2, help="Rayon du blur") + parser.add_argument("--contrast", type=float, default=1.5, help="Facteur de contraste") + parser.add_argument("--brightness", type=float, default=1.0, help="Facteur de luminosité") + parser.add_argument("--saturation", type=float, default=1.05, help="Facteur de saturation") + parser.add_argument("--sharpen", action="store_true", help="Activer le sharpen") + args = parser.parse_args() + + img = Image.open(args.input).convert("RGB") + processed = apply_post_processing( + img, + blur_radius=args.blur, + contrast=args.contrast, + brightness=args.brightness, + saturation=args.saturation, + sharpen=args.sharpen + ) + processed.save(args.output) + print(f"✅ Image traitée sauvegardée : {args.output}") diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 00000000..5af80905 --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,525 @@ +import torch +from PIL import Image, ImageEnhance +from torchvision.transforms import ToTensor, ToPILImage +from pathlib import Path +from torchvision.transforms import functional as F +import math +import itertools +import numpy as np + + +from scripts.utils.vae_utils import safe_load_unet +from scripts.utils.n3r_utils import load_images_test +from scripts.modules.motion_ulta_lite_fix import MotionModuleUltraLiteFixComplete +from scripts.utils.n3r_utils import LATENT_SCALE +from scripts.utils.tools_utils import ( + prepare_frame_tensor, + normalize_frame, + tensor_to_pil, + ensure_4_channels, + save_frames_as_video_from_folder +) +#from scripts.n3rmodelSD import encode_images_to_latents_safe + +# ------------------- DECODE ULTRA-SAFE GLOBAL ------------------- +#gamma = 1.0 # 1.2 +#brightness = 1.0 # 1.0 +#contrast = 1.3 # 1.2 +#saturation = 1.1 # 1.2 + + + +def test_parameter_grid_with_boost( + latents_motion, vae, img_orig, + gammas=[1.0,1.1,1.2,1.5], + contrasts=[1.1,1.2,1.3,1.5], + saturations=[1.0,1.1,1.2,1.5], + brightnesses=[1.0], + epsilon=1e-5, + latent_scale_boost=5.71, #[5.4, 5.5, 5.6, 5.7, 5.8] + device="cuda" +): + results = [] + + # Convertir img_orig en tensor sur le device + if isinstance(img_orig, list) or isinstance(img_orig, tuple): + img_orig_tensor = img_orig[0] + else: + img_orig_tensor = img_orig + if isinstance(img_orig_tensor, Image.Image): + img_orig_tensor = torch.tensor(np.array(img_orig_tensor) / 255.0).permute(2,0,1) + img_orig_tensor = img_orig_tensor.to(device=device, dtype=torch.float32) + + for gamma, contrast, saturation, brightness in itertools.product(gammas, contrasts, saturations, brightnesses): + # Décodage blockwise avec boost et epsilon + decoded_img = decode_latents_ultrasafe_blockwise( + latents_motion, vae, + gamma=gamma, + contrast=contrast, + saturation=saturation, + brightness=brightness, + device=device, + epsilon=epsilon, + latent_scale_boost=latent_scale_boost + ) + + # Si BATCH =1, assure tensor + if isinstance(decoded_img, list): + decoded_img = decoded_img[0] + + # Redimensionner le décodé pour matcher l'original + decoded_img_resized = decoded_img.resize( + (img_orig_tensor.shape[2], img_orig_tensor.shape[1]), Image.BICUBIC + ) + + # Comparer + stats = compare_images_stats_v2(img_orig_tensor, decoded_img_resized, device=device) + stats.update({ + "gamma": gamma, + "contrast": contrast, + "saturation": saturation, + "brightness": brightness, + "epsilon": epsilon, + "latent_scale_boost": latent_scale_boost + }) + results.append(stats) + + # Trier par meilleure fidélité (mean_diff_total) + results_sorted = sorted(results, key=lambda x: x["mean_diff_total"]) + return results_sorted + +def compare_images_stats_v2(img_orig, img_decoded, threshold=0.05, device="cuda"): + """ + Compare deux images PIL ou tensors et retourne les écarts statistiques. + - img_orig : tensor [C,H,W] ou [B,C,H,W] ou PIL + - img_decoded : tensor ou PIL + - threshold : seuil pour % de pixels différents + - device : 'cuda' ou 'cpu' + """ + + # --- Convertir PIL en tensor float [0,1] --- + if isinstance(img_orig, Image.Image): + img_orig = torch.tensor(np.array(img_orig) / 255.0).permute(2,0,1) + if isinstance(img_decoded, Image.Image): + img_decoded = torch.tensor(np.array(img_decoded) / 255.0).permute(2,0,1) + + # --- Retirer batch si nécessaire --- + if img_orig.ndim == 4: # [B,C,H,W] + img_orig = img_orig[0] + if img_decoded.ndim == 4: # [B,C,H,W] + img_decoded = img_decoded[0] + + # --- Redimensionner img_decoded pour matcher img_orig --- + C,H,W = img_orig.shape + if img_decoded.shape[1:] != (H,W): # shape [C,H_dec,W_dec] + # torchvision functional resize attend [C,H,W] mais resize = (H,W) + img_decoded = F.resize(img_decoded, size=(H,W), antialias=True) + + # --- Envoyer sur le device --- + img_orig = img_orig.to(device=device, dtype=torch.float32) + img_decoded = img_decoded.to(device=device, dtype=torch.float32) + + # --- Calcul des différences --- + diff = torch.abs(img_orig - img_decoded) + mean_diff_per_channel = diff.view(3,-1).mean(dim=1).cpu().numpy() + mean_diff_total = diff.mean().item() + percent_diff = 100 * (diff.max(dim=0)[0] > threshold).sum().item() / (diff.shape[1]*diff.shape[2]) + + return { + "mean_diff_r": mean_diff_per_channel[0], + "mean_diff_g": mean_diff_per_channel[1], + "mean_diff_b": mean_diff_per_channel[2], + "mean_diff_total": mean_diff_total, + "percent_diff_pixels": percent_diff + } + + + +def compare_images_stats(img_orig, img_decoded, threshold=0.05, device="cuda"): + """Compare deux images PIL ou tensors et retourne les écarts statistiques""" + if isinstance(img_orig, Image.Image): + img_orig = torch.tensor(np.array(img_orig) / 255.0).permute(2,0,1).to(device) + if isinstance(img_decoded, Image.Image): + img_decoded = torch.tensor(np.array(img_decoded) / 255.0).permute(2,0,1).to(device) + + assert img_orig.shape == img_decoded.shape, "Les images doivent avoir la même taille" + + diff = torch.abs(img_orig - img_decoded) + mean_diff_per_channel = diff.view(3,-1).mean(dim=1).cpu().numpy() + mean_diff_total = diff.mean().item() + percent_diff = 100 * (diff.max(dim=0)[0] > threshold).sum().item() / (diff.shape[1]*diff.shape[2]) + + return { + "mean_diff_r": mean_diff_per_channel[0], + "mean_diff_g": mean_diff_per_channel[1], + "mean_diff_b": mean_diff_per_channel[2], + "mean_diff_total": mean_diff_total, + "percent_diff_pixels": percent_diff + } + + +def test_parameter_grid_extended(latents_motion, vae, img_orig, + gammas=[1.0, 1.2, 1.5], + contrasts=[1.0, 1.2, 1.5], + saturations=[1.0, 1.2, 1.5], + brightnesses=[1.0], + device="cuda"): + + results = [] + + # Convertir img_orig en tensor sur le device + if isinstance(img_orig, list) or isinstance(img_orig, tuple): + img_orig_tensor = img_orig[0] + else: + img_orig_tensor = img_orig + if isinstance(img_orig_tensor, Image.Image): + img_orig_tensor = torch.tensor(np.array(img_orig_tensor) / 255.0).permute(2,0,1) + img_orig_tensor = img_orig_tensor.to(device=device, dtype=torch.float32) + + for gamma, contrast, saturation, brightness in itertools.product(gammas, contrasts, saturations, brightnesses): + # Décodage blockwise + decoded_img = decode_latents_ultrasafe_blockwise( + latents_motion, vae, + gamma=gamma, + contrast=contrast, + saturation=saturation, + brightness=brightness, + device=device + ) + + # Si BATCH =1, assure tensor + if isinstance(decoded_img, list): + decoded_img = decoded_img[0] + + # Redimensionner le décodé pour matcher l'original + decoded_img_resized = decoded_img.resize( + (img_orig_tensor.shape[2], img_orig_tensor.shape[1]), Image.BICUBIC + ) + + # Comparer + stats = compare_images_stats(img_orig_tensor, decoded_img_resized, device=device) + stats.update({ + "gamma": gamma, + "contrast": contrast, + "saturation": saturation, + "brightness": brightness + }) + results.append(stats) + + # Trier par meilleure fidélité (mean_diff_total) + results_sorted = sorted(results, key=lambda x: x["mean_diff_total"]) + return results_sorted + +# ------------------------------------------------------------- +# Comparatif statistique intégré +# ------------------------------------------------------------- + + +def compare_images_stats_v1(img_orig, img_decoded, threshold=0.05): + """Compare deux images PIL ou tensors et retourne les écarts statistiques""" + # Convertir en tensor CPU float [0,1] + if isinstance(img_orig, Image.Image): + img_orig = torch.tensor(np.array(img_orig) / 255.0).permute(2,0,1) + if isinstance(img_decoded, Image.Image): + img_decoded = torch.tensor(np.array(img_decoded) / 255.0).permute(2,0,1) + + # Forcer CPU pour éviter le RuntimeError + img_orig = img_orig.cpu() + img_decoded = img_decoded.cpu() + + # S’assurer que les tailles correspondent + assert img_orig.shape == img_decoded.shape, "Les images doivent avoir la même taille" + + diff = torch.abs(img_orig - img_decoded) + mean_diff_per_channel = diff.view(3,-1).mean(dim=1).numpy() + mean_diff_total = diff.mean().item() + percent_diff = 100 * (diff.max(dim=0)[0] > threshold).sum().item() / (diff.shape[1]*diff.shape[2]) + + return { + "mean_diff_r": mean_diff_per_channel[0], + "mean_diff_g": mean_diff_per_channel[1], + "mean_diff_b": mean_diff_per_channel[2], + "mean_diff_total": mean_diff_total, + "percent_diff_pixels": percent_diff + } + +def apply_adjustments(img_pil, gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0): + img = ImageEnhance.Brightness(img_pil).enhance(brightness) + img = ImageEnhance.Contrast(img).enhance(contrast) + img = ImageEnhance.Color(img).enhance(saturation) + if gamma != 1.0: + img = img.point(lambda x: 255 * ((x/255) ** (1/gamma))) + return img + +def test_parameter_grid(latents_motion, vae, img_orig, + gammas=[1.0,1.1,1.2,1.5], + contrasts=[1.0,1.2,1.3,1.5], + saturations=[1.0,1.2,1.3,1.5], + brightnesses=[1.0, 1.1]): + + results = [] + for gamma, contrast, saturation, brightness in itertools.product(gammas, contrasts, saturations, brightnesses): + decoded_img = decode_latents_ultrasafe_blockwise( + latents_motion, vae, + gamma=gamma, + contrast=contrast, + saturation=saturation, + brightness=brightness + ) + stats = compare_images_stats_v1(img_orig[0], decoded_img) + stats.update({"gamma": gamma, "contrast": contrast, "saturation": saturation, "brightness": brightness}) + results.append(stats) + + # Trier par meilleure fidélité + results_sorted = sorted(results, key=lambda x: x["mean_diff_total"]) + return results_sorted +# ------------------- DECODE ULTRA-SAFE BLOCKWISE ------------------- +# -------------------------------------------------------------- +# decode_latents_ultrasafe_blockwise.py +# -------------------------------------------------------------- + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +@torch.no_grad() +def decode_latents_ultrasafe_blockwise_test(latents, vae, block_size=32, device='cuda', dtype=torch.float16): + """ + Décodage des latents avec tiling et pondération cosinus pour éviter les artefacts de patch. + + Args: + latents: Tensor [1, 4, H_latent, W_latent] + vae: modèle VAE (decode) + block_size: taille du patch latent (en latent space) + device: 'cuda' ou 'cpu' + dtype: torch.float16 pour accélérer sur GPU + """ + + B, C, H, W = latents.shape + latents = latents.to(device, dtype=dtype) + + # Taille de sortie image + out_h, out_w = H*8, W*8 + + output_rgb = torch.zeros((B, 3, out_h, out_w), device=device, dtype=dtype) + weight = torch.zeros_like(output_rgb) + + # Préparer la pondération cosinus + yy = torch.linspace(-torch.pi/2, torch.pi/2, out_h, device=device) + xx = torch.linspace(-torch.pi/2, torch.pi/2, out_w, device=device) + wy = torch.cos(yy).clamp(min=0) + wx = torch.cos(xx).clamp(min=0) + weight_patch_full = torch.outer(wy, wx)[None, None, :, :] # [1,1,H,W] + + # Découper les latents en patchs + for y0 in range(0, H, block_size): + y1 = min(y0 + block_size, H) + for x0 in range(0, W, block_size): + x1 = min(x0 + block_size, W) + + latent_patch = latents[:, :, y0:y1, x0:x1] + + # Décoder le patch (VAE decode) + decoded = vae.decode(latent_patch).sample # [B,3,H_patch*8,W_patch*8] + decoded = decoded.to(dtype=dtype) + + # Pondération cosinus pour ce patch + ih0, ih1 = y0*8, y1*8 + iw0, iw1 = x0*8, x1*8 + weight_patch = weight_patch_full[:, :, ih0:ih1, iw0:iw1] + + output_rgb[:, :, ih0:ih1, iw0:iw1] += decoded * weight_patch + weight[:, :, ih0:ih1, iw0:iw1] += weight_patch + + # Normaliser + output_rgb /= weight + output_rgb = output_rgb.clamp(-1.0, 1.0) + + return output_rgb + + + +# Version précédente fonctionnel ******************************** +def decode_latents_ultrasafe_blockwise( + latents, vae, + block_size=32, overlap=16, + gamma=1.2, brightness=1.0, + contrast=1.2, saturation=1.3, + device="cuda", frame_counter=0, output_dir=Path("."), + epsilon=1e-6, + latent_scale_boost=1.1 # boost léger pour récupérer les nuances +): + """ + Décodage ultra-safe par blocs des latents en image PIL. + Optimisé pour préserver les nuances de couleur et réduire l'effet "photocopie". + """ + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + # Dimensions finales + out_H = H * 8 + out_W = W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + + # Calcul positions garanties pour full coverage + y_positions = list(range(0, H - block_size + 1, stride)) or [0] + x_positions = list(range(0, W - block_size + 1, stride)) or [0] + + if y_positions[-1] != H - block_size: + y_positions.append(H - block_size) + if x_positions[-1] != W - block_size: + x_positions.append(W - block_size) + + for y in y_positions: + for x in x_positions: + y1 = y + block_size + x1 = x + block_size + + patch = latents[:, :, y:y1, x:x1] + + # Sécurité : NaN / Inf / epsilon + patch = torch.nan_to_num(patch, nan=0.0, posinf=5.0, neginf=-5.0) + if torch.all(patch == 0): + patch += epsilon + + # Décodage + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + # Intégration dans l'image finale + iy0, ix0 = y*8, x*8 + iy1, ix1 = y1*8, x1*8 + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + # Moyenne pour blending final + output_rgb = output_rgb / weight.clamp(min=1e-6) + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # Convertir en PIL et appliquer corrections gamma / contrast / saturation / brightness + frame_pil_list = [] + for i in range(B): + img = F.to_pil_image((output_rgb[i] + 1) / 2) # [-1,1] -> [0,1] + img = ImageEnhance.Brightness(img).enhance(brightness) + img = ImageEnhance.Contrast(img).enhance(contrast) + img = ImageEnhance.Color(img).enhance(saturation) + if gamma != 1.0: + img = img.point(lambda x: 255 * ((x / 255) ** (1 / gamma))) + frame_pil_list.append(img) + + return frame_pil_list[0] if B == 1 else frame_pil_list + + +def encode_images_to_latents_nuanced(images, vae, device="cuda", latent_scale=LATENT_SCALE): + """ + Encode une image en latents VAE tout en préservant le contraste et les nuances de couleur. + - Utilise la moyenne de la distribution latente + - Clamp minimal seulement pour sécurité + - Force 4 canaux si nécessaire + """ + + images = images.to(device=device, dtype=torch.float32) + vae = vae.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latents = vae.encode(images).latent_dist.mean # moyenne pour plus de stabilité + + # Appliquer le scaling mais garder la dynamique + latents = latents * latent_scale + + # Sécurité NaN / Inf (mais pas normalisation globale) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Forcer 4 canaux si nécessaire (VAE attend souvent 4) + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + return latents + +# ---------------- Paramètres ---------------- +device = "cuda" +image_path = "input/256/image_256x0.png" # ton image de test +vae_path = "/mnt/62G/huggingface/vae/vae-ft-mse-840000-ema-pruned.safetensors" + +# ---------------- Charger VAE ---------------- +from diffusers import AutoencoderKL +vae = AutoencoderKL.from_single_file(vae_path, torch_dtype=torch.float16).to(device) +vae.enable_slicing() +vae.enable_tiling() + +# ---------------- Charger Motion Module ---------------- +motion_module = MotionModuleUltraLiteFixComplete(verbose=True) + +# ---------------- Charger image ---------------- +image = load_images_test([image_path], W=256, H=256, device=device, dtype=torch.float16) + +# ---------------- Encoder ---------------- +#latents = encode_images_to_latents_safe(image, vae, device=device, epsilon=1e-5) +latents = encode_images_to_latents_nuanced(image, vae, device=device, latent_scale=LATENT_SCALE) +latents = ensure_4_channels(latents) + +# ---------------- Appliquer motion ---------------- +latents_motion = motion_module(latents.clone()) + +# ---------------- Décoder ---------------- +# latents_motion = ton tenseur [1, 4, F, H, W] ou [1, 4, H, W] +# vae = ton VAE chargé +# device = "cuda" ou "cpu" + +# Paramètres boostés +gamma = 1.0 # 1.2 +brightness = 1.0 # 1.0 +contrast = 1.5 # 1.2 +saturation = 1.5 # 1.2 +upscale_factor = 2 +frame_counter = 0 # pour debug/log si nécessaire + +# Décodage ultra-safe (blockwise comme ton code) +frame_pil = decode_latents_ultrasafe_blockwise( + latents_motion, vae, + block_size=32, overlap=24, + gamma=gamma, + brightness=brightness, + contrast=contrast, + saturation=saturation, + device=device, + frame_counter=frame_counter, + output_dir=Path("."), + epsilon=1e-5, + latent_scale_boost=5.71 +) + + +#frame_pil = decode_latents_ultrasafe_blockwise_test(latents_motion, vae, block_size=32, device='cuda', dtype=torch.float16) + +# Upscale pour debug visuel +if upscale_factor > 1: + frame_pil = frame_pil.resize( + (frame_pil.width * upscale_factor, frame_pil.height * upscale_factor), + Image.BICUBIC + ) + +# Affichage rapide +frame_pil.show() + + +# ------------------------------------------------------------- +# Exemple d'utilisation après ton décodage normal +# ------------------------------------------------------------- +#results = test_parameter_grid(latents_motion, vae, image) + + +#print("Top 5 configurations les plus proches de l'image originale - test_parameter_grid :") +#for r in results[:5]: +# print(r) + +results = test_parameter_grid_with_boost(latents_motion, vae, image) + + +print("Top 5 configurations les plus proches de l'image originale - test_parameter_grid_with_boost:") +for r in results[:5]: + print(r) diff --git a/scripts/utils/Readme.txt b/scripts/utils/Readme.txt new file mode 100644 index 00000000..3986dbd1 --- /dev/null +++ b/scripts/utils/Readme.txt @@ -0,0 +1,34 @@ +All fonction utils + +Sample run: +python -m scripts.n3rHYBRID10 \ + --pretrained-model-path "/mnt/62G/huggingface/miniSD" \ + --config configs/prompts/2_animate/128.yaml \ + --device cuda +📌 Paramètres : fps=12, frames/image=12, steps=12, seed=1234 +⏱ Durée totale estimée : 5.0s +🔄 Chargement tokenizer et text_encoder +✅ Text encoder OK +✅ State dict VAE chargé, clés: ['decoder.conv_in.bias', 'decoder.conv_in.weight', 'decoder.conv_out.bias', 'decoder.conv_out.weight', 'decoder.mid.attn_1.k.bias'] +🔎 Latent shape: torch.Size([1, 4, 32, 32]) +🔎 Decoded shape: torch.Size([1, 3, 256, 256]) +✅ Test VAE 256 OK +✅ VAE OK +✅ UNet + Scheduler OK +✅ Motion module (Python) loaded and instantiated: scripts/modules/motion_module_tiny.py +✅ Image chargée : input/image_128x0.png +✅ Image chargée : input/image_128x1.png +✅ Image chargée : input/image_128x2.png +✅ Image chargée : input/image_128x3.png +✅ Image chargée : input/image_128x4.png +✅ Génération terminée. + + + +python -m scripts.n3rHYBRID11 \ + --pretrained-model-path "/mnt/62G/huggingface/miniSD" \ + --config configs/prompts/2_animate/256_quality.yaml \ + --device cuda \ + --vae-offload \ + --fp16 + diff --git a/scripts/utils/config_loader.py b/scripts/utils/config_loader.py new file mode 100644 index 00000000..feff389e --- /dev/null +++ b/scripts/utils/config_loader.py @@ -0,0 +1,14 @@ +# scripts/utils/config_loader.py +import yaml + +def load_config(path): + cfg_main = yaml.safe_load(open(path)) + + inference_cfg_path = cfg_main.get("inference_config") + if inference_cfg_path: + cfg_infer = yaml.safe_load(open(inference_cfg_path)) + # Merge : priorité au YAML principal + for k, v in cfg_infer.items(): + if k not in cfg_main: + cfg_main[k] = v + return cfg_main diff --git a/scripts/utils/fx_utils.py b/scripts/utils/fx_utils.py new file mode 100644 index 00000000..cfa1fad5 --- /dev/null +++ b/scripts/utils/fx_utils.py @@ -0,0 +1,855 @@ +# fx_utils.py +# ------------------- ENCODE ------------------- + +import os, math, threading +from pathlib import Path +from PIL import Image, ImageFilter, ImageEnhance, ImageChops + +import torch +import numpy as np +import subprocess +from torchvision.transforms import functional as F +import torch.nn.functional as Fu +import torch.nn.functional as TF +import torch.nn.functional as FF +LATENT_SCALE = 0.18215 + + + +def compress_highlights(frame_pil, threshold=235, strength=0.6): + import numpy as np + arr = np.array(frame_pil).astype("float32") + + # luminance approx + lum = 0.299 * arr[:,:,0] + 0.587 * arr[:,:,1] + 0.114 * arr[:,:,2] + + mask = lum > threshold + + # compression douce + factor = 1.0 - strength * ((lum - threshold) / (255 - threshold)) + factor = np.clip(factor, 0.6, 1.0) + + arr[mask] = arr[mask] * factor[mask][:, None] + + arr = np.clip(arr, 0, 255).astype("uint8") + return Image.fromarray(arr) + + +def remove_white_noise(frame_pil, threshold=254, blur_radius=0.1): + """ + Atténue les pixels trop blancs (artefacts) par lissage local. + threshold : valeur RGB au-dessus de laquelle un pixel est considéré bruit + blur_radius : rayon du lissage local + """ + frame_rgb = frame_pil.convert("RGB") + # créer un masque des pixels trop blancs + mask = frame_rgb.point(lambda i: 255 if i > threshold else 0) + mask = mask.convert("L") + # flouter la zone bruyante + blurred = frame_rgb.filter(ImageFilter.GaussianBlur(blur_radius)) + # fusionner uniquement sur le masque + cleaned = Image.composite(blurred, frame_rgb, mask) + return cleaned + + + +def apply_post_processing_unreal_smooth(frame_pil, + contrast=1.15, + vibrance=1.05, + edge_strength=1.5, + simplify_radius=0.8, + smooth_radius=0.05, + sharpen_percent=70): + # 1️⃣ Unreal (volume + bords) + frame_pil = apply_post_processing_unreal_safe( + frame_pil, + contrast=contrast, + vibrance=vibrance, + edge_strength=edge_strength, + simplify_radius=simplify_radius + ) + + # 2️⃣ Lissage adaptatif (smoothing léger, préserve contours) + # On utilise un GaussianBlur très léger et on peut mélanger avec l'image originale pour contrôler le lissage + frame_blur = frame_pil.filter(ImageFilter.GaussianBlur(radius=smooth_radius)) + frame_pil = Image.blend(frame_pil, frame_blur, alpha=0.35) # alpha <1 pour ne pas tout écraser + + # 3️⃣ Adaptive / final tweaks + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=0.0, + contrast=1.05, + brightness=1.05, + saturation=0.90, + vibrance_base=1.0, + vibrance_max=1.05, + sharpen=True, + sharpen_radius=1, + sharpen_percent=sharpen_percent, + sharpen_threshold=2 + ) + + # 4️⃣ Clamp final pour éviter pixels blancs + frame_pil = frame_pil.point(lambda i: max(0, min(255, int(i)))) + + return frame_pil + + +def apply_post_processing_unreal_safe( + frame_pil, + blur_radius=0.01, # 🔽 plus subtil + contrast=1.08, # 🔽 réduit + brightness=1.02, + saturation=0.98, + vibrance=1.02, + edge_boost=True, + edge_strength=0.4, # 🔥 énorme différence + simplify_radius=0.4, # 🔽 moins de blur destructif + sharpen=True, + sharpen_radius=0.8, + sharpen_percent=60, # 🔽 moins agressif + sharpen_threshold=3 +): + """ + Version douce : rendu naturel, moins de pixelisation + """ + from PIL import ImageFilter, ImageEnhance, ImageChops + import numpy as np + + # 1️⃣ Micro-smooth (léger) + if simplify_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=simplify_radius)) + + # 2️⃣ Vibrance douce (continue, pas de seuil) + if vibrance != 1.0: + arr = np.array(frame_pil).astype(np.float32) + mean = arr.mean(axis=2, keepdims=True) + arr = mean + (arr - mean) * vibrance + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + # 3️⃣ Ajustements globaux doux + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + # 4️⃣ Edge boost subtil (blend au lieu de add) + if edge_boost: + gray = frame_pil.convert("L") + edge = gray.filter(ImageFilter.FIND_EDGES) + edge = ImageEnhance.Contrast(edge).enhance(1.2) + + edge_rgb = Image.merge("RGB", (edge, edge, edge)) + + # 🔥 blend au lieu de add → beaucoup plus naturel + frame_pil = Image.blend(frame_pil, edge_rgb, edge_strength) + + # 5️⃣ Sharpen léger (micro détails seulement) + if sharpen: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + + # 6️⃣ Micro blur final (anti pixel) + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + return frame_pil + + +def apply_post_processing_unreal(frame_pil, + blur_radius=0.02, + contrast=1.2, + brightness=1.05, + saturation=1.0, + vibrance=1.1, + sharpen=True, + sharpen_radius=2, + sharpen_percent=150, + sharpen_threshold=2, + edge_boost=True, + edge_strength=1.5, + texture_simplify=True, + simplify_radius=1.0): + """ + Post-processing de type 'Unreal' pour donner plus de volume et styliser les images. + """ + + # ----------------- 1) Lissage / simplification textures ----------------- + if texture_simplify and simplify_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=simplify_radius)) + + # ----------------- 2) Contraste / Luminosité / Saturation ----------------- + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + # ----------------- 3) Vibrance (boost couleurs faibles) ----------------- + if vibrance != 1.0: + frame_hsv = frame_pil.convert("HSV") + h, s, v = frame_hsv.split() + s = s.point(lambda i: min(255, int(i * vibrance) if i < 128 else i)) + frame_pil = Image.merge("HSV", (h, s, v)).convert("RGB") + + # ----------------- 4) Edge Enhance / Relief ----------------- + if edge_boost: + # Edge enhance + high contrast overlay pour effet 'bordelands / unreal' + edge = frame_pil.filter(ImageFilter.FIND_EDGES) + edge = ImageEnhance.Contrast(edge).enhance(edge_strength) + frame_pil = ImageChops.add(frame_pil, edge, scale=1.0, offset=0) + + # ----------------- 5) Sharp ----------------- + if sharpen: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + + # ----------------- 6) Option blur final léger ----------------- + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + return frame_pil + + +# ---------------- Fusion N3R/VAE adaptative ---------------- + +def fuse_n3r_latents_adaptive(latents_frame, n3r_latents, latent_injection=0.7, clamp_val=1.0, creative_noise=0.0): + # Assurer même taille + if n3r_latents.shape != latents_frame.shape: + n3r_latents = torch.nn.functional.interpolate( + n3r_latents, size=latents_frame.shape[-2:], mode='bilinear', align_corners=False + ) + # Ajuster le nombre de canaux si nécessaire + if n3r_latents.shape[1] < latents_frame.shape[1]: + extra = latents_frame[:, n3r_latents.shape[1]:, :, :].clone() * 0 + n3r_latents = torch.cat([n3r_latents, extra], dim=1) + elif n3r_latents.shape[1] > latents_frame.shape[1]: + n3r_latents = n3r_latents[:, :latents_frame.shape[1], :, :] + + n3r_latents = n3r_latents.clone() + + # Normalisation **canal par canal** (RGB uniquement) + for c in range(min(3, n3r_latents.shape[1])): + n3r_c = n3r_latents[:, c:c+1, :, :] + frame_c = latents_frame[:, c:c+1, :, :] + mean, std = n3r_c.mean(), n3r_c.std() + n3r_c = (n3r_c - mean) / (std + 1e-6) + n3r_c = n3r_c * frame_c.std() + frame_c.mean() + n3r_latents[:, c:c+1, :, :] = n3r_c + + # Ajouter un bruit créatif léger si nécessaire + if creative_noise > 0.0: + noise = torch.randn_like(n3r_latents) * creative_noise + n3r_latents += noise + + # Clamp stricte + n3r_latents = torch.clamp(n3r_latents, -clamp_val, clamp_val) + latents_frame = torch.clamp(latents_frame, -clamp_val, clamp_val) + + # Fusion finale + fused_latents = latent_injection * latents_frame + (1 - latent_injection) * n3r_latents + fused_latents = torch.clamp(fused_latents, -clamp_val, clamp_val) + + print(f"[N3R fusion frame] mean/std par canal: {fused_latents.mean(dim=(2,3))}, injection={latent_injection:.2f}") + return fused_latents + +def fuse_n3r_latents_adaptive_v1(latents_frame, n3r_latents, latent_injection=0.7, clamp_val=1.0, creative_noise=0.0): + n3r_latents = n3r_latents.clone() + + # Normalisation **canal par canal** + for c in range(3): # RGB uniquement + n3r_c = n3r_latents[:,c:c+1,:,:] + frame_c = latents_frame[:,c:c+1,:,:] + mean, std = n3r_c.mean(), n3r_c.std() + n3r_c = (n3r_c - mean) / (std + 1e-6) + n3r_c = n3r_c * frame_c.std() + frame_c.mean() + n3r_latents[:,c:c+1,:,:] = n3r_c + + # Ajouter un bruit créatif léger si nécessaire + if creative_noise > 0.0: + noise = torch.randn_like(n3r_latents) * creative_noise + n3r_latents += noise + + # Clamp stricte pour éviter débordement + n3r_latents = torch.clamp(n3r_latents, -clamp_val, clamp_val) + latents_frame = torch.clamp(latents_frame, -clamp_val, clamp_val) + + # Fusion finale + fused_latents = latent_injection * latents_frame + (1 - latent_injection) * n3r_latents + fused_latents = torch.clamp(fused_latents, -clamp_val, clamp_val) + + print(f"[N3R fusion frame] mean/std par canal: {fused_latents.mean(dim=(2,3))}, injection={latent_injection:.2f}") + return fused_latents + + +def interpolate_param_fast(start_val, end_val, current_frame, total_frames, mode="linear", speed=2.0): + """ + Interpolation accélérée pour faire varier les paramètres plus rapidement au début. + speed > 1 → plus rapide, speed < 1 → plus lent + """ + t = current_frame / max(total_frames-1, 1) + t = min(1.0, t * speed) # accélère la progression + + if mode == "linear": + return start_val + (end_val - start_val) * t + elif mode == "cosine": + t = (1 - math.cos(math.pi * t)) / 2 + return start_val + (end_val - start_val) * t + elif mode == "ease_in_out": + t = t*t*(3 - 2*t) + return start_val + (end_val - start_val) * t + else: + return start_val + (end_val - start_val) * t + + +def interpolate_param(start_val, end_val, current_frame, total_frames, mode="linear"): + """ + Interpolation d'un paramètre entre start_val -> end_val sur total_frames. + Modes disponibles: 'linear', 'cosine', 'ease_in_out' + """ + t = current_frame / max(total_frames-1,1) + if mode == "linear": + return start_val + (end_val - start_val) * t + elif mode == "cosine": + # Cosine easing pour un départ/arrivée plus doux + t = (1 - math.cos(math.pi * t)) / 2 + return start_val + (end_val - start_val) * t + elif mode == "ease_in_out": + t = t*t*(3 - 2*t) + return start_val + (end_val - start_val) * t + else: + return start_val + (end_val - start_val) * t + + +def estimate_sharpness(image): + gray = image.convert("L") + arr = np.array(gray, dtype=np.float32) + laplacian = np.abs( + arr[:-2,1:-1] + arr[2:,1:-1] + arr[1:-1,:-2] + arr[1:-1,2:] - 4*arr[1:-1,1:-1] + ) + return laplacian.mean() + +def adaptive_post_process(image): + sharpness = estimate_sharpness(image) + + # seuils empiriques (à ajuster) + if sharpness < 8: + # image floue → sharpen fort + return apply_post_processing( + image, + blur_radius=0.02, + contrast=1.1, + brightness=1.05, + saturation=0.9, + sharpen=True, + sharpen_radius=1, + sharpen_percent=120, + sharpen_threshold=2 + ) + + elif sharpness > 15: + # image déjà très nette → adoucir + return apply_post_processing( + image, + blur_radius=0.15, + contrast=1.05, + brightness=1.02, + saturation=0.85, + sharpen=False + ) + + else: + # équilibré + return apply_post_processing( + image, + blur_radius=0.05, + contrast=1.1, + brightness=1.05, + saturation=0.9, + sharpen=True, + sharpen_radius=1, + sharpen_percent=80, + sharpen_threshold=2 + ) + + +def apply_post_processing_adaptive( + frame_pil, + blur_radius=0.05, + contrast=1.15, + brightness=1.05, + saturation=0.85, + vibrance_base=1.1, # vibrance de base + vibrance_max=1.3, # max booster pour zones peu saturées + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True # clamp adaptatif du canal rouge +): + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + # ---------------- GaussianBlur léger ---------------- + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ---------------- Contrast / Brightness / Saturation ---------------- + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + # ---------------- Vibrance adaptative ---------------- + try: + frame_np = np.array(frame_pil).astype(np.float32) + # calculer la saturation relative par pixel + max_rgb = np.max(frame_np, axis=2) + min_rgb = np.min(frame_np, axis=2) + sat = max_rgb - min_rgb + # vibrance: plus le pixel est peu saturé, plus on boost + factor_map = vibrance_base + (vibrance_max - vibrance_base) * (1 - sat/255.0) + factor_map = np.clip(factor_map, vibrance_base, vibrance_max) + for c in range(3): + frame_np[:,:,c] = np.clip(frame_np[:,:,c] * factor_map, 0, 255) + frame_pil = Image.fromarray(frame_np.astype(np.uint8)) + except Exception as e: + print(f"[WARNING] vibrance adaptative skipped: {e}") + + # ---------------- Clamp adaptatif du canal rouge ---------------- + if clamp_r: + try: + r, g, b = frame_pil.split() + r_np = np.array(r).astype(np.float32) + r_mean = r_np.mean() + # si la moyenne est trop haute, réduire légèrement + if r_mean > 180: + factor = 180 / r_mean + r_np = np.clip(r_np * factor, 0, 255) + r = Image.fromarray(r_np.astype(np.uint8)) + frame_pil = Image.merge("RGB", (r, g, b)) + except Exception as e: + print(f"[WARNING] clamp rouge skipped: {e}") + + # ---------------- Sharp / UnsharpMask ---------------- + if sharpen: + try: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + except Exception as e: + print(f"[WARNING] sharpening skipped: {e}") + + return frame_pil + +def apply_post_processing( + frame_pil, + blur_radius=0.05, + contrast=1.15, + brightness=1.05, + saturation=0.85, + vibrance=1.0, # valeurs raisonnables pour éviter doré + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True # optionnel: clamp canal rouge pour éviter doré +): + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + # GaussianBlur + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # Contrast / Brightness / Saturation + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + # Vibrance: booster légèrement les couleurs peu saturées + if vibrance != 1.0: + try: + frame_hsv = frame_pil.convert("HSV") + h, s, v = frame_hsv.split() + s = s.point(lambda i: min(255, int(i * vibrance) if i < 128 else i)) + frame_pil = Image.merge("HSV", (h, s, v)).convert("RGB") + except Exception as e: + print(f"[WARNING] vibrance skipped due to error: {e}") + + # Clamp canal rouge pour éviter les zones trop dorées + if clamp_r: + r, g, b = frame_pil.split() + r = r.point(lambda i: min(230, i)) # clamp max à 230 (~90% du max) + frame_pil = Image.merge("RGB", (r, g, b)) + + # Sharp / UnsharpMask + if sharpen: + try: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + except Exception as e: + print(f"[WARNING] sharpening skipped due to error: {e}") + + return frame_pil + +def apply_post_processing_v1(frame_pil, + blur_radius=0.05, + contrast=1.15, + brightness=1.05, + saturation=0.85, + vibrance=1.1, # <-- ajout vibrance + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2): + # GaussianBlur + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # Contrast / Brightness / Saturation + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + # Vibrance: booster les couleurs peu saturées + if vibrance != 1.0: + frame_hsv = frame_pil.convert("HSV") + h, s, v = frame_hsv.split() + s = s.point(lambda i: min(255, int(i * vibrance) if i < 128 else i)) + frame_pil = Image.merge("HSV", (h, s, v)).convert("RGB") + + # Sharp / UnsharpMask + if sharpen: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + + return frame_pil + + +def apply_post_processing_blur(frame_pil, blur_radius=0.2, contrast=1.0, brightness=1.0, saturation=1.0): + """ + Appliquer des effets post-decode sur une frame PIL. + + Args: + frame_pil (PIL.Image): L'image décodée depuis les latents + blur_radius (float): Rayon du flou gaussien + contrast (float): Facteur de contraste (1.0 = inchangé) + brightness (float): Facteur de luminosité (1.0 = inchangé) + saturation (float): Facteur de saturation (1.0 = inchangé) + + Returns: + PIL.Image: Image modifiée + """ + # GaussianBlur simple + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # Ajustements optionnels (contrast, brightness, saturation) + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + return frame_pil + + +def encode_images_to_latents_hybrid(images, vae, device="cuda", latent_scale=LATENT_SCALE): + """ + Encodage hybride pour conserver la fidélité et la richesse de détails. + - Utilise un échantillon de la distribution (plus vivant que mean) + - Clamp léger pour éviter débordements extrêmes mais garder micro-contrastes + - Assure 4 channels si nécessaire pour compatibilité UNet/N3R + """ + images = images.to(device=device, dtype=torch.float32) + vae = vae.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() # <-- sample() pour micro-variations + + latents = latents * latent_scale + latents = torch.clamp(latents, -5.0, 5.0) # Clamp léger, pas de nan_to_num + + # Assurer 4 channels si nécessaire + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + return latents + +def encode_images_to_latents_safe(images, vae, device="cuda", latent_scale=0.18215): + """ + Encode une image en latents VAE en gardant la stabilité et le contraste. + """ + images = images.to(device=device, dtype=torch.float32) + vae = vae.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latents = vae.encode(images).latent_dist.mean # moyenne pour stabilité + + latents = latents * latent_scale + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + return latents + + +def save_frames_as_video_from_folder( + folder_path, + out_path, + fps=12, + upscale_factor=1 +): + """ + Sauvegarde toutes les images PNG d'un dossier en vidéo MP4. + - Utilise ffmpeg directement (plus besoin de imageio) + - Supporte l'upscale des images + """ + folder_path = Path(folder_path) + images = sorted(folder_path.glob("*.png")) + + if not images: + raise ValueError(f"Aucune image trouvée dans {folder_path}") + + # Créer un dossier temporaire pour les images upscale + tmp_dir = folder_path / "_tmp_upscaled" + tmp_dir.mkdir(exist_ok=True) + + # Redimensionner les images si nécessaire + for idx, img_path in enumerate(images): + img = Image.open(img_path) + if upscale_factor != 1: + img = img.resize((img.width * upscale_factor, img.height * upscale_factor), Image.BICUBIC) + tmp_file = tmp_dir / f"frame_{idx:05d}.png" + img.save(tmp_file) + + # Appel ffmpeg pour créer la vidéo + cmd = [ + "ffmpeg", + "-y", # overwrite + "-framerate", str(fps), + "-i", str(tmp_dir / "frame_%05d.png"), + "-c:v", "libx264", + "-pix_fmt", "yuv420p", + str(out_path) + ] + + print("⚡ Génération vidéo avec ffmpeg…") + subprocess.run(cmd, check=True) + print(f"🎬 Vidéo sauvegardée : {out_path}") + + # Optionnel : supprimer le dossier temporaire + for f in tmp_dir.glob("*.png"): + f.unlink() + tmp_dir.rmdir() + + +def encode_images_to_latents_nuanced_v1(images, vae, device="cuda", latent_scale=LATENT_SCALE): + """ + Encode une image en latents VAE tout en préservant le contraste et les nuances de couleur. + - Utilise la moyenne de la distribution latente + - Clamp minimal seulement pour sécurité + - Force 4 canaux si nécessaire + """ + + images = images.to(device=device, dtype=torch.float32) + vae = vae.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latents = vae.encode(images).latent_dist.mean # moyenne pour plus de stabilité + + # Appliquer le scaling mais garder la dynamique + latents = latents * latent_scale + + # Sécurité NaN / Inf (mais pas normalisation globale) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Forcer 4 canaux si nécessaire (VAE attend souvent 4) + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + return latents + + + + +def encode_images_to_latents_nuanced(images, vae, unet, device="cuda", latent_scale=LATENT_SCALE): + """ + Encode une image en latents VAE, en préservant nuances et contraste, + et redimensionne dynamiquement pour correspondre à la taille attendue par le UNet. + """ + images = images.to(device=device, dtype=torch.float32) + vae = vae.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + # Encoder l'image → latents + latents = vae.encode(images).latent_dist.mean # moyenne pour stabilité + + # Appliquer le scaling + latents = latents * latent_scale + + # Sécurité NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Forcer 4 canaux si nécessaire + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + # 🔹 Redimensionner dynamiquement pour correspondre au UNet + # On regarde la taille attendue à partir de l'UNet (ou ses skip connections) + # Supposons que le UNet a un module `config` avec "sample_size" ou attention_resolutions + try: + # Si UNet a `sample_size` ou autre attribut + target_H = getattr(unet.config, "sample_size", latents.shape[2]) + target_W = getattr(unet.config, "sample_size", latents.shape[3]) + except AttributeError: + # fallback : garder la taille actuelle + target_H, target_W = latents.shape[2], latents.shape[3] + + # Interpolation bilinéaire pour adapter la taille + if (latents.shape[2], latents.shape[3]) != (target_H, target_W): + latents = TF.interpolate(latents, size=(target_H, target_W), mode="bilinear", align_corners=False) + print(f"[DEBUG] Latents resized to ({target_H}, {target_W})") + + return latents + +# ------------------- ENCODE ------------------- + +def encode_images_to_latents_target(images, vae, device="cuda", latent_scale=LATENT_SCALE, target_size=64): + """ + Encode une image en latents VAE, compatible MiniSD/AnimateDiff ultra-light (~2Go VRAM) + - garde la dynamique et contraste + - clamp minimal pour sécurité + - force 4 canaux + - resize latents à target_size (MiniSD) + """ + images = images.to(device=device, dtype=torch.float32) + vae = vae.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latents = vae.encode(images).latent_dist.mean # moyenne pour plus de stabilité + + # Appliquer le scaling + latents = latents * latent_scale + + # Sécurité NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Forcer 4 canaux si nécessaire + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + # Redimensionner à target_size x target_size + if latents.shape[2] != target_size or latents.shape[3] != target_size: + latents = torch.nn.functional.interpolate( + latents, size=(target_size, target_size), mode="bilinear", align_corners=False + ) + + return latents + + + +# ------------------- DECODE ------------------- + + +def decode_latents_ultrasafe_blockwise( + latents, vae, + block_size=32, overlap=16, + gamma=1.0, brightness=1.0, + contrast=1.0, saturation=1.0, + device="cuda", frame_counter=0, output_dir=Path("."), + epsilon=1e-6, + latent_scale_boost=1.0 # boost léger pour récupérer les nuances +): + """ + Décodage ultra-safe par blocs des latents en image PIL. + Optimisé pour préserver les nuances de couleur et réduire l'effet "photocopie". + """ + + # 🔹 Correctif : forcer le VAE sur le bon device et en float32 + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + # Dimensions finales + out_H = H * 8 + out_W = W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + + # Calcul positions garanties pour full coverage + y_positions = list(range(0, H - block_size + 1, stride)) or [0] + x_positions = list(range(0, W - block_size + 1, stride)) or [0] + + if y_positions[-1] != H - block_size: + y_positions.append(H - block_size) + if x_positions[-1] != W - block_size: + x_positions.append(W - block_size) + + for y in y_positions: + for x in x_positions: + y1 = y + block_size + x1 = x + block_size + + patch = latents[:, :, y:y1, x:x1] + + # Sécurité : NaN / Inf / epsilon + patch = torch.nan_to_num(patch, nan=0.0, posinf=5.0, neginf=-5.0) + if torch.all(patch == 0): + patch += epsilon + + # Décodage + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + # Intégration dans l'image finale + iy0, ix0 = y*8, x*8 + iy1, ix1 = y1*8, x1*8 + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + # Moyenne pour blending final + output_rgb = output_rgb / weight.clamp(min=1e-6) + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # Convertir en PIL et appliquer corrections gamma / contrast / saturation / brightness + frame_pil_list = [] + for i in range(B): + img = F.to_pil_image((output_rgb[i] + 1) / 2) # [-1,1] -> [0,1] + img = ImageEnhance.Brightness(img).enhance(brightness) + img = ImageEnhance.Contrast(img).enhance(contrast) + img = ImageEnhance.Color(img).enhance(saturation) + if gamma != 1.0: + img = img.point(lambda x: 255 * ((x / 255) ** (1 / gamma))) + frame_pil_list.append(img) + + return frame_pil_list[0] if B == 1 else frame_pil_list diff --git a/scripts/utils/logging_utils.py b/scripts/utils/logging_utils.py new file mode 100644 index 00000000..2dd87c68 --- /dev/null +++ b/scripts/utils/logging_utils.py @@ -0,0 +1,72 @@ +# utils/logging_utils.py +import os +import csv +import torch + +def log_latent_stats(frame_idx, latents, csv_path="latent_stats.csv"): + """ + Écrit les stats latentes dans un CSV. + Args: + frame_idx (int): numéro de la frame + latents (torch.Tensor): tensor latent + csv_path (str or Path): chemin vers le CSV + """ + min_val = float(latents.min()) + max_val = float(latents.max()) + mean_val = float(latents.mean()) + std_val = float(latents.std()) + + # Écrire l'en-tête si le fichier n'existe pas + write_header = not os.path.exists(csv_path) + + with open(csv_path, "a", newline="") as f: + writer = csv.writer(f) + if write_header: + writer.writerow(["frame", "min", "max", "mean", "std"]) + writer.writerow([frame_idx, min_val, max_val, mean_val, std_val]) + + +def log_patch_stats(frame_idx, patch_idx, patch, csv_path="patch_stats.csv"): + """ + Écrit les stats de chaque patch VAE dans un CSV. + Args: + frame_idx (int): numéro de la frame + patch_idx (str): identifiant du patch (ex: "0_0") + patch (torch.Tensor): patch latent ou décodé + csv_path (str or Path): chemin vers le CSV + """ + min_val = float(patch.min()) + max_val = float(patch.max()) + mean_val = float(patch.mean()) + std_val = float(patch.std()) + + shape_str = "x".join(map(str, patch.shape)) + dtype_str = str(patch.dtype) + device_str = str(patch.device) + any_nan = int(torch.isnan(patch).any()) + any_inf = int(torch.isinf(patch).any()) + + # Mémoire GPU (si sur CUDA) + if patch.is_cuda: + mem_alloc = torch.cuda.memory_allocated() + mem_reserved = torch.cuda.memory_reserved() + else: + mem_alloc = 0 + mem_reserved = 0 + + # Écrire l'en-tête si le fichier n'existe pas + write_header = not os.path.exists(csv_path) + + with open(csv_path, "a", newline="") as f: + writer = csv.writer(f) + if write_header: + writer.writerow([ + "frame", "patch", "shape", "dtype", "device", + "min", "max", "mean", "std", "NaN", "Inf", + "gpu_alloc_bytes", "gpu_reserved_bytes" + ]) + writer.writerow([ + frame_idx, patch_idx, shape_str, dtype_str, device_str, + min_val, max_val, mean_val, std_val, + any_nan, any_inf, mem_alloc, mem_reserved + ]) diff --git a/scripts/utils/logo.png b/scripts/utils/logo.png new file mode 100644 index 00000000..f718be9a Binary files /dev/null and b/scripts/utils/logo.png differ diff --git a/scripts/utils/lora_utils.py b/scripts/utils/lora_utils.py new file mode 100644 index 00000000..6d88f891 --- /dev/null +++ b/scripts/utils/lora_utils.py @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------ +# lora_utils_smart_device.py - Chargement intelligent des LoRA avec device +# ------------------------------------------------------------------ +import torch +from safetensors.torch import load_file + + +def detect_unet_type(unet): + """ + Détecte le type de UNet selon cross_attention_dim + """ + dim = getattr(getattr(unet, "config", None), "cross_attention_dim", None) + + if dim == 768: + model_type = "SD1.x compatible" + elif dim == 1024: + model_type = "SD2.x compatible" + elif dim == 2048: + model_type = "SDXL compatible" + else: + model_type = "UNet custom" + + return model_type, dim + + +def apply_lora(unet, lora_path, alpha=0.8, device=None, verbose=True): + """ + Applique un modèle LoRA / n3oray sur UNet + """ + + device = device or next(unet.parameters()).device + + # ---------------- Détection UNet ---------------- + model_type, cross_dim = detect_unet_type(unet) + + print("🧠 Détection modèle UNet") + print(f" type : {model_type}") + print(f" cross_attention_dim : {cross_dim}") + + print(f"📌 Chargement LoRA : {lora_path}") + + # Charger LoRA sur CPU + lora_state = load_file(lora_path, device="cpu") + + unet_state = dict(unet.named_parameters()) + + applied = 0 + skipped = 0 + missing = 0 + + for name, lora_param in lora_state.items(): + + if name not in unet_state: + missing += 1 + continue + + param = unet_state[name] + + # vérification dimension + if param.shape != lora_param.shape: + + if verbose: + print( + f"[LoRA SKIP] {name} " + f"{tuple(lora_param.shape)} != {tuple(param.shape)}" + ) + + skipped += 1 + continue + + lora_param = lora_param.to(device=device, dtype=param.dtype) + + # mélange des poids + param.data.mul_(1 - alpha).add_(lora_param, alpha=alpha) + + applied += 1 + + print("✅ LoRA résumé") + print(f" couches appliquées : {applied}") + print(f" couches ignorées : {skipped}") + print(f" couches absentes : {missing}") + + return unet + + +def apply_lora_smart(unet, lora_path, alpha=0.8, device=None, verbose=True): + device = device or next(unet.parameters()).device + model_type, cross_dim = detect_unet_type(unet) + + if verbose: + print("🧠 Détection modèle UNet") + print(f" type : {model_type}") + print(f" cross_attention_dim : {cross_dim}") + print(f"📌 Chargement LoRA : {lora_path}") + + lora_state = load_file(lora_path, device="cpu") + unet_state = dict(unet.named_parameters()) + + # Affichage filtré + if verbose: + for k, v in lora_state.items(): + if "up_blocks" in k and "attn1.to_q.weight" in k: + print(k, v.shape) + + # Compatibilité : intersection avec le UNet + compatible_keys = [k for k in lora_state if k in unet_state and unet_state[k].shape == lora_state[k].shape] + if not compatible_keys: + print(f"⚠ LoRA '{lora_path}' incompatible avec ce UNet, chargement annulé.") + return unet + + # Application + applied = 0 + skipped = 0 + missing = 0 + for k, lora_param in lora_state.items(): + if k not in unet_state: + missing += 1 + continue + param = unet_state[k] + if param.shape != lora_param.shape: + skipped += 1 + continue + lora_param = lora_param.to(device=device, dtype=param.dtype) + with torch.no_grad(): + param.copy_(param*(1-alpha) + lora_param*alpha) + applied += 1 + + if verbose: + print("✅ LoRA résumé") + print(f" couches appliquées : {applied}") + print(f" couches ignorées : {skipped}") + print(f" couches absentes : {missing}") + + return unet + + +def list_lora_parameters(lora_path): + """ + Liste les paramètres contenus dans un LoRA + """ + lora_state = load_file(lora_path, device="cpu") + print(f"📌 Paramètres dans {lora_path}") + for k, v in lora_state.items(): + print(f"{k} -> {tuple(v.shape)}") + return list(lora_state.keys()) diff --git a/scripts/utils/model_utils.py b/scripts/utils/model_utils.py new file mode 100644 index 00000000..5981d910 --- /dev/null +++ b/scripts/utils/model_utils.py @@ -0,0 +1,84 @@ +# ------------------------- +# scripts/utils/model_utils.py +# ------------------------- +import torch +from diffusers import UNet2DConditionModel, AutoencoderKL, LMSDiscreteScheduler +from transformers import CLIPTextModel, CLIPTokenizer +from pathlib import Path + +from diffusers import DDIMScheduler, LMSDiscreteScheduler + +def get_text_embeddings(text_encoder, tokenizer, prompt, negative_prompt="", device="cuda", dtype=torch.float32): + """ + Retourne les embeddings textuels positifs et négatifs pour l'inference guidée. + + Args: + text_encoder: modèle CLIPTextModel chargé + tokenizer: tokenizer CLIP + prompt: texte positif (str) + negative_prompt: texte négatif (str) + device: "cuda" ou "cpu" + dtype: torch dtype (float32 ou float16) + + Returns: + pos_embeds, neg_embeds: tensors shape [1, seq_len, hidden_dim] + """ + + # --- Tokenization --- + pos_inputs = tokenizer(prompt, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + neg_inputs = tokenizer(negative_prompt, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt") + + # --- Passage dans le text encoder --- + with torch.no_grad(): + pos_embeds = text_encoder(pos_inputs.input_ids.to(device))[0].to(dtype) + neg_embeds = text_encoder(neg_inputs.input_ids.to(device))[0].to(dtype) + + return pos_embeds, neg_embeds + +# ------------------------- +# Chargement UNet +# ------------------------- +def load_pretrained_unet(pretrained_path: str, device: str = "cuda", dtype=torch.float32): + unet_path = Path(pretrained_path) + unet = UNet2DConditionModel.from_pretrained(unet_path, subfolder="unet", torch_dtype=dtype) + unet.to(device) + unet.eval() + return unet + +# ------------------------- +# Chargement scheduler +# ------------------------- +def load_scheduler(pretrained_path: str): + scheduler_path = Path(pretrained_path) + scheduler = LMSDiscreteScheduler.from_pretrained(scheduler_path, subfolder="scheduler") + return scheduler + + +# ------------------------- +# Chargement scheduler +# ------------------------- + +def load_DDIMScheduler(model_path, scheduler_type="DDIMScheduler"): + """ + Charge le scheduler de diffusion depuis un modèle pré-entraîné. + Renvoie un objet Scheduler, pas un tensor. + """ + if scheduler_type == "DDIMScheduler": + scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler") + elif scheduler_type == "LMSDiscreteScheduler": + scheduler = LMSDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler") + else: + raise ValueError(f"Scheduler inconnu: {scheduler_type}") + + return scheduler + +# ------------------------- +# Chargement Text Encoder + Tokenizer +# ------------------------- +def load_text_encoder(pretrained_path: str, device: str = "cuda"): + encoder_path = Path(pretrained_path) + tokenizer = CLIPTokenizer.from_pretrained(encoder_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(encoder_path, subfolder="text_encoder") + text_encoder.to(device) + text_encoder.eval() + return text_encoder, tokenizer diff --git a/scripts/utils/motion_utils.py b/scripts/utils/motion_utils.py new file mode 100644 index 00000000..8bab4472 --- /dev/null +++ b/scripts/utils/motion_utils.py @@ -0,0 +1,186 @@ +# motion_utils.py + +from pathlib import Path +import os +import importlib.util +import torch + +try: + import safetensors.torch +except ImportError: + safetensors = None + +# ------------------------- +# Default motion module +# ------------------------- +class DefaultMotionModule(torch.nn.Module): + def forward(self, latents): + return latents + +default_motion_module = DefaultMotionModule() + + +# --------------------------------------------------------- +# Diffusion FONCTIONNE PARFAITEMENT +# images_latents: [B,4,T,H,W] +# apply_motion_module = generate_latent +# --------------------------------------------------------- +def apply_motion_module(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, device="cuda", dtype=torch.float16, guidance_scale=7.5, init_image_scale=2.0, seed=42): + """ + latents: [B,4,T,H,W] (déjà encodés et scalés) init_image_scale: poids de l'image initiale + """ + torch.manual_seed(seed) + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + # ⚡ on garde une copie des latents initiaux + init_latents = latents.clone() + for t in scheduler.timesteps: + if motion_module is not None: + latents = motion_module(latents) + + # classifier-free guidance + latent_model_input = torch.cat([latents] * 2) + embeds = torch.cat([neg_embeds, pos_embeds]) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # ⚡ appliquer init_image_scale pour garder l’influence de l’image initiale + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents + init_image_scale * (init_latents - latents) + + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + + return latents + +#--------------------------------------------------------- +# ------------------------- +# Génération de latents par bloc OK +# def generate_latents(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, device="cuda", dtype=torch.float16, guidance_scale=7.5, init_image_scale=2.0, seed=42, +# ------------------------- +def generate_latents_1(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, device="cuda", dtype=torch.float16, guidance_scale=4.5, init_image_scale=0.85): + """ + latents: [B, C, F, H, W] + pos_embeds / neg_embeds: [B, L, D] + """ + """ + latents: [B,4,T,H,W] (déjà encodés et scalés) init_image_scale: poids de l'image initiale + """ + torch.manual_seed(42) + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + # ⚡ on garde une copie des latents initiaux + init_latents = latents.clone() + for t in scheduler.timesteps: + if motion_module is not None: + latents = motion_module(latents) + + # classifier-free guidance + latent_model_input = torch.cat([latents] * 2) + embeds = torch.cat([neg_embeds, pos_embeds]) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # ⚡ appliquer init_image_scale pour garder l’influence de l’image initiale + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents + init_image_scale * (init_latents - latents) + + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + + return latents + + + +def load_motion_module(module_path: str, device: str = "cuda", fp16: bool = True, verbose: bool = True): + """ + Charge un motion module depuis un .py, .ckpt ou .safetensors + et applique un patch safe automatique pour éviter les frames trop faibles. + """ + if not os.path.exists(module_path): + raise FileNotFoundError(f"Motion module not found: {module_path}") + + dtype = torch.float16 if fp16 else torch.float32 + + # ---------------- Module Python ---------------- + if module_path.endswith(".py"): + spec = importlib.util.spec_from_file_location("motion_module", module_path) + mm = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mm) + + # Cherche la première classe nn.Module + cls_candidates = [v for v in mm.__dict__.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)] + if len(cls_candidates) == 0: + raise ValueError(f"No nn.Module subclass found in {module_path}") + + motion_module_cls = cls_candidates[0] + + motion_module = motion_module_cls() + motion_module.to(device=device, dtype=dtype) + motion_module.eval() + + # Patch safe: injecte un bruit minimal si frames trop faibles + if hasattr(motion_module, "forward"): + original_forward = motion_module.forward + def safe_forward(x): + frame_max = x.abs().max() + if frame_max < 1e-3: + x = x + torch.randn_like(x)*1e-2 + if verbose: + print(f"[SAFE DEBUG] Frame trop faible ({frame_max:.6f}) → bruit injecté") + return original_forward(x) + motion_module.forward = safe_forward + + if verbose: + print(f"✅ Motion module (Python) loaded and patched safe: {module_path}") + return motion_module + + # ---------------- Checkpoint / safetensors ---------------- + elif module_path.endswith(".ckpt") or module_path.endswith(".safetensors"): + try: + if module_path.endswith(".ckpt"): + state_dict = torch.load(module_path, map_location="cpu") + else: + if safetensors is None: + raise ImportError("safetensors not installed") + state_dict = safetensors.torch.load_file(module_path, device="cpu") + except Exception as e: + raise RuntimeError(f"Failed to load checkpoint: {e}") + + class MotionModule(torch.nn.Module): + def __init__(self, sd): + super().__init__() + self.sd = sd + def forward(self, x): + # Safe patch minimal + frame_max = x.abs().max() + if frame_max < 1e-3 and verbose: + print(f"[SAFE DEBUG] Frame trop faible ({frame_max:.6f}) → bruit injecté") + x = x + torch.randn_like(x)*1e-2 + return x + + motion_module = MotionModule(state_dict) + motion_module.to(device=device, dtype=dtype) + motion_module.eval() + if verbose: + print(f"✅ Motion module (checkpoint) loaded safe: {module_path}") + return motion_module + + else: + raise ValueError("Unsupported motion module file type: must be .py, .ckpt, or .safetensors") diff --git a/scripts/utils/n3rControlNet.py b/scripts/utils/n3rControlNet.py new file mode 100644 index 00000000..b140b540 --- /dev/null +++ b/scripts/utils/n3rControlNet.py @@ -0,0 +1,83 @@ +#n3rControlNet.py ------------------------------------------------ +#**** Ensemble des outils pour n3rControlNet +#------------------------------------------------------------------ + +import torch, numpy as np, cv2, gc + +def create_canny_control(image_pil, low=100, high=200, device='cuda', dtype=torch.float16): + """ + Génère un tenseur de contrôle à partir d'une image PIL via Canny. + Logs détaillés pour debug. + """ + print("[Canny] Conversion PIL -> grayscale") + img = np.array(image_pil.convert("L"), dtype=np.float32) / 255.0 + print(f"[Canny] Image shape: {img.shape}, min: {img.min():.6f}, max: {img.max():.6f}") + + # Canny edges + edges = cv2.Canny((img * 255).astype(np.uint8), low, high).astype(np.float32) / 255.0 + print(f"[Canny] Edges computed, min: {edges.min():.6f}, max: {edges.max():.6f}") + + # Convertir en tensor directement sur device + edges_tensor = torch.tensor(edges, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0) + print(f"[Canny] Tensor shape: {edges_tensor.shape}, dtype: {edges_tensor.dtype}, device: {edges_tensor.device}") + + # Clamp pour éviter 0 ou 1 exacts + edges_tensor = edges_tensor.clamp(1e-5, 1-1e-5) + + return edges_tensor + + +def control_to_latent(control_tensor, vae, device='cuda', LATENT_SCALE=1.0): + # Si 1 canal, dupliquer pour obtenir 3 canaux + if control_tensor.shape[1] == 1: + control_tensor = control_tensor.repeat(1, 3, 1, 1) + + # Assurer le type float16 pour économiser la VRAM + control_tensor = control_tensor.to(device=device, dtype=vae.dtype) # <- correction ici + print(f"[Control->Latent] Converted tensor dtype: {control_tensor.dtype}, device: {control_tensor.device}") + + # Encode VAE + with torch.no_grad(): # économise un peu de VRAM + latent = vae.encode(control_tensor).latent_dist.sample() + + print(f"[Control->Latent] Latent shape: {latent.shape}, min: {latent.min()}, max: {latent.max()}") + + return latent * LATENT_SCALE + + +import torch +import torch.nn.functional as F + +def match_latent_size(latents, control_latent, control_weight_map): + """ + Redimensionne control_latent et control_weight_map pour correspondre + à la taille de latents, et ajuste dtype/device. + + Args: + latents (torch.Tensor): tensor cible [B, C, H, W] + control_latent (torch.Tensor): tensor ControlNet [B, Cc, Hc, Wc] + control_weight_map (torch.Tensor): tensor poids [B, 1, Hc, Wc] + + Returns: + Tuple[torch.Tensor, torch.Tensor]: tensors redimensionnés et alignés + """ + target_size = latents.shape[-2:] + + # Redimensionner si nécessaire + if control_latent.shape[-2:] != target_size: + control_latent = F.interpolate( + control_latent, size=target_size, + mode='bilinear', align_corners=False + ) + + if control_weight_map.shape[-2:] != target_size: + control_weight_map = F.interpolate( + control_weight_map, size=target_size, + mode='bilinear', align_corners=False + ) + + # Assurer dtype et device identiques à latents + control_latent = control_latent.to(dtype=latents.dtype, device=latents.device) + control_weight_map = control_weight_map.to(dtype=latents.dtype, device=latents.device) + + return control_latent, control_weight_map diff --git a/scripts/utils/n3rModelFast4Go.py b/scripts/utils/n3rModelFast4Go.py new file mode 100644 index 00000000..b8038bc6 --- /dev/null +++ b/scripts/utils/n3rModelFast4Go.py @@ -0,0 +1,250 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class N3RModelOptimized(nn.Module): + def __init__(self, L_low=3, L_high=6, N_samples=6, tile_size=64, cpu_offload=True): + super().__init__() + self.L_low = L_low + self.L_high = L_high + self.N_samples = N_samples + self.tile_size = tile_size + self.cpu_offload = cpu_offload + + input_dim = 3 + 2 * 3 * L_high # Max freq + self.mlp = nn.Sequential( + nn.Linear(input_dim, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, 4) # RGB + density + ).half() # FP16 + + def positional_encoding(self, x): + """ + Coord [-1,1], mix L_low et L_high pour un rendu plus naturel + """ + x = x.half() + enc = [x] + # basse fréquence (structure) + for i in range(self.L_low): + for fn in [torch.sin, torch.cos]: + enc.append(fn((2.0 ** i) * x)) + # haute fréquence (détails) + for i in range(self.L_low, self.L_high): + for fn in [torch.sin, torch.cos]: + enc.append(fn((2.0 ** i) * x) * 0.2) # réduire l'amplitude pour limiter le néon + return torch.cat(enc, dim=-1) + + def normalize_coords(self, coords, H, W): + coords = coords.clone() + coords[:, 0] = coords[:, 0] / (W-1) * 2 - 1 + coords[:, 1] = coords[:, 1] / (H-1) * 2 - 1 + return coords + + def forward_tile(self, coords_tile): + mlp_device = next(self.mlp.parameters()).device + x_enc = self.positional_encoding(coords_tile.to(mlp_device)) + out = self.mlp(x_enc) + return out + + def forward(self, coords, H, W): + """ + coords : tensor (H*W*N_samples, 3) sur le device final (cuda) + H, W : dimensions de l'image + """ + device = coords.device + ts = self.tile_size + + # Choix du device de sortie + output_device = torch.device("cpu") if self.cpu_offload else device + output = torch.zeros((H*W*self.N_samples, 4), device=output_device, dtype=torch.float16) + + for y in range(0, H, ts): + for x in range(0, W, ts): + y_end = min(y + ts, H) + x_end = min(x + ts, W) + + # meshgrid vectorisé + y_range = torch.arange(y, y_end, device=device) + x_range = torch.arange(x, x_end, device=device) + s_range = torch.arange(self.N_samples, device=device) + yy, xx, s = torch.meshgrid(y_range, x_range, s_range, indexing='ij') + idx_tile = (yy * W * self.N_samples + xx * self.N_samples + s).reshape(-1) + + coords_tile = coords[idx_tile] + + # forward tile sur CPU si offload + mlp_device = torch.device("cpu") if self.cpu_offload else device + output_tile = self.forward_tile(coords_tile.to(mlp_device)) + + # assignation sur le même device que output + output[idx_tile.to(output_device)] = output_tile.to(output_device) + + # remettre sur le device final si CPU offload + if self.cpu_offload: + output = output.to(device) + + return output + + +# ------------------------- +# n3rModelLazyCPU.py +# ------------------------- + +class N3RModelLazyCPU(nn.Module): + def __init__(self, L=6, N_samples=6, tile_size=64, cpu_offload=True): + """ + L : nombre de fréquences pour positional encoding + N_samples : échantillons par pixel + tile_size : taille de chaque tile + cpu_offload: True pour décharger le MLP sur CPU + """ + super().__init__() + self.L = L + self.N_samples = N_samples + self.tile_size = tile_size + self.cpu_offload = cpu_offload + + input_dim = 3 + 2 * 3 * self.L + self.mlp = nn.Sequential( + nn.Linear(input_dim, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, 4) # RGB + density + ).half() # FP16 + + if self.cpu_offload: + self.mlp = self.mlp.to("cpu") + + def positional_encoding(self, x): + x = x.half() + enc = [x] + for i in range(self.L): + for fn in [torch.sin, torch.cos]: + enc.append(fn((2.0 ** i) * x)) + return torch.cat(enc, dim=-1) + + def forward_tile(self, coords_tile): + """ + Forward d'un tile avec CPU-offload safe + """ + device_orig = coords_tile.device + + # On envoie tout sur le même device que le MLP + mlp_device = next(self.mlp.parameters()).device + coords_tile = coords_tile.to(mlp_device).half() + + x_enc = self.positional_encoding(coords_tile) + out = self.mlp(x_enc) + + # Si MLP est sur CPU, remettre le output sur le device original (GPU) + if mlp_device.type == "cpu" and device_orig.type != "cpu": + out = out.to(device_orig) + + return out + + def forward(self, coords, H, W): + """ + coords : (H*W*N_samples, 3) + """ + device = coords.device + output = torch.zeros((H*W*self.N_samples, 4), device=device, dtype=torch.float16) + + ts = self.tile_size + + for y in range(0, H, ts): + for x in range(0, W, ts): + y_end = min(y + ts, H) + x_end = min(x + ts, W) + + # Meshgrid vectorisé pour tile + y_range = torch.arange(y, y_end, device=device) + x_range = torch.arange(x, x_end, device=device) + s_range = torch.arange(self.N_samples, device=device) + yy, xx, ss = torch.meshgrid(y_range, x_range, s_range, indexing='ij') + idx_tile = (yy * W * self.N_samples + xx * self.N_samples + ss).reshape(-1) + + coords_tile = coords[idx_tile] + output_tile = self.forward_tile(coords_tile) + + output[idx_tile] = output_tile + + return output + +class N3RModelFast4GB(nn.Module): + def __init__(self, L=6, N_samples=16, tile_size=64, cpu_offload=False): + """ + L : nombre de fréquences pour le positional encoding + N_samples : échantillons par pixel + tile_size : taille des tiles pour le forward + cpu_offload: True pour décharger les tiles sur CPU pour VRAM limitée + """ + super().__init__() + self.L = L + self.N_samples = N_samples + self.tile_size = tile_size + self.cpu_offload = cpu_offload + + # Dimensions d'entrée : 3 coords + 2*3*L (sin/cos) + input_dim = 3 + 2 * 3 * self.L + self.mlp = nn.Sequential( + nn.Linear(input_dim, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, 4) # RGB + density + ).half() # FP16 partout + + def positional_encoding(self, x): + """Encode input coordinates avec sin/cos.""" + x = x.half() + enc = [x] + for i in range(self.L): + factor = 2.0 ** i + enc.append(torch.sin(factor * x)) + enc.append(torch.cos(factor * x)) + return torch.cat(enc, dim=-1) + + def forward_tile(self, coords): + """Forward pass d’un tile avec autocast FP16.""" + with torch.amp.autocast(device_type='cuda', dtype=torch.float16): + x_enc = self.positional_encoding(coords) + out = self.mlp(x_enc) + return out + + def forward(self, coords, H, W): + """ + coords : tensor (H*W*N_samples, 3) + H, W : dimensions de l'image + """ + device = coords.device + output = torch.zeros((H*W*self.N_samples, 4), device=device, dtype=torch.float16) + + ts = self.tile_size + for y in range(0, H, ts): + for x in range(0, W, ts): + y_end = min(y + ts, H) + x_end = min(x + ts, W) + + # Meshgrid vectorisé + yy, xx = torch.meshgrid( + torch.arange(y, y_end, device=device), + torch.arange(x, x_end, device=device), + indexing='ij' + ) + idx_tile = (yy[..., None] * W * self.N_samples + xx[..., None] * self.N_samples + + torch.arange(self.N_samples, device=device)).reshape(-1) + + coords_tile = coords[idx_tile] + output_tile = self.forward_tile(coords_tile) + + # CPU offload sélectif + if self.cpu_offload: + output[idx_tile.cpu()] = output_tile.cpu() + else: + output[idx_tile] = output_tile + + return output diff --git a/scripts/utils/n3rModelUtils.py b/scripts/utils/n3rModelUtils.py new file mode 100644 index 00000000..df4a918f --- /dev/null +++ b/scripts/utils/n3rModelUtils.py @@ -0,0 +1,135 @@ +# n3rModelUtils.py +import os, math +import torch +from torchvision.transforms import functional as F +from scripts.utils.tools_utils import update_n3r_memory, inject_external_embeddings + +def fuse_n3r_latents_adaptive_new(latents_frame, n3r_latents, frame_counter=None, total_frames=None, + latent_injection_start=0.9, latent_injection_end=0.55, + clamp_val=1.0, creative_noise=0.0): + """ + Fusion adaptative des latents N3R avec les latents du frame courant. + - latents_frame : tensor [B,C,H,W] du frame courant + - n3r_latents : tensor [B,C,H,W] N3R latents à fusionner + - frame_counter, total_frames : pour injection dynamique + - latent_injection_start / end : plage d'injection dynamique + - creative_noise : ajout de bruit léger + """ + # ---------------- Resize si nécessaire ---------------- + if n3r_latents.shape != latents_frame.shape: + n3r_latents = torch.nn.functional.interpolate( + n3r_latents, size=latents_frame.shape[-2:], mode='bilinear', align_corners=False + ) + # Ajuster le nombre de canaux + if n3r_latents.shape[1] < latents_frame.shape[1]: + extra = latents_frame[:, n3r_latents.shape[1]:, :, :].clone() * 0 + n3r_latents = torch.cat([n3r_latents, extra], dim=1) + elif n3r_latents.shape[1] > latents_frame.shape[1]: + n3r_latents = n3r_latents[:, :latents_frame.shape[1], :, :] + + n3r_latents = n3r_latents.clone() + + # ---------------- Normalisation sûre par canal ---------------- + for c in range(min(3, n3r_latents.shape[1])): + n3r_c = n3r_latents[:, c:c+1, :, :] + frame_c = latents_frame[:, c:c+1, :, :] + mean, std = n3r_c.mean(), n3r_c.std() + if std.item() == 0: # sécurité division par zéro + std = 1.0 + n3r_c = (n3r_c - mean) / std + frame_std = frame_c.std() + frame_mean = frame_c.mean() + if frame_std.item() == 0: + frame_std = 1.0 + n3r_c = n3r_c * frame_std + frame_mean + n3r_latents[:, c:c+1, :, :] = n3r_c + + # ---------------- Bruit créatif léger ---------------- + if creative_noise > 0.0: + noise = torch.randn_like(n3r_latents) * creative_noise + n3r_latents += noise + + # ---------------- Clamp ---------------- + n3r_latents = torch.clamp(n3r_latents, -clamp_val, clamp_val) + latents_frame = torch.clamp(latents_frame, -clamp_val, clamp_val) + + # ---------------- Injection dynamique ---------------- + if frame_counter is not None and total_frames is not None: + alpha = 0.5 - 0.5 * math.cos(math.pi * frame_counter / max(total_frames - 1, 1)) + latent_injection = latent_injection_start + (latent_injection_end - latent_injection_start) * alpha + latent_injection = min(max(latent_injection, 0.0), 1.0) + else: + latent_injection = latent_injection_start + + fused_latents = latent_injection * latents_frame + (1 - latent_injection) * n3r_latents + fused_latents = torch.clamp(fused_latents, -clamp_val, clamp_val) + + return fused_latents + + +def generate_n3r_coords(H, W, N_samples, seed, frame_counter, device): + """Génère des coordonnées normalisées + jitter + bruit.""" + ys = torch.linspace(-1.0, 1.0, H, device=device) + xs = torch.linspace(-1.0, 1.0, W, device=device) + ss = torch.arange(N_samples, device=device, dtype=torch.float32) + ys, xs, ss = torch.meshgrid(ys, xs, ss, indexing='ij') + coords = torch.stack([xs, ys, ss], dim=-1).reshape(-1, 3) + + # jitter léger + noise_scale = 0.01 + 0.02 * math.sin(frame_counter * 0.1) + torch.manual_seed(seed + frame_counter) + coords = coords + (torch.rand_like(coords) - 0.5) * 0.02 + coords = coords + torch.randn_like(coords) * noise_scale + coords = torch.nan_to_num(coords) + + return coords + + +def process_n3r_latents(n3r_model, coords, H, W, target_H, target_W): + """Forward N3R et interpolation vers HxW cible + ajout canal alpha.""" + n3r_latents_raw = n3r_model(coords, H, W)[:, :3] + expected = H * W * n3r_model.N_samples + if n3r_latents_raw.shape[0] != expected: + raise RuntimeError(f"N3R reshape mismatch: {n3r_latents_raw.shape[0]} vs {expected}") + + n3r_latents = n3r_latents_raw.view(H, W, n3r_model.N_samples, 3).mean(dim=2) + n3r_latents = n3r_latents.permute(2, 0, 1).unsqueeze(0) # (1,3,H,W) + + if n3r_latents.shape[1] == 3: + n3r_latents = torch.cat([n3r_latents, torch.zeros_like(n3r_latents[:, :1])], dim=1) + + if n3r_latents.shape[-2:] != (target_H, target_W): + n3r_latents = F.interpolate(n3r_latents, size=(target_H, target_W), mode='bilinear', align_corners=False) + + return n3r_latents + + +def fuse_with_memory(n3r_latents, memory_dict, cf_embeds, frame_counter): + """Fusionne N3R latents avec la mémoire et calcule alpha adaptatif.""" + memory_alpha = 0.1 + 0.1 * math.sin(frame_counter * 0.05) + pos_emb, neg_emb = cf_embeds + key_embed = pos_emb - 0.5 * neg_emb + fused_latents = update_n3r_memory(memory_dict, key_embed, n3r_latents, memory_alpha=memory_alpha) + + if fused_latents.shape[-2:] != n3r_latents.shape[-2:]: + fused_latents = F.interpolate(fused_latents, size=n3r_latents.shape[-2:], mode='bilinear', align_corners=False) + + similarity = torch.cosine_similarity(n3r_latents.flatten(), fused_latents.flatten(), dim=0) + adaptive_alpha = 0.1 + 0.2 * (1 - similarity) + fused_latents = (1 - adaptive_alpha) * fused_latents + adaptive_alpha * n3r_latents + + return fused_latents + + +def inject_external(fused_latents, external_latent, frame_counter, device): + """Injection du latent externe avec poids dynamique.""" + if external_latent is None or external_latent.shape != fused_latents.shape: + raise RuntimeError("External latent absent ou dimensions incorrectes") + + dynamic_weight = 0.08 * (0.6 + 0.4 * math.sin(frame_counter * 0.1)) + external_embeddings = [{"key": "knx_neg", "latent": external_latent, "weight": dynamic_weight, "type": "negative"}] + + if external_embeddings: + fused_latents = 0.9 * fused_latents + 0.1 * inject_external_embeddings(fused_latents, external_embeddings, device) + + return fused_latents diff --git a/scripts/utils/n3rOpenPose_utils.py b/scripts/utils/n3rOpenPose_utils.py new file mode 100644 index 00000000..be1f3227 --- /dev/null +++ b/scripts/utils/n3rOpenPose_utils.py @@ -0,0 +1,1909 @@ +#******************************************** +# n3rOpenPose_utils.py +#******************************************** +import torch +from diffusers import ControlNetModel +import math +import torch.nn.functional as F +from .n3rControlNet import create_canny_control, control_to_latent, match_latent_size +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import os +import torchvision.transforms.functional as TF +from PIL import Image, ImageDraw +import traceback + +import torch +import torch.nn.functional as F + +def gaussian_blur_tensor(x, kernel_size=5, sigma=1.0): + """Applique un flou gaussien sur un tensor 2D ou 4D (B,C,H,W).""" + if x.ndim == 2: + x = x.unsqueeze(0).unsqueeze(0) # [1,1,H,W] + elif x.ndim == 3: + x = x.unsqueeze(0) # [1,C,H,W] + + # créer kernel gaussien 1D + coords = torch.arange(kernel_size).float() - (kernel_size - 1) / 2 + gauss = torch.exp(-(coords**2) / (2 * sigma**2)) + gauss = gauss / gauss.sum() + + # kernel 2D par produit extérieur + kernel2d = gauss[:, None] * gauss[None, :] + kernel2d = kernel2d.unsqueeze(0).unsqueeze(0) # [1,1,K,K] + kernel2d = kernel2d.to(x.device, dtype=x.dtype) + + # padding "same" + pad = kernel_size // 2 + x = F.conv2d(x, kernel2d, padding=pad) + + if x.shape[0] == 1 and x.shape[1] == 1: + x = x.squeeze() + return x + +def log_frame_error(img_path, error: Exception, verbose: bool = True): + """ + Log propre d'une erreur sur une frame. + + Args: + img_path: chemin de l'image/frame + error: exception capturée + verbose: afficher le traceback complet + """ + + print(f"\n[FRAME ERROR] {img_path}") + print(f"Type d'erreur : {type(error).__name__}") + print(f"Message d'erreur : {error}") + + if verbose: + print("Traceback complet :") + traceback.print_exc() + + +def prepare_controlnet( + controlnet, + freeze: bool = True, + enable_slicing: bool = True, + device=None, + dtype=None, + verbose: bool = True +): + """ + Prépare un ControlNet : + - eval mode + - freeze des poids + - attention slicing (si dispo) + - move device / dtype + - init pose_sequence + + Returns: + controlnet, pose_sequence (None par défaut) + """ + + # ---- eval mode + controlnet.eval() + if verbose: + print("✅ ControlNet en mode eval") + + # ---- freeze + if freeze: + for p in controlnet.parameters(): + p.requires_grad = False + if verbose: + print("✅ Paramètres gelés") + + # ---- attention slicing + if enable_slicing: + fn = getattr(controlnet, "enable_attention_slicing", None) + if callable(fn): + fn() + if verbose: + print("✅ Attention slicing activé") + else: + if verbose: + print("⚠ enable_attention_slicing non disponible") + + # ---- device / dtype + if device is not None or dtype is not None: + controlnet = controlnet.to(device=device, dtype=dtype) + if verbose: + print(f"✅ Déplacé sur {device} / {dtype}") + + # ---- init pose + pose_sequence = None + + return controlnet, pose_sequence + +def fix_pose_sequence( + pose_sequence: torch.Tensor, + total_frames: int, + device=None, + dtype=None, + verbose: bool = True +) -> torch.Tensor: + """ + Ajuste une séquence de poses au bon nombre de frames avec interpolation. + + Args: + pose_sequence: Tensor (F, C, H, W) + total_frames: nombre de frames cible + device: device cible (optionnel) + dtype: dtype cible (optionnel) + verbose: afficher logs + + Returns: + Tensor (F, C, H, W) + """ + print(f"🎞 fix_pose_sequence - Frames JSON: {pose_sequence.shape[0]}") + print(f"🎞 fix_pose_sequence - Frames attendues: {total_frames}") + + if pose_sequence.shape[0] != total_frames: + if verbose: + print("⚠ Ajustement du nombre de frames OpenPose") + + # (F, C, H, W) → (1, C, F, H, W) + pose_sequence = pose_sequence.permute(1, 0, 2, 3).unsqueeze(0) + + pose_sequence = F.interpolate( + pose_sequence, + size=(total_frames, pose_sequence.shape[-2], pose_sequence.shape[-1]), + mode='trilinear', + align_corners=False + ) + + # retour → (F, C, H, W) + pose_sequence = pose_sequence.squeeze(0).permute(1, 0, 2, 3) + + # Fix device + dtype + if device is not None or dtype is not None: + pose_sequence = pose_sequence.to(device=device, dtype=dtype) + + if verbose: + print( + "✅ PoseSequence final:", + pose_sequence.shape, + pose_sequence.device, + pose_sequence.dtype + ) + + return pose_sequence + + + +def tensor_to_pil(tensor): + """ + Convertit un tensor torch [C,H,W] ou [H,W] en PIL.Image RGB. + """ + if tensor.dim() == 3: + C, H, W = tensor.shape + if C == 1: + array = tensor[0].cpu().numpy() # [H,W] + pil_img = Image.fromarray(array).convert("RGB") + elif C == 3: + array = tensor.permute(1, 2, 0).cpu().numpy() # [H,W,C] + pil_img = Image.fromarray(array) + else: + raise ValueError(f"Tensor avec {C} canaux non supporté") + elif tensor.dim() == 2: + pil_img = Image.fromarray(tensor.cpu().numpy()).convert("RGB") + else: + raise ValueError(f"Tensor shape non supportée: {tensor.shape}") + return pil_img + +import os +from PIL import Image +import torch + +def save_debug_pose_image(pose_tensor, frame_counter, output_dir, cfg=None, prefix="openpose"): + """ + Sauvegarde une image de pose pour contrôle visuel. + + pose_tensor : torch.Tensor [C,H,W] ou [H,W] + frame_counter : int, numéro de frame + output_dir : str, dossier où sauvegarder + cfg : dict ou None, peut contenir paramètre 'visual_debug' pour activer/désactiver + prefix : str, préfixe du fichier + """ + + # Vérifie si le debug visuel est activé + if cfg is not None and cfg.get("visual_debug") is False: + return + + # Convertir tensor en uint8 [0,255] + pose_img = (pose_tensor * 255).clamp(0, 255).byte() + + # Fonction interne pour gérer tous les formats [C,H,W], [H,W] + def tensor_to_pil(tensor): + if tensor.dim() == 3: + C, H, W = tensor.shape + if C == 1: + array = tensor[0].cpu().numpy() # [H,W] + pil_img = Image.fromarray(array).convert("RGB") + elif C == 3: + array = tensor.permute(1, 2, 0).cpu().numpy() # [H,W,C] + pil_img = Image.fromarray(array) + else: + raise ValueError(f"Tensor avec {C} canaux non supporté") + elif tensor.dim() == 2: + pil_img = Image.fromarray(tensor.cpu().numpy()).convert("RGB") + else: + # Si la tensor a une forme inattendue, on essaie de la "squeezer" + tensor = tensor.squeeze() + if tensor.dim() in [2, 3]: + return tensor_to_pil(tensor) + raise ValueError(f"Tensor shape non supportée: {tensor.shape}") + return pil_img + + pil_pose = tensor_to_pil(pose_img) + + # Création du dossier si nécessaire + os.makedirs(output_dir, exist_ok=True) + + # Nom du fichier : openpose_00001.png + filename = f"{prefix}_{frame_counter:05d}.png" + path = os.path.join(output_dir, filename) + + pil_pose.save(path) + print(f"[DEBUG] Pose sauvegardée : {path}") + +def save_debug_pose_image_mini(pose_tensor, frame_counter, output_dir, cfg=None, prefix="openpose"): + """ + Sauvegarde la pose détectée pour vérification visuelle. + + Args: + pose_tensor (torch.Tensor): Tensor BCHW ou CHW (1,3,H,W ou 3,H,W) + frame_counter (int): numéro de la frame + output_dir (Path): dossier de sortie pour sauvegarde + cfg (dict, optional): configuration, active si cfg.get("debug_pose_visual", False) est True + prefix (str): préfixe du fichier image (default: 'openpose') + """ + if cfg is None or not cfg.get("debug_pose_visual", False): + return + + # S'assurer que le tensor est BCHW + if pose_tensor.ndim == 3: # CHW -> BCHW + pose_tensor = pose_tensor.unsqueeze(0) + + pose_tensor = pose_tensor[0] # retirer batch + + # Limiter à 3 canaux + if pose_tensor.shape[0] > 3: + pose_tensor = pose_tensor[:3, :, :] + + # CHW -> HWC + pose_np = pose_tensor.permute(1, 2, 0).cpu().numpy() + # Normalisation 0-255 + pose_np = (pose_np - pose_np.min()) / (pose_np.max() - pose_np.min() + 1e-8) * 255.0 + pose_np = pose_np.astype("uint8") + img = Image.fromarray(pose_np) + + # Nom de fichier : openpose_0001.png + output_dir.mkdir(parents=True, exist_ok=True) + filename = output_dir / f"{prefix}_{frame_counter:04d}.png" + img.save(filename) + +def debug_pose_visual(pose_tensor, frame_counter, cfg=None, title="Pose Debug"): + """ + Affiche la pose détectée pour vérification visuelle. + + Args: + pose_tensor (torch.Tensor): Tensor BCHW ou CHW (1,3,H,W ou 3,H,W) + frame_counter (int): numéro de la frame + cfg (dict, optional): configuration, active si cfg.get("debug_pose_visual", False) est True + title (str): titre pour l'affichage + """ + if cfg is None or not cfg.get("debug_pose_visual", False): + return + + # S'assurer que le tensor est BCHW + if pose_tensor.ndim == 3: # CHW -> BCHW + pose_tensor = pose_tensor.unsqueeze(0) + + pose_tensor = pose_tensor[0] # retirer batch + + # Limiter à 3 canaux + if pose_tensor.shape[0] > 3: + pose_tensor = pose_tensor[:3, :, :] + + # CHW -> HWC pour PIL + pose_np = pose_tensor.permute(1, 2, 0).cpu().numpy() + pose_np = (pose_np - pose_np.min()) / (pose_np.max() - pose_np.min() + 1e-8) * 255.0 + pose_np = pose_np.astype("uint8") + img = Image.fromarray(pose_np) + + # Affichage rapide avec matplotlib + plt.figure(figsize=(4, 4)) + plt.imshow(img) + plt.axis("off") + plt.title(f"{title} - Frame {frame_counter}") + plt.show(block=False) + plt.pause(0.1) # court délai pour rafraîchir + plt.close() + +def convert_json_to_pose_sequence(anim_data, H=512, W=512, device="cuda", dtype=torch.float16, debug=False): + """ + Convertit un JSON d'animation OpenPose simplifié en tensor utilisable par ControlNet. + + Output: + pose_sequence: tensor [num_frames, 3, H, W] (RGB image type) + """ + + frames = anim_data.get("animation", []) + pose_images = [] + + for idx, frame in enumerate(frames): + keypoints = frame.get("keypoints", []) + + # Image noire + canvas = np.zeros((H, W, 3), dtype=np.uint8) + + # --- Dessin des points --- + for kp in keypoints: + x = int(kp["x"]) + y = int(kp["y"]) + conf = kp.get("confidence", 1.0) + + if conf > 0.3: + cv2.circle(canvas, (x, y), 4, (255, 255, 255), -1) + + # --- Dessin des connexions (squelette simple) --- + skeleton = [ + (0, 1), # tête → torse + (1, 2), # torse → bras gauche + (1, 3), # torse → bras droit + (1, 4), # torse → jambe gauche + (1, 5), # torse → jambe droite + ] + + for a, b in skeleton: + if a < len(keypoints) and b < len(keypoints): + x1, y1 = int(keypoints[a]["x"]), int(keypoints[a]["y"]) + x2, y2 = int(keypoints[b]["x"]), int(keypoints[b]["y"]) + cv2.line(canvas, (x1, y1), (x2, y2), (255, 255, 255), 2) + + # --- Conversion en tensor --- + img = torch.from_numpy(canvas).float() / 255.0 # [H, W, C] + img = img.permute(2, 0, 1) # → [C, H, W] + + pose_images.append(img) + + pose_sequence = torch.stack(pose_images).to(device=device, dtype=dtype) + + if debug: + print(f"[JSON->POSE] shape: {pose_sequence.shape}") + print(f"[JSON->POSE] min/max: {pose_sequence.min().item()} / {pose_sequence.max().item()}") + + return pose_sequence + + +def apply_controlnet_openpose_step_safe( + latents, + timestep, + unet, + controlnet, + scheduler, + pose_image, + pos_embeds, + neg_embeds, + guidance_scale, + controlnet_scale=0.25, + device="cuda", + dtype=torch.float16, + debug=False +): + """ + Wrapper sécurisé pour apply_controlnet_openpose_step + - gère CPU/GPU + - corrige dtype + - convertit timestep en long pour scheduler + """ + # --- CPU float32 pour ControlNet --- + latents_cpu = latents.to("cpu", dtype=torch.float32) + unet_cpu = unet.to("cpu", dtype=torch.float32) + controlnet_cpu = controlnet.to("cpu", dtype=torch.float32) + pose_cpu = pose_image.to("cpu", dtype=torch.float32) + pos_embeds_cpu = pos_embeds.to("cpu", dtype=torch.float32) + neg_embeds_cpu = neg_embeds.to("cpu", dtype=torch.float32) + + # --- Préparer timestep --- + if timestep.ndim == 0: + timestep = timestep.unsqueeze(0) + batch_size = latents_cpu.shape[0] + timestep = timestep.repeat(batch_size).to(torch.long).to("cpu") + + # --- Appel ControlNet OpenPose --- + latents_cpu = apply_controlnet_openpose_step( + latents=latents_cpu, + t=timestep, + unet=unet_cpu, + controlnet=controlnet_cpu, + scheduler=scheduler, + pose_image=pose_cpu, + pos_embeds=pos_embeds_cpu, + neg_embeds=neg_embeds_cpu, + guidance_scale=guidance_scale, + controlnet_scale=controlnet_scale, + device="cpu", + dtype=torch.float32, + debug=debug + ) + + # --- Retour sur GPU et dtype final --- + latents_out = latents_cpu.to(device, dtype=dtype) + unet.to(device, dtype=dtype) + + return latents_out + +def build_control_latent_debug(input_pil, vae, device="cuda", latent_scale=0.18215): + import torch + + print("\n================ CONTROL LATENT DEBUG ================") + + # 1. Canny + control = create_canny_control(input_pil) + + print("[STEP 1] RAW CONTROL") + print(" shape:", control.shape) + print(" dtype:", control.dtype) + print(" min/max:", control.min().item(), control.max().item()) + + # 2. 1 → 3 channels + if control.shape[1] == 1: + control = control.repeat(1, 3, 1, 1) + + # 3. Normalize PROPERLY (CRUCIAL) + control = control.clamp(0, 1) # sécurité + control = control * 2.0 - 1.0 # [-1,1] + + print("[STEP 2] NORMALIZED") + print(" min/max:", control.min().item(), control.max().item()) + + # 4. Move to device FP32 + control = control.to(device=device, dtype=torch.float32) + + print("[STEP 3] DEVICE") + print(" device:", control.device) + print(" dtype:", control.dtype) + + # 5. Sync VAE + print("[STEP 4] VAE STATE") + print(" vae dtype:", next(vae.parameters()).dtype) + print(" vae device:", next(vae.parameters()).device) + + # 🔥 FORCER cohérence VAE + vae = vae.to(device=device, dtype=torch.float32) + + # 6. Encode SAFE (no autocast) + with torch.no_grad(): + try: + latent_dist = vae.encode(control).latent_dist + latent = latent_dist.sample() + except Exception as e: + print("❌ VAE ENCODE CRASH:", e) + raise + + print("[STEP 5] LATENT RAW") + print(" min/max:", latent.min().item(), latent.max().item()) + print(" NaN:", torch.isnan(latent).sum().item()) + + # 🚨 CHECK NaN + if torch.isnan(latent).any(): + print("⚠️ NaN DETECTED → applying fallback") + + # fallback 1: zero latent + latent = torch.zeros_like(latent) + + # fallback 2 (optionnel): + # latent = torch.randn_like(latent) * 0.1 + + # 7. Scale (SD standard) + latent = latent * latent_scale + + print("[STEP 6] SCALED LATENT") + print(" min/max:", latent.min().item(), latent.max().item()) + + # 8. Back to FP16 + latent = latent.to(dtype=torch.float16) + + print("[FINAL]") + print(" dtype:", latent.dtype) + print(" device:", latent.device) + print("=====================================================\n") + + return latent + +# ---------------- Control -> Latent sécurisé ---------------- +def control_to_latent_safe(control_tensor, vae, device="cuda", LATENT_SCALE=1.0): + # 🔥 FORCE VAE EN FP32 + vae = vae.to(device=device, dtype=torch.float32) + + control_tensor = control_tensor.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + latent = vae.encode(control_tensor).latent_dist.sample() + + return latent * LATENT_SCALE + +def process_latents_streamed(control_latent, mini_latents=None, mini_weight=0.5, device="cuda"): + """ + Fusionne ControlNet / mini-latents frame par frame, patch par patch + pour réduire l'empreinte VRAM. + """ + # On garde tout en float16 tant que possible + control_latent = control_latent.to(device=device, dtype=torch.float16) + + if mini_latents is not None: + mini_latents = mini_latents.to(device=device, dtype=torch.float16) + + # Initialisation finale du tensor latents en float16 + latents = control_latent.clone() + + # Si mini_latents existe, on fait un mix patch par patch + if mini_latents is not None: + B, C, H, W = latents.shape + patch_size = 16 # petit patch pour limiter la VRAM + for y in range(0, H, patch_size): + y1 = min(y + patch_size, H) + for x in range(0, W, patch_size): + x1 = min(x + patch_size, W) + + # Sélection patch + patch_main = latents[:, :, y:y1, x:x1] + patch_mini = mini_latents[:, :, y:y1, x:x1] + + # Mix float16 → float16 pour VRAM + patch_main = (1 - mini_weight) * patch_main + mini_weight * patch_mini + + # Écriture patch back + latents[:, :, y:y1, x:x1] = patch_main + + # Nettoyage immédiat pour libérer VRAM + del patch_main, patch_mini + torch.cuda.empty_cache() + + return latents + + +def match_latent_size(latents_main, *tensors): + """ + Interpole tous les tensors pour correspondre à la taille HxW de latents_main. + """ + matched = [] + for t in tensors: + if t.shape[2:] != latents_main.shape[2:]: + t = F.interpolate(t, size=latents_main.shape[2:], mode='bilinear', align_corners=False) + matched.append(t) + return matched if len(matched) > 1 else matched[0] + + +def match_latent_size_v1(latents_main, latents_mini): + """ + Assure que latents_mini a la même taille HxW que latents_main. + """ + if latents_mini.shape[2:] != latents_main.shape[2:]: + latents_mini = F.interpolate( + latents_mini, + size=latents_main.shape[2:], # H, W + mode='bilinear', + align_corners=False + ) + return latents_mini + + +import torch +# ****************************** A TESTER ******************************** +def apply_openpose_tilewise( + latents, + pose, + apply_fn, + block_size=32, + overlap=16, + device='cuda', + debug=False, + debug_dir=None, + frame_idx=0, + full_res=(960, 512) +): + import torch + import torch.nn.functional as F + import os + import numpy as np + from PIL import Image + + B, C_lat, H_lat, W_lat = latents.shape + + # Redimensionner pose aux latents + pose_resized = F.interpolate( + pose, size=(H_lat, W_lat), mode='bilinear', align_corners=False + ).clamp(0.0, 1.0) + + stride = block_size - overlap + + # 🔹 Padding pour éviter tiles vides + pad_h = (stride - H_lat % stride) % stride + pad_w = (stride - W_lat % stride) % stride + latents_padded = F.pad(latents, (0, pad_w, 0, pad_h), mode='reflect') + H_pad, W_pad = H_lat + pad_h, W_lat + pad_w + + latents_out = latents_padded.clone() + if debug: + impact_map = torch.zeros((H_pad, W_pad), device=device) + + tile_id = 0 + for i in range(0, H_pad, stride): + for j in range(0, W_pad, stride): + i_end = min(i + block_size, H_pad) + j_end = min(j + block_size, W_pad) + i_start = max(0, i_end - block_size) + j_start = max(0, j_end - block_size) + + # 🔹 Ignorer les tiles invalides + if (i_end - i_start) <= 0 or (j_end - j_start) <= 0: + continue + + latent_tile = latents_padded[:, :, i_start:i_end, j_start:j_end].to(device=device, dtype=torch.float16) + tile_coords = (j_start, i_start, j_end, i_end) + + # Apply avec fallback + try: + latent_tile_processed = apply_fn(latent_tile, tile_coords) + latent_tile_processed = latent_tile_processed.clamp(-5.0, 5.0) + except Exception as e: + print(f"[WARNING] Tile {tile_id} ({i_start}:{i_end},{j_start}:{j_end}) failed: {e}") + latent_tile_processed = latent_tile + + # 🔹 Blend sur les overlaps + blend_mask = torch.ones((1, 1, i_end-i_start, j_end-j_start), device=device) + # horizontal + if j_start != 0: + fade = torch.linspace(0, 1, min(overlap, j_end-j_start), device=device).view(1,1,1,-1) + blend_mask[:, :, :, :fade.shape[-1]] *= fade + if j_end != W_pad: + fade = torch.linspace(1, 0, min(overlap, j_end-j_start), device=device).view(1,1,1,-1) + blend_mask[:, :, :, -fade.shape[-1]:] *= fade + # vertical + if i_start != 0: + fade = torch.linspace(0, 1, min(overlap, i_end-i_start), device=device).view(1,1,-1,1) + blend_mask[:, :, :fade.shape[-2], :] *= fade + if i_end != H_pad: + fade = torch.linspace(1, 0, min(overlap, i_end-i_start), device=device).view(1,1,-1,1) + blend_mask[:, :, -fade.shape[-2]:, :] *= fade + + latents_out[:, :, i_start:i_end, j_start:j_end] = ( + latents_out[:, :, i_start:i_end, j_start:j_end] * (1 - blend_mask) + + latent_tile_processed * blend_mask + ) + + if debug: + diff_map = (latent_tile_processed - latent_tile).abs().mean(dim=1) + impact_map[i_start:i_end, j_start:j_end] += diff_map.squeeze(0) + print(f"[TILE {tile_id}] diff={diff_map.mean().item():.6f}") + + tile_id += 1 + + # Retirer padding + latents_out = latents_out[:, :, :H_lat, :W_lat] + + # 🔹 DEBUG impact map + if debug and debug_dir is not None: + os.makedirs(debug_dir, exist_ok=True) + impact_map_full = impact_map[:H_lat, :W_lat].unsqueeze(0).unsqueeze(0) + impact_map_full = F.interpolate(impact_map_full, size=full_res, mode='bilinear', align_corners=False).squeeze() + impact_np = impact_map_full.detach().cpu().numpy() + impact_np = impact_np - impact_np.min() + if impact_np.max() > 0: + impact_np = impact_np / impact_np.max() + Image.fromarray((impact_np*255).astype(np.uint8)).save( + os.path.join(debug_dir, f"impact_map_{frame_idx:05d}.png") + ) + + return latents_out + +def apply_openpose_tilewise_v3( + latents, + pose, + apply_fn, + block_size=64, + overlap=32, + device='cuda', + debug=False, + debug_dir=None, + frame_idx=0, + full_res=(960, 512) # résolution originale pour l'impact map +): + import torch + import torch.nn.functional as F + import os + import numpy as np + from PIL import Image + + B, C_lat, H_lat, W_lat = latents.shape + + # 🔹 Redimensionner le pose à la résolution des latents et clamp + pose_resized = F.interpolate( + pose, size=(H_lat, W_lat), + mode='bilinear', align_corners=False + ).clamp(0.0, 1.0) + + latents_out = latents.clone() + + # 🔹 DEBUG MAP (impact spatial) + if debug: + impact_map = torch.zeros((H_lat, W_lat), device=device) + + stride = block_size - overlap + tile_id = 0 + + for i in range(0, H_lat, stride): + for j in range(0, W_lat, stride): + + i_end = min(i + block_size, H_lat) + j_end = min(j + block_size, W_lat) + + i_start = max(0, i_end - block_size) + j_start = max(0, j_end - block_size) + + latent_tile = latents[:, :, i_start:i_end, j_start:j_end].to(device=device, dtype=torch.float16) + tile_coords = (j_start, i_start, j_end, i_end) + + # 🔹 APPLY avec fallback en cas de crash + try: + latent_tile_processed = apply_fn(latent_tile, tile_coords) + latent_tile_processed = latent_tile_processed.clamp(-5.0, 5.0) + except Exception as e: + print(f"[WARNING] Tile {tile_id} ({i_start}:{i_end},{j_start}:{j_end}) failed: {e}") + latent_tile_processed = latent_tile # fallback + + # 🔹 DEBUG METRICS + if debug: + diff_map = (latent_tile_processed - latent_tile).abs().mean(dim=1) + impact_map[i_start:i_end, j_start:j_end] += diff_map.squeeze(0) + + diff_mean = diff_map.mean().item() + min_val = latent_tile_processed.min().item() + max_val = latent_tile_processed.max().item() + print(f"[TILE {tile_id}] ({i_start}:{i_end},{j_start}:{j_end}) " + f"diff={diff_mean:.6f} min={min_val:.4f} max={max_val:.4f}") + + # 🔹 BLEND au lieu d'overwrite + latents_out[:, :, i_start:i_end, j_start:j_end] = ( + 0.7 * latents_out[:, :, i_start:i_end, j_start:j_end] + + 0.3 * latent_tile_processed + ) + + tile_id += 1 + + # 🔹 VISUAL DEBUG GLOBAL + if debug and debug_dir is not None: + os.makedirs(debug_dir, exist_ok=True) + + # 🔹 Upsample à la résolution originale + impact_map_full = F.interpolate( + impact_map.unsqueeze(0).unsqueeze(0), + size=full_res, + mode='bilinear', + align_corners=False + ).squeeze() + + # 🔹 Lissage pour réduire le bruit + impact_map_full = gaussian_blur_tensor(impact_map_full, kernel_size=5, sigma=1.0) + + # 🔹 Normalisation pour visualisation + impact_np = impact_map_full.detach().cpu().numpy() + impact_np = impact_np - impact_np.min() + if impact_np.max() > 0: + impact_np = impact_np / impact_np.max() + impact_img = (impact_np * 255).astype(np.uint8) + impact_img = Image.fromarray(impact_img) + + impact_img.save(os.path.join(debug_dir, f"impact_map_{frame_idx:05d}.png")) + print(f"[DEBUG] Impact map saved") + + return latents_out + +def apply_openpose_tilewise_v2( + latents, + pose, + apply_fn, + block_size=64, + overlap=32, + device='cuda', + debug=False, + debug_dir=None, + frame_idx=0, + full_res=(960, 512) # résolution originale pour l'impact map +): + import torch + import torch.nn.functional as F + import os + import numpy as np + from PIL import Image + + B, C_lat, H_lat, W_lat = latents.shape + + # 🔹 Redimensionner le pose à la résolution des latents + pose_resized = F.interpolate( + pose, size=(H_lat, W_lat), + mode='bilinear', align_corners=False + ) + + latents_out = latents.clone() + + # 🔹 DEBUG MAP (impact spatial) + if debug: + impact_map = torch.zeros((H_lat, W_lat), device=device) + + stride = block_size - overlap + tile_id = 0 + + for i in range(0, H_lat, stride): + for j in range(0, W_lat, stride): + + i_end = min(i + block_size, H_lat) + j_end = min(j + block_size, W_lat) + + i_start = max(0, i_end - block_size) + j_start = max(0, j_end - block_size) + + latent_tile = latents[:, :, i_start:i_end, j_start:j_end] + tile_coords = (j_start, i_start, j_end, i_end) + + # 🔹 APPLY + latent_tile_processed = apply_fn(latent_tile, tile_coords) + + # 🔹 DEBUG METRICS + if debug: + diff_map = (latent_tile_processed - latent_tile).abs().mean(dim=1) + impact_map[i_start:i_end, j_start:j_end] += diff_map.squeeze(0) + + diff_mean = diff_map.mean().item() + min_val = latent_tile_processed.min().item() + max_val = latent_tile_processed.max().item() + print(f"[TILE {tile_id}] ({i_start}:{i_end},{j_start}:{j_end}) " + f"diff={diff_mean:.6f} min={min_val:.4f} max={max_val:.4f}") + + # 🔹 BLEND au lieu d'overwrite + latents_out[:, :, i_start:i_end, j_start:j_end] = ( + 0.7 * latents_out[:, :, i_start:i_end, j_start:j_end] + + 0.3 * latent_tile_processed + ) + + tile_id += 1 + + # 🔹 VISUAL DEBUG GLOBAL + if debug and debug_dir is not None: + os.makedirs(debug_dir, exist_ok=True) + + # 🔹 Upsample à la résolution originale + impact_map_full = F.interpolate( + impact_map.unsqueeze(0).unsqueeze(0), + size=full_res, + mode='bilinear', + align_corners=False + ).squeeze() + + # 🔹 Lissage pour réduire le bruit + impact_map_full = gaussian_blur_tensor(impact_map_full, kernel_size=5, sigma=1.0) + + # 🔹 Normalisation pour visualisation + impact_np = impact_map_full.detach().cpu().numpy() + impact_np = impact_np - impact_np.min() + if impact_np.max() > 0: + impact_np = impact_np / impact_np.max() + impact_img = (impact_np * 255).astype(np.uint8) + impact_img = Image.fromarray(impact_img) + + impact_img.save(os.path.join(debug_dir, f"impact_map_{frame_idx:05d}.png")) + print(f"[DEBUG] Impact map saved") + + return latents_out + + +def apply_openpose_tilewise_v1( + latents, + pose, + apply_fn, + block_size=64, + overlap=32, + device='cuda', + debug=False, + debug_dir=None, + frame_idx=0 +): + import torch + import torch.nn.functional as F + import os + import numpy as np + from PIL import Image + + B, C_lat, H_lat, W_lat = latents.shape + + # 🔹 Redimensionner pose pour matcher les latents + pose_resized = F.interpolate( + pose, size=(H_lat, W_lat), + mode='bilinear', align_corners=False + ) + + latents_out = latents.clone() + + # 🔹 DEBUG MAP (impact spatial) + if debug: + impact_map = torch.zeros((H_lat, W_lat), device=device) + + stride = block_size - overlap + tile_id = 0 + + for i in range(0, H_lat, stride): + for j in range(0, W_lat, stride): + i_end = min(i + block_size, H_lat) + j_end = min(j + block_size, W_lat) + i_start = max(0, i_end - block_size) + j_start = max(0, j_end - block_size) + + latent_tile = latents[:, :, i_start:i_end, j_start:j_end] + pose_tile = pose_resized[:, :, i_start:i_end, j_start:j_end] + + tile_coords = (j_start, i_start, j_end, i_end) + + # 🔹 APPLY + latent_tile_processed = apply_fn(latent_tile, tile_coords) + + # 🔹 DEBUG METRICS + if debug: + # Diff par pixel (moyenne sur canaux) + diff_map = (latent_tile_processed - latent_tile).abs().mean(dim=1).squeeze(0) + impact_map[i_start:i_end, j_start:j_end] += diff_map + + diff_scalar = diff_map.mean().item() + min_val = latent_tile_processed.min().item() + max_val = latent_tile_processed.max().item() + + print(f"[TILE {tile_id}] ({i_start}:{i_end},{j_start}:{j_end}) " + f"diff={diff_scalar:.6f} min={min_val:.4f} max={max_val:.4f}") + + # Détecter tile mort + if max_val == 0 and min_val == 0: + print(f"[⚠️ ZERO TILE] {tile_id}") + + # 🔹 BLEND au lieu d'overwrite + latents_out[:, :, i_start:i_end, j_start:j_end] = ( + 0.7 * latents_out[:, :, i_start:i_end, j_start:j_end] + + 0.3 * latent_tile_processed + ) + + tile_id += 1 + + # 🔹 VISUAL DEBUG GLOBAL + if debug and debug_dir is not None: + os.makedirs(debug_dir, exist_ok=True) + + impact_np = impact_map.detach().float().cpu().numpy() + # normalisation visuelle + impact_np = impact_np - impact_np.min() + if impact_np.max() > 0: + impact_np = impact_np / impact_np.max() + + impact_img = (impact_np * 255).astype(np.uint8) + impact_img = Image.fromarray(impact_img) + impact_img.save(os.path.join(debug_dir, f"impact_map_{frame_idx:05d}.png")) + + print(f"[DEBUG] Impact map saved") + + return latents_out + +def apply_openpose_tilewise_old( + latents, + pose, + apply_fn, + block_size=64, + overlap=32, + device='cuda', + debug=False, + debug_dir=None, + frame_idx=0 +): + import torch + import torch.nn.functional as F + import os + import numpy as np + from PIL import Image + + B, C_lat, H_lat, W_lat = latents.shape + + pose_resized = F.interpolate( + pose, size=(H_lat, W_lat), + mode='bilinear', align_corners=False + ) + + latents_out = latents.clone() + + # 🔥 DEBUG MAP (impact spatial) + if debug: + impact_map = torch.zeros((H_lat, W_lat), device=device) + + stride = block_size - overlap + + tile_id = 0 + + for i in range(0, H_lat, stride): + for j in range(0, W_lat, stride): + + i_end = min(i + block_size, H_lat) + j_end = min(j + block_size, W_lat) + + #i_start = i_end - block_size if i_end - i < block_size else i + #j_start = j_end - block_size if j_end - j < block_size else j + + i_start = max(0, i_end - block_size) + j_start = max(0, j_end - block_size) + + latent_tile = latents[:, :, i_start:i_end, j_start:j_end] + pose_tile = pose_resized[:, :, i_start:i_end, j_start:j_end] + + tile_coords = (j_start, i_start, j_end, i_end) + + # 🔹 APPLY + latent_tile_processed = apply_fn(latent_tile, tile_coords) + + # 🔥 DEBUG METRICS + if debug: + diff = (latent_tile_processed - latent_tile).abs().mean().item() + min_val = latent_tile_processed.min().item() + max_val = latent_tile_processed.max().item() + + print(f"[TILE {tile_id}] ({i_start}:{i_end},{j_start}:{j_end}) " + f"diff={diff:.6f} min={min_val:.4f} max={max_val:.4f}") + + # détecter tile mort + if max_val == 0 and min_val == 0: + print(f"[⚠️ ZERO TILE] {tile_id}") + + # remplir impact map + impact_map[i_start:i_end, j_start:j_end] += diff + + # 🔥 IMPORTANT → BLEND au lieu d'overwrite + latents_out[:, :, i_start:i_end, j_start:j_end] = ( + 0.7 * latents_out[:, :, i_start:i_end, j_start:j_end] + + 0.3 * latent_tile_processed + ) + + tile_id += 1 + + # 🔥 VISUAL DEBUG GLOBAL + if debug and debug_dir is not None: + os.makedirs(debug_dir, exist_ok=True) + + impact_np = impact_map.detach().float().cpu().numpy() + + # normalisation visuelle + impact_np = impact_np - impact_np.min() + if impact_np.max() > 0: + impact_np = impact_np / impact_np.max() + + impact_img = (impact_np * 255).astype(np.uint8) + impact_img = Image.fromarray(impact_img) + + impact_img.save(os.path.join(debug_dir, f"impact_map_{frame_idx:05d}.png")) + + print(f"[DEBUG] Impact map saved") + + return latents_out + +def apply_openpose_tilewise_ori(latents, pose, apply_fn, block_size=64, overlap=96, device='cuda'): + """ + Applique OpenPose tile par tile sur le latent. + + latents : torch.Tensor [B, 4, H_latent, W_latent] (diffusion latents) + pose : torch.Tensor [B, 3, H_img, W_img] (full OpenPose) + apply_fn: fonction qui applique ControlNet sur un tile + """ + B, C_lat, H_lat, W_lat = latents.shape + B, C_pose, H_img, W_img = pose.shape + + # Redimensionner le pose pour matcher le latent full size + pose_resized = F.interpolate(pose, size=(H_lat, W_lat), mode='bilinear', align_corners=False) + + # Créer une copie pour modification + latents_out = latents.clone() + + stride = block_size - overlap + for i in range(0, H_lat, stride): + for j in range(0, W_lat, stride): + # Limites du tile (clip si dépasse) + i_end = min(i + block_size, H_lat) + j_end = min(j + block_size, W_lat) + i_start = i_end - block_size if i_end - i < block_size else i + j_start = j_end - block_size if j_end - j < block_size else j + + # Extraire tile latent et tile pose + latent_tile = latents[:, :, i_start:i_end, j_start:j_end] + pose_tile = pose_resized[:, :, i_start:i_end, j_start:j_end] + + # Appliquer ControlNet sur le tile + #latent_tile_processed = apply_fn(latent_tile, pose_tile) + tile_coords = (j_start, i_start, j_end, i_end) # ⚠️ ordre important + latent_tile_processed = apply_fn(latent_tile, tile_coords) + + # Écraser dans latents_out + latents_out[:, :, i_start:i_end, j_start:j_end] = latent_tile_processed + + return latents_out + + +def apply_controlnet_openpose_step_safe( + latents, + t, + unet, + controlnet, + scheduler, + pose_image, + pos_embeds, + neg_embeds=None, + guidance_scale=5.0, + controlnet_scale=0.7, + device="cuda", + dtype=torch.float16, + debug=False +): + # 🔹 Déplacement sur device / dtype + latents = latents.to(device=device, dtype=dtype) + pose_image = pose_image.to(device=device, dtype=dtype) + + # 🔹 Préparer batch pour classifier-free guidance + if neg_embeds is not None: + latent_model_input = torch.cat([latents] * 2) + encoder_states = torch.cat([neg_embeds, pos_embeds]) + pose_input = torch.cat([pose_image] * 2) + else: + latent_model_input = latents + encoder_states = pos_embeds + pose_input = pose_image + + # 🔹 Scheduler scale + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # 🔹 ControlNet + down_samples, mid_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + controlnet_cond=pose_input, + return_dict=False + ) + + # 🔹 Safe normalization des résidus ControlNet + down_samples = [d / (d.abs().mean() + 1e-6) for d in down_samples] + mid_sample = mid_sample / (mid_sample.abs().mean() + 1e-6) + + # 🔹 UNet avec ControlNet + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + down_block_additional_residuals=[d * controlnet_scale for d in down_samples], + mid_block_additional_residual=mid_sample * controlnet_scale + ).sample + + # 🔹 Classifier-Free Guidance + if neg_embeds is not None: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔹 Scheduler step avec batch safe + latents_input = torch.cat([latents] * 2) if neg_embeds is not None else latents + latents = scheduler.step(noise_pred, t, latents_input).prev_sample + + # 🔹 Récupérer batch original si CFG + if neg_embeds is not None: + latents = latents.chunk(2)[0] + + # 🔹 Clamp final pour sécurité + latents = torch.clamp(latents, -1.0, 1.0) + + if debug: + print(f"[ControlNet SAFE] t={t}, latents min/max: {latents.min().item():.3f}/{latents.max().item():.3f}") + + return latents + + +def apply_controlnet_openpose_step( + latents, + t, + unet, + controlnet, + scheduler, + pose_image, + pos_embeds, + neg_embeds=None, + guidance_scale=5.0, + controlnet_scale=0.7, + device="cuda", + dtype=torch.float16, + debug=False +): + import torch + + latents = latents.to(device=device, dtype=dtype) + pose_image = pose_image.to(device=device, dtype=dtype) + + # 🔁 classifier-free guidance + if neg_embeds is not None: + latent_model_input = torch.cat([latents] * 2) + encoder_states = torch.cat([neg_embeds, pos_embeds]) + pose_input = torch.cat([pose_image] * 2) + else: + latent_model_input = latents + encoder_states = pos_embeds + pose_input = pose_image + + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # 🔥 ControlNet + down_samples, mid_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + controlnet_cond=pose_input, + return_dict=False + ) + + # 🔥 UNet avec ControlNet + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + down_block_additional_residuals=[d * controlnet_scale for d in down_samples], + mid_block_additional_residual=mid_sample * controlnet_scale + ).sample + + # 🔁 CFG + if neg_embeds is not None: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔥 Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + if debug: + print(f"[ControlNet] t={t}, latents min/max: {latents.min().item():.3f}/{latents.max().item():.3f}") + + return latents + + +def generate_pose_sequence( + base_pose, + num_frames=16, + motion_type="idle", # "idle", "sway", "zoom", "breath" + amplitude=5.0, + device="cuda", + dtype=None, + debug=False +): + """ + Génère une séquence de poses animées (OpenPose-like). + + Args: + base_pose: tensor [1,3,H,W] ou [1,1,H,W] (image pose) + num_frames: nombre de frames + motion_type: type animation + amplitude: intensité mouvement + device: device + debug: print infos + + Returns: + List[Tensor]: liste de control_tensor + """ + + + if dtype is None: + dtype = base_pose.dtype + + base_pose = base_pose.to(device=device, dtype=dtype) + B, C, H, W = base_pose.shape + + poses = [] + + for t in range(num_frames): + + alpha = t / max(1, num_frames - 1) + phase = 2 * math.pi * alpha + + pose = base_pose.clone() + + # -------------------------------------------------- + # 🎯 MOTION TYPES + # -------------------------------------------------- + + # 🔹 1. IDLE (micro mouvements naturels) + if motion_type == "idle": + dx = math.sin(phase) * amplitude * 0.3 + dy = math.cos(phase) * amplitude * 0.2 + + # 🔹 2. SWAY (balancement) + elif motion_type == "sway": + dx = math.sin(phase) * amplitude + dy = 0 + + # 🔹 3. BREATH (zoom subtil) + elif motion_type == "breath": + scale = 1.0 + math.sin(phase) * 0.03 + pose = F.interpolate( + pose, + scale_factor=scale, + mode="bilinear", + align_corners=False + ) + pose = F.interpolate(pose, size=(H, W)) + dx, dy = 0, 0 + + # 🔹 4. ZOOM léger + drift + elif motion_type == "zoom": + scale = 1.0 + math.sin(phase) * 0.05 + pose = F.interpolate( + pose, + scale_factor=scale, + mode="bilinear", + align_corners=False + ) + pose = F.interpolate(pose, size=(H, W)) + dx = math.sin(phase) * amplitude * 0.2 + dy = math.cos(phase) * amplitude * 0.2 + + else: + dx, dy = 0, 0 + + # -------------------------------------------------- + # 🔹 Translation affine (ultra stable) + # -------------------------------------------------- + if motion_type != "breath": + theta = torch.tensor([ + [1, 0, dx / (W/2)], + [0, 1, dy / (H/2)] + ], device=device, dtype=dtype).unsqueeze(0) + + grid = F.affine_grid(theta, pose.size(), align_corners=False) + pose = F.grid_sample(pose, grid, align_corners=False) + + # -------------------------------------------------- + # 🔒 Sécurité + # -------------------------------------------------- + pose = torch.nan_to_num(pose, 0.0) + pose = pose.clamp(0, 1) + + poses.append(pose) + + # -------------------------------------------------- + # 🔍 Debug + # -------------------------------------------------- + if debug: + print(f"[PoseSeq] frames: {num_frames}") + print(f"[PoseSeq] type: {motion_type}") + print(f"[PoseSeq] shape: {poses[0].shape}") + + return poses + + +def apply_controlnet_openpose_step_v1( + latents, + t, + unet, + controlnet, + control_tensor, + pos_embeds=None, + neg_embeds=None, + guidance_scale=1.0, + controlnet_strength=1.0, + device="cuda", + debug=False +): + """ + Applique ControlNet OpenPose sur un step UNet avec CFG. + + Args: + latents: [B,C,H,W] + t: timestep + unet: modèle UNet + controlnet: modèle ControlNet + control_tensor: [B,1,H,W] ou [B,3,H,W] (pose image) + pos_embeds: embeddings positifs + neg_embeds: embeddings négatifs + guidance_scale: CFG strength + controlnet_strength: influence pose + device: device + debug: print infos + + Returns: + latents mis à jour + """ + + # -------------------------------------------------- + # 🔒 Sécurisation inputs + # -------------------------------------------------- + latents = torch.nan_to_num(latents, 0.0) + control_tensor = torch.nan_to_num(control_tensor, 0.0) + + control_tensor = control_tensor.clamp(0, 1) + + if control_tensor.shape[1] == 1: + control_tensor = control_tensor.repeat(1, 3, 1, 1) + + control_tensor = control_tensor.to(device=device, dtype=latents.dtype) + + if debug: + print(f"[ControlNet] latents: {latents.shape}") + print(f"[ControlNet] control: {control_tensor.shape}") + print(f"[ControlNet] timestep: {t}") + + # -------------------------------------------------- + # 🔹 POS PASS + # -------------------------------------------------- + down_pos, mid_pos = controlnet( + latents, + t, + encoder_hidden_states=pos_embeds, + controlnet_cond=control_tensor, + return_dict=False + ) + + down_pos = [d * controlnet_strength for d in down_pos] + mid_pos = mid_pos * controlnet_strength + + noise_pos = unet( + latents, + t, + encoder_hidden_states=pos_embeds, + down_block_additional_residuals=down_pos, + mid_block_additional_residual=mid_pos + ).sample + + # -------------------------------------------------- + # 🔹 NEG PASS (si CFG) + # -------------------------------------------------- + if neg_embeds is not None: + + down_neg, mid_neg = controlnet( + latents, + t, + encoder_hidden_states=neg_embeds, + controlnet_cond=control_tensor, + return_dict=False + ) + + down_neg = [d * controlnet_strength for d in down_neg] + mid_neg = mid_neg * controlnet_strength + + noise_neg = unet( + latents, + t, + encoder_hidden_states=neg_embeds, + down_block_additional_residuals=down_neg, + mid_block_additional_residual=mid_neg + ).sample + + # 🔥 CFG + noise_pred = noise_neg + guidance_scale * (noise_pos - noise_neg) + + else: + noise_pred = noise_pos + + # -------------------------------------------------- + # 🔹 Update latents (diffusion step simplifié) + # -------------------------------------------------- + latents = latents + noise_pred * 0.1 # facteur stable (évite explosion) + + # 🔒 Clamp sécurité + latents = torch.clamp(latents, -1.5, 1.5) + + # -------------------------------------------------- + # 🔍 Debug + # -------------------------------------------------- + if debug: + print(f"[ControlNet] noise min/max: {noise_pred.min():.3f}/{noise_pred.max():.3f}") + print(f"[ControlNet] latents min/max: {latents.min():.3f}/{latents.max():.3f}") + + return latents + +# Chargement par defaut: +# /mnt/62G/AnimateDiff main* 54s +# animatediff ❯ ls -l /mnt/62G/huggingface/sd-controlnet-openpose/ +# .rw-r--r--@ 1,4G n3oray 25 mars 22:50  diffusion_pytorch_model.safetensors +# .rw-r--r--@ 67 n3oray 25 mars 22:51  Note.txt + +def load_controlnet_openpose_local( + local_model_path="/mnt/62G/huggingface/sd-controlnet-openpose", + device="cuda", + dtype=torch.float16, + use_fp16=True, + debug=True +): + """ + Charge ControlNet OpenPose depuis un dossier local contenant : + - diffusion_pytorch_model.safetensors + - config.json + + Args: + local_model_path (str): chemin vers le dossier local du modèle + device (str): "cuda" ou "cpu" + dtype (torch.dtype): dtype cible + use_fp16 (bool): force fp16 si possible + debug (bool): logs détaillés + + Returns: + controlnet (ControlNetModel) + """ + print(f"Chargement ControlNet OpenPose depuis dossier local : {local_model_path}") + print(f"device : {device}") + print(f"dtype : {dtype}") + + try: + # 🔹 Choix dtype intelligent + load_dtype = torch.float16 if (use_fp16 and device.startswith("cuda")) else torch.float32 + + controlnet = ControlNetModel.from_pretrained( + local_model_path, + torch_dtype=load_dtype, + local_files_only=True + ) + + # 🔹 Move device + controlnet = controlnet.to(device) + + # 🔹 Vérification paramètres + total_params = sum(p.numel() for p in controlnet.parameters()) / 1e6 + if debug: + print(f"🧠 ControlNet prêt") + print(f" params : {total_params:.1f}M") + print(f" dtype : {next(controlnet.parameters()).dtype}") + print(f" device : {next(controlnet.parameters()).device}") + + # 🔹 Mode eval + controlnet.eval() + + # 🔹 Nettoyage mémoire GPU + torch.cuda.empty_cache() + + return controlnet + + except Exception as e: + print("❌ ERREUR chargement ControlNet depuis dossier local") + print(str(e)) + + # 🔥 fallback CPU (évite crash) + if device.startswith("cuda"): + print("⚠ Fallback CPU...") + return load_controlnet_openpose_local( + local_model_path=local_model_path, + device="cpu", + dtype=torch.float32, + use_fp16=False, + debug=debug + ) + + raise e + + +def load_controlnet_openpose( + device="cuda", + dtype=torch.float16, + model_id="lllyasviel/sd-controlnet-openpose", + use_fp16=True, + debug=True +): + """ + Charge ControlNet OpenPose avec gestion propre GPU / CPU / dtype. + + Args: + device (str): "cuda" ou "cpu" + dtype (torch.dtype): dtype cible (fp16 recommandé) + model_id (str): repo HF + use_fp16 (bool): force fp16 si possible + debug (bool): logs détaillés + + Returns: + controlnet (ControlNetModel) + """ + print(f"Chargement ControlNet OpenPose - Parametres recommander:") + print(f"guidance_scale = 5.0 → 6.0") + print(f"controlnet_strength = 0.7 → 0.9") + print(f"latents update factor = 0.1 ✅ (ne pas monter)") + + if debug: + print("🔄 Chargement ControlNet OpenPose...") + print(f" model_id : {model_id}") + print(f" device : {device}") + print(f" dtype : {dtype}") + + try: + # 🔹 Choix dtype intelligent + load_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 + + controlnet = ControlNetModel.from_pretrained( + model_id, + torch_dtype=load_dtype + ) + + if debug: + print("✅ Modèle chargé depuis HuggingFace") + + # 🔹 Move device + controlnet = controlnet.to(device) + + # 🔹 Vérification paramètres + total_params = sum(p.numel() for p in controlnet.parameters()) / 1e6 + + if debug: + print(f"🧠 ControlNet prêt") + print(f" params : {total_params:.1f}M") + print(f" dtype : {next(controlnet.parameters()).dtype}") + print(f" device : {next(controlnet.parameters()).device}") + + # 🔹 Mode eval + controlnet.eval() + + # 🔹 Sécurité mémoire (important pour 4GB) + torch.cuda.empty_cache() + + return controlnet + + except Exception as e: + print("❌ ERREUR chargement ControlNet") + print(str(e)) + + # 🔥 fallback CPU (évite crash) + if device == "cuda": + print("⚠ Fallback CPU...") + return load_controlnet_openpose( + device="cpu", + dtype=torch.float32, + use_fp16=False, + debug=debug + ) + + raise e + +def apply_controlnet_openpose_step_ultrasafe( + latents, + t, + unet, + controlnet, + scheduler, + pose_image, + pos_embeds, + neg_embeds=None, + guidance_scale=5.0, + controlnet_scale=0.7, + device="cuda", + dtype=torch.float16, + debug=False +): + import traceback + + # Backup latents + latents_prev = latents.clone().to(device=device, dtype=dtype) + + # 🔹 Cast strict au dtype du modèle + latents = latents.to(device=device, dtype=dtype) + pose_image = pose_image.to(device=device, dtype=dtype) + pos_embeds = pos_embeds.to(device=device, dtype=dtype) + if neg_embeds is not None: + neg_embeds = neg_embeds.to(device=device, dtype=dtype) + + # 🔹 CFG batch + if neg_embeds is not None: + latent_model_input = torch.cat([latents] * 2) + encoder_states = torch.cat([neg_embeds, pos_embeds]) + pose_input = torch.cat([pose_image] * 2) + else: + latent_model_input = latents + encoder_states = pos_embeds + pose_input = pose_image + + latent_model_input = latent_model_input.to(dtype=dtype) + encoder_states = encoder_states.to(dtype=dtype) + pose_input = pose_input.to(dtype=dtype) + + # Scheduler scale + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # ControlNet + try: + down_samples, mid_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + controlnet_cond=pose_input, + return_dict=False + ) + except Exception as e: + if debug: + print(f"[ControlNet ERROR] {e}") + traceback.print_exc() + return latents_prev + + # Normalisation + down_samples = [torch.nan_to_num(d / (d.std() + 1e-6), nan=0.0, posinf=1.0, neginf=-1.0) for d in down_samples] + mid_sample = torch.nan_to_num(mid_sample / (mid_sample.abs().mean() + 1e-6), nan=0.0, posinf=1.0, neginf=-1.0) + + # UNet + try: + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=encoder_states, + down_block_additional_residuals=[d * controlnet_scale for d in down_samples], + mid_block_additional_residual=mid_sample * controlnet_scale + ).sample + except Exception as e: + if debug: + print(f"[UNet ERROR] {e}") + traceback.print_exc() + return latents_prev + + # CFG + if neg_embeds is not None: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + try: + latents = scheduler.step(noise_pred, t, latent_model_input).prev_sample + except Exception as e: + if debug: + print(f"[Scheduler ERROR] {e}") + traceback.print_exc() + return latents_prev + + if neg_embeds is not None: + latents = latents.chunk(2)[0] + + # Final safety + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=-1.0) + latents = torch.clamp(latents, -0.85, 0.85) + + if debug: + print(f"[ControlNet OK] t={t}, dtype={latents.dtype}, min/max={latents.min().item():.3f}/{latents.max().item():.3f}") + + return latents + +import torch +import torch.nn.functional as F + +def controlnet_tile_fn( + latent_tile, + tile_coords, + frame_counter, + pose_full, + unet, + controlnet, + scheduler, + cf_embeds, + current_guidance_scale, + controlnet_scale, + device, + target_dtype +): + import torch + import torch.nn.functional as F + + x0, y0, x1, y1 = tile_coords + scale = 8 + + # 🔹 SAFE coords + x0, y0 = max(0, x0), max(0, y0) + + # 🔹 Crop pose + pose_tile = pose_full[:, :, y0*scale:y1*scale, x0*scale:x1*scale] + + pose_tile = F.interpolate( + pose_tile, + size=(latent_tile.shape[2]*scale, latent_tile.shape[3]*scale), + mode='bilinear', + align_corners=False + ) + + # 🔹 Embeddings + pos_embeds = cf_embeds[0] + neg_embeds = cf_embeds[1] if cf_embeds[1] is not None else None + + # ========================================================= + # 🔥 FIX 1 — PASSAGE EN FP32 (CRITIQUE) + # ========================================================= + latent_tile_fp32 = latent_tile.to(device=device, dtype=torch.float32) + pose_tile_fp32 = pose_tile.to(device=device, dtype=torch.float32) + pos_embeds_fp32 = pos_embeds.to(device=device, dtype=torch.float32) + neg_embeds_fp32 = neg_embeds.to(device=device, dtype=torch.float32) if neg_embeds is not None else None + + # ========================================================= + # 🔥 FIX 2 — TIMESTEP SAFE + # ========================================================= + timesteps = scheduler.timesteps + t = timesteps[min(frame_counter, len(timesteps)-1)] + + # ========================================================= + # 🔥 FIX 3 — ADD NOISE + CLAMP + # ========================================================= + noise = torch.randn_like(latent_tile_fp32) + latent_noisy = scheduler.add_noise(latent_tile_fp32, noise, t) + latent_noisy = torch.clamp(latent_noisy, -20, 20) + + # 🔥 scale model input + latent_model_input = scheduler.scale_model_input(latent_noisy, t) + + # ========================================================= + # 🔹 CFG setup + # ========================================================= + if neg_embeds_fp32 is not None: + latent_model_input = torch.cat([latent_model_input] * 2) + embeds = torch.cat([neg_embeds_fp32, pos_embeds_fp32]) + else: + embeds = pos_embeds_fp32 + + # ========================================================= + # 🔹 CONTROLNET + # ========================================================= + # 🔥 CAST AVANT MODEL + latent_model_input = latent_model_input.to(target_dtype) + embeds = embeds.to(target_dtype) + pose_tile_fp32 = pose_tile_fp32.to(target_dtype) + + down_samples, mid_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=embeds, + controlnet_cond=pose_tile_fp32, + return_dict=False + ) + + # ========================================================= + # 🔹 UNET + # ========================================================= + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds, + down_block_additional_residuals=down_samples, + mid_block_additional_residual=mid_sample, + return_dict=False + )[0] + + # 🔥 FIX 4 — remove NaN + noise_pred = torch.nan_to_num(noise_pred, nan=0.0, posinf=1.0, neginf=-1.0) + + # ========================================================= + # 🔹 CFG merge + # ========================================================= + if neg_embeds_fp32 is not None: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + current_guidance_scale * (noise_text - noise_uncond) + + # ========================================================= + # 🔹 STEP DIFFUSION + # ========================================================= + latents_out = scheduler.step(noise_pred, t, latent_noisy).prev_sample + + # ========================================================= + # 🔥 FIX 5 — DELTA SAFE + # ========================================================= + delta = latents_out - latent_tile_fp32 + delta = torch.nan_to_num(delta, nan=0.0, posinf=1.0, neginf=-1.0) + + delta = controlnet_scale * delta + #delta = torch.clamp(delta, -0.5, 0.5) + delta = torch.clamp(delta, -0.15, 0.15) + + # 🔹 DEBUG minimal + if torch.isnan(delta).any(): + print("[⚠️ NaN détecté dans delta]") + else: + print(f"[ControlNet OK] delta min/max: {delta.min().item():.4f}/{delta.max().item():.4f}") + + # ========================================================= + # 🔥 FIX 6 — RETOUR FP16 + # ========================================================= + return (latent_tile_fp32 + delta).to(dtype=target_dtype) + + + + + diff --git a/scripts/utils/n3rProNet.py b/scripts/utils/n3rProNet.py new file mode 100644 index 00000000..97707038 --- /dev/null +++ b/scripts/utils/n3rProNet.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class N3RProNet(nn.Module): + def __init__(self, channels=4): + super().__init__() + + self.conv1 = nn.Conv2d(channels, 8, 3, padding=1) + self.conv2 = nn.Conv2d(8, 8, 3, padding=1) + self.conv3 = nn.Conv2d(8, channels, 1) + + self.act = nn.SiLU() + + def forward(self, x): + residual = x + + x = self.act(self.conv1(x)) + x = self.act(self.conv2(x)) + x = self.conv3(x) + + return x + residual diff --git a/scripts/utils/n3rProNet_utils.py b/scripts/utils/n3rProNet_utils.py new file mode 100644 index 00000000..01783626 --- /dev/null +++ b/scripts/utils/n3rProNet_utils.py @@ -0,0 +1,3017 @@ +# n3rProNet_utils.py +#------------------------------------------------------------------------------- +from .tools_utils import ensure_4_channels, sanitize_latents, log_debug +import torch +import math +import numpy as np +from PIL import Image, ImageFilter +import torch.nn.functional as F +from pathlib import Path + +from torchvision.transforms.functional import to_pil_image + + + +def scale_eye_coords_to_latents(eye_coords, img_H, img_W, lat_H, lat_W): + """ + Convertit coords image -> latent space + """ + + # 🔥 FIX : gérer None ou liste vide + if not eye_coords: + return None + + scale_x = lat_W / img_W + scale_y = lat_H / img_H + + return [(int(x * scale_x), int(y * scale_y)) for x, y in eye_coords] + + +def get_eye_coords_safe(image_pil, H=None, W=None): + try: + coords = get_eye_coords(image_pil) + if coords is None: + print("⚠️ Aucun visage détecté") + return None + print(f"👁 Eyes detected: {coords}") + return coords + except Exception as e: + print(f"[Eye detection ERROR] {e}") + return None + + +def get_eye_coords(image_pil): + """ + Détecte les coordonnées des yeux avec MediaPipe. + + Args: + image_pil (PIL.Image): image d'entrée + + Returns: + list[(x, y)]: centres des yeux en coordonnées image + """ + import numpy as np + import mediapipe as mp + + mp_face_mesh = mp.solutions.face_mesh + + image = np.array(image_pil.convert("RGB")) + h, w, _ = image.shape + + with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=1, + refine_landmarks=True + ) as face_mesh: + + results = face_mesh.process(image) + + if not results.multi_face_landmarks: + return None + + face_landmarks = results.multi_face_landmarks[0] + + # 🔹 Indices iris MediaPipe (refine_landmarks=True requis) + LEFT_IRIS = [474, 475, 476, 477] + RIGHT_IRIS = [469, 470, 471, 472] + + def get_center(indices): + xs, ys = [], [] + for idx in indices: + lm = face_landmarks.landmark[idx] + xs.append(lm.x * w) + ys.append(lm.y * h) + return int(sum(xs) / len(xs)), int(sum(ys) / len(ys)) + + left_eye = get_center(LEFT_IRIS) + right_eye = get_center(RIGHT_IRIS) + + return [left_eye, right_eye] + +def apply_glow_froid_iris(latents, eye_coords, iris_radius_ratio=0.08, strength=0.25, blur_kernel=5): + """ + Applique un glow froid ciblé sur l'iris des yeux dans les latents [B,C,H,W]. + + Args: + latents (torch.Tensor): Latents [B,C,H,W]. + eye_coords (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)]. + iris_radius_ratio (float): Ratio de rayon de l'iris par rapport à la plus petite dimension H/W. + strength (float): Intensité du glow (0.0 à 1.0). + blur_kernel (int): Taille du noyau pour un léger flou gaussien. + + Returns: + torch.Tensor: Latents avec glow appliqué sur les iris. + """ + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Créer un masque radial pour chaque œil + mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + min_dim = min(H, W) + iris_radius = iris_radius_ratio * min_dim + + yy, xx = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + for x_eye, y_eye in eye_coords: + dist = torch.sqrt((xx - x_eye)**2 + (yy - y_eye)**2) + eye_mask = torch.exp(-(dist**2) / (2 * iris_radius**2)) + mask += eye_mask.unsqueeze(0) # broadcast batch dimension + + # Clamp à 1 pour éviter dépassement si 2 yeux se chevauchent + mask = mask.clamp(0.0, 1.0) + + # 2️⃣ Appliquer léger blur pour adoucir les bords + if blur_kernel > 1: + kernel = torch.ones((C, 1, blur_kernel, blur_kernel), device=device, dtype=dtype) + kernel = kernel / kernel.sum() + mask = F.conv2d(mask.repeat(1, C, 1, 1), kernel, padding=blur_kernel//2, groups=C) + + # 3️⃣ Créer glow gaussien via convolution légère + sigma = blur_kernel / 3.0 + glow_kernel = torch.exp(-((torch.arange(-blur_kernel//2+1, blur_kernel//2+2, device=device).view(-1,1))**2)/ (2*sigma**2)) + glow_kernel = glow_kernel / glow_kernel.sum() + glow_kernel = glow_kernel.view(1,1,blur_kernel,1).repeat(C,1,1,1) + glow = F.conv2d(latents, glow_kernel, padding=(blur_kernel//2,0), groups=C) + glow = F.conv2d(glow, glow_kernel.transpose(2,3), padding=(0,blur_kernel//2), groups=C) # convolution 2D approximative + + # 4️⃣ Fusion glow sur iris seulement + latents_out = latents * (1 - mask) + glow * mask * strength + latents_out = latents_out.clamp(-1.0, 1.0) + + return latents_out + +def apply_intelligent_glow_froid_latents(latents, strength=0.2, blur_kernel=7): + """ + Applique un effet "glow froid" directement sur des latents [B, C, H, W]. + + Args: + latents (torch.Tensor): Latents [B,C,H,W]. + strength (float): Intensité du glow (0.0 à 1.0). + blur_kernel (int): Taille du noyau pour le flou gaussien (doit être impair). + + Returns: + torch.Tensor: Latents avec glow appliqué. + """ + if latents.ndim != 4: + raise ValueError("Latents doivent être de shape [B, C, H, W]") + + B, C, H, W = latents.shape + + # 🔹 Création noyau gaussien 2D + def gaussian_kernel(kernel_size, sigma, channels): + ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=latents.device) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel = torch.exp(-(xx**2 + yy**2) / (2.0 * sigma**2)) + kernel = kernel / kernel.sum() + kernel = kernel.view(1, 1, kernel_size, kernel_size).repeat(channels, 1, 1, 1) + return kernel + + sigma = blur_kernel / 3.0 + kernel = gaussian_kernel(blur_kernel, sigma, C).to(latents.device, latents.dtype) + + padding = blur_kernel // 2 + # 🔹 Appliquer convolution pour obtenir le glow + glow = F.conv2d(latents, kernel, padding=padding, groups=C) + + # 🔹 Fusion latents original + glow + latents_out = latents * (1 - strength) + glow * strength + + # 🔹 Clamp pour stabilité + latents_out = latents_out.clamp(-1.0, 1.0) + + return latents_out + + +# Appplication effect sur les iris yeux: +def apply_glow_froid_iris(latents, eye_coords, iris_radius_ratio=0.08, strength=0.2, blur_kernel=7): + """ + Applique un glow froid uniquement sur l'iris des yeux dans les latents [B,C,H,W]. + + Args: + latents (torch.Tensor): Latents SD [B,C,H,W] + eye_coords (list of tuples): Coordonnées des yeux [(x1,y1),(x2,y2)] + iris_radius_ratio (float): proportion de H/W pour rayon iris + strength (float): intensité du glow + blur_kernel (int): taille du kernel gaussien (impair) + + Returns: + torch.Tensor: latents avec glow sur iris + """ + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Créer masque radial pour l’iris + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + for i, (x, y) in enumerate(eye_coords): + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + # coordonnées grille + Y, X = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + dist2 = ((X - x)**2) / (rx**2) + ((Y - y)**2) / (ry**2) + iris_mask[0, 0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0, 1) # éviter >1 si deux yeux se chevauchent + + # 2️⃣ Créer kernel gaussien 2D + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) # [C,1,kH,kW] + + # 3️⃣ Appliquer convolution channel-wise + glow = F.conv2d(latents * iris_mask, kernel, padding=blur_kernel // 2, groups=C) + + # 4️⃣ Fusion glow sur iris uniquement + latents_out = latents * (1 - iris_mask) + glow * iris_mask * strength + latents_out = latents_out.clamp(-1.0, 1.0) + + return latents_out + + +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +#----------- Rendu HD ------------------------------ +def apply_pro_net_volumetrique( + latents, + coords_v, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + glow_strength=0.2, + blur_kernel=3, # plus petit = détails plus nets + iris_radius_ratio=0.08, + mask_blur_kernel=1, # très léger flou du masque + debug=False +): + """ + ProNet volumétrique HD + glow iris avec contours adoucis mais plus net + """ + + import torch + import torch.nn.functional as F + + if not coords_v: + return apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn).to(dtype) + + # 2️⃣ Masque iris + iris_mask = torch.zeros((B,1,H,W), device=device, dtype=dtype) + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in coords_v: + rx = max(1, int(W * iris_radius_ratio)) + ry = max(1, int(H * iris_radius_ratio)) + dist2 = ((X - x)**2)/(rx**2 + 1e-6) + ((Y - y)**2)/(ry**2 + 1e-6) + iris_mask[0,0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0,1) + + # 3️⃣ Léger flou du masque seulement + if mask_blur_kernel > 1: + sigma = mask_blur_kernel / 3 + ax = torch.arange(-mask_blur_kernel//2 + 1., mask_blur_kernel//2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + mask_kernel = torch.exp(-(xx**2 + yy**2)/(2*sigma**2)) + mask_kernel = mask_kernel / mask_kernel.sum() + mask_kernel = mask_kernel.view(1,1,mask_blur_kernel,mask_blur_kernel) + iris_mask = F.conv2d(iris_mask, mask_kernel, padding=mask_blur_kernel//2) + iris_mask = iris_mask.clamp(0,1) + + # 4️⃣ Détails HD (high-frequency) + if blur_kernel > 1: + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel//2 + 1., blur_kernel//2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2)/(2*sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1,1,blur_kernel,blur_kernel).repeat(C,1,1,1).to(dtype) + blurred = F.conv2d(latents_prot, kernel, padding=blur_kernel//2, groups=C) + high_freq = latents_prot - blurred + else: + high_freq = latents_prot - latents_prot # pas de flou → pas de high_freq + + # 5️⃣ Glow adaptatif seulement sur iris + latents_out = latents_prot + glow_strength * high_freq * iris_mask + latents_out = latents_out.clamp(-1.0,1.0) + + # 6️⃣ Debug + if debug: + import matplotlib.pyplot as plt + plt.figure(figsize=(12,4)) + plt.subplot(1,3,1); plt.imshow(latents_prot[0,0].detach().cpu(), cmap='gray'); plt.title("ProNet") + plt.subplot(1,3,2); plt.imshow(high_freq[0,0].detach().cpu(), cmap='gray'); plt.title("High-Freq") + plt.subplot(1,3,3); plt.imshow(iris_mask[0,0].detach().cpu(), cmap='Reds', alpha=0.5); plt.title("Mask Iris") + plt.tight_layout(); plt.show() + print("👁 DEBUG HD sharp appliqué") + + return latents_out + +def apply_pro_net_volumetrique_good( + latents, + coords_v, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + glow_strength=0.2, + blur_kernel=7, + iris_radius_ratio=0.08, + debug=False +): + """ + Applique ProNet et un effet "HDR / détail" sur les iris des yeux, + compatible FP16 / latents interpolés. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents à traiter. + coords_v (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)]. + n3r_pro_net: modèle ProNet + n3r_pro_strength (float): force ProNet + sanitize_fn: fonction de nettoyage latents + glow_strength (float): intensité du glow / amplification + blur_kernel (int): taille du kernel pour flou + iris_radius_ratio (float): proportion de H/W pour rayon iris + debug (bool): visualisation mask + latents + + Returns: + torch.Tensor: latents avec effet HDR sur iris uniquement + """ + if not coords_v: + # Aucun yeux détectés → ProNet seul + return apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Appliquer ProNet + latents_prot = apply_n3r_pro_net( + latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn + ) + + # 2️⃣ Créer masque iris + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in coords_v: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + dist2 = ((X - x)**2)/(rx**2) + ((Y - y)**2)/(ry**2) + iris_mask[0, 0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0, 1) + + # 3️⃣ Kernel gaussien, même dtype que latents (FP16 ok) + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) + + # 4️⃣ Convolution channel-wise → amplification détails iris + glow = F.conv2d(latents_prot * iris_mask, kernel, padding=blur_kernel // 2, groups=C) + + # 5️⃣ Fusion ProNet + iris glow + latents_out = latents_prot * (1 - iris_mask) + glow * iris_mask * glow_strength + latents_out = latents_out.clamp(-1.0, 1.0) + + # ---------------- DEBUG ---------------- + if debug: + lat_vis = latents[0, 0].detach().cpu() + prot_vis = latents_prot[0, 0].detach().cpu() + glow_vis = glow[0, 0].detach().cpu() + mask_vis = iris_mask[0, 0].detach().cpu() + + plt.figure(figsize=(12, 4)) + plt.subplot(1, 4, 1) + plt.imshow(lat_vis, cmap='gray') + plt.title("Latent original") + plt.subplot(1, 4, 2) + plt.imshow(prot_vis, cmap='gray') + plt.title("ProNet") + plt.subplot(1, 4, 3) + plt.imshow(glow_vis, cmap='gray') + plt.title("HDR / Glow Iris") + plt.subplot(1, 4, 4) + plt.imshow(lat_vis, cmap='gray', alpha=0.7) + plt.imshow(mask_vis, cmap='Reds', alpha=0.4) + plt.title("Mask Iris") + plt.tight_layout() + plt.show() + print("👁 DEBUG activé → vérifie position / taille iris") + + return latents_out + +#----- Amplification des détails des yeux + +def apply_pro_net_with_eyes( + latents, + eye_coords, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + detail_strength=0.35, # intensité HDR + blur_kernel=5, # kernel pour détails + iris_radius_ratio=0.06, # plus petit = cible mieux iris + mask_blur_kernel=3 # flou du masque pour adoucir les contours +): + """ + ProNet optimisé + amplification HDR des détails sur l’iris (pas glow) + avec fusion douce pour éviter halo sur les contours. + """ + + import torch + import torch.nn.functional as F + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Appliquer ProNet standard + latents_prot = apply_n3r_pro_net( + latents, + model=n3r_pro_net, + strength=n3r_pro_strength, + sanitize_fn=sanitize_fn + ).to(dtype) + + # 2️⃣ Si pas d’yeux → fallback + if not eye_coords: + return latents_prot + + # 3️⃣ Création masque IRIS + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in eye_coords: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + dist = ((X - x)**2) / (rx**2 + 1e-6) + ((Y - y)**2) / (ry**2 + 1e-6) + iris_mask[0, 0] += (dist <= 1).float() + + iris_mask = iris_mask.clamp(0, 1) + + # 4️⃣ Flouter le masque pour adoucir les contours + if mask_blur_kernel > 1: + sigma = mask_blur_kernel / 3 + ax = torch.arange(-mask_blur_kernel // 2 + 1., mask_blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + mask_kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + mask_kernel_2d = mask_kernel_2d / mask_kernel_2d.sum() + mask_kernel = mask_kernel_2d.view(1, 1, mask_blur_kernel, mask_blur_kernel) + iris_mask = F.conv2d(iris_mask, mask_kernel, padding=mask_blur_kernel // 2) + iris_mask = iris_mask.clamp(0, 1) + + # 5️⃣ Blur pour récupérer les détails (high-frequency) + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1).to(dtype) + blurred = F.conv2d(latents_prot, kernel, padding=blur_kernel // 2, groups=C) + details = latents_prot - blurred + + # 6️⃣ Amplification HDR adaptative selon le masque flou + detail_strength_map = detail_strength * iris_mask + enhanced = latents_prot + details * detail_strength_map + + # 7️⃣ Fusion douce + latents_out = latents_prot * (1 - iris_mask) + enhanced * iris_mask + + # 8️⃣ Clamp final pour sécurité + latents_out = torch.clamp(latents_out, -1.0, 1.0) + + print("👁 HDR détails appliqué sur iris avec contours adoucis") + + return latents_out + +#------------ Stable version mais un peu fort ---- +def apply_pro_net_with_eyes_boost( + latents, + eye_coords, + n3r_pro_net, + n3r_pro_strength, + sanitize_fn, + detail_strength=0.35, # intensité HDR + blur_kernel=5, # plus petit = plus précis + iris_radius_ratio=0.06 # plus petit = cible mieux iris +): + """ + ProNet + amplification HDR des détails sur l’iris (pas glow). + """ + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ ProNet + latents_prot = apply_n3r_pro_net( + latents, + model=n3r_pro_net, + strength=n3r_pro_strength, + sanitize_fn=sanitize_fn + ) + + # 🔒 sécurité dtype (évite ton erreur Half/Float) + latents_prot = latents_prot.to(dtype) + + # 2️⃣ Si pas d’yeux → fallback + if not eye_coords: + return latents_prot + + # 3️⃣ Création masque IRIS (ellipse fine) + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + + Y, X = torch.meshgrid( + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + + for x, y in eye_coords: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + + dist = ((X - x)**2) / (rx**2 + 1e-6) + ((Y - y)**2) / (ry**2 + 1e-6) + iris_mask[0, 0] += (dist <= 1).float() + + iris_mask = iris_mask.clamp(0, 1) + + # 4️⃣ Kernel GAUSSIEN (corrigé) + sigma = blur_kernel / 3 + + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device, dtype=dtype) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) + + # 🔒 même dtype que latents + kernel = kernel.to(dtype) + + # 5️⃣ Blur = base low-frequency + blurred = F.conv2d( + latents_prot, + kernel, + padding=blur_kernel // 2, + groups=C + ) + + # 6️⃣ Détails (high-frequency) + details = latents_prot - blurred + + # 7️⃣ Amplification HDR + enhanced = latents_prot + detail_strength * details + + # 8️⃣ Fusion UNIQUEMENT sur iris + latents_out = latents_prot * (1 - iris_mask) + enhanced * iris_mask + + # 9️⃣ Clamp sécurité + latents_out = latents_out.clamp(-1.0, 1.0) + + print("👁 HDR détails appliqué sur iris") + + return latents_out + +def apply_pro_net_with_eyes_test(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn, + glow_strength=0.2, blur_kernel=7, iris_radius_ratio=0.08): + """ + Applique ProNet et un glow froid uniquement sur l’iris des yeux. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents à traiter. + eye_coords (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)] + n3r_pro_net: modèle ProNet + n3r_pro_strength (float): force ProNet + sanitize_fn: fonction de nettoyage latents + glow_strength (float): intensité du glow + blur_kernel (int): kernel pour flou gaussien + iris_radius_ratio (float): proportion de H/W pour rayon iris + + Returns: + torch.Tensor: latents avec glow sur iris uniquement + """ + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + # 1️⃣ Application ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + # 2️⃣ Glow froid uniquement sur l’iris + if eye_coords: + iris_mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + for x, y in eye_coords: + rx = int(W * iris_radius_ratio) + ry = int(H * iris_radius_ratio) + Y, X = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + dist2 = ((X - x)**2) / (rx**2) + ((Y - y)**2) / (ry**2) + iris_mask[0, 0] += (dist2 <= 1).float() + iris_mask = iris_mask.clamp(0, 1) + + # Kernel gaussien 2D + sigma = blur_kernel / 3 + ax = torch.arange(-blur_kernel // 2 + 1., blur_kernel // 2 + 1., device=device) + xx, yy = torch.meshgrid(ax, ax, indexing='ij') + kernel_2d = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) + kernel_2d = kernel_2d / kernel_2d.sum() + kernel = kernel_2d.view(1, 1, blur_kernel, blur_kernel).repeat(C, 1, 1, 1) + + # Convolution channel-wise pour glow + glow = F.conv2d(latents_prot * iris_mask, kernel, padding=blur_kernel // 2, groups=C) + + # Fusion uniquement sur l’iris + latents_out = latents_prot * (1 - iris_mask) + glow * iris_mask * glow_strength + latents_out = latents_out.clamp(-1.0, 1.0) + print("👁 Glow froid appliqué sur iris uniquement") + else: + # fallback si pas d’yeux détectés + latents_out = latents_prot + + return latents_out + + +def apply_pro_net_with_eye_glow(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn, glow_strength=0.2, blur_kernel=7): + """ + Applique ProNet et un glow froid uniquement sur les yeux. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents à traiter. + eye_coords (list of tuples): Coordonnées yeux [(x1,y1),(x2,y2)] + n3r_pro_net: modèle ProNet + n3r_pro_strength (float): force ProNet + sanitize_fn: fonction de nettoyage latents + glow_strength (float): intensité du glow + blur_kernel (int): kernel pour le flou + + Returns: + torch.Tensor: latents avec glow sur yeux uniquement + """ + # 1️⃣ Appliquer ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + # 2️⃣ Glow froid sur latents ProNet + glow_latents = apply_intelligent_glow_froid_latents(latents_prot, strength=glow_strength, blur_kernel=blur_kernel) + + + # 3️⃣ Fusion glow uniquement sur les yeux + if eye_coords: + eye_radius = int(min(latents.shape[-2:]) * 0.15) + eye_mask = create_eye_mask(latents, eye_coords, eye_radius) + if eye_mask is not None: + eye_mask = eye_mask.to(latents.device).float() + if eye_mask.ndim == 3: # [B,H,W] -> [B,1,H,W] + eye_mask = eye_mask.unsqueeze(1) + latents = latents * (1 - eye_mask) + glow_latents * eye_mask + print("👁 Glow froid appliqué uniquement sur yeux") + else: + latents = glow_latents # fallback + else: + latents = glow_latents # pas d’yeux détectés → glow global + + return latents + +# Application effect en dehors de yeux: +def apply_pro_net_with_out_eyes(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn): + # 1️⃣ Application du ProNet + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + + # 2️⃣ Application du glow froid intelligent en dehors des yeux sur le ProNet + latents_prot = apply_intelligent_glow_froid_out(latents_prot) + + # 3️⃣ Fusion avec le masque yeux si détecté + if eye_coords: + print("Eye coords:", eye_coords) + eye_radius = int(min(latents.shape[-2:]) * 0.15) # augmenter légèrement pour protection + eye_mask = create_eye_mask(latents, eye_coords, eye_radius) + + if eye_mask is not None: + eye_mask = eye_mask.to(latents.device) + # protection yeux + ProNet + glow + latents = latents * eye_mask + latents_prot * (1 - eye_mask) + print("👁 protection yeux appliquée avec glow froid") + else: + # si le masque échoue, on applique ProNet + glow sur tout + latents = latents_prot + else: + # pas d’yeux détectés → ProNet + glow global + latents = latents_prot + + return latents + + +def apply_pro_net_with_eye_simple(latents, eye_coords, n3r_pro_net, n3r_pro_strength, sanitize_fn): + latents_prot = apply_n3r_pro_net(latents, model=n3r_pro_net, strength=n3r_pro_strength, sanitize_fn=sanitize_fn) + if eye_coords: + print("Eye coords:", eye_coords) + eye_radius = int(min(latents.shape[-2:]) * 0.15) # un peu plus large valeur 0.12 ou 0.15 + eye_mask = create_eye_mask(latents, eye_coords, eye_radius) + if eye_mask is not None: + eye_mask = eye_mask.to(latents.device) + latents = latents * eye_mask + latents_prot * (1 - eye_mask) + print("👁 protection yeux appliquée (main frames)") + else: + latents = latents_prot + else: + latents = latents_prot + return latents + +def tensor_to_pil(tensor): + """ + tensor: [1,3,H,W] ou [3,H,W] dans [-1,1] + """ + if tensor.dim() == 4: + tensor = tensor[0] + tensor = (tensor.clamp(-1,1) + 1) / 2 + return to_pil_image(tensor.cpu()) + +try: + import mediapipe as mp + from mediapipe.python.solutions import face_mesh as mp_face_mesh + MP_FACE_MESH = mp_face_mesh +except Exception: + MP_FACE_MESH = None + print("⚠ mediapipe non disponible → fallback sans yeux") + + +def get_coords_safe(image, H, W): + coords = get_coords(image) + + if coords: + print(f"👁 Eyes detected: {coords}") + return coords + + print("⚠ fallback eye coords used") + + # 🔥 adapté portrait vertical (ton cas 536x960) + return [ + (int(H * 0.32), int(W * 0.38)), + (int(H * 0.32), int(W * 0.62)) + ] + +# -------------------------------------------------- +# 🔥 Détection yeux (version clean sans cv2) +# -------------------------------------------------- +def get_coords(image): + """ + Retourne [(y_left, x_left), (y_right, x_right)] + Compatible PIL ou numpy + """ + if MP_FACE_MESH is None: + return [] + + # Conversion propre + if isinstance(image, Image.Image): + img = np.array(image) + else: + img = image + + if img is None or img.ndim != 3: + return [] + + h, w, _ = img.shape + + with MP_FACE_MESH.FaceMesh(static_image_mode=True, max_num_faces=1) as face_mesh: + results = face_mesh.process(img) # ✅ déjà RGB → pas besoin de cv2 + + if not results.multi_face_landmarks: + return [] + + lm = results.multi_face_landmarks[0].landmark + + # 🔥 Points clés yeux (stables) + left_eye_pts = [33, 133] + right_eye_pts = [362, 263] + + left_eye = np.mean([(lm[i].y * h, lm[i].x * w) for i in left_eye_pts], axis=0) + right_eye = np.mean([(lm[i].y * h, lm[i].x * w) for i in right_eye_pts], axis=0) + + return [ + (int(left_eye[0]), int(left_eye[1])), + (int(right_eye[0]), int(right_eye[1])) + ] + +# -------------------------------------------------- +# 🔥 Création mask yeux (latents) +# -------------------------------------------------- +import torch +import matplotlib.pyplot as plt + + +def create_volumetrique_mask(latents, coords, radius_ratio=0.15, only=False, in_radius_ratio=0.08, debug=False): + """ + Crée un masque pour les yeux ou uniquement pour l’iris. + + Args: + latents (torch.Tensor): [B,C,H,W] Latents + coords (list of tuples): [(x1,y1),(x2,y2)] coordonnées yeux + radius_ratio (float): proportion H/W pour rayon + only (bool): True → masque uniquement iris, False → masque œil entier + in_radius_ratio (float): proportion H/W pour rayon iris si only=True + debug (bool): Si True, affiche le masque + + Returns: + torch.Tensor: [B,1,H,W] masque float (0=hors masque, 1=masque) + """ + #if not coords or latents.ndim != 4: + if coords is None or len(coords) == 0 or latents.ndim != 4: + return None + + B, C, H, W = latents.shape + device, dtype = latents.device, latents.dtype + + mask = torch.zeros((B, 1, H, W), device=device, dtype=dtype) + + for x, y in coords: + r = int(min(H, W) * (radius_ratio if only else in_radius_ratio)) + Y, X = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij') + dist2 = (X - x)**2 + (Y - y)**2 + mask[0, 0] += (dist2 <= r**2).float() + + mask = mask.clamp(0, 1) + + if debug: + # Affiche le masque superposé à un latents converti en image pour vérification + lat_vis = latents[0, 0].detach().cpu() # canal 0 + plt.figure(figsize=(6,6)) + plt.imshow(lat_vis, cmap='gray', alpha=0.7) + plt.imshow(mask[0,0].cpu(), cmap='Reds', alpha=0.3) + plt.title("Debug Eye/Iris Mask") + plt.show() + + return mask + +def create_eye_mask(latents, eye_coords, eye_radius=8, falloff=4): + """ + Soft mask gaussien → transitions naturelles + """ + if eye_coords is None or len(eye_coords) == 0: + return None + + B, C, H, W = latents.shape + mask = torch.zeros((1, 1, H, W), device=latents.device) + + for y_c, x_c in eye_coords: + y_lat = int(y_c / 8) + x_lat = int(x_c / 8) + + for y in range(H): + for x in range(W): + dist = ((y - y_lat)**2 + (x - x_lat)**2)**0.5 + value = max(0, 1 - dist / (eye_radius + falloff)) + mask[0, 0, y, x] = torch.maximum(mask[0, 0, y, x], torch.tensor(value, device=latents.device)) + + return mask.repeat(B, C, 1, 1) + +def detect_eyes_auto(frame_pil): + """Retourne les coordonnées (y,x) approximatives des yeux""" + img = np.array(frame_pil) + h, w, _ = img.shape + with MP_FACE_MESH.FaceMesh(static_image_mode=True, max_num_faces=1) as face_mesh: + results = face_mesh.process(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + if not results.multi_face_landmarks: + return [] + lm = results.multi_face_landmarks[0].landmark + left_eye = np.mean([(lm[i].y*h, lm[i].x*w) for i in [33, 133]], axis=0) + right_eye = np.mean([(lm[i].y*h, lm[i].x*w) for i in [362, 263]], axis=0) + return [(int(left_eye[0]), int(left_eye[1])), (int(right_eye[0]), int(right_eye[1]))] + +# Decode avec blending optimise : +# +# --------------------------------------------------------------------------------------------- +def decode_latents_ultrasafe_blockwise_ultranatural( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.015, + sharpen_edges_strength=0.02, + gamma_boost=1.03 # légèrement plus de punch naturel +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae.eval() # pas besoin de caster tout le VAE + B, C, H, W = latents.shape + + # ⚡ latents en float16 pour réduire VRAM, multiplication par scale + latents = latents.to(device=device, dtype=torch.float16) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device, dtype=torch.float32) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device, dtype=torch.float32) + wx = torch.hann_window(w, device=device, dtype=torch.float32) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device, dtype=torch.float32) + x = torch.linspace(0, 1, w, device=device, dtype=torch.float32) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode patch par patch ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + # ⚡ Convertir temporairement patch en float32 pour compatibilité VAE + decoded = vae.decode(patch.to(torch.float32)).sample + decoded = decoded.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + feather = create_feather(fh, fw).unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ⚡ Libération VRAM patch + del patch, decoded, feather + torch.cuda.empty_cache() + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ---------------- Sharp adaptatif ---------------- + if sharpen_mode is not None: + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + local_std = detail.std(dim=[2,3], keepdim=True) + 1e-6 + adapt_strength = sharpen_strength / (1 + 5*(1-local_std)) + output_rgb = output_rgb + adapt_strength * torch.tanh(detail) + + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + kernel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], device=device, dtype=output_rgb.dtype) + kernel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], device=device, dtype=output_rgb.dtype) + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- Gamma adaptatif ---------------- + output_rgb_gamma = ((output_rgb + 1) / 2.0).clamp(0,1) + luminance = output_rgb_gamma.mean(dim=1, keepdim=True) + adapt_gamma = gamma_boost * (1.0 + 0.1*(0.5-luminance)) + output_rgb_gamma = output_rgb_gamma ** adapt_gamma + output_rgb = output_rgb_gamma * 2 - 1 + + # ---------------- Micro-boost couleur ---------------- + mean_c = output_rgb.mean(dim=[2,3], keepdim=True) + color_boost = torch.sigmoid(5.0*(output_rgb - mean_c)) * 0.03 + output_rgb = (output_rgb + color_boost).clamp(-1.0, 1.0) + + # ---------------- Conversion PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + +def decode_latents_ultrasafe_blockwise_ultranatural_optimized( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", + sharpen_strength=0.015, + sharpen_edges_strength=0.02, + gamma_boost=1.03 +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + # accumulation directement sur CPU + output_rgb = torch.zeros(B, 3, out_H, out_W, dtype=torch.float32, device="cpu") + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # Feather patch + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + # Decode sur GPU + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + # feather + fh, fw = decoded.shape[2], decoded.shape[3] + feather = create_feather(fh, fw).unsqueeze(0).unsqueeze(0) + + # Move decoded sur CPU immédiatement + decoded_cpu = (decoded * feather).to("cpu") + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded_cpu + weight[:, :, iy0:iy1, ix0:ix1] += feather.to("cpu") + + # Libération VRAM + del patch, decoded, feather, decoded_cpu + torch.cuda.empty_cache() + + # Normalisation + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + del weight + torch.cuda.empty_cache() + + # 🔥 Sharpen adaptatif (CPU) + if sharpen_mode is not None: + output_rgb = output_rgb.clone() # pour sécurité + + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + local_std = detail.std(dim=[2,3], keepdim=True) + 1e-6 + adapt_strength = sharpen_strength / (1 + 5*(1-local_std)) + output_rgb += adapt_strength * torch.tanh(detail) + + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + kernel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32) + kernel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32) + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + output_rgb += sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # Gamma adaptatif (CPU) + output_rgb_gamma = ((output_rgb + 1) / 2.0).clamp(0,1) + luminance = output_rgb_gamma.mean(dim=1, keepdim=True) + adapt_gamma = gamma_boost * (1.0 + 0.1*(0.5-luminance)) + output_rgb_gamma = output_rgb_gamma ** adapt_gamma + output_rgb = output_rgb_gamma * 2 - 1 + del output_rgb_gamma, luminance, adapt_gamma + + # Micro-boost couleur + mean_c = output_rgb.mean(dim=[2,3], keepdim=True) + color_boost = torch.sigmoid(5.0*(output_rgb - mean_c)) * 0.03 + output_rgb = (output_rgb + color_boost).clamp(-1.0, 1.0) + del mean_c, color_boost + + # Conversion PIL frame par frame + frames = [to_pil_image((output_rgb[i]+1)/2) for i in range(B)] + del output_rgb + torch.cuda.empty_cache() + return frames[0] if B==1 else frames + +def decode_latents_ultrasafe_blockwise_ultranatural_stable( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.015, + sharpen_edges_strength=0.02, + gamma_boost=1.03 # légèrement plus de punch naturel +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device, dtype=torch.float32) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ⚡ Libération VRAM patch + del patch, decoded, feather + torch.cuda.empty_cache() + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + del weight + torch.cuda.empty_cache() + + # ========================================================= + # 🔥 SHARPEN ADAPTATIF + # ========================================================= + if sharpen_mode is not None: + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + local_std = detail.std(dim=[2,3], keepdim=True) + 1e-6 + adapt_strength = sharpen_strength / (1 + 5*(1-local_std)) + output_rgb = output_rgb + adapt_strength * torch.tanh(detail) + + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + kernel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], device=device, dtype=output_rgb.dtype) + kernel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], device=device, dtype=output_rgb.dtype) + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- Gamma adaptatif ---------------- + output_rgb_gamma = ((output_rgb + 1) / 2.0).clamp(0,1) + luminance = output_rgb_gamma.mean(dim=1, keepdim=True) + adapt_gamma = gamma_boost * (1.0 + 0.1*(0.5-luminance)) + output_rgb_gamma = output_rgb_gamma ** adapt_gamma + output_rgb = output_rgb_gamma * 2 - 1 + del output_rgb_gamma, luminance, adapt_gamma + torch.cuda.empty_cache() + + # ---------------- Micro-boost couleur ---------------- + mean_c = output_rgb.mean(dim=[2,3], keepdim=True) + color_boost = torch.sigmoid(5.0*(output_rgb - mean_c)) * 0.03 + output_rgb = (output_rgb + color_boost).clamp(-1.0, 1.0) + del mean_c, color_boost + torch.cuda.empty_cache() + + # ---------------- To PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + del output_rgb + torch.cuda.empty_cache() + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_natural( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.02, + sharpen_edges_strength=0.02, + gamma_boost=1.10 # 12% plus de punch naturel +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN SECTION ADAPTATIVE + # ========================================================= + if sharpen_mode is not None: + + # ---- 1. Tanh sharpen (détails globaux adaptatifs) + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + # facteur adaptatif selon contraste local + local_std = detail.std(dim=[2,3], keepdim=True) + 1e-6 + adapt_strength = sharpen_strength / (1 + 5*(1-local_std)) + output_rgb = output_rgb + adapt_strength * torch.tanh(detail) + + # ---- 2. Edge sharpen adaptatif + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + + kernel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], device=device, dtype=output_rgb.dtype) + kernel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], device=device, dtype=output_rgb.dtype) + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- Gamma adaptatif ---------------- + output_rgb_gamma = ((output_rgb + 1) / 2.0).clamp(0,1) # [0,1] + output_rgb_gamma = output_rgb_gamma ** gamma_boost + output_rgb = output_rgb_gamma * 2 - 1 + + # ---------------- To PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_sharp( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.04, + sharpen_edges_strength=0.05 +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN SECTION + # ========================================================= + + if sharpen_mode is not None: + + # ---- 1. Tanh sharpen (détails globaux) + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + output_rgb = output_rgb + sharpen_strength * torch.tanh(detail) + + # ---- Edge sharpen PRO (anti plastique) + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + + kernel_x = torch.tensor( + [[-1,0,1],[-2,0,2],[-1,0,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_y = torch.tensor( + [[-1,-2,-1],[0,0,0],[1,2,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + + # 🔥 NORMALISATION douce (pas globale) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + + # 🔥 MASQUE BEAUCOUP plus sélectif (clé) + edge_mask = torch.sigmoid(6.0 * (edges - 0.7)) + + # 🔥 DIRECTION du contraste (pas ajout brut) + sign = torch.sign(output_rgb) + + output_rgb = output_rgb + sharpen_edges_strength * edge_mask * sign * edges * 0.5 + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- To PIL ---------------- + # Ajouter gamma boost ici + gamma = 1.10 + output_rgb_gamma = ((output_rgb + 1.0) / 2.0).clamp(0,1) + output_rgb_gamma = output_rgb_gamma ** gamma + output_rgb_gamma = output_rgb_gamma * 2.0 - 1.0 + output_rgb = output_rgb_gamma + + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_plastique( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True, + sharpen_mode="both", # None, "tanh", "edges", "both" + sharpen_strength=0.04, + sharpen_edges_strength=0.05 +): + import torch + import torch.nn.functional as F + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # ---------------- Feather ---------------- + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + # ---------------- Decode ---------------- + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # ---------------- Normalisation ---------------- + weight = torch.clamp(weight, min=1e-3) + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + # ========================================================= + # 🔥 SHARPEN SECTION + # ========================================================= + + if sharpen_mode is not None: + + # ---- 1. Tanh sharpen (détails globaux) + if sharpen_mode in ["tanh", "both"]: + mean = output_rgb.mean(dim=[2,3], keepdim=True) + detail = output_rgb - mean + output_rgb = output_rgb + sharpen_strength * torch.tanh(detail) + + # ---- 2. Edge sharpen (version PRO stable) + if sharpen_mode in ["edges", "both"]: + B, C, H, W = output_rgb.shape + + kernel_x = torch.tensor( + [[-1,0,1],[-2,0,2],[-1,0,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_y = torch.tensor( + [[-1,-2,-1],[0,0,0],[1,2,1]], + device=device, + dtype=output_rgb.dtype + ) + + kernel_x = kernel_x.view(1,1,3,3).repeat(C,1,1,1) + kernel_y = kernel_y.view(1,1,3,3).repeat(C,1,1,1) + + grad_x = F.conv2d(output_rgb, kernel_x, padding=1, groups=C) + grad_y = F.conv2d(output_rgb, kernel_y, padding=1, groups=C) + + edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) + + # 🔥 NORMALISATION LOCALE (clé stabilité) + edges = edges / (edges.mean(dim=[2,3], keepdim=True) + 1e-6) + + # 🔥 MASQUE edges (évite bruit dans zones plates) + edge_mask = torch.sigmoid(4.0 * (edges - 0.5)) + + # 🔥 Sharpen intelligent + output_rgb = output_rgb + sharpen_edges_strength * edges * edge_mask + + output_rgb = output_rgb.clamp(-1.0, 1.0) + + # ---------------- To PIL ---------------- + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def decode_latents_ultrasafe_blockwise_pro( + latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0, + use_hann=True +): + import torch + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # 🔥 Fenêtre de blending PRO (Hann = ultra stable) + def create_feather(h, w): + if use_hann: + wy = torch.hann_window(h, device=device) + wx = torch.hann_window(w, device=device) + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + else: + y = torch.linspace(0, 1, h, device=device) + x = torch.linspace(0, 1, w, device=device) + wy = 1 - torch.abs(y - 0.5) * 2 + wx = 1 - torch.abs(x - 0.5) * 2 + return (wy[:, None] * wx[None, :]).clamp(min=1e-3) + + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + fh, fw = decoded.shape[2], decoded.shape[3] + + # 🔥 feather dynamique (corrige bord image) + feather = create_feather(fh, fw) + feather = feather.unsqueeze(0).unsqueeze(0) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + fh, ix0 + fw + + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded * feather + weight[:, :, iy0:iy1, ix0:ix1] += feather + + # 🔥 sécurité critique (évite artefacts) + weight = torch.clamp(weight, min=1e-3) + + output_rgb = (output_rgb / weight).clamp(-1.0, 1.0) + + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +# Decode latents par blockwise - ultrasafe : +def decode_latents_ultrasafe_blockwise(latents, vae, + block_size=32, overlap=16, + device="cuda", + frame_counter=0, + latent_scale_boost=1.0): + """ + Décodage ultra-safe par blocs des latents en image PIL. + Paramètres conservés uniquement : block_size, overlap, device, frame_counter, latent_scale_boost + """ + import torch + from torchvision.transforms.functional import to_pil_image + + vae = vae.to(device=device, dtype=torch.float32) + vae.eval() + + B, C, H, W = latents.shape + latents = latents.to(device=device, dtype=torch.float32) * latent_scale_boost + + out_H, out_W = H * 8, W * 8 + output_rgb = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output_rgb) + + stride = block_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + for y in y_positions: + for x in x_positions: + y1 = min(y + block_size, H) + x1 = min(x + block_size, W) + patch = latents[:, :, y:y1, x:x1] + patch = torch.nan_to_num(patch, nan=0.0) + + with torch.no_grad(): + decoded = vae.decode(patch).sample.to(torch.float32) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = iy0 + decoded.shape[2], ix0 + decoded.shape[3] + output_rgb[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + output_rgb = (output_rgb / weight.clamp(min=1e-6)).clamp(-1.0, 1.0) + + frames = [to_pil_image((output_rgb[i] + 1) / 2) for i in range(B)] + return frames[0] if B == 1 else frames + + +def apply_intelligent_glow_pro( + frame_pil, + strength=0.18, + edge_weight=0.6, + luminance_weight=0.8, + blur_radius=1.2 +): + from PIL import Image, ImageFilter + import numpy as np + + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ---------------- Luminance ---------------- + lum = 0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2] + lum_mask = np.clip((lum - 0.6) / 0.4, 0, 1) + lum_mask = np.power(lum_mask, 1.5) + + # ---------------- Edge ---------------- + gray = (lum * 255).astype(np.uint8) + edge = Image.fromarray(gray).filter(ImageFilter.FIND_EDGES) + edge = np.array(edge).astype(np.float32) / 255.0 + edge = np.clip(edge * 1.2, 0, 1) + edge = np.power(edge, 1.3) + + # ---------------- Mask combiné ---------------- + combined_mask = np.clip(luminance_weight * lum_mask + edge_weight * edge, 0, 1) + + # ---------------- Glow ---------------- + glow_img = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + glow_arr = np.array(glow_img).astype(np.float32) / 255.0 + + # ---------------- Appliquer glow seulement sur la luminance ---------------- + glow_lum = 0.299 * glow_arr[:, :, 0] + 0.587 * glow_arr[:, :, 1] + 0.114 * glow_arr[:, :, 2] + + # mixer luminance glow + couleur originale + result = arr.copy() + for c in range(3): + # conserver la teinte originale mais injecter glow sur la luminosité + result[:, :, c] = arr[:, :, c] + (glow_lum - lum) * combined_mask * strength + + result = np.clip(result, 0, 1) + return Image.fromarray((result * 255).astype(np.uint8)) + + +def apply_intelligent_glow_froid( + frame_pil, + strength=0.18, + edge_weight=0.6, + luminance_weight=0.8, + blur_radius=1.2 +): + from PIL import Image, ImageFilter, ImageEnhance + import numpy as np + + # ---------------- Base ---------------- + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ---------------- Luminance mask ---------------- + # luminance perceptuelle + lum = 0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2] + + # masque doux (favorise les zones claires) + lum_mask = np.clip((lum - 0.6) / 0.4, 0, 1) + lum_mask = np.power(lum_mask, 1.5) # douceur + + # ---------------- Edge mask ---------------- + gray = (lum * 255).astype(np.uint8) + edge = Image.fromarray(gray).filter(ImageFilter.FIND_EDGES) + edge = np.array(edge).astype(np.float32) / 255.0 + + # adoucir les edges (évite bruit) + edge = np.clip(edge * 1.2, 0, 1) + edge = np.power(edge, 1.3) + + # ---------------- Fusion intelligente ---------------- + combined_mask = ( + luminance_weight * lum_mask + + edge_weight * edge + ) + + combined_mask = np.clip(combined_mask, 0, 1) + + # ---------------- Glow ---------------- + # blur image pour glow + glow_img = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + glow_arr = np.array(glow_img).astype(np.float32) / 255.0 + + # appliquer glow uniquement où mask actif + result = arr + (glow_arr - arr) * combined_mask[..., None] * strength + + result = np.clip(result, 0, 1) + + return Image.fromarray((result * 255).astype(np.uint8)) + + +def apply_post_processing_adaptive( + frame_pil, + blur_radius=0.03, + contrast=1.10, + vibrance_strength=0.25, # 🔥 contrôle simple (0 → off, 0.3 = doux) + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True +): + from PIL import ImageEnhance, ImageFilter + import numpy as np + + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + # ---------------- 1️⃣ Micro blur (anti pixel) ---------------- + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ---------------- 2️⃣ Contrast (seul vrai levier global) ---------------- + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + + # ---------------- 3️⃣ Vibrance douce (version stable) ---------------- + if vibrance_strength > 0: + try: + arr = np.array(frame_pil).astype(np.float32) + + # saturation simple + max_rgb = arr.max(axis=2) + min_rgb = arr.min(axis=2) + sat = (max_rgb - min_rgb) / 255.0 + + # 🔥 boost UNIQUEMENT zones peu saturées + boost = 1.0 + vibrance_strength * (1.0 - sat) + + arr = arr * boost[..., None] + + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] vibrance skipped: {e}") + + # ---------------- 4️⃣ Clamp rouge (anti rose / peau cramée) ---------------- + if clamp_r: + try: + arr = np.array(frame_pil).astype(np.float32) + + r = arr[:, :, 0] + r_mean = r.mean() + + if r_mean > 160: # 🔥 seuil plus bas = plus stable + factor = 160 / (r_mean + 1e-6) + arr[:, :, 0] *= factor + + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] clamp rouge skipped: {e}") + + # ---------------- 5️⃣ Sharpen léger ---------------- + if sharpen: + try: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + except Exception as e: + print(f"[WARNING] sharpening skipped: {e}") + + return frame_pil + + + + +def smooth_edges(frame_pil, strength=0.4, blur_radius=1.2): + from PIL import ImageFilter, ImageChops + import numpy as np + + # 1️⃣ edges + edges = frame_pil.convert("L").filter(ImageFilter.FIND_EDGES) + + # 2️⃣ normalisation du masque + edges_np = np.array(edges).astype(np.float32) / 255.0 + edges_np = np.clip(edges_np * 2.0, 0, 1) # renforce zones edges + + # 3️⃣ blur global (source) + blurred = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # 4️⃣ blend intelligent + orig = np.array(frame_pil).astype(np.float32) + blur = np.array(blurred).astype(np.float32) + + mask = edges_np[..., None] * strength + + result = orig * (1 - mask) + blur * mask + + return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) + + +def apply_post_processing_unreal_cinematic( + frame_pil, + exposure=1.0, + vibrance=1.02, + edge_strength=0.25, + sharpen=True, + brightness_adj=0.90, # 🔻 -5% + contrast_adj=1.65 # 🔺 +65% +): + from PIL import Image, ImageEnhance, ImageFilter, ImageChops + import numpy as np + + # 🔥 1. Base (sans toucher contraste global) + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr *= exposure + + # Vibrance douce + mean_c = arr.mean(axis=2, keepdims=True) + arr = mean_c + (arr - mean_c) * vibrance + arr = np.clip(arr, 0, 1) + + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ========================= + # ✏️ EDGE CRAYON BLANC + # ========================= + gray = img.convert("L") + edges = gray.filter(ImageFilter.FIND_EDGES) + + edges = edges.filter(ImageFilter.GaussianBlur(radius=0.8)) + edges = ImageChops.invert(edges) + edges = ImageEnhance.Contrast(edges).enhance(1.2) + + edge_rgb = Image.merge("RGB", (edges, edges, edges)) + + # Screen = effet lumineux propre + img_edges = ImageChops.screen(img, edge_rgb) + + # Blend final contrôlé + img = Image.blend(frame_pil, img_edges, edge_strength) + + # ========================= + # 🔥 AJUSTEMENTS DEMANDÉS + # ========================= + img = ImageEnhance.Brightness(img).enhance(brightness_adj) + img = ImageEnhance.Contrast(img).enhance(contrast_adj) + + # ========================= + # 🔧 Sharpen doux + # ========================= + if sharpen: + img = img.filter(ImageFilter.UnsharpMask( + radius=0.5, + percent=30, + threshold=3 + )) + + # 🔥 micro lissage final + img = img.filter(ImageFilter.GaussianBlur(radius=0.25)) + + return img + +def apply_post_processing_minimal( + frame_pil, + blur_radius=0.05, + contrast=1.15, + vibrance_base=1.0, + vibrance_max=1.25, + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True +): + from PIL import Image, ImageFilter, ImageEnhance + import numpy as np + + if frame_pil.mode != "RGB": + frame_pil = frame_pil.convert("RGB") + + # ---------------- 1. Blur léger ---------------- + if blur_radius > 0: + frame_pil = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ---------------- 2. Contraste ---------------- + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + + # ---------------- 3. Vibrance adaptative ---------------- + try: + frame_np = np.array(frame_pil).astype(np.float32) + + max_rgb = np.max(frame_np, axis=2) + min_rgb = np.min(frame_np, axis=2) + sat = max_rgb - min_rgb + + factor_map = vibrance_base + (vibrance_max - vibrance_base) * (1 - sat / 255.0) + factor_map = np.clip(factor_map, vibrance_base, vibrance_max) + + frame_np *= factor_map[..., None] + frame_np = np.clip(frame_np, 0, 255) + + frame_pil = Image.fromarray(frame_np.astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] vibrance skipped: {e}") + + # ---------------- 4. Clamp rouge ---------------- + if clamp_r: + try: + arr = np.array(frame_pil).astype(np.float32) + r_mean = arr[..., 0].mean() + + if r_mean > 180: + factor = 180 / r_mean + arr[..., 0] *= factor + + frame_pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + except Exception as e: + print(f"[WARNING] clamp rouge skipped: {e}") + + # ---------------- 5. Sharpen ---------------- + if sharpen: + frame_pil = frame_pil.filter(ImageFilter.UnsharpMask( + radius=sharpen_radius, + percent=sharpen_percent, + threshold=sharpen_threshold + )) + + return frame_pil + +def apply_intelligent_glow(frame_pil, + glow_strength=0.22, + blur_radius=1.2, + luminance_threshold=0.7, + edge_strength=1.2, + detail_preservation=0.85): + """ + Glow intelligent : + - basé sur luminance + edges + - évite effet flou global + - boost détails lumineux uniquement + """ + from PIL import Image, ImageFilter, ImageEnhance, ImageChops + import numpy as np + + # ----------------------- + # 1️⃣ Base numpy + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ----------------------- + # 2️⃣ Luminance mask + # ----------------------- + gray = frame_pil.convert("L") + lum = np.array(gray).astype(np.float32) / 255.0 + + lum_mask = np.clip((lum - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + + # ----------------------- + # 3️⃣ Edge mask (important 🔥) + # ----------------------- + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = ImageEnhance.Contrast(edges).enhance(edge_strength) + + edge_arr = np.array(edges).astype(np.float32) / 255.0 + + # 🔥 combinaison intelligente + combined_mask = lum_mask * edge_arr + + # ----------------------- + # 4️⃣ Glow blur + # ----------------------- + blurred = frame_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + blurred_arr = np.array(blurred).astype(np.float32) / 255.0 + + # ----------------------- + # 5️⃣ Application du glow + # ----------------------- + for c in range(3): + arr[..., c] = arr[..., c] + glow_strength * combined_mask * blurred_arr[..., c] + + arr = np.clip(arr, 0, 1) + + # ----------------------- + # 6️⃣ Reconstruction + # ----------------------- + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 7️⃣ Préservation détails + # ----------------------- + img = Image.blend(frame_pil, img, 1 - detail_preservation) + + # ----------------------- + # 8️⃣ Micro sharpen + # ----------------------- + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=25, threshold=2)) + + return img + + +def apply_chromatic_soft_glow(frame_pil, + glow_strength=0.25, + exposure=1.05, + blur_radius=2.0, + luminance_threshold=0.8, + color_saturation=1.05, + sharpen=True): + """ + Soft Glow chromatique localisé : + - Glow appliqué sur pixels clairs selon leur canal (R/G/B) + - Zones sombres préservées + - Détails conservés + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr = np.clip(arr * exposure, 0, 1) + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # Masque par canal + # ----------------------- + r, g, b = arr[...,0], arr[...,1], arr[...,2] + mask_r = np.clip((r - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + mask_g = np.clip((g - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + mask_b = np.clip((b - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + + # ----------------------- + # Glow par canal + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + bright_arr = np.array(bright).astype(np.float32) / 255.0 + + # Mélange selon masque couleur + arr[...,0] = np.clip(arr[...,0] + glow_strength * mask_r * bright_arr[...,0], 0, 1) + arr[...,1] = np.clip(arr[...,1] + glow_strength * mask_g * bright_arr[...,1], 0, 1) + arr[...,2] = np.clip(arr[...,2] + glow_strength * mask_b * bright_arr[...,2], 0, 1) + + img = Image.fromarray((arr*255).astype(np.uint8)) + + # ----------------------- + # Saturation douce + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + + # ----------------------- + # Micro sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=30, threshold=2)) + + return img + + +def apply_localized_soft_glow(frame_pil, + glow_strength=0.25, + exposure=1.05, + blur_radius=2.0, + luminance_threshold=0.6, + color_saturation=1.05, + sharpen=True): + """ + Filtre 'Soft Glow Localisé': + - Glow appliqué seulement sur les zones lumineuses + - Effet subtil, préserve les zones sombres + - Maintien des détails + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Convertir en float + exposure + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr = np.clip(arr * exposure, 0, 1) + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 2️⃣ Masque de luminosité + # ----------------------- + gray = img.convert("L") + lum_arr = np.array(gray).astype(np.float32) / 255.0 + mask = np.clip((lum_arr - luminance_threshold) / (1.0 - luminance_threshold), 0, 1) + mask_img = Image.fromarray((mask * 255).astype(np.uint8)) + + # ----------------------- + # 3️⃣ Glow léger + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + glow_img = ImageChops.screen(img, bright) + # Appliquer glow uniquement là où mask > 0 + glow_img = Image.composite(glow_img, img, mask_img) + img = Image.blend(img, glow_img, glow_strength) + + # ----------------------- + # 4️⃣ Saturation douce + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + + # ----------------------- + # 5️⃣ Micro sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=30, threshold=2)) + + return img + + +def apply_soft_glow(frame_pil, + glow_strength=0.25, + exposure=1.05, + blur_radius=2.0, + color_saturation=1.05, + sharpen=True): + """ + Filtre 'Soft Glow' : + - Surexposition douce sur les zones claires + - Glow léger et subtil + - Maintien des détails et textures + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Convertir en float + exposure léger + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + arr = np.clip(arr * exposure, 0, 1) + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 2️⃣ Glow subtil (Light Bloom) + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + img = ImageChops.screen(img, bright) + img = Image.blend(img, bright, glow_strength) + + # ----------------------- + # 3️⃣ Saturation douce + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + + # ----------------------- + # 4️⃣ Micro sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=30, threshold=2)) + + return img + + +def apply_cinematic_neon_glow(frame_pil, + glow_strength=0.25, + edge_strength=0.15, + color_saturation=1.15, + exposure=1.05, + contrast=1.25, + blur_radius=0.4, + sharpen=True): + """ + Filtre original 'Cinematic Neon Glow': + - Glow subtil autour des zones claires + - Couleurs saturées style néon / cinématographique + - Bords légèrement lumineux type sketch + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Convertir en float + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) / 255.0 + + # ----------------------- + # 2️⃣ Exposure léger + # ----------------------- + arr *= exposure + arr = np.clip(arr, 0, 1) + + img = Image.fromarray((arr * 255).astype(np.uint8)) + + # ----------------------- + # 3️⃣ Glow subtil (Light Bloom) + # ----------------------- + bright = img.filter(ImageFilter.GaussianBlur(radius=5)) + img = ImageChops.screen(img, bright) # effet lumineux + img = Image.blend(img, bright, glow_strength) + + # ----------------------- + # 4️⃣ Edge sketch léger + # ----------------------- + gray = img.convert("L").filter(ImageFilter.GaussianBlur(radius=1.0)) + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = ImageChops.invert(edges) + edges_rgb = Image.merge("RGB", (edges, edges, edges)) + img = ImageChops.blend(img, edges_rgb, edge_strength) + + # ----------------------- + # 5️⃣ Saturation & Contraste + # ----------------------- + img = ImageEnhance.Color(img).enhance(color_saturation) + img = ImageEnhance.Contrast(img).enhance(contrast) + + # ----------------------- + # 6️⃣ Micro blur anti-pixel + # ----------------------- + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ----------------------- + # 7️⃣ Sharpen subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=40, threshold=2)) + + return img + + +def apply_post_processing_sketch(frame_pil, edge_strength=0.2, blur_radius=0.3, sharpen=True, + contrast_boost=1.6, # +60% contraste + exposure=0.80): # -20% brillance + """ + Effet dessin subtil / croquis clair ajusté : + - Contours légèrement visibles (blancs doux) + - +40% contraste, -10% brillance + - Lisse les pixels isolés + - Ne dénature pas les couleurs de base + """ + from PIL import Image, ImageFilter, ImageChops, ImageEnhance + import numpy as np + + # ----------------------- + # 1️⃣ Edge detection doux + # ----------------------- + gray = frame_pil.convert("L").filter(ImageFilter.GaussianBlur(radius=0.5)) + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = edges.filter(ImageFilter.MedianFilter(size=3)) # supprime points isolés + edges = edges.filter(ImageFilter.GaussianBlur(radius=0.6)) # lissage + edges = ImageEnhance.Contrast(edges).enhance(1.2) + edges = ImageChops.invert(edges) + edge_rgb = Image.merge("RGB", (edges, edges, edges)) + + # ----------------------- + # 2️⃣ Fusion douce des edges + # ----------------------- + img = ImageChops.blend(frame_pil, edge_rgb, edge_strength) + + # ----------------------- + # 3️⃣ Exposure / Brillance + # ----------------------- + img = ImageEnhance.Brightness(img).enhance(exposure) + + # ----------------------- + # 4️⃣ Contraste + # ----------------------- + img = ImageEnhance.Contrast(img).enhance(contrast_boost) + + # ----------------------- + # 5️⃣ Blur léger anti-pixel + # ----------------------- + if blur_radius > 0: + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # ----------------------- + # 6️⃣ Sharp subtil + # ----------------------- + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.5, percent=40, threshold=2)) + + return img + + + +def apply_post_processing_drawing(frame_pil, + edge_strength=0.7, + color_levels=48, + saturation=0.95, + contrast=1.10, + sharpen=True): + """ + Post-processing dessin type line-art. + Simplifie les couleurs, ajoute des contours au crayon blanc, + supprime les points noirs et garde un rendu net. + """ + + from PIL import Image, ImageFilter, ImageEnhance, ImageChops + import numpy as np + + # ----------------------- + # 1️⃣ Color simplification douce + # ----------------------- + arr = np.array(frame_pil).astype(np.float32) + levels = color_levels + arr = np.round(arr / (256 / levels)) * (256 / levels) + img = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + # ----------------------- + # 2️⃣ Edge detection propre + # ----------------------- + gray = frame_pil.convert("L").filter(ImageFilter.GaussianBlur(radius=0.6)) + edges = gray.filter(ImageFilter.FIND_EDGES) + edges = edges.filter(ImageFilter.GaussianBlur(radius=0.8)) + edges = edges.filter(ImageFilter.MedianFilter(size=3)) # supprime points isolés + edges = ImageEnhance.Contrast(edges).enhance(1.4) + edges = edges.point(lambda x: 0 if x < 15 else int(x * 1.2)) + edges = ImageChops.invert(edges) + edge_rgb = Image.merge("RGB", (edges, edges, edges)) + + # ----------------------- + # 3️⃣ Fusion douce contours + # ----------------------- + img_edges = ImageChops.multiply(img, edge_rgb) + img = Image.blend(img, img_edges, edge_strength * 0.85) + + # ----------------------- + # 4️⃣ Color / Contrast / Sharpen + # ----------------------- + img = ImageEnhance.Color(img).enhance(saturation) + img = ImageEnhance.Contrast(img).enhance(contrast) + if sharpen: + img = img.filter(ImageFilter.UnsharpMask(radius=0.6, percent=60, threshold=3)) + + return img + + + + +def save_frame_verbose(frame: Image.Image, output_dir: Path, frame_counter: int, suffix: str = "00", psave: bool = True): + """ + Sauvegarde une frame avec suffixe et affiche un message si verbose=True + + Args: + frame (Image.Image): Image PIL à sauvegarder + output_dir (Path): Dossier de sortie + frame_counter (int): Numéro de frame + suffix (str): Suffixe pour différencier les étapes + verbose (bool): Affiche le message si True + """ + file_path = output_dir / f"frame_{frame_counter:05d}_{suffix}.png" + + if psave: + print(f"[SAVE Frame {frame_counter:03d}_{suffix}] -> {file_path}") + frame.save(file_path) + return file_path + +def neutralize_color_cast(img, strength=0.45, warm_bias=0.015, green_bias=-0.07): + """ + Neutralise la dominante de couleur tout en corrigeant un excès de vert. + + Args: + img (PIL.Image): image à corriger + strength (float): intensité de neutralisation (0.0 = off, 1.0 = full) + warm_bias (float): réchauffe légèrement (rouge+/bleu-) + green_bias (float): ajuste le vert (-0.07 = moins 7%) + """ + import numpy as np + from PIL import Image + + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + gray = mean.mean() + + gain = gray / (mean + 1e-6) + gain = 1.0 + (gain - 1.0) * strength + + arr[..., 0] *= gain[0] * (1 + warm_bias) # rouge + + arr[..., 1] *= gain[1] * (1 + green_bias) # vert corrigé + arr[..., 2] *= gain[2] * (1 - warm_bias) # bleu - + + arr = np.clip(arr, 0, 255) + + return Image.fromarray(arr.astype(np.uint8)) + + +def neutralize_color_cast_clean(img, strength=0.6, warm_bias=0.02): + import numpy as np + from PIL import Image + + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + gray = mean.mean() + + gain = gray / (mean + 1e-6) + gain = 1.0 + (gain - 1.0) * strength + + arr[..., 0] *= gain[0] * (1 + warm_bias) # 🔥 léger rouge + + arr[..., 1] *= gain[1] + arr[..., 2] *= gain[2] * (1 - warm_bias) # 🔥 léger bleu - + + return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + +def neutralize_color_cast_str(img, strength=0.6): + import numpy as np + from PIL import Image + + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + gray = mean.mean() + + gain = gray / (mean + 1e-6) + + # 🔥 interpolation (clé) + gain = 1.0 + (gain - 1.0) * strength + + arr[..., 0] *= gain[0] + arr[..., 1] *= gain[1] + arr[..., 2] *= gain[2] + + return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + + +def neutralize_color_cast_simple(img): + import numpy as np + arr = np.array(img).astype(np.float32) + + mean = arr.mean(axis=(0,1)) + + # cible gris neutre + gray = mean.mean() + + gain = gray / (mean + 1e-6) + + arr[..., 0] *= gain[0] + arr[..., 1] *= gain[1] + arr[..., 2] *= gain[2] + + return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) + +def kelvin_to_rgb(temp): + """ + Approximation réaliste Kelvin → RGB (inspiré photographie) + """ + temp = temp / 100.0 + + # Rouge + if temp <= 66: + r = 255 + else: + r = temp - 60 + r = 329.698727446 * (r ** -0.1332047592) + + # Vert + if temp <= 66: + g = temp + g = 99.4708025861 * math.log(g) - 161.1195681661 + else: + g = temp - 60 + g = 288.1221695283 * (g ** -0.0755148492) + + # Bleu + if temp >= 66: + b = 255 + elif temp <= 19: + b = 0 + else: + b = temp - 10 + b = 138.5177312231 * math.log(b) - 305.0447927307 + + return ( + max(0, min(255, r)) / 255.0, + max(0, min(255, g)) / 255.0, + max(0, min(255, b)) / 255.0 + ) + +def adjust_color_temperature( + image, + target_temp=7800, + reference_temp=6500, + strength=0.5, + adaptive=True, + max_gain=2.0, + debug=False +): + import numpy as np + + img = np.array(image).astype(np.float32) / 255.0 + + # --- 1. Gains température (comme ton code) + r1, g1, b1 = kelvin_to_rgb(reference_temp) + r2, g2, b2 = kelvin_to_rgb(target_temp) + + base_gain = np.array([ + r2 / r1, + g2 / g1, + b2 / b1 + ]) + + # --- 2. Estimation rapide du WB actuel (gray-world simplifié) + if adaptive: + mean_rgb = img.reshape(-1, 3).mean(axis=0) + mean_rgb = np.maximum(mean_rgb, 1e-6) + + # normalisation sur G + wb_ratio = mean_rgb / mean_rgb[1] + + # mesure du déséquilibre + imbalance = np.std(wb_ratio) + + # facteur adaptatif doux (évite overcorrection) + adaptive_factor = 1.0 + min(1.0, imbalance * 2.0) + else: + adaptive_factor = 1.0 + + # --- 3. Interpolation (ta logique conservée 💡) + final_gain = (1 - strength) + strength * base_gain * adaptive_factor + + # --- 4. Clamp sécurité (très important en pratique) + final_gain = np.clip(final_gain, 1 / max_gain, max_gain) + + # --- 5. Application + img *= final_gain + + img = np.clip(img, 0, 1) + + if debug: + print("=== DEBUG TEMP ===") + print(f"mean_rgb: {mean_rgb if adaptive else 'disabled'}") + print(f"base_gain: {base_gain}") + print(f"adaptive_factor: {adaptive_factor}") + print(f"final_gain: {final_gain}") + print("==================") + + return Image.fromarray((img * 255).astype(np.uint8)) + + +def adjust_color_temperature_basic(image, target_temp=10000, reference_temp=6500, strength=0.5): + import numpy as np + + img = np.array(image).astype(np.float32) / 255.0 + + r1, g1, b1 = kelvin_to_rgb(reference_temp) + r2, g2, b2 = kelvin_to_rgb(target_temp) + + # 🔥 interpolation (clé) + r_gain = (1 - strength) + strength * (r2 / r1) + g_gain = (1 - strength) + strength * (g2 / g1) + b_gain = (1 - strength) + strength * (b2 / b1) + + img[..., 0] *= r_gain + img[..., 1] *= g_gain + img[..., 2] *= b_gain + + img = np.clip(img, 0, 1) + return Image.fromarray((img * 255).astype(np.uint8)) + +def adjust_color_temperature_simple(image, target_temp=7800, reference_temp=6500): + import numpy as np + + img = np.array(image).astype(np.float32) / 255.0 + + # Gains relatifs (IMPORTANT → comme GIMP) + r1, g1, b1 = kelvin_to_rgb(reference_temp) + r2, g2, b2 = kelvin_to_rgb(target_temp) + + r_gain = r2 / r1 + g_gain = g2 / g1 + b_gain = b2 / b1 + + img[..., 0] *= r_gain + img[..., 1] *= g_gain + img[..., 2] *= b_gain + + img = np.clip(img, 0, 1) + return Image.fromarray((img * 255).astype(np.uint8)) + + +def soft_tone_map(img): + import numpy as np + + arr = np.array(img).astype(np.float32) / 255.0 + + # 🔥 contraste léger (au lieu de compression) + mean = arr.mean(axis=(0,1), keepdims=True) + arr = (arr - mean) * 1.1 + mean + + return Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8)) + +def soft_tone_map_unreal(img, exposure=1.0): + import numpy as np + + arr = np.array(img).astype(np.float32) / 255.0 + + # 🔥 exposure + arr = arr * exposure + + # 🔥 tone mapping type Reinhard (plus naturel) + mapped = arr / (1.0 + arr) + + # 🔥 léger contraste local (clé !) + mapped = np.power(mapped, 0.9) + + return Image.fromarray((np.clip(mapped, 0, 1) * 255).astype(np.uint8)) + + +def soft_tone_map_v1(img): + arr = np.array(img).astype(np.float32) / 255.0 + + # 🔥 compression plus douce (log-like) + arr = np.log1p(arr * 1.5) / np.log1p(1.5) + + # 🔥 léger adoucissement des contrastes + arr = np.power(arr, 0.95) + + return Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8)) + +def soft_tone_map1(img): + arr = np.array(img).astype(np.float32) / 255.0 + arr = arr / (arr + 0.2) + arr = np.power(arr, 0.95) + arr = np.clip(arr, 0, 1) + return Image.fromarray((arr * 255).astype(np.uint8)) + +def apply_n3r_pro_net(latents, model=None, strength=0.3, sanitize_fn=None): + if model is None or strength <= 0: + return latents + + try: + latents = latents.to(next(model.parameters()).dtype) + refined = model(latents) + + # 🔥 différence (detail map) + detail = refined - latents + + # 🔥 SMOOTH du détail (clé !!!) + detail = F.avg_pool2d(detail, kernel_size=3, stride=1, padding=1) + + # 🔥 injection contrôlée + latents = latents + strength * detail + + if sanitize_fn: + latents = sanitize_fn(latents) + + return latents + + except Exception as e: + print(f"[N3RProNet ERROR] {e}") + return latents + + +def apply_n3r_pro_net1(latents, model=None, strength=0.3, sanitize_fn=None): + if model is None or strength <= 0: + return latents + + try: + dtype = next(model.parameters()).dtype + latents = latents.to(dtype) + + refined = model(latents) + + # 🔥 CLAMP SAFE (évite explosion) + refined = torch.clamp(refined, -2.5, 2.5) + + # 🔥 BLEND DOUX (beaucoup plus stable) + latents = (1 - strength) * latents + strength * refined + + # 🔥 NORMALISATION LÉGÈRE + latents = latents / (latents.std(dim=[1,2,3], keepdim=True) + 1e-6) + + if sanitize_fn: + latents = sanitize_fn(latents) + + return latents + + except Exception as e: + print(f"[N3RProNet ERROR] {e}") + return latents + + +def apply_n3r_pro_net_v1(latents, model=None, strength=0.3, sanitize_fn=None, frame_idx=None, total_frames=None): + if model is None or strength <= 0: + return latents + + try: + model_dtype = next(model.parameters()).dtype + model_device = next(model.parameters()).device + latents = latents.to(dtype=model_dtype, device=model_device) + latents = ensure_4_channels(latents) + + if frame_idx is not None and total_frames is not None: + adaptive_strength = strength * (0.3 + 0.7 * 0.5 * (1 - math.cos(math.pi * frame_idx / total_frames))) + else: + adaptive_strength = strength + + refined = model(latents) + + # 🔹 Normalisation du delta pour éviter saturation + delta = refined - latents + max_delta = delta.abs().amax(dim=(1,2,3), keepdim=True).clamp(min=1e-5) + delta = delta / max_delta + latents = latents + adaptive_strength * delta + + # 🔹 Clamp léger pour stabilité + latents = latents / latents.abs().amax(dim=(1,2,3), keepdim=True).clamp(min=1.0) + + if sanitize_fn: + latents = sanitize_fn(latents) + + return latents + + except Exception as e: + print(f"[N3RProNet ERROR] {e}") + return latents + + + +def full_frame_postprocess_add( frame_pil: Image.Image, output_dir: Path, frame_counter: int, target_temp: int = 7800, reference_temp: int = 6500, temp_strength: float = 0.22, blur_radius: float = 0.03, contrast: float = 1.10, saturation: float = 1.0, sharpen_percent: int = 90, psave: bool = True, unreal: bool = False, cartoon: bool = False , glow: bool = False) -> Image.Image: + """ + Returns: + frame_pil final traité + """ + removewhite = False + minimal = False + + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="01", psave=psave) + # 🔥 1. Température + frame_pil = adjust_color_temperature( + frame_pil, + target_temp=target_temp, + reference_temp=reference_temp, + strength=temp_strength + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="02", psave=psave) + + # 🔥 2. Neutralisation de la dominante + frame_pil = neutralize_color_cast(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="03", psave=psave) + + # 🔥 3. Tone mapping + frame_pil = soft_tone_map(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="04", psave=psave) + + # 🔥 4. Post-traitement adaptatif + if minimal: + frame_pil = apply_post_processing_minimal( + frame_pil, + blur_radius=blur_radius, + contrast=contrast, + vibrance_base=1.0, + vibrance_max=1.1, + sharpen=True, + sharpen_radius=1, + sharpen_percent=sharpen_percent, + sharpen_threshold=2 + ) + else: + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=0.03, + contrast=1.10, + vibrance_strength=0.05, # 🔥 contrôle simple (0 → off, 0.3 = doux) + sharpen=False, + sharpen_radius=1, + sharpen_percent=90, + sharpen_threshold=2, + clamp_r=True + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="05", psave=psave) + + + # 🔥 5. clean white Style + if removewhite: + frame_pil = remove_white_noise(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="06", psave=psave) + + # 🔥 6. Unreal Style + if unreal: + frame_pil = apply_post_processing_unreal_cinematic(frame_pil) + frame_pil = smooth_edges(frame_pil, strength=0.35, blur_radius=1.0) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="07", psave=psave) + + elif cartoon: + # 🔥 6. Cartoon Style + frame_pil = apply_post_processing_sketch(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="08", psave=psave) + + # 🔥 7. Glow Style + if glow: + # Glow forcé pour le style + frame_pil = apply_chromatic_soft_glow(frame_pil) + frame_pil = apply_localized_soft_glow(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="09", psave=psave) + else: + # Glow intelligent + frame_pil = apply_intelligent_glow( frame_pil ) + from PIL import ImageEnhance + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(1.04) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="09", psave=psave) + + return frame_pil + + + +def full_frame_postprocess( + frame_pil: Image.Image, + output_dir: Path, + frame_counter: int, + target_temp: int = 7800, + reference_temp: int = 6500, + temp_strength: float = 0.20, # 🔥 légèrement réduit (moins bleu) + blur_radius: float = 0.025, # 🔥 un peu moins de blur global + contrast: float = 1.08, # 🔥 évite sur-contraste cumulé + sharpen_percent: int = 90, + psave: bool = True, + unreal: bool = False, + cartoon: bool = False +) -> Image.Image: + + # ---------------- 1️⃣ Input ---------------- + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="01", psave=psave) + + # ---------------- 2️⃣ Température ---------------- + frame_pil = adjust_color_temperature( + frame_pil, + target_temp=target_temp, + reference_temp=reference_temp, + strength=temp_strength + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="02", psave=psave) + + # ---------------- 3️⃣ Neutralisation (adoucie) ---------------- + frame_pil = neutralize_color_cast(frame_pil, strength=0.6) # 🔥 clé + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="03", psave=psave) + + # ---------------- 4️⃣ Tone mapping (plus doux) ---------------- + frame_pil = soft_tone_map(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="04", psave=psave) + + # ---------------- 5️⃣ Adaptive (nettoyage + micro boost) ---------------- + frame_pil = apply_post_processing_adaptive( + frame_pil, + blur_radius=blur_radius, + contrast=contrast, + vibrance_strength=0.22, # 🔥 légèrement réduit + sharpen=False, + clamp_r=True + ) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="05", psave=psave) + + # ---------------- 6️⃣ Stylisation ---------------- + if unreal: + frame_pil = apply_post_processing_unreal_cinematic(frame_pil) + frame_pil = smooth_edges(frame_pil, strength=0.30, blur_radius=0.8) # 🔥 moins destructif + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="06", psave=psave) + + elif cartoon: + frame_pil = apply_post_processing_sketch(frame_pil) + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="07", psave=psave) + + # ---------------- 7️⃣ Glow intelligent (rééquilibré) ---------------- + # strength=0.15 edge_weight=0.5 luminance_weight=0.8 + + frame_pil = apply_intelligent_glow_pro( + frame_pil, + strength=0.18, # 🔥 moins agressif + edge_weight=0.6, # 🔥 priorise edges + luminance_weight=0.8 # 🔥 glow sur zones lumineuses + ) + + # 🔥 micro contraste FINAL (après glow → très important) + from PIL import ImageEnhance + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(1.04) + + save_frame_verbose(frame_pil, output_dir, frame_counter, suffix="09", psave=psave) + + return frame_pil diff --git a/scripts/utils/n3r_utils.py b/scripts/utils/n3r_utils.py new file mode 100644 index 00000000..d5c9c304 --- /dev/null +++ b/scripts/utils/n3r_utils.py @@ -0,0 +1,5834 @@ +# n3r_utils.py +import torch +from torchvision import transforms +from PIL import Image +from pathlib import Path +from torchvision.utils import save_image +import ffmpeg +import torch.nn as nn +import math +from tqdm import tqdm +from diffusers import UNet2DConditionModel, AutoencoderKL, DPMSolverMultistepScheduler +import os +import numpy as np +import yaml +import torch.nn.functional as F +import torchvision.transforms as T +from transformers import CLIPTokenizer, CLIPTextModel +from einops import rearrange +from math import ceil +import copy +from torchvision.transforms import ToPILImage + +LATENT_SCALE = 0.18215 # échelle typique pour SD/AnimateDiff + + +def adapt_embeddings_to_unet(pos_embeds, neg_embeds, target_dim): + """Adapte automatiquement les embeddings texte pour correspondre au cross_attention_dim du UNet.""" + current_dim = pos_embeds.shape[-1] + if current_dim == target_dim: + return pos_embeds, neg_embeds + # Troncature + if current_dim > target_dim: + pos_embeds = pos_embeds[..., :target_dim] + neg_embeds = neg_embeds[..., :target_dim] + # Padding + elif current_dim < target_dim: + pad = target_dim - current_dim + pos_embeds = torch.nn.functional.pad(pos_embeds, (0, pad)) + neg_embeds = torch.nn.functional.pad(neg_embeds, (0, pad)) + return pos_embeds, neg_embeds + +# -----n3r_utils----------- Ultra-safe embeddings pour UNet ---------------- +# 🔹 Projection embeddings pour UNet +# Singleton pour la projection UNet + +class UNetEmbeddingProjector(nn.Module): + _singleton_instance = None + + def __init__(self, in_dim=1024, out_dim=768, device='cuda'): + super().__init__() + self.proj = nn.Linear(in_dim, out_dim).to(device) + + @classmethod + def get_instance(cls, in_dim=1024, out_dim=768, device='cuda'): + if cls._singleton_instance is None: + cls._singleton_instance = cls(in_dim, out_dim, device) + return cls._singleton_instance + + def forward(self, embeds): + if embeds.shape[-1] != 1024: + raise ValueError( + f"❌ Embeddings inattendus : dernier dim={embeds.shape[-1]}, attendu=1024" + ) + return self.proj(embeds) + + +def prepare_embeddings_for_unet(embeds, device='cuda'): + embeds = embeds.to(device) + projector = UNetEmbeddingProjector.get_instance(device=device) + projected = projector(embeds) + if projected.shape[-1] != 768: + raise ValueError( + f"❌ UNet embeddings projetés ont mauvaise dimension : {projected.shape[-1]}, attendu=768" + ) + return projected + +# ---------------- Ultra-safe decode patchwise ---------------- +from tqdm import trange +from concurrent.futures import ThreadPoolExecutor + + +def generate_latents_safe_debug(unet, **kwargs): + """ + Génération de latents avec UNet, FP16-safe et debug. + Affiche min/max à chaque étape pour vérifier les latents. + """ + + # Arguments connus + known_kwargs = [ + "scheduler", "input_latents", "embeddings", "motion_module", + "guidance_scale", "device", "fp16", "steps", "debug" + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in known_kwargs} + + # Arguments obligatoires + input_latents = filtered_kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' doit être fourni") + + device = filtered_kwargs.get("device", "cuda") + fp16 = filtered_kwargs.get("fp16", True) + debug = filtered_kwargs.get("debug", False) + motion_module = filtered_kwargs.get("motion_module", None) + scheduler = filtered_kwargs.get("scheduler", None) + steps = filtered_kwargs.get("steps", 20) + embeddings = filtered_kwargs.get("embeddings", None) + guidance_scale = filtered_kwargs.get("guidance_scale", 4.0) + + # Latents initiaux + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + is_video = latents.ndim == 5 # [B, C, F, H, W] + steps_list = getattr(scheduler, "timesteps", range(steps)) + + if debug: + print(f"[DEBUG] Initial latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + # Désactiver Dynamo/Inductor pour debug + torch._dynamo.reset() + torch._dynamo.disable() + + for step_idx, t in enumerate(steps_list): + + # Motion module si présent + if motion_module is not None: + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Préparer latents pour UNet + if is_video: + B, C, F, H, W = latents.shape + latents_unet = latents.reshape(B*F, C, H, W) + else: + latents_unet = latents + + # FP16 si demandé + latents_unet = latents_unet.half() if fp16 else latents_unet.float() + + # 🔹 Préparer embeddings (concat pour CF guidance) + if isinstance(embeddings, tuple): + pos_embeds, neg_embeds = embeddings + encoder_states = torch.cat([neg_embeds.to(device), pos_embeds.to(device)], dim=0) + else: + encoder_states = embeddings + + # 🔹 Forward UNet + try: + unet_out = unet(latents_unet, t, encoder_hidden_states=encoder_states) + + # Détecter type de sortie UNet + if isinstance(unet_out, dict): + latents_out = unet_out["sample"] + elif isinstance(unet_out, (tuple, list)): + latents_out = unet_out[0] + else: + raise TypeError(f"Unexpected UNet output type: {type(unet_out)}") + + except Exception as e: + print(f"⚠️ [UNet ERROR] step={step_idx} - {e}") + continue + + # Reshape si vidéo + if is_video: + latents = latents_out.reshape(B, F, C, H, W) + else: + latents = latents_out + + # Clamp et nettoyage + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5.0, 5.0) + + if debug: + print(f"[DEBUG Step {step_idx}] latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + if debug: + print(f"[INFO] Finished latents generation, final min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + return latents + + +# ---------------- UTILS ---------------- +def prepare_embeddings_for_unet_safe(pos_embeds, neg_embeds, target_dim=512, device="cuda"): + """ + Prépare les embeddings pour UNet en tronquant ou projetant à target_dim. + Concatène négatifs + positifs pour classifier-free guidance. + LoRA peut rester sur 1024, UNet reçoit target_dim. + """ + pos_embeds = pos_embeds.to(device) + neg_embeds = neg_embeds.to(device) + + # Tronquer si plus grand que target_dim + if pos_embeds.size(-1) > target_dim: + pos_embeds = pos_embeds[..., :target_dim] + neg_embeds = neg_embeds[..., :target_dim] + + # Concat CF guidance + encoder_states = torch.cat([neg_embeds, pos_embeds], dim=0) + return encoder_states + + +# ---------------- Latents generator wrapper ---------------- +# ---------------- wrapper ultra-safe ---------------- +def generate_latents_safe_wrapper(unet, **kwargs): + """ + Wrapper ultra-safe pour générer les latents avec UNet. + - Gère CUDA OOM + - Accepte directement `encoder_states` pour UNet (projection déjà faite) + """ + + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be provided") + + encoder_states = kwargs.get("encoder_states", None) + motion_module = kwargs.get("motion_module", None) + guidance_scale = kwargs.get("guidance_scale", 1.0) + scheduler = kwargs.get("scheduler", None) + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + steps = kwargs.get("steps", 20) + debug = kwargs.get("debug", False) + + filtered_kwargs = { + "input_latents": input_latents, + "scheduler": scheduler, + "embeddings": encoder_states, # juste passer encoder_states comme embeddings + "motion_module": motion_module, + "guidance_scale": guidance_scale, + "device": device, + "fp16": fp16, + "steps": steps, + "debug": debug + } + + try: + return generate_latents_safe_debug(unet, **filtered_kwargs) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [SAFE WRAPPER] CUDA out of memory caught, returning input latents") + return input_latents.clone() + else: + raise e + + +import torch + +def generate_latents_safe_debug_v3(unet, **kwargs): + """ + Génération de latents avec UNet, FP16-safe et debug. + Corrige automatiquement la dimension des embeddings pour cross-attention UNet. + Version vidéo stable : toutes les frames sont conservées. + Affiche min/max à chaque étape pour détecter si les latents restent à zéro. + """ + # Arguments connus + known_kwargs = [ + "scheduler", "input_latents", "embeddings", "motion_module", + "guidance_scale", "device", "fp16", "steps", "debug" + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in known_kwargs} + + # Arguments requis + input_latents = filtered_kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' doit être fourni") + + device = filtered_kwargs.get("device", "cuda") + fp16 = filtered_kwargs.get("fp16", True) + debug = filtered_kwargs.get("debug", False) + motion_module = filtered_kwargs.get("motion_module", None) + scheduler = filtered_kwargs.get("scheduler", None) + steps = filtered_kwargs.get("steps", 20) + embeddings = filtered_kwargs.get("embeddings", None) + + # Latents initiaux + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + is_video = latents.ndim == 5 # [B, C, F, H, W] + steps_list = getattr(scheduler, "timesteps", range(steps)) + + if debug: + print(f"[DEBUG] Initial latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + # Désactiver Dynamo/Inductor pour debug + torch._dynamo.reset() + torch._dynamo.disable() + + # Récupération cross_attention_dim du UNet + unet_cross_attention_dim = getattr(unet.config, "cross_attention_dim", 768) # SD1.x = 768 + + # Fonction interne pour adapter embeddings + def adapt_embeddings(pos_embeds, neg_embeds, target_dim): + current_dim = pos_embeds.shape[-1] + if current_dim == target_dim: + return pos_embeds, neg_embeds + elif current_dim > target_dim: + pos_embeds = pos_embeds[..., :target_dim] + neg_embeds = neg_embeds[..., :target_dim] + else: + pad = target_dim - current_dim + pos_embeds = torch.nn.functional.pad(pos_embeds, (0, pad)) + neg_embeds = torch.nn.functional.pad(neg_embeds, (0, pad)) + return pos_embeds, neg_embeds + + for step_idx, t in enumerate(steps_list): + # Motion module si présent + if motion_module is not None: + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Préparer latents pour UNet + if is_video: + B, C, F, H, W = latents.shape + latents_unet = latents.reshape(B*F, C, H, W) + else: + latents_unet = latents + + # FP16 si demandé + latents_unet = latents_unet.half() if fp16 else latents_unet.float() + + # Préparer embeddings pour CF guidance + encoder_states = None + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + pos_embeds, neg_embeds = adapt_embeddings(pos_embeds.to(device), neg_embeds.to(device), unet_cross_attention_dim) + encoder_states = torch.cat([neg_embeds, pos_embeds], dim=0) + elif embeddings is not None: + # Simple embeddings + if embeddings.shape[-1] != unet_cross_attention_dim: + linear_proj = torch.nn.Linear(embeddings.shape[-1], unet_cross_attention_dim).to(device) + encoder_states = linear_proj(embeddings.to(device)) + else: + encoder_states = embeddings.to(device) + + # 🔹 Forward UNet + try: + unet_out = unet(latents_unet, t, encoder_hidden_states=encoder_states) + if isinstance(unet_out, dict): + latents_out = unet_out["sample"] + elif isinstance(unet_out, (tuple, list)): + latents_out = unet_out[0] + else: + raise TypeError(f"Unexpected UNet output type: {type(unet_out)}") + except Exception as e: + print(f"⚠️ [UNet ERROR] step={step_idx} - {e}") + continue + + # Reshape si vidéo + if is_video: + latents = latents_out.reshape(B, F, C, H, W) + else: + latents = latents_out + + # Nettoyage et clamp + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5.0, 5.0) + + if debug: + print(f"[DEBUG Step {step_idx}] latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + if debug: + print(f"[INFO] Finished latents generation, final min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + return latents + + + + +def generate_latents_safe_debug_v2(unet, **kwargs): + """ + Génération de latents avec UNet, FP16-safe et debug. + Version vidéo stable : toutes les frames sont conservées. + Affiche min/max à chaque étape pour détecter si les latents restent à zéro. + """ + + import torch + + # Arguments connus + known_kwargs = [ + "scheduler", "input_latents", "embeddings", "motion_module", + "guidance_scale", "device", "fp16", "steps", "debug" + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in known_kwargs} + + # Required + input_latents = filtered_kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' doit être fourni") + + device = filtered_kwargs.get("device", "cuda") + fp16 = filtered_kwargs.get("fp16", True) + debug = filtered_kwargs.get("debug", False) + motion_module = filtered_kwargs.get("motion_module", None) + scheduler = filtered_kwargs.get("scheduler", None) + steps = filtered_kwargs.get("steps", 20) + embeddings = filtered_kwargs.get("embeddings", None) + + # Latents initiaux + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + is_video = latents.ndim == 5 # [B, C, F, H, W] + steps_list = getattr(scheduler, "timesteps", range(steps)) + + if debug: + print(f"[DEBUG] Initial latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + # Désactiver Dynamo/Inductor pour debug + torch._dynamo.reset() + torch._dynamo.disable() + + for step_idx, t in enumerate(steps_list): + + # Motion module si présent + if motion_module is not None: + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # Préparer latents pour UNet + if is_video: + B, C, F, H, W = latents.shape + latents_unet = latents.reshape(B*F, C, H, W) + else: + latents_unet = latents + + # FP16 si demandé + latents_unet = latents_unet.half() if fp16 else latents_unet.float() + + # 🔹 Préparer embeddings + if isinstance(embeddings, tuple): + pos_embeds, neg_embeds = embeddings + encoder_states = torch.cat([neg_embeds.to(device), pos_embeds.to(device)], dim=0) + else: + encoder_states = embeddings + + # 🔹 Forward UNet + try: + unet_out = unet(latents_unet, t, encoder_hidden_states=encoder_states) + + # Détecter type de sortie UNet + if isinstance(unet_out, dict): + latents_out = unet_out["sample"] + elif isinstance(unet_out, (tuple, list)): + latents_out = unet_out[0] # prendre la première valeur (sample) + else: + raise TypeError(f"Unexpected UNet output type: {type(unet_out)}") + + except Exception as e: + print(f"⚠️ [UNet ERROR] step={step_idx} - {e}") + continue + + # Reshape si vidéo + if is_video: + latents = latents_out.reshape(B, F, C, H, W) + else: + latents = latents_out + + # Nettoyage et clamp + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5.0, 5.0) + + if debug: + print(f"[DEBUG Step {step_idx}] latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + if debug: + print(f"[INFO] Finished latents generation, final min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + return latents + + +# ---------------- Wrapper ultra-safe ---------------- +# ici plus d'erreur cross_attention_dim mais génération reste incorrect: ⚠ [UNet ERROR] step=3 - The size of tensor a (2048) must match the size of tensor b (4096) at non-singleton dimension 1 + +import torch +import torch.nn.functional as F + +# ---------------- SAFE UNET FORWARD ---------------- +def safe_forward_unet(unet, latents, t, encoder_hidden_states=None, device="cuda", fp16=True, debug=False): + """ + Forward UNet sécurisé pour mini GPU. + - Ajuste automatiquement H/W si nécessaire + - Retourne toujours les latents (sample) + """ + # Vérifie H/W selon sample_size du UNet + target_H = getattr(unet.config, "sample_size", latents.shape[2]*8) // 8 + target_W = getattr(unet.config, "sample_size", latents.shape[3]*8) // 8 + if (latents.shape[2], latents.shape[3]) != (target_H, target_W): + latents = F.interpolate(latents, size=(target_H, target_W), + mode="bilinear", align_corners=False) + if debug: + print(f"[DEBUG] Latents resized to {latents.shape[2:]}") + + # Mettre en fp16 si demandé + if fp16: + latents = latents.half() + + # Forward UNet + out = unet(latents, t, encoder_hidden_states=encoder_hidden_states) + if isinstance(out, dict): + sample = out.get("sample", None) + if sample is None: + raise ValueError("UNet output dict does not contain 'sample'") + return sample + elif isinstance(out, (tuple, list)): + return out[0] + else: + raise TypeError(f"Unexpected UNet output type: {type(out)}") + + +def run_diffusion_pipeline(unet, vae, scheduler, images, embeddings, + timesteps, guidance_scale=4.0, device="cuda", + fp16=True, debug=True): + """ + Pipeline complet avec guidance scale : + - Encode les images en latents + - Forward UNet étape par étape (safe) + - Applique le guidance scale correctement + """ + # Encoder les images en latents + latents = encode_images_to_latents_safe(images, vae, device=device) + latents = latents.to(device=device, dtype=torch.float16 if fp16 else torch.float32) + + # Récupère les embeddings pos/neg + pos_embeds, neg_embeds = embeddings + # Duplique latents pour concat pos+neg + latents_input = torch.cat([latents, latents], dim=0) + # Concat embeddings + encoder_hidden_states = torch.cat([pos_embeds.to(device), neg_embeds.to(device)], dim=0) + + for step, timestep in enumerate(timesteps): + if debug: + print(f"\n[INFO] Step {step}/{len(timesteps)-1} - timestep={timestep}") + + # 🔹 Forward UNet sécurisé + noise_pred_all = safe_forward_unet( + unet=unet, + latents=latents_input, + t=timestep, + encoder_hidden_states=encoder_hidden_states, + device=device + ) + + # 🔹 Appliquer guidance scale + batch = latents.shape[0] + # Formula: noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond) + noise_pred = noise_pred_all[batch:] * guidance_scale + noise_pred_all[:batch] * (1 - guidance_scale) + + # 🔹 Mise à jour des latents via scheduler + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + + # 🔹 Debug min/max pour NaN/Inf + if debug: + print(f"[DEBUG] Latents min/max: {latents.min().item():.4f}/{latents.max().item():.4f}") + nan_count = torch.isnan(latents).sum().item() + inf_count = torch.isinf(latents).sum().item() + if nan_count or inf_count: + print(f"⚠️ [DEBUG Step {step}] Found NaN: {nan_count}, Inf: {inf_count}") + + return latents + + +# ---------------- LATENTS GENERATION SAFE DEBUG ---------------- + +def generate_latents_safe_debug_v5(unet, **kwargs): + """ + Génération de latents avec UNet, FP16-safe et debug. + Supporte init_image_scale et creative_noise. + Logs détaillés : shape, min/max, cross_attention_dim, NaN/Inf. + Compatible mini GPU. + """ + import torch + import torch.nn.functional as F + + # ---------------- Filtrage kwargs ---------------- + known_kwargs = [ + "scheduler", "input_latents", "embeddings", "motion_module", + "guidance_scale", "device", "fp16", "steps", "debug", + "init_image_scale", "creative_noise" + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in known_kwargs} + + input_latents = filtered_kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' doit être fourni") + + device = filtered_kwargs.get("device", "cuda") + fp16 = filtered_kwargs.get("fp16", True) + debug = filtered_kwargs.get("debug", False) + motion_module = filtered_kwargs.get("motion_module", None) + scheduler = filtered_kwargs.get("scheduler", None) + steps = filtered_kwargs.get("steps", 20) + embeddings = filtered_kwargs.get("embeddings", None) + + init_image_scale = filtered_kwargs.get("init_image_scale", 1.0) + creative_noise = filtered_kwargs.get("creative_noise", 0.0) + + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + original_latents = latents.clone() + is_video = latents.ndim == 5 # [B,C,F,H,W] + steps_list = getattr(scheduler, "timesteps", range(steps)) + + if debug: + print(f"[DEBUG] Initial latents shape: {latents.shape}") + print(f"[DEBUG] Initial latents min/max: {latents.min().item():.4f}/{latents.max().item():.4f}") + if motion_module: + print(f"[DEBUG] Motion module detected: {motion_module}") + if embeddings: + if isinstance(embeddings, tuple): + print(f"[DEBUG] Embeddings shapes - pos: {embeddings[0].shape}, neg: {embeddings[1].shape}") + else: + print(f"[DEBUG] Embeddings shape: {embeddings.shape}") + + # ---------------- Désactive Dynamo ---------------- + torch._dynamo.reset() + torch._dynamo.disable() + + # ---------------- Boucle de steps ---------------- + for step_idx, t in enumerate(steps_list): + if debug: + print(f"\n[DEBUG] Step {step_idx}/{len(steps_list)} - timestep {t}") + + # ---------------- Motion Module ---------------- + if motion_module is not None: + latents_before_motion = latents.clone() + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + if debug: + diff = (latents - latents_before_motion).abs().max() + print(f"[DEBUG] Motion module diff max abs: {diff:.4f}") + + # ---------------- Préparer latents pour UNet ---------------- + if is_video: + B, C, F, H, W = latents.shape + latents_unet = latents.reshape(B*F, C, H, W) + else: + latents_unet = latents + + latents_unet = latents_unet.half() if fp16 else latents_unet.float() + + # ---------------- Préparer embeddings ---------------- + if isinstance(embeddings, tuple): + pos_embeds, neg_embeds = embeddings + encoder_states = pos_embeds.to(device) # ⚠ seulement pos_embeds + else: + encoder_states = embeddings + + if debug: + print(f"[DEBUG] Latents input to UNet shape: {latents_unet.shape}") + if embeddings: print(f"[DEBUG] Encoder states shape: {encoder_states.shape}") + + # ---------------- UNet SAFE FORWARD ---------------- + try: + latents_out = safe_forward_unet(unet, latents_unet, t, encoder_hidden_states=encoder_states, device=device) + except Exception as e: + print(f"⚠️ [UNet ERROR] step={step_idx} - {e}") + continue + + # ---------------- Reshape si vidéo ---------------- + if is_video: + latents = latents_out.reshape(B, F, C, H, W) + else: + latents = latents_out + + # ---------------- Clamp & Nan/Inf ---------------- + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5.0, 5.0) + + # ---------------- Ajout init_image_scale ---------------- + if init_image_scale > 0.0: + # Redimension original_latents vers latents actuels + resized_orig = F.interpolate(original_latents, size=latents.shape[-2:], mode='bilinear', align_corners=False) + latents = init_image_scale * resized_orig + (1.0 - init_image_scale) * latents + + # ---------------- Ajout creative_noise ---------------- + if creative_noise > 0.0: + noise = torch.randn_like(latents) * creative_noise + latents = latents + noise + + if debug: + nan_count = torch.isnan(latents).sum().item() + inf_count = torch.isinf(latents).sum().item() + print(f"[DEBUG Step {step_idx}] latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}, NaN={nan_count}, Inf={inf_count}") + + if debug: + print(f"\n[INFO] Finished latents generation, final shape={latents.shape}, min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + return latents + + + +# ---------------- LATENTS GENERATION SAFE DEBUG (avec init_image_scale & creative_noise) ---------------- + +def generate_latents_safe_debug_v4(unet, **kwargs): + """ + Génération de latents avec UNet, FP16-safe et debug. + Logs détaillés : shape, min/max, cross_attention_dim, NaN/Inf. + Compatible mini GPU. + """ + import torch + import torch.nn.functional as F + + # ---------------- Filtrage kwargs ---------------- + known_kwargs = [ + "scheduler", "input_latents", "embeddings", "motion_module", + "guidance_scale", "device", "fp16", "steps", "debug" + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in known_kwargs} + + input_latents = filtered_kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' doit être fourni") + + device = filtered_kwargs.get("device", "cuda") + fp16 = filtered_kwargs.get("fp16", True) + debug = filtered_kwargs.get("debug", False) + motion_module = filtered_kwargs.get("motion_module", None) + scheduler = filtered_kwargs.get("scheduler", None) + steps = filtered_kwargs.get("steps", 20) + embeddings = filtered_kwargs.get("embeddings", None) + + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + is_video = latents.ndim == 5 # [B,C,F,H,W] + steps_list = getattr(scheduler, "timesteps", range(steps)) + + if debug: + print(f"[DEBUG] Initial latents shape: {latents.shape}") + print(f"[DEBUG] Initial latents min/max: {latents.min().item():.4f}/{latents.max().item():.4f}") + if motion_module: + print(f"[DEBUG] Motion module detected: {motion_module}") + if embeddings: + if isinstance(embeddings, tuple): + print(f"[DEBUG] Embeddings shapes - pos: {embeddings[0].shape}, neg: {embeddings[1].shape}") + else: + print(f"[DEBUG] Embeddings shape: {embeddings.shape}") + + # ---------------- Désactive Dynamo ---------------- + torch._dynamo.reset() + torch._dynamo.disable() + + # ---------------- Boucle de steps ---------------- + for step_idx, t in enumerate(steps_list): + if debug: + print(f"\n[DEBUG] Step {step_idx}/{len(steps_list)} - timestep {t}") + + # ---------------- Motion Module ---------------- + if motion_module is not None: + latents_before_motion = latents.clone() + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + if debug: + diff = (latents - latents_before_motion).abs().max() + print(f"[DEBUG] Motion module diff max abs: {diff:.4f}") + + # ---------------- Préparer latents pour UNet ---------------- + if is_video: + B, C, F, H, W = latents.shape + latents_unet = latents.reshape(B*F, C, H, W) + else: + latents_unet = latents + + latents_unet = latents_unet.half() if fp16 else latents_unet.float() + + # ---------------- Préparer embeddings ---------------- + if isinstance(embeddings, tuple): + pos_embeds, neg_embeds = embeddings + encoder_states = pos_embeds.to(device) # ⚠ seulement pos_embeds + else: + encoder_states = embeddings + + if debug: + print(f"[DEBUG] Latents input to UNet shape: {latents_unet.shape}") + if embeddings: print(f"[DEBUG] Encoder states shape: {encoder_states.shape}") + + # ---------------- UNet SAFE FORWARD ---------------- + try: + latents_out = safe_forward_unet(unet, latents_unet, t, encoder_hidden_states=encoder_states, device=device) + except Exception as e: + print(f"⚠️ [UNet ERROR] step={step_idx} - {e}") + continue + + # ---------------- Reshape si vidéo ---------------- + if is_video: + latents = latents_out.reshape(B, F, C, H, W) + else: + latents = latents_out + + # ---------------- Clamp & Nan/Inf ---------------- + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5.0, 5.0) + + if debug: + nan_count = torch.isnan(latents).sum().item() + inf_count = torch.isinf(latents).sum().item() + print(f"[DEBUG Step {step_idx}] latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}, NaN={nan_count}, Inf={inf_count}") + + if debug: + print(f"\n[INFO] Finished latents generation, final shape={latents.shape}, min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + return latents + + +# ---------------- MINI GPU SAFE 320 ---------------- +def generate_latents_mini_gpu_320(unet, **kwargs): + """ + Wrapper ultra-safe pour GPU <4 Go et UNet miniSD 320-640-1280-1280. + - Utilise generate_latents_safe_debug_v4 + - Forçage latents aux bonnes dimensions + - Adaptation embeddings automatique + - Ajout : init_image_scale & creative_noise + """ + import torch + import torch.nn.functional as F + + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be fourni") + + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + motion_module = kwargs.get("motion_module", None) + scheduler = kwargs.get("scheduler", None) + guidance_scale = kwargs.get("guidance_scale", 1.0) + steps = kwargs.get("steps", 20) + debug = kwargs.get("debug", False) + + # Nouveaux params + init_image_scale = kwargs.get("init_image_scale", 0.0) + creative_noise = kwargs.get("creative_noise", 0.0) + + # ---------------- FORCER LATENTS ---------------- + expected_channels = getattr(unet.config, "in_channels", 4) + if input_latents.shape[1] != expected_channels: + pad = expected_channels - input_latents.shape[1] + if pad > 0: + input_latents = F.pad(input_latents, (0,0,0,0,0,0,0,pad)) + else: + input_latents = input_latents[:, :expected_channels] + + # Adapter H/W si nécessaire + if input_latents.ndim == 4: # B,C,H,W + H, W = input_latents.shape[2:] + elif input_latents.ndim == 5: # B,C,T,H,W + T = input_latents.shape[2] + if T == 1: + input_latents = input_latents.squeeze(2) + H, W = input_latents.shape[2:] + else: + raise ValueError(f"[LATENT FIX] Shape inattendue: {input_latents.shape}") + + # Mini GPU → forcer petite taille latents compatible UNet + target_size = min(getattr(unet.config, "sample_size", 320), 32) + if (H != target_size) or (W != target_size): + input_latents = F.interpolate(input_latents, size=(target_size, target_size), mode='nearest') + + kwargs["input_latents"] = input_latents + kwargs["init_image_scale"] = init_image_scale + kwargs["creative_noise"] = creative_noise + + # ---------------- Embeddings ---------------- + embeddings = kwargs.get("embeddings", None) + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + target_dim = getattr(unet.config, "cross_attention_dim", pos_embeds.shape[-1]) + pos_embeds, neg_embeds = adapt_embeddings_to_unet(pos_embeds, neg_embeds, target_dim) + embeddings = (pos_embeds.to(device), neg_embeds.to(device)) + elif embeddings is not None: + embeddings = embeddings.to(device) + kwargs["embeddings"] = embeddings + + # ---------------- Génération safe ---------------- + if motion_module is not None and input_latents.ndim == 5: # B,C,T,H,W + B,C,T,H,W = input_latents.shape + for t in range(T): + frame_latents = input_latents[:, :, t, :, :] + kwargs["input_latents"] = frame_latents + frame_latents = generate_latents_safe_debug_v5(unet, **kwargs) + input_latents[:, :, t, :, :] = frame_latents + return input_latents + else: + return generate_latents_safe_debug_v5(unet, **kwargs) + + + + + +def generate_latents_mini_gpu_320_v1(unet, **kwargs): + """ + Wrapper ultra-safe pour GPU <4 Go et UNet miniSD 320-640-1280-1280. + - Pas de duplication batch pour CF guidance + - Motion module appliqué frame par frame + - Forçage latents aux bonnes dimensions (miniGPU safe) + - Adaptation automatique des embeddings au cross_attention_dim du UNet + - Logs détaillés pour débogage + - Gère CUDA OOM + """ + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be fourni") + + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + motion_module = kwargs.get("motion_module", None) + scheduler = kwargs.get("scheduler", None) + guidance_scale = kwargs.get("guidance_scale", 1.0) + steps = kwargs.get("steps", 20) + debug = kwargs.get("debug", True) + + if debug: print(f"🔹 [DEBUG] input_latents.shape: {input_latents.shape}") + + # ---------------- FORCER LATENTS ---------------- + expected_channels = getattr(unet.config, "in_channels", 4) + if input_latents.shape[1] != expected_channels: + pad = expected_channels - input_latents.shape[1] + if pad > 0: + if debug: print(f"🔹 [DEBUG] Padding latents: +{pad} channels") + input_latents = F.pad(input_latents, (0,0,0,0,0,0,0,pad)) + else: + if debug: print(f"🔹 [DEBUG] Troncature latents: {input_latents.shape[1]} -> {expected_channels}") + input_latents = input_latents[:, :expected_channels] + + # Adapter H/W si nécessaire + if input_latents.ndim == 4: # B,C,H,W + H, W = input_latents.shape[2:] + elif input_latents.ndim == 5: # B,C,T,H,W + T = input_latents.shape[2] + if T == 1: + input_latents = input_latents.squeeze(2) + H, W = input_latents.shape[2:] + else: + raise ValueError(f"[LATENT FIX] Shape inattendue: {input_latents.shape}") + + if debug: print(f"🔹 [DEBUG] Latents H/W: {H}x{W}, channels: {input_latents.shape[1]}") + + # Mini GPU → forcer petite taille latents compatible UNet + target_size = getattr(unet.config, "sample_size", 32) # SD2-1.5 → 64 + if (H != target_size) or (W != target_size): + if debug: print(f"🔹 [DEBUG] Interpolation latents {H}x{W} -> {target_size}x{target_size}") + input_latents = F.interpolate(input_latents, size=(target_size, target_size), mode='nearest') + + kwargs["input_latents"] = input_latents + + # ---------------- Embeddings ---------------- + embeddings = kwargs.get("embeddings", None) + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + # Lecture du cross_attention_dim attendu par le UNet + target_dim = getattr(unet.config, "cross_attention_dim", pos_embeds.shape[-1]) + if pos_embeds.shape[-1] != target_dim: + if debug: print(f"🔄 Projection embeddings {pos_embeds.shape[-1]} -> {target_dim}") + projection = torch.nn.Linear(pos_embeds.shape[-1], target_dim).to(pos_embeds.device).to(pos_embeds.dtype) + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + if debug: + print(f"🔹 [DEBUG] pos_embeds.shape: {pos_embeds.shape}, neg_embeds.shape: {neg_embeds.shape}") + embeddings = (pos_embeds.to(device), neg_embeds.to(device)) + elif embeddings is not None: + embeddings = embeddings.to(device) + kwargs["embeddings"] = embeddings + + # ---------------- Filtrage arguments ---------------- + allowed_keys = [ + "input_latents", + "scheduler", + "embeddings", + "motion_module", + "guidance_scale", + "device", + "fp16", + "steps", + "debug", + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} + + # ---------------- Génération safe ---------------- + try: + if motion_module is not None: + latents = filtered_kwargs["input_latents"] + if latents.ndim == 5: # B,C,T,H,W + B, C, T, H, W = latents.shape + if debug: print(f"🔹 [DEBUG] Motion module présent, latents shape: {latents.shape}") + for t in range(T): + frame_latents = latents[:, :, t, :, :] + filtered_kwargs["input_latents"] = frame_latents + frame_latents = generate_latents_safe_debug_v4(unet, **filtered_kwargs) + latents[:, :, t, :, :] = frame_latents + return latents + else: + if debug: print("🔹 [DEBUG] Motion module absent ou latents 4D") + return generate_latents_safe_debug_v4(unet, **filtered_kwargs) + else: + if debug: print("🔹 [DEBUG] Motion module absent, génération standard") + return generate_latents_safe_debug_v4(unet, **filtered_kwargs) + + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [MINI GPU] CUDA OOM, retour des latents initiaux") + return input_latents.clone() + else: + raise e + + + + + +def generate_latents_mini_gpu_4go(unet, **kwargs): + """ + Wrapper ultra-safe pour GPU <4 Go et UNet miniSD 320-640-1280-1280. + - Pas de duplication batch pour CF guidance + - Motion module appliqué frame par frame + - Forçage latents aux bonnes dimensions (miniGPU safe) + - Gère CUDA OOM + """ + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be fourni") + + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + motion_module = kwargs.get("motion_module", None) + scheduler = kwargs.get("scheduler", None) + guidance_scale = kwargs.get("guidance_scale", 1.0) + steps = kwargs.get("steps", 20) + debug = kwargs.get("debug", True) + + # ---------------- FORCER LATENTS ---------------- + expected_channels = getattr(unet.config, "in_channels", 4) + if input_latents.shape[1] != expected_channels: + pad = expected_channels - input_latents.shape[1] + if pad > 0: + input_latents = F.pad(input_latents, (0,0,0,0,0,0,0,pad)) + else: + input_latents = input_latents[:, :expected_channels] + + # Adapter H/W si nécessaire + if input_latents.ndim == 4: # B,C,H,W + H, W = input_latents.shape[2:] + elif input_latents.ndim == 5: # B,C,T,H,W + T = input_latents.shape[2] + if T == 1: + input_latents = input_latents.squeeze(2) + H, W = input_latents.shape[2:] + else: + raise ValueError(f"[LATENT FIX] Shape inattendue: {input_latents.shape}") + + # Mini GPU → forcer petite taille latents compatible UNet 320-640-1280-1280 + target_size = min(unet.sample_size, 32) + if (H != target_size) or (W != target_size): + input_latents = F.interpolate( + input_latents, + size=(target_size, target_size), + mode='nearest' + ) + + kwargs["input_latents"] = input_latents + + # ---------------- Embeddings ---------------- + embeddings = kwargs.get("embeddings", None) + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + embeddings = (pos_embeds.to(device), neg_embeds.to(device)) + elif embeddings is not None: + embeddings = embeddings.to(device) + kwargs["embeddings"] = embeddings + + # ---------------- Filtrage arguments ---------------- + allowed_keys = [ + "input_latents", + "scheduler", + "embeddings", + "motion_module", + "guidance_scale", + "device", + "fp16", + "steps", + "debug", + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} + + # ---------------- Génération safe ---------------- + try: + if motion_module is not None: + latents = filtered_kwargs["input_latents"] + if latents.ndim == 5: # B,C,T,H,W + B, C, T, H, W = latents.shape + for t in range(T): + frame_latents = latents[:, :, t, :, :] + filtered_kwargs["input_latents"] = frame_latents + frame_latents = generate_latents_safe_debug_v2(unet, **filtered_kwargs) + latents[:, :, t, :, :] = frame_latents + return latents + else: + return generate_latents_safe_debug_v2(unet, **filtered_kwargs) + else: + return generate_latents_safe_debug_v2(unet, **filtered_kwargs) + + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [MINI GPU] CUDA OOM, retour des latents initiaux") + return input_latents.clone() + else: + raise e + + +def generate_latents_mini_gpu(unet, **kwargs): + """ + Wrapper ultra-safe pour GPU <4 Go. + - Forçage des latents aux dimensions attendues par le UNet + - Adaptation embeddings pour cross_attention_dim + - Gère CUDA OOM + """ + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be fourni") + + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + motion_module = kwargs.get("motion_module", None) # peut être None + scheduler = kwargs.get("scheduler", None) + guidance_scale = kwargs.get("guidance_scale", 1.0) + steps = kwargs.get("steps", 20) + debug = kwargs.get("debug", True) + + # ---------------- FORCER LATENTS ---------------- + expected_channels = getattr(unet.config, "in_channels", 4) + if input_latents.shape[1] != expected_channels: + pad = expected_channels - input_latents.shape[1] + if pad > 0: + input_latents = F.pad(input_latents, (0,0,0,0,0,0,0,pad)) + else: + input_latents = input_latents[:, :expected_channels] + + # Adapter H/W si nécessaire + if input_latents.ndim == 4: # B,C,H,W + H, W = input_latents.shape[2:] + elif input_latents.ndim == 5: # B,C,T,H,W + T = input_latents.shape[2] + if T == 1: + input_latents = input_latents.squeeze(2) + H, W = input_latents.shape[2:] + else: + raise ValueError(f"[LATENT FIX] Shape inattendue: {input_latents.shape}") + + # Mini GPU → forcer petite taille latents + target_size = min(getattr(unet, "sample_size", 32), 32) + if (H != target_size) or (W != target_size): + input_latents = F.interpolate( + input_latents, + size=(target_size, target_size), + mode='nearest' + ) + + kwargs["input_latents"] = input_latents + + # ---------------- Embeddings ---------------- + embeddings = kwargs.get("embeddings", None) + cross_dim = getattr(unet.config, "cross_attention_dim", 768) + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + # Pad / truncate embeddings pour correspondre au UNet + if pos_embeds.shape[-1] != cross_dim: + diff = cross_dim - pos_embeds.shape[-1] + if diff > 0: + pos_embeds = F.pad(pos_embeds, (0,diff)) + neg_embeds = F.pad(neg_embeds, (0,diff)) + else: + pos_embeds = pos_embeds[..., :cross_dim] + neg_embeds = neg_embeds[..., :cross_dim] + embeddings = (pos_embeds.to(device), neg_embeds.to(device)) + elif embeddings is not None: + embeddings = embeddings.to(device) + kwargs["embeddings"] = embeddings + + # ---------------- Filtrage arguments ---------------- + allowed_keys = [ + "input_latents", + "scheduler", + "embeddings", + "motion_module", + "guidance_scale", + "device", + "fp16", + "steps", + "debug", + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} + + # ---------------- Génération safe ---------------- + try: + # Motion module désactivé pour miniSD / VRAM ultra-light + return generate_latents_safe_debug_v2(unet, **filtered_kwargs) + + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [MINI GPU] CUDA OOM, retour des latents initiaux") + return input_latents.clone() + else: + raise e + + +def generate_latents_safe_miniGPU(unet, **kwargs): + """ + Wrapper ultra-safe mini GPU pour UNet SD1.x / SD2. + - Forcer latents ≥ 64x64 + - Motion module frame par frame + - Projection embeddings si nécessaire + - OOM safe + """ + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be provided") + + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + debug = kwargs.get("debug", True) + embeddings = kwargs.get("embeddings", None) + motion_module = kwargs.get("motion_module", None) + steps = kwargs.get("steps", 20) + scheduler = kwargs.get("scheduler", None) + guidance_scale = kwargs.get("guidance_scale", 7.5) + + # ---------------- Projection embeddings automatique ---------------- + unet_cross_dim = getattr(unet.config, "cross_attention_dim", 768) + + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + if pos_embeds.shape[-1] != unet_cross_dim: + linear_proj = torch.nn.Linear(pos_embeds.shape[-1], unet_cross_dim).to( + device, dtype=torch.float16 if fp16 else torch.float32 + ) + pos_embeds = linear_proj(pos_embeds.to(device)) + neg_embeds = linear_proj(neg_embeds.to(device)) + embeddings = (pos_embeds, neg_embeds) + elif embeddings is not None and embeddings.shape[-1] != unet_cross_dim: + linear_proj = torch.nn.Linear(embeddings.shape[-1], unet_cross_dim).to( + device, dtype=torch.float16 if fp16 else torch.float32 + ) + embeddings = linear_proj(embeddings.to(device)) + + kwargs["embeddings"] = embeddings + + # ---------------- FORCER LATENTS CHANNEL ---------------- + expected_channels = getattr(unet.config, "in_channels", 4) + if input_latents.shape[1] != expected_channels: + if input_latents.shape[1] < expected_channels: + pad = expected_channels - input_latents.shape[1] + input_latents = F.pad(input_latents, (0,0,0,0,0,0,0,pad)) + else: + input_latents = input_latents[:, :expected_channels] + + # ---------------- FORCER LATENTS H/W ---------------- + if input_latents.ndim == 4: # B,C,H,W + H, W = input_latents.shape[2:] + elif input_latents.ndim == 5: # B,C,T,H,W + T = input_latents.shape[2] + if T == 1: + input_latents = input_latents.squeeze(2) + H, W = input_latents.shape[2:] + else: + raise ValueError(f"[LATENT FIX] Shape inattendue: {input_latents.shape}") + + # SD1.x mini GPU → on force ≥ 64 + target_size = max(64, min(unet.sample_size, H, W)) + if (H != target_size) or (W != target_size): + input_latents = F.interpolate(input_latents, size=(target_size, target_size), mode='nearest') + + kwargs["input_latents"] = input_latents + + # ---------------- Filtrage arguments ---------------- + allowed_keys = [ + "input_latents", + "scheduler", + "embeddings", + "motion_module", + "guidance_scale", + "device", + "fp16", + "steps", + "debug", + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} + + # ---------------- Appel safe debug frame par frame ---------------- + try: + if motion_module is not None and input_latents.ndim == 5: # B,C,T,H,W + B, C, T, H, W = input_latents.shape + latents_out = [] + for t in range(T): + frame_latents = input_latents[:, :, t, :, :] + filtered_kwargs["input_latents"] = frame_latents + latents_frame = generate_latents_safe_debug_v2(unet, **filtered_kwargs) + latents_out.append(latents_frame.unsqueeze(2)) + return torch.cat(latents_out, dim=2) + else: + return generate_latents_safe_debug_v2(unet, **filtered_kwargs) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [SAFE WRAPPER] CUDA out of memory caught, returning input latents") + return input_latents.clone() + else: + raise e + +def generate_latents_safe_wrapper_final(unet, **kwargs): + """ + Wrapper ultra-safe pour générer les latents avec UNet. + - Gère CUDA OOM + - Prépare embeddings pour classifier-free guidance + - Applique motion module avant duplication + - Force latents à correspondre à la taille UNet (H/W) + """ + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be fourni") + + device = kwargs.get("device", "cuda") + fp16 = kwargs.get("fp16", True) + debug = kwargs.get("debug", True) + embeddings = kwargs.get("embeddings", None) + motion_module = kwargs.get("motion_module", None) + steps = kwargs.get("steps", 20) + scheduler = kwargs.get("scheduler", None) + guidance_scale = kwargs.get("guidance_scale", 7.5) + + latents = input_latents.to(device) + + # ---------------- Motion module avant duplication ---------------- + if motion_module is not None: + latents = motion_module(latents) + + # ---------------- Adapter H/W ---------------- + if latents.ndim == 4: # B,C,H,W + H, W = latents.shape[2:] + elif latents.ndim == 5: # B,C,T,H,W + T = latents.shape[2] + if T == 1: + latents = latents.squeeze(2) + H, W = latents.shape[2:] + else: + raise ValueError(f"[LATENT FIX] Shape inattendue: {latents.shape}") + + if (H != unet.sample_size) or (W != unet.sample_size): + latents = torch.nn.functional.interpolate( + latents, + size=(unet.sample_size, unet.sample_size), + mode='nearest' + ) + + # ---------------- Préparer embeddings ---------------- + unet_cross_dim = getattr(unet.config, "cross_attention_dim", 768) + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + # Projection automatique si nécessaire + if pos_embeds.shape[-1] != unet_cross_dim: + proj = torch.nn.Linear(pos_embeds.shape[-1], unet_cross_dim).to(device, dtype=torch.float16 if fp16 else torch.float32) + pos_embeds = proj(pos_embeds.to(device)) + neg_embeds = proj(neg_embeds.to(device)) + embeddings = (pos_embeds, neg_embeds) + elif embeddings is not None and embeddings.shape[-1] != unet_cross_dim: + proj = torch.nn.Linear(embeddings.shape[-1], unet_cross_dim).to(device, dtype=torch.float16 if fp16 else torch.float32) + embeddings = proj(embeddings.to(device)) + + # ---------------- Duplication batch pour CFG ---------------- + if guidance_scale > 1.0: + latents = torch.cat([latents, latents], dim=0) + if isinstance(embeddings, tuple): + pos_embeds, neg_embeds = embeddings + embeddings = torch.cat([neg_embeds.to(device), pos_embeds.to(device)], dim=0) + else: + embeddings = embeddings.to(device) + + kwargs["input_latents"] = latents + kwargs["embeddings"] = embeddings + + # ---------------- Filtrage arguments ---------------- + allowed_keys = [ + "input_latents", + "scheduler", + "embeddings", + "motion_module", + "guidance_scale", + "device", + "fp16", + "steps", + "debug", + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} + + # ---------------- Appel safe debug ---------------- + try: + return generate_latents_safe_debug_v2(unet, **filtered_kwargs) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [SAFE WRAPPER] CUDA out of memory caught, returning input latents") + return latents.clone() + else: + raise e + +# ici erreur cross_attention_dim mais la génération reste correct: +def generate_latents_safe_wrapper_v2(unet, **kwargs): + """ + Wrapper ultra-safe pour générer les latents avec UNet. + - Gère CUDA OOM + - Prépare embeddings tuple pour classifier-free guidance + """ + + input_latents = kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' must be provided") + + embeddings = kwargs.get("embeddings", None) + + # Si embeddings est un tuple de (pos, neg), concat pour CF guidance + if isinstance(embeddings, tuple) and len(embeddings) == 2: + pos_embeds, neg_embeds = embeddings + kwargs["embeddings"] = torch.cat([neg_embeds.to(kwargs.get("device", "cuda")), + pos_embeds.to(kwargs.get("device", "cuda"))], dim=0) + + # Arguments autorisés pour la version ultra-light + allowed_keys = [ + "input_latents", + "scheduler", + "embeddings", + "motion_module", + "guidance_scale", + "device", + "fp16", + "steps", + "debug", + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} + + try: + return generate_latents_safe_debug_v2(unet, **filtered_kwargs) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("⚠️ [SAFE WRAPPER] CUDA out of memory caught, returning input latents") + return input_latents.clone() + else: + raise e + + + +def generate_latents_ultralight( + unet, + input_latents, + scheduler, + embeddings=None, + motion_module=None, + guidance_scale=1.0, + device="cuda", + fp16=True, + steps=20, + patch_size=32, + debug=False +): + """ + Génération de latents ultra-light, patch par patch pour GPU <4Go. + - input_latents: torch.Tensor [B,C,H,W] ou [B,C,F,H,W] pour vidéo + - motion_module: optionnel, appliqué à chaque step + - embeddings: CLIP embeddings + """ + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + is_video = latents.ndim == 5 + steps_list = getattr(scheduler, "timesteps", range(steps)) + + for step_idx, t in enumerate(steps_list): + if motion_module is not None: + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # UNet patch par patch pour économiser la VRAM + if is_video: + B, C, F, H, W = latents.shape + latents_patch = latents.permute(0,2,1,3,4).reshape(B*F, C, H, W) + else: + latents_patch = latents + + H, W = latents_patch.shape[-2:] + output = torch.zeros_like(latents_patch, device=latents_patch.device) + + # Patch UNet + stride = patch_size + for y0 in range(0, H, stride): + for x0 in range(0, W, stride): + y1 = min(y0 + patch_size, H) + x1 = min(x0 + patch_size, W) + + patch = latents_patch[:, :, y0:y1, x0:x1] + patch = patch.half() if fp16 else patch.float() + + try: + with torch.no_grad(): + patch_out = unet(patch, t, encoder_hidden_states=embeddings)["sample"] + except Exception as e: + print(f"⚠ [UNet ERROR patch {y0},{x0}] {e}") + patch_out = patch.clone() + + # Clamp et cleanup + patch_out = torch.nan_to_num(patch_out, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5,5) + output[:, :, y0:y1, x0:x1] = patch_out + del patch, patch_out + torch.cuda.empty_cache() + + if is_video: + latents = output.reshape(B, F, C, H, W).permute(0,2,1,3,4) + else: + latents = output + + if debug: + print(f"[Step {step_idx}] latents min/max={latents.min():.4f}/{latents.max():.4f}") + + return latents + + + + + +# -------------------------------------------------------------------------------------------------------- + + +def decode_latents_ultrasafe_cpu(latents, vae, block_size=160, overlap=96, debug=True, threaded=True): + """ + Decode latents safely on CPU in patches, optionally threaded. + Supports 4D [B, C, H, W] and 5D [B, C, F, H, W] latents. + Logs min/max for debugging. + """ + import threading + + vae.to("cpu") + latents = latents.to("cpu") + is_5d = latents.dim() == 5 + if is_5d: + B, C, F, H, W = latents.shape + else: + B, C, H, W = latents.shape + F = 1 + latents = latents.unsqueeze(2) + + frames = [None] * F + + def decode_frame(f): + frame_latents = latents[:, :, f] + if debug: + print(f"[DEBUG] Decoding frame {f} latents shape: {frame_latents.shape} min/max: {frame_latents.min().item():.4f}/{frame_latents.max().item():.4f}") + + H_idx = list(range(0, H, block_size - overlap)) + W_idx = list(range(0, W, block_size - overlap)) + decoded_frame = torch.zeros((B, 3, H, W), dtype=torch.float32) + + for i, h_start in enumerate(H_idx): + h_end = min(h_start + block_size, H) + for j, w_start in enumerate(W_idx): + w_end = min(w_start + block_size, W) + patch = frame_latents[:, :, h_start:h_end, w_start:w_end] + with torch.no_grad(): + decoded_patch = vae.decode(patch).sample + decoded_frame[:, :, h_start:h_end, w_start:w_end] = decoded_patch + + frames[f] = decoded_frame + + if threaded: + threads = [] + for f in range(F): + t = threading.Thread(target=decode_frame, args=(f,)) + t.start() + threads.append(t) + for t in threads: + t.join() + else: + for f in range(F): + decode_frame(f) + + output = torch.stack(frames, dim=2) if is_5d else frames[0] + if debug: + print(f"[DEBUG] Final output shape: {output.shape} min/max: {output.min().item():.4f}/{output.max().item():.4f}") + return output + + +# ---------------- Encodage images → latents safe ---------------- +def encode_images_to_latents_safe(images, vae, device="cuda", dtype=torch.float16): + # -- assure cohérence dtype/device -- + vae = vae.to(device=device, dtype=dtype) + images_t = images.to(device=device, dtype=dtype) + + # sauvegarde dtype original pour restore + original_dtype = next(vae.parameters()).dtype + + with torch.no_grad(): + latents = vae.encode(images_t).latent_dist.sample() + + latents = latents * LATENT_SCALE + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # normalisation pour UNet / motion module + max_abs = latents.abs().max() + if max_abs > 0: + latents = latents / max_abs # scale [-1,1] + + # reshape pour pipeline [B,C,T,H,W] + if latents.ndim == 4: + latents = latents.unsqueeze(2) + + print( + "[SAFE ENCODE]", + "min=", latents.min().item(), + "max=", latents.max().item(), + "std=", latents.std().item(), + "dtype=", latents.dtype, + "device=", latents.device, + "shape=", latents.shape + ) + + return latents + +def decode_latents_ultrasafe_blockwise_adapted(latents, vae, block_size=160, overlap=96, gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0): + """ + Décodage ultrasafe des latents en PIL.Image + Adapté pour accepter 4D [B,C,H,W] ou 5D [B,C,T,H,W]. + Si 5D, retourne une liste de frames PIL. + """ + # Convert 5D → list de frames 4D + if latents.ndim == 5: + B,C,T,H,W = latents.shape + frames_4d = [latents[:, :, t, :, :] for t in range(T)] + elif latents.ndim == 4: + frames_4d = [latents] + else: + raise ValueError(f"Latents doivent être 4D ou 5D, got {latents.shape}") + + pil_frames = [] + for fidx, latent_4d in enumerate(frames_4d): + # On décode par blocs (supposons que decode_latents_ultrasafe_blockwise existe) + frame_pil = decode_latents_ultrasafe_blockwise(latent_4d, vae, block_size=block_size, overlap=overlap, + gamma=gamma, brightness=brightness, contrast=contrast, + saturation=saturation) + pil_frames.append(frame_pil) + + if len(pil_frames) == 1: + return pil_frames[0] + return pil_frames + + +# ---------------- Décodage latents → image ultrasafe (blockwise) ---------------- +def decode_latents_ultrasafe_blockwise(latents, vae, block_size=160, overlap=96, + gamma=1.0, brightness=1.0, contrast=1.0, saturation=1.0): + """ + Décodage sécurisé des latents en image PIL, bloc par bloc pour réduire VRAM. + Input: + latents : [B,C,H,W] ou [B,C,1,H,W] + vae : AutoencoderKL + """ + if latents.ndim == 5: + latents = latents.squeeze(2) # [B,C,H,W] + + B, C, H, W = latents.shape + device = next(vae.parameters()).device + + stride = block_size - overlap + h_steps = max(1, (H - overlap + stride - 1) // stride) + w_steps = max(1, (W - overlap + stride - 1) // stride) + + output = torch.zeros(B, C, H, W, device="cpu") + weight_map = torch.zeros(B, 1, H, W, device="cpu") + + for i in range(h_steps): + for j in range(w_steps): + y0 = i * stride + x0 = j * stride + y1 = min(y0 + block_size, H) + x1 = min(x0 + block_size, W) + + patch = latents[:, :, y0:y1, x0:x1].to(device) + patch = torch.nan_to_num(patch, nan=0.0, posinf=5.0, neginf=-5.0) + + with torch.no_grad(): + patch_decoded = vae.decode(patch.to(vae.dtype)).sample + patch_decoded = ((patch_decoded + 1)/2).clamp(0,1).cpu() + + # Weighted blend pour recouvrement + h_patch, w_patch = patch_decoded.shape[2], patch_decoded.shape[3] + mask = torch.ones(1, 1, h_patch, w_patch) + output[:, :, y0:y0+h_patch, x0:x0+w_patch] += patch_decoded * mask + weight_map[:, :, y0:y0+h_patch, x0:x0+w_patch] += mask + + torch.cuda.empty_cache() + + output = output / weight_map.clamp(min=1e-5) + frame_pil = ToPILImage()(output[0].clamp(0,1)) + + # Ajustements image + if gamma != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(gamma) + if brightness != 1.0: + frame_pil = ImageEnhance.Brightness(frame_pil).enhance(brightness) + if contrast != 1.0: + frame_pil = ImageEnhance.Contrast(frame_pil).enhance(contrast) + if saturation != 1.0: + frame_pil = ImageEnhance.Color(frame_pil).enhance(saturation) + + return frame_pil + +# ---------------- Decode safe 4Go VRAM ---------------- +def decode_latents_safe_vram(latents, vae, gamma=0.7, brightness=1.2, contrast=1.1, saturation=1.15): + from torchvision.transforms import ToPILImage + import torch + from PIL import ImageEnhance + + latents = torch.nan_to_num(latents, nan=0.0, posinf=4.0, neginf=-4.0) + + # Si batch T multiple, on prend le 0 pour chaque frame + if latents.ndim == 5: + latents = latents[:, :, 0, :, :] + + latents = latents / LATENT_SCALE + latents = latents.to(dtype=torch.float32, device="cpu") # decode safe 4Go + + with torch.no_grad(): + image_tensor = vae.decode(latents).sample + + image_tensor = ((image_tensor + 1) / 2).clamp(0, 1) + image_tensor = image_tensor.pow(1.0 / gamma) + + images = [] + to_pil = ToPILImage() + for i in range(image_tensor.shape[0]): + img = image_tensor[i] + pil_img = to_pil(img.cpu()) + pil_img = ImageEnhance.Brightness(pil_img).enhance(brightness) + pil_img = ImageEnhance.Contrast(pil_img).enhance(contrast) + pil_img = ImageEnhance.Color(pil_img).enhance(saturation) + images.append(pil_img) + return images + + +def decode_latents_to_image_auto(latents, vae): + """ + Décodage automatique du latent, quel que soit son format (3D, 4D, 5D), + retourne un tenseur [B, C, H, W]. + """ + # S'assurer que latents est 4D : [B, C, H, W] + if latents.ndim == 5: # [B, C, T, H, W] + latents = latents.squeeze(2) + elif latents.ndim == 3: # [C, H, W] -> [1, C, H, W] + latents = latents.unsqueeze(0) + elif latents.ndim != 4: + raise ValueError(f"Latents avec dimension inattendue {latents.shape}") + + # Décodage avec VAE + with torch.no_grad(): + images = vae.decode(latents / LATENT_SCALE).sample + + # Normalisation 0-1 + images = (images + 1.0) / 2.0 + images = images.clamp(0, 1) + return images + + +# ------------------------- +# scripts/utils/n3r_utils.py +# ------------------------- + +def patchify_latents(latents, tile_size=128, overlap=32): + """ + Découpe les latents en patches pour traitement patch-based. + Retourne une liste plate de patches + leurs positions. + """ + _, C, H, W = latents.shape + patches = [] + coords = [] + + stride = tile_size - overlap + for y in range(0, H, stride): + for x in range(0, W, stride): + y1 = min(y + tile_size, H) + x1 = min(x + tile_size, W) + y0 = y1 - tile_size if y1 - y < tile_size else y + x0 = x1 - tile_size if x1 - x < tile_size else x + + patch = latents[:, :, y0:y1, x0:x1] + patches.append(patch) # ✅ Patch est un Tensor + coords.append((y0, y1, x0, x1)) # ✅ Coordonnees patch + + return patches, coords + + + +# ------------------------- +# Dépatchification sécurisée +# ------------------------- +def unpatchify_latents(patches, coords, full_shape, device=None, dtype=None): + """ + Recompose les latents à partir de la liste de patches. + """ + B, C, H, W = full_shape + latents = torch.zeros(full_shape, device=device, dtype=dtype) + + for patch, (y0, y1, x0, x1) in zip(patches, coords): + latents[:, :, y0:y1, x0:x1] = patch + + return latents +# ------------------------- +# Génération d’un frame patch-based v3 +# ------------------------- +# ------------------------- +# generate_frame_patched_v3 +# ------------------------- +def generate_frame_patched_v3( + input_image, + vae, + unet, + scheduler, + pos_emb, + neg_emb=None, + tile_size=128, + overlap=32, + steps=12, + guidance_scale=4.5, + init_image_scale=0.75, + creative_noise=0.0, + device="cuda", + dtype=torch.float16 +): + """ + Génère un frame patch-based safe pour CPU/GPU. + """ + # ------------------------- + # Encode l'image en latents (toujours sur CPU pour VAE) + # ------------------------- + with torch.no_grad(): + latents = vae.encode(input_image.to(torch.float32)).latent_dist.sample() * LATENT_SCALE + latents = latents.to(device=device, dtype=dtype) + + # ------------------------- + # Patchify latents + # ------------------------- + patches, patch_coords = patchify_latents(latents, tile_size=tile_size, overlap=overlap) + + # ------------------------- + # Scheduler timesteps + # ------------------------- + for t in scheduler.timesteps[:steps]: + t_tensor = torch.tensor([t], device=device, dtype=dtype) + new_patches = [] + + for patch in patches: + patch = patch.to(device=device, dtype=dtype) + + # Flatten 5D -> 4D si nécessaire (batch x channels x H x W) + if patch.dim() == 5: + B,C,D,H,W = patch.shape + patch = patch.view(B*C, D, H, W) + + # UNet forward + noise_pred = unet( + patch, + timestep=t_tensor, + encoder_hidden_states=pos_emb + ).sample + + # Guidance si neg_emb fourni + if neg_emb is not None: + noise_pred_neg = unet( + patch, + timestep=t_tensor, + encoder_hidden_states=neg_emb + ).sample + noise_pred = noise_pred_neg + guidance_scale * (noise_pred - noise_pred_neg) + + # Creative noise + if creative_noise > 0.0: + noise_pred += creative_noise * torch.randn_like(noise_pred) + + new_patches.append(noise_pred) + + # Libération GPU intermédiaire + del patch, noise_pred + torch.cuda.empty_cache() + + patches = new_patches + + # ------------------------- + # Recomposition latents + # ------------------------- + latents = unpatchify_latents(patches, patch_coords, latents.shape) + + # ------------------------- + # Décodage VAE + # ------------------------- + with torch.no_grad(): + frame_tensor = vae.decode(latents.to(torch.float32)).sample + frame_tensor = frame_tensor.to(device=device, dtype=dtype) + + return frame_tensor +# ------------------------- +# Génération d’un frame patch-based v1 +# ------------------------- + +def generate_frame_patched( + input_image, vae, unet, scheduler, + pos_emb, neg_emb=None, + tile_size=128, overlap=32, + steps=12, + guidance_scale=4.5, + init_image_scale=0.75, + creative_noise=0.0, + device="cuda", + dtype=torch.float16 +): + """ + Génère un frame en utilisant la logique patch-based safe pour CPU/GPU. + """ + # ------------------------- + # Encode l'image en latents (toujours sur CPU pour VAE) + # ------------------------- + with torch.no_grad(): + latents = vae.encode(input_image.to(torch.float32)).latent_dist.sample() * LATENT_SCALE + latents = latents.to(device=device, dtype=dtype) + + # ------------------------- + # Patchify + # ------------------------- + patches, patch_coords = patchify_latents(latents, tile_size=tile_size, overlap=overlap) + + # ------------------------- + # Parcours des timesteps + # ------------------------- + for t in scheduler.timesteps[:steps]: + t_tensor = torch.tensor([t], device=device, dtype=dtype) + + new_patches = [] + for patch in patches: + patch = patch.to(device=device, dtype=dtype) + + # UNet call : timestep obligatoire + embeddings + noise_pred = unet( + patch, + timestep=t_tensor, + encoder_hidden_states=pos_emb + ).sample + + # Guidance si négatif fourni + if neg_emb is not None: + noise_pred_neg = unet( + patch, + timestep=t_tensor, + encoder_hidden_states=neg_emb + ).sample + noise_pred = noise_pred_neg + guidance_scale * (noise_pred - noise_pred_neg) + + new_patches.append(noise_pred) + + patches = new_patches + + # Optionnel : ajouter un peu de noise créatif + if creative_noise > 0.0: + for i in range(len(patches)): + patches[i] += creative_noise * torch.randn_like(patches[i]) + + # ------------------------- + # Unpatchify + # ------------------------- + latents = unpatchify_latents(patches, patch_coords, latents.shape) + + # ------------------------- + # Décode en image + # ------------------------- + with torch.no_grad(): + frame_tensor = vae.decode(latents.to(torch.float32)).sample + frame_tensor = frame_tensor.to(dtype=dtype, device=device) + + return frame_tensor + + +# ------------------------- +# Génération split_image_into_patches +# ------------------------- + +def split_image_into_patches(img, patch_size=128, overlap=16): + """Découpe l'image en patches avec overlap.""" + _, C, H, W = img.shape + stride = patch_size - overlap + patches = [] + positions = [] + + for y in range(0, H, stride): + for x in range(0, W, stride): + y0, x0 = y, x + y1, x1 = min(y+patch_size, H), min(x+patch_size, W) + patch = img[:, :, y0:y1, x0:x1] + patches.append(F.pad(patch, (0, patch_size-(x1-x0), 0, patch_size-(y1-y0)))) + positions.append((y0, y1, x0, x1)) + return torch.stack(patches), positions + +def reassemble_patches(patches, positions, H, W, overlap=16): + """Reconstitue l'image à partir des patches avec blending.""" + device = patches.device + out = torch.zeros(1, 3, H, W, device=device) + count = torch.zeros(1, 1, H, W, device=device) + + for patch, (y0, y1, x0, x1) in zip(patches, positions): + h, w = y1-y0, x1-x0 + out[:, :, y0:y1, x0:x1] += patch[:, :, :h, :w] + count[:, :, y0:y1, x0:x1] += 1 + return out / count + + + + +def generate_frame_with_tiling( + input_image, vae, unet, scheduler, embeddings, motion_module=None, + tile_size=128, overlap=16, fp16=True, + guidance_scale=4.5, init_image_scale=0.75, + creative_noise=0.0, steps=12 +): + """ + Génère une image à partir d'une input_image patchée, recomposée automatiquement. + Optimisé pour MiniSD/TinySD afin de réduire la VRAM et éviter OOM. + """ + + device = input_image.device + dtype = torch.float16 if fp16 else torch.float32 + _, C, H, W = input_image.shape + + # Calculer nombre de patches + stride = tile_size - overlap + y_positions = list(range(0, H, stride)) + x_positions = list(range(0, W, stride)) + + # Préparer canvas final + frame_latents = torch.zeros((1, C, H, W), device=device, dtype=dtype) + weight_map = torch.zeros((1, 1, H, W), device=device, dtype=dtype) + + for y in y_positions: + for x in x_positions: + # Extraire patch + y0, y1 = y, min(y + tile_size, H) + x0, x1 = x, min(x + tile_size, W) + patch = input_image[:, :, y0:y1, x0:x1].to(dtype) + + # --- Encoder en latents (VAE FP32-safe) --- + with torch.no_grad(): + patch_latents = vae.encode(patch.float()).latent_dist.sample() * 0.18215 + patch_latents = patch_latents.to(dtype) + + # --- Motion module --- + if motion_module is not None: + patch_latents = motion_module(patch_latents) + + # --- Scheduler / UNet --- + # Utiliser FP16 pour UNet si demandé + patch_latents = patch_latents.half() if fp16 else patch_latents.float() + for pos_embeds, neg_embeds in embeddings: + for t in scheduler.timesteps: + with torch.no_grad(): + noise_pred = unet( + patch_latents, timestep=t, encoder_hidden_states=pos_embeds + ).sample + # Guidance + patch_latents = patch_latents + guidance_scale * (noise_pred - patch_latents) + + # --- Décoder patch --- + with torch.no_grad(): + patch_img = vae.decode(patch_latents.float()).sample # toujours FP32 pour VAE + patch_img = patch_img.to(dtype) + + # --- Ajouter au canvas final --- + frame_latents[:, :, y0:y1, x0:x1] += patch_img + weight_map[:, :, y0:y1, x0:x1] += 1.0 + + # Libérer mémoire GPU + torch.cuda.empty_cache() + + # Normaliser par superposition + frame_latents /= weight_map + return frame_latents + + +# ------------------------- +# encode_image_latents reste FP32 pour stabilité +# ------------------------- +def encode_image_latents_fp32(image_tensor, vae, scale=LATENT_SCALE): + device = next(vae.parameters()).device + img = image_tensor.to(device=device, dtype=next(vae.parameters()).dtype) + with torch.no_grad(): + latents = vae.encode(img).latent_dist.sample() * scale + return latents.unsqueeze(2) # [B,C,1,H,W] + + +# ------------------------- +# Génération patch par patch +# ------------------------- +from scripts.modules.motion_module_tiny import MotionModuleTiny + +def generate_frame_with_tiling_v1( + input_image, + vae, + unet, + motion_module, + patch_size: int = 128, + overlap: int = 16, + fp16: bool = True, + **kwargs +): + """ + Découpe l'image en patches, encode, génère latents, puis recompose. + FP16-safe pour Mini/Tiny SD. + """ + # Convertir en FP16 si demandé + dtype = torch.float16 if fp16 else torch.float32 + + _, h, w = input_image.shape + stride = patch_size - overlap + frame_tensor = torch.zeros_like(input_image, dtype=dtype) + + # Eviter double passage de fp16 + kwargs = dict(kwargs) + kwargs.pop("fp16", None) + + for y in range(0, h, stride): + for x in range(0, w, stride): + y1, x1 = y, x + y2, x2 = min(y + patch_size, h), min(x + patch_size, w) + patch = input_image[:, y1:y2, x1:x2] + + # Encoder le patch + patch_latents = encode_image_latents_fp32(patch, vae, fp16=fp16) + + # Génération latents + patch_latents = generate_latents_ai_5D_optimized( + patch_latents, + unet, + motion_module, + fp16=fp16, + **kwargs + ) + + # Décodage patch + patch_frame = vae.decode(patch_latents).sample.to(dtype) + + # Recomposer dans l'image finale + frame_tensor[:, y1:y2, x1:x2] = patch_frame + + return frame_tensor + + +# ------------------------- +# Génération et décodage sécurisée pour n3rHYBRID24 +# ------------------------- +def generate_and_decode(latent_frame, unet, scheduler, pos_embeds, neg_embeds, + motion_module, vae, device="cuda", dtype=torch.float32, + guidance_scale=4.5, init_image_scale=0.85, creative_noise=0.0, + seed=42, steps=35, tile_size=128, overlap=32, vae_offload=False): + """ + Génère les latents pour un frame et les décode en image finale, + avec gestion automatique des devices, FP16, offload et tiling. + """ + import torch, time + + torch.manual_seed(seed) + + # ------------------------- + # Déplacer latents et embeddings sur le bon device et dtype + # ------------------------- + latent_frame = latent_frame.to(device=device, dtype=dtype) + pos_embeds = pos_embeds.to(device=device, dtype=dtype) + neg_embeds = neg_embeds.to(device=device, dtype=dtype) + + # ------------------------- + # Génération avec UNet + Scheduler + # ------------------------- + gen_start = time.time() + batch_latents = generate_latents_ai_5D_optimized( + latent_frame=latent_frame, + scheduler=scheduler, + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=seed, + steps=steps + ) + gen_time = time.time() - gen_start + + +# ------------------------- +# Encode tile safe FP32 +# ------------------------- +def encode_tile_safe_fp32(vae, tile_np, device="cuda", vae_offload=False): + """ + Encode une tile numpy [C,H,W] en latent VAE [1,4,H/8,W/8] + VRAM-safe, compatible FP32 VAE complet et offload + """ + tile_tensor = torch.from_numpy(tile_np).unsqueeze(0).to(device=device, dtype=torch.float32) # [1,3,H,W] + with torch.no_grad(): + if vae_offload: + vae.to(device) # mettre VAE sur le même device que le tile + latent = vae.encode(tile_tensor).latent_dist.sample() * LATENT_SCALE + if vae_offload: + vae.cpu() # remettre VAE sur CPU pour économiser VRAM + if device.startswith("cuda"): + torch.cuda.synchronize() + return latent + +# ------------------------- +# Merge tiles FP32 +# ------------------------- +def merge_tiles_fp32(tile_list, positions, H, W, latent_scale=1.0): + """ + Fusionne les tiles latents [1,C,th,tw] en image complète [1,C,H,W]. + Supporte tiles de tailles différentes et bordures. + """ + device = tile_list[0].device + C = tile_list[0].shape[1] + + out = torch.zeros(1, C, H, W, dtype=tile_list[0].dtype, device=device) + count = torch.zeros(1, C, H, W, dtype=tile_list[0].dtype, device=device) + + for tile, (y1, y2, x1, x2) in zip(tile_list, positions): + _, c, th, tw = tile.shape + h_len = y2 - y1 + w_len = x2 - x1 + th = min(th, h_len) + tw = min(tw, w_len) + out[:, :, y1:y1+th, x1:x1+tw] += tile[:, :, :th, :tw] + count[:, :, y1:y1+th, x1:x1+tw] += 1.0 + + count[count==0] = 1.0 + out = out / count + return out + +def encode_tile_safe_latent(vae, tile, device, LATENT_SCALE=0.18215): + """ + Encode une tuile en latent FP32 et pad si nécessaire. + tile: np.array (H,W,3) float32 0-1 + return: torch tensor (1,4,H_latent_max,W_latent_max) + """ + tile_tensor = torch.tensor(tile).permute(2,0,1).unsqueeze(0).to(device) + latent = vae.encode(tile_tensor).latent_dist.sample() * LATENT_SCALE + # Vérifier H,W du latent + H_lat, W_lat = latent.shape[2], latent.shape[3] + H_max = (tile.shape[0] + 7)//8 # VAE scale + W_max = (tile.shape[1] + 7)//8 + if H_lat != H_max or W_lat != W_max: + padH = H_max - H_lat + padW = W_max - W_lat + latent = torch.nn.functional.pad(latent, (0,padW,0,padH)) + return latent + +# --- Découper une image en tiles avec overlap --- +def tile_image_128(image, tile_size=128, overlap=16): + """ + Découpe une image (H,W,C ou C,H,W) en tiles avec overlap. + Retourne une liste de tiles (numpy arrays) et leurs positions (x1,y1,x2,y2). + """ + # Assure shape [C,H,W] + if image.ndim == 3 and image.shape[2] in [1,3]: + # H,W,C -> C,H,W + image = image.transpose(2,0,1) + elif image.ndim != 3: + raise ValueError(f"Image doit être 3D, shape={image.shape}") + + C,H,W = image.shape + stride = tile_size - overlap + tiles = [] + positions = [] + + for y in range(0, H, stride): + for x in range(0, W, stride): + y1, y2 = y, min(y + tile_size, H) + x1, x2 = x, min(x + tile_size, W) + tile = image[:, y1:y2, x1:x2] + tiles.append(tile.astype(np.float32)) # reste numpy + positions.append((x1, y1, x2, y2)) + return tiles, positions + + +# --- Normalisation d'une tile --- +def normalize_tile_128(img_array): + """ + img_array: np.ndarray, shape [H,W,C] ou [C,H,W], valeurs 0-255 + Retour: torch.Tensor [1,3,H,W] float32, valeurs 0-1 + """ + if img_array.ndim == 3 and img_array.shape[2] == 3: # HWC + img_array = img_array.transpose(2,0,1) + img_tensor = torch.from_numpy(img_array).unsqueeze(0).float() / 255.0 + return img_tensor + + + + +# --- Merge tiles pour reconstruire l'image --- +# ------------------------- +# Merge tiles +# ------------------------- +def merge_tiles(tile_list, positions, H, W): + out = torch.zeros(1, 3, H, W, dtype=torch.float32) + count = torch.zeros(1, 3, H, W, dtype=torch.float32) + + for tile, (y1, y2, x1, x2) in zip(tile_list, positions): + _, c, th, tw = tile.shape + out[:,:,y1:y2,x1:x2] += tile[:,:, :th, :tw] + count[:,:,y1:y2,x1:x2] += 1.0 + + count[count==0] = 1.0 + out = out / count + return out + + +def save_frame(img_array, filename): + img_array = np.clip(img_array, 0.0, 1.0) + img_uint8 = (img_array * 255).astype(np.uint8) + os.makedirs(os.path.dirname(filename), exist_ok=True) + Image.fromarray(img_uint8).save(filename) + +def encode_image_latents(image_tensor, vae, scale=LATENT_SCALE): + """Encode RGB -> latents 4 canaux""" + device = next(vae.parameters()).device + img = image_tensor.to(device=device, dtype=next(vae.parameters()).dtype) + with torch.no_grad(): + latents = vae.encode(img).latent_dist.sample() * scale + return latents # [B,4,H/8,W/8] + + +def generate_latents_ai_5D_stable( + latent_frame, # [B,4,H,W] ou [1,4,H,W] + pos_embeds, # [1,77,768] + neg_embeds, # [1,77,768] + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=7.5, + init_image_scale=0.7, + creative_noise=0.03, + seed=42, + steps=40 +): + """ + Génération de latents ultra-stable avec : + - réinjection du latent initial pour stabilité + - creative_noise modéré + - support batch [B,4,H,W] ou frame unique [1,4,H,W] + Sortie : latents [B,4,H,W] + """ + + torch.manual_seed(seed) + B = latent_frame.shape[0] + + latents = latent_frame.to(device=device, dtype=dtype) + init_latents = latents.clone() + + scheduler.set_timesteps(steps, device=device) + + for t in scheduler.timesteps: + # 🔹 Motion module + if motion_module is not None: + latents = motion_module(latents) + + # 🔹 Creative noise + if creative_noise > 0: + latents = latents + torch.randn_like(latents) * creative_noise + + # 🔹 Classifier-Free Guidance + latent_model_input = torch.cat([latents, latents], dim=0) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + with torch.no_grad(): + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔹 Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # 🔹 Réinjection du latent initial (stabilité couleur / contenu) + latents = latents + init_image_scale * (init_latents - latents) + + # 🔹 Sécurité NaN / inf + if torch.isnan(latents).any() or torch.isinf(latents).any(): + latents = latents + torch.randn_like(latents) * 1e-3 + + # 🔹 Debug + mean_val = latents.abs().mean().item() + if mean_val < 1e-5: + print(f"⚠ Latent trop petit à timestep {t}, mean={mean_val:.6f}") + + return latents + + + +def generate_latents_ai_5D_optimized( + latent_frame, # [1,4,H,W] + pos_embeds, # [1,77,768] + neg_embeds, # [1,77,768] + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=7.5, # aligné robuste + init_image_scale=2.0, # aligné robuste + creative_noise=0.0, + seed=42, + steps=40 +): + """ + Version équivalente à generate_latents_robuste mais pour une seule frame. + Sortie: [1,4,H,W] + """ + + torch.manual_seed(seed) + + # ---- Setup ---- + latents = latent_frame.to(device=device, dtype=dtype) + init_latents = latents.clone() + + scheduler.set_timesteps(steps, device=device) + + for t in scheduler.timesteps: + + # 🔹 Motion module + if motion_module is not None: + latents = motion_module(latents) + + # 🔹 Creative noise (même endroit que robuste) + if creative_noise > 0: + latents = latents + torch.randn_like(latents) * creative_noise + + # 🔹 Classifier-Free Guidance + latent_model_input = torch.cat([latents, latents], dim=0) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔹 Scheduler step (IMPORTANT: batch normal, pas latents[:1]) + latents = scheduler.step( + noise_pred, + t, + latents + ).prev_sample + + # 🔹 Réinjection identique à robuste + latents = latents + init_image_scale * (init_latents - latents) + + # 🔹 Sécurité NaN / inf + if torch.isnan(latents).any() or torch.isinf(latents).any(): + print(f"⚠ NaN/inf détecté à timestep {t}, correction légère") + latents = latents.clone() + latents = latents + torch.randn_like(latents) * 1e-3 + + # 🔹 Debug stabilité + mean_val = latents.abs().mean().item() + if math.isnan(mean_val) or mean_val < 1e-5: + print(f"⚠ Latent trop petit à timestep {t}, mean={mean_val:.6f}") + + return latents + + + +# ------------------------- +# 🔹 Génération des latents +# ------------------------- + +def generate_latents_4Go( + latent_frame, # [1,4,H,W] + pos_embeds, # [1,77,768] + neg_embeds, # [1,77,768] + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.0, + init_image_scale=0.9, + steps=20, + seed=1234 +): + """ + Génération de latents optimisée pour AnimateDiff avec Classifier-Free Guidance. + Corrige le mismatch batch entre UNet et embeddings. + """ + torch.manual_seed(seed) + + # 🔹 Mettre le latent sur le device + latents = latent_frame.to(device=device, dtype=dtype) + + batch_size = latents.shape[0] # normalement 1 + + # 🔹 Répliquer embeddings selon batch size + neg_embeds = neg_embeds.repeat(batch_size, 1, 1) + pos_embeds = pos_embeds.repeat(batch_size, 1, 1) + + # 🔹 Embeddings C.F.G + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + # batch = 2 * batch_size + + # 🔹 Scheduler + scheduler.set_timesteps(steps, device=device) + + for t in scheduler.timesteps: + + # 🔹 Input latent doublé pour C.F.G + latent_input = torch.cat([latents, latents], dim=0) + + # 🔹 Motion module si utilisé + if motion_module is not None: + latent_input = motion_module(latent_input) + + # 🔹 Blend avec latent original + latent_input = latent_input * init_image_scale + torch.cat([latents, latents], dim=0) * (1 - init_image_scale) + + # 🔹 UNet forward + noise_pred = unet(latent_input, t, encoder_hidden_states=embeds).sample + + # 🔹 Classifier-Free Guidance + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔹 Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # 🔹 Clamp sécurité fp16 + latents = latents.clamp(-1.5, 1.5) + + return latents + +# ------------------------- +# 🔹 Génération des texts +# ------------------------- + +def encode_text_embeddings( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: str, + negative_prompt: str = "", + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_length: int = 77 +): + """ + Encode le texte en embeddings pour Classifier-Free Guidance. + + Args: + tokenizer (CLIPTokenizer): tokenizer du modèle. + text_encoder (CLIPTextModel): text encoder du modèle. + prompt (str): texte positif. + negative_prompt (str): texte négatif (CFG). + device (str): "cuda" ou "cpu". + dtype (torch.dtype): torch.float16 ou torch.float32. + max_length (int): longueur max du tokenizer (77 pour SD). + + Returns: + tuple: (pos_embeds, neg_embeds) chacun [1, max_length, 768] + """ + # 🔹 Tokenisation texte positif + pos_tokens = tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt" + ).to(device) + + # 🔹 Tokenisation texte négatif + neg_tokens = tokenizer( + negative_prompt if negative_prompt else "", + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt" + ).to(device) + + with torch.no_grad(): + # 🔹 Encodage en embeddings + pos_embeds = text_encoder(**pos_tokens).last_hidden_state.to(dtype) + neg_embeds = text_encoder(**neg_tokens).last_hidden_state.to(dtype) + + return pos_embeds, neg_embeds + +def load_image_latent(image_path, vae, device="cuda", dtype=torch.float16, target_size=128): + """ + Charge une image et la convertit en latent tensor pour le UNet. + + Args: + image_path (str): Chemin vers l'image. + vae (AutoencoderKL): VAE du modèle. + device (str): "cuda" ou "cpu". + dtype (torch.dtype): torch.float16 ou torch.float32. + target_size (int): taille de l'image (carrée) pour le modèle. + + Returns: + torch.Tensor: latent tensor [1, 4, H/8, W/8] + """ + # 1️⃣ Charger et redimensionner + image = Image.open(image_path).convert("RGB") + transform = T.Compose([ + T.Resize((target_size, target_size)), + T.ToTensor(), + ]) + img_tensor = transform(image).unsqueeze(0).to(device=device, dtype=dtype) # [1,3,H,W] + + # 2️⃣ Normalisation pour VAE + img_tensor = img_tensor * 2.0 - 1.0 # [-1,1] + + # 3️⃣ Encoder via VAE (no_grad pour économiser mémoire) + with torch.no_grad(): + latent = vae.encode(img_tensor).latent_dist.sample() # [1,4,H/8,W/8] + + return latent + + +# ------------------------------ +# 🔹 Fonction latents optimisée +# ------------------------------ + +def generate_latents_ai_5D_light( + latent_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.0, + init_image_scale=0.9, + creative_noise=0.0, + seed=1234, + steps=25 +): + torch.manual_seed(seed) + + latent_frame = latent_frame.to(device=device, dtype=dtype) + latents = latent_frame.repeat(2,1,1,1) # batch=2 pour CFG une seule fois + + if creative_noise > 0.0: + latents += torch.randn_like(latents) * creative_noise + + # embeddings + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + scheduler.set_timesteps(steps, device=device) + + for t in scheduler.timesteps: + latent_input = latents + + if motion_module: + latent_input = motion_module(latent_input) + + # blend avec image initiale + latent_input = latent_input * init_image_scale + latent_frame.repeat(2,1,1,1) * (1 - init_image_scale) + + # UNet forward + noise_pred = unet(latent_input, t, encoder_hidden_states=embeds).sample + + # guidance + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # scheduler step + latents_step = scheduler.step(noise_pred, t, latents[:1]).prev_sample + + # clamp sécurité + latents_step = latents_step.clamp(-2.0, 2.0) + + # préparer pour prochaine étape (reduplication batch=2) + latents = torch.cat([latents_step, latents_step], dim=0) + + return latents[:1] + +# ------------------------------------------------------------ +# generate_latents_ai_5D_optimized_test - Anomalie sortie Frame Noir +# ------------------------------------------------------------ + +def generate_latents_ai_5D_optimized_test( + latent_frame, # [1,4,H,W] + pos_embeds, # [1,77,768] + neg_embeds, # [1,77,768] + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.0, + init_image_scale=0.9, + creative_noise=0.0, + seed=1234, + steps=40 +): + + torch.manual_seed(seed) + + latent_frame = latent_frame.to(device=device, dtype=dtype) + + if creative_noise > 0.0: + latent_frame = latent_frame + torch.randn_like(latent_frame) * creative_noise + + # 🔹 Classifier-Free Guidance embeddings + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + # shape = [2,77,768] + + # 🔹 Duplicate latent for CFG + latents = latent_frame.repeat(2, 1, 1, 1) # [2,4,H,W] + + scheduler.set_timesteps(steps, device=device) + + for t in scheduler.timesteps: + + latent_input = latents + + # 🔹 Motion module (si utilisé) + if motion_module is not None: + latent_input = motion_module(latent_input) + + # 🔹 Blend avec image initiale (stabilité) + latent_input = ( + latent_input * init_image_scale + + latent_frame.repeat(2,1,1,1) * (1 - init_image_scale) + ) + + # 🔹 UNet forward + noise_pred = unet( + latent_input, + t, + encoder_hidden_states=embeds + ).sample + + # 🔹 Guidance + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # 🔹 Scheduler step + latents = scheduler.step( + noise_pred, + t, + latents[:1] # on garde batch=1 pour le step + ).prev_sample + + # 🔹 Clamp sécurité fp16 + latents = latents.clamp(-1.5, 1.5) + + # 🔹 Re-dupliquer pour prochaine itération + latents = latents.repeat(2,1,1,1) + + # Retour final batch=1 + return latents[:1] + + +def generate_latents_ai_5D_batch( + latent_frames, # [B,C,T,H,W] tensor + pos_embeds, # embeddings positifs + neg_embeds, # embeddings négatifs + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.5, + init_image_scale=0.85, + creative_noise=0.0, + seed=42, + steps=35 +): + """ + Génération de latents animés pour plusieurs images / frames. + + latent_frames : [B,C,T,H,W] + pos_embeds, neg_embeds : embeddings texte + unet : modèle UNet + scheduler : scheduler de diffusion + motion_module : module de mouvement optionnel + """ + + torch.manual_seed(seed) + + B, C, T, H, W = latent_frames.shape + latent_frames = latent_frames.to(device=device, dtype=dtype) + + # Ajouter bruit créatif si besoin + if creative_noise > 0.0: + noise = torch.randn_like(latent_frames) * creative_noise + latent_frames = latent_frames + noise + + # Wrapper motion module + if motion_module: + def motion_wrapper(latents, timestep=None): + try: + return motion_module(latents) + except TypeError: + return motion_module(latents) + else: + motion_wrapper = lambda x, t=None: x + + # Scheduler timesteps + scheduler.set_timesteps(steps, device=device) + timesteps = scheduler.timesteps + + # Embeddings concaténés + total_embeds = torch.cat([pos_embeds, neg_embeds], dim=0) + total_batch = total_embeds.shape[0] + + # Boucle par frame + for t_idx, t in enumerate(timesteps): + # Itérer sur chaque timestep + # latent_frames = [B,C,T,H,W] → on traite chaque frame + for f_idx in range(T): + latent_frame = latent_frames[:, :, f_idx, :, :] # [B,C,H,W] + + # Motion module + latent_frame = motion_wrapper(latent_frame, t) + + # Préparer batch pour guidance + latent_input = latent_frame.repeat(total_batch, 1, 1, 1) + latent_input = latent_input * init_image_scale + latent_frame.repeat(total_batch,1,1,1) * (1-init_image_scale) + + # UNet forward + noise_pred = unet(latent_input, t, encoder_hidden_states=total_embeds).sample + + # Guidance scale + if total_batch > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Scheduler step + step_output = scheduler.step(noise_pred, t, latent_frame) + latent_frame = step_output.prev_sample + + # Clamp sécurité fp16 + latent_frame = latent_frame.clamp(-1.5, 1.5) + + # Remettre la frame à sa place + latent_frames[:, :, f_idx, :, :] = latent_frame + + return latent_frames + +def generate_latents_ai_5D_optimized1( + latent_frame, # [B,C,H,W] tensor + pos_embeds, # embeddings positifs + neg_embeds, # embeddings négatifs + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.5, + init_image_scale=0.85, + creative_noise=0.0, + seed=42, + steps=35 +): + """ + Génération de latents animés 5D (B,C,T,H,W) + - latent_frame : [B,C,H,W] + - pos_embeds, neg_embeds : embeddings texte + - unet : UNet + - scheduler : scheduler de diffusion (DDIM/LMS) + - motion_module : module de mouvement optionnel + """ + torch.manual_seed(seed) + + # Assurer que le latent est sur le bon device + latent_frame = latent_frame.to(device=device, dtype=dtype) + + # Ajouter bruit créatif + if creative_noise > 0.0: + noise = torch.randn_like(latent_frame) * creative_noise + latent_frame = latent_frame + noise + + # Wrapper motion module + if motion_module: + def motion_wrapper(latents, timestep=None): + try: + return motion_module(latents) + except TypeError: + return motion_module(latents) + else: + motion_wrapper = lambda x, t=None: x + + # Scheduler timesteps + scheduler.set_timesteps(steps, device=device) + timesteps = scheduler.timesteps + + # Préparer batch pour guidance + total_embeds = torch.cat([pos_embeds, neg_embeds], dim=0) + batch_size = total_embeds.shape[0] + + # Répéter le latent pour correspondre au batch + latent_model_input = latent_frame.repeat(batch_size, 1, 1, 1) + + # Boucle diffusion + for i, t in enumerate(timesteps): + latent_model_input = latent_model_input.to(device=device, dtype=dtype) + + # Motion module + latent_model_input = motion_wrapper(latent_model_input, t) + + # Guidance input + latent_model_input_in = latent_model_input.clone() + latent_model_input_in = latent_model_input_in * init_image_scale + latent_frame.repeat(batch_size,1,1,1) * (1-init_image_scale) + + # UNet forward + noise_pred = unet(latent_model_input_in, t, encoder_hidden_states=total_embeds).sample + + # Guidance scale + if batch_size > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred + + # Scheduler step + step_output = scheduler.step(noise_pred, t, latent_model_input) + latent_model_input = step_output.prev_sample + + # Clamp pour sécurité fp16 + latent_model_input = latent_model_input.clamp(-1.5, 1.5) + + return latent_model_input + + +# ------------------------- +# generate_latents_ai_5D_optimized (fixed) +# ------------------------- + + +def generate_latents_ai_5D_fixed( + latent_frame, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.5, + init_image_scale=0.85, + creative_noise=0.0, + seed=42, + steps=35, + clamp_max=1.0 +): + """ + Génération de latents animés 5D (B,C,H,W) + """ + + torch.manual_seed(seed) + + latent_model_input = latent_frame.clone() + + # Bruit créatif + if creative_noise > 0.0: + noise = torch.randn_like(latent_model_input) * creative_noise + latent_model_input += noise + + # Motion module + if motion_module: + def motion_wrapper(latents, timestep=None): + try: + return motion_module(latents) + except TypeError: + return motion_module(latents) + else: + motion_wrapper = lambda x, t=None: x + + # Scheduler + scheduler.set_timesteps(steps, device=device) + timesteps = scheduler.timesteps + + for t in timesteps: + latent_model_input = latent_model_input.to(device=device, dtype=dtype) + + # Motion module + latent_model_input = motion_wrapper(latent_model_input, t) + + # Init image scale mix + if init_image_scale < 1.0: + latent_model_input = latent_model_input * init_image_scale + latent_frame * (1 - init_image_scale) + + # Guidance batch + batch_size = pos_embeds.shape[0] + latent_model_input_in = latent_model_input.repeat(batch_size * 2, 1, 1, 1) + embeds = torch.cat([pos_embeds, neg_embeds], dim=0) + + # UNet + noise_pred = unet(latent_model_input_in, t, encoder_hidden_states=embeds).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Scheduler step + latent_model_input = scheduler.step(noise_pred, t, latent_model_input).prev_sample + + # Clamp + latent_model_input = latent_model_input.clamp(-clamp_max, clamp_max) + + return latent_model_input + +def generate_latents_ai_5D_256( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.75, + creative_noise=0.1, seed=0, steps=20 +): + """ + Génère des latents animés optimisés pour AnimateDiff avec logs détaillés. + Correctifs pour éviter fonds blancs et contours néons. + """ + torch.manual_seed(seed) + + # Scheduler + scheduler.set_timesteps(steps, device=device) + + # Déplacer latents et embeddings + latents = latents.to(device=device, dtype=dtype) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + # Ajout de bruit initial avec minimum 0.1 pour contraste + noise = torch.randn_like(latents) * max(creative_noise, 0.1) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0:1]) + print(f"[Init Noise] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + print(f"🔥 Seed: {seed}, Steps: {steps}, Guidance: {guidance_scale}, Init_scale: {init_image_scale}, Noise: {creative_noise}") + print(f"[Init] Latents shape: {latents.shape}, min: {latents.min():.4f}, max: {latents.max():.4f}") + print(f"Embeddings shape: {embeds.shape}, batch_size: {embeds.shape[0]}") + + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + for i, t in enumerate(scheduler.timesteps): + + # Motion module + if motion_module is not None: + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + print(f"[Step {i} Motion] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Guidance + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step (clip_sample désactivé pour éviter fonds blancs) + if "clip_sample" in scheduler.step.__code__.co_varnames: + latents = scheduler.step(noise_pred, t, latents, clip_sample=False).prev_sample + else: + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # Logs avant clamp + print(f"[Step {i} pre-clamp] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Clamp léger pour éviter overflow fp16 + LATENT_CLAMP = 1.5 + latents = latents.clamp(-LATENT_CLAMP, LATENT_CLAMP) + + # Logs après clamp + print(f"[Step {i} post-clamp] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Amplification finale pour contraste avant VAE + latents = latents * 1.2 + latents = latents.clamp(-LATENT_CLAMP, LATENT_CLAMP) + + return latents + +def generate_latents_ai_5D_optimized_v1( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, + creative_noise=0.0, seed=0, steps=20 +): + """ + Génère des latents animés optimisés pour AnimateDiff avec logs détaillés. + ✅ Modification : aucun clamp ou tanh sur les latents pour éviter fonds blancs et néons. + """ + torch.manual_seed(seed) + + # Scheduler + scheduler.set_timesteps(steps, device=device) + + # Déplacer latents et embeddings + latents = latents.to(device=device, dtype=dtype) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + # Ajout de bruit initial si demandé + noise = torch.randn_like(latents) * creative_noise if creative_noise > 0 else torch.zeros_like(latents) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0:1]) + print(f"[Init Noise] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + print(f"🔥 Seed: {seed}, Steps: {steps}, Guidance: {guidance_scale}, Init_scale: {init_image_scale}, Noise: {creative_noise}") + print(f"[Init] Latents shape: {latents.shape}, min: {latents.min():.4f}, max: {latents.max():.4f}") + print(f"Embeddings shape: {embeds.shape}, batch_size: {embeds.shape[0]}") + + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + for i, t in enumerate(scheduler.timesteps): + + # Motion module + if motion_module is not None: + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + print(f"[Step {i} Motion] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Guidance + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # Logs uniquement (pas de clamp) + print(f"[Step {i}] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + return latents + + +def decode_latents_correct(latents, vae, latent_scale=0.18215): + + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + # Move + rescale latents + latents = latents.to(device=vae_device, dtype=vae_dtype) / latent_scale + + with torch.no_grad(): + img = vae.decode(latents).sample + + # ✅ Conversion SD correcte + img = (img / 2 + 0.5).clamp(0.0, 1.0) + + return img.cpu() + + +def decode_latents_correct_v1(latents, vae, latent_scale=0.18215): + """ + Décodage correct des latents AnimateDiff. + - latents: sortie UNet brute + - vae: modèle VAE chargé + - latent_scale: scale exact du VAE (souvent 0.18215 pour SD/miniSD) + """ + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + # Move latents sur le device et dtype du VAE + if vae_device.type == "cpu": + latents = latents.to(device=vae_device, dtype=torch.float32) / latent_scale + else: + latents = latents.to(device=vae_device, dtype=vae_dtype) / latent_scale + + with torch.no_grad(): + img = vae.decode(latents).sample + img = img.clamp(0.0, 1.0) + + return img.cpu() + + +def decode_latents_safe_ai(latents, vae, tile_size=128, overlap=64): + """ + Décodage sécurisé des latents pour AnimateDiff. + - Compatible fp16/fp32 et VAE-offload. + - Évite l’effet néon / fonds saturés. + - Clamp léger pour prévenir overflow, sans écraser la dynamique. + """ + vae_device = next(vae.parameters()).device + + # 1️⃣ Assurer le bon device et dtype + dtype = torch.float16 if vae_device.type == "cuda" else torch.float32 + latents = latents.to(device=vae_device, dtype=dtype) + + # 2️⃣ Éviter NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=10.0, neginf=-10.0) + + # 3️⃣ Clamp large pour fp16 (prévention overflow) + LATENT_CLAMP = 5.0 + latents = latents.clamp(-LATENT_CLAMP, LATENT_CLAMP) + + # 4️⃣ Décodage VAE + with torch.no_grad(): + # Décode le latent + frame_tensor = vae.decode(latents).sample + + # Clamp final entre 0 et 1 pour image + frame_tensor = frame_tensor.clamp(0.0, 1.0) + + return frame_tensor.cpu() + + + +def decode_latents_safe_torch(latent_frame, vae, tile_size=128, overlap=64, latent_scale=LATENT_SCALE): + vae_device = next(vae.parameters()).device + dtype = torch.float16 if vae_device.type == "cuda" else torch.float32 + latents = latent_frame.to(device=vae_device, dtype=dtype) / latent_scale + + # Éviter NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=10.0, neginf=-10.0) + + # Optionnel : clamp large pour fp16 + # latents = latents.clamp(-10.0, 10.0) + + with torch.no_grad(): + frame_tensor = vae.decode(latents).sample + frame_tensor = frame_tensor.clamp(0.0, 1.0) + + return frame_tensor.cpu() + + +def decode_latents_safe_clamp(latent_frame, vae, tile_size=128, overlap=64, latent_scale=LATENT_SCALE): + """ + Décodage sécurisé des latents en tensor float16/float32, compatible fp16/fp32 et VAE-offload. + Découpe en tiles pour limiter l’usage VRAM. + """ + vae_device = next(vae.parameters()).device + + # 1️⃣ Forcer dtype float16 si GPU, sinon float32 CPU + dtype = torch.float16 if vae_device.type == "cuda" else torch.float32 + latents = latent_frame.to(device=vae_device, dtype=dtype) / latent_scale + + # 2️⃣ Éviter NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=-1.0) + + # 3️⃣ Clamp pour éviter overflow fp16 + LATENT_CLAMP = 1.5 + latents = latents.clamp(-LATENT_CLAMP, LATENT_CLAMP) + + # 4️⃣ Décodage tile par tile si nécessaire (ici on simplifie pour petits latents) + B, C, H, W = latents.shape + with torch.no_grad(): + # Décode et force sortie entre 0 et 1 + frame_tensor = vae.decode(latents).sample + frame_tensor = frame_tensor.clamp(0.0, 1.0) + + return frame_tensor.cpu() + + + +def decode_latents_safe_ai1(latent_frame, vae, tile_size=128, overlap=64, latent_scale=LATENT_SCALE): + """ + Décodage sécurisé des latents en tensor float32, compatible fp16/fp32 et VAE-offload. + Découpe en tiles pour limiter l’usage VRAM. + """ + vae_device = next(vae.parameters()).device + latents = latent_frame.to(device=vae_device, dtype=torch.float32) / latent_scale # float32 obligatoire pour VAE + + # Éviter NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=0.0) + + # Décodage tile par tile si nécessaire + B, C, H, W = latents.shape + frame_tensor = torch.zeros(B, 3, H*16, W*16, device=vae_device) # dimension finale approximative + # Pour AnimateDiff on peut simplifier avec une seule passe si pas très grand + with torch.no_grad(): + frame_tensor = vae.decode(latents).sample.clamp(0, 1) + + return frame_tensor.cpu() + +# ------------------------- +# Encode image en latents +# ------------------------- +def encode_image_latents(image_tensor, vae, scale=0.18215, dtype=torch.float16): + """ + Encode une image en latents avec VAE. + """ + vae_device = next(vae.parameters()).device + img = image_tensor.to(device=vae_device, dtype=torch.float32 if vae_device.type=="cpu" else dtype) + with torch.no_grad(): + latents = vae.encode(img).latent_dist.sample() * scale + print(f"[Encode] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + return latents.unsqueeze(2) # [B,C,1,H,W] +# ------------------------- +# Génération latents animés optimisés +# ------------------------- +def generate_latents_ai_5D_optimized_old( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, + creative_noise=0.0, seed=0, steps=20 +): + """ + Génère des latents animés optimisés pour AnimateDiff avec logs détaillés. + """ + torch.manual_seed(seed) + + # Scheduler + scheduler.set_timesteps(steps, device=device) + + # Déplacer latents et embeddings + latents = latents.to(device=device, dtype=dtype) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + # Ajout de bruit initial + noise = torch.randn_like(latents) * creative_noise if creative_noise > 0 else torch.zeros_like(latents) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0:1]) + print(f"[Init Noise] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + print(f"🔥 Seed: {seed}, Steps: {steps}, Guidance: {guidance_scale}, Init_scale: {init_image_scale}, Noise: {creative_noise}") + print(f"[Init] Latents shape: {latents.shape}, min: {latents.min():.4f}, max: {latents.max():.4f}") + print(f"Embeddings shape: {embeds.shape}, batch_size: {embeds.shape[0]}") + + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + for i, t in enumerate(scheduler.timesteps): + + # Motion module + if motion_module is not None: + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + print(f"[Step {i} Motion] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Guidance + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # Logs avant clamp + print(f"[Step {i} pre-clamp] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Clamp progressif pour éviter overflow + #clamp_val = max(1.0, init_image_scale * 10) + #latents = latents.clamp(-clamp_val, clamp_val) + + # Clamp fixe à ±1.0 + #CLAMP_MAX = 1.0 + #latents = latents.clamp(-CLAMP_MAX, CLAMP_MAX) + + # Calculer max absolu actuel + #max_abs = latents.abs().max() + #if max_abs > 1.0: + # latents = latents / max_abs # ramène les latents dans [-1, 1] + + latents = torch.tanh(latents) # comprime automatiquement dans [-1, 1] + + # Logs après clamp + print(f"[Step {i} post-clamp] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + return latents + +# ------------------------- +# Décodage sécurisé +# ------------------------- +def decode_latents_safe(latents, vae, device="cuda", tile_size=128, overlap=64): + """ + Décodage sécurisé des latents en tensor float32 sur CPU. + - Compatible avec vae_offload (CPU) et latents GPU + """ + vae_device = next(vae.parameters()).device + latents = latents.to(vae_device).float() + + # Nettoyage NaN/Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=0.0) + print(f"[Decode start] Latents shape: {latents.shape}, min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Décodage en tiles (pour VRAM limitée) + frame_tensor = decode_latents_to_image_tiled128( + latents, + vae, + tile_size=tile_size, + overlap=overlap, + device=vae_device + ).clamp(0, 1) + + print(f"[Decode end] Frame tensor min: {frame_tensor.min():.4f}, max: {frame_tensor.max():.4f}") + + return frame_tensor.cpu() + +def generate_latents_ai_5D_debug_opti( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, + creative_noise=0.0, seed=0, steps=20 +): + """ + Génère des latents animés optimisés pour AnimateDiff. + """ + import torch + + torch.manual_seed(seed) + + # Configure le scheduler + scheduler.set_timesteps(steps, device=device) + + # Déplace latents et embeddings sur device + latents = latents.to(device=device, dtype=dtype) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + # Ajout de bruit initial contrôlé par creative_noise + if creative_noise > 0.0: + noise = torch.randn_like(latents) * creative_noise + else: + noise = torch.zeros_like(latents) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0:1]) + + print(f"🔥 Seed: {seed}, Steps: {steps}, Guidance scale: {guidance_scale}, Init scale: {init_image_scale}, Creative noise: {creative_noise}") + print(f"[Init] Latents shape: {latents.shape}, min: {latents.min():.4f}, max: {latents.max():.4f}") + print(f"Embeddings shape: {embeds.shape}, batch_size: {embeds.shape[0]}") + + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + for i, t in enumerate(scheduler.timesteps): + # Motion module (optionnel) + if motion_module is not None: + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + print(f"[Step {i}] t: {t:.4f} | [Motion] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Guidance + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Step scheduler et scaling progressif pour éviter la compression + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # Scaling progressif pour garder amplitude cohérente + clamp_val = max(1.0, init_image_scale * 10) + latents = latents.clamp(-clamp_val, clamp_val) + + print(f"[Step {i} post-step] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + return latents + +def generate_latents_ai_5D_brouillard( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, + creative_noise=0.0, seed=0, steps=20 +): + """ + Génère des latents animés 5D à partir d'une image de base avec logs. + """ + + torch.manual_seed(seed) + print(f"🔥 Seed: {seed}, Steps: {steps}, Guidance scale: {guidance_scale}, Init scale: {init_image_scale}, Creative noise: {creative_noise}") + + # ------------------------------------- + # Scheduler et timesteps + # ------------------------------------- + scheduler.set_timesteps(steps, device=device) + print(f"Scheduler timesteps: {scheduler.timesteps}") + + # ------------------------------------- + # Initialisation des latents + # ------------------------------------- + latents = latents.to(device=device, dtype=torch.float32) + noise = torch.randn_like(latents) * (creative_noise if creative_noise > 0 else 1.0) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0:1]) + print(f"[Init] Latents shape: {latents.shape}, min: {latents.min():.4f}, max: {latents.max():.4f}") + + # ------------------------------------- + # Concat embeddings une seule fois + # ------------------------------------- + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + batch_size = embeds.shape[0] + print(f"Embeddings shape: {embeds.shape}, batch_size: {batch_size}") + + # ------------------------------------- + # Boucle principale + # ------------------------------------- + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + for step_idx, t in enumerate(scheduler.timesteps): + print(f"\n[Step {step_idx}] t: {t}") + + # Motion module si disponible + if motion_module is not None: + if latents.dim() == 4: # [B, C, H, W] -> ajouter F=1 + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + print(f"[Motion] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + # Répéter latents pour correspondre au batch d'embeddings + if latents.shape[0] != batch_size: + repeats = batch_size // latents.shape[0] + latent_model_input = latents.repeat(repeats, 1, 1, 1) + else: + latent_model_input = latents + + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # UNet + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + # CFG : split batch + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Step scheduler + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents.clamp(-20, 20) # pour stabilité + + print(f"[Step {step_idx} post-step] Latents min: {latents.min():.4f}, max: {latents.max():.4f}") + + return latents.to(dtype) + +def generate_latents_ai_5D_testcreative( + latents, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=5.0, + init_image_scale=0.6, + creative_noise=0.0, # <-- AJOUTE ÇA + seed=0, + steps=12 +): + torch.manual_seed(seed) + + scheduler.set_timesteps(steps, device=device) + + latents = latents.to(device=device, dtype=dtype) + + # ---------------- IMG2IMG strength ---------------- + t_start = int(steps * init_image_scale) + t_start = min(t_start, steps - 1) + + timesteps = scheduler.timesteps[t_start:] + t_noise = timesteps[0:1] + + noise = torch.randn_like(latents) + + latents = scheduler.add_noise( + latents, + noise, + t_noise + ) + + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + + for t in timesteps: + + if motion_module is not None: + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + + noise_pred = noise_uncond + guidance_scale * ( + noise_text - noise_uncond + ) + + latents = scheduler.step(noise_pred, t, latents).prev_sample + + return latents + + + +def generate_latents_ai_5D_optitest( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, + creative_noise=0.0, seed=0, steps=20 +): + torch.manual_seed(seed) + + # 🔥 IMPORTANT : config réelle des steps + scheduler.set_timesteps(steps, device=device) + + # Utilise dtype natif (fp16 si cuda) + latents = latents.to(device=device, dtype=dtype) + #latents = latents * scheduler.init_noise_sigma + noise = torch.randn_like(latents) + latents = scheduler.add_noise( + latents, + noise, + scheduler.timesteps[0:1] + ) + # ----------------------------------------------------- + # 🔥 concat embeddings UNE SEULE FOIS + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype) + + with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype): + + for t in scheduler.timesteps: + + if motion_module: + #latents = motion_module.apply(latents, t) + if motion_module is not None: + # Si latents = [B,C,H,W], on ajoute F=1 + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + + # CFG batching propre + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + + noise_pred = noise_uncond + guidance_scale * ( + noise_text - noise_uncond + ) + + latents = scheduler.step(noise_pred, t, latents).prev_sample + + return latents + + +# ------------------------- +# Décodage hybride safe VAE +# ------------------------- +def decode_latents_hybrid(vae, latents, scale=LATENT_SCALE): + """ + Décodage sécurisé des latents pour VAE fp32 + latents fp16. + latents : Tensor GPU (fp16 ou fp32) + vae : modèle VAE (fp32 ou offload) + scale : facteur d'échelle des latents + """ + # Sauvegarder dtype des latents + lat_dtype = latents.dtype + + # Conversion en fp32 pour VAE + latents_fp32 = latents.to(torch.float32) + + with torch.no_grad(): + decoded = vae.decode(latents_fp32 / scale).sample + + # Clamp pour éviter valeurs extrêmes + decoded = decoded.clamp(-1, 1) + + # Retourner au dtype original pour GPU/fp16 si besoin + return decoded.to(lat_dtype) + + +# --- PIPELINE PRINCIPALE --- +def generate_5D_video_auto(pretrained_model_path, config, device='cuda'): + print("🔄 Chargement des modèles...") + motion_module = MotionModuleTiny(device=device) + scheduler = init_scheduler(config) # ta fonction existante + vae = load_vae(pretrained_model_path, device=device) + + total_frames = config['total_frames'] + fps = config['fps'] + H_src, W_src = config['image_size'] # résolution source + + # Génère les latents initiaux + latents = torch.randn(1, 4, H_src//8, W_src//8, device=device, dtype=torch.float16) + print(f"[INFO] Latents initiaux shape={latents.shape}") + + video_frames = [] + for t in range(total_frames): + try: + latents = motion_module.step(latents, t) + frame = decode_latents_frame_auto(latents, vae, H_src, W_src) + video_frames.append(frame) + except Exception as e: + print(f"⚠ Erreur frame {t:05d} → reset léger: {e}") + continue + + save_video(video_frames, fps, output_path=config['output_path']) + print(f"🎬 Vidéo générée : {config['output_path']}") + +# ------------------------- +# Load images utility +# ------------------------- +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}, shape={t.shape}, dtype={t.dtype}, device={t.device}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +# ------------------------- +# Mémoire GPU utils +# ------------------------- +def log_gpu_memory(tag=""): + if torch.cuda.is_available(): + print(f"[GPU MEM] {tag} → allocated={torch.cuda.memory_allocated()/1e6:.1f}MB, " + f"reserved={torch.cuda.memory_reserved()/1e6:.1f}MB, " + f"max_allocated={torch.cuda.max_memory_allocated()/1e6:.1f}MB") + +def decode_latents_frame_auto(latents, vae, H_src, W_src): + """ + Decode des latents VAE en images avec tiles 128x128, auto-adapté à la taille source. + """ + device = vae.device + print(f"[VAE] Decode → tile_size={tile_size}, overlap={overlap}, device={device}, latents.shape={latents.shape}") + log_gpu_memory("avant decode VAE") + + # Assure batch 4D + latents = latents.unsqueeze(0) if latents.dim() == 3 else latents + + # Décodage VAE en tiles + with torch.no_grad(): + frame_tensor = decode_latents_to_image_tiled( + latents, + vae, + tile_size=tile_size, + overlap=overlap + ).clamp(0,1) + + # Redimensionnement proportionnel à l'image source + H_out, W_out = H_src, W_src + if frame_tensor.shape[-2:] != (H_out, W_out): + frame_tensor = torch.nn.functional.interpolate( + frame_tensor, + size=(H_out, W_out), + mode='bicubic', + align_corners=False + ) + + log_gpu_memory("après decode VAE") + return frame_tensor.squeeze(0) +# ------------------------------------------------------------------- +# ------------- PATCH Stable Diffusion -------- +# ------------------------------------------------------------------ + +def generate_latents_ai_5D_std( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, creative_noise=0.0, seed=0 +): + torch.manual_seed(seed) + + # Toujours float32 pour stabilité du scheduler + latents = latents.to(device=device, dtype=torch.float32) + latents = latents * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + + # Sécurité NaN / Inf + if torch.isnan(latents).any() or torch.isinf(latents).any(): + latents = torch.randn_like(latents) * 0.1 + + # Si Motion Module est actif + if motion_module: + #latents = motion_module.apply(latents, t) + if motion_module is not None: + # Si latents = [B,C,H,W], on ajoute F=1 + if latents.dim() == 4: + latents = motion_module(latents.unsqueeze(2)).squeeze(2) + else: + latents = motion_module(latents) + # Motion Module fin + + # 🔥 Préparer batch pour guidance + batch_size = pos_embeds.shape[0] + neg_embeds.shape[0] + if latents.shape[0] != batch_size: + # Répéter latents pour matcher embeddings + repeats = batch_size // latents.shape[0] + latent_model_input = latents.repeat(repeats, 1, 1, 1) + else: + latent_model_input = latents + + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input.to(dtype=dtype) + + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(dtype=dtype) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + # Guidance : séparer le batch + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond.float() + guidance_scale * ( + noise_text.float() - noise_uncond.float() + ) + + # Step scheduler + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents.clamp(-20, 20) + + # Retour dtype original + return latents.to(dtype) + +# ------------------------------------------------------------------- +# ------------- 5 D Original -------- +# ------------------------------------------------------------------ + +def generate_latents_ai_5D( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, creative_noise=0.0, seed=0 +): + torch.manual_seed(seed) + + latents = latents.to(device=device, dtype=torch.float32) + latents = latents * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + + if torch.isnan(latents).any() or torch.isinf(latents).any(): + latents = torch.randn_like(latents) * 0.1 + + # 🔥 DUPLICATION POUR GUIDANCE + latent_model_input = torch.cat([latents, latents], dim=0) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input.to(dtype=dtype) + + embeds = torch.cat([neg_embeds, pos_embeds], dim=0).to(dtype=dtype) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + + noise_pred = noise_uncond.float() + guidance_scale * ( + noise_text.float() - noise_uncond.float() + ) + + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents.clamp(-20, 20) + + return latents.to(dtype) + + +def generate_latents_ai_5D_v1( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, creative_noise=0.0, seed=0 +): + torch.manual_seed(seed) + latents = latents.to(device=device, dtype=torch.float32) # scheduler en float32 + latents = latents * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + if torch.isnan(latents).any() or torch.isinf(latents).any(): + latents = torch.randn_like(latents) * 0.1 + + # préparation input UNet 5D + latent_model_input = scheduler.scale_model_input(latents, t) + latent_model_input = latent_model_input.to(dtype=dtype) # UNet en fp16 si demandé + + # concat embeddings pour guidance + embeds = torch.cat([neg_embeds, pos_embeds]) + embeds = embeds.to(dtype=dtype) + + with torch.no_grad(): + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + + # guidance + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond.float() + guidance_scale * (noise_text.float() - noise_uncond.float()) + + # scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # clamp anti-explosion + latents = latents.clamp(-20, 20) + + return latents.to(dtype) + + +# ------------------------- +# UNet 3D latents generation +# ------------------------- +def generate_latents_3d( + latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, creative_noise=0.0, seed=0 +): + torch.manual_seed(seed) + + # Toujours float32 pour scheduler + latents = latents.to(device=device, dtype=torch.float32) + + if init_image_scale < 1.0: + noise = torch.randn_like(latents) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0]) + else: + latents = latents * scheduler.init_noise_sigma + + latents = latents.clamp(-10, 10) + + for t in scheduler.timesteps: + + if torch.isnan(latents).any() or torch.isinf(latents).any(): + print(f"⚠ NaN détecté à timestep {int(t)} → reset léger") + latents = torch.randn_like(latents) * 0.1 + + # 3D CFG concat + latent_model_input = torch.cat([latents]*2) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + embeds = torch.cat([neg_embeds, pos_embeds]) + + if motion_module: + latent_model_input = motion_module.apply(latent_model_input) + + with torch.no_grad(): + noise_pred = unet(latent_model_input.to(dtype), t, encoder_hidden_states=embeds).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond.float() + guidance_scale * (noise_text.float() - noise_uncond.float()) + + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents.clamp(-20,20) + + mean_val = latents.abs().mean().item() + if math.isnan(mean_val) or mean_val < 1e-6: + print(f"⚠ Latent instable à timestep {int(t)} → reset") + latents = torch.randn_like(latents) * 0.05 + + return latents.to(dtype) + +# --------------------------------------------------------- +# Génération de latents AI compatible Fp16 et Fp32 +# --------------------------------------------------------- + +def generate_latents_ai( + latents, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=4.5, + init_image_scale=0.85, + creative_noise=0.0, + seed=0, +): + torch.manual_seed(seed) + + use_fp16 = dtype == torch.float16 and device == "cuda" + + # ------------------------------------------------ + # Toujours garder les latents scheduler en float32 + # ------------------------------------------------ + latents = latents.to(device=device, dtype=torch.float32) + + # ------------------------------------------------ + # Initialisation correcte (image vs text mode) + # ------------------------------------------------ + if init_image_scale < 1.0: + noise = torch.randn_like(latents) + latents = scheduler.add_noise(latents, noise, scheduler.timesteps[0]) + else: + latents = latents * scheduler.init_noise_sigma + + latents = latents.clamp(-10, 10) + + for t in scheduler.timesteps: + + # ------------------------------------------------ + # Sécurité anti-NaN + # ------------------------------------------------ + if torch.isnan(latents).any() or torch.isinf(latents).any(): + print(f"⚠ NaN détecté à timestep {int(t)} → reset léger") + latents = torch.randn_like(latents) * 0.1 + + # ------------------------------------------------ + # CFG concat (plus stable et plus rapide) + # ------------------------------------------------ + latent_model_input = torch.cat([latents] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + embeds = torch.cat([neg_embeds, pos_embeds]) + + # UNet en fp16 si activé + model_input = latent_model_input.to(dtype if use_fp16 else torch.float32) + + with torch.no_grad(): + noise_pred = unet( + model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + + # ------------------------------------------------ + # Guidance toujours en float32 pour stabilité + # ------------------------------------------------ + noise_pred = noise_uncond.float() + guidance_scale * ( + noise_text.float() - noise_uncond.float() + ) + + # ------------------------------------------------ + # Step scheduler en float32 + # ------------------------------------------------ + latents = scheduler.step( + noise_pred, + t, + latents + ).prev_sample + + latents = latents.clamp(-20, 20) + + mean_val = latents.abs().mean().item() + if math.isnan(mean_val) or mean_val < 1e-6: + print(f"⚠ Latent instable à timestep {int(t)} → reset") + latents = torch.randn_like(latents) * 0.05 + + # Retour au dtype demandé + return latents.to(dtype if use_fp16 else torch.float32) + + + +# --------------------------------------------------------- +# Génération de latents par bloc SAFE avec logs +# --------------------------------------------------------- +def generate_latents_2(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, + device="cuda", dtype=torch.float16, guidance_scale=4.5, init_image_scale=0.85): + """ + latents: [B, C, F, H, W] + pos_embeds / neg_embeds: [B, L, D] + """ + torch.manual_seed(42) + B, C, F, H, W = latents.shape + + # ⚡ Assurer dtype/device compatibilité UNet + unet_dtype = next(unet.parameters()).dtype + latents = latents.to(device=device, dtype=unet_dtype) + pos_embeds = pos_embeds.to(device=device, dtype=unet_dtype) + if neg_embeds is not None: + neg_embeds = neg_embeds.to(device=device, dtype=unet_dtype) + + motion_module = motion_module.to(device=device, dtype=unet_dtype) if motion_module else None + + # Reshape pour timesteps + latents = latents.permute(0, 2, 1, 3, 4).reshape(B*F, C, H, W).contiguous() + init_latents = latents.clone() + + for t in scheduler.timesteps: + try: + if motion_module: + latents = motion_module(latents) + + # classifier-free guidance + latent_model_input = torch.cat([latents] * 2) + embeds = torch.cat([neg_embeds, pos_embeds]) if neg_embeds is not None else pos_embeds + + # Vérification NaN avant UNet + if torch.isnan(latents).any(): + print(f"❌ Warning: NaN detected in latents before UNet | t={t}") + + with torch.no_grad(): + noise_pred = unet(latent_model_input, timestep=t, encoder_hidden_states=embeds).sample + + if neg_embeds is not None: + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # ⚡ conserver influence de l'image initiale + latents = latents + init_image_scale * (init_latents - latents) + + # Logs min/max + print(f"🔹 Step t={t} | latents min: {latents.min():.6f}, max: {latents.max():.6f}") + + # Stop si NaN + if torch.isnan(latents).any(): + raise RuntimeError(f"NaN detected in latents after UNet at timestep {t}") + + except Exception as e: + print(f"❌ Erreur UNet/scheduler à t={t}: {e}") + # On peut remplacer par latents initiaux pour continuer + latents = init_latents.clone() + torch.cuda.empty_cache() + + # Reshape final + latents = latents.reshape(B, F, C, H, W).permute(0, 2, 1, 3, 4).contiguous() + return latents + + +# --------------------------------------------------------- +# Tuilage sécurisé +# --------------------------------------------------------- +def decode_latents_to_image_tiled(latents, vae, tile_size=32, overlap=8): + """ + Decode VAE en tuiles avec couverture complète garantie. + - Aucun trou possible + - Blending propre + - Stable mathématiquement + """ + + device = vae.device + latents = latents.to(device).float() / LATENT_SCALE + + B, C, H, W = latents.shape + stride = tile_size - overlap + + # Dimensions image finale (scale factor VAE = 8) + out_H = H * 8 + out_W = W * 8 + + output = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output) + + # --- positions garanties --- + y_positions = list(range(0, H - tile_size + 1, stride)) + x_positions = list(range(0, W - tile_size + 1, stride)) + + if not y_positions: + y_positions = [0] + if not x_positions: + x_positions = [0] + + if y_positions[-1] != H - tile_size: + y_positions.append(H - tile_size) + + if x_positions[-1] != W - tile_size: + x_positions.append(W - tile_size) + + for y in y_positions: + for x in x_positions: + + y1 = y + tile_size + x1 = x + tile_size + + tile = latents[:, :, y:y1, x:x1] + + with torch.no_grad(): + decoded = vae.decode(tile).sample + + decoded = (decoded / 2 + 0.5).clamp(0, 1) + + iy0 = y * 8 + ix0 = x * 8 + iy1 = y1 * 8 + ix1 = x1 * 8 + + output[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + return output / weight.clamp(min=1e-6) +# ------------------------- +# Génération tuilée 128x128 ultra safe VRAM +# ------------------------- +def decode_latents_to_image_tiled_ori(latents, vae, tile_size=32, overlap=8): + """ + Decode VAE en tuiles côté latent. + Stable, sans toucher au scheduler. + latents: [B, 4, H, W] + """ + + device = vae.device + latents = latents.to(device).float() / LATENT_SCALE + + B, C, H, W = latents.shape + stride = tile_size - overlap + + # Taille image finale (VAE scale factor = 8) + out_H = H * 8 + out_W = W * 8 + + output = torch.zeros(B, 3, out_H, out_W, device="cpu") + weight = torch.zeros_like(output) + + for y in range(0, H, stride): + for x in range(0, W, stride): + + y1 = min(y + tile_size, H) + x1 = min(x + tile_size, W) + + tile = latents[:, :, y:y1, x:x1] + + with torch.no_grad(): + decoded = vae.decode(tile).sample + + decoded = (decoded / 2 + 0.5).clamp(0, 1) + + # coordonnées en image space + iy0 = y * 8 + ix0 = x * 8 + iy1 = y1 * 8 + ix1 = x1 * 8 + + output[:, :, iy0:iy1, ix0:ix1] += decoded.cpu() + weight[:, :, iy0:iy1, ix0:ix1] += 1 + + return (output / weight.clamp(min=1e-6)).to(device) + + +# --------------------------------------------------------- +# Chargement image unique → [1,C,1,H,W] +# --------------------------------------------------------- +def load_input_image(image_path, W, H, device, dtype): + img = Image.open(image_path).convert("RGB") + + preprocess = transforms.Compose([ + transforms.Resize((H, W)), + transforms.ToTensor(), + transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]) + ]) + + img = preprocess(img).unsqueeze(0).unsqueeze(2) # [1,C,1,H,W] + return img.to(device=device, dtype=dtype) + + +# --------------------------------------------------------- +# Encode images → latents +# input: [B,C,T,H,W] +# output: [B,4,T,H_lat,W_lat] +# --------------------------------------------------------- +def encode_images(input_images, vae): + device = next(vae.parameters()).device + dtype = next(vae.parameters()).dtype + + B, C, T, H, W = input_images.shape + input_images = input_images.to(device=device, dtype=dtype) + + latents_list = [] + + with torch.no_grad(): + for t in range(T): + imgs_2d = input_images[:, :, t, :, :] # [B,C,H,W] + latent = vae.encode(imgs_2d).latent_dist.sample() + latent = latent * LATENT_SCALE + latents_list.append(latent) + + latents = torch.stack(latents_list, dim=2) # [B,4,T,H_lat,W_lat] + return latents + +#--------------------------------------------------------- +# VERSION ROBUSTE +# # Ajouter un peu de bruit créatif initial si demandé +# # Reshape pour traitement UNet [B*T, C, H, W] +#--------------------------------------------------------- + + +from functools import wraps +import torch +from contextlib import contextmanager + +# 🔹 Context manager pour désactiver xFormers temporairement +@contextmanager +def disable_xformers(unet): + """ + Désactive memory-efficient xFormers attention si disponible. + Utile pour éviter les erreurs FakeTensor pendant le safe encode/fallback. + """ + orig = getattr(unet, "enable_xformers_memory_efficient_attention", None) + if callable(orig): + # Désactiver temporairement + unet.enable_xformers_memory_efficient_attention(False) + yield + if callable(orig): + # Réactiver + unet.enable_xformers_memory_efficient_attention(True) + + +def generate_latents_safe_test(unet, **kwargs): + """ + Génération de latents avec UNet, FP16-safe et debug. + Affiche min/max à chaque étape pour détecter si les latents restent à zéro. + """ + + # Arguments connus + known_kwargs = [ + "scheduler", "input_latents", "embeddings", "motion_module", + "guidance_scale", "device", "fp16", "steps", "debug" + ] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in known_kwargs} + + # Required + input_latents = filtered_kwargs.get("input_latents") + if input_latents is None: + raise ValueError("⚠️ 'input_latents' doit être fourni") + + device = filtered_kwargs.get("device", "cuda") + fp16 = filtered_kwargs.get("fp16", True) + debug = filtered_kwargs.get("debug", False) + motion_module = filtered_kwargs.get("motion_module", None) + scheduler = filtered_kwargs.get("scheduler", None) + steps = filtered_kwargs.get("steps", 20) + embeddings = filtered_kwargs.get("embeddings", None) + + # Latents initiaux + latents = input_latents.clone().to(device=device, dtype=torch.float16 if fp16 else torch.float32) + is_video = latents.ndim == 5 + steps_list = getattr(scheduler, "timesteps", range(steps)) + + if debug: + print(f"[DEBUG] Initial latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + # Désactiver Dynamo/Inductor pour debug + torch._dynamo.reset() + torch._dynamo.disable() + + for step_idx, t in enumerate(steps_list): + if motion_module is not None: + latents = motion_module(latents) + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + if is_video: + B, C, F, H, W = latents.shape + latents_unet = latents.permute(0, 2, 1, 3, 4).reshape(B*F, C, H, W) + else: + latents_unet = latents + + # FP16 si demandé + latents_unet = latents_unet.half() if fp16 else latents_unet.float() + + # 🔹 Forward UNet + try: + unet_out = unet(latents_unet, t, encoder_hidden_states=embeddings) + latents_out = unet_out["sample"] + except Exception as e: + print(f"⚠️ [UNet ERROR] step={step_idx} - {e}") + continue + + # Reshape si vidéo + if is_video: + latents = latents_out.reshape(B, F, C, H, W).permute(0, 2, 1, 3, 4) + else: + latents = latents_out + + # Nettoyage et clamp + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0).clamp(-5.0, 5.0) + + if debug: + print(f"[DEBUG Step {step_idx}] latents min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + if debug: + print(f"[INFO] Finished latents generation, final min/max={latents.min().item():.4f}/{latents.max().item():.4f}") + + return latents + + +# 🔹 Wrapper général sûr +def generate_latents_safe_wrapper_test(unet, **kwargs): + try: + return generate_latents_safe_test(unet, **kwargs) + except TypeError as e: + if "meta__efficient_attention_backward() got multiple values for argument 'scale'" in str(e): + print("⚠️ [SAFE WRAPPER] Efficient attention error caught, returning input latents") + return kwargs.get("input_latents").clone() + else: + raise e + +# ---------------- SAFE WRAPPER ---------------- + + +def generate_latents_robuste_safe(latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=7.5, init_image_scale=2.0, + creative_noise=0.0, seed=42): + + B, C_orig, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + + latents = latents.permute(0,2,1,3,4).reshape(B*T, C_orig, H, W).contiguous() + init_latents = latents.clone() + + torch.manual_seed(seed) + + # 🔹 IMPORTANT pour DPMSolver + scheduler._step_index = None + + for t_step in scheduler.timesteps: + + if motion_module: + latents = motion_module(latents) + + if creative_noise > 0: + latents = latents + torch.randn_like(latents) * creative_noise + + latent_model_input = torch.cat([latents, latents], dim=0) + + if latent_model_input.shape[1] == 1: + latent_model_input = latent_model_input.repeat(1,4,1,1) + + embeds = torch.cat([neg_embeds, pos_embeds], dim=0) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t_step, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + step_output = scheduler.step( + model_output=noise_pred, + timestep=t_step, + sample=latents + ) + + latents = getattr(step_output, "prev_sample", step_output) + + latents = latents + init_image_scale * (init_latents - latents) + + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + latents = latents.clamp(-5.0,5.0) + + C_final = latents.shape[1] + latents = latents.reshape(B, T, C_final, H, W).permute(0,2,1,3,4).contiguous() + + return latents + + +#------------------------------------------------------------------------------- +# VERY STABLE +#------------------------------------------------------------------------------- +def generate_latents_robuste(latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=7.5, init_image_scale=2.0, + creative_noise=0.0, seed=42): + """ + Génère des latents pour une frame ou un batch de frames avec robustesse. + + latents : [B,C,T,H,W] latents encodés et scalés + pos_embeds / neg_embeds : embeddings textuels + creative_noise : float, bruit supplémentaire pour variation frame à frame + """ + torch.manual_seed(seed) + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + init_latents = latents.clone() # copie pour init_image_scale + + for t_step in scheduler.timesteps: + # Motion module (optionnel) + if motion_module is not None: + latents = motion_module(latents) + + # Ajouter un peu de noise créatif si demandé + if creative_noise > 0: + latents = latents + torch.randn_like(latents) * creative_noise + + # Classifier-free guidance + latent_model_input = torch.cat([latents, latents], dim=0) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0) + with torch.no_grad(): + noise_pred = unet(latent_model_input, t_step, encoder_hidden_states=embeds).sample + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + latents = scheduler.step(noise_pred, t_step, latents).prev_sample + + # Réinjection de l'image initiale + latents = latents + init_image_scale * (init_latents - latents) + + # Vérification NaN / inf + if torch.isnan(latents).any() or torch.isinf(latents).any(): + print(f"⚠ NaN/inf détecté à timestep {t_step}, réinitialisation avec petit bruit") + latents = latents.clone() + latents = latents + torch.randn_like(latents) * 1e-3 + + # Log de debug + mean_val = latents.abs().mean().item() + std_val = latents.std().item() + if math.isnan(mean_val) or mean_val < 1e-5: + print(f"⚠ Latent trop petit à timestep {t_step}, mean={mean_val:.6f}") + + # Remettre la forme [B,C,T,H,W] + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + return latents + +#------------------------------------------------------------------------------- +# VERY STABLE - Parfait pour l'init de la video' - Version la plus adapter +#------------------------------------------------------------------------------- +def generate_latents_robuste_4D( + latents, + unet, + scheduler, + pos_embeds=None, + neg_embeds=None, + motion_module=None, + device='cuda', + dtype=torch.float16, + guidance_scale=1.0, + init_image_scale=0.85, + creative_noise=0.0, + steps=8, + seed=None, + gamma_boost=1.08, # renforce couleurs ~10% + detail_strength=0.03, # amplification douce des détails + adaptive_noise_strength=0.005 # bruit créatif local sur zones plates + ): + """ + Génération initiale de latents ultra-safe 4D pour AnimateDiff. + Clamp, correction NaN/Inf et boost adaptatif de détails et couleurs. + """ + import torch + import torch.nn.functional as F + + if seed is not None: + torch.manual_seed(seed) + + latents = latents.to(device=device, dtype=dtype) + + # Échelle initiale + if init_image_scale != 1.0: + latents = latents * init_image_scale + + for t in range(steps): + with torch.no_grad(): + # ----- Dummy UNet step ----- + noise = torch.randn_like(latents) * 0.01 + latents = latents + noise + + # ----- Motion module si présent ----- + if motion_module is not None: + latents, _ = motion_module(latents) + + # ----- Clamp strict et correction NaN/Inf ----- + latents = torch.clamp(latents, -1.0, 1.0) + latents[torch.isnan(latents)] = 0.0 + latents[torch.isinf(latents)] = 0.0 + + # ----- Amplification douce des détails ----- + mean = latents.mean(dim=[2,3], keepdim=True) + detail = latents - mean + latents = latents + detail_strength * torch.tanh(detail) + + # ----- Bruit créatif adaptatif local ----- + if adaptive_noise_strength > 0: + low_contrast_mask = (latents.std(dim=[2,3], keepdim=True) < 0.1).float() + latents = latents + low_contrast_mask * torch.randn_like(latents) * adaptive_noise_strength + + # ----- Bruit créatif global ----- + if creative_noise > 0.0: + latents += torch.randn_like(latents) * creative_noise + + # ----- Gamma adaptatif pour les couleurs ----- + latents_norm = ((latents + 1.0) / 2.0).clamp(0,1) + latents_norm = latents_norm ** gamma_boost + latents = latents_norm * 2 - 1 + + # ----- Assurer 4D ----- + if latents.ndim != 4: + B, C, H, W = latents.shape[0], 4, latents.shape[-2], latents.shape[-1] + latents = latents[:, :C, :, :] if latents.shape[1] >= 4 else F.pad(latents, (0,0,0,0,0,4-latents.shape[1])) + + return latents + +def generate_latents_robuste_4D_net( + latents, + unet, + scheduler, + pos_embeds=None, + neg_embeds=None, + motion_module=None, + device='cuda', + dtype=torch.float16, + guidance_scale=1.0, + init_image_scale=0.85, + creative_noise=0.0, + steps=8, + seed=None, + clamp_percentile=99.5, + smoothing_factor=0.05 +): + """ + Génération initiale de latents ultra-safe 4D pour AnimateDiff. + - Clamp adaptatif par percentile + - Correction NaN / Inf intelligente + - Bruit créatif progressif + - Motion module sécurisé + - Option smoothing léger pour stabiliser latents + """ + import torch + import torch.nn.functional as F + + if seed is not None: + torch.manual_seed(seed) + + latents = latents.to(device=device, dtype=dtype) + if init_image_scale != 1.0: + latents = latents * init_image_scale + + for t in range(steps): + with torch.no_grad(): + # === Bruit UNet (placeholder) === + noise = torch.randn_like(latents) * 0.01 + latents = latents + noise + + # === Motion module === + if motion_module is not None: + latents, _ = motion_module(latents) + + # === Correction NaN / Inf === + mask_invalid = torch.isnan(latents) | torch.isinf(latents) + if mask_invalid.any(): + latents[mask_invalid] = latents[~mask_invalid].mean() + + # === Clamp adaptatif === + upper = torch.quantile(latents, clamp_percentile / 100.0) + lower = torch.quantile(latents, 1 - clamp_percentile / 100.0) + latents = latents.clamp(min=lower.item(), max=upper.item()) + + # === Bruit créatif progressif === + if creative_noise > 0.0: + progressive_noise = creative_noise * (1.0 - t / steps) + latents += torch.randn_like(latents) * progressive_noise + + # === Smoothing léger pour stabilité === + if smoothing_factor > 0.0 and t > 0: + latents = (1 - smoothing_factor) * latents + smoothing_factor * latents_prev + + latents_prev = latents.clone() + + # === Assurer 4D correct === + if latents.ndim != 4: + B, H, W = latents.shape[0], latents.shape[-2], latents.shape[-1] + C = min(4, latents.shape[1]) + latents = latents[:, :C, :, :] if latents.shape[1] >= 4 else F.pad(latents, (0,0,0,0,0,4-latents.shape[1])) + + return latents + # Génération initiale robuste : + #42 Classique, beaucoup de tests communautaires utilisent ce seed. #1234 Fidèle, stable, souvent utilisé pour des tests de cohérence. + #5555 Fidélité à l’image initiale (ton choix actuel) #2026 Léger changement dans la texture ou la posture, subtil mais prévisible + #9876 Variation un peu plus visible, garde la structure globale +def generate_latents_robuste_4D_v1( + latents, + unet, + scheduler, + pos_embeds=None, + neg_embeds=None, + motion_module=None, + device='cuda', + dtype=torch.float16, + guidance_scale=1.0, + init_image_scale=0.85, + creative_noise=0.0, + steps=8, + seed=None + ): + """ + Génération initiale de latents ultra-safe 4D pour AnimateDiff. + Clamp et correction NaN automatique à chaque step. + """ + if seed is not None: + torch.manual_seed(seed) + + latents = latents.to(device=device, dtype=dtype) + + # Échelle initiale si nécessaire + if init_image_scale != 1.0: + latents = latents * init_image_scale + + for t in range(steps): + with torch.no_grad(): + # Dummy UNet step (ou remplacer par vrai step si nécessaire) + noise = torch.randn_like(latents) * 0.01 + latents = latents + noise + + # Motion module si fourni + if motion_module is not None: + latents, _ = motion_module(latents) + + # Clamp strict et correction NaN + latents = torch.clamp(latents, -1.0, 1.0) + latents[torch.isnan(latents)] = 0.0 + latents[torch.isinf(latents)] = 0.0 + + # Optionnel: petit bruit créatif + if creative_noise > 0.0: + latents += torch.randn_like(latents) * creative_noise + + # S'assurer que c'est bien 4D + if latents.ndim != 4: + B, C, H, W = latents.shape[0], 4, latents.shape[-2], latents.shape[-1] + latents = latents[:, :C, :, :] if latents.shape[1] >= 4 else F.pad(latents, (0,0,0,0,0,4-latents.shape[1])) + + return latents +# --------------------- V1 ---------------------------------------- +def generate_latents_robuste_v1( + latents, + pos_embeds, + neg_embeds, + unet, + scheduler, + motion_module=None, + device="cuda", + dtype=torch.float16, + guidance_scale=7.5, + init_image_scale=2.0, + creative_noise=0.0, + seed=42 +): + """ + Génère des latents animés à partir d'une séquence initiale. + + latents: [B, 4, T, H, W] (déjà encodés et scalés) + pos_embeds / neg_embeds: embeddings texte pour guidance + guidance_scale: poids de guidance classifier-free + init_image_scale: poids de l'image initiale + creative_noise: bruit aléatoire ajouté avant chaque étape pour variation + """ + torch.manual_seed(seed) + + # Vérification dimensions + if latents.ndim != 5 or latents.shape[1] != 4: + raise ValueError(f"Latents attendus en [B,4,T,H,W], got {latents.shape}") + + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + + # Ajouter un peu de bruit créatif initial si demandé + if creative_noise > 0: + latents = latents + torch.randn_like(latents) * creative_noise + + # Reshape pour traitement UNet [B*T, C, H, W] + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + init_latents = latents.clone() + + for t_idx, t in enumerate(scheduler.timesteps): + + # Appliquer motion module si présent + if motion_module is not None: + latents = motion_module(latents) + + # Classifier-free guidance + latent_model_input = torch.cat([latents, latents], dim=0) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0) + + with torch.no_grad(): + noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeds).sample + + noise_uncond, noise_text = noise_pred.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # Réappliquer init_image_scale pour garder influence image initiale + latents = latents + init_image_scale * (init_latents - latents) + + # Log pour vérifier valeurs anormales + mean_val = latents.abs().mean().item() + if math.isnan(mean_val) or mean_val < 1e-5: + print(f"⚠ Step {t_idx}/{len(scheduler.timesteps)}: mean latent {mean_val:.6f}, reset avec petit bruit") + latents = init_latents + torch.randn_like(init_latents) * 0.01 + + # Repasser en [B, C, T, H, W] + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + + return latents + + +# --------------------------------------------------------- +# Diffusion FONCTIONNE PARFAITEMENT +# images_latents: [B,4,T,H,W] +# --------------------------------------------------------- +def generate_latents(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, device="cuda", dtype=torch.float16, guidance_scale=7.5, init_image_scale=2.0, seed=42): + """ + latents: [B,4,T,H,W] (déjà encodés et scalés) init_image_scale: poids de l'image initiale + """ + torch.manual_seed(seed) + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + # ⚡ on garde une copie des latents initiaux + init_latents = latents.clone() + for t in scheduler.timesteps: + if motion_module is not None: + latents = motion_module(latents) + + # classifier-free guidance + latent_model_input = torch.cat([latents] * 2) + embeds = torch.cat([neg_embeds, pos_embeds]) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # ⚡ appliquer init_image_scale pour garder l’influence de l’image initiale + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents + init_image_scale * (init_latents - latents) + + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + + return latents + +#--------------------------------------------------------- +# ------------------------- +# Génération de latents par bloc OK +# def generate_latents(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, device="cuda", dtype=torch.float16, guidance_scale=7.5, init_image_scale=2.0, seed=42, +# ------------------------- +def generate_latents_1(latents, pos_embeds, neg_embeds, unet, scheduler, motion_module=None, device="cuda", dtype=torch.float16, guidance_scale=4.5, init_image_scale=0.85): + """ + latents: [B, C, F, H, W] + pos_embeds / neg_embeds: [B, L, D] + """ + """ + latents: [B,4,T,H,W] (déjà encodés et scalés) init_image_scale: poids de l'image initiale + """ + torch.manual_seed(42) + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + # ⚡ on garde une copie des latents initiaux + init_latents = latents.clone() + for t in scheduler.timesteps: + if motion_module is not None: + latents = motion_module(latents) + + # classifier-free guidance + latent_model_input = torch.cat([latents] * 2) + embeds = torch.cat([neg_embeds, pos_embeds]) + + with torch.no_grad(): + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=embeds + ).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # ⚡ appliquer init_image_scale pour garder l’influence de l’image initiale + latents = scheduler.step(noise_pred, t, latents).prev_sample + latents = latents + init_image_scale * (init_latents - latents) + + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + + return latents + +# ------------------------- +# Génération de latents tuilés (block_size + overlap) +# ------------------------- +def generate_tiled(input_latents, pos_embeds, neg_embeds, unet, scheduler, motion_module, + device, dtype, guidance_scale=4.5, init_image_scale=0.85, + block_size=128, overlap=16): + """ + input_latents: [B, C, F, H, W] + Retourne latents [B, C, F, H, W] + """ + B, C, F, H, W = input_latents.shape + output_latents = torch.zeros_like(input_latents) + + h_blocks = math.ceil(H / (block_size - overlap)) + w_blocks = math.ceil(W / (block_size - overlap)) + + for hi in range(h_blocks): + for wi in range(w_blocks): + h_start = hi * (block_size - overlap) + w_start = wi * (block_size - overlap) + h_end = min(h_start + block_size, H) + w_end = min(w_start + block_size, W) + h_start = max(h_end - block_size, 0) + w_start = max(w_end - block_size, 0) + + block = input_latents[:, :, :, h_start:h_end, w_start:w_end] + + # Génération latents sur le bloc + block_out = generate_latents_2( + block, pos_embeds, neg_embeds, unet, scheduler, motion_module, + device, dtype, guidance_scale, init_image_scale + ) + + output_latents[:, :, :, h_start:h_end, w_start:w_end] = block_out + + return output_latents + + +# --------------------------------------------------------- +# Decode latents → images +# latents: [B,4,T,H,W] +# --------------------------------------------------------- +def decode_latents(latents, vae): + + device = next(vae.parameters()).device + dtype = next(vae.parameters()).dtype + + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + + frames = [] + + with torch.no_grad(): + for t in range(T): + latent = latents[:, :, t, :, :] + + # 🔥 INVERSE DU SCALE (CORRECTION MAJEURE) + latent = latent / LATENT_SCALE + + img = vae.decode(latent).sample + img = (img / 2 + 0.5).clamp(0, 1) + + frames.append(img.float()) + + images = torch.stack(frames, dim=2) # [B,3,T,H,W] + return images + + +# --------------------------------------------------------- +# Sauvegarde vidéo +# --------------------------------------------------------- +def create_video_from_latents(latents, vae, output_dir, fps=12): + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + images = decode_latents(latents, vae) + + images = images.squeeze(0) # [3,T,H,W] + images = images.permute(1,0,2,3) # [T,3,H,W] + + for i, img in enumerate(images): + save_image(img, output_dir / f"frame_{i:04d}.png") + + ( + ffmpeg + .input(f"{output_dir}/frame_%04d.png", framerate=fps) + .output(str(output_dir / "output.mp4"), + vcodec="libx264", + pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + + print(f"🎬 Vidéo générée : {output_dir / 'output.mp4'}") + + +# ---------- NEW FONCTION ---------------------------------- +# safe_load_unet, safe_load_vae, safe_load_scheduler, encode_images_to_latents, decode_latents_to_image, load_images, save_frames_as_video + +# ------------------------- +# Load models safe +# ------------------------- + +def safe_load_unet(model_path, device, fp16=True): + folder = os.path.join(model_path, "unet") + if os.path.exists(folder): + model = UNet2DConditionModel.from_pretrained(folder) + if fp16: + model = model.half() # réduit la VRAM de moitié + return model.to(device) + return None + +#def safe_load_unet(model_path, device, fp16=False): +# model = UNet2DConditionModel.from_pretrained(os.path.join(model_path,"unet")) +# if fp16: model = model.half() +# return model.to(device) + +def safe_load_vae(model_path, device, fp16=False, offload=False): + model = AutoencoderKL.from_pretrained(os.path.join(model_path,"vae")) + model = model.to("cpu" if offload else device) + if fp16: model = model.half() + return model + +def safe_load_scheduler(model_path): + return DPMSolverMultistepScheduler.from_pretrained(os.path.join(model_path,"scheduler")) + +# -------------------------------------- +# Encode / Decode FP16 simple +# -------------------------------------- +def encode_images_to_latents_simple(images, vae): + device = vae.device + dtype = next(vae.parameters()).dtype # prend fp16 si le VAE est en FP16 + images = images.to(device=device, dtype=dtype) + + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + + +def decode_latents_to_image_safe(latents, vae): + dtype = next(vae.parameters()).dtype + latents = latents.to(vae.device).to(dtype) / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + +# ------------------------------ +def encode_images_to_latents_half(images, vae): + # récupère dtype réel du VAE + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + images = images.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B * F, C, H, W) + + latents_2d = vae.encode(images_2d).latent_dist.sample() + latents_2d = latents_2d * LATENT_SCALE + + latents = latents_2d.view( + B, F, + latents_2d.shape[1], + latents_2d.shape[2], + latents_2d.shape[3] + ) + + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + + else: + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + + return latents + +def decode_latents_to_image_vae(latents, vae): + + # Récupère device + dtype réel du VAE + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + # Aligne le dtype sur celui du VAE + latents = latents.to(device=vae_device, dtype=vae_dtype) + + latents = latents / LATENT_SCALE + + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0, 1) + + # On repasse en float32 pour sauvegarde PNG + return img.float() +# ------------------------- +# Encode / Decode +# ------------------------- +def encode_images_to_latents_ori(images, vae): + device = vae.device + images = images.to(device=device, dtype=torch.float32) + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +def decode_latents_to_image_ori(latents, vae): + latents = latents.to(vae.device).float() / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + + +# -------------------------------------------------------- +# | Mode | VAE | Images | Latents | Résultat | +# | ----------- | ------- | ------- | ------- | -------- | +# | fp32 | float32 | float32 | float32 | ✅ | +# | fp16 | float16 | float16 | float16 | ✅ | +# | offload CPU | float32 | float32 | float32 | ✅ | +# ------------------------- ci dessous: +# ------------------------- +# Encode / Decode corrigé FP16 safe +# ------------------------- +def encode_images_to_latents(images, vae): + device = next(vae.parameters()).device + dtype = next(vae.parameters()).dtype # on aligne avec le VAE + images = images.to(device=device, dtype=dtype) + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + return latents + +def decode_latents_to_image(latents, vae): + # On force latents à avoir le même dtype et device que le VAE + vae_dtype = next(vae.parameters()).dtype + vae_device = next(vae.parameters()).device + latents = latents.to(device=vae_device, dtype=vae_dtype) / LATENT_SCALE + + with torch.no_grad(): + img = vae.decode(latents).sample + + # Normalisation sûre vers 0-1 + img = (img / 2 + 0.5).clamp(0, 1) + + # Si FP16 → convertir en float32 pour torchvision save_image + if img.dtype == torch.float16: + img = img.float() + + return img + +# NEW +# +# +# --------------------------------------------------------- +# Decode latents to image avec logs et sécurité +# --------------------------------------------------------- +def decode_latents_to_image_2(latents, vae, latent_scale=0.18215): + """ + latents: [B, C, F, H, W] ou [B, C, 1, H, W] pour frame unique + vae: VAE pour décodage + """ + try: + print(f"🔹 decode_latents_to_image_2 | input shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}") + + # Si latents a une dimension de frame singleton, la squeeze + if latents.shape[2] == 1: + latents = latents.squeeze(2) + print(f"🔹 Squeeze frame dimension → shape: {latents.shape}") + + # Assurer dtype et device compatible VAE + vae_dtype = next(vae.parameters()).dtype + vae_device = next(vae.parameters()).device + latents = latents.to(device=vae_device, dtype=vae_dtype) / latent_scale + + # Check NaN avant VAE + print(f"🔹 Latents before VAE decode | min: {latents.min()}, max: {latents.max()}, dtype: {latents.dtype}") + if torch.isnan(latents).any(): + print("❌ Warning: NaN detected in latents before VAE decode!") + + with torch.no_grad(): + img = vae.decode(latents).sample + + # Check NaN après décodage + print(f"🔹 Image after VAE decode | min: {img.min()}, max: {img.max()}, dtype: {img.dtype}") + if torch.isnan(img).any(): + print("❌ Warning: NaN detected in decoded image!") + + # Normalisation safe vers 0-1 + img = (img / 2 + 0.5).clamp(0, 1) + print(f"🔹 Image final | min: {img.min()}, max: {img.max()}, dtype: {img.dtype}, shape: {img.shape}") + + # Conversion FP16 -> FP32 si nécessaire + if img.dtype == torch.float16: + img = img.float() + + return img + + except Exception as e: + print(f"❌ Exception in decode_latents_to_image_2: {e}") + # Retourne une image noire safe si VAE échoue + B, C, H, W = latents.shape[:4] + return torch.zeros(B, 3, H*8, W*8, device=latents.device) # scale approx 8x pour SD VAE +# ------------------------- +# Encode / Decode corrigé +# ------------------------- +# ------------------------- + +def decode_latents_to_image_old(latents, vae): + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + latents = latents.to(device=vae_device, dtype=vae_dtype) + latents = latents / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0, 1) + return img.float() # on repasse en float32 pour PNG + +# ------------------------- +# Image utilities +# ------------------------- +def load_image_file(path, W, H, device, dtype): + img = Image.open(path).convert("RGB") + img = img.resize((W,H), Image.LANCZOS) + img_tensor = torch.tensor(np.array(img)).permute(2,0,1).to(device=device, dtype=dtype)/127.5 - 1.0 + return img_tensor + +# ------------------------- +# Image utilities +# ------------------------- +def load_images_test(paths, W, H, device, dtype): + + all_tensors = [] + + for p in paths: + + if p.lower().endswith(".gif"): + + img = Image.open(p) + + for f in ImageSequence.Iterator(img): + + t = torch.tensor(np.array(f)).float() / 127.5 - 1.0 + + if t.ndim == 3: + t = t.permute(2,0,1) + + all_tensors.append(t) + + print(f"✅ GIF chargé : {p}") + + else: + + t = load_image_file(p, W, H, device="cpu", dtype=torch.float32) + + if t.ndim == 3 and t.shape[-1] == 3: + t = t.permute(2,0,1) + + all_tensors.append(t) + + print(f"✅ Image chargée : {p}") + + imgs = torch.stack(all_tensors, dim=0) + + print("IMAGE SHAPE:", imgs.shape) + print("IMAGE MIN/MAX:", imgs.min().item(), imgs.max().item()) + + return imgs.to(device=device, dtype=dtype) + + + +def load_images(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + if p.lower().endswith(".gif"): + img = Image.open(p) + frames = [torch.tensor(np.array(f)).permute(2,0,1).to(device=device, dtype=dtype)/127.5 - 1.0 + for f in ImageSequence.Iterator(img)] + print(f"✅ GIF chargé : {p} avec {len(frames)} frames") + all_tensors.extend(frames) + else: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + + + +def load_images_s(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + img = Image.open(p).convert("RGB").resize((W,H), Image.LANCZOS) + t = torch.tensor(np.array(img)).permute(2,0,1).to(device=device,dtype=dtype)/127.5 - 1.0 + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +def load_images_all(paths, W, H, device, dtype): + all_tensors = [] + for p in paths: + if p.lower().endswith(".gif"): + img = Image.open(p) + frames = [torch.tensor(np.array(f)).permute(2,0,1).to(device=device, dtype=dtype)/127.5 - 1.0 + for f in ImageSequence.Iterator(img)] + print(f"✅ GIF chargé : {p} avec {len(frames)} frames") + all_tensors.extend(frames) + else: + t = load_image_file(p, W, H, device, dtype) + print(f"✅ Image chargée : {p}") + all_tensors.append(t) + return torch.stack(all_tensors, dim=0) + +# ------------------------- +# Save video +# ------------------------- +def save_frames_as_video_rmtmp(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): shutil.rmtree(temp_dir) + temp_dir.mkdir() + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) +# ----------------------------- +# -------- MOTION ------------- +# ----------------------------- +def default_motion_module(latents: torch.Tensor, frame_idx: int = 0, total_frames: int = 1) -> torch.Tensor: + """ + Motion module par défaut avec effets aléatoires pour Tiny-SD + - Translation aléatoire subtile + - Oscillation douce + - Zoom subtil (ne change pas la taille) + - Bruit léger sur latents + """ + # --- Zoom subtil --- + zoom_factor = random.uniform(0.98, 1.02) # ±2% + latents = latents * zoom_factor # scale des valeurs des latents seulement + # (⚠️ Ne change pas H/W, safe pour le scheduler) + + B, C, H, W = latents.shape + + # --- Translation aléatoire subtile (pan/tilt) --- + dx = random.randint(-1, 1) + dy = random.randint(-1, 1) + + # --- Oscillation douce --- + osc_amp = 1 + dx_osc = int(osc_amp * math.sin(2 * math.pi * frame_idx / max(total_frames,1))) + dy_osc = int(osc_amp * math.cos(2 * math.pi * frame_idx / max(total_frames,1))) + + # Pan/tilt safe avec torch.roll + latents = torch.roll(latents, shifts=(dy+dy_osc, dx+dx_osc), dims=(2,3)) + + # --- Noise léger --- + noise_sigma = 0.003 + latents = latents + torch.randn_like(latents) * noise_sigma + + return latents + + +def default_motion_module_test(latents: torch.Tensor, frame_idx: int = 0, total_frames: int = 1) -> torch.Tensor: + # Zoom subtil + zoom_factor = 0.01 + factor = 1.0 + zoom_factor * frame_idx / max(total_frames,1) + B, C, H, W = latents.shape + latents = F.interpolate(latents, scale_factor=factor, mode='bilinear', align_corners=False) + latents = latents[:, :, :H, :W] # recadrer si nécessaire + + # Oscillation subtile + dx = int(1 * math.sin(2 * math.pi * frame_idx / max(total_frames,1))) + dy = int(1 * math.cos(2 * math.pi * frame_idx / max(total_frames,1))) + grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') + grid_x = torch.clamp(grid_x - dx, 0, W-1) + grid_y = torch.clamp(grid_y - dy, 0, H-1) + latents = latents[:, :, grid_y, grid_x] + + return latents + +# ------------------------- +# Model loaders +# ------------------------- + + +def load_config(config_path): + with open(config_path, "r") as f: + return yaml.safe_load(f) + +# ------------------------- +# Upscale vidéo avec Real-ESRGAN +# ------------------------- +def upscale_video_with_realesrgan(video_path, output_path, device="cuda", scale=2, fps=12): + """ + Upscale une vidéo avec Real-ESRGAN (torch) + - video_path : chemin vidéo d'entrée + - output_path : chemin vidéo sortie + - scale : facteur d'upscale (2, 4, etc.) + - fps : framerate de sortie + """ + try: + from realesrgan import RealESRGAN + except ImportError: + print("❌ Module Real-ESRGAN non installé. pip install realesrgan") + return + + import tempfile + from PIL import Image + import ffmpeg + + temp_dir = tempfile.mkdtemp() + temp_out_dir = tempfile.mkdtemp() + + # 1️⃣ Extraire frames + ( + ffmpeg + .input(str(video_path)) + .output(f"{temp_dir}/frame_%05d.png") + .overwrite_output() + .run(quiet=True) + ) + + # 2️⃣ Charger modèle + model = RealESRGAN(device, scale=scale) + model.load_weights(f"RealESRGAN_x{scale}.pth", download=True) # si pas déjà présent + + # 3️⃣ Upscale frame par frame + frame_paths = sorted(Path(temp_dir).glob("frame_*.png")) + for idx, fpath in enumerate(frame_paths): + img = Image.open(fpath).convert("RGB") + upscaled = model.predict(img) + upscaled.save(Path(temp_out_dir) / f"frame_{idx:05d}.png") + + # 4️⃣ Recomposer la vidéo + ( + ffmpeg + .input(f"{temp_out_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p", crf=18) + .overwrite_output() + .run(quiet=True) + ) + + # 5️⃣ Cleanup + shutil.rmtree(temp_dir) + shutil.rmtree(temp_out_dir) + print(f"✅ Vidéo finale Real-ESRGAN x{scale} générée : {output_path}") + +# ------------------------- +# Upscale cinématique (sans torch) +# ------------------------- +def upscale_video_cinematic_smooth(input_video_path, output_video_path, scale=4, fps=12, interp_frames=1): + """ + Upscale x4 avec interpolation de frames pour smooth cinematic style. + - input_video_path : vidéo 128x128 + - output_video_path : vidéo finale upscalée + - scale : facteur d'upscale (par défaut 4) + - interp_frames : nombre de frames interpolées entre deux frames (smooth) + """ + input_video_path = Path(input_video_path) + temp_dir = input_video_path.parent / "temp_frames_smooth" + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + # Extraction frames + ( + ffmpeg + .input(str(input_video_path)) + .output(str(temp_dir / "frame_%05d.png")) + .overwrite_output() + .run(quiet=True) + ) + + # Charger frames existantes + frame_paths = sorted(temp_dir.glob("frame_*.png")) + upscaled_frames = [] + + for idx in range(len(frame_paths)): + frame = Image.open(frame_paths[idx]) + W, H = frame.size + frame_up = frame.resize((W*scale, H*scale), Image.LANCZOS) + + upscaled_frames.append(frame_up) + + # Interpolation linéaire entre frames pour smooth + if idx < len(frame_paths) - 1 and interp_frames > 0: + next_frame = Image.open(frame_paths[idx+1]).resize((W*scale, H*scale), Image.LANCZOS) + for t in range(1, interp_frames+1): + alpha = t / (interp_frames+1) + interp_frame = Image.blend(frame_up, next_frame, alpha) + upscaled_frames.append(interp_frame) + + # Ré-encodage vidéo + temp_up_dir = input_video_path.parent / "temp_frames_upscaled_smooth" + if temp_up_dir.exists(): + shutil.rmtree(temp_up_dir) + temp_up_dir.mkdir() + + for i, f in enumerate(upscaled_frames): + f.save(temp_up_dir / f"frame_{i:05d}.png") + + ( + ffmpeg + .input(str(temp_up_dir / "frame_%05d.png"), framerate=fps*(interp_frames+1)) + .output(str(output_video_path), vcodec="libx264", pix_fmt="yuv420p", crf=18) + .overwrite_output() + .run(quiet=True) + ) + + shutil.rmtree(temp_dir) + shutil.rmtree(temp_up_dir) + print(f"✅ Vidéo finale x{scale} smooth cinematic générée : {output_video_path}") + + +#--------------------------------------------------------------------------------------- +# -------------- Remove watermark fonction --------------------------------------------- +#--------------------------------------------------------------------------------------- + +def remove_watermark_auto_blur_v1( + frame_pil, + target_hex_list, + tolerance=40, + threshold=0.45, + candidate_zones=None, + blur_radius=10, + feather_radius=10, + show_mask=False +): + """ + Version ULTRA INVISIBLE PRO + - Détection couleur + - Masque progressif (feather) + - Fusion douce + """ + + import numpy as np + import cv2 + from PIL import Image, ImageFilter + + img_np = np.array(frame_pil).astype(np.int16) + H, W, _ = img_np.shape + + # Convertir hex en RGB + target_colors = np.array( + [[int(h[i:i+2], 16) for i in (1, 3, 5)] for h in target_hex_list], + dtype=np.int16 + ) + + if candidate_zones is None: + candidate_zones = [(0, 0, W, H)] + + for idx, (x, y, w, h) in enumerate(candidate_zones): + + patch = img_np[y:y+h, x:x+w] + mask_total = np.zeros((h, w), dtype=np.uint8) + + # Détection couleur + for color in target_colors: + dist = np.linalg.norm(patch - color, axis=2) + mask_total += (dist <= tolerance).astype(np.uint8) + + ratio = mask_total.sum() / (w * h) + + if ratio >= threshold: + + # Masque binaire + mask_binary = (mask_total > 0).astype(np.uint8) * 255 + + # Dilatation légère pour couvrir bordures + kernel = np.ones((3, 3), np.uint8) + mask_binary = cv2.dilate(mask_binary, kernel, iterations=1) + + # Feather (adoucissement progressif) + mask_soft = cv2.GaussianBlur(mask_binary, (0, 0), feather_radius) + + # Normaliser masque 0..1 + mask_soft = mask_soft.astype(np.float32) / 255.0 + mask_soft = np.expand_dims(mask_soft, axis=2) + + # Extraire région + region = frame_pil.crop((x, y, x+w, y+h)) + region_np = np.array(region).astype(np.float32) + + # Blur léger naturel + region_blur = region.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + region_blur_np = np.array(region_blur).astype(np.float32) + + # Blend ultra doux + blended = region_np * (1 - mask_soft) + region_blur_np * mask_soft + blended = blended.astype(np.uint8) + + # Replacer dans image + frame_pil.paste(Image.fromarray(blended), (x, y)) + + return frame_pil + +def remove_watermark_auto_blur_simple(frame_pil, target_hex_list, tolerance=15, + threshold=0.5, candidate_zones=None, + blur_radius=20, show_mask=False): + """ + Détecte automatiquement les zones du watermark et les floute. + + - frame_pil : image PIL + - target_hex_list : liste des couleurs du watermark en hex + - tolerance : tolérance couleur + - threshold : proportion minimale de pixels correspondants pour appliquer + - candidate_zones : zones probables [(x,y,w,h), ...], sinon toute l'image + - blur_radius : rayon du flou gaussien + - show_mask : bool, afficher le masque de détection pour debug + """ + import numpy as np + import cv2 + from PIL import Image, ImageFilter + + img_np = np.array(frame_pil).astype(np.int16) + H, W, _ = img_np.shape + + # Convertir hex en RGB + target_colors = np.array([[int(h[i:i+2],16) for i in (1,3,5)] for h in target_hex_list], dtype=np.int16) + + if candidate_zones is None: + candidate_zones = [(0, 0, W, H)] + + for idx, (x, y, w, h) in enumerate(candidate_zones): + patch = img_np[y:y+h, x:x+w] + mask_total = np.zeros((h, w), dtype=np.uint8) + + # Détection des pixels correspondant aux couleurs du watermark + for c_idx, color in enumerate(target_colors): + dist = np.linalg.norm(patch - color, axis=2) + mask_total += (dist <= tolerance).astype(np.uint8) + if show_mask: + print(f"[DEBUG] Zone {idx}, Couleur {c_idx} match pixels: {(dist <= tolerance).sum()}") + + ratio = mask_total.sum() / (w*h) + if show_mask: + print(f"[DEBUG] Zone {idx} ratio total: {ratio:.3f} (seuil={threshold})") + import matplotlib.pyplot as plt + plt.imshow(mask_total, cmap='gray') + plt.title(f"Zone {idx} mask") + plt.show() + + if ratio >= threshold: + # Créer un masque binaire pour OpenCV + mask_uint8 = (mask_total > 0).astype(np.uint8) * 255 + contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + for cnt in contours: + bx, by, bw, bh = cv2.boundingRect(cnt) + left = x + bx + top = y + by + right = left + bw + bottom = top + bh + + # Extraire la région et flouter + region = frame_pil.crop((left, top, right, bottom)) + region_blur = region.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + frame_pil.paste(region_blur, (left, top)) + + if show_mask: + print(f"[DEBUG] Zone {idx} floutée.") + + return frame_pil + +def remove_watermark_white(frame_pil, bbox, padding=2): + """ + Remplace le watermark par un rectangle blanc. + bbox = (x, y, w, h) + """ + if bbox is None: + return frame_pil + + x, y, w, h = bbox + + # Ajouter un petit padding pour couvrir les bords + x = max(0, x - padding) + y = max(0, y - padding) + w += padding * 2 + h += padding * 2 + + img_w, img_h = frame_pil.size + w = min(w, img_w - x) + h = min(h, img_h - y) + + # Créer rectangle blanc + from PIL import ImageDraw + draw = ImageDraw.Draw(frame_pil) + draw.rectangle([x, y, x + w, y + h], fill=(255, 255, 255)) + + return frame_pil + + +def decode_latents_ultrasafe_blockwise_test( + latents, vae, + block_size=32, overlap=16, + gamma=1.0, brightness=1.0, + contrast=1.0, saturation=1.0, + device="cuda", frame_counter=0, output_dir=Path("."), + epsilon=1e-5 # valeur minimale pour éviter patches nuls +): + """ + Décodage par blocs ultra-safe des latents en image PIL avec correction auto + et debug min/max pour chaque patch et image finale. + + - latents: [B, 4, H, W] sur CPU ou GPU + - vae: VAE déjà chargé (fp16 ou fp32) + - block_size: taille des patches + - overlap: chevauchement + - gamma/brightness/contrast/saturation: correction finale + - device: device du VAE (cuda ou cpu) + - epsilon: valeur minimale pour patches nuls + """ + vae_dtype = next(vae.parameters()).dtype + B, C, H, W = latents.shape + output_rgb = torch.zeros(B, 3, H * 8, W * 8, device=device, dtype=torch.float32) + + y_steps = list(range(0, H, block_size - overlap)) + x_steps = list(range(0, W, block_size - overlap)) + + for y0 in y_steps: + y1 = min(y0 + block_size, H) + for x0 in x_steps: + x1 = min(x0 + block_size, W) + patch = latents[:, :, y0:y1, x0:x1].to(device=device, dtype=vae_dtype) + + # Debug avant VAE + print(f"[DEBUG] patch avant VAE ({y0},{x0}): shape={patch.shape}, " + f"dtype={patch.dtype}, min={patch.min():.6f}, max={patch.max():.6f}") + + # ✅ Correction NaN / Inf et epsilon minimal + patch = torch.nan_to_num(patch, nan=0.0, posinf=5.0, neginf=-5.0) + if torch.all(patch == 0): + patch += epsilon + + # log patch + patch_idx = f"{y0}_{x0}" + log_patch_stats(frame_idx=frame_counter, patch_idx=patch_idx, patch=patch, csv_path=output_dir / "patch_stats.csv") + + # Decode + with torch.no_grad(): + patch_decoded = vae.decode(patch).sample # [B, 3, h*8, w*8] + patch_decoded = patch_decoded.to(torch.float32) + + # ✅ Recentrage automatique pour éviter frames sombres + min_val = patch_decoded.min() + max_val = patch_decoded.max() + if max_val - min_val > 1e-6: + patch_decoded = (patch_decoded - min_val) / (max_val - min_val) + + # Debug après VAE + log_patch_stats(frame_idx=frame_counter, patch_idx=patch_idx+"_decoded", patch=patch_decoded, csv_path=output_dir / "patch_stats.csv") + + h_start, h_end = y0 * 8, y1 * 8 + w_start, w_end = x0 * 8, x1 * 8 + output_rgb[:, :, h_start:h_end, w_start:w_end] = patch_decoded + + # Debug avant clamp + print(f"[DEBUG] output_rgb final avant clamp: shape={output_rgb.shape}, " + f"min={output_rgb.min():.6f}, max={output_rgb.max():.6f}") + + output_rgb = output_rgb.clamp(0.0, 1.0) + + # Correction gamma / contraste / luminosité / saturation + frame_pil_list = [] + for i in range(B): + img = F.to_pil_image(output_rgb[i]) + img = ImageEnhance.Brightness(img).enhance(brightness) + img = ImageEnhance.Contrast(img).enhance(contrast) + img = ImageEnhance.Color(img).enhance(saturation) + img = img.point(lambda x: (x / 255) ** (1 / gamma) * 255) + frame_pil_list.append(img) + + # Debug post-correction + pil_tensor = F.to_tensor(img) + print(f"[DEBUG] frame {i} après PIL correction: shape={pil_tensor.shape}, " + f"min={pil_tensor.min():.6f}, max={pil_tensor.max():.6f}") + + return frame_pil_list[0] if B == 1 else frame_pil_list + + + +def encode_images_to_latents_safe(images, vae, device="cuda", epsilon=1e-5): + """ + Encode des images en latents sûrs pour UNet. + Retourne toujours un tensor [B, 4, H_latent, W_latent], dtype=vae.dtype. + + - epsilon : valeur minimale ajoutée aux latents nuls pour éviter frames noires + """ + images_t = images.to(device=device, dtype=torch.float32) + original_dtype = next(vae.parameters()).dtype + + vae = vae.to(device=device, dtype=torch.float32) # safe pour l'encodage + + with torch.no_grad(): + latents = vae.encode(images_t).latent_dist.sample() + + print(f"[DEBUG encode] latents min/max après sample: {latents.min().item():.6f}/{latents.max().item():.6f}") + + # Scaling + latents = latents * LATENT_SCALE + print(f"[DEBUG encode] latents min/max après LATENT_SCALE: {latents.min().item():.6f}/{latents.max().item():.6f}") + + # Clamp NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + # ✅ Remplir les latents entièrement nuls + if torch.all(latents == 0): + latents += epsilon + + # Normalisation pour éviter overflow + max_abs = latents.abs().max() + if max_abs > 0: + latents = latents / max_abs + + # Conversion en dtype final du VAE + latents = latents.to(original_dtype) + + # ---------------- FORCE 4 CANAUX ---------------- + if latents.ndim == 4 and latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + if latents.ndim == 5 and latents.shape[2] == 1: + latents = latents.repeat(1, 1, 4, 1, 1) + + print(f"[DEBUG encode] shape finale latents: {latents.shape}, min/max: {latents.min().item():.6f}/{latents.max().item():.6f}") + return latents diff --git a/scripts/utils/safe_latent.py b/scripts/utils/safe_latent.py new file mode 100644 index 00000000..1787781f --- /dev/null +++ b/scripts/utils/safe_latent.py @@ -0,0 +1,13 @@ +import torch + +def ensure_valid(latents, eps=1e-3): + """ + Remplace NaN/inf et ajoute un petit bruit si le latent est trop faible. + """ + latents = torch.nan_to_num(latents, nan=eps, posinf=eps, neginf=-eps) + if latents.abs().mean() < eps: + latents += torch.randn_like(latents) * eps + return latents.clamp(-3.0, 3.0) + + + diff --git a/scripts/utils/tools_utils.py b/scripts/utils/tools_utils.py new file mode 100644 index 00000000..5205f486 --- /dev/null +++ b/scripts/utils/tools_utils.py @@ -0,0 +1,761 @@ +# -------------------------------------------------------------- +# tools_utils.py - Fonctions utilitaires génériques +# -------------------------------------------------------------- +import os, math +import hashlib +from PIL import Image, ImageEnhance +from torchvision.transforms import ToPILImage +from torchvision.transforms import functional as F +from .fx_utils import apply_post_processing_adaptive + +LATENT_SCALE = 0.18215 # valeur globale, peut être importée si nécessaire + +import json +import torch +from pathlib import Path + + +# ---------------- Utilitaires prompt ---------------- +def encode_prompts_batch(prompts, negative_prompts, tokenizer, text_encoder, + device="cuda", projection=None): + """ + Encode une liste de prompts positifs et négatifs en embeddings utilisables pour la génération. + + Args: + prompts (list[str] or list[list[str]]): prompts positifs + negative_prompts (list[str] or list[list[str]]): prompts négatifs + tokenizer: tokenizer du modèle textuel + text_encoder: modèle textuel pour embeddings + device (str): "cuda" ou "cpu" + projection (callable, optional): fonction pour transformer les embeddings (ex: LoRA) + + Returns: + pos_embeds_list, neg_embeds_list : listes de torch.Tensor [B, seq_len, dim] + """ + pos_embeds_list = [] + neg_embeds_list = [] + + for i, prompt_item in enumerate(prompts): + # Texte positif + prompt_text = " ".join(prompt_item) if isinstance(prompt_item, list) else str(prompt_item) + # Texte négatif correspondant (fallback au premier si trop court) + neg_text_item = negative_prompts[i] if i < len(negative_prompts) else negative_prompts[0] + neg_text = " ".join(neg_text_item) if isinstance(neg_text_item, list) else str(neg_text_item) + + # Tokenize + text_inputs = tokenizer( + prompt_text, + padding="max_length", + truncation=True, + max_length=tokenizer.model_max_length, + return_tensors="pt" + ) + neg_inputs = tokenizer( + neg_text, + padding="max_length", + truncation=True, + max_length=tokenizer.model_max_length, + return_tensors="pt" + ) + + # Encoder + with torch.no_grad(): + pos_embeds = text_encoder(text_inputs.input_ids.to(device)).last_hidden_state + neg_embeds = text_encoder(neg_inputs.input_ids.to(device)).last_hidden_state + + # Appliquer projection si fourni + if projection is not None: + pos_embeds = projection(pos_embeds) + neg_embeds = projection(neg_embeds) + + # Ajouter à la liste + pos_embeds_list.append(pos_embeds) + neg_embeds_list.append(neg_embeds) + + return pos_embeds_list, neg_embeds_list + +# ---------------- Utilitaires motion ---------------- +def apply_motion_safe(latents, motion_module, threshold=1e-3): + if latents.abs().max() < threshold: + return latents, False + return motion_module(latents), True + +def save_input_frame(input_image, output_dir, frame_counter, pbar=None, + blur_radius=0.0, contrast=1.0, saturation=1.0, apply_post=False): + try: + from torchvision.transforms.functional import to_pil_image + + # Tensor → CPU + clamp + img = input_image[0].detach().cpu().clamp(-1, 1) + + # [-1,1] → [0,1] + img = (img + 1) / 2 + + # → PIL + img_pil = to_pil_image(img) + + # Option post-process + if apply_post: + img_pil = apply_post_processing_adaptive( + img_pil, + blur_radius=blur_radius, + contrast=contrast, + brightness=1.0, + saturation=saturation + ) + + # Save + img_pil.save(output_dir / f"frame_{frame_counter:05d}_00.png") + print(f"[INPUT SAVE Frame {frame_counter:03d}_00]") + + # Update compteur + progress bar + frame_counter += 1 + if pbar: + pbar.update(1) + + return frame_counter + + except Exception as e: + print(f"[INPUT SAVE ERROR] {e}") + return frame_counter + +def get_dynamic_latent_injection(frame_counter, total_frames, start=0.90, end=0.55, mode="cosine"): + """ + Calcule latent_injection pour chaque frame, avec protection contre division par zéro. + """ + if total_frames <= 1: + return start # Pas de progression possible + + t = frame_counter / (total_frames - 1) # toujours safe + if mode == "cosine": + alpha = 0.5 - 0.5 * math.cos(math.pi * t) + else: + alpha = t + latent_injection = start + (end - start) * alpha + return min(max(latent_injection, 0.0), 1.0) + +# ------------------------------------------------------------------------------------------- +# --- Sélection simple des embeddings prompts par frame --- +def get_embeddings_for_frame(frame_idx, frames_per_prompt, pos_list, neg_list, device="cuda"): + #Retourne les embeddings du prompt correspondant à la frame_idx. Chaque prompt produit `frames_per_prompt` frames consécutives. + num_prompts = len(pos_list) + prompt_idx = min(frame_idx // frames_per_prompt, num_prompts - 1) + return pos_list[prompt_idx].to(device), neg_list[prompt_idx].to(device) + + +def adapt_embeddings_to_unet(pos_embeds, neg_embeds, target_dim): + """Adapte automatiquement les embeddings texte pour correspondre au cross_attention_dim du UNet.""" + current_dim = pos_embeds.shape[-1] + if current_dim == target_dim: + return pos_embeds, neg_embeds + # Troncature + if current_dim > target_dim: + pos_embeds = pos_embeds[..., :target_dim] + neg_embeds = neg_embeds[..., :target_dim] + # Padding + elif current_dim < target_dim: + pad = target_dim - current_dim + pos_embeds = torch.nn.functional.pad(pos_embeds, (0, pad)) + neg_embeds = torch.nn.functional.pad(neg_embeds, (0, pad)) + return pos_embeds, neg_embeds + +def compute_weighted_params(frame_idx, total_frames, + init_start=0.85, init_end=0.5, + noise_start=0.0, noise_end=0.08, + guidance_start=3.5, guidance_end=4.5, + mode="cosine"): + """ + Calcule init_image_scale, creative_noise et guidance_scale de manière pondérée. + Ajuste guidance en fonction du signal de l'image et du bruit. + """ + # interpolation linéaire ou cosinus + def interp(a, b, t, mode="cosine"): + if mode=="cosine": + mu = (1 - math.cos(math.pi * t)) / 2 + else: + mu = t + return a*(1-mu) + b*mu + + t = frame_idx / max(total_frames-1,1) + init_scale = interp(init_start, init_end, t, mode) + creative_noise = interp(noise_start, noise_end, t, mode) + + # pondération guidance_scale : + # si init_scale élevé → fidèle → guidance plus faible + # si init_scale faible → moins fidèle → guidance plus créative + base_guidance = interp(guidance_start, guidance_end, t, mode) + weighted_guidance = base_guidance * (1 + 0.5*(1-init_scale)) * (1 - 0.5*creative_noise) + + return init_scale, creative_noise, weighted_guidance + +def load_external_embedding_as_latent(path, target_shape): + from safetensors.torch import load_file + emb = load_file(path) + clip_vec = list(emb.values())[0] + + # projection simple + latent = clip_vec.mean() * torch.randn(target_shape) + return latent + +def inject_external_embeddings( + latents, + external_embeddings, + device, + normalize=True, + clamp_range=(-1.0, 1.0) +): + """ + Injecte des embeddings latents externes dans les latents principaux. + + Args: + latents (torch.Tensor): latents [B,C,H,W] + external_embeddings (list[dict]): liste de dicts avec : + - "latent": tensor + - "weight": float + - "type": "positive" ou "negative" + device (str): device cible + normalize (bool): normalise les embeddings pour éviter domination + clamp_range (tuple): clamp final + + Returns: + torch.Tensor: latents modifiés + """ + + if not external_embeddings: + return latents + + latents = latents.to(device) + + for emb in external_embeddings: + try: + ext = emb.get("latent", None) + weight = float(emb.get("weight", 0.0)) + emb_type = emb.get("type", "positive") + + if ext is None or weight == 0.0: + continue + + # --- Device + dtype safe --- + ext = ext.to(device=device, dtype=latents.dtype) + + # --- Resize si nécessaire --- + if ext.shape != latents.shape: + ext = torch.nn.functional.interpolate( + ext, + size=latents.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Ajustement batch/channel si besoin + if ext.shape[1] != latents.shape[1]: + ext = ext[:, :latents.shape[1], :, :] + + # --- Nettoyage --- + ext = torch.nan_to_num(ext) + + # --- Normalisation (important) --- + if normalize: + ext_std = ext.std() + lat_std = latents.std() + + if ext_std > 1e-6: + ext = ext * (lat_std / ext_std) + + # --- Injection --- + if emb_type == "negative": + latents = latents - weight * ext + else: + latents = latents + weight * ext + + except Exception as e: + print(f"[inject_external_embeddings ERROR] {e}") + continue + + # --- Clamp + sécurité --- + latents = torch.clamp(latents, clamp_range[0], clamp_range[1]) + latents = torch.nan_to_num(latents) + + return latents + + +def update_n3r_memory(memory_dict, cf_embeds, n3r_latents, memory_alpha=0.15): + prompt_bytes = cf_embeds[0,0].cpu().numpy().tobytes() + prompt_key = hashlib.sha256(prompt_bytes).hexdigest() + print(f"[DEBUG] clé mémoire : {prompt_key[:8]}..., latents fusionnés") + previous_memory = memory_dict.get(prompt_key, torch.zeros_like(n3r_latents)) + fused_latents = (1 - memory_alpha) * previous_memory.to(n3r_latents.device) + memory_alpha * n3r_latents + memory_dict[prompt_key] = fused_latents.detach().cpu() + return fused_latents + +# ------------------- Sauvegarde mémoire ------------------- +def save_memory(memory_dict, memory_file: Path): + """ + Sauvegarde la mémoire N3R au format JSON. + Convertit les tensors en listes pour compatibilité JSON. + """ + serializable_memory = {k: v.tolist() for k, v in memory_dict.items()} + memory_file = memory_file.with_suffix(".json") + memory_file.parent.mkdir(parents=True, exist_ok=True) # créer dossier si absent + with open(memory_file, "w") as f: + json.dump(serializable_memory, f, indent=2) + print(f"💾 Mémoire N3R sauvegardée : {memory_file}") + + +# ------------------- Chargement mémoire ------------------- +def load_memory(memory_file: Path): + """ + Charge la mémoire N3R depuis un fichier JSON. + Convertit les listes en tensors. + """ + memory_file = memory_file.with_suffix(".json") + if memory_file.exists(): + with open(memory_file, "r") as f: + mem = json.load(f) + # Convertir listes → tensors + memory_dict = {k: torch.tensor(v) for k, v in mem.items()} + print(f"✅ Mémoire N3R chargée depuis {memory_file}") + return memory_dict + else: + print("⚡ Nouvelle mémoire N3R initialisée") + return {} + +def stabilize_latents_before_decode( + latents, + latent_scale, + clamp_val=0.95, + smooth_kernel=3, + enable_smoothing=True +): + """ + Stabilise les latents avant décodage pour éviter les artefacts de tiles. Args: latents (torch.Tensor): latents [B,C,H,W] latent_scale (float): facteur VAE (ex: 0.18215) clamp_val (float): limite de clamp (0.9–1.0 recommandé) smooth_kernel (int): taille du noyau de smoothing enable_smoothing (bool): active le lissage spatial Returns: torch.Tensor: latents prêts pour decode + """ + + # 🔥 sécurité NaN / inf + latents = torch.nan_to_num(latents) + + # 🔥 clamp doux (évite contrastes violents entre tiles) + latents = torch.clamp(latents, -clamp_val, clamp_val) + + # 🔥 lissage spatial (corrige les seams) + if enable_smoothing and smooth_kernel > 1: + latents = torch.nn.functional.avg_pool2d( + latents, + kernel_size=smooth_kernel, + stride=1, + padding=smooth_kernel // 2 + ) + + # 🔥 continuité mémoire (important pour decode blockwise) + latents = latents.contiguous() + + # 🔥 scale VAE + latents = latents / latent_scale + + return latents + + +import math + +# linear → transition constante (basique) +# cosine → très fluide (ton choix actuel 👍) +# smoothstep → encore plus doux (souvent top pour vidéo) +# ease_in → démarre lentement +# ease_out → ralentit à la fin +def get_interpolated_embeddings( + frame_idx, + frames_per_prompt, + pos_list, + neg_list, + device="cuda", + debug=False, + profile="cosine" # "linear", "cosine", "smoothstep", "ease_in", "ease_out" +): + num_prompts = len(pos_list) + + if num_prompts == 0: + raise ValueError("pos_list ne peut pas être vide") + + # index de base + idx = frame_idx // frames_per_prompt + idx = min(idx, num_prompts - 1) + + idx_next = min(idx + 1, num_prompts - 1) + + # progression locale dans le segment + t_raw = (frame_idx % frames_per_prompt) / frames_per_prompt + + # profils d'interpolation + if profile == "linear": + t = t_raw + elif profile == "cosine": + t = 0.5 - 0.5 * math.cos(math.pi * t_raw) + elif profile == "smoothstep": + t = t_raw * t_raw * (3 - 2 * t_raw) + elif profile == "ease_in": + t = t_raw ** 2 + elif profile == "ease_out": + t = 1 - (1 - t_raw) ** 2 + else: + raise ValueError(f"Profil inconnu: {profile}") + + # interpolation + pos = (1 - t) * pos_list[idx] + t * pos_list[idx_next] + neg = (1 - t) * neg_list[idx] + t * neg_list[idx_next] + + if debug: + print("=== DEBUG INTERPOLATION ===") + print(f"frame_idx: {frame_idx}") + print(f"segment idx: {idx} → {idx_next}") + print(f"t_raw: {t_raw:.4f}") + print(f"t ({profile}): {t:.4f}") + print(f"frames_per_prompt: {frames_per_prompt}") + print(f"num_prompts: {num_prompts}") + print("===========================") + + return pos.to(device), neg.to(device) + + +def get_interpolated_embeddings_s(frame_idx, frames_per_prompt, pos_list, neg_list, device="cuda"): + num_prompts = len(pos_list) + + # index de base + idx = frame_idx // frames_per_prompt + idx_next = min(idx + 1, num_prompts - 1) + + # progression locale dans le segment + t = (frame_idx % frames_per_prompt) / frames_per_prompt + + # cosine smooth (beaucoup mieux que linéaire) + t = 0.5 - 0.5 * math.cos(math.pi * t) + + pos = (1 - t) * pos_list[idx] + t * pos_list[idx_next] + neg = (1 - t) * neg_list[idx] + t * neg_list[idx_next] + + return pos.to(device), neg.to(device) + + +def compute_overlap(W, H, block_size, max_overlap_ratio=0.5, min_overlap=8): + overlap = int(block_size * max_overlap_ratio) + overlap = min(overlap, min(W, H) // 4) + overlap = max(overlap, min_overlap) + return overlap + +# ---------------- DEBUG UTILS ---------------- +def log_debug(message, level="INFO", verbose=True): + """ + Affiche le message si verbose=True. + level: "INFO", "DEBUG", "WARNING" + """ + if verbose: + print(f"[{level}] {message}") +# ------------------------------------------------------------------------------------------- +# Version vraiment stable +def sanitize_latents( + latents, + ref_stats=None, # (mean, std) référence EMA + max_val=1.2, # clamp hard + eps=1e-6, + min_momentum=0.90, # momentum minimum pour plus de détail + max_momentum=0.95, # momentum maximum pour ultra stabilité + debug=False +): + import torch + + # --- 1. Nettoyage NaN / Inf --- + latents = torch.nan_to_num(latents, nan=0.0, posinf=max_val, neginf=-max_val) + + mean = latents.mean() + std = latents.std() + + # --- 2. Ajustement dynamique du momentum --- + # Variance relative : plus std élevé → plus on réduit momentum pour conserver détails + std_factor = torch.clamp(std / 1.0, 0.0, 1.0) # normalisation relative + momentum = max_momentum - (max_momentum - min_momentum) * std_factor + + # --- 3. Initialisation référence --- + if ref_stats is None: + ref_mean = mean.detach() + ref_std = std.detach() + else: + prev_mean, prev_std = ref_stats + # EMA adaptatif selon momentum + ref_mean = momentum * prev_mean + (1 - momentum) * mean + ref_std = momentum * prev_std + (1 - momentum) * std + + # --- 4. Alignement vers référence --- + latents = latents - mean + latents = latents * (ref_std / (std + eps)) + latents = latents + ref_mean + + # --- 5. Clamp hard final + anti-saturation --- + latents = torch.clamp(latents, -max_val, max_val) + max_abs = latents.abs().max() + if max_abs > max_val: + latents = latents * (max_val / (max_abs + eps)) + + if debug: + print(f"[adaptive] mean={mean:.3f}->{ref_mean:.3f}, std={std:.3f}->{ref_std:.3f}, momentum={momentum:.3f}") + + #return latents, (ref_mean.detach(), ref_std.detach()) + return latents +#------------------------------------------------------------------------------------------------- +# Version vraiment stable +# 0.95 → ultra stable (vidéo cinéma) +# 0.9 → compromis détail / stabilité +# 0.7 → un peu plus réactif (créatif) + +def sanitize_latents_adaptive( + latents, + ref_stats=None, # 🔥 référence globale stable + momentum=0.95, # stabilité temporelle + max_val=1.2, + eps=1e-6, + debug=False +): + import torch + + latents = torch.nan_to_num(latents, nan=0.0, posinf=max_val, neginf=-max_val) + + mean = latents.mean() + std = latents.std() + + # --- 1. Initialisation référence --- + if ref_stats is None: + ref_mean = mean.detach() + ref_std = std.detach() + else: + prev_mean, prev_std = ref_stats + + # 🔥 EMA (clé) + ref_mean = momentum * prev_mean + (1 - momentum) * mean + ref_std = momentum * prev_std + (1 - momentum) * std + + # --- 2. Normalisation vers référence stable --- + latents = latents - mean + latents = latents * (ref_std / (std + eps)) + latents = latents + ref_mean + + # --- 3. Clamp propre (hard → ton meilleur résultat) + latents = torch.clamp(latents, -max_val, max_val) + + # --- 4. Anti saturation final + max_abs = latents.abs().max() + if max_abs > max_val: + latents = latents * (max_val / (max_abs + eps)) + + if debug: + print(f"[stable] mean={mean:.3f}→{ref_mean:.3f}, std={std:.3f}→{ref_std:.3f}") + + #return latents, (ref_mean.detach(), ref_std.detach()) + return latents + + +# Version stable hard +def sanitize_latents_hard( + latents, + clamp_mode="hard", # "hard", "tanh" + max_val=1.2, + std_threshold=1.5, + percentile=0.995, + eps=1e-6, + debug=False +): + import torch + + # --- 1. Nettoyage NaN / Inf (safe) + latents = torch.nan_to_num(latents, nan=0.0, posinf=max_val, neginf=-max_val) + + # --- 2. Clamp intelligent (meilleur que clamp brut) + if clamp_mode == "hard": + latents = torch.clamp(latents, -max_val, max_val) + + elif clamp_mode == "tanh": + latents = torch.tanh(latents / max_val) * max_val + + else: + raise ValueError(f"Unknown clamp_mode: {clamp_mode}") + + # --- 3. Détection explosion (robuste) + std = latents.std() + + if std > std_threshold: + # normalisation robuste basée sur percentiles + flat = latents.flatten() + + high = torch.quantile(flat, percentile) + low = torch.quantile(flat, 1 - percentile) + + scale = max(abs(high), abs(low), eps) + + latents = latents / scale + + if debug: + print(f"[sanitize] percentile scaling applied: scale={scale:.4f}") + + # --- 4. Stabilisation fine (évite drift) + mean = latents.mean() + latents = latents - mean * 0.05 # recentrage léger (pas destructif) + + if debug: + print(f"[sanitize] std={std:.4f}, mean={mean:.4f}") + + return latents + +# Version original: +def sanitize_latents_v1(latents): + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=-1.0) + + # clamp doux (évite saturation brutale) + latents = torch.clamp(latents, -1.2, 1.2) + # normalisation légère si explosion + if latents.std() > 1.5: + latents = latents / latents.std() + + return latents + +# ------------------------------------------------------------------------------------------- +def stabilize_latents_advanced(latents, strength=0.99, knee=0.7): + # sécurité + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=-1.0) + + # clamp doux + latents = torch.clamp(latents, -1.2, 1.2) + # normalisation si explosion + std = latents.std() + if std > 1.5: + latents = latents / std + # 🔥 compression non-linéaire (anti crispy blanc) + latents = torch.tanh(latents * (1.0 / knee)) * knee + # léger scaling global + latents = latents * strength + + return latents + + + +def print_generation_params(params: dict): + """ + Affiche les paramètres de génération dans un tableau clair. + + params : dict + Dictionnaire contenant tous les paramètres nécessaires. + Exemple de clés attendues : + 'fps', 'upscale_factor', 'num_fraps_per_image', 'steps', + 'guidance_scale', 'init_image_scale', 'creative_noise', + 'latent_scale_boost', 'final_latent_scale', 'seed' + """ + print("📌 Paramètres de génération :") + print(f"{'Paramètre':<20} {'Valeur':>10} {'Paramètre':<20} {'Valeur':>10}") + + left_keys = ['fps', 'num_fraps_per_image', 'guidance_scale', 'guidance_scale_end', 'creative_noise', 'creative_noise_end','final_latent_scale', 'transition_frames', 'use_n3r_model'] + right_keys = ['use_mini_gpu', 'upscale_factor', 'steps', 'init_image_scale', 'init_image_scale_end', 'latent_scale_boost', 'seed', 'latent_injection', 'block_size'] + + for l, r in zip(left_keys, right_keys): + print(f"{l:<20} {params.get(l, ''):>10} {r:<20} {params.get(r, ''):>10}") + + +# ---------------- Tensor / PIL utils ---------------- +def prepare_frame_tensor(frame_tensor): + """ + Prépare un tensor de frame pour traitement (squeeze / permute / clamp) + """ + if frame_tensor.ndim == 5: frame_tensor = frame_tensor.squeeze(2) + if frame_tensor.ndim == 4: frame_tensor = frame_tensor.squeeze(0) + if frame_tensor.ndim == 3 and frame_tensor.shape[0] != 3: + frame_tensor = frame_tensor.permute(2,0,1) + return frame_tensor.clamp(0,1) + +def normalize_frame(frame_tensor): + """ + Normalise un tensor image dans l'intervalle [0, 1]. + + Cette fonction évite les problèmes d'overflow ou de valeurs hors plage + en re-scalant dynamiquement les valeurs du tensor. + + Args: + frame_tensor (torch.Tensor): + Tensor image de forme [C, H, W] ou [B, C, H, W], + avec des valeurs arbitraires. + + Returns: + torch.Tensor: + Tensor normalisé dans [0, 1]. + + Notes: + - Si min == max, aucune normalisation n'est appliquée. + - Un clamp final garantit la stabilité numérique. + """ + min_val = frame_tensor.min() + max_val = frame_tensor.max() + + if max_val > min_val: + frame_tensor = (frame_tensor - min_val) / (max_val - min_val) + + return frame_tensor.clamp(0, 1) + + +def tensor_to_pil(frame_tensor): + """ + Convertit un tensor torch en image PIL. + + Args: + frame_tensor (torch.Tensor): + Tensor image de forme [C, H, W] ou [1, C, H, W], + avec des valeurs attendues dans [0, 1]. + + Returns: + PIL.Image.Image: + Image PIL prête à être sauvegardée ou affichée. + + Notes: + - Si un batch est fourni ([B, C, H, W]), seule la première image est utilisée. + - Les valeurs sont automatiquement clampées dans [0, 1]. + - Le tensor est déplacé sur CPU avant conversion. + """ + if frame_tensor.ndim == 4: + frame_tensor = frame_tensor[0] + + return ToPILImage()(frame_tensor.cpu().clamp(0, 1)) + + +def ensure_4_channels(latents): + """ + Garantit que le tensor latent possède 4 canaux (format attendu par les modèles SD). + + Args: + latents (torch.Tensor): + Tensor latent de forme [B, C, H, W]. + + Returns: + torch.Tensor: + Tensor avec exactement 4 canaux. + + Notes: + - Si C == 1, les canaux sont dupliqués pour obtenir 4 canaux. + - Si C == 4, le tensor est retourné tel quel. + - Ne gère pas les cas C != 1 et C != 4 (à étendre si besoin). + """ + if latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + return latents + + +# ---------------- Video utils ---------------- +def save_frames_as_video_from_folder(folder_path, output_path, fps=12): + """Sauvegarde un dossier de frames PNG en vidéo mp4 via ffmpeg""" + import ffmpeg + from pathlib import Path + folder_path = Path(folder_path) + frame_files = sorted(folder_path.glob("frame_*.png")) + if not frame_files: + print("❌ Aucun frame trouvé") + return + pattern = str(folder_path / "frame_*.png") + ( + ffmpeg.input(pattern, framerate=fps, pattern_type='glob') + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) diff --git a/scripts/utils/vae_config.py b/scripts/utils/vae_config.py new file mode 100644 index 00000000..be289857 --- /dev/null +++ b/scripts/utils/vae_config.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------ +# vae_config.py - utilitaires pour VAE / détection type et infos +# ------------------------------------------------------------------ +import torch +from diffusers import AutoencoderKL + +def load_vae(vae_path, device="cpu", dtype=torch.float16): + """ + Charge un VAE depuis un fichier .safetensors ou .ckpt et active slicing/tiling. + Retourne le VAE et ses informations de compatibilité. + """ + print(f"📦 Chargement VAE : {vae_path}") + + vae = AutoencoderKL.from_single_file( + vae_path, + torch_dtype=dtype + ).to(device) + + vae.enable_slicing() + vae.enable_tiling() + + # Détection du type + vae_type, latent_channels, scaling_factor = detect_vae_type(vae) + print("🧠 Détection VAE") + print(f" type : {vae_type}") + print(f" latent_channels : {latent_channels}") + print(f" scaling_factor : {scaling_factor}") + + return vae, vae_type, latent_channels, scaling_factor + + +def detect_vae_type(vae): + """ + Détecte le type de VAE chargé en se basant sur scaling_factor et latent_channels. + Retourne : type VAE (str), latent_channels (int), scaling_factor (float) + """ + latent_channels = getattr(vae.config, "latent_channels", None) + scaling_factor = getattr(vae.config, "scaling_factor", None) + + # fallback classique pour SD1/SD2 + if scaling_factor is None: + scaling_factor = 0.18215 + + if abs(scaling_factor - 0.18215) < 1e-4: + vae_type = "SD1 / SD2 compatible" + elif abs(scaling_factor - 0.13025) < 1e-4: + vae_type = "SDXL compatible" + else: + vae_type = "VAE custom" + + return vae_type, latent_channels, scaling_factor + + +def vae_summary(vae): + """ + Affiche un résumé complet du VAE pour debug. + """ + vae_type, latent_channels, scaling_factor = detect_vae_type(vae) + print("─────────────────────────────") + print("📌 VAE SUMMARY") + print(f" type : {vae_type}") + print(f" latent_channels: {latent_channels}") + print(f" scaling_factor : {scaling_factor}") + print("─────────────────────────────") diff --git a/scripts/utils/vae_utils.py b/scripts/utils/vae_utils.py new file mode 100644 index 00000000..39cef53e --- /dev/null +++ b/scripts/utils/vae_utils.py @@ -0,0 +1,1692 @@ +# utils/vae_utils.py +import torch +from diffusers import AutoencoderKL +from pathlib import Path +import os +from diffusers import UNet2DConditionModel, AutoencoderKL, DPMSolverMultistepScheduler +from torch.nn.functional import interpolate +from torch.nn.functional import pad +from safetensors.torch import load_file +from torchvision.transforms import ToPILImage + +import torchvision.transforms as T +import torch, numpy as np +from PIL import Image +import math + + +LATENT_SCALE = 0.18215 + + +import torch +from torchvision.transforms import ToPILImage + +to_pil = ToPILImage() + +# ------------------------- +# Decode frame via VAE (compatible vae_offload) +# ------------------------- +def decode_latents_safe(latents, vae, device, tile_size=128, overlap=64): + """ + Décodage sécurisé des latents en image PIL, compatible avec VAE sur CPU (vae_offload) + et avec latents sur GPU. + """ + # Déplacer latents sur le device du VAE + vae_device = next(vae.parameters()).device + latents = latents.to(vae_device).float() + + # Éviter NaN/Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=0.0) + + # Décodage en tiles pour VRAM limitée + frame_tensor = decode_latents_to_image_tiled128( + latents, + vae, + tile_size=tile_size, + overlap=overlap, + device=vae_device + ).clamp(0, 1) + + # Renvoie un tensor CPU float32 pour sauvegarde + return frame_tensor.cpu() + +def decode_latents_safe_64(latents, vae, patch_size=64): + """ + Decode latents into images safely on low VRAM by splitting into patches. + latents: [B, C, H, W] + """ + B, C, H, W = latents.shape + decoded_imgs = [] + + for b in range(B): + latent = latents[b:b+1] # [1,C,H,W] + img = torch.zeros(1, C, H, W, device=latent.device) + + # Patch decoding + for i in range(0, H, patch_size): + for j in range(0, W, patch_size): + hi = min(i + patch_size, H) + wj = min(j + patch_size, W) + patch = latent[:, :, i:hi, j:wj] + decoded_patch = vae.decode(patch).sample + img[:, :, i:hi, j:wj] = decoded_patch + + # Clamp et convertir en PIL + img = img.squeeze(0).clamp(0, 1) + img_pil = to_pil(img.cpu()) + decoded_imgs.append(img_pil) + + # Nettoyage VRAM + del latent, img + torch.cuda.empty_cache() + + return decoded_imgs + +# Exemple dans ta boucle de génération +# for frame_idx, frame_latents in enumerate(latents_list): +# frame_pil = decode_latents_safe(frame_latents, vae)[0] +# frame_pil.save(f"output/frame_{frame_idx:03d}.png") + + +def decode_latents_to_image_bright_enhanced(latents, vae, gamma=0.7, brightness=1.2, contrast=1.1, saturation=1.15): + """ + Décodage des latents en image PIL avec : + - Correction gamma pour éclaircir + - Augmentation de luminosité, contraste et saturation pour un rendu plus vivant + """ + latents = torch.nan_to_num(latents, nan=0.0, posinf=4.0, neginf=-4.0) + + if latents.ndim == 5: # [B,C,T,H,W] + latents = latents[:, :, 0, :, :] + + # Revenir à l'échelle attendue par le VAE + latents = latents / LATENT_SCALE + + with torch.no_grad(): + # 🔹 S’assurer que les latents sont du même dtype que le VAE + latents = latents.to(vae.dtype) + image = vae.decode(latents).sample + + # Normalisation [-1,1] -> [0,1] + image = (image + 1.0) / 2.0 + image = image.clamp(0, 1) + + # Correction gamma + image = image.pow(1.0 / gamma) + + # Convertir en PIL pour post-processing + + + image = image[0] # ✅ retire dimension batch + pil_image = ToPILImage()(image.cpu().clamp(0, 1)) + + # Boost luminosité, contraste et saturation + pil_image = ImageEnhance.Brightness(pil_image).enhance(brightness) + pil_image = ImageEnhance.Contrast(pil_image).enhance(contrast) + pil_image = ImageEnhance.Color(pil_image).enhance(saturation) + + return pil_image + + + +def decode_latents_ultrasafe(latents, vae, gamma=0.7, brightness=1.2, contrast=1.1, saturation=1.15): + """ + Décodage ultra-sécurisé des latents en image PIL. + Protège contre : + - NaN / inf + - latents trop grands ou trop petits + - mauvais dtype + - images trop sombres / noires + """ + from torchvision.transforms import ToPILImage + import torch + from PIL import Image, ImageEnhance + + # ---------------- sécurité latents ---------------- + latents = torch.nan_to_num(latents, nan=0.0, posinf=5.0, neginf=-5.0) + + if latents.ndim == 5: # [B,C,T,H,W] + latents = latents[:, :, 0, :, :] # retirer dimension temporelle + + latents = latents / LATENT_SCALE # remise à l'échelle attendue par le VAE + + # Clamp très strict pour éviter valeurs extrêmes + latents = latents.clamp(-5.0, 5.0) + + # Assurer le dtype correct + latents = latents.to(vae.dtype) + + with torch.no_grad(): + image = vae.decode(latents).sample # [B,3,H,W] + + # Normalisation [-1,1] -> [0,1] + image = (image + 1.0) / 2.0 + image = image.clamp(0, 1) + + # Correction gamma pour éclaircir + image = image.pow(1.0 / gamma) + + # S'assurer qu'on a bien [3,H,W] + if image.ndim == 4: + image = image[0] # retirer batch si présent + + # Convertir en PIL + pil_image = ToPILImage()(image.cpu().clamp(0,1)) + + # Boost luminosité, contraste, saturation + pil_image = ImageEnhance.Brightness(pil_image).enhance(brightness) + pil_image = ImageEnhance.Contrast(pil_image).enhance(contrast) + pil_image = ImageEnhance.Color(pil_image).enhance(saturation) + + return pil_image + + +def decode_latents_to_image_vram_safe(latents, vae, gamma=0.7, brightness=1.2, contrast=1.1, saturation=1.15): + """ + Décodage des latents en PIL.Image avec corrections gamma, luminosité, contraste et saturation, + optimisé pour faible VRAM (4Go). Utilise float32 pour la stabilité. + + Args: + latents (torch.Tensor): [B,C,T,H,W] ou [B,C,H,W] + vae (AutoencoderKL): modèle VAE + gamma (float): correction gamma + brightness, contrast, saturation (float): boosts PIL.Image + + Returns: + List[PIL.Image]: images décodées + """ + from torchvision.transforms import ToPILImage + import torch + from PIL import ImageEnhance + + # Sécurisation NaN / Inf + latents = torch.nan_to_num(latents, nan=0.0, posinf=4.0, neginf=-4.0) + + # Si [B,C,T,H,W] → on prend la première "T" pour l'instant + if latents.ndim == 5: + latents = latents[:,:,0,:,:] + + # Revenir à l'échelle attendue par le VAE + latents = latents / LATENT_SCALE + + # Décode en float32 pour éviter frames noires + latents = latents.to(dtype=torch.float32) + + with torch.no_grad(): + image_tensor = vae.decode(latents).sample + + # Normalisation [-1,1] -> [0,1] + image_tensor = (image_tensor + 1.0) / 2.0 + image_tensor = image_tensor.clamp(0,1) + + # Correction gamma + image_tensor = image_tensor.pow(1.0 / gamma) + + images = [] + to_pil = ToPILImage() + for i in range(image_tensor.shape[0]): + img = image_tensor[i] + pil_img = to_pil(img.cpu()) + + # Boost luminosité / contraste / saturation + pil_img = ImageEnhance.Brightness(pil_img).enhance(brightness) + pil_img = ImageEnhance.Contrast(pil_img).enhance(contrast) + pil_img = ImageEnhance.Color(pil_img).enhance(saturation) + + images.append(pil_img) + + return images + + + +def generate_latents_robuste_model(latents, pos_embeds, neg_embeds, unet, scheduler, + motion_module=None, device="cuda", dtype=torch.float16, + guidance_scale=4.5, init_image_scale=0.85, + creative_noise=0.0, seed=42): + """ + Génère des latents robustes avec protection NaN/inf + """ + torch.manual_seed(seed) + B, C, T, H, W = latents.shape + latents = latents.to(device=device, dtype=dtype) + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W).contiguous() + init_latents = latents.clone() + + for t_step in scheduler.timesteps: + # Motion module optionnel + if motion_module is not None: + latents = motion_module(latents) + + # Bruit créatif + if creative_noise > 0: + latents = latents + torch.randn_like(latents) * creative_noise + + # Classifier-free guidance + latent_model_input = torch.cat([latents, latents], dim=0) + embeds = torch.cat([neg_embeds, pos_embeds], dim=0) + + # ⚡ Autocast pour FP16 stable + with torch.autocast(device_type=device, dtype=dtype): + with torch.no_grad(): + noise_pred = unet(latent_model_input, t_step, encoder_hidden_states=embeds).sample + + noise_uncond, noise_text = noise_pred.chunk(2) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Scheduler step + latents = scheduler.step(noise_pred, t_step, latents).prev_sample + + # Réinjection image initiale + latents = latents + init_image_scale * (init_latents - latents) + + # Protection NaN / inf + if torch.isnan(latents).any() or torch.isinf(latents).any(): + latents = torch.nan_to_num(latents, nan=0.0, posinf=1.0, neginf=-1.0) + latents = latents + torch.randn_like(latents) * 1e-2 + + # Vérification latents trop petits + mean_val = latents.abs().mean().item() + if math.isnan(mean_val) or mean_val < 1e-5: + latents = latents + torch.randn_like(latents) * 1e-2 + + latents = latents.reshape(B, T, C, H, W).permute(0,2,1,3,4).contiguous() + return latents + + +# ------------------------- +# Génération et décodage sécurisée pour n3rHYBRID24 +# ------------------------- +def generate_and_decode(latent_frame, unet, scheduler, pos_embeds, neg_embeds, + motion_module, vae, device="cuda", dtype=torch.float32, + guidance_scale=4.5, init_image_scale=0.85, creative_noise=0.0, + seed=42, steps=35, tile_size=128, overlap=32, vae_offload=False): + """ + Génère les latents pour un frame et les décode en image finale, + avec gestion automatique des devices, FP16, offload et tiling. + """ + import torch, time + + torch.manual_seed(seed) + + # ------------------------- + # Déplacer latents et embeddings sur le bon device et dtype + # ------------------------- + latent_frame = latent_frame.to(device=device, dtype=dtype) + pos_embeds = pos_embeds.to(device=device, dtype=dtype) + neg_embeds = neg_embeds.to(device=device, dtype=dtype) + + # ------------------------- + # Génération avec UNet + Scheduler + # ------------------------- + gen_start = time.time() + batch_latents = generate_latents_ai_5D_optimized( + latent_frame=latent_frame, + scheduler=scheduler, + pos_embeds=pos_embeds, + neg_embeds=neg_embeds, + unet=unet, + motion_module=motion_module, + device=device, + dtype=dtype, + guidance_scale=guidance_scale, + init_image_scale=init_image_scale, + creative_noise=creative_noise, + seed=seed, + steps=steps + ) + gen_time = time.time() - gen_start + + # ------------------------- + # Gestion VAE offload / device + # ------------------------- + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + if vae_offload: + vae.to(device) + + # Assure correspondance latents / VAE + batch_latents = batch_latents.to(device=vae_device, dtype=vae_dtype) + + # ------------------------- + # Décodage avec tiling universel + # ------------------------- + decode_start = time.time() + frame_tensor = decode_latents_to_image_tiled_universel( + batch_latents, + vae, + tile_size=tile_size, + overlap=overlap + ) + decode_time = time.time() - decode_start + + # Revenir en CPU si offload + if vae_offload: + vae.cpu() + torch.cuda.empty_cache() + + # ------------------------- + # Clamp final pour éviter NaN/inf + # ------------------------- + frame_tensor = frame_tensor.clamp(0.0, 1.0) + + return frame_tensor, batch_latents, gen_time, decode_time + + +def decode_latents_to_image_tiled_universel(latents, vae, tile_size=64, overlap=16): + """ + Decode latents 4D [B,C,H,W] ou 5D [B,C,T,H,W] en images [B,3,H*8,W*8]. + Tiling avec blending sécurisé. + """ + vae_dtype = next(vae.parameters()).dtype + device = vae.device + + if latents.ndim == 5: + B,C,T,H,W = latents.shape + latents = latents.permute(0,2,1,3,4).reshape(B*T,C,H,W) + elif latents.ndim == 4: + B,C,H,W = latents.shape + else: + raise ValueError(f"Latents attendus 4D ou 5D, got {latents.shape}") + + latents = latents.to(vae_dtype) + latents_scaled = latents / LATENT_SCALE + + output = torch.zeros(B,3,H*8,W*8,device=device,dtype=torch.float32) + weight = torch.zeros_like(output) + + stride = tile_size - overlap + y_positions = list(range(0,H-tile_size+1,stride)) or [0] + x_positions = list(range(0,W-tile_size+1,stride)) or [0] + if y_positions[-1] != H-tile_size: y_positions.append(H-tile_size) + if x_positions[-1] != W-tile_size: x_positions.append(W-tile_size) + + for y in y_positions: + for x in x_positions: + y1 = y+tile_size + x1 = x+tile_size + tile = latents_scaled[:,:,y:y1,x:x1] + + with torch.no_grad(): + decoded = vae.decode(tile).sample.float() + decoded = (decoded/2 + 0.5).clamp(0,1) + + iy0, ix0, iy1, ix1 = y*8, x*8, y1*8, x1*8 + output[:,:,iy0:iy1,ix0:ix1] += decoded + weight[:,:,iy0:iy1,ix0:ix1] += 1.0 + + return output / weight.clamp(min=1e-6) + + +def decode_latents_to_image_tiled128(latents, vae, tile_size=128, overlap=32, device="cuda"): + """ + Decode les latents [B,4,H,W] ou [B,4,T,H,W] en RGB [B,3,H,W] ou [B,3,T,H,W] + avec tiling pour éviter OOM et blending correct. + """ + vae_dtype = next(vae.parameters()).dtype + vae_device = next(vae.parameters()).device if device=="cuda" else device + + # Support 5D + if latents.ndim == 5: + B, C, T, H, W = latents.shape + latents = latents.permute(0,2,1,3,4).reshape(B*T, C, H, W) + reshape_back = True + elif latents.ndim == 4: + B, C, H, W = latents.shape + reshape_back = False + else: + raise ValueError(f"Latents attendus 4D ou 5D, got {latents.shape}") + + if C != 4: + raise ValueError(f"Latents doivent avoir 4 canaux, got {C}") + + # Convert dtype & device + latents = latents.to(vae_device, vae_dtype) + + stride = tile_size - overlap + out_H = H * 8 + out_W = W * 8 + + output = torch.zeros(latents.shape[0], 3, out_H, out_W, device=vae_device, dtype=torch.float32) + weight = torch.zeros_like(output) + + # Positions tuiles + y_positions = list(range(0, H - tile_size + 1, stride)) + x_positions = list(range(0, W - tile_size + 1, stride)) + if not y_positions: y_positions = [0] + if not x_positions: x_positions = [0] + if y_positions[-1] != H - tile_size: y_positions.append(H - tile_size) + if x_positions[-1] != W - tile_size: x_positions.append(W - tile_size) + + for y in y_positions: + for x in x_positions: + y1 = y + tile_size + x1 = x + tile_size + tile = latents[:, :, y:y1, x:x1] + + with torch.no_grad(): + decoded = vae.decode(tile / LATENT_SCALE).sample + + # Correction Stable Diffusion + decoded = (decoded / 2 + 0.5).clamp(0,1) + + iy0, ix0 = y*8, x*8 + iy1, ix1 = y1*8, x1*8 + + output[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + output = output / weight.clamp(min=1e-6) + + if reshape_back: + # Retour à [B,3,T,H,W] + output = output.reshape(B, T, 3, out_H, out_W).permute(0,2,1,3,4) + + return output + +def log_rgb_stats(image_tensor, step=""): + """Enregistre les statistiques RGB et renvoie les warnings sous forme de liste.""" + messages = [] + R = image_tensor[0, 0].cpu().numpy() + G = image_tensor[0, 1].cpu().numpy() + B = image_tensor[0, 2].cpu().numpy() + + R_min, R_max, R_mean = np.min(R), np.max(R), np.mean(R) + G_min, G_max, G_mean = np.min(G), np.max(G), np.mean(G) + B_min, B_max, B_mean = np.min(B), np.max(B), np.mean(B) + + # Vérifications + if R_min < 0.0 or R_max > 1.0: + messages.append(f"{step}: canal R hors plage [{R_min:.3f},{R_max:.3f}]") + if G_min < 0.0 or G_max > 1.0: + messages.append(f"{step}: canal G hors plage [{G_min:.3f},{G_max:.3f}]") + if B_min < 0.0 or B_max > 1.0: + messages.append(f"{step}: canal B hors plage [{B_min:.3f},{B_max:.3f}]") + + if abs(R_mean - G_mean) > 0.2 or abs(G_mean - B_mean) > 0.2: + messages.append(f"{step}: écart important entre R/G/B (R={R_mean:.3f}, G={G_mean:.3f}, B={B_mean:.3f})") + + return messages + +def encode_tile_vae(tile_rgb, vae, fp16=False): + """Encode une tile RGB [1,3,H,W] -> latent [1,4,H/8,W/8]""" + device = next(vae.parameters()).device + dtype = torch.float16 if fp16 else torch.float32 + tile_rgb = tile_rgb.to(device=device, dtype=dtype) + with torch.no_grad(): + latent = vae.encode(tile_rgb).latent_dist.sample() * LATENT_SCALE + return latent + +def tile_image_vae(img_tensor, tile_size=128, overlap=32): + """Découpe un tensor [B,3,H,W] en tiles RGB""" + B,C,H,W = img_tensor.shape + stride = tile_size - overlap + tiles, positions = [], [] + for y in range(0, H, stride): + for x in range(0, W, stride): + y1, y2 = y, min(y+tile_size,H) + x1, x2 = x, min(x+tile_size,W) + tile = img_tensor[:,:,y1:y2,x1:x2] + tiles.append(tile) + positions.append((y1,y2,x1,x2)) + return tiles, positions + +def merge_tiles_vae(tiles, positions, H, W): + """Fusionne les tiles [1,3,h,w] en image finale [1,3,H,W]""" + device = tiles[0].device + out = torch.zeros((tiles[0].shape[0], 3, H, W), device=device) + count = torch.zeros((tiles[0].shape[0], 3, H, W), device=device) + for t,(y1,y2,x1,x2) in zip(tiles,positions): + th,tw = t.shape[2], t.shape[3] + out[:,:,y1:y1+th,x1:x1+tw] += t + count[:,:,y1:y1+th,x1:x1+tw] += 1.0 + out /= count.clamp(min=1.0) + return out + +# -------------------------------------------- +# Vérification des tiles +#----------------------------------------------- +def clamp_and_warn_tile(tile_rgb, frame_idx, tile_idx, warnings_list): + """ + Clamp chaque canal à [0,1] et détecte les écarts importants R/G/B + tile_rgb : tensor [1,3,H,W] + """ + # Calcul stats + R, G, B = tile_rgb[0,0], tile_rgb[0,1], tile_rgb[0,2] + r_min, r_max, r_mean = R.min().item(), R.max().item(), R.mean().item() + g_min, g_max, g_mean = G.min().item(), G.max().item(), G.mean().item() + b_min, b_max, b_mean = B.min().item(), B.max().item(), B.mean().item() + + # Warning si hors plage [0,1] + if r_min < 0 or r_max > 1: + warnings_list.append(f"Frame {frame_idx} - Tile {tile_idx}: canal R hors plage [{r_min:.3f},{r_max:.3f}]") + if g_min < 0 or g_max > 1: + warnings_list.append(f"Frame {frame_idx} - Tile {tile_idx}: canal G hors plage [{g_min:.3f},{g_max:.3f}]") + if b_min < 0 or b_max > 1: + warnings_list.append(f"Frame {frame_idx} - Tile {tile_idx}: canal B hors plage [{b_min:.3f},{b_max:.3f}]") + + # Warning si écart important R/G/B + r_g_b = [r_mean, g_mean, b_mean] + if max(r_g_b) - min(r_g_b) > 0.3: # seuil configurable + warnings_list.append(f"Frame {frame_idx} - Tile {tile_idx}: écart important entre R/G/B (R={r_mean:.3f}, G={g_mean:.3f}, B={b_mean:.3f})") + + # Clamp pour éviter de propager l’erreur + tile_rgb = torch.clamp(tile_rgb, 0.0, 1.0) + return tile_rgb + +# scripts/utils/vae_utils.py + +# ------------------------- +# Encode tile safe FP32 +# ------------------------- +def encode_tile_safe_fp32(vae, tile_np, device="cuda", vae_offload=False): + """ + Encode une tile numpy [C,H,W] en latent VAE [1,4,H/8,W/8] + VRAM-safe, compatible FP32 VAE complet et offload + """ + tile_tensor = torch.from_numpy(tile_np).unsqueeze(0).to(device=device, dtype=torch.float32) # [1,3,H,W] + with torch.no_grad(): + if vae_offload: + vae.to(device) # mettre VAE sur le même device que le tile + latent = vae.encode(tile_tensor).latent_dist.sample() * LATENT_SCALE + if vae_offload: + vae.cpu() # remettre VAE sur CPU pour économiser VRAM + if device.startswith("cuda"): + torch.cuda.synchronize() + return latent + +# ------------------------- +# Merge tiles FP32 +# ------------------------- +def merge_tiles_fp32(tile_list, positions, H, W, latent_scale=1.0): + """ + Fusionne les tiles latents [1,C,th,tw] en image complète [1,C,H,W]. + Supporte tiles de tailles différentes et bordures. + """ + device = tile_list[0].device + C = tile_list[0].shape[1] + + out = torch.zeros(1, C, H, W, dtype=tile_list[0].dtype, device=device) + count = torch.zeros(1, C, H, W, dtype=tile_list[0].dtype, device=device) + + for tile, (y1, y2, x1, x2) in zip(tile_list, positions): + _, c, th, tw = tile.shape + h_len = y2 - y1 + w_len = x2 - x1 + th = min(th, h_len) + tw = min(tw, w_len) + out[:, :, y1:y1+th, x1:x1+tw] += tile[:, :, :th, :tw] + count[:, :, y1:y1+th, x1:x1+tw] += 1.0 + + count[count==0] = 1.0 + out = out / count + return out + +def encode_tile_safe_latent(vae, tile, device, LATENT_SCALE=0.18215): + """ + Encode une tuile en latent FP32 et pad si nécessaire. + tile: np.array (H,W,3) float32 0-1 + return: torch tensor (1,4,H_latent_max,W_latent_max) + """ + tile_tensor = torch.tensor(tile).permute(2,0,1).unsqueeze(0).to(device) + latent = vae.encode(tile_tensor).latent_dist.sample() * LATENT_SCALE + # Vérifier H,W du latent + H_lat, W_lat = latent.shape[2], latent.shape[3] + H_max = (tile.shape[0] + 7)//8 # VAE scale + W_max = (tile.shape[1] + 7)//8 + if H_lat != H_max or W_lat != W_max: + padH = H_max - H_lat + padW = W_max - W_lat + latent = torch.nn.functional.pad(latent, (0,padW,0,padH)) + return latent + +# --- Découper une image en tiles avec overlap --- +def tile_image_128(image, tile_size=128, overlap=16): + """ + Découpe une image (H,W,C ou C,H,W) en tiles avec overlap. + Retourne une liste de tiles (numpy arrays) et leurs positions (x1,y1,x2,y2). + """ + # Assure shape [C,H,W] + if image.ndim == 3 and image.shape[2] in [1,3]: + # H,W,C -> C,H,W + image = image.transpose(2,0,1) + elif image.ndim != 3: + raise ValueError(f"Image doit être 3D, shape={image.shape}") + + C,H,W = image.shape + stride = tile_size - overlap + tiles = [] + positions = [] + + for y in range(0, H, stride): + for x in range(0, W, stride): + y1, y2 = y, min(y + tile_size, H) + x1, x2 = x, min(x + tile_size, W) + tile = image[:, y1:y2, x1:x2] + tiles.append(tile.astype(np.float32)) # reste numpy + positions.append((x1, y1, x2, y2)) + return tiles, positions + + +# --- Normalisation d'une tile --- +def normalize_tile_128(img_array): + """ + img_array: np.ndarray, shape [H,W,C] ou [C,H,W], valeurs 0-255 + Retour: torch.Tensor [1,3,H,W] float32, valeurs 0-1 + """ + if img_array.ndim == 3 and img_array.shape[2] == 3: # HWC + img_array = img_array.transpose(2,0,1) + img_tensor = torch.from_numpy(img_array).unsqueeze(0).float() / 255.0 + return img_tensor + + +def decode_latents_correct(latents, vae): + """ + Décodage des latents en image RGB float32 + """ + vae_device = next(vae.parameters()).device + latents = latents.to(device=vae_device, dtype=torch.float32) + with torch.no_grad(): + decoded = vae.decode(latents).sample + decoded = torch.clamp(decoded, -1, 1) + decoded = (decoded + 1) / 2 + return decoded + + +# ------------------------- +# Fonction pour charger un VAE et tester son décodage +# ------------------------- +def safe_load_vae(vae_path, device="cuda", fp16=False, offload=False): + """ + Charge un VAE (FP32 ou FP16), renvoie l'objet VAE prêt à l'emploi. + """ + try: + # Chargement state_dict + state_dict = load_file(vae_path, device="cpu") + print("✅ State dict VAE chargé, clés:", list(state_dict.keys())[:5]) + + # Création d'un VAE compatible SD + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D"]*4, + up_block_types=["UpDecoderBlock2D"]*4, + block_out_channels=[128, 256, 512, 512], + latent_channels=4, + sample_size=32 # à adapter selon le checkpoint + ) + + # Chargement des poids + vae.load_state_dict(state_dict, strict=False) + + # Offload si demandé + if offload: + vae = vae.to("cpu") + else: + vae = vae.to(device) + if fp16: + vae = vae.half() + + return vae + + except Exception as e: + print(f"⚠ Erreur lors du chargement du VAE : {e}") + return None + + + + + +# --------------------------------- +# Debug UNet sur latents +# --------------------------------- +# Ajoutons des fonctions de test pour valider les latents avant de les passer dans UNet et VAE. + +def test_unet_on_latents(latents, unet, device): + """ + Teste l'entrée latente avant de la passer dans le UNet + Cette fonction aide à vérifier si la forme des latents est correcte avant de les utiliser. + """ + print(f"[Test UNet] Latents avant UNet - min={latents.min():.4f}, max={latents.max():.4f}, mean={latents.mean():.4f}") + + # Assurer que les latents ont la forme correcte (batch, channels, height, width) + if latents.ndimension() != 4: + raise ValueError(f"Les latents doivent avoir 4 dimensions, mais ils en ont {latents.ndimension()}") + + # Test rapide avec le UNet pour vérifier si la forme est correcte + try: + # Assurez-vous que le UNet est configuré avec les bons paramètres (timestep, encoder_hidden_states) + # Simulez les entrées nécessaires à un UNet standard si nécessaire + timestep = torch.tensor([0]).to(device) # Exemple de timestep (à ajuster selon votre modèle) + encoder_hidden_states = torch.zeros((latents.size(0), 77, latents.size(2)), device=device) # Ajustez selon votre taille + + output = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + print(f"[Test UNet] Latents après UNet - min={output.min():.4f}, max={output.max():.4f}, mean={output.mean():.4f}") + return output + except Exception as e: + print(f"[Test UNet] Erreur pendant l'exécution du UNet : {e}") + return None + +def test_vae_on_latents(latents, vae, device): + """ + Teste l'entrée latente avant de la passer au VAE + Cette fonction vérifie la forme et le contenu des latents avant le décodage avec VAE. + """ + print(f"[Test VAE] Latents avant VAE - min={latents.min():.4f}, max={latents.max():.4f}, mean={latents.mean():.4f}") + + # Assurez-vous que les latents ont la forme correcte pour le VAE + if latents.ndimension() != 4: + raise ValueError(f"Les latents doivent avoir 4 dimensions, mais ils en ont {latents.ndimension()}") + + # Essayons de décoder avec le VAE + try: + decoded_image = vae.decode(latents).sample # Ajustez selon la fonction exacte du VAE + print(f"[Test VAE] Image décodée - min={decoded_image.min():.4f}, max={decoded_image.max():.4f}, mean={decoded_image.mean():.4f}") + return decoded_image + except Exception as e: + print(f"[Test VAE] Erreur pendant le décodage avec le VAE : {e}") + return None + + +# ------------------------- +# Décodage tiled (pour grandes images) +# ------------------------- +def encode_images_to_latents_ai(images, vae): + """ + Encode une batch d'images [B, 3, H, W] en latents [B, 4, H/8, W/8]. + """ + device = images.device + vae_dtype = next(vae.parameters()).dtype + + with torch.no_grad(): + # On encode en latent avec le VAE, en s'assurant que la sortie a bien 4 canaux + latents = vae.encode(images.to(vae_dtype)).latent_dist.sample() + + # Assure que les latents sont en 4 canaux, ce qui est attendu pour le VAE + if latents.shape[1] != 4: + raise ValueError(f"Latents doivent avoir 4 canaux, mais ont {latents.shape[1]} canaux.") + + # Scale pour correspondre au SD + latents = latents * LATENT_SCALE + return latents + +# -------------------------------------------------------- +#---------- deprecated +#--------------------------------------------------------- + +def decode_latents_to_image_tiled128_old(latents, vae, tile_size=128, overlap=64, device="cuda"): + """ + Decode les latents [B, 4, H, W] en images [B, 3, H, W] avec tiling pour éviter OOM. + Supporte latents 5D [B, C, T, H, W] en reshaping automatique. + """ + vae_dtype = next(vae.parameters()).dtype + + # Support 5D : [B, C, T, H, W] -> [B*T, C, H, W] + if latents.ndim == 5: + B, C, T, H, W = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) + elif latents.ndim == 4: + B, C, H, W = latents.shape + else: + raise ValueError(f"Latents attendus 4D ou 5D, got {latents.shape}") + + # Assure que C=4 + if C != 4: + raise ValueError(f"Latents doivent avoir 4 canaux, got {C} canaux") + + # Convert dtype pour correspondre au VAE + latents = latents.to(vae_dtype) + + H_out, W_out = H, W + output = torch.zeros(B if latents.ndim == 4 else B*T, 3, H_out, W_out, device=device, dtype=torch.float32) + + # Décodage en tiling si image > tile_size + if max(H, W) <= tile_size: + with torch.no_grad(): + decoded = vae.decode(latents / LATENT_SCALE).sample + return decoded.clamp(0, 1) + + # Tiling (optionnel, pour très grandes images) + # Ici simplifié : on peut ajouter tiling si nécessaire + with torch.no_grad(): + decoded = vae.decode(latents / LATENT_SCALE).sample + return decoded.clamp(0, 1) +# ------------------------- +# vae_utils.py (version corrigée) +# ------------------------- +# ------------------------- +# Encode images en latents +# ------------------------- +def encode_images_to_latents_ai_old(images, vae): + """ + Encode une batch d'images [B, 3, H, W] en latents [B, 4, H/8, W/8]. + """ + device = images.device + vae_dtype = next(vae.parameters()).dtype + + with torch.no_grad(): + latents = vae.encode(images.to(vae_dtype)).latent_dist.sample() + # Scale pour correspondre au SD + latents = latents * LATENT_SCALE + return latents + +# ------------------------- +# Décodage tiled des latents 5D +# ------------------------- +def decode_latents_to_image_tiled128_5D(latents, vae, tile_size=128, overlap=64, device="cuda"): + """ + Decode les latents [B, 4, H, W] en images [B, 3, H, W] avec tiling pour éviter OOM. + Supporte latents 5D [B, C, T, H, W] en reshaping automatique. + """ + vae_dtype = next(vae.parameters()).dtype + + # Support 5D : [B, C, T, H, W] -> [B*T, C, H, W] + if latents.ndim == 5: + B, C, T, H, W = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) + elif latents.ndim == 4: + B, C, H, W = latents.shape + else: + raise ValueError(f"Latents attendus 4D ou 5D, got {latents.shape}") + + # Assure que C=4 + assert C == 4, f"Latents doivent avoir 4 canaux, got {C}" + + # Convert dtype pour correspondre au VAE + latents = latents.to(vae_dtype) + + H_out, W_out = H, W + output = torch.zeros(B if latents.ndim == 4 else B*T, 3, H_out, W_out, device=device, dtype=torch.float32) + + # Décodage en tiling si image > tile_size + if max(H, W) <= tile_size: + with torch.no_grad(): + decoded = vae.decode(latents / LATENT_SCALE).sample + return decoded.clamp(0, 1) + + # Tiling (optionnel, pour très grandes images) + # Ici simplifié : on peut ajouter tiling si nécessaire + with torch.no_grad(): + decoded = vae.decode(latents / LATENT_SCALE).sample + return decoded.clamp(0, 1) + +# --------------------------------------------------------------------------------------------- +# Test KO - deprecated +# ------------------------------------------------------------------------------------------- + +def decode_latents_to_image_test(latents, vae, tile_size=128, overlap=64, device="cuda"): + latents = latents.to(torch.float32) + + # For 3D or 4D latents + if latents.ndim == 2: # [H*W] ou [N, H*W] + raise ValueError(f"Latents trop aplatis, shape={latents.shape}") + elif latents.ndim == 3: # [C,H,W] -> ajouter batch + latents = latents.unsqueeze(0) # [1,C,H,W] + elif latents.ndim == 4: # [B,C,H,W] + pass + elif latents.ndim == 5: # [B,C,T,H,W] + B, C, T, H, W = latents.shape + latents = latents.reshape(B*T, C, H, W) + else: + raise ValueError(f"Latents must be 3D, 4D or 5D, got {latents.ndim}D") + + # Dupliquer si 1 seul canal + if latents.shape[1] == 1: + latents = latents.repeat(1, 4, 1, 1) + + # Décodage + with torch.no_grad(): + decoded = vae.decode(latents / 0.18215).sample + decoded = decoded.clamp(0, 1) + + return decoded + + +def decode_latents_to_image_SDiffusion(latents, vae, tile_size=128, overlap=64, device="cuda"): + """ + Décodage VAE en tuiles 128x128, compatible latents 4D ou 5D. + """ + # Convertir en float32 + latents = latents.to(torch.float32) + + # Fusion batch+time si nécessaire + if latents.ndim == 5: + B, C, T, H, W = latents.shape + latents = latents.reshape(B*T, C, H, W) + elif latents.ndim == 4: + B, C, H, W = latents.shape + else: + raise ValueError(f"Latents must be 4D or 5D, got {latents.ndim}D") + + assert C == 4, f"Expected 4 channels in latents, got {C}" + + # Décodage complet (pas en tuiles pour simplifier ici) + with torch.no_grad(): + decoded = vae.decode(latents / 0.18215).sample # shape [B*T, 3, H, W] + decoded = decoded.clamp(0, 1) + + # Remettre en 5D si nécessaire + if 'T' in locals(): + decoded = decoded.reshape(B, T, 3, H, W) + + return decoded + +def decode_latents_to_image_tiled4D(latents, vae, tile_size=128, overlap=64, device="cuda"): + """ + Fonction corrigée pour décoder les latents en image via VAE, en traitant les tuiles. + Cette version gère correctement la dimension des canaux latents (C=4) et la conversion en float32. + """ + + # Assurez-vous que latents ont la bonne forme [B, C, H, W] où C = 4 + B, C, H, W = latents.shape + assert C == 4, f"Expected 4 channels in latents, got {C}" + + # Convertir en float32 si nécessaire (VAE attend des latents en float32) + latents = latents.to(torch.float32) + + # Variable pour stocker les tuiles décodées + decoded_image = torch.zeros(B, 3, H * tile_size, W * tile_size, device=device) + + # Transformation en image + to_pil = ToPILImage() + + # Décodez les tuiles par petits morceaux + for i in range(0, H, tile_size - overlap): + for j in range(0, W, tile_size - overlap): + # Extraire la tuile avec chevauchement + y1, y2 = i, min(i + tile_size, H) + x1, x2 = j, min(j + tile_size, W) + latent_tile = latents[:, :, y1:y2, x1:x2] + + # Décoder la tuile + with torch.no_grad(): + decoded_tile = vae.decode(latent_tile / 0.18215).sample + + # Ajuster les dimensions pour fusionner les tuiles + decoded_tile = decoded_tile.clamp(0, 1) + + # Insérer la tuile dans la position correspondante de l'image finale + decoded_image[:, :, y1 * tile_size:(y2 * tile_size), x1 * tile_size:(x2 * tile_size)] = decoded_tile + + return decoded_image + + + + + +# ------------------------- +# Fonction test VAE (sans affichage d'image) +# ------------------------- +def test_vae_simple(vae_path, device="cuda"): + try: + vae = safe_load_vae(vae_path, device=device, fp16=False) + if vae is None: + return False + + # Tenseur aléatoire pour test + test_latent = torch.randn(1, 4, 32, 32).to(device) + with torch.no_grad(): + decoded_out = vae.decode(test_latent / 0.18215) + decoded = decoded_out.sample if hasattr(decoded_out, "sample") else decoded_out + + # Vérification simple + if decoded is not None and decoded.shape[1] == 3: + return True + return False + + except Exception as e: + print(f"⚠ Test VAE échoué : {e}") + return False + + +def test_vae_256(vae, image): + """ + Test rapide du VAE 256x256. + Vérifie encode -> decode sans afficher d'image. + """ + + device = next(vae.parameters()).device + dtype = next(vae.parameters()).dtype + + transform = T.Compose([ + T.Resize(256), + T.ToTensor(), + T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + + image_tensor = transform(image).unsqueeze(0).to(device=device, dtype=dtype) + + with torch.no_grad(): + latents = vae.encode(image_tensor).latent_dist.sample() + print("🔎 Latent shape:", latents.shape) + + decoded = vae.decode(latents).sample + print("🔎 Decoded shape:", decoded.shape) + + print("✅ Test VAE 256 OK") + + + +def test_vae(vae_path: str, device: str = "cuda") -> bool: + """ + Charge un VAE depuis un .safetensors et effectue un test de décodage rapide. + Retourne True si le VAE est opérationnel, False sinon. + """ + try: + # Charge le state_dict depuis le fichier .safetensors + state_dict = load_file(vae_path, device="cpu") + print("✅ State dict chargé avec succès, clés:", list(state_dict.keys())[:5]) + + # Crée un VAE standard compatible SD + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D"]*4, + up_block_types=["UpDecoderBlock2D"]*4, + block_out_channels=[128, 256, 512, 512], + latent_channels=4, + sample_size=32 # adapter selon le checkpoint + ) + + # Ignore les clés manquantes ou inattendues + vae.load_state_dict(state_dict, strict=False) + vae = vae.to(device) + print(f"✅ VAE chargé et déplacé sur {device}") + + # Test rapide avec un tenseur latent aléatoire + test_latent = torch.randn(1, 4, 32, 32).to(device) + with torch.no_grad(): + decoded_out = vae.decode(test_latent / 0.18215) + decoded = decoded_out.sample if hasattr(decoded_out, "sample") else decoded_out + + # Vérifie juste la forme sans afficher l'image + if decoded.shape[1] == 3: + print(f"✅ Décodage test OK, output shape: {decoded.shape}") + return True + else: + print(f"⚠ Décodage test incorrect, output shape: {decoded.shape}") + return False + + except Exception as e: + print("⚠ Erreur lors du test VAE :", e) + return False + + + +def safe_load_vae_safetensors(vae_path, device="cuda", fp16=False, offload=False): + """ + Charge un VAE depuis un fichier .safetensors seul. + """ + if not vae_path or not os.path.exists(vae_path): + print(f"⚠ VAE non trouvé à {vae_path}") + return None + + try: + # Charger le state dict du fichier safetensors + state_dict = load_file(vae_path, device="cpu") # d'abord en CPU pour éviter OOM + + # Créer un AutoencoderKL vide (Tiny-SD 128x128) + vae = AutoencoderKL( + in_channels=4, + out_channels=4, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + block_out_channels=[128, 256], + latent_channels=4, + sample_size=32 + ) + + # Charger le state dict + vae.load_state_dict(state_dict, strict=False) + + # Déplacer sur le device voulu et ajuster dtype + vae = vae.to(torch.float16 if fp16 else torch.float32) + vae = vae.to(device) + + return vae + except Exception as e: + print(f"⚠ Erreur lors du chargement du VAE : {e}") + return None + + + +# ------------------------- +# Decode latents VAE auto tile +# ------------------------- +def decode_latents_frame_ai_auto(latents: torch.Tensor, vae, device="cuda"): + """ + Decode latents en image PIL ou tensor, tile_size automatique: + - Si max(H,W) <= 256: decode complet + - Sinon: tiles 128x128, overlap 50% + latents: [C,H/8,W/8] ou [B,C,H/8,W/8] + vae: modèle VAE + device: "cuda" ou "cpu" + Retourne: Tensor float32 [3,H,W] en [0,1] + """ + import torch + latents = latents.to(device) + + # Si pas de batch dim, ajouter + if latents.ndim == 3: + latents = latents.unsqueeze(0) # [1,C,H,W] + + _, C, H_lat, W_lat = latents.shape + H_out, W_out = H_lat*8, W_lat*8 # upscaling par VAE + + # Petites images → decode complet + if max(H_out, W_out) <= 256: + with torch.no_grad(): + frame_tensor = vae.decode(latents / 0.18215).sample # Tiny-SD scale + frame_tensor = torch.clamp(frame_tensor, -1.0, 1.0) + frame_tensor = (frame_tensor + 1) / 2 # [-1,1] → [0,1] + frame_tensor = frame_tensor[0] # enlever batch dim + return frame_tensor + + # Grandes images → decode en tiles + tile_size = 128 + overlap = tile_size // 2 + _, C, H, W = latents.shape + output = torch.zeros((1, 3, H*8, W*8), device=device) + + count_map = torch.zeros_like(output) + + # Générer les positions de tiles + xs = list(range(0, W, tile_size - overlap)) + ys = list(range(0, H, tile_size - overlap)) + + for y in ys: + for x in xs: + y1, y2 = y, min(y + tile_size, H) + x1, x2 = x, min(x + tile_size, W) + + lat_tile = latents[:, :, y1:y2, x1:x2] + with torch.no_grad(): + dec_tile = vae.decode(lat_tile / 0.18215).sample + dec_tile = torch.clamp(dec_tile, -1, 1) + dec_tile = (dec_tile + 1) / 2 # [0,1] + + H_t, W_t = dec_tile.shape[2], dec_tile.shape[3] + output[:, :, y1*8:y1*8+H_t*8, x1*8:x1*8+W_t*8] += dec_tile.repeat(1,1,8,8) + count_map[:, :, y1*8:y1*8+H_t*8, x1*8:x1*8+W_t*8] += 1.0 + + output /= count_map.clamp(min=1.0) + return output[0] + + +# ------------------------- +# Mémoire GPU utils +# ------------------------- +def log_gpu_memory(tag=""): + if torch.cuda.is_available(): + print(f"[GPU MEM] {tag} → allocated={torch.cuda.memory_allocated()/1e6:.1f}MB, " + f"reserved={torch.cuda.memory_reserved()/1e6:.1f}MB, " + f"max_allocated={torch.cuda.max_memory_allocated()/1e6:.1f}MB") + + + +# ------------------------- +# Encode / Decode avec logs +# ------------------------- +def encode_images_to_latents_ai_test(images, vae): + device = vae.device + images = images.to(device=device, dtype=torch.float32) + print(f"[VAE] Encode → device={device}, images.shape={images.shape}, dtype={images.dtype}") + log_gpu_memory("avant encode VAE") + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) # [B,C,1,H,W] + print(f"[VAE] Latents shape après encode: {latents.shape}") + log_gpu_memory("après encode VAE") + return latents + +def decode_latents_frame_ai(latents, vae): + # --- Tile size auto --- + _, C, H, W = latents.shape[-4:] # [B,C,H,W] + tile_size = min(H, W) # tile couvre toute la latente pour éviter mosaïque + overlap = tile_size // 2 + print(f"[VAE] Decode avec tile_size={tile_size}, overlap={overlap}, device={vae.device}, latents.shape={latents.shape}") + log_gpu_memory("avant decode VAE") + frame_tensor = decode_latents_to_image_tiled(latents, vae, tile_size=tile_size, overlap=overlap).clamp(0,1) + print(f"[VAE] Frame tensor shape après decode: {frame_tensor.shape}") + log_gpu_memory("après decode VAE") + return frame_tensor + + +# ------------------------- +# Encode / Decode +# ------------------------- +def encode_images_to_latents(images, vae): + device = vae.device + images = images.to(device=device, dtype=torch.float32) + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +def decode_latents_to_image(latents, vae): + latents = latents.to(vae.device).float() / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + +# ------------------------- +# Model loaders +# ------------------------- + +def safe_load_unet(model_path, device, fp16=False): + """ + Chargement sécurisé du UNet depuis un dossier local. + - supporte .safetensors ou .bin + - force strict=False pour éviter les mismatches + - affiche les poids non chargés + - support fp16/fp32 + """ + + folder = os.path.join(model_path, "unet") + if not os.path.exists(folder): + raise FileNotFoundError(f"Le dossier UNet n'existe pas : {folder}") + + print(f"🔄 Chargement UNet depuis {folder} ...") + model = UNet2DConditionModel.from_pretrained(folder) + + # --- Chemins fichiers de poids --- + state_dict_path_safetensors = os.path.join(folder, "diffusion_pytorch_model.safetensors") + state_dict_path_bin = os.path.join(folder, "diffusion_pytorch_model.bin") + + # --- Chargement des poids selon format --- + if os.path.exists(state_dict_path_safetensors): + print("✅ Chargement poids safetensors") + state_dict = load_file(state_dict_path_safetensors, device="cpu") + elif os.path.exists(state_dict_path_bin): + print("✅ Chargement poids bin") + state_dict = torch.load(state_dict_path_bin, map_location="cpu") + else: + raise FileNotFoundError("Aucun fichier de poids UNet trouvé (.safetensors ou .bin)") + + # --- Charger dans le modèle avec strict=False --- + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if missing or unexpected: + print(f"⚠️ Poids manquants : {missing}") + print(f"⚠️ Poids inattendus : {unexpected}") + + # --- Convert dtype si demandé --- + if fp16: + model = model.half() + + # --- Envoyer sur device --- + model = model.to(device) + + print(f"✅ UNet chargé avec dtype={next(model.parameters()).dtype}, device={next(model.parameters()).device}") + return model +# ------------------------- +# Model loaders +# ------------------------- + +def safe_load_unet_ori(model_path, device, fp16=False): + folder = os.path.join(model_path, "unet") + if os.path.exists(folder): + model = UNet2DConditionModel.from_pretrained(folder) + if fp16: + model = model.half() + return model.to(device) + return None + +def safe_load_vae_stable(model_path, device, fp16=False, offload=False): + folder = os.path.join(model_path, "vae") + if os.path.exists(folder): + model = AutoencoderKL.from_pretrained(folder) + model = model.to("cpu" if offload else device).float() + return model + return None + +def safe_load_scheduler(model_path): + folder = os.path.join(model_path, "scheduler") + if os.path.exists(folder): + return DPMSolverMultistepScheduler.from_pretrained(folder) + return None + + +# --- PIPELINE PRINCIPALE --- +def generate_5D_video_auto(pretrained_model_path, config, device='cuda'): + print("🔄 Chargement des modèles...") + motion_module = MotionModuleTiny(device=device) + scheduler = init_scheduler(config) # ta fonction existante + vae = load_vae(pretrained_model_path, device=device) + + total_frames = config['total_frames'] + fps = config['fps'] + H_src, W_src = config['image_size'] # résolution source + + # Génère les latents initiaux + latents = torch.randn(1, 4, H_src//8, W_src//8, device=device, dtype=torch.float16) + print(f"[INFO] Latents initiaux shape={latents.shape}") + + video_frames = [] + for t in range(total_frames): + try: + latents = motion_module.step(latents, t) + frame = decode_latents_frame_auto(latents, vae, H_src, W_src) + video_frames.append(frame) + except Exception as e: + print(f"⚠ Erreur frame {t:05d} → reset léger: {e}") + continue + + save_video(video_frames, fps, output_path=config['output_path']) + print(f"🎬 Vidéo générée : {config['output_path']}") + + +def decode_latents_frame_auto(latents, vae, H_src, W_src): + """ + Decode des latents VAE en images avec tiles 128x128, auto-adapté à la taille source. + """ + device = vae.device + print(f"[VAE] Decode → tile_size={tile_size}, overlap={overlap}, device={device}, latents.shape={latents.shape}") + log_gpu_memory("avant decode VAE") + + # Assure batch 4D + latents = latents.unsqueeze(0) if latents.dim() == 3 else latents + + # Décodage VAE en tiles + with torch.no_grad(): + frame_tensor = decode_latents_to_image_tiled( + latents, + vae, + tile_size=tile_size, + overlap=overlap + ).clamp(0,1) + + # Redimensionnement proportionnel à l'image source + H_out, W_out = H_src, W_src + if frame_tensor.shape[-2:] != (H_out, W_out): + frame_tensor = torch.nn.functional.interpolate( + frame_tensor, + size=(H_out, W_out), + mode='bicubic', + align_corners=False + ) + + log_gpu_memory("après decode VAE") + return frame_tensor.squeeze(0) + +# --------------------------------------------------------- +# Tuilage sécurisé +# --------------------------------------------------------- +def decode_latents_to_image_tiled(latents, vae, tile_size=32, overlap=8): + """ + Decode VAE en tuiles avec couverture complète garantie. + - Aucun trou possible + - Blending propre + - Stable mathématiquement + """ + + device = vae.device + latents = latents.to(device).float() / LATENT_SCALE + + B, C, H, W = latents.shape + stride = tile_size - overlap + + # Dimensions image finale (scale factor VAE = 8) + out_H = H * 8 + out_W = W * 8 + + output = torch.zeros(B, 3, out_H, out_W, device=device) + weight = torch.zeros_like(output) + + # --- positions garanties --- + y_positions = list(range(0, H - tile_size + 1, stride)) + x_positions = list(range(0, W - tile_size + 1, stride)) + + if not y_positions: + y_positions = [0] + if not x_positions: + x_positions = [0] + + if y_positions[-1] != H - tile_size: + y_positions.append(H - tile_size) + + if x_positions[-1] != W - tile_size: + x_positions.append(W - tile_size) + + for y in y_positions: + for x in x_positions: + + y1 = y + tile_size + x1 = x + tile_size + + tile = latents[:, :, y:y1, x:x1] + + with torch.no_grad(): + decoded = vae.decode(tile).sample + + decoded = (decoded / 2 + 0.5).clamp(0, 1) + + iy0 = y * 8 + ix0 = x * 8 + iy1 = y1 * 8 + ix1 = x1 * 8 + + output[:, :, iy0:iy1, ix0:ix1] += decoded + weight[:, :, iy0:iy1, ix0:ix1] += 1.0 + + return output / weight.clamp(min=1e-6) + + + +# ------------------------- +# Encode / Decode FP16 safe +# ------------------------- +def encode_images_to_latents_safe(images, vae): + device = vae.device + dtype = next(vae.parameters()).dtype # prend fp16 si le VAE est en FP16 + images = images.to(device=device, dtype=dtype) + + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + + +def decode_latents_to_image_safe(latents, vae): + dtype = next(vae.parameters()).dtype + latents = latents.to(vae.device).to(dtype) / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + +# ------------------------------ +def encode_images_to_latents_half(images, vae): + # récupère dtype réel du VAE + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + images = images.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B * F, C, H, W) + + latents_2d = vae.encode(images_2d).latent_dist.sample() + latents_2d = latents_2d * LATENT_SCALE + + latents = latents_2d.view( + B, F, + latents_2d.shape[1], + latents_2d.shape[2], + latents_2d.shape[3] + ) + + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + + else: + latents = vae.encode(images).latent_dist.sample() + latents = latents * LATENT_SCALE + latents = latents.unsqueeze(2) + + return latents + +def decode_latents_to_image_vae(latents, vae): + + # Récupère device + dtype réel du VAE + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + + # Aligne le dtype sur celui du VAE + latents = latents.to(device=vae_device, dtype=vae_dtype) + + latents = latents / LATENT_SCALE + + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0, 1) + + # On repasse en float32 pour sauvegarde PNG + return img.float() +# ------------------------- +# Encode / Decode +# ------------------------- +def encode_images_to_latents_ori(images, vae): + device = vae.device + images = images.to(device=device, dtype=torch.float32) + with torch.no_grad(): + if images.dim() == 5: # [B,C,F,H,W] + B, C, F, H, W = images.shape + images_2d = images.view(B*F, C, H, W) + latents_2d = vae.encode(images_2d).latent_dist.sample() * LATENT_SCALE + latent_shape = latents_2d.shape + latents = latents_2d.view(B, F, latent_shape[1], latent_shape[2], latent_shape[3]) + latents = latents.permute(0, 2, 1, 3, 4).contiguous() + else: + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + latents = latents.unsqueeze(2) + return latents + +def decode_latents_to_image_ori(latents, vae): + latents = latents.to(vae.device).float() / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0,1) + return img + + +# -------------------------------------------------------- +# | Mode | VAE | Images | Latents | Résultat | +# | ----------- | ------- | ------- | ------- | -------- | +# | fp32 | float32 | float32 | float32 | ✅ | +# | fp16 | float16 | float16 | float16 | ✅ | +# | offload CPU | float32 | float32 | float32 | ✅ | +# ------------------------- ci dessous: +# ------------------------- +# Encode / Decode corrigé FP16 safe +# ------------------------- +def encode_images_to_latents(images, vae): + device = next(vae.parameters()).device + dtype = next(vae.parameters()).dtype # on aligne avec le VAE + images = images.to(device=device, dtype=dtype) + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() * LATENT_SCALE + return latents + +def decode_latents_to_image(latents, vae): + # On force latents à avoir le même dtype et device que le VAE + vae_dtype = next(vae.parameters()).dtype + vae_device = next(vae.parameters()).device + latents = latents.to(device=vae_device, dtype=vae_dtype) / LATENT_SCALE + + with torch.no_grad(): + img = vae.decode(latents).sample + + # Normalisation sûre vers 0-1 + img = (img / 2 + 0.5).clamp(0, 1) + + # Si FP16 → convertir en float32 pour torchvision save_image + if img.dtype == torch.float16: + img = img.float() + + return img + +# NEW +# +# +# --------------------------------------------------------- +# Decode latents to image avec logs et sécurité +# --------------------------------------------------------- +def decode_latents_to_image_2(latents, vae, latent_scale=0.18215): + """ + latents: [B, C, F, H, W] ou [B, C, 1, H, W] pour frame unique + vae: VAE pour décodage + """ + try: + print(f"🔹 decode_latents_to_image_2 | input shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}") + + # Si latents a une dimension de frame singleton, la squeeze + if latents.shape[2] == 1: + latents = latents.squeeze(2) + print(f"🔹 Squeeze frame dimension → shape: {latents.shape}") + + # Assurer dtype et device compatible VAE + vae_dtype = next(vae.parameters()).dtype + vae_device = next(vae.parameters()).device + latents = latents.to(device=vae_device, dtype=vae_dtype) / latent_scale + + # Check NaN avant VAE + print(f"🔹 Latents before VAE decode | min: {latents.min()}, max: {latents.max()}, dtype: {latents.dtype}") + if torch.isnan(latents).any(): + print("❌ Warning: NaN detected in latents before VAE decode!") + + with torch.no_grad(): + img = vae.decode(latents).sample + + # Check NaN après décodage + print(f"🔹 Image after VAE decode | min: {img.min()}, max: {img.max()}, dtype: {img.dtype}") + if torch.isnan(img).any(): + print("❌ Warning: NaN detected in decoded image!") + + # Normalisation safe vers 0-1 + img = (img / 2 + 0.5).clamp(0, 1) + print(f"🔹 Image final | min: {img.min()}, max: {img.max()}, dtype: {img.dtype}, shape: {img.shape}") + + # Conversion FP16 -> FP32 si nécessaire + if img.dtype == torch.float16: + img = img.float() + + return img + + except Exception as e: + print(f"❌ Exception in decode_latents_to_image_2: {e}") + # Retourne une image noire safe si VAE échoue + B, C, H, W = latents.shape[:4] + return torch.zeros(B, 3, H*8, W*8, device=latents.device) # scale approx 8x pour SD VAE +# ------------------------- +# Encode / Decode corrigé +# ------------------------- +# ------------------------- + +def decode_latents_to_image_old(latents, vae): + vae_device = next(vae.parameters()).device + vae_dtype = next(vae.parameters()).dtype + latents = latents.to(device=vae_device, dtype=vae_dtype) + latents = latents / LATENT_SCALE + with torch.no_grad(): + img = vae.decode(latents).sample + img = (img / 2 + 0.5).clamp(0, 1) + return img.float() # on repasse en float32 pour PNG diff --git a/scripts/utils/video_utils.py b/scripts/utils/video_utils.py new file mode 100644 index 00000000..123145cf --- /dev/null +++ b/scripts/utils/video_utils.py @@ -0,0 +1,107 @@ +from pathlib import Path +import shutil +from PIL import Image +import os +import math +import ffmpeg + +def upscale_video( + input_video: Path, + output_video: Path, + scale_factor: int = 2, + method: str = "lanczos" +): + """ + Upscale une vidéo existante sans IA. + + method: + - lanczos (recommandé) + - bicubic + - bilinear + """ + + input_video = Path(input_video) + output_video = Path(output_video) + + print(f"🔎 Upscaling vidéo x{scale_factor} ({method})...") + + ( + ffmpeg + .input(str(input_video)) + .filter("scale", + f"iw*{scale_factor}", + f"ih*{scale_factor}", + flags=method) + .output( + str(output_video), + vcodec="libx264", + pix_fmt="yuv420p", + crf=18 # qualité élevée + ) + .overwrite_output() + .run(quiet=True) + ) + + print(f"✅ Vidéo upscalée générée : {output_video}") + + +# ------------------------- +# Video save +# ------------------------- + +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) + + +# ------------------------- +# Save video +# ------------------------- +def save_frames_as_video_rmtmp(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): shutil.rmtree(temp_dir) + temp_dir.mkdir() + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video_ori(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) diff --git a/scripts/video_utils.py b/scripts/video_utils.py new file mode 100644 index 00000000..123145cf --- /dev/null +++ b/scripts/video_utils.py @@ -0,0 +1,107 @@ +from pathlib import Path +import shutil +from PIL import Image +import os +import math +import ffmpeg + +def upscale_video( + input_video: Path, + output_video: Path, + scale_factor: int = 2, + method: str = "lanczos" +): + """ + Upscale une vidéo existante sans IA. + + method: + - lanczos (recommandé) + - bicubic + - bilinear + """ + + input_video = Path(input_video) + output_video = Path(output_video) + + print(f"🔎 Upscaling vidéo x{scale_factor} ({method})...") + + ( + ffmpeg + .input(str(input_video)) + .filter("scale", + f"iw*{scale_factor}", + f"ih*{scale_factor}", + flags=method) + .output( + str(output_video), + vcodec="libx264", + pix_fmt="yuv420p", + crf=18 # qualité élevée + ) + .overwrite_output() + .run(quiet=True) + ) + + print(f"✅ Vidéo upscalée générée : {output_video}") + + +# ------------------------- +# Video save +# ------------------------- + +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) + + +# ------------------------- +# Save video +# ------------------------- +def save_frames_as_video_rmtmp(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): shutil.rmtree(temp_dir) + temp_dir.mkdir() + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir) +# ------------------------- +# Video utilities +# ------------------------- +def save_frames_as_video_ori(frames, output_path, fps=12): + temp_dir = Path("temp_frames") + if temp_dir.exists(): + shutil.rmtree(temp_dir) + temp_dir.mkdir() + + for idx, frame in enumerate(frames): + frame.save(temp_dir / f"frame_{idx:05d}.png") + + ( + ffmpeg.input(f"{temp_dir}/frame_%05d.png", framerate=fps) + .output(str(output_path), vcodec="libx264", pix_fmt="yuv420p") + .overwrite_output() + .run(quiet=True) + ) + shutil.rmtree(temp_dir)