diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..5c1ec2c2
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,18 @@
+samples/
+wandb/
+outputs/
+__pycache__/
+
+scripts/animate_inter.py
+scripts/gradio_app.py
+models/Controlnet/*
+models/DreamBooth_LoRA/*
+models/DreamBooth_LoRA/Put*personalized*T2I*checkpoints*here.txt
+models/Motion_Module/*
+models/*
+*.ipynb
+*.safetensors
+*.ckpt
+.ossutil_checkpoint/
+ossutil_output/
+debugs/
diff --git a/README.md b/README.md
index b5450a73..aaf9392f 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-# AnimateDiff
+# Controled AnimateDiff (V2 is also available)
-This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725).
+This repository is an Controlnet Extension of the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725).
**[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)**
@@ -11,48 +11,113 @@ Yaohui Wang,
Yu Qiao,
Dahua Lin,
Bo Dai
-
*Corresponding Author
-[Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/)
+
+[](https://arxiv.org/abs/2307.04725)
+[](https://animatediff.github.io/)
+[](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff)
+[](https://huggingface.co/spaces/guoyww/AnimateDiff)
-## Todo
-- [x] Code Release
-- [x] Arxiv Report
-- [x] GPU Memory Optimization
-- [ ] Gradio Interface
+***WARNING! This version works as well as official but not compatible with the official implementation due to the difference of library versions.***
+
+
+
+  |
+  |
+  |
+  |
+
+
+  |
+  |
+  |
+  |
+
+
+Test video sources: dance and smiling.
+## Todo
+- [x] Add Controlnet in the pipeline.
+- [x] Add Controlnet in Gradio Demo.
+- [X] Optimize code in attention processor style.
+
+## Features
+- Added Controlnet for Video to Video control.
+- GPU Memory, ~12-14GB VRAM to inference w/o Controlnet and ~15-17GB VRAM with Controlnet.
+
+- **[2023/09/10]** New Motion Module release ! `mm_sd_v15_v2.ckpt` was trained on larger resolution & batch size, and gains noticabe quality improvements.Check it out at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) and use it with `configs/inference/inference-v2.yaml`. Example:
+ ```
+ python -m scripts.animate --config configs/prompts/v2/5-RealisticVision.yaml
+ ```
+ Here is a qualitative comparison between `mm_sd_v15.ckpt` (left) and `mm_sd_v15_v2.ckpt` (right):
+
+- GPU Memory Optimization, ~12GB VRAM to inference
+
+- User Interface: [Gradio](#gradio-demo), A1111 WebUI Extension [sd-webui-animatediff](https://github.com/continue-revolution/sd-webui-animatediff) (by [@continue-revolution](https://github.com/continue-revolution))
+- Google Colab: [Colab](https://colab.research.google.com/github/camenduru/AnimateDiff-colab/blob/main/AnimateDiff_colab.ipynb) (by [@camenduru](https://github.com/camenduru))
## Common Issues
Installation
+
Please ensure the installation of [xformer](https://github.com/facebookresearch/xformers) that is applied to reduce the inference memory.
+
Various resolution or number of frames
Currently, we recommend users to generate animation with 16 frames and 512 resolution that are aligned with our training settings. Notably, various resolution/frames may affect the quality more or less.
+
+
+How to use it without any coding
+
+1) Get lora models: train lora model with [A1111](https://github.com/continue-revolution/sd-webui-animatediff) based on a collection of your own favorite images (e.g., tutorials [English](https://www.youtube.com/watch?v=mfaqqL5yOO4), [Japanese](https://www.youtube.com/watch?v=N1tXVR9lplM), [Chinese](https://www.bilibili.com/video/BV1fs4y1x7p2/))
+or download Lora models from [Civitai](https://civitai.com/).
+
+2) Animate lora models: using gradio interface or A1111
+(e.g., tutorials [English](https://github.com/continue-revolution/sd-webui-animatediff), [Japanese](https://www.youtube.com/watch?v=zss3xbtvOWw), [Chinese](https://941ai.com/sd-animatediff-webui-1203.html))
+
+3) Be creative togther with other techniques, such as, super resolution, frame interpolation, music generation, etc.
+
+
+
Animating a given image
+
We totally agree that animating a given image is an appealing feature, which we would try to support officially in future. For now, you may enjoy other efforts from the [talesofai](https://github.com/talesofai/AnimateDiff).
Contributions from community
-Contributions are always welcome!! We will create another branch which community could contribute to. As for the main branch, we would like to align it with the original technical report:)
+Contributions are always welcome!! The dev branch is for community contributions. As for the main branch, we would like to align it with the original technical report :)
-
-## Setup for Inference
+## Setups for Inference
### Prepare Environment
-~~Our approach takes around 60 GB GPU memory to inference. NVIDIA A100 is recommanded.~~
-
-***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!***
```
git clone https://github.com/guoyww/AnimateDiff.git
@@ -71,7 +136,7 @@ git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDif
bash download_bashscripts/0-MotionModule.sh
```
-You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing), then put them in `models/Motion_Module/` folder.
+You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836), then put them in `models/Motion_Module/` folder.
### Prepare Personalize T2I
Here we provide inference configs for 6 demo T2I on CivitAI.
@@ -123,6 +188,59 @@ Then run the following commands:
```
python -m scripts.animate --config [path to the config file]
```
+## Inference with Controlnet
+Controlnet appoach is using video as source of content. It takes first `L` (usualy 16) frames from video.
+
+Download controlnet models using script:
+```bash
+bash download_bashscripts/9-Controlnets.sh
+```
+
+Run examples:
+```bash
+python -m scripts.animate --config configs/prompts/1-ToonYou-Controlnet.yaml
+python -m scripts.animate --config configs/prompts/2-Lyriel-Controlnet.yaml
+python -m scripts.animate --config configs/prompts/3-RcnzCartoon-Controlnet.yaml
+```
+
+Add controlnet to other config (see example in 1-ToonYou-Controlnet.yaml):
+```yaml
+control:
+ video_path: "./videos/smiling.mp4"
+ get_each: 2 # get each frame from video
+ controlnet_processor: "softedge" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_softedge" # control_v11p_sd15_softedge, control_v11f1p_sd15_depth, control_v11p_sd15_canny
+ guess_mode: True
+```
+
+## Steps for Training
+
+### Dataset
+Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine.
+Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`.
+
+### Configuration
+After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder:
+```
+train_data:
+ csv_path: [Replace with .csv Annotation File Path]
+ video_folder: [Replace with Video Folder Path]
+ sample_size: 256
+```
+Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files.
+
+### Training
+To train motion modules
+```
+torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml
+```
+
+To finetune the unet's image layers
+```
+torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml
+```
+
## Gradio Demo
We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
@@ -131,6 +249,8 @@ conda activate animatediff
python app.py
```
By default, the demo will run at `localhost:7860`.
+Be sure that imageio with backend is installed. (pip install imageio[ffmpeg])
+
## Gallery
@@ -241,4 +361,4 @@ Pose Model:Hold Sig
**Bo Dai**: [daibo@pjlab.org.cn](mailto:daibo@pjlab.org.cn)
## Acknowledgements
-Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video).
\ No newline at end of file
+Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video).
diff --git a/__assets__/animations/compare/ffmpeg b/__assets__/animations/compare/ffmpeg
new file mode 100644
index 00000000..e69de29b
diff --git a/__assets__/animations/compare/new_0.gif b/__assets__/animations/compare/new_0.gif
new file mode 100644
index 00000000..8681fa58
Binary files /dev/null and b/__assets__/animations/compare/new_0.gif differ
diff --git a/__assets__/animations/compare/new_1.gif b/__assets__/animations/compare/new_1.gif
new file mode 100644
index 00000000..dd0b296c
Binary files /dev/null and b/__assets__/animations/compare/new_1.gif differ
diff --git a/__assets__/animations/compare/new_2.gif b/__assets__/animations/compare/new_2.gif
new file mode 100644
index 00000000..7baeb7b8
Binary files /dev/null and b/__assets__/animations/compare/new_2.gif differ
diff --git a/__assets__/animations/compare/new_3.gif b/__assets__/animations/compare/new_3.gif
new file mode 100644
index 00000000..07dc3202
Binary files /dev/null and b/__assets__/animations/compare/new_3.gif differ
diff --git a/__assets__/animations/compare/old_0.gif b/__assets__/animations/compare/old_0.gif
new file mode 100644
index 00000000..70709b86
Binary files /dev/null and b/__assets__/animations/compare/old_0.gif differ
diff --git a/__assets__/animations/compare/old_1.gif b/__assets__/animations/compare/old_1.gif
new file mode 100644
index 00000000..5c605bea
Binary files /dev/null and b/__assets__/animations/compare/old_1.gif differ
diff --git a/__assets__/animations/compare/old_2.gif b/__assets__/animations/compare/old_2.gif
new file mode 100644
index 00000000..2e20b7b8
Binary files /dev/null and b/__assets__/animations/compare/old_2.gif differ
diff --git a/__assets__/animations/compare/old_3.gif b/__assets__/animations/compare/old_3.gif
new file mode 100644
index 00000000..b035a95e
Binary files /dev/null and b/__assets__/animations/compare/old_3.gif differ
diff --git a/__assets__/animations/control/canny/dance_1girl.gif b/__assets__/animations/control/canny/dance_1girl.gif
new file mode 100644
index 00000000..59e47d1b
Binary files /dev/null and b/__assets__/animations/control/canny/dance_1girl.gif differ
diff --git a/__assets__/animations/control/canny/dance_medival_portrait.gif b/__assets__/animations/control/canny/dance_medival_portrait.gif
new file mode 100644
index 00000000..053eb8a2
Binary files /dev/null and b/__assets__/animations/control/canny/dance_medival_portrait.gif differ
diff --git a/__assets__/animations/control/canny/smiling_medival_portrait.gif b/__assets__/animations/control/canny/smiling_medival_portrait.gif
new file mode 100644
index 00000000..dc1d19c3
Binary files /dev/null and b/__assets__/animations/control/canny/smiling_medival_portrait.gif differ
diff --git a/__assets__/animations/control/depth/smiling_1girl.gif b/__assets__/animations/control/depth/smiling_1girl.gif
new file mode 100644
index 00000000..cd4c5756
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_1girl.gif differ
diff --git a/__assets__/animations/control/depth/smiling_forbidden_castle.gif b/__assets__/animations/control/depth/smiling_forbidden_castle.gif
new file mode 100644
index 00000000..7603fc2d
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_forbidden_castle.gif differ
diff --git a/__assets__/animations/control/depth/smiling_halo.gif b/__assets__/animations/control/depth/smiling_halo.gif
new file mode 100644
index 00000000..6bfc6817
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_halo.gif differ
diff --git a/__assets__/animations/control/depth/smiling_medival.gif b/__assets__/animations/control/depth/smiling_medival.gif
new file mode 100644
index 00000000..1b53e1e0
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_medival.gif differ
diff --git a/__assets__/animations/control/depth/smiling_realistic_0.gif b/__assets__/animations/control/depth/smiling_realistic_0.gif
new file mode 100644
index 00000000..cb9d645f
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_realistic_0.gif differ
diff --git a/__assets__/animations/control/depth/smiling_realistic_1.gif b/__assets__/animations/control/depth/smiling_realistic_1.gif
new file mode 100644
index 00000000..5ad6bef7
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_realistic_1.gif differ
diff --git a/__assets__/animations/control/depth/smiling_realistic_2.gif b/__assets__/animations/control/depth/smiling_realistic_2.gif
new file mode 100644
index 00000000..fd6d604a
Binary files /dev/null and b/__assets__/animations/control/depth/smiling_realistic_2.gif differ
diff --git a/__assets__/animations/control/original/dance_original_16_2.gif b/__assets__/animations/control/original/dance_original_16_2.gif
new file mode 100644
index 00000000..b8da7b65
Binary files /dev/null and b/__assets__/animations/control/original/dance_original_16_2.gif differ
diff --git a/__assets__/animations/control/original/smiling_original_16_2.gif b/__assets__/animations/control/original/smiling_original_16_2.gif
new file mode 100644
index 00000000..edaed566
Binary files /dev/null and b/__assets__/animations/control/original/smiling_original_16_2.gif differ
diff --git a/__assets__/animations/control/softedge/dance_1girl.gif b/__assets__/animations/control/softedge/dance_1girl.gif
new file mode 100644
index 00000000..3d531421
Binary files /dev/null and b/__assets__/animations/control/softedge/dance_1girl.gif differ
diff --git a/__assets__/animations/control/softedge/smiling_realistic_0.gif b/__assets__/animations/control/softedge/smiling_realistic_0.gif
new file mode 100644
index 00000000..10fbb58b
Binary files /dev/null and b/__assets__/animations/control/softedge/smiling_realistic_0.gif differ
diff --git a/animatediff/controlnet/controlnet_module.py b/animatediff/controlnet/controlnet_module.py
new file mode 100644
index 00000000..b377d74b
--- /dev/null
+++ b/animatediff/controlnet/controlnet_module.py
@@ -0,0 +1,191 @@
+from collections import defaultdict
+from typing import Any
+
+import cv2
+import torch
+from PIL import Image
+from tqdm import tqdm
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+
+from .controlnet_processors import CONTROLNET_PROCESSORS
+
+
+def get_video_info(cap):
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ return height, width, fps, frame_count
+
+
+class ControlnetModule:
+ def __init__(self, config):
+ self.config = config
+ self.video_length = self.config['video_length']
+ self.img_w = self.config['img_w']
+ self.img_h = self.config['img_h']
+ self.do_cfg = self.config['guidance_scale'] > 1.0
+ self.num_inference_steps = config['steps']
+ self.guess_mode = config['guess_mode']
+ self.conditioning_scale = config['conditioning_scale']
+ self.device = config['device']
+
+ controlnet_info = CONTROLNET_PROCESSORS[self.config['controlnet_processor']]
+
+ if ('controlnet_processor_path' not in config) or not len(config['controlnet_processor_path']):
+ config['controlnet_processor_path'] = controlnet_info['controlnet']
+
+ controlnet = ControlNetModel.from_pretrained(
+ controlnet_info['controlnet'], torch_dtype=torch.float16)
+
+ if controlnet_info['is_custom']:
+ self.processor = controlnet_info['processor'](
+ **controlnet_info['processor_params'])
+ else:
+ self.processor = controlnet_info['processor'].from_pretrained(
+ 'lllyasviel/Annotators')
+
+ self.controlnet_pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ config['controlnet_pipeline'], #"runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ torch_dtype=torch.float16
+ )
+
+ del self.controlnet_pipe.vae
+ del self.controlnet_pipe.unet
+ del self.controlnet_pipe.feature_extractor
+
+ self.controlnet_pipe.to(self.device)
+
+ def process_video(self, video_path):
+ cap = cv2.VideoCapture(video_path)
+ orig_height, orig_width, fps, frames_count = get_video_info(cap)
+ print('| --- START VIDEO PROCESSING --- |')
+ print(f'| HxW: {orig_height}x{orig_width} | FPS: {fps} | FRAMES COUNT: {frames_count} |')
+
+ get_each = self.config.get('get_each', 1)
+ processed_images = []
+
+ for frame_index in tqdm(range(self.config['video_length'] * get_each)):
+ ret, image = cap.read()
+ if not ret or image is None:
+ break
+
+ if frame_index % get_each != 0:
+ continue
+
+ image = cv2.resize(image, (self.img_w, self.img_h))
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ condition_image = self.processor(Image.fromarray(image))
+ processed_images.append(condition_image)
+
+ return processed_images
+
+ def generate_control_blocks(self, processed_images, prompt, negative_prompt, seed):
+ print('| --- EXTRACT CONTROLNET FEATURES --- |')
+
+ shape = (1, 4, self.video_length, self.img_h // 8, self.img_w // 8)
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+ control_latents = torch.randn(
+ shape,
+ generator=generator,
+ device=self.device,
+ dtype=torch.float16
+ )
+
+ prompt_embeds = self.controlnet_pipe._encode_prompt(
+ prompt,
+ self.device,
+ 1,
+ self.do_cfg,
+ negative_prompt,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ lora_scale=None,
+ )
+
+ self.controlnet_pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
+ timesteps = self.controlnet_pipe.scheduler.timesteps
+
+ control_blocks = []
+ for t in tqdm(timesteps):
+ down_block_samples = []
+ mid_block_samples = []
+
+ for img_index, image in enumerate(processed_images):
+ latents = control_latents[:, :, img_index, :, :]
+ image = self.controlnet_pipe.control_image_processor.preprocess(
+ image,
+ height=self.img_h,
+ width=self.img_w
+ ).to(dtype=torch.float32)
+
+ image = image.repeat_interleave(1, dim=0)
+ image = image.to(device=self.device, dtype=torch.float16)
+
+ if self.do_cfg and not self.guess_mode:
+ image = torch.cat([image] * 2)
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_cfg else latents
+ latent_model_input = self.controlnet_pipe.scheduler.scale_model_input(latent_model_input, t)
+
+ if self.guess_mode and self.do_cfg:
+ control_model_input = latents
+ control_model_input = self.controlnet_pipe.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet_pipe.controlnet(
+ control_model_input.to(self.device),
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds.to(self.device),
+ controlnet_cond=image,
+ conditioning_scale=self.conditioning_scale,
+ guess_mode=self.guess_mode,
+ return_dict=False,
+ )
+
+ if self.guess_mode and self.do_cfg:
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ down_block_samples.append([x.detach().cpu() for x in down_block_res_samples])
+ mid_block_samples.append(mid_block_res_sample.detach().cpu())
+
+ control_blocks.append({
+ 'down_block_samples': down_block_samples,
+ 'mid_block_samples': mid_block_samples,
+ })
+
+ return control_blocks
+
+ def resort_features(self, control_blocks):
+ mid_blocks = []
+ down_blocks = []
+
+ for c_block in control_blocks:
+ d_blocks = defaultdict(list)
+ for image_weights in c_block['down_block_samples']:
+ for b_index, block_weights in enumerate(image_weights):
+ d_blocks[b_index] += block_weights.unsqueeze(0)
+
+ down_block = []
+ for _, value in d_blocks.items():
+ down_block.append(torch.stack(value).permute(1, 2, 0, 3, 4))
+
+ mid_block = torch.stack(c_block['mid_block_samples']).permute(1, 2, 0, 3, 4)
+
+ down_blocks.append(down_block)
+ mid_blocks.append(mid_block)
+
+ return down_blocks, mid_blocks
+
+ def __call__(self, video_path, prompt, negative_prompt, generator):
+ processed_images = self.process_video(video_path)
+ control_blocks = self.generate_control_blocks(
+ processed_images, prompt, negative_prompt, generator)
+ down_features, mid_features = self.resort_features(control_blocks)
+ return down_features, mid_features
diff --git a/animatediff/controlnet/controlnet_processors.py b/animatediff/controlnet/controlnet_processors.py
new file mode 100644
index 00000000..56f2dd07
--- /dev/null
+++ b/animatediff/controlnet/controlnet_processors.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+import cv2
+import numpy as np
+from PIL import Image
+from transformers import pipeline
+from controlnet_aux import HEDdetector, OpenposeDetector, NormalBaeDetector
+
+
+class CannyProcessor:
+ def __init__(self, t1, t2, **kwargs):
+ self.t1 = t1
+ self.t2 = t2
+
+ def __call__(self, input_image):
+ image = np.array(input_image)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ image = cv2.Canny(image, self.t1, self.t2)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ control_image = Image.fromarray(image)
+ return control_image
+
+
+class DepthProcessor:
+ def __init__(self, **kwargs):
+ self.depth_estimator = pipeline('depth-estimation')
+
+ def __call__(self, input_image):
+ image = self.depth_estimator(input_image)['depth']
+ image = np.array(image)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ control_image = Image.fromarray(image)
+ return control_image
+
+
+CONTROLNET_PROCESSORS = {
+ 'canny': {
+ 'controlnet': 'lllyasviel/control_v11p_sd15_canny',
+ 'processor': CannyProcessor,
+ 'processor_params': {'t1': 50, 't2': 150},
+ 'is_custom': True,
+ },
+ 'depth': {
+ 'controlnet': 'lllyasviel/control_v11f1p_sd15_depth',
+ 'processor': DepthProcessor,
+ 'processor_params': {},
+ 'is_custom': True,
+ },
+ 'softedge': {
+ 'controlnet': 'lllyasviel/control_v11p_sd15_softedge',
+ 'processor': HEDdetector, # PidiNetDetector
+ 'processor_params': {},
+ 'is_custom': False,
+ },
+ 'pose': {
+ 'controlnet': 'lllyasviel/sd-controlnet-openpose',
+ 'processor': OpenposeDetector,
+ 'processor_params': {},
+ 'is_custom': False,
+ },
+ 'norm': {
+ 'controlnet': 'lllyasviel/control_v11p_sd15_normalbae',
+ 'processor': NormalBaeDetector,
+ 'processor_params': {},
+ 'is_custom': False,
+ },
+}
diff --git a/animatediff/data/dataset.py b/animatediff/data/dataset.py
new file mode 100644
index 00000000..3f6ec102
--- /dev/null
+++ b/animatediff/data/dataset.py
@@ -0,0 +1,98 @@
+import os, io, csv, math, random
+import numpy as np
+from einops import rearrange
+from decord import VideoReader
+
+import torch
+import torchvision.transforms as transforms
+from torch.utils.data.dataset import Dataset
+from animatediff.utils.util import zero_rank_print
+
+
+
+class WebVid10M(Dataset):
+ def __init__(
+ self,
+ csv_path, video_folder,
+ sample_size=256, sample_stride=4, sample_n_frames=16,
+ is_image=False,
+ ):
+ zero_rank_print(f"loading annotations from {csv_path} ...")
+ with open(csv_path, 'r') as csvfile:
+ self.dataset = list(csv.DictReader(csvfile))
+ self.length = len(self.dataset)
+ zero_rank_print(f"data scale: {self.length}")
+
+ self.video_folder = video_folder
+ self.sample_stride = sample_stride
+ self.sample_n_frames = sample_n_frames
+ self.is_image = is_image
+
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+ self.pixel_transforms = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.Resize(sample_size[0]),
+ transforms.CenterCrop(sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ])
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
+ video_reader = VideoReader(video_dir)
+ video_length = len(video_reader)
+
+ if not self.is_image:
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+ else:
+ batch_index = [random.randint(0, video_length - 1)]
+
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+
+ if self.is_image:
+ pixel_values = pixel_values[0]
+
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+
+ except Exception as e:
+ idx = random.randint(0, self.length-1)
+
+ pixel_values = self.pixel_transforms(pixel_values)
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+
+
+if __name__ == "__main__":
+ from animatediff.utils.util import save_videos_grid
+
+ dataset = WebVid10M(
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
+ sample_size=256,
+ sample_stride=4, sample_n_frames=16,
+ is_image=True,
+ )
+ import pdb
+ pdb.set_trace()
+
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
+ for idx, batch in enumerate(dataloader):
+ print(batch["pixel_values"].shape, len(batch["text"]))
+ # for i in range(batch["pixel_values"].shape[0]):
+ # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
diff --git a/animatediff/models/attention.py b/animatediff/models/attention.py
index ad23583c..56a535f9 100644
--- a/animatediff/models/attention.py
+++ b/animatediff/models/attention.py
@@ -4,17 +4,17 @@
from typing import Optional
import torch
-import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.modeling_utils import ModelMixin
+from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+from diffusers.models.attention_processor import Attention
from einops import rearrange, repeat
-import pdb
+
@dataclass
class Transformer3DModelOutput(BaseOutput):
@@ -165,32 +165,19 @@ def __init__(
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
- # SC-Attn
- assert unet_use_cross_frame_attention is not None
- if unet_use_cross_frame_attention:
- self.attn1 = SparseCausalAttention2D(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
- upcast_attention=upcast_attention,
- )
- else:
- self.attn1 = CrossAttention(
- query_dim=dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- dropout=dropout,
- bias=attention_bias,
- upcast_attention=upcast_attention,
- )
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if cross_attention_dim is not None:
- self.attn2 = CrossAttention(
+ self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
@@ -214,7 +201,7 @@ def __init__(
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
- self.attn_temp = CrossAttention(
+ self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
@@ -225,48 +212,11 @@ def __init__(
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
- def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
- if not is_xformers_available():
- print("Here is how to install it")
- raise ModuleNotFoundError(
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
- " xformers",
- name="xformers",
- )
- elif not torch.cuda.is_available():
- raise ValueError(
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
- " available for GPU "
- )
- else:
- try:
- # Make sure we can run the memory efficient attention
- _ = xformers.ops.memory_efficient_attention(
- torch.randn((1, 2, 40), device="cuda"),
- torch.randn((1, 2, 40), device="cuda"),
- torch.randn((1, 2, 40), device="cuda"),
- )
- except Exception as e:
- raise e
- self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
- if self.attn2 is not None:
- self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
- # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
-
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
-
- # if self.only_cross_attention:
- # hidden_states = (
- # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
- # )
- # else:
- # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
-
- # pdb.set_trace()
if self.unet_use_cross_frame_attention:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
else:
diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py
index 2359e712..1909c82c 100644
--- a/animatediff/models/motion_module.py
+++ b/animatediff/models/motion_module.py
@@ -1,17 +1,15 @@
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Callable
import torch
-import numpy as np
import torch.nn.functional as F
from torch import nn
-import torchvision
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.modeling_utils import ModelMixin
+
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.models.attention import CrossAttention, FeedForward
+from diffusers.models.attention import FeedForward
+from diffusers.models.attention_processor import Attention, XFormersAttnProcessor, AttnProcessor
from einops import rearrange, repeat
import math
@@ -245,7 +243,7 @@ def forward(self, x):
return self.dropout(x)
-class VersatileAttention(CrossAttention):
+class VersatileAttention(Attention):
def __init__(
self,
attention_mode = None,
@@ -268,10 +266,48 @@ def __init__(
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ):
+ if use_memory_efficient_attention_xformers:
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
+ # You don't need XFormersAttnProcessor here.
+ # processor = XFormersAttnProcessor(
+ # attention_op=attention_op,
+ # )
+ processor = AttnProcessor()
+ else:
+ processor = AttnProcessor()
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
- batch_size, sequence_length, _ = hidden_states.shape
+ self.set_processor(processor)
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, **cross_attention_kwargs):
if self.attention_mode == "Temporal":
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
@@ -283,49 +319,16 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
else:
raise NotImplementedError
- encoder_hidden_states = encoder_hidden_states
-
- if self.group_norm is not None:
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = self.to_q(hidden_states)
- dim = query.shape[-1]
- query = self.reshape_heads_to_batch_dim(query)
-
- if self.added_kv_proj_dim is not None:
- raise NotImplementedError
-
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
- key = self.to_k(encoder_hidden_states)
- value = self.to_v(encoder_hidden_states)
-
- key = self.reshape_heads_to_batch_dim(key)
- value = self.reshape_heads_to_batch_dim(value)
-
- if attention_mask is not None:
- if attention_mask.shape[-1] != query.shape[1]:
- target_length = query.shape[1]
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
- attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
-
- # attention, what we cannot get enough of
- if self._use_memory_efficient_attention_xformers:
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
- hidden_states = hidden_states.to(query.dtype)
- else:
- if self._slice_size is None or query.shape[0] // self._slice_size == 1:
- hidden_states = self._attention(query, key, value, attention_mask)
- else:
- hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
-
- # linear proj
- hidden_states = self.to_out[0](hidden_states)
-
- # dropout
- hidden_states = self.to_out[1](hidden_states)
+ hidden_states = self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
+
diff --git a/animatediff/models/resnet.py b/animatediff/models/resnet.py
index ad28eb0c..da80f174 100644
--- a/animatediff/models/resnet.py
+++ b/animatediff/models/resnet.py
@@ -18,6 +18,17 @@ def forward(self, x):
return x
+class InflatedGroupNorm(nn.GroupNorm):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
@@ -112,6 +123,7 @@ def __init__(
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
+ use_inflated_groupnorm=None,
):
super().__init__()
self.pre_norm = pre_norm
@@ -126,7 +138,11 @@ def __init__(
if groups_out is None:
groups_out = groups
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ assert use_inflated_groupnorm != None
+ if use_inflated_groupnorm:
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ else:
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
@@ -142,7 +158,11 @@ def __init__(
else:
self.time_emb_proj = None
- self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ if use_inflated_groupnorm:
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ else:
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py
index 9d67e8ae..48fa062d 100644
--- a/animatediff/models/unet.py
+++ b/animatediff/models/unet.py
@@ -12,7 +12,7 @@
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.modeling_utils import ModelMixin
+from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
@@ -24,7 +24,7 @@
get_down_block,
get_up_block,
)
-from .resnet import InflatedConv3d
+from .resnet import InflatedConv3d, InflatedGroupNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -77,6 +77,8 @@ def __init__(
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
+ use_inflated_groupnorm=False,
+
# Additional
use_motion_module = False,
motion_module_resolutions = ( 1,2,4,8 ),
@@ -88,7 +90,7 @@ def __init__(
unet_use_temporal_attention = None,
):
super().__init__()
-
+
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
@@ -150,6 +152,7 @@ def __init__(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
motion_module_type=motion_module_type,
@@ -175,6 +178,7 @@ def __init__(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
@@ -227,6 +231,7 @@ def __init__(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and (res in motion_module_resolutions),
motion_module_type=motion_module_type,
@@ -236,7 +241,10 @@ def __init__(
prev_output_channel = output_channel
# out
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ if use_inflated_groupnorm:
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@@ -317,6 +325,8 @@ def forward(
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
+ down_block_additional_residuals: Optional[List[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
@@ -406,11 +416,23 @@ def forward(
down_block_res_samples += res_samples
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual.to(dtype=down_block_res_sample.dtype)
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
# mid
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
-
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual.to(dtype=sample.dtype)
# up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
diff --git a/animatediff/models/unet_blocks.py b/animatediff/models/unet_blocks.py
index 8a17f201..711ad6cc 100644
--- a/animatediff/models/unet_blocks.py
+++ b/animatediff/models/unet_blocks.py
@@ -30,7 +30,8 @@ def get_down_block(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
-
+ use_inflated_groupnorm=None,
+
use_motion_module=None,
motion_module_type=None,
@@ -50,6 +51,8 @@ def get_down_block(
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
@@ -77,6 +80,7 @@ def get_down_block(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
@@ -106,6 +110,7 @@ def get_up_block(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
use_motion_module=None,
motion_module_type=None,
@@ -125,6 +130,8 @@ def get_up_block(
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
@@ -152,6 +159,7 @@ def get_up_block(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
@@ -181,6 +189,7 @@ def __init__(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
use_motion_module=None,
@@ -206,6 +215,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
)
]
attentions = []
@@ -248,6 +259,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
)
)
@@ -290,6 +303,7 @@ def __init__(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
use_motion_module=None,
@@ -318,6 +332,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
@@ -421,6 +437,8 @@ def __init__(
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
+
+ use_inflated_groupnorm=None,
use_motion_module=None,
motion_module_type=None,
@@ -444,6 +462,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
@@ -526,6 +546,7 @@ def __init__(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
use_motion_module=None,
@@ -556,6 +577,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
@@ -661,6 +684,8 @@ def __init__(
output_scale_factor=1.0,
add_upsample=True,
+ use_inflated_groupnorm=None,
+
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
@@ -685,6 +710,8 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py
index 58f22d16..4458c237 100644
--- a/animatediff/pipelines/pipeline_animation.py
+++ b/animatediff/pipelines/pipeline_animation.py
@@ -306,7 +306,7 @@ def prepare_latents(self, batch_size, num_channels_latents, video_length, height
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
- latents = latents.to(device)
+ latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
@@ -330,6 +330,8 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
+ down_block_control: Optional[List[torch.FloatTensor]] = None,
+ mid_block_control: Optional[torch.FloatTensor] = None,
**kwargs,
):
# Default height and width to unet
@@ -392,7 +394,13 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=text_embeddings,
+ down_block_additional_residuals=[x.to(self.device) for x in down_block_control[i]] if down_block_control is not None else None,
+ mid_block_additional_residual=mid_block_control[i].to(self.device) if mid_block_control is not None else None,
+ ).sample.to(dtype=latents_dtype)
# noise_pred = []
# import pdb
# pdb.set_trace()
diff --git a/animatediff/utils/convert_from_ckpt.py b/animatediff/utils/convert_from_ckpt.py
index 9ee269d8..7730fc5c 100644
--- a/animatediff/utils/convert_from_ckpt.py
+++ b/animatediff/utils/convert_from_ckpt.py
@@ -660,7 +660,17 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
+
+ renamed_vae_checkpoint = {}
+ for name, tensor in new_checkpoint.items():
+ if 'encoder.mid_block.attentions.0.' in name or 'decoder.mid_block.attentions.0.' in name:
+ renamed = name.replace('key', 'to_k').replace('query', 'to_q').replace('value', 'to_v')
+ renamed = renamed.replace('proj_attn', 'to_out.0')
+ renamed_vae_checkpoint[renamed] = tensor
+ else:
+ renamed_vae_checkpoint[name] = tensor
+
+ return renamed_vae_checkpoint
def convert_ldm_bert_checkpoint(checkpoint, config):
@@ -723,7 +733,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
- text_model.load_state_dict(text_model_dict)
+ text_model.load_state_dict(text_model_dict, strict=False)
return text_model
diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py
index 83f31614..ee2dd2b8 100644
--- a/animatediff/utils/util.py
+++ b/animatediff/utils/util.py
@@ -5,11 +5,16 @@
import torch
import torchvision
+import torch.distributed as dist
from tqdm import tqdm
from einops import rearrange
+def zero_rank_print(s):
+ if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
+
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
diff --git a/app.py b/app.py
index 488daf81..38aa1005 100644
--- a/app.py
+++ b/app.py
@@ -1,36 +1,29 @@
-import gradio as gr
-import os
-from glob import glob
-import random
-import pdb
-from transformers import CLIPTextModel, CLIPTokenizer
-from animatediff.models.unet import UNet3DConditionModel
-from animatediff.pipelines.pipeline_animation import AnimationPipeline
-from diffusers import AutoencoderKL
-from datetime import datetime
import os
-from omegaconf import OmegaConf
import json
import torch
+import random
+
+import gradio as gr
+from glob import glob
+from omegaconf import OmegaConf
+from datetime import datetime
+from safetensors import safe_open
from diffusers import AutoencoderKL
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
-
+from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.controlnet.controlnet_module import ControlnetModule
from animatediff.utils.util import save_videos_grid
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
-from diffusers.utils.import_utils import is_xformers_available
-
-from safetensors import safe_open
sample_idx = 0
-
scheduler_dict = {
"Euler": EulerDiscreteScheduler,
"PNDM": PNDMScheduler,
@@ -46,7 +39,6 @@
}
"""
-
class AnimateController:
def __init__(self):
@@ -55,6 +47,8 @@ def __init__(self):
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
+ self.controlnet_dir = os.path.join(self.basedir, "models", "Controlnet")
+ self.videos_dir = os.path.join(self.basedir, "videos")
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
self.savedir_sample = os.path.join(self.savedir, "sample")
os.makedirs(self.savedir, exist_ok=True)
@@ -62,17 +56,21 @@ def __init__(self):
self.stable_diffusion_list = []
self.motion_module_list = []
self.personalized_model_list = []
+ self.controlnet_list = []
+ self.videos_list = []
self.refresh_stable_diffusion()
self.refresh_motion_module()
self.refresh_personalized_model()
-
+ self.refresh_controlnet()
+ self.refresh_videos()
# config models
self.tokenizer = None
self.text_encoder = None
self.vae = None
self.unet = None
self.pipeline = None
+ self.controlnet = None
self.lora_model_state_dict = {}
self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
@@ -88,6 +86,12 @@ def refresh_personalized_model(self):
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
+ def refresh_controlnet(self):
+ self.controlnet_list = glob(os.path.join(self.controlnet_dir, "*/"))
+
+ def refresh_videos(self):
+ self.videos_list = glob(os.path.join(self.videos_dir, "*.mp4"))
+
def update_stable_diffusion(self, stable_diffusion_dropdown):
self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
@@ -150,7 +154,13 @@ def animate(
length_slider,
height_slider,
cfg_scale_slider,
- seed_textbox
+ seed_textbox,
+ videos_path_dropdown,
+ get_each_slider,
+ controlnet_processor_name_dropdown,
+ controlnet_processor_path_dropdown,
+ controlnet_guess_mode_checkbox,
+ controlnet_conditioning_scale_slider,
):
if self.unet is None:
raise gr.Error(f"Please select a pretrained model path.")
@@ -174,7 +184,28 @@ def animate(
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
else: torch.seed()
seed = torch.initial_seed()
-
+
+ down_features, mid_features = None, None
+ controlnet = None
+ if videos_path_dropdown and videos_path_dropdown != "none":
+ controlnet_config = {
+ 'video_length': length_slider,
+ 'img_h': height_slider,
+ 'img_w': width_slider,
+ 'guidance_scale': cfg_scale_slider,
+ 'steps': sample_step_slider,
+ 'get_each': get_each_slider,
+ 'conditioning_scale': controlnet_conditioning_scale_slider,
+ 'controlnet_processor': controlnet_processor_name_dropdown,
+ 'controlnet_pipeline': stable_diffusion_dropdown,
+ 'controlnet_processor_path': controlnet_processor_path_dropdown,
+ 'guess_mode': controlnet_guess_mode_checkbox,
+ 'device': 'cuda',
+ }
+ controlnet = ControlnetModule(controlnet_config)
+ down_features, mid_features = controlnet(
+ videos_path_dropdown, prompt_textbox, negative_prompt_textbox, seed)
+
sample = pipeline(
prompt_textbox,
negative_prompt = negative_prompt_textbox,
@@ -183,6 +214,8 @@ def animate(
width = width_slider,
height = height_slider,
video_length = length_slider,
+ down_block_control = down_features,
+ mid_block_control = mid_features,
).videos
save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
@@ -288,7 +321,54 @@ def update_personalized_model():
prompt_textbox = gr.Textbox(label="Prompt", lines=2)
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
-
+
+ gr.Markdown(
+ """
+ ### 2.* Controlnet for AnimateDiff (Optional).
+ """
+ )
+
+ with gr.Column(visible=False) as controlnet_column:
+ with gr.Row().style(equal_height=True):
+ videos_path_dropdown = gr.Dropdown(
+ label="Select video for applying controlnet (optional)",
+ choices=["none"] + controller.videos_list,
+ value="none",
+ interactive=True,
+ )
+ videos_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def update_videos():
+ controller.refresh_videos()
+ return gr.Dropdown.update(choices=controller.videos_list)
+ videos_refresh_button.click(fn=update_videos, inputs=[], outputs=[videos_path_dropdown])
+
+ controlnet_processor_name_dropdown = gr.Dropdown(
+ label="Select controlnet processor (if video selected)",
+ choices=["canny", "depth", "softedge", "pose", "norm"],
+ value="none",
+ interactive=True,
+ )
+
+ controlnet_processor_path_dropdown = gr.Dropdown(
+ label="Set controlnet processor path (if video selected)",
+ choices=["none"] + controller.controlnet_list,
+ value="none",
+ interactive=True,
+ )
+ controlnet_processor_path_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def update_videos():
+ controller.refresh_videos()
+ return gr.Dropdown.update(choices=controller.controlnet_list)
+ controlnet_processor_path_refresh_button.click(fn=update_videos, inputs=[], outputs=[controlnet_processor_path_dropdown])
+
+ with gr.Row().style(equal_height=True):
+ controlnet_guess_mode_checkbox = gr.Checkbox(value=True, label="Controlnet Guess mode")
+ get_each_slider = gr.Slider(label="Get Each Frame", value=2, minimum=1, maximum=4, step=1)
+ controlnet_conditioning_scale_slider = gr.Slider(label="Controlnet strenth", value=0.5, minimum=0.1, maximum=1.0, step=0.1)
+
+ change_visibility = gr.Button(value="SHOW CONTROLNET SETTINGS (OPTIONAL)")
+ change_visibility.click(lambda :gr.update(visible=True), None, controlnet_column)
+
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Row():
@@ -325,6 +405,12 @@ def update_personalized_model():
height_slider,
cfg_scale_slider,
seed_textbox,
+ videos_path_dropdown,
+ get_each_slider,
+ controlnet_processor_name_dropdown,
+ controlnet_processor_path_dropdown,
+ controlnet_guess_mode_checkbox,
+ controlnet_conditioning_scale_slider,
],
outputs=[result_video]
)
diff --git a/configs/inference/inference.yaml b/configs/inference/inference-v1.yaml
similarity index 100%
rename from configs/inference/inference.yaml
rename to configs/inference/inference-v1.yaml
diff --git a/configs/inference/inference-v2.yaml b/configs/inference/inference-v2.yaml
new file mode 100644
index 00000000..a33bc124
--- /dev/null
+++ b/configs/inference/inference-v2.yaml
@@ -0,0 +1,27 @@
+unet_additional_kwargs:
+ use_inflated_groupnorm: true
+ unet_use_cross_frame_attention: false
+ unet_use_temporal_attention: false
+ use_motion_module: true
+ motion_module_resolutions:
+ - 1
+ - 2
+ - 4
+ - 8
+ motion_module_mid_block: true
+ motion_module_decoder_only: false
+ motion_module_type: Vanilla
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types:
+ - Temporal_Self
+ - Temporal_Self
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 32
+ temporal_attention_dim_div: 1
+
+noise_scheduler_kwargs:
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "linear"
diff --git a/configs/prompts/1-ToonYou-Controlnet.yaml b/configs/prompts/1-ToonYou-Controlnet.yaml
new file mode 100644
index 00000000..bc7d583b
--- /dev/null
+++ b/configs/prompts/1-ToonYou-Controlnet.yaml
@@ -0,0 +1,30 @@
+ToonYou:
+ base: ""
+ path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
+ motion_module:
+ - "models/Motion_Module/mm_sd_v15.ckpt"
+
+ control:
+ video_path: "./videos/dance.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "softedge" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_softedge" # control_v11p_sd15_softedge, control_v11f1p_sd15_depth, control_v11p_sd15_canny
+ guess_mode: True
+
+ seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
+ - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
+ - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
+ - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
+
+ n_prompt:
+ - ""
+ - "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
+ - ""
+ - ""
diff --git a/configs/prompts/2-Lyriel-Controlnet.yaml b/configs/prompts/2-Lyriel-Controlnet.yaml
new file mode 100644
index 00000000..b0fc2558
--- /dev/null
+++ b/configs/prompts/2-Lyriel-Controlnet.yaml
@@ -0,0 +1,31 @@
+Lyriel:
+ base: ""
+ path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
+ motion_module:
+ # - "models/Motion_Module/mm_sd_v14.ckpt"
+ - "models/Motion_Module/mm_sd_v15.ckpt"
+
+ control:
+ video_path: "./videos/smiling.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "canny" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_canny" # control_v11p_sd15_softedge, control_v11f1p_sd15_depth, control_v11p_sd15_canny
+ guess_mode: True
+
+ seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
+ - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
+ - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
+ - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
+
+ n_prompt:
+ - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
+ - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
+ - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
+ - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"
diff --git a/configs/prompts/3-RcnzCartoon-Controlnet.yaml b/configs/prompts/3-RcnzCartoon-Controlnet.yaml
new file mode 100644
index 00000000..4dcd09ca
--- /dev/null
+++ b/configs/prompts/3-RcnzCartoon-Controlnet.yaml
@@ -0,0 +1,31 @@
+RcnzCartoon:
+ base: ""
+ path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
+ motion_module:
+ # - "models/Motion_Module/mm_sd_v14.ckpt"
+ - "models/Motion_Module/mm_sd_v15.ckpt"
+
+ control:
+ video_path: "./videos/smiling.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "softedge" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_softedge" # control_v11p_sd15_softedge, control_v11p_sd15_canny, control_v11f1p_sd15_depth
+ guess_mode: True
+
+ seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded"
+ - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face"
+ - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes"
+ - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
+
+ n_prompt:
+ - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
+ - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular"
+ - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,"
+ - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"
diff --git a/configs/prompts/v2/1-ToonYou-Controlnet.yaml b/configs/prompts/v2/1-ToonYou-Controlnet.yaml
new file mode 100644
index 00000000..69600784
--- /dev/null
+++ b/configs/prompts/v2/1-ToonYou-Controlnet.yaml
@@ -0,0 +1,31 @@
+ToonYou:
+ base: ""
+ path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
+ inference_config: "configs/inference/inference-v2.yaml"
+ motion_module:
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+ control:
+ video_path: "./videos/dance.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "softedge" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_softedge" # control_v11p_sd15_softedge, control_v11f1p_sd15_depth, control_v11p_sd15_canny
+ guess_mode: True
+
+ seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
+ - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
+ - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
+ - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
+
+ n_prompt:
+ - ""
+ - "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
+ - ""
+ - ""
diff --git a/configs/prompts/v2/2-Lyriel-Controlnet.yaml b/configs/prompts/v2/2-Lyriel-Controlnet.yaml
new file mode 100644
index 00000000..b77bb558
--- /dev/null
+++ b/configs/prompts/v2/2-Lyriel-Controlnet.yaml
@@ -0,0 +1,31 @@
+Lyriel:
+ base: ""
+ path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
+ inference_config: "configs/inference/inference-v2.yaml"
+ motion_module:
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+ control:
+ video_path: "./videos/smiling.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "canny" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_canny" # control_v11p_sd15_softedge, control_v11f1p_sd15_depth, control_v11p_sd15_canny
+ guess_mode: True
+
+ seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
+ - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
+ - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
+ - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
+
+ n_prompt:
+ - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
+ - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
+ - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
+ - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"
diff --git a/configs/prompts/v2/3-RcnzCartoon-Controlnet.yaml b/configs/prompts/v2/3-RcnzCartoon-Controlnet.yaml
new file mode 100644
index 00000000..7fbf242d
--- /dev/null
+++ b/configs/prompts/v2/3-RcnzCartoon-Controlnet.yaml
@@ -0,0 +1,31 @@
+RcnzCartoon:
+ base: ""
+ path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
+ inference_config: "configs/inference/inference-v2.yaml"
+ motion_module:
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+ control:
+ video_path: "./videos/smiling.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "softedge" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_softedge" # control_v11p_sd15_softedge, control_v11p_sd15_canny, control_v11f1p_sd15_depth
+ guess_mode: True
+
+ seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded"
+ - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face"
+ - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes"
+ - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
+
+ n_prompt:
+ - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
+ - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular"
+ - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,"
+ - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"
diff --git a/configs/prompts/v2/5-RealisticVision-Controlnet.yaml b/configs/prompts/v2/5-RealisticVision-Controlnet.yaml
new file mode 100644
index 00000000..31ea0cd4
--- /dev/null
+++ b/configs/prompts/v2/5-RealisticVision-Controlnet.yaml
@@ -0,0 +1,32 @@
+RealisticVision:
+ base: ""
+ path: "models/DreamBooth_LoRA/realisticVisionV51_v20Novae.safetensors"
+
+ inference_config: "configs/inference/inference-v2.yaml"
+ motion_module:
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+ control:
+ video_path: "./videos/smiling.mp4" # smiling, dance or your video
+ get_each: 2 # get each frame from video
+ conditioning_scale: 0.75 # controlnet strength
+ controlnet_processor: "depth" # softedge, canny, depth
+ controlnet_pipeline: "models/StableDiffusion/stable-diffusion-v1-5"
+ controlnet_processor_path: "models/Controlnet/control_v11p_sd15_softedge" # control_v11p_sd15_softedge, control_v11f1p_sd15_depth, control_v11p_sd15_canny
+ guess_mode: True
+
+ seed: [13100322578370451493, 14752961627088720670, 9329399085567825781, 16987697414827649302]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+ - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+ - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
+
+ n_prompt:
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
diff --git a/configs/prompts/v2/5-RealisticVision.yaml b/configs/prompts/v2/5-RealisticVision.yaml
new file mode 100644
index 00000000..7770b19b
--- /dev/null
+++ b/configs/prompts/v2/5-RealisticVision.yaml
@@ -0,0 +1,23 @@
+RealisticVision:
+ base: ""
+ path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
+
+ inference_config: "configs/inference/inference-v2.yaml"
+ motion_module:
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
+
+ seed: [13100322578370451493, 14752961627088720670, 9329399085567825781, 16987697414827649302]
+ steps: 25
+ guidance_scale: 7.5
+
+ prompt:
+ - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+ - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
+ - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
+
+ n_prompt:
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
diff --git a/configs/training/image_finetune.yaml b/configs/training/image_finetune.yaml
new file mode 100644
index 00000000..ea05fd14
--- /dev/null
+++ b/configs/training/image_finetune.yaml
@@ -0,0 +1,48 @@
+image_finetune: true
+
+output_dir: "outputs"
+pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
+
+noise_scheduler_kwargs:
+ num_train_timesteps: 1000
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "scaled_linear"
+ steps_offset: 1
+ clip_sample: false
+
+train_data:
+ csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
+ video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
+ sample_size: 256
+
+validation_data:
+ prompts:
+ - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
+ - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
+ - "Robot dancing in times square."
+ - "Pacific coast, carmel by the sea ocean and waves."
+ num_inference_steps: 25
+ guidance_scale: 8.
+
+trainable_modules:
+ - "."
+
+unet_checkpoint_path: ""
+
+learning_rate: 1.e-5
+train_batch_size: 50
+
+max_train_epoch: -1
+max_train_steps: 100
+checkpointing_epochs: -1
+checkpointing_steps: 60
+
+validation_steps: 5000
+validation_steps_tuple: [2, 50]
+
+global_seed: 42
+mixed_precision_training: true
+enable_xformers_memory_efficient_attention: True
+
+is_debug: False
diff --git a/configs/training/training.yaml b/configs/training/training.yaml
new file mode 100644
index 00000000..626f05c2
--- /dev/null
+++ b/configs/training/training.yaml
@@ -0,0 +1,66 @@
+image_finetune: false
+
+output_dir: "outputs"
+pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
+
+unet_additional_kwargs:
+ use_motion_module : true
+ motion_module_resolutions : [ 1,2,4,8 ]
+ unet_use_cross_frame_attention : false
+ unet_use_temporal_attention : false
+
+ motion_module_type: Vanilla
+ motion_module_kwargs:
+ num_attention_heads : 8
+ num_transformer_block : 1
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
+ temporal_position_encoding : true
+ temporal_position_encoding_max_len : 24
+ temporal_attention_dim_div : 1
+ zero_initialize : true
+
+noise_scheduler_kwargs:
+ num_train_timesteps: 1000
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "linear"
+ steps_offset: 1
+ clip_sample: false
+
+train_data:
+ csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
+ video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
+ sample_size: 256
+ sample_stride: 4
+ sample_n_frames: 16
+
+validation_data:
+ prompts:
+ - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
+ - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
+ - "Robot dancing in times square."
+ - "Pacific coast, carmel by the sea ocean and waves."
+ num_inference_steps: 25
+ guidance_scale: 8.
+
+trainable_modules:
+ - "motion_modules."
+
+unet_checkpoint_path: ""
+
+learning_rate: 1.e-4
+train_batch_size: 4
+
+max_train_epoch: -1
+max_train_steps: 100
+checkpointing_epochs: -1
+checkpointing_steps: 60
+
+validation_steps: 5000
+validation_steps_tuple: [2, 50]
+
+global_seed: 42
+mixed_precision_training: true
+enable_xformers_memory_efficient_attention: True
+
+is_debug: False
diff --git a/download_bashscripts/9-Controlnets.sh b/download_bashscripts/9-Controlnets.sh
new file mode 100644
index 00000000..424b6001
--- /dev/null
+++ b/download_bashscripts/9-Controlnets.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+git clone https://huggingface.co/lllyasviel/control_v11p_sd15_softedge models/Controlnet/control_v11p_sd15_softedge
+git clone https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth models/Controlnet/control_v11f1p_sd15_depth
+git clone https://huggingface.co/lllyasviel/control_v11p_sd15_canny models/Controlnet/control_v11p_sd15_canny
\ No newline at end of file
diff --git a/environment.yaml b/environment.yaml
index 03c2dd10..4921eaac 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -1,21 +1,25 @@
name: animatediff
channels:
- pytorch
- - xformers
+ - nvidia
dependencies:
- python=3.10
- - pytorch==1.12.1
- - torchvision==0.13.1
- - torchaudio==0.12.1
- - cudatoolkit=11.3
- - xformers
+ - pytorch=2.0.1
+ - torchvision=0.15.2
+ # - torchaudio=0.13.1
+ # - pytorch-cuda=11.7
- pip
- pip:
- - diffusers[torch]==0.11.1
- - transformers==4.25.1
+ - diffusers==0.20.2
+ - transformers==4.32.1
+ - xformers==0.0.21
+ - controlnet-aux==0.0.6
- imageio==2.27.0
+ - imageio[ffmpeg]
+ - decord==0.6.0
- gdown
- einops
- omegaconf
- safetensors
- gradio
+ - wandb
diff --git a/models/Controlnet/Put controlnet models repo here.txt b/models/Controlnet/Put controlnet models repo here.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/scripts/animate.py b/scripts/animate.py
index 8bb5dd74..2d551040 100644
--- a/scripts/animate.py
+++ b/scripts/animate.py
@@ -2,14 +2,10 @@
import datetime
import inspect
import os
-from omegaconf import OmegaConf
import torch
-
-import diffusers
+from omegaconf import OmegaConf
from diffusers import AutoencoderKL, DDIMScheduler
-
-from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
@@ -17,13 +13,10 @@
from animatediff.utils.util import save_videos_grid
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
+from animatediff.controlnet.controlnet_module import ControlnetModule
from diffusers.utils.import_utils import is_xformers_available
-from einops import rearrange, repeat
-
-import csv, pdb, glob
from safetensors import safe_open
-import math
from pathlib import Path
@@ -34,7 +27,6 @@ def main(args):
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(args.config).stem}-{time_str}"
os.makedirs(savedir)
- inference_config = OmegaConf.load(args.inference_config)
config = OmegaConf.load(args.config)
samples = []
@@ -45,7 +37,8 @@ def main(args):
motion_modules = model_config.motion_module
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
for motion_module in motion_modules:
-
+ inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
+
### >>> create validation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
@@ -105,6 +98,20 @@ def main(args):
pipeline.to("cuda")
### <<< create validation pipeline <<< ###
+ down_features, mid_features = None, None
+ controlnet = None
+ if 'control' in model_config:
+ controlnet_config = {
+ 'video_length': args.L,
+ 'img_h': args.H,
+ 'img_w': args.W,
+ 'guidance_scale': model_config.guidance_scale,
+ 'steps': model_config.steps,
+ 'device': 'cuda',
+ **model_config.control
+ }
+ controlnet = ControlnetModule(controlnet_config)
+
prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
@@ -119,7 +126,10 @@ def main(args):
if random_seed != -1: torch.manual_seed(random_seed)
else: torch.seed()
config[config_key].random_seed.append(torch.initial_seed())
-
+
+ if controlnet is not None:
+ down_features, mid_features = controlnet(model_config.control.video_path, prompt, n_prompt, random_seed)
+
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
@@ -130,6 +140,8 @@ def main(args):
width = args.W,
height = args.H,
video_length = args.L,
+ down_block_control = down_features,
+ mid_block_control = mid_features,
).videos
samples.append(sample)
@@ -148,7 +160,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
- parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
+ parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--L", type=int, default=16 )
diff --git a/train.py b/train.py
new file mode 100644
index 00000000..094e419a
--- /dev/null
+++ b/train.py
@@ -0,0 +1,493 @@
+import os
+import math
+import wandb
+import random
+import logging
+import inspect
+import argparse
+import datetime
+import subprocess
+
+from pathlib import Path
+from tqdm.auto import tqdm
+from einops import rearrange
+from omegaconf import OmegaConf
+from safetensors import safe_open
+from typing import Dict, Optional, Tuple
+
+import torch
+import torchvision
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.optim.swa_utils import AveragedModel
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import diffusers
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.models import UNet2DConditionModel
+from diffusers.pipelines import StableDiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+import transformers
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.data.dataset import WebVid10M
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid, zero_rank_print
+
+
+
+def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
+ """Initializes distributed environment."""
+ if launcher == 'pytorch':
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ local_rank = rank % num_gpus
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group(backend=backend, **kwargs)
+
+ elif launcher == 'slurm':
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ local_rank = proc_id % num_gpus
+ torch.cuda.set_device(local_rank)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['RANK'] = str(proc_id)
+ port = os.environ.get('PORT', port)
+ os.environ['MASTER_PORT'] = str(port)
+ dist.init_process_group(backend=backend)
+ zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
+
+ else:
+ raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
+
+ return local_rank
+
+
+
+def main(
+ image_finetune: bool,
+
+ name: str,
+ use_wandb: bool,
+ launcher: str,
+
+ output_dir: str,
+ pretrained_model_path: str,
+
+ train_data: Dict,
+ validation_data: Dict,
+ cfg_random_null_text: bool = True,
+ cfg_random_null_text_ratio: float = 0.1,
+
+ unet_checkpoint_path: str = "",
+ unet_additional_kwargs: Dict = {},
+ ema_decay: float = 0.9999,
+ noise_scheduler_kwargs = None,
+
+ max_train_epoch: int = -1,
+ max_train_steps: int = 100,
+ validation_steps: int = 100,
+ validation_steps_tuple: Tuple = (-1,),
+
+ learning_rate: float = 3e-5,
+ scale_lr: bool = False,
+ lr_warmup_steps: int = 0,
+ lr_scheduler: str = "constant",
+
+ trainable_modules: Tuple[str] = (None, ),
+ num_workers: int = 32,
+ train_batch_size: int = 1,
+ adam_beta1: float = 0.9,
+ adam_beta2: float = 0.999,
+ adam_weight_decay: float = 1e-2,
+ adam_epsilon: float = 1e-08,
+ max_grad_norm: float = 1.0,
+ gradient_accumulation_steps: int = 1,
+ gradient_checkpointing: bool = False,
+ checkpointing_epochs: int = 5,
+ checkpointing_steps: int = -1,
+
+ mixed_precision_training: bool = True,
+ enable_xformers_memory_efficient_attention: bool = True,
+
+ global_seed: int = 42,
+ is_debug: bool = False,
+):
+ check_min_version("0.10.0.dev0")
+
+ # Initialize distributed training
+ local_rank = init_dist(launcher=launcher)
+ global_rank = dist.get_rank()
+ num_processes = dist.get_world_size()
+ is_main_process = global_rank == 0
+
+ seed = global_seed + global_rank
+ torch.manual_seed(seed)
+
+ # Logging folder
+ folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
+ output_dir = os.path.join(output_dir, folder_name)
+ if is_debug and os.path.exists(output_dir):
+ os.system(f"rm -rf {output_dir}")
+
+ *_, config = inspect.getargvalues(inspect.currentframe())
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+
+ if is_main_process and (not is_debug) and use_wandb:
+ run = wandb.init(project="animatediff", name=folder_name, config=config)
+
+ # Handle the output folder creation
+ if is_main_process:
+ os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(f"{output_dir}/samples", exist_ok=True)
+ os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
+ os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
+
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+ if not image_finetune:
+ unet = UNet3DConditionModel.from_pretrained_2d(
+ pretrained_model_path, subfolder="unet",
+ unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
+ )
+ else:
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
+
+ # Load pretrained unet weights
+ if unet_checkpoint_path != "":
+ zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
+ unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
+ if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
+ state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
+
+ m, u = unet.load_state_dict(state_dict, strict=False)
+ zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+ assert len(u) == 0
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # Set unet trainable parameters
+ unet.requires_grad_(False)
+ for name, param in unet.named_parameters():
+ for trainable_module_name in trainable_modules:
+ if trainable_module_name in name:
+ param.requires_grad = True
+ break
+
+ trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ optimizer = torch.optim.AdamW(
+ trainable_params,
+ lr=learning_rate,
+ betas=(adam_beta1, adam_beta2),
+ weight_decay=adam_weight_decay,
+ eps=adam_epsilon,
+ )
+
+ if is_main_process:
+ zero_rank_print(f"trainable params number: {len(trainable_params)}")
+ zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
+
+ # Enable xformers
+ if enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable gradient checkpointing
+ if gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Move models to GPU
+ vae.to(local_rank)
+ text_encoder.to(local_rank)
+
+ # Get the training dataset
+ train_dataset = WebVid10M(**train_data, is_image=image_finetune)
+ distributed_sampler = DistributedSampler(
+ train_dataset,
+ num_replicas=num_processes,
+ rank=global_rank,
+ shuffle=True,
+ seed=global_seed,
+ )
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=train_batch_size,
+ shuffle=False,
+ sampler=distributed_sampler,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ # Get the training iteration
+ if max_train_steps == -1:
+ assert max_train_epoch != -1
+ max_train_steps = max_train_epoch * len(train_dataloader)
+
+ if checkpointing_steps == -1:
+ assert checkpointing_epochs != -1
+ checkpointing_steps = checkpointing_epochs * len(train_dataloader)
+
+ if scale_lr:
+ learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
+
+ # Scheduler
+ lr_scheduler = get_scheduler(
+ lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
+ )
+
+ # Validation pipeline
+ if not image_finetune:
+ validation_pipeline = AnimationPipeline(
+ unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
+ ).to("cuda")
+ else:
+ validation_pipeline = StableDiffusionPipeline.from_pretrained(
+ pretrained_model_path,
+ unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
+ )
+ validation_pipeline.enable_vae_slicing()
+
+ # DDP warpper
+ unet.to(local_rank)
+ unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
+ # Afterwards we recalculate our number of training epochs
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
+
+ if is_main_process:
+ logging.info("***** Running training *****")
+ logging.info(f" Num examples = {len(train_dataset)}")
+ logging.info(f" Num Epochs = {num_train_epochs}")
+ logging.info(f" Instantaneous batch size per device = {train_batch_size}")
+ logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
+ logging.info(f" Total optimization steps = {max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
+ progress_bar.set_description("Steps")
+
+ # Support mixed-precision training
+ scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
+
+ for epoch in range(first_epoch, num_train_epochs):
+ train_dataloader.sampler.set_epoch(epoch)
+ unet.train()
+
+ for step, batch in enumerate(train_dataloader):
+ if cfg_random_null_text:
+ batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
+
+ # Data batch sanity check
+ if epoch == first_epoch and step == 0:
+ pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+ if not image_finetune:
+ pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+ for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+ pixel_value = pixel_value[None, ...]
+ save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
+ else:
+ for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+ pixel_value = pixel_value / 2. + 0.5
+ torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
+
+ ### >>>> Training >>>> ###
+
+ # Convert videos to latent space
+ pixel_values = batch["pixel_values"].to(local_rank)
+ video_length = pixel_values.shape[1]
+ with torch.no_grad():
+ if not image_finetune:
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+ latents = vae.encode(pixel_values).latent_dist
+ latents = latents.sample()
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+ else:
+ latents = vae.encode(pixel_values).latent_dist
+ latents = latents.sample()
+
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each video
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ with torch.no_grad():
+ prompt_ids = tokenizer(
+ batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids.to(latents.device)
+ encoder_hidden_states = text_encoder(prompt_ids)[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ raise NotImplementedError
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ # Mixed-precision training
+ with torch.cuda.amp.autocast(enabled=mixed_precision_training):
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ optimizer.zero_grad()
+
+ # Backpropagate
+ if mixed_precision_training:
+ scaler.scale(loss).backward()
+ """ >>> gradient clipping >>> """
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+ """ <<< gradient clipping <<< """
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ loss.backward()
+ """ >>> gradient clipping >>> """
+ torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
+ """ <<< gradient clipping <<< """
+ optimizer.step()
+
+ lr_scheduler.step()
+ progress_bar.update(1)
+ global_step += 1
+
+ ### <<<< Training <<<< ###
+
+ # Wandb logging
+ if is_main_process and (not is_debug) and use_wandb:
+ wandb.log({"train_loss": loss.item()}, step=global_step)
+
+ # Save checkpoint
+ if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
+ save_path = os.path.join(output_dir, f"checkpoints")
+ state_dict = {
+ "epoch": epoch,
+ "global_step": global_step,
+ "state_dict": unet.state_dict(),
+ }
+ if step == len(train_dataloader) - 1:
+ torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
+ else:
+ torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
+ logging.info(f"Saved state to {save_path} (global_step: {global_step})")
+
+ # Periodically validation
+ if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
+ samples = []
+
+ generator = torch.Generator(device=latents.device)
+ generator.manual_seed(global_seed)
+
+ height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+ width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
+
+ prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
+
+ for idx, prompt in enumerate(prompts):
+ if not image_finetune:
+ sample = validation_pipeline(
+ prompt,
+ generator = generator,
+ video_length = train_data.sample_n_frames,
+ height = height,
+ width = width,
+ **validation_data,
+ ).videos
+ save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
+ samples.append(sample)
+
+ else:
+ sample = validation_pipeline(
+ prompt,
+ generator = generator,
+ height = height,
+ width = width,
+ num_inference_steps = validation_data.get("num_inference_steps", 25),
+ guidance_scale = validation_data.get("guidance_scale", 8.),
+ ).images[0]
+ sample = torchvision.transforms.functional.to_tensor(sample)
+ samples.append(sample)
+
+ if not image_finetune:
+ samples = torch.concat(samples)
+ save_path = f"{output_dir}/samples/sample-{global_step}.gif"
+ save_videos_grid(samples, save_path)
+
+ else:
+ samples = torch.stack(samples)
+ save_path = f"{output_dir}/samples/sample-{global_step}.png"
+ torchvision.utils.save_image(samples, save_path, nrow=4)
+
+ logging.info(f"Saved samples to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= max_train_steps:
+ break
+
+ dist.destroy_process_group()
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, required=True)
+ parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
+ parser.add_argument("--wandb", action="store_true")
+ args = parser.parse_args()
+
+ name = Path(args.config).stem
+ config = OmegaConf.load(args.config)
+
+ main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)
diff --git a/videos/Put your short videos here.txt b/videos/Put your short videos here.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/videos/dance.mp4 b/videos/dance.mp4
new file mode 100644
index 00000000..1457de3b
Binary files /dev/null and b/videos/dance.mp4 differ
diff --git a/videos/smiling.mp4 b/videos/smiling.mp4
new file mode 100644
index 00000000..d8421c9c
Binary files /dev/null and b/videos/smiling.mp4 differ