diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..c5ec55892f4b264d53b7cc3d138fb5f0a2eb85b1 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,84 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+docs/kisekaeichi_ref.png filter=lfs diff=lfs merge=lfs -text
+docs/kisekaeichi_result.png filter=lfs diff=lfs merge=lfs -text
+docs/kisekaeichi_start.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_000400_00_20250607144540_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_000900_00_20250607154029_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_001000_00_20250607155150_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_001100_00_20250607160313_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_001200_00_20250607161443_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_002400_00_20250607182338_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_002500_00_20250607183410_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_005700_00_20250608001933_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_005800_00_20250608003038_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_005900_00_20250608004117_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_006000_00_20250608005157_000.png filter=lfs diff=lfs merge=lfs -text
+ani_landscape_w14_outputs/sample/ani_landscape_w14_lora_e000000_00_20250607140357_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/.ipynb_checkpoints/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000-checkpoint.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614200312_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614205429_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614201444_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614210603_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614202631_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614211751_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614203815_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000500_00_20250614214104_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000600_00_20250614215236_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000700_00_20250614220428_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000800_00_20250614221624_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000900_00_20250614222805_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001000_00_20250614223938_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001100_00_20250614225110_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001200_00_20250614230254_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001300_00_20250614231431_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001500_00_20250614233807_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001600_00_20250614234934_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001700_00_20250615000056_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001800_00_20250615001248_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001900_00_20250615002428_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002000_00_20250615003609_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002100_00_20250615004756_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002200_00_20250615005932_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002300_00_20250615011124_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002400_00_20250615012246_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002500_00_20250615013434_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002600_00_20250615014624_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002700_00_20250615015805_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002800_00_20250615020938_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002900_00_20250615022111_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003000_00_20250615023241_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003100_00_20250615024417_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003200_00_20250615025601_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003300_00_20250615030731_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003400_00_20250615031919_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003500_00_20250615033107_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003600_00_20250615034258_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003700_00_20250615035416_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003800_00_20250615040611_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003900_00_20250615041741_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004000_00_20250615042926_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004100_00_20250615044128_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004200_00_20250615045316_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004300_00_20250615050438_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004400_00_20250615051620_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004500_00_20250615052749_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004600_00_20250615053935_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004700_00_20250615055120_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004800_00_20250615060257_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004900_00_20250615061430_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005000_00_20250615062622_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005100_00_20250615063804_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005200_00_20250615064945_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005300_00_20250615070047_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005400_00_20250615071255_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005500_00_20250615072436_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005600_00_20250615073627_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005700_00_20250615074808_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005800_00_20250615075935_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193614_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614194106_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614195132_000.png filter=lfs diff=lfs merge=lfs -text
+ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614204247_000.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a615a85d2a96fa5f831eca8e3fb21c362649e984
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+__pycache__/
+.venv
+venv/
+logs/
+uv.lock
\ No newline at end of file
diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d69a07fdb2ed93eb691a290b9bbc5c841f41cdc
--- /dev/null
+++ b/.ipynb_checkpoints/README-checkpoint.md
@@ -0,0 +1,127 @@
+# Anime Light Landscape Text-to-Video Generation
+
+This repository contains the necessary steps and scripts to generate anime-style videos using the Anime_Landscape text-to-video model with LoRA (Low-Rank Adaptation) weights. The model produces anime-style videos based on textual prompts with distinctive geometric and neon aesthetic.
+
+## Prerequisites
+
+Before proceeding, ensure that you have the following installed on your system:
+
+• **Ubuntu** (or a compatible Linux distribution)
+• **Python 3.x**
+• **pip** (Python package manager)
+• **Git**
+• **Git LFS** (Git Large File Storage)
+• **FFmpeg**
+
+## Installation
+
+1. **Update and Install Dependencies**
+
+ ```bash
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
+ ```
+
+2. **Clone the Repository**
+
+ ```bash
+ git clone https://huggingface.co/svjack/Anime_Bright_Landscape_wan_2_1_14_B_text2video_lora
+ cd Anime_Bright_Landscape_wan_2_1_14_B_text2video_lora
+ ```
+
+3. **Install Python Dependencies**
+
+ ```bash
+ pip install torch torchvision
+ pip install -r requirements.txt
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
+ pip install moviepy==1.0.3
+ pip install sageattention==1.0.6
+ ```
+
+4. **Download Model Weights**
+
+ ```bash
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
+ wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
+ ```
+
+## Usage
+
+To generate a video, use the `wan_generate_video.py` script with the appropriate parameters.
+
+#### Interactive Mode
+For experimenting with different prompts:
+```bash
+python wan_generate_video.py --fp8 --task t2v-14B --video_size 480 832 --video_length 81 --infer_steps 35 \
+--save_path save --output_type both \
+--dit wan2.1_t2v_14B_bf16.safetensors --vae Wan2.1_VAE.pth \
+--t5 models_t5_umt5-xxl-enc-bf16.pth \
+--attn_mode torch \
+--lora_weight ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005500.safetensors \
+--lora_multiplier 1.0 \
+--interactive
+```
+
+```prompt
+```
+
+```prompt
+```
+
+```prompt
+```
+
+```prompt
+```
+
+```prompt
+```
+
+
+
+## Key Parameters
+
+* `--fp8`: Enable FP8 precision (recommended)
+* `--task`: Model version (`t2v-1.3B`)
+* `--video_size`: Output resolution (e.g., `480 832`)
+* `--video_length`: Number of frames (typically 81)
+* `--infer_steps`: Quality vs speed trade-off (35-50)
+* `--lora_weight`: Path to Kinich LoRA weights
+* `--lora_multiplier`: Strength of LoRA effect (1.0 recommended)
+* `--prompt`: Should include "In the style of Kinich" for best results
+
+## Style Characteristics
+
+For optimal results, prompts should describe:
+- Characters with geometric neon hair patterns
+- Black outfits with gold/teal designs
+- Futuristic or high-energy backgrounds
+- Vibrant color palettes with glowing elements
+- Dynamic poses and expressions
+
+## Output
+
+Generated videos and frames will be saved in the specified `save_path` directory with:
+- MP4 video file
+- Individual frames as PNG images
+
+## Troubleshooting
+
+• Verify all model weights are correctly downloaded
+• Ensure sufficient GPU memory (>=12GB recommended)
+• Check for version conflicts in Python packages
+
+## License
+
+This project is licensed under the MIT License.
+
+## Acknowledgments
+
+• **Hugging Face** for model hosting
+• **Wan-AI** for base models
+• **svjack** for LoRA adaptation
+
+For support, please open an issue in the repository.
\ No newline at end of file
diff --git a/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/.ipynb_checkpoints/Untitled-checkpoint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8fae7352ecb282b10ec4bc718f40cf8263e82853
--- /dev/null
+++ b/.ipynb_checkpoints/Untitled-checkpoint.ipynb
@@ -0,0 +1,42 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a0d2adc9-e517-4906-b730-3fbc16a1a7e5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 35 \\\n",
+ "--save_path save --output_type both \\\n",
+ "--dit aniWan2114BFp8E4m3fn_t2v13B.safetensors --vae Wan2.1_VAE.pth \\\n",
+ "--t5 models_t5_umt5-xxl-enc-bf16.pth \\\n",
+ "--attn_mode torch \\\n",
+ "--lora_weight Kinich_w1_3_outputs/Kinich_w1_3_lora-000070.safetensors \\\n",
+ "--lora_multiplier 1.0 \\\n",
+ "--interactive\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "wan2gp",
+ "language": "python",
+ "name": "wan2gp"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000000000000000000000000000000000000..c8cfe3959183f8e9a50f83f54cd723f2dc9c252d
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.10
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d69a07fdb2ed93eb691a290b9bbc5c841f41cdc
--- /dev/null
+++ b/README.md
@@ -0,0 +1,127 @@
+# Anime Light Landscape Text-to-Video Generation
+
+This repository contains the necessary steps and scripts to generate anime-style videos using the Anime_Landscape text-to-video model with LoRA (Low-Rank Adaptation) weights. The model produces anime-style videos based on textual prompts with distinctive geometric and neon aesthetic.
+
+## Prerequisites
+
+Before proceeding, ensure that you have the following installed on your system:
+
+• **Ubuntu** (or a compatible Linux distribution)
+• **Python 3.x**
+• **pip** (Python package manager)
+• **Git**
+• **Git LFS** (Git Large File Storage)
+• **FFmpeg**
+
+## Installation
+
+1. **Update and Install Dependencies**
+
+ ```bash
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
+ ```
+
+2. **Clone the Repository**
+
+ ```bash
+ git clone https://huggingface.co/svjack/Anime_Bright_Landscape_wan_2_1_14_B_text2video_lora
+ cd Anime_Bright_Landscape_wan_2_1_14_B_text2video_lora
+ ```
+
+3. **Install Python Dependencies**
+
+ ```bash
+ pip install torch torchvision
+ pip install -r requirements.txt
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
+ pip install moviepy==1.0.3
+ pip install sageattention==1.0.6
+ ```
+
+4. **Download Model Weights**
+
+ ```bash
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
+ wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
+ ```
+
+## Usage
+
+To generate a video, use the `wan_generate_video.py` script with the appropriate parameters.
+
+#### Interactive Mode
+For experimenting with different prompts:
+```bash
+python wan_generate_video.py --fp8 --task t2v-14B --video_size 480 832 --video_length 81 --infer_steps 35 \
+--save_path save --output_type both \
+--dit wan2.1_t2v_14B_bf16.safetensors --vae Wan2.1_VAE.pth \
+--t5 models_t5_umt5-xxl-enc-bf16.pth \
+--attn_mode torch \
+--lora_weight ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005500.safetensors \
+--lora_multiplier 1.0 \
+--interactive
+```
+
+```prompt
+```
+
+```prompt
+```
+
+```prompt
+```
+
+```prompt
+```
+
+```prompt
+```
+
+
+
+## Key Parameters
+
+* `--fp8`: Enable FP8 precision (recommended)
+* `--task`: Model version (`t2v-1.3B`)
+* `--video_size`: Output resolution (e.g., `480 832`)
+* `--video_length`: Number of frames (typically 81)
+* `--infer_steps`: Quality vs speed trade-off (35-50)
+* `--lora_weight`: Path to Kinich LoRA weights
+* `--lora_multiplier`: Strength of LoRA effect (1.0 recommended)
+* `--prompt`: Should include "In the style of Kinich" for best results
+
+## Style Characteristics
+
+For optimal results, prompts should describe:
+- Characters with geometric neon hair patterns
+- Black outfits with gold/teal designs
+- Futuristic or high-energy backgrounds
+- Vibrant color palettes with glowing elements
+- Dynamic poses and expressions
+
+## Output
+
+Generated videos and frames will be saved in the specified `save_path` directory with:
+- MP4 video file
+- Individual frames as PNG images
+
+## Troubleshooting
+
+• Verify all model weights are correctly downloaded
+• Ensure sufficient GPU memory (>=12GB recommended)
+• Check for version conflicts in Python packages
+
+## License
+
+This project is licensed under the MIT License.
+
+## Acknowledgments
+
+• **Hugging Face** for model hosting
+• **Wan-AI** for base models
+• **svjack** for LoRA adaptation
+
+For support, please open an issue in the repository.
\ No newline at end of file
diff --git a/Untitled.ipynb b/Untitled.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8fae7352ecb282b10ec4bc718f40cf8263e82853
--- /dev/null
+++ b/Untitled.ipynb
@@ -0,0 +1,42 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a0d2adc9-e517-4906-b730-3fbc16a1a7e5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 35 \\\n",
+ "--save_path save --output_type both \\\n",
+ "--dit aniWan2114BFp8E4m3fn_t2v13B.safetensors --vae Wan2.1_VAE.pth \\\n",
+ "--t5 models_t5_umt5-xxl-enc-bf16.pth \\\n",
+ "--attn_mode torch \\\n",
+ "--lora_weight Kinich_w1_3_outputs/Kinich_w1_3_lora-000070.safetensors \\\n",
+ "--lora_multiplier 1.0 \\\n",
+ "--interactive\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "wan2gp",
+ "language": "python",
+ "name": "wan2gp"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00000500.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00000500.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..e88828406b9b38dcbfb7ffbd23263b2dc398361f
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00000500.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8268e71a0ce3d721d00f18fee64216845eb671b9def5381cf97cdb43c275d3e9
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00001000.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00001000.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a86bc6675f503891fd7bf9f50d3cd96c7ec8fe9b
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00001000.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:70804ff7ce7c684d17f1082cc9e3d7a98ecf5099612bfb081e0b9b675048b2b2
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00001500.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00001500.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..4c9786a3a3bcea4b31f4718397360a52195fc586
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00001500.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:232c7e6ef6d4735d27ce4f53fc67aa67d5a4f1b79e87ce72928a1914ab0327d5
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00002000.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00002000.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..7ef8d704238be9c1567188bf511eefee006a92d5
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00002000.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:087d5da7ba98f7d7d0150c393946e3ae8cde6ef78ba152ee5c0993e0ba6b2d81
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00002500.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00002500.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a1335459a06e7f47aaacaa2038aaef9e66c998d6
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00002500.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b1dd78a5d1f4a6ee7ec0806b7d4e2e7b0dafe0d831e58b71a9a9639d235d5957
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00003000.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00003000.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a5258dd2523c434fdb301f04a026c8300b46f40a
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00003000.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1804b549e7be6ff6099e5040b22549c9a2f238ff753200ce76a4de39c5af6ba
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00003500.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00003500.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..6de3b42fe154a4eee50af98462021a7329675831
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00003500.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3786a74e3d3bb23dead2d3c65187eaa0d90c4cdb224b084b5b71e7e803265415
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00004000.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00004000.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..c2920f6971cb6a10ff68e30cbe1ea2d91ad046ef
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00004000.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad89c2cce43b3baf5a4f211a219503c869b69321226a460abd78b8547c0d9c2f
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00004500.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00004500.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..9e50f2b2cd07ba87d19b62ea372e771bcaa4f378
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00004500.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c62ff62e759c2432d403998945e4e740dd65e57fc7a6d0628fd83b7cf067e78e
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005000.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005000.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..624b0ceb5af5153bead38e28c2b6ee8c9ad9b60e
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005000.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a3d94100772a310d08afb66b9b3654bf5785011c0628be852c8e856caad1aef
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005500.safetensors b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005500.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..2f03e001dda46057a9264bc5fba8776bdab86280
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/ani_bright_landscape_w14_lora-step00005500.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6aa5ad7f728324b71421dbae565198e54be0cc2b257d7e4406003a00af37345
+size 613557784
diff --git a/ani_bright_landscape_w14_outputs/sample/.ipynb_checkpoints/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000-checkpoint.png b/ani_bright_landscape_w14_outputs/sample/.ipynb_checkpoints/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000-checkpoint.png
new file mode 100644
index 0000000000000000000000000000000000000000..0a229afa6a1b1171b7816bdfb0dce06303c2efb5
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/.ipynb_checkpoints/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000-checkpoint.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:523f8ffc816a3716bcf0e71ba0ff993897e3e3f3cde05cf00b90af4eacce9c29
+size 112817
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614200312_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614200312_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..661acd5c3a99d7e483234775ab27b1809fb307be
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614200312_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2ad86a4e92cf00dcf3613b69b5cdb72a33644e2e5cf85729cd8c790efa8677ee
+size 101659
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614205429_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614205429_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..4af2fb2d09be77acb256e15b614426d9a98a0dc0
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000100_00_20250614205429_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:692e2662cd15e2b0b3eef19010bd72521d02302f9c8aba135cce5a504f4f44b9
+size 116434
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614201444_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614201444_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb2dbf0f9eb8b73f6a4aff5614b9dab2e0762f34
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614201444_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c89e174e5da55c5500dfb540bcabb53c21fe7d8bd7685c5a44ca8a7ff06af5e
+size 103245
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614210603_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614210603_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..3b0661bc04f691edadd8571b045c91d7021d5acc
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000200_00_20250614210603_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0114e39379946620c8fd19cb5868b68f2f2b4230a177b6c9d80c78d76c2f2c36
+size 101880
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614202631_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614202631_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..4de93aabb313f482992156ed39a82d7fd7325f82
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614202631_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34b2612a09be5e8aebd5177a7df4bfd3991a5f440eed321baa130e1ee65927ad
+size 104359
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614211751_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614211751_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..f2fd7402ca92317beeeb427036e5462ba9537c4c
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000300_00_20250614211751_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8b9a8abe20ecee36550b5a765622a13de730ab2902f1f2eb9b1f2e650b7759f
+size 111789
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614203815_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614203815_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee1a7e29cf17bba051108203b929e11ce27760c3
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614203815_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:272aa54a44980f46f48840d78c8417170c076f119c5e3d453be094508bda7515
+size 106373
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614212935_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614212935_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..5237da95ef13407c777e746d7b4552a4c20f787f
Binary files /dev/null and b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000400_00_20250614212935_000.png differ
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000500_00_20250614214104_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000500_00_20250614214104_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..74ba003fd1ce644ae8b5febb3e437007eff251fc
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000500_00_20250614214104_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c8648164cd9c0ac4c9d9b822e7dd38f50ef56cdc295076dc3ea6aa23058ac02
+size 115805
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000600_00_20250614215236_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000600_00_20250614215236_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..9e1d713a8152c36e74d74c000c242a952f82b5bb
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000600_00_20250614215236_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d966b0de82ed2ae5dc2858660c59f03721f42756045724b8fb7e8ed62e5368a
+size 118612
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000700_00_20250614220428_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000700_00_20250614220428_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..dfe903650cf7b7d2554e7b0016884325de228266
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000700_00_20250614220428_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:412616594f9b67566612beaee9d140ef8a15a07a06625b7d1f1497060a315383
+size 105721
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000800_00_20250614221624_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000800_00_20250614221624_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..f1ae77248b9c68d494c547671e9deea6147289e5
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000800_00_20250614221624_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6382e426ad328bbfebdc491004b2e14ef919d200dd4238f59658477d79905e49
+size 122355
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000900_00_20250614222805_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000900_00_20250614222805_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..f693ff475f19f02b71a48ffb81a778fe0c6af3f7
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_000900_00_20250614222805_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:870ce0d929e211546997cdc84cae6df2a3be7faa5c7b548c37a1721f109ab4cc
+size 111192
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001000_00_20250614223938_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001000_00_20250614223938_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..3ad6db491bd9f3eaaccddbf1527a457cb496ac4a
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001000_00_20250614223938_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0825458c4e35e06e8658cea68d8e3b616362714691ade9cb5b62a76a36bff32
+size 127524
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001100_00_20250614225110_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001100_00_20250614225110_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..943b11bb858193a83467cda940acd88c79405173
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001100_00_20250614225110_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c78d8f3040470e9783bd11df2fb67dc459846ed2ed13ee0f039d7bc882a1934f
+size 122696
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001200_00_20250614230254_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001200_00_20250614230254_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6823511926388d44939e0cb08bdd893840cd5e2
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001200_00_20250614230254_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf2132498c55028d4efb6f97c2b280da6c2a37ae68369d75a3d8797484b76d2c
+size 124948
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001300_00_20250614231431_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001300_00_20250614231431_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a52e52836ef2df46423f925143d27e86c5857c7
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001300_00_20250614231431_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa88a9fdf7b29eb036210c1b85e9b83d2dbe5ed5c1df77981344c056ea8088ed
+size 121463
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001400_00_20250614232604_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001400_00_20250614232604_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..caa93ca68410f828111348192e993793f0435ef9
Binary files /dev/null and b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001400_00_20250614232604_000.png differ
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001500_00_20250614233807_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001500_00_20250614233807_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..01ed44917bff7562951ecf6ab60fc7a8155466a7
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001500_00_20250614233807_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:984ad07a9c89d44b17c86f3b8c5ea5cddabf55bae55ac4c4e151496ed81ca5c6
+size 128884
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001600_00_20250614234934_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001600_00_20250614234934_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..611f5b1883e6832077e48cb582c72a368a8bd2a1
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001600_00_20250614234934_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06cbd8afd7075e95ac195e07fadf2a009fee4d4fc4252d94b32605894af30546
+size 120033
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001700_00_20250615000056_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001700_00_20250615000056_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..9bab4dd2c49ed7d4379f068e03bfcb2feee0ab58
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001700_00_20250615000056_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9acc04b3850bad79b4f9a0ba124205cd26d452a27e4e480d8678a3916e6b8ac7
+size 114804
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001800_00_20250615001248_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001800_00_20250615001248_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb858cfc1e1a1f745dec270a0aaf5dd992d46b2e
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001800_00_20250615001248_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e12e41b04a8571415de66966f63dcc86c22b4e59010a9b9ec683dfe98058f93
+size 108694
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001900_00_20250615002428_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001900_00_20250615002428_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd454e123cabba790d722e24abc1ad1720aa1cc5
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_001900_00_20250615002428_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:198ee84d7f38e1acedca658c43283b1ddf64e12804a39f205b175676be536183
+size 132829
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002000_00_20250615003609_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002000_00_20250615003609_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..d325c2d3418e922bea96afa3e673d5d5fa6749fe
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002000_00_20250615003609_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87ac31d7f505bc284bf882a39b4db5b06189b65bddee1d86907207c84c64bc2f
+size 117018
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002100_00_20250615004756_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002100_00_20250615004756_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..9cafde80ce15ce07dbd0264a61dbac7edd4cb6e6
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002100_00_20250615004756_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e2606003f15a8a6a9099e4121b5758b6d0a9632f82731d232260117c7940a2d
+size 129610
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002200_00_20250615005932_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002200_00_20250615005932_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..13c1da3543f3c078d5ae4c7f5af049626ae6055a
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002200_00_20250615005932_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e1cc90c9a3d9b06f2cf159c398816d821a7a9fd51ae3e15a2eb4d82e57a473ea
+size 121122
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002300_00_20250615011124_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002300_00_20250615011124_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..7651a783e2a5645ef91afb9c5efa024d07768fc9
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002300_00_20250615011124_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1bbfa9ec059a1a99378519037781b33696707f4ee50327a9934264374b45da4b
+size 118534
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002400_00_20250615012246_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002400_00_20250615012246_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c380197607b5a98d91238470b471e18ff992b17
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002400_00_20250615012246_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cab0b726324ee288259842b1d65188c63c7f0bc4ff004312a4891d344d25aa25
+size 125692
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002500_00_20250615013434_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002500_00_20250615013434_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..2f4d9b748a2ac19bfc489492a98dea0035d2cae4
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002500_00_20250615013434_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ecd96e34fa0e4dd473c0f68265c2f944fc82044912a00ff3a53bbbd5819600c0
+size 124710
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002600_00_20250615014624_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002600_00_20250615014624_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..f7693712442c84c396baa2fc4c15ef5a504092f0
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002600_00_20250615014624_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b42741b0ac951091d96ed7a23792b0f470bf8761451356239227ea392c0f8e2c
+size 124000
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002700_00_20250615015805_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002700_00_20250615015805_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..efffaef1b5dc7cd41ea11980f77601dcad4212ef
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002700_00_20250615015805_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fa7d5f7faa05351c820bc5ef05ff82c4453d046c13ce35994c092bb2449b4d5
+size 123799
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002800_00_20250615020938_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002800_00_20250615020938_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..8da214a692e9792049f78d9e5bbed7a583b6fab1
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002800_00_20250615020938_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7458872fbb015b369b3a7cd75a0d205dfc20c153cb378fecb6d505efe29b254b
+size 134176
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002900_00_20250615022111_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002900_00_20250615022111_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..30685d730c094f9fb1a4ffa0d9d1c041f363aacf
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_002900_00_20250615022111_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4013c7ac587ea00e3ea409bf2e0b9fd9d9bb3cc018a780d9e0660c85b8421afd
+size 134661
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003000_00_20250615023241_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003000_00_20250615023241_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..9e9d127d2dcf80649eccef7fcd0191ba45e6f9f3
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003000_00_20250615023241_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:995ea30756019a9893f0fb248021e271c5c8f7805c96ba0186375eb875229bf9
+size 125542
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003100_00_20250615024417_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003100_00_20250615024417_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..d0bb8136e29e2664f0c4f198745f23e9b9d88d35
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003100_00_20250615024417_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3150f4773b93abe3aa4525656597b127880d72d3ec6721f16d77b8cf00435207
+size 117712
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003200_00_20250615025601_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003200_00_20250615025601_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..6a32e6b0745c8e23d8f0f6ba425ec179babd0d5a
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003200_00_20250615025601_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42fb6127250445499e8bfdea3d7830becffdb07253f036b5ae5e35275af9d562
+size 114743
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003300_00_20250615030731_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003300_00_20250615030731_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..30e7085122f8a79acf50d3dffad836eab6cdbb5b
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003300_00_20250615030731_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e239e6b1799911d401a35093fa4cf1ed44eed307de42e9dad4224ddac7e24a9
+size 126740
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003400_00_20250615031919_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003400_00_20250615031919_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a86414b772c2343634299a6e68f0ce942a90fbd
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003400_00_20250615031919_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ee3ad87b5bded9409363329399050cdb0532bad824149728f8afeedcb8daef5
+size 119057
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003500_00_20250615033107_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003500_00_20250615033107_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..00f1942f8cdfc4cb5015206969195f11f09ad3b9
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003500_00_20250615033107_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2120ae4b9e92c4a6df5c80f4722745232d9dce76d902900f3ca54343d79f88fc
+size 119802
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003600_00_20250615034258_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003600_00_20250615034258_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..20106c62771032e523189d5a15e212f53eb99ac3
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003600_00_20250615034258_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7ec08700431901e3b5da9986e5d19281d593d1c92114c41971d71f64a6f2894b
+size 113388
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003700_00_20250615035416_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003700_00_20250615035416_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..752934bc41ce023183c095a4b2b603664704c1cc
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003700_00_20250615035416_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3406cbd49bbed12eeee6ad25f6ba471bb272731609e3fe2e1cbb2118faa3e780
+size 120372
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003800_00_20250615040611_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003800_00_20250615040611_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..dd9bd5ed4b0bcad09d20d3e2eedd0e8397292fbc
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003800_00_20250615040611_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13ca4716ad5de7e96d00314545525f57018550b67ed57c7be96e66f6fb47942b
+size 116311
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003900_00_20250615041741_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003900_00_20250615041741_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..dada78dd558c8197de98086281e6ed85c873f8b8
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_003900_00_20250615041741_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ce822eccf458a9afd60654069895ae424feba7171122530fae12c909cbce88a
+size 100023
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004000_00_20250615042926_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004000_00_20250615042926_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..36a2942bd93e3e369488ec1e0a677dd4f1f52d6b
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004000_00_20250615042926_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d0a05af570feb683e24b2f0b9b83880289f1f81604f8a6f41a2bdaa8c2a59f0
+size 126491
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004100_00_20250615044128_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004100_00_20250615044128_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..207c75c6664b2794a42043c313cb4aaca40f9dd7
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004100_00_20250615044128_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3729ab025ad4e6b8e9fdf007b48b8adb6a5a8f798b4c68aef5641c804b151e08
+size 133138
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004200_00_20250615045316_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004200_00_20250615045316_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b198f373e169b9b1905d76da10889be5c4fbe569
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004200_00_20250615045316_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2488a1594fe1e9a0b0cef8161354d85876f21f2bd8c56d3f6e5bec4a88725a7
+size 130490
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004300_00_20250615050438_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004300_00_20250615050438_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..70713b353eac9bbc36f49bcff7059a64a8dc2961
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004300_00_20250615050438_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e0df3b7b440ade8272cbd9443b1f892c25a2be2e5eb20933ef18fbc2fdf9855
+size 128869
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004400_00_20250615051620_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004400_00_20250615051620_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..86dc6c3af37913d47e4225c60d226fb54c398718
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004400_00_20250615051620_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50b565fb102491b57c2ae068520861c8690e6b32863bf3008c663e5cb12ce498
+size 115605
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004500_00_20250615052749_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004500_00_20250615052749_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3bc5acbdb3022ec00890c78546c9d740617cc98
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004500_00_20250615052749_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a5b0ac983f29143c857b8d474e43e42de1f786b36880717b3e084968a04445c
+size 125685
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004600_00_20250615053935_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004600_00_20250615053935_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e4b0e3fac316a26cb540571c0f219b5311df73e
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004600_00_20250615053935_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:187e4bac3cdb15519653ee3ec23d89077c5af62a50073a68427f96d345fdc2de
+size 135787
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004700_00_20250615055120_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004700_00_20250615055120_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..83c4a688de58280c4badb25d928c953ff47e2976
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004700_00_20250615055120_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3878d1b6b128229013ccfe770f1d1f74ca9f3675caf189a0ca64e21ce1e9758e
+size 129182
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004800_00_20250615060257_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004800_00_20250615060257_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..1df7d44100c8ad568f62d8712625799634d42957
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004800_00_20250615060257_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c24cba0ea106559f0bed6cfeb03b86f6ee206c63df26bbf46e91cdaf1e8c5c9
+size 125511
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004900_00_20250615061430_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004900_00_20250615061430_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..2610358d37841348600bf8f29b3e4cb4a479e175
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_004900_00_20250615061430_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73125167e2393740cd205e1028a8a01f20d8129ea846b498252ead445c8cb85a
+size 132755
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005000_00_20250615062622_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005000_00_20250615062622_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..48823dae73d42094f99711d713e63d70822c5dee
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005000_00_20250615062622_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b4e327e3b364ef3b5291a13e7fead9867062f5d4fd8f72adf0ceb4adda58c06
+size 134415
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005100_00_20250615063804_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005100_00_20250615063804_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b65416ed9b29c61705488a0d28fc75b7e4057c06
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005100_00_20250615063804_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc1ce4a94c445db88012d968e8bb69db49fb3648228b483f52e1dde7446f90fe
+size 143648
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005200_00_20250615064945_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005200_00_20250615064945_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab81d13e1f7aaea95bfc7a5b60913d77d7458d63
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005200_00_20250615064945_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eaac5165f606caf99690d344c47809bb8aec837bead2e0b876b50d4826326557
+size 115446
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005300_00_20250615070047_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005300_00_20250615070047_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..dacd35f3a94fc150e37d96bdc7789914188454a6
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005300_00_20250615070047_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:268bd949f83b606762248ded35fb9712a2c694ffe46cdd0ca229ce1bd8c814f4
+size 127306
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005400_00_20250615071255_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005400_00_20250615071255_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..74643a6b96870f0334469f778bfac06280250c06
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005400_00_20250615071255_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a452fd8bb880030d2284e20b32b255924e256c913fdacd448c94d21a4b82610
+size 121329
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005500_00_20250615072436_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005500_00_20250615072436_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee3945469cb3fbac4aae598464f8312aac54f29b
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005500_00_20250615072436_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e63b269dca3146a9e14a444f81d35282acfb3038a615bb482c9f13b9b202984e
+size 123891
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005600_00_20250615073627_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005600_00_20250615073627_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9da4b9e4793603e00a5f08fc9e58e9b36f18c27
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005600_00_20250615073627_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5a330b0ff1e5f533fe69aff9433057f7b973de3ae95fcb2b7aea5b54205a8fe
+size 133998
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005700_00_20250615074808_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005700_00_20250615074808_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..0dca424e79d4c929efa753f0b1fe7aa71b48e467
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005700_00_20250615074808_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c254085af71e716a0b15be2091c866a6c4c266b3e16d7e1cd5603dab06dddcef
+size 126486
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005800_00_20250615075935_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005800_00_20250615075935_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..6a6ccb8cfe0af8d359a9bc5fbc2576736bf0da82
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_005800_00_20250615075935_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08f3e3160b63643cf6be90ea31fca0d8fb563b8748d935f73d7bce99071307fb
+size 132391
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..0a229afa6a1b1171b7816bdfb0dce06303c2efb5
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193330_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:523f8ffc816a3716bcf0e71ba0ff993897e3e3f3cde05cf00b90af4eacce9c29
+size 112817
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193614_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193614_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..2fcbf72fa3dee86b63143ae4efda77d71207e142
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614193614_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07a9f546ce3246ea15d6e0e582abf2127fe8084fc376accbc1b494b4a2956fab
+size 115510
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614194106_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614194106_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..67e1bdead6b560588246e520b8a0d65256fc37ff
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614194106_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d64117a721dd2254c642ac0c38afd3be8d1abc095382dbb94ae5a298f25ed1d
+size 117504
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614195132_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614195132_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b69f34895dd07b94819ce9f3e6c6dcbd8ab02abe
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614195132_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:154be63fc0ae409d6579ad438e6569c5378f5495309eecc0840ee4425a0bb4c3
+size 112317
diff --git a/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614204247_000.png b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614204247_000.png
new file mode 100644
index 0000000000000000000000000000000000000000..b13f771901b70b9acd138862c701699f42200a3d
--- /dev/null
+++ b/ani_bright_landscape_w14_outputs/sample/ani_bright_landscape_w14_lora_e000000_00_20250614204247_000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c133c0056ccd7d16ead96c9bd4758310f346d55626dfa08595599827a845e50
+size 127100
diff --git a/cache_latents.py b/cache_latents.py
new file mode 100644
index 0000000000000000000000000000000000000000..15cae1237827009884dbf49f029f434f45a4dc97
--- /dev/null
+++ b/cache_latents.py
@@ -0,0 +1,339 @@
+import argparse
+import os
+import glob
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from PIL import Image
+
+import logging
+
+from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
+from hunyuan_model.vae import load_vae
+from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from utils.model_utils import str_to_dtype
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
+ import cv2
+
+ imgs = (
+ [image]
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
+ else [image[0], image[-1]]
+ )
+ if len(imgs) > 1:
+ print(f"Number of images: {len(image)}")
+ for i, img in enumerate(imgs):
+ if len(imgs) > 1:
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
+ else:
+ print(f"Image: {img.shape}")
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
+ cv2.imshow("image", cv2_img)
+ k = cv2.waitKey(0)
+ cv2.destroyAllWindows()
+ if k == ord("q") or k == ord("d"):
+ return k
+ return k
+
+
+def show_console(
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
+ width: int,
+ back: str,
+ interactive: bool = False,
+) -> int:
+ from ascii_magic import from_pillow_image, Back
+
+ back = None
+ if back is not None:
+ back = getattr(Back, back.upper())
+
+ k = None
+ imgs = (
+ [image]
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
+ else [image[0], image[-1]]
+ )
+ if len(imgs) > 1:
+ print(f"Number of images: {len(image)}")
+ for i, img in enumerate(imgs):
+ if len(imgs) > 1:
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
+ else:
+ print(f"Image: {img.shape}")
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
+ ascii_img = from_pillow_image(pil_img)
+ ascii_img.to_terminal(columns=width, back=back)
+
+ if interactive:
+ k = input("Press q to quit, d to next dataset, other key to next: ")
+ if k == "q" or k == "d":
+ return ord(k)
+
+ if not interactive:
+ return ord(" ")
+ return ord(k) if k else ord(" ")
+
+
+def save_video(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], cache_path: str, fps: int = 24):
+ import av
+
+ directory = os.path.dirname(cache_path)
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image):
+ # save image
+ image_path = cache_path.replace(".safetensors", ".jpg")
+ img = image if isinstance(image, Image.Image) else Image.fromarray(image)
+ img.save(image_path)
+ print(f"Saved image: {image_path}")
+ else:
+ imgs = image
+ print(f"Number of images: {len(imgs)}")
+ # save video
+ video_path = cache_path.replace(".safetensors", ".mp4")
+ height, width = imgs[0].shape[0:2]
+
+ # create output container
+ container = av.open(video_path, mode="w")
+
+ # create video stream
+ codec = "libx264"
+ pixel_format = "yuv420p"
+ stream = container.add_stream(codec, rate=fps)
+ stream.width = width
+ stream.height = height
+ stream.pix_fmt = pixel_format
+ stream.bit_rate = 1000000 # 1Mbit/s for preview quality
+
+ for frame_img in imgs:
+ if isinstance(frame_img, Image.Image):
+ frame = av.VideoFrame.from_image(frame_img)
+ else:
+ frame = av.VideoFrame.from_ndarray(frame_img, format="rgb24")
+ packets = stream.encode(frame)
+ for packet in packets:
+ container.mux(packet)
+
+ for packet in stream.encode():
+ container.mux(packet)
+
+ container.close()
+
+ print(f"Saved video: {video_path}")
+
+
+def show_datasets(
+ datasets: list[BaseDataset],
+ debug_mode: str,
+ console_width: int,
+ console_back: str,
+ console_num_images: Optional[int],
+ fps: int = 24,
+):
+ if debug_mode != "video":
+ print(f"d: next dataset, q: quit")
+
+ num_workers = max(1, os.cpu_count() - 1)
+ for i, dataset in enumerate(datasets):
+ print(f"Dataset [{i}]")
+ batch_index = 0
+ num_images_to_show = console_num_images
+ k = None
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
+ print(f"bucket resolution: {key}, count: {len(batch)}")
+ for j, item_info in enumerate(batch):
+ item_info: ItemInfo
+ print(f"{batch_index}-{j}: {item_info}")
+ if debug_mode == "image":
+ k = show_image(item_info.content)
+ elif debug_mode == "console":
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
+ if num_images_to_show is not None:
+ num_images_to_show -= 1
+ if num_images_to_show == 0:
+ k = ord("d") # next dataset
+ elif debug_mode == "video":
+ save_video(item_info.content, item_info.latent_cache_path, fps)
+ k = None # save next video
+
+ if k == ord("q"):
+ return
+ elif k == ord("d"):
+ break
+ if k == ord("d"):
+ break
+ batch_index += 1
+
+
+def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
+ if len(contents.shape) == 4:
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
+
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
+ contents = contents.to(vae.device, dtype=vae.dtype)
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
+
+ h, w = contents.shape[3], contents.shape[4]
+ if h < 8 or w < 8:
+ item = batch[0] # other items should have the same size
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
+
+ # print(f"encode batch: {contents.shape}")
+ with torch.no_grad():
+ latent = vae.encode(contents).latent_dist.sample()
+ # latent = latent * vae.config.scaling_factor
+
+ # # debug: decode and save
+ # with torch.no_grad():
+ # latent_to_decode = latent / vae.config.scaling_factor
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
+ # images = (images / 2 + 0.5).clamp(0, 1)
+ # images = images.cpu().float().numpy()
+ # images = (images * 255).astype(np.uint8)
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
+ # for b in range(images.shape[0]):
+ # for f in range(images.shape[1]):
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
+ # img = Image.fromarray(images[b, f])
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
+
+ for item, l in zip(batch, latent):
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
+ save_latent_cache(item, l)
+
+
+def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
+ for i, dataset in enumerate(datasets):
+ logger.info(f"Encoding dataset [{i}]")
+ all_latent_cache_paths = []
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
+ all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
+
+ if args.skip_existing:
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
+ if len(filtered_batch) == 0:
+ continue
+ batch = filtered_batch
+
+ bs = args.batch_size if args.batch_size is not None else len(batch)
+ for i in range(0, len(batch), bs):
+ encode(batch[i : i + bs])
+
+ # normalize paths
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
+ all_latent_cache_paths = set(all_latent_cache_paths)
+
+ # remove old cache files not in the dataset
+ all_cache_files = dataset.get_all_latent_cache_files()
+ for cache_file in all_cache_files:
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
+ if args.keep_cache:
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
+ else:
+ os.remove(cache_file)
+ logger.info(f"Removed old cache file: {cache_file}")
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ if args.debug_mode is not None:
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
+ return
+
+ assert args.vae is not None, "vae checkpoint is required"
+
+ # Load VAE model: HunyuanVideo VAE model is float16
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
+ vae.eval()
+ logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
+
+ if args.vae_chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ elif args.vae_tiling:
+ vae.enable_spatial_tiling(True)
+
+ # Encode images
+ def encode(one_batch: list[ItemInfo]):
+ encode_and_save_batch(vae, one_batch)
+
+ encode_datasets(datasets, encode, args)
+
+
+def setup_parser_common() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
+ parser.add_argument(
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
+ )
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console", "video"], help="debug mode")
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
+ parser.add_argument(
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
+ )
+ parser.add_argument(
+ "--console_num_images",
+ type=int,
+ default=None,
+ help="debug mode: not interactive, number of images to show for each dataset",
+ )
+ return parser
+
+
+def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser.add_argument(
+ "--vae_tiling",
+ action="store_true",
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser_common()
+ parser = hv_setup_parser(parser)
+
+ args = parser.parse_args()
+ main(args)
diff --git a/cache_text_encoder_outputs.py b/cache_text_encoder_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..491e8a05417bc33c90b53663f7952cc3d2b6ded7
--- /dev/null
+++ b/cache_text_encoder_outputs.py
@@ -0,0 +1,214 @@
+import argparse
+import os
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+import accelerate
+
+from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
+from hunyuan_model import text_encoder as text_encoder_module
+from hunyuan_model.text_encoder import TextEncoder
+
+import logging
+
+from utils.model_utils import str_to_dtype
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
+ data_type = "video" # video only, image is not supported
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
+
+ with torch.no_grad():
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
+
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
+
+
+def encode_and_save_batch(
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
+):
+ prompts = [item.caption for item in batch]
+ # print(prompts)
+
+ # encode prompt
+ if accelerator is not None:
+ with accelerator.autocast():
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
+ else:
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
+
+ # # convert to fp16 if needed
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
+
+ # save prompt cache
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
+
+
+def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
+ all_cache_files_for_dataset = [] # exisiting cache files
+ all_cache_paths_for_dataset = [] # all cache paths in the dataset
+ for dataset in datasets:
+ all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
+ all_cache_files = set(all_cache_files)
+ all_cache_files_for_dataset.append(all_cache_files)
+
+ all_cache_paths_for_dataset.append(set())
+ return all_cache_files_for_dataset, all_cache_paths_for_dataset
+
+
+def process_text_encoder_batches(
+ num_workers: Optional[int],
+ skip_existing: bool,
+ batch_size: int,
+ datasets: list[BaseDataset],
+ all_cache_files_for_dataset: list[set],
+ all_cache_paths_for_dataset: list[set],
+ encode: callable,
+):
+ num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
+ for i, dataset in enumerate(datasets):
+ logger.info(f"Encoding dataset [{i}]")
+ all_cache_files = all_cache_files_for_dataset[i]
+ all_cache_paths = all_cache_paths_for_dataset[i]
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
+ # update cache files (it's ok if we update it multiple times)
+ all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
+
+ # skip existing cache files
+ if skip_existing:
+ filtered_batch = [
+ item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
+ ]
+ # print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
+ if len(filtered_batch) == 0:
+ continue
+ batch = filtered_batch
+
+ bs = batch_size if batch_size is not None else len(batch)
+ for i in range(0, len(batch), bs):
+ encode(batch[i : i + bs])
+
+
+def post_process_cache_files(
+ datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set], keep_cache: bool
+):
+ for i, dataset in enumerate(datasets):
+ all_cache_files = all_cache_files_for_dataset[i]
+ all_cache_paths = all_cache_paths_for_dataset[i]
+ for cache_file in all_cache_files:
+ if cache_file not in all_cache_paths:
+ if keep_cache:
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
+ else:
+ os.remove(cache_file)
+ logger.info(f"Removed old cache file: {cache_file}")
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ # define accelerator for fp8 inference
+ accelerator = None
+ if args.fp8_llm:
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
+
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
+
+ # Load Text Encoder 1
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
+ text_encoder_1.to(device=device)
+
+ # Encode with Text Encoder 1 (LLM)
+ logger.info("Encoding with Text Encoder 1")
+
+ def encode_for_text_encoder_1(batch: list[ItemInfo]):
+ encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
+
+ process_text_encoder_batches(
+ args.num_workers,
+ args.skip_existing,
+ args.batch_size,
+ datasets,
+ all_cache_files_for_dataset,
+ all_cache_paths_for_dataset,
+ encode_for_text_encoder_1,
+ )
+ del text_encoder_1
+
+ # Load Text Encoder 2
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
+ text_encoder_2.to(device=device)
+
+ # Encode with Text Encoder 2
+ logger.info("Encoding with Text Encoder 2")
+
+ def encode_for_text_encoder_2(batch: list[ItemInfo]):
+ encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
+
+ process_text_encoder_batches(
+ args.num_workers,
+ args.skip_existing,
+ args.batch_size,
+ datasets,
+ all_cache_files_for_dataset,
+ all_cache_paths_for_dataset,
+ encode_for_text_encoder_2,
+ )
+ del text_encoder_2
+
+ # remove cache files not in dataset
+ post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
+
+
+def setup_parser_common():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
+ parser.add_argument(
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
+ )
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
+ return parser
+
+
+def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser_common()
+ parser = hv_setup_parser(parser)
+
+ args = parser.parse_args()
+ main(args)
diff --git a/convert_lora.py b/convert_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd9b77d3ddfdd3e3cd802157c3abc2583d29709c
--- /dev/null
+++ b/convert_lora.py
@@ -0,0 +1,137 @@
+import argparse
+
+import torch
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from utils import model_utils
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def convert_from_diffusers(prefix, weights_sd):
+ # convert from diffusers(?) to default LoRA
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
+
+ # note: Diffusers has no alpha, so alpha is set to rank
+ new_weights_sd = {}
+ lora_dims = {}
+ for key, weight in weights_sd.items():
+ diffusers_prefix, key_body = key.split(".", 1)
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
+ logger.warning(f"unexpected key: {key} in diffusers format")
+ continue
+
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
+ new_weights_sd[new_key] = weight
+
+ lora_name = new_key.split(".")[0] # before first dot
+ if lora_name not in lora_dims and "lora_down" in new_key:
+ lora_dims[lora_name] = weight.shape[0]
+
+ # add alpha with rank
+ for lora_name, dim in lora_dims.items():
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
+
+ return new_weights_sd
+
+
+def convert_to_diffusers(prefix, weights_sd):
+ # convert from default LoRA to diffusers
+
+ # get alphas
+ lora_alphas = {}
+ for key, weight in weights_sd.items():
+ if key.startswith(prefix):
+ lora_name = key.split(".", 1)[0] # before first dot
+ if lora_name not in lora_alphas and "alpha" in key:
+ lora_alphas[lora_name] = weight
+
+ new_weights_sd = {}
+ for key, weight in weights_sd.items():
+ if key.startswith(prefix):
+ if "alpha" in key:
+ continue
+
+ lora_name = key.split(".", 1)[0] # before first dot
+
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
+ module_name = module_name.replace("_", ".") # replace "_" with "."
+ if ".cross.attn." in module_name or ".self.attn." in module_name:
+ # Wan2.1 lora name to module name: ugly but works
+ module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
+ module_name = module_name.replace("self.attn", "self_attn") # fix self attn
+ module_name = module_name.replace("k.img", "k_img") # fix k img
+ module_name = module_name.replace("v.img", "v_img") # fix v img
+ else:
+ # HunyuanVideo lora name to module name: ugly but works
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
+ module_name = module_name.replace("img.", "img_") # fix img
+ module_name = module_name.replace("txt.", "txt_") # fix txt
+ module_name = module_name.replace("attn.", "attn_") # fix attn
+
+ diffusers_prefix = "diffusion_model"
+ if "lora_down" in key:
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
+ dim = weight.shape[0]
+ elif "lora_up" in key:
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
+ dim = weight.shape[1]
+ else:
+ logger.warning(f"unexpected key: {key} in default LoRA format")
+ continue
+
+ # scale weight by alpha
+ if lora_name in lora_alphas:
+ # we scale both down and up, so scale is sqrt
+ scale = lora_alphas[lora_name] / dim
+ scale = scale.sqrt()
+ weight = weight * scale
+ else:
+ logger.warning(f"missing alpha for {lora_name}")
+
+ new_weights_sd[new_key] = weight
+
+ return new_weights_sd
+
+
+def convert(input_file, output_file, target_format):
+ logger.info(f"loading {input_file}")
+ weights_sd = load_file(input_file)
+ with safe_open(input_file, framework="pt") as f:
+ metadata = f.metadata()
+
+ logger.info(f"converting to {target_format}")
+ prefix = "lora_unet_"
+ if target_format == "default":
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
+ metadata = metadata or {}
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
+ elif target_format == "other":
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
+ else:
+ raise ValueError(f"unknown target format: {target_format}")
+
+ logger.info(f"saving to {output_file}")
+ save_file(new_weights_sd, output_file, metadata=metadata)
+
+ logger.info("done")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
+ parser.add_argument("--input", type=str, required=True, help="input model file")
+ parser.add_argument("--output", type=str, required=True, help="output model file")
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ convert(args.input, args.output, args.target)
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dataset/config_utils.py b/dataset/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..95eb38c961fac1b8792fe2a6a059314e0a4a47a3
--- /dev/null
+++ b/dataset/config_utils.py
@@ -0,0 +1,400 @@
+import argparse
+from dataclasses import (
+ asdict,
+ dataclass,
+)
+import functools
+import random
+from textwrap import dedent, indent
+import json
+from pathlib import Path
+
+# from toolz import curry
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import toml
+import voluptuous
+from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
+
+from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+@dataclass
+class BaseDatasetParams:
+ resolution: Tuple[int, int] = (960, 544)
+ enable_bucket: bool = False
+ bucket_no_upscale: bool = False
+ caption_extension: Optional[str] = None
+ batch_size: int = 1
+ num_repeats: int = 1
+ cache_directory: Optional[str] = None
+ debug_dataset: bool = False
+ architecture: str = "no_default" # short style like "hv" or "wan"
+
+
+@dataclass
+class ImageDatasetParams(BaseDatasetParams):
+ image_directory: Optional[str] = None
+ image_jsonl_file: Optional[str] = None
+ control_directory: Optional[str] = None
+
+ # FramePack dependent parameters
+ fp_latent_window_size: Optional[int] = 9
+ fp_1f_clean_indices: Optional[Sequence[int]] = None
+ fp_1f_target_index: Optional[int] = None
+ fp_1f_no_post: Optional[bool] = False
+
+
+@dataclass
+class VideoDatasetParams(BaseDatasetParams):
+ video_directory: Optional[str] = None
+ video_jsonl_file: Optional[str] = None
+ control_directory: Optional[str] = None
+ target_frames: Sequence[int] = (1,)
+ frame_extraction: Optional[str] = "head"
+ frame_stride: Optional[int] = 1
+ frame_sample: Optional[int] = 1
+ max_frames: Optional[int] = 129
+ source_fps: Optional[float] = None
+
+ # FramePack dependent parameters
+ fp_latent_window_size: Optional[int] = 9
+
+
+@dataclass
+class DatasetBlueprint:
+ is_image_dataset: bool
+ params: Union[ImageDatasetParams, VideoDatasetParams]
+
+
+@dataclass
+class DatasetGroupBlueprint:
+ datasets: Sequence[DatasetBlueprint]
+
+
+@dataclass
+class Blueprint:
+ dataset_group: DatasetGroupBlueprint
+
+
+class ConfigSanitizer:
+ # @curry
+ @staticmethod
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
+ Schema(ExactSequence([klass, klass]))(value)
+ return tuple(value)
+
+ # @curry
+ @staticmethod
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
+ try:
+ Schema(klass)(value)
+ return (value, value)
+ except:
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
+
+ # datasets schema
+ DATASET_ASCENDABLE_SCHEMA = {
+ "caption_extension": str,
+ "batch_size": int,
+ "num_repeats": int,
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
+ "enable_bucket": bool,
+ "bucket_no_upscale": bool,
+ }
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
+ "image_directory": str,
+ "image_jsonl_file": str,
+ "cache_directory": str,
+ "control_directory": str,
+ "fp_latent_window_size": int,
+ "fp_1f_clean_indices": [int],
+ "fp_1f_target_index": int,
+ "fp_1f_no_post": bool,
+ }
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
+ "video_directory": str,
+ "video_jsonl_file": str,
+ "control_directory": str,
+ "target_frames": [int],
+ "frame_extraction": str,
+ "frame_stride": int,
+ "frame_sample": int,
+ "max_frames": int,
+ "cache_directory": str,
+ "source_fps": float,
+ }
+
+ # options handled by argparse but not handled by user config
+ ARGPARSE_SPECIFIC_SCHEMA = {
+ "debug_dataset": bool,
+ }
+
+ def __init__(self) -> None:
+ self.image_dataset_schema = self.__merge_dict(
+ self.DATASET_ASCENDABLE_SCHEMA,
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
+ )
+ self.video_dataset_schema = self.__merge_dict(
+ self.DATASET_ASCENDABLE_SCHEMA,
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
+ )
+
+ def validate_flex_dataset(dataset_config: dict):
+ if "video_directory" in dataset_config or "video_jsonl_file" in dataset_config:
+ return Schema(self.video_dataset_schema)(dataset_config)
+ else:
+ return Schema(self.image_dataset_schema)(dataset_config)
+
+ self.dataset_schema = validate_flex_dataset
+
+ self.general_schema = self.__merge_dict(
+ self.DATASET_ASCENDABLE_SCHEMA,
+ )
+ self.user_config_validator = Schema(
+ {
+ "general": self.general_schema,
+ "datasets": [self.dataset_schema],
+ }
+ )
+ self.argparse_schema = self.__merge_dict(
+ self.ARGPARSE_SPECIFIC_SCHEMA,
+ )
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
+
+ def sanitize_user_config(self, user_config: dict) -> dict:
+ try:
+ return self.user_config_validator(user_config)
+ except MultipleInvalid:
+ # TODO: clarify the error message
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
+ raise
+
+ # NOTE: In nature, argument parser result is not needed to be sanitize
+ # However this will help us to detect program bug
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
+ try:
+ return self.argparse_config_validator(argparse_namespace)
+ except MultipleInvalid:
+ # XXX: this should be a bug
+ logger.error(
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
+ )
+ raise
+
+ # NOTE: value would be overwritten by latter dict if there is already the same key
+ @staticmethod
+ def __merge_dict(*dict_list: dict) -> dict:
+ merged = {}
+ for schema in dict_list:
+ # merged |= schema
+ for k, v in schema.items():
+ merged[k] = v
+ return merged
+
+
+class BlueprintGenerator:
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
+
+ def __init__(self, sanitizer: ConfigSanitizer):
+ self.sanitizer = sanitizer
+
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
+
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
+ general_config = sanitized_user_config.get("general", {})
+
+ dataset_blueprints = []
+ for dataset_config in sanitized_user_config.get("datasets", []):
+ is_image_dataset = "image_directory" in dataset_config or "image_jsonl_file" in dataset_config
+ if is_image_dataset:
+ dataset_params_klass = ImageDatasetParams
+ else:
+ dataset_params_klass = VideoDatasetParams
+
+ params = self.generate_params_by_fallbacks(
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
+ )
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
+
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
+
+ return Blueprint(dataset_group_blueprint)
+
+ @staticmethod
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
+ search_value = BlueprintGenerator.search_value
+ default_params = asdict(param_klass())
+ param_names = default_params.keys()
+
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
+
+ return param_klass(**params)
+
+ @staticmethod
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
+ for cand in fallbacks:
+ value = cand.get(key)
+ if value is not None:
+ return value
+
+ return default_value
+
+
+# if training is True, it will return a dataset group for training, otherwise for caching
+def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
+
+ for dataset_blueprint in dataset_group_blueprint.datasets:
+ if dataset_blueprint.is_image_dataset:
+ dataset_klass = ImageDataset
+ else:
+ dataset_klass = VideoDataset
+
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
+ datasets.append(dataset)
+
+ # assertion
+ cache_directories = [dataset.cache_directory for dataset in datasets]
+ num_of_unique_cache_directories = len(set(cache_directories))
+ if num_of_unique_cache_directories != len(cache_directories):
+ raise ValueError(
+ "cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
+ + " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
+ )
+
+ # print info
+ info = ""
+ for i, dataset in enumerate(datasets):
+ is_image_dataset = isinstance(dataset, ImageDataset)
+ info += dedent(
+ f"""\
+ [Dataset {i}]
+ is_image_dataset: {is_image_dataset}
+ resolution: {dataset.resolution}
+ batch_size: {dataset.batch_size}
+ num_repeats: {dataset.num_repeats}
+ caption_extension: "{dataset.caption_extension}"
+ enable_bucket: {dataset.enable_bucket}
+ bucket_no_upscale: {dataset.bucket_no_upscale}
+ cache_directory: "{dataset.cache_directory}"
+ debug_dataset: {dataset.debug_dataset}
+ """
+ )
+
+ if is_image_dataset:
+ info += indent(
+ dedent(
+ f"""\
+ image_directory: "{dataset.image_directory}"
+ image_jsonl_file: "{dataset.image_jsonl_file}"
+ fp_latent_window_size: {dataset.fp_latent_window_size}
+ fp_1f_clean_indices: {dataset.fp_1f_clean_indices}
+ fp_1f_target_index: {dataset.fp_1f_target_index}
+ fp_1f_no_post: {dataset.fp_1f_no_post}
+ \n"""
+ ),
+ " ",
+ )
+ else:
+ info += indent(
+ dedent(
+ f"""\
+ video_directory: "{dataset.video_directory}"
+ video_jsonl_file: "{dataset.video_jsonl_file}"
+ control_directory: "{dataset.control_directory}"
+ target_frames: {dataset.target_frames}
+ frame_extraction: {dataset.frame_extraction}
+ frame_stride: {dataset.frame_stride}
+ frame_sample: {dataset.frame_sample}
+ max_frames: {dataset.max_frames}
+ source_fps: {dataset.source_fps}
+ \n"""
+ ),
+ " ",
+ )
+ logger.info(f"{info}")
+
+ # make buckets first because it determines the length of dataset
+ # and set the same seed for all datasets
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
+ for i, dataset in enumerate(datasets):
+ # logger.info(f"[Dataset {i}]")
+ dataset.set_seed(seed)
+ if training:
+ dataset.prepare_for_training()
+
+ return DatasetGroup(datasets)
+
+
+def load_user_config(file: str) -> dict:
+ file: Path = Path(file)
+ if not file.is_file():
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
+
+ if file.name.lower().endswith(".json"):
+ try:
+ with open(file, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ except Exception:
+ logger.error(
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
+ )
+ raise
+ elif file.name.lower().endswith(".toml"):
+ try:
+ config = toml.load(file)
+ except Exception:
+ logger.error(
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
+ )
+ raise
+ else:
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
+
+ return config
+
+
+# for config test
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("dataset_config")
+ config_args, remain = parser.parse_known_args()
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--debug_dataset", action="store_true")
+ argparse_namespace = parser.parse_args(remain)
+
+ logger.info("[argparse_namespace]")
+ logger.info(f"{vars(argparse_namespace)}")
+
+ user_config = load_user_config(config_args.dataset_config)
+
+ logger.info("")
+ logger.info("[user_config]")
+ logger.info(f"{user_config}")
+
+ sanitizer = ConfigSanitizer()
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
+
+ logger.info("")
+ logger.info("[sanitized_user_config]")
+ logger.info(f"{sanitized_user_config}")
+
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
+
+ logger.info("")
+ logger.info("[blueprint]")
+ logger.info(f"{blueprint}")
+
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
diff --git a/dataset/dataset_config.md b/dataset/dataset_config.md
new file mode 100644
index 0000000000000000000000000000000000000000..a6d61978aac81e34a0a3768b05f3673adea34521
--- /dev/null
+++ b/dataset/dataset_config.md
@@ -0,0 +1,538 @@
+> 📝 Click on the language section to expand / 言語をクリックして展開
+
+## Dataset Configuration
+
+Please create a TOML file for dataset configuration.
+
+Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
+
+The cache directory must be different for each dataset.
+
+Each video is extracted frame by frame without additional processing and used for training. It is recommended to use videos with a frame rate of 24fps for HunyuanVideo, 16fps for Wan2.1 and 30fps for FramePack. You can check the videos that will be trained using `--debug_mode video` when caching latent (see [here](/README.md#latent-caching)).
+
+日本語
+
+データセットの設定を行うためのTOMLファイルを作成してください。
+
+画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
+
+キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
+
+動画は追加のプロセスなしでフレームごとに抽出され、学習に用いられます。そのため、HunyuanVideoは24fps、Wan2.1は16fps、FramePackは30fpsのフレームレートの動画を使用することをお勧めします。latentキャッシュ時の`--debug_mode video`を使用すると、学習される動画を確認できます([こちら](/README.ja.md#latentの事前キャッシュ)を参照)。
+
+
+### Sample for Image Dataset with Caption Text Files
+
+```toml
+# resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
+# otherwise, the default values will be used for each item
+
+# general configurations
+[general]
+resolution = [960, 544]
+caption_extension = ".txt"
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+image_directory = "/path/to/image_dir"
+cache_directory = "/path/to/cache_directory"
+num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+`cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
+
+`num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
+
+
+日本語
+
+`cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
+
+`num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
+
+resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
+
+`[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
+
+
+### Sample for Image Dataset with Metadata JSONL File
+
+```toml
+# resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
+# caption_extension is not required for metadata jsonl file
+# cache_directory is required for each dataset with metadata jsonl file
+
+# general configurations
+[general]
+resolution = [960, 544]
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+image_jsonl_file = "/path/to/metadata.jsonl"
+cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
+num_repeats = 1 # optional, default is 1. Same as above.
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+JSONL file format for metadata:
+
+```json
+{"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
+{"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
+```
+
+
+日本語
+
+resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
+
+metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
+
+キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
+
+
+
+### Sample for Video Dataset with Caption Text Files
+
+```toml
+# Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
+# can be set in either general or datasets sections
+# Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
+# must be set in each datasets section
+
+# general configurations
+[general]
+resolution = [960, 544]
+caption_extension = ".txt"
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+video_directory = "/path/to/video_dir"
+cache_directory = "/path/to/cache_directory" # recommended to set cache directory
+target_frames = [1, 25, 45]
+frame_extraction = "head"
+source_fps = 30.0 # optional, source fps for videos in the directory, decimal number
+
+[[datasets]]
+video_directory = "/path/to/video_dir2"
+cache_directory = "/path/to/cache_directory2" # recommended to set cache directory
+frame_extraction = "full"
+max_frames = 45
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+__In HunyuanVideo and Wan2.1, the number of `target_frames` must be "N\*4+1" (N=0,1,2,...).__ Otherwise, it will be truncated to the nearest "N*4+1".
+
+In FramePack, it is recommended to set `frame_extraction` to `full` and `max_frames` to a sufficiently large value, as it can handle longer videos. However, if the video is too long, an Out of Memory error may occur during VAE encoding. The videos in FramePack are trimmed to "N * latent_window_size * 4 + 1" frames (for example, 37, 73, 109... if `latent_window_size` is 9).
+
+If the `source_fps` is specified, the videos in the directory are considered to be at this frame rate, and some frames will be skipped to match the model's frame rate (24 for HunyuanVideo and 16 for Wan2.1). __The value must be a decimal number, for example, `30.0` instead of `30`.__ The skipping is done automatically and does not consider the content of the images. Please check if the converted data is correct using `--debug_mode video`.
+
+If `source_fps` is not specified (default), all frames of the video will be used regardless of the video's frame rate.
+
+
+日本語
+
+共通パラメータ(resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)は、generalまたはdatasetsのいずれかに設定できます。
+動画固有のパラメータ(target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)は、各datasetsセクションに設定する必要があります。
+
+__HunyuanVideoおよびWan2.1では、target_framesの数値は「N\*4+1」である必要があります。__ これ以外の値の場合は、最も近いN\*4+1の値に切り捨てられます。
+
+FramePackでも同様ですが、FramePackでは動画が長くても学習可能なため、 `frame_extraction`に`full` を指定し、`max_frames`を十分に大きな値に設定することをお勧めします。ただし、あまりにも長すぎるとVAEのencodeでOut of Memoryエラーが発生する可能性があります。FramePackの動画は、「N * latent_window_size * 4 + 1」フレームにトリミングされます(latent_window_sizeが9の場合、37、73、109……)。
+
+`source_fps`を指定した場合、ディレクトリ内の動画をこのフレームレートとみなして、モデルのフレームレートにあうようにいくつかのフレームをスキップします(HunyuanVideoは24、Wan2.1は16)。__小数点を含む数値で指定してください。__ 例:`30`ではなく`30.0`。スキップは機械的に行われ、画像の内容は考慮しません。変換後のデータが正しいか、`--debug_mode video`で確認してください。
+
+`source_fps`を指定しない場合、動画のフレームは(動画自体のフレームレートに関係なく)すべて使用されます。
+
+他の注意事項は画像データセットと同様です。
+
+
+### Sample for Video Dataset with Metadata JSONL File
+
+```toml
+# Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
+# can be set in either general or datasets sections
+# Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
+# must be set in each datasets section
+
+# caption_extension is not required for metadata jsonl file
+# cache_directory is required for each dataset with metadata jsonl file
+
+# general configurations
+[general]
+resolution = [960, 544]
+batch_size = 1
+enable_bucket = true
+bucket_no_upscale = false
+
+[[datasets]]
+video_jsonl_file = "/path/to/metadata.jsonl"
+target_frames = [1, 25, 45]
+frame_extraction = "head"
+cache_directory = "/path/to/cache_directory_head"
+source_fps = 30.0 # optional, source fps for videos in the jsonl file
+# same metadata jsonl file can be used for multiple datasets
+[[datasets]]
+video_jsonl_file = "/path/to/metadata.jsonl"
+target_frames = [1]
+frame_stride = 10
+cache_directory = "/path/to/cache_directory_stride"
+
+# other datasets can be added here. each dataset can have different configurations
+```
+
+JSONL file format for metadata:
+
+```json
+{"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
+{"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
+```
+
+`video_path` can be a directory containing multiple images.
+
+
+日本語
+metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
+
+`video_path`は、複数の画像を含むディレクトリのパスでも構いません。
+
+他の注意事項は今までのデータセットと同様です。
+
+
+### frame_extraction Options
+
+- `head`: Extract the first N frames from the video.
+- `chunk`: Extract frames by splitting the video into chunks of N frames.
+- `slide`: Extract frames from the video with a stride of `frame_stride`.
+- `uniform`: Extract `frame_sample` samples uniformly from the video.
+- `full`: Extract all frames from the video.
+
+In the case of `full`, the entire video is used, but it is trimmed to "N*4+1" frames. It is also trimmed to the `max_frames` if it exceeds that value. To avoid Out of Memory errors, please set `max_frames`.
+
+The frame extraction methods other than `full` are recommended when the video contains repeated actions. `full` is recommended when each video represents a single complete motion.
+
+For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
+
+
+日本語
+
+- `head`: 動画から最初のNフレームを抽出します。
+- `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
+- `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
+- `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
+- `full`: 動画から全てのフレームを抽出します。
+
+`full`の場合、各動画の全体を用いますが、「N*4+1」のフレーム数にトリミングされます。また`max_frames`を超える場合もその値にトリミングされます。Out of Memoryエラーを避けるために、`max_frames`を設定してください。
+
+`full`以外の抽出方法は、動画が特定の動作を繰り返している場合にお勧めします。`full`はそれぞれの動画がひとつの完結したモーションの場合にお勧めします。
+
+例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
+
+
+```
+Original Video, 40 frames: x = frame, o = no frame
+oooooooooooooooooooooooooooooooooooooooo
+
+head, target_frames = [1, 13, 25] -> extract head frames:
+xooooooooooooooooooooooooooooooooooooooo
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+
+chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+oooooooooooooxxxxxxxxxxxxxoooooooooooooo
+ooooooooooooooooooooooooooxxxxxxxxxxxxxo
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+
+NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
+注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
+
+slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
+xooooooooooooooooooooooooooooooooooooooo
+ooooooooooxooooooooooooooooooooooooooooo
+ooooooooooooooooooooxooooooooooooooooooo
+ooooooooooooooooooooooooooooooxooooooooo
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+ooooooooooxxxxxxxxxxxxxooooooooooooooooo
+ooooooooooooooooooooxxxxxxxxxxxxxooooooo
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
+
+uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
+xooooooooooooooooooooooooooooooooooooooo
+oooooooooooooxoooooooooooooooooooooooooo
+oooooooooooooooooooooooooxoooooooooooooo
+ooooooooooooooooooooooooooooooooooooooox
+xxxxxxxxxxxxxooooooooooooooooooooooooooo
+oooooooooxxxxxxxxxxxxxoooooooooooooooooo
+ooooooooooooooooooxxxxxxxxxxxxxooooooooo
+oooooooooooooooooooooooooooxxxxxxxxxxxxx
+xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
+oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
+ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
+oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
+
+Three Original Videos, 20, 25, 35 frames: x = frame, o = no frame
+
+full, max_frames = 31 -> extract all frames (trimmed to the maximum length):
+video1: xxxxxxxxxxxxxxxxx (trimmed to 17 frames)
+video2: xxxxxxxxxxxxxxxxxxxxxxxxx (25 frames)
+video3: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx (trimmed to 31 frames)
+```
+
+### Sample for Image Dataset with Control Images
+
+The dataset with control images. This is used for training the one frame training for FramePack.
+
+The dataset configuration with caption text files is similar to the image dataset, but with an additional `control_directory` parameter.
+
+The control images are used from the `control_directory` with the same filename (or different extension) as the image, for example, `image_dir/image1.jpg` and `control_dir/image1.png`. The images in `image_directory` should be the target images (the images to be generated during inference, the changed images). The `control_directory` should contain the starting images for inference. The captions should be stored in `image_directory`.
+
+If multiple control images are specified, the filenames of the control images should be numbered (excluding the extension). For example, specify `image_dir/image1.jpg` and `control_dir/image1_0.png`, `control_dir/image1_1.png`. You can also specify the numbers with four digits, such as `image1_0000.png`, `image1_0001.png`.
+
+The metadata JSONL file format is the same as the image dataset, but with an additional `control_path` parameter.
+
+```json
+{"image_path": "/path/to/image1.jpg", "control_path": "/path/to/control1.png", "caption": "A caption for image1"}
+{"image_path": "/path/to/image2.jpg", "control_path": "/path/to/control2.png", "caption": "A caption for image2"}
+
+If multiple control images are specified, the attribute names should be `control_path_0`, `control_path_1`, etc.
+
+```json
+{"image_path": "/path/to/image1.jpg", "control_path_0": "/path/to/control1_0.png", "control_path_1": "/path/to/control1_1.png", "caption": "A caption for image1"}
+{"image_path": "/path/to/image2.jpg", "control_path_0": "/path/to/control2_0.png", "control_path_1": "/path/to/control2_1.png", "caption": "A caption for image2"}
+```
+
+The control images can also have an alpha channel. In this case, the alpha channel of the image is used as a mask for the latent.
+
+
+日本語
+
+制御画像を持つデータセットです。現時点ではFramePackの単一フレーム学習に使用します。
+
+キャプションファイルを用いる場合は`control_directory`を追加で指定してください。制御画像は、画像と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある画像が使用されます(例:`image_dir/image1.jpg`と`control_dir/image1.png`)。`image_directory`の画像は学習対象の画像(推論時に生成する画像、変化後の画像)としてください。`control_directory`には推論時の開始画像を格納してください。キャプションは`image_directory`へ格納してください。
+
+複数枚の制御画像が指定可能です。この場合、制御画像のファイル名(拡張子を除く)へ数字を付与してください。例えば、`image_dir/image1.jpg`と`control_dir/image1_0.png`, `control_dir/image1_1.png`のように指定します。`image1_0000.png`, `image1_0001.png`のように数字を4桁で指定することもできます。
+
+メタデータJSONLファイルを使用する場合は、`control_path`を追加してください。複数枚の制御画像を指定する場合は、`control_path_0`, `control_path_1`のように数字を付与してください。
+
+制御画像はアルファチャンネルを持つこともできます。この場合、画像のアルファチャンネルはlatentへのマスクとして使用されます。
+
+
+
+### Sample for Video Dataset with Control Images
+
+The dataset with control videos is used for training ControlNet models.
+
+The dataset configuration with caption text files is similar to the video dataset, but with an additional `control_directory` parameter.
+
+The control video for a video is used from the `control_directory` with the same filename (or different extension) as the video, for example, `video_dir/video1.mp4` and `control_dir/video1.mp4` or `control_dir/video1.mov`. The control video can also be a directory without an extension, for example, `video_dir/video1.mp4` and `control_dir/video1`.
+
+```toml
+[[datasets]]
+video_directory = "/path/to/video_dir"
+control_directory = "/path/to/control_dir" # required for dataset with control videos
+cache_directory = "/path/to/cache_directory" # recommended to set cache directory
+target_frames = [1, 25, 45]
+frame_extraction = "head"
+```
+
+The dataset configuration with metadata JSONL file is same as the video dataset, but metadata JSONL file must include the control video paths. The control video path can be a directory containing multiple images.
+
+```json
+{"video_path": "/path/to/video1.mp4", "control_path": "/path/to/control1.mp4", "caption": "A caption for video1"}
+{"video_path": "/path/to/video2.mp4", "control_path": "/path/to/control2.mp4", "caption": "A caption for video2"}
+```
+
+
+日本語
+
+制御動画を持つデータセットです。ControlNetモデルの学習に使用します。
+
+キャプションを用いる場合のデータセット設定は動画データセットと似ていますが、`control_directory`パラメータが追加されています。上にある例を参照してください。ある動画に対する制御用動画として、動画と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある動画が使用されます(例:`video_dir/video1.mp4`と`control_dir/video1.mp4`または`control_dir/video1.mov`)。また、拡張子なしのディレクトリ内の、複数枚の画像を制御用動画として使用することもできます(例:`video_dir/video1.mp4`と`control_dir/video1`)。
+
+データセット設定でメタデータJSONLファイルを使用する場合は、動画と制御用動画のパスを含める必要があります。制御用動画のパスは、複数枚の画像を含むディレクトリのパスでも構いません。
+
+
+
+## Architecture-specific Settings / アーキテクチャ固有の設定
+
+The dataset configuration is shared across all architectures. However, some architectures may require additional settings or have specific requirements for the dataset.
+
+### FramePack
+
+For FramePack, you can set the latent window size for training. It is recommended to set it to 9 for FramePack training. The default value is 9, so you can usually omit this setting.
+
+```toml
+[[datasets]]
+fp_latent_window_size = 9
+```
+
+
+日本語
+
+学習時のlatent window sizeを指定できます。FramePackの学習においては、9を指定することを推奨します。省略時は9が使用されますので、通常は省略して構いません。
+
+
+
+### FramePack One Frame Training
+
+For the default one frame training of FramePack, you need to set the following parameters in the dataset configuration:
+
+```toml
+[[datasets]]
+fp_1f_clean_indices = [0]
+fp_1f_target_index = 9
+fp_1f_no_post = false
+```
+
+**Advanced Settings:**
+
+**Note that these parameters are still experimental, and the optimal values are not yet known.** The parameters may also change in the future.
+
+`fp_1f_clean_indices` sets the `clean_indices` value passed to the FramePack model. You can specify multiple indices. `fp_1f_target_index` sets the index of the frame to be trained (generated). `fp_1f_no_post` sets whether to add a zero value as `clean_latent_post`, default is `false` (add zero value).
+
+The number of control images should match the number of indices specified in `fp_1f_clean_indices`.
+
+The default values mean that the first image (control image) is at index `0`, and the target image (the changed image) is at index `9`.
+
+For training with 1f-mc, set `fp_1f_clean_indices` to `[0, 1]` and `fp_1f_target_index` to `9` (or another value). This allows you to use multiple control images to train a single generated image. The control images will be two in this case.
+
+```toml
+[[datasets]]
+fp_1f_clean_indices = [0, 1]
+fp_1f_target_index = 9
+fp_1f_no_post = false
+```
+
+For training with kisekaeichi, set `fp_1f_clean_indices` to `[0, 10]` and `fp_1f_target_index` to `1` (or another value). This allows you to use the starting image (the image just before the generation section) and the image following the generation section (equivalent to `clean_latent_post`) to train the first image of the generated video. The control images will be two in this case. `fp_1f_no_post` should be set to `true`.
+
+```toml
+[[datasets]]
+fp_1f_clean_indices = [0, 10]
+fp_1f_target_index = 1
+fp_1f_no_post = true
+```
+
+With `fp_1f_clean_indices` and `fp_1f_target_index`, you can specify any number of control images and any index of the target image for training.
+
+If you set `fp_1f_no_post` to `false`, the `clean_latent_post_index` will be `1 + fp1_latent_window_size`.
+
+You can also set the `no_2x` and `no_4x` options for cache scripts to disable the clean latents 2x and 4x.
+
+The 2x indices are `1 + fp1_latent_window_size + 1` for two indices (usually `11, 12`), and the 4x indices are `1 + fp1_latent_window_size + 1 + 2` for sixteen indices (usually `13, 14, ..., 28`), regardless of `fp_1f_no_post` and `no_2x`, `no_4x` settings.
+
+
+日本語
+
+※ **以下のパラメータは研究中で最適値はまだ不明です。** またパラメータ自体も変更される可能性があります。
+
+デフォルトの1フレーム学習を行う場合、`fp_1f_clean_indices`に`[0]`を、`fp_1f_target_index`に`9`(または5から15程度の値)を、`no_post`に`false`を設定してください。(記述例は英語版ドキュメントを参照、以降同じ。)
+
+**より高度な設定:**
+
+`fp_1f_clean_indices`は、FramePackモデルに渡される `clean_indices` の値を設定します。複数指定が可能です。`fp_1f_target_index`は、学習(生成)対象のフレームのインデックスを設定します。`fp_1f_no_post`は、`clean_latent_post` をゼロ値で追加するかどうかを設定します(デフォルトは`false`で、ゼロ値で追加します)。
+
+制御画像の枚数は`fp_1f_clean_indices`に指定したインデックスの数とあわせてください。
+
+デフォルトの1フレーム学習では、開始画像(制御画像)1枚をインデックス`0`、生成対象の画像(変化後の画像)をインデックス`9`に設定しています。
+
+1f-mcの学習を行う場合は、`fp_1f_clean_indices`に `[0, 1]`を、`fp_1f_target_index`に`9`を設定してください。これにより動画の先頭の2枚の制御画像を使用して、後続の1枚の生成画像を学習します。制御画像は2枚になります。
+
+kisekaeichiの学習を行う場合は、`fp_1f_clean_indices`に `[0, 10]`を、`fp_1f_target_index`に`1`(または他の値)を設定してください。これは、開始画像(生成セクションの直前の画像)(`clean_latent_pre`に相当)と、生成セクションに続く1枚の画像(`clean_latent_post`に相当)を使用して、生成動画の先頭の画像(`target_index=1`)を学習します。制御画像は2枚になります。`f1_1f_no_post`は`true`に設定してください。
+
+`fp_1f_clean_indices`と`fp_1f_target_index`を応用することで、任意の枚数の制御画像を、任意のインデックスを指定して学習することが可能です。
+
+`fp_1f_no_post`を`false`に設定すると、`clean_latent_post_index`は `1 + fp1_latent_window_size` になります。
+
+推論時の `no_2x`、`no_4x`に対応する設定は、キャッシュスクリプトの引数で行えます。なお、2xのindexは `1 + fp1_latent_window_size + 1` からの2個(通常は`11, 12`)、4xのindexは `1 + fp1_latent_window_size + 1 + 2` からの16個になります(通常は`13, 14, ..., 28`)です。これらの値は`fp_1f_no_post`や`no_2x`, `no_4x`の設定に関わらず、常に同じです。
+
+
+
+## Specifications
+
+```toml
+# general configurations
+[general]
+resolution = [960, 544] # optional, [W, H], default is [960, 544]. This is the default resolution for all datasets
+caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
+batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
+num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
+enable_bucket = true # optional, default is false. Enable bucketing for datasets
+bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
+
+### Image Dataset
+
+# sample image dataset with caption text files
+[[datasets]]
+image_directory = "/path/to/image_dir"
+caption_extension = ".txt" # required for caption text files, if general caption extension is not set
+resolution = [960, 544] # required if general resolution is not set
+batch_size = 4 # optional, overwrite the default batch size
+num_repeats = 1 # optional, overwrite the default num_repeats
+enable_bucket = false # optional, overwrite the default bucketing setting
+bucket_no_upscale = true # optional, overwrite the default bucketing setting
+cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
+control_directory = "/path/to/control_dir" # optional, required for dataset with control images
+
+# sample image dataset with metadata **jsonl** file
+[[datasets]]
+image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
+resolution = [960, 544] # required if general resolution is not set
+cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
+# caption_extension is not required for metadata jsonl file
+# batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
+
+### Video Dataset
+
+# sample video dataset with caption text files
+[[datasets]]
+video_directory = "/path/to/video_dir"
+caption_extension = ".txt" # required for caption text files, if general caption extension is not set
+resolution = [960, 544] # required if general resolution is not set
+
+control_directory = "/path/to/control_dir" # optional, required for dataset with control images
+
+# following configurations must be set in each [[datasets]] section for video datasets
+
+target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
+
+# NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
+
+frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
+frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
+frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
+max_frames = 129 # optional, default is 129. Maximum number of frames to extract, available for "full" frame extraction
+# batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
+
+# sample video dataset with metadata jsonl file
+[[datasets]]
+video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
+
+target_frames = [1, 79]
+
+cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
+# frame_extraction, frame_stride, frame_sample, max_frames are also available for metadata jsonl file
+```
+
+
+
+The metadata with .json file will be supported in the near future.
+
+
+
diff --git a/dataset/image_video_dataset.py b/dataset/image_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e16de7fc0f5099afca7757dea3af2a84ee96fa8
--- /dev/null
+++ b/dataset/image_video_dataset.py
@@ -0,0 +1,1889 @@
+from concurrent.futures import ThreadPoolExecutor
+import glob
+import json
+import math
+import os
+import random
+import time
+from typing import Any, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+from safetensors.torch import save_file, load_file
+from safetensors import safe_open
+from PIL import Image
+import cv2
+import av
+
+from utils import safetensors_utils
+from utils.model_utils import dtype_to_str
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
+
+try:
+ import pillow_avif
+
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
+except:
+ pass
+
+# JPEG-XL on Linux
+try:
+ from jxlpy import JXLImagePlugin
+
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
+except:
+ pass
+
+# JPEG-XL on Windows
+try:
+ import pillow_jxl
+
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
+except:
+ pass
+
+VIDEO_EXTENSIONS = [
+ ".mp4",
+ ".webm",
+ ".avi",
+ ".mkv",
+ ".mov",
+ ".flv",
+ ".wmv",
+ ".m4v",
+ ".mpg",
+ ".mpeg",
+ ".MP4",
+ ".WEBM",
+ ".AVI",
+ ".MKV",
+ ".MOV",
+ ".FLV",
+ ".WMV",
+ ".M4V",
+ ".MPG",
+ ".MPEG",
+] # some of them are not tested
+
+ARCHITECTURE_HUNYUAN_VIDEO = "hv"
+ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
+ARCHITECTURE_WAN = "wan"
+ARCHITECTURE_WAN_FULL = "wan"
+ARCHITECTURE_FRAMEPACK = "fp"
+ARCHITECTURE_FRAMEPACK_FULL = "framepack"
+
+
+def glob_images(directory, base="*"):
+ img_paths = []
+ for ext in IMAGE_EXTENSIONS:
+ if base == "*":
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
+ else:
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
+ img_paths = list(set(img_paths)) # remove duplicates
+ img_paths.sort()
+ return img_paths
+
+
+def glob_videos(directory, base="*"):
+ video_paths = []
+ for ext in VIDEO_EXTENSIONS:
+ if base == "*":
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
+ else:
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
+ video_paths = list(set(video_paths)) # remove duplicates
+ video_paths.sort()
+ return video_paths
+
+
+def divisible_by(num: int, divisor: int) -> int:
+ return num - num % divisor
+
+
+def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
+ """
+ Resize the image to the bucket resolution.
+
+ bucket_reso: **(width, height)**
+ """
+ is_pil_image = isinstance(image, Image.Image)
+ if is_pil_image:
+ image_width, image_height = image.size
+ else:
+ image_height, image_width = image.shape[:2]
+
+ if bucket_reso == (image_width, image_height):
+ return np.array(image) if is_pil_image else image
+
+ bucket_width, bucket_height = bucket_reso
+
+ # resize the image to the bucket resolution to match the short side
+ scale_width = bucket_width / image_width
+ scale_height = bucket_height / image_height
+ scale = max(scale_width, scale_height)
+ image_width = int(image_width * scale + 0.5)
+ image_height = int(image_height * scale + 0.5)
+
+ if scale > 1:
+ image = Image.fromarray(image) if not is_pil_image else image
+ image = image.resize((image_width, image_height), Image.LANCZOS)
+ image = np.array(image)
+ else:
+ image = np.array(image) if is_pil_image else image
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
+
+ # crop the image to the bucket resolution
+ crop_left = (image_width - bucket_width) // 2
+ crop_top = (image_height - bucket_height) // 2
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
+ return image
+
+
+class ItemInfo:
+ def __init__(
+ self,
+ item_key: str,
+ caption: str,
+ original_size: tuple[int, int],
+ bucket_size: Optional[tuple[Any]] = None,
+ frame_count: Optional[int] = None,
+ content: Optional[np.ndarray] = None,
+ latent_cache_path: Optional[str] = None,
+ ) -> None:
+ self.item_key = item_key
+ self.caption = caption
+ self.original_size = original_size
+ self.bucket_size = bucket_size
+ self.frame_count = frame_count
+ self.content = content
+ self.latent_cache_path = latent_cache_path
+ self.text_encoder_output_cache_path: Optional[str] = None
+
+ # np.ndarray for video, list[np.ndarray] for image with multiple controls
+ self.control_content: Optional[Union[np.ndarray, list[np.ndarray]]] = None
+
+ # FramePack architecture specific
+ self.fp_latent_window_size: Optional[int] = None
+ self.fp_1f_clean_indices: Optional[list[int]] = None # indices of clean latents for 1f
+ self.fp_1f_target_index: Optional[int] = None # target index for 1f clean latents
+ self.fp_1f_no_post: Optional[bool] = None # whether to add zero values as clean latent post
+
+ def __str__(self) -> str:
+ return (
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path}, content={self.content.shape if self.content is not None else None})"
+ )
+
+
+# We use simple if-else approach to support multiple architectures.
+# Maybe we can use a plugin system in the future.
+
+# the keys of the dict are `_FxHxW_` for latents
+# and `_` for other tensors
+
+
+def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
+ """HunyuanVideo architecture only. HunyuanVideo doesn't support I2V and control latents"""
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
+
+ _, F, H, W = latent.shape
+ dtype_str = dtype_to_str(latent.dtype)
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
+
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
+
+
+def save_latent_cache_wan(
+ item_info: ItemInfo,
+ latent: torch.Tensor,
+ clip_embed: Optional[torch.Tensor],
+ image_latent: Optional[torch.Tensor],
+ control_latent: Optional[torch.Tensor],
+):
+ """Wan architecture only"""
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
+
+ _, F, H, W = latent.shape
+ dtype_str = dtype_to_str(latent.dtype)
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
+
+ if clip_embed is not None:
+ sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
+
+ if image_latent is not None:
+ sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
+
+ if control_latent is not None:
+ sd[f"latents_control_{F}x{H}x{W}_{dtype_str}"] = control_latent.detach().cpu()
+
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
+
+
+def save_latent_cache_framepack(
+ item_info: ItemInfo,
+ latent: torch.Tensor,
+ latent_indices: torch.Tensor,
+ clean_latents: torch.Tensor,
+ clean_latent_indices: torch.Tensor,
+ clean_latents_2x: torch.Tensor,
+ clean_latent_2x_indices: torch.Tensor,
+ clean_latents_4x: torch.Tensor,
+ clean_latent_4x_indices: torch.Tensor,
+ image_embeddings: torch.Tensor,
+):
+ """FramePack architecture only"""
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
+
+ _, F, H, W = latent.shape
+ dtype_str = dtype_to_str(latent.dtype)
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu().contiguous()}
+
+ # `latents_xxx` must have {F, H, W} suffix
+ indices_dtype_str = dtype_to_str(latent_indices.dtype)
+ sd[f"image_embeddings_{dtype_str}"] = image_embeddings.detach().cpu() # image embeddings dtype is same as latents dtype
+ sd[f"latent_indices_{indices_dtype_str}"] = latent_indices.detach().cpu()
+ sd[f"clean_latent_indices_{indices_dtype_str}"] = clean_latent_indices.detach().cpu()
+ sd[f"latents_clean_{F}x{H}x{W}_{dtype_str}"] = clean_latents.detach().cpu().contiguous()
+ if clean_latent_2x_indices is not None:
+ sd[f"clean_latent_2x_indices_{indices_dtype_str}"] = clean_latent_2x_indices.detach().cpu()
+ if clean_latents_2x is not None:
+ sd[f"latents_clean_2x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_2x.detach().cpu().contiguous()
+ if clean_latent_4x_indices is not None:
+ sd[f"clean_latent_4x_indices_{indices_dtype_str}"] = clean_latent_4x_indices.detach().cpu()
+ if clean_latents_4x is not None:
+ sd[f"latents_clean_4x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_4x.detach().cpu().contiguous()
+
+ # for key, value in sd.items():
+ # print(f"{key}: {value.shape}")
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
+
+
+def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
+ metadata = {
+ "architecture": arch_fullname,
+ "width": f"{item_info.original_size[0]}",
+ "height": f"{item_info.original_size[1]}",
+ "format_version": "1.0.1",
+ }
+ if item_info.frame_count is not None:
+ metadata["frame_count"] = f"{item_info.frame_count}"
+
+ for key, value in sd.items():
+ # NaN check and show warning, replace NaN with 0
+ if torch.isnan(value).any():
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
+ value[torch.isnan(value)] = 0
+
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
+ os.makedirs(latent_dir, exist_ok=True)
+
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
+
+
+def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
+ """HunyuanVideo architecture only"""
+ assert (
+ embed.dim() == 1 or embed.dim() == 2
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
+
+ sd = {}
+ dtype_str = dtype_to_str(embed.dtype)
+ text_encoder_type = "llm" if is_llm else "clipL"
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
+ if mask is not None:
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
+
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
+
+
+def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
+ """Wan architecture only. Wan2.1 only has a single text encoder"""
+
+ sd = {}
+ dtype_str = dtype_to_str(embed.dtype)
+ text_encoder_type = "t5"
+ sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
+
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
+
+
+def save_text_encoder_output_cache_framepack(
+ item_info: ItemInfo, llama_vec: torch.Tensor, llama_attention_mask: torch.Tensor, clip_l_pooler: torch.Tensor
+):
+ """FramePack architecture only."""
+ sd = {}
+ dtype_str = dtype_to_str(llama_vec.dtype)
+ sd[f"llama_vec_{dtype_str}"] = llama_vec.detach().cpu()
+ sd[f"llama_attention_mask"] = llama_attention_mask.detach().cpu()
+ dtype_str = dtype_to_str(clip_l_pooler.dtype)
+ sd[f"clip_l_pooler_{dtype_str}"] = clip_l_pooler.detach().cpu()
+
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
+
+
+def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
+ for key, value in sd.items():
+ # NaN check and show warning, replace NaN with 0
+ if torch.isnan(value).any():
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
+ value[torch.isnan(value)] = 0
+
+ metadata = {
+ "architecture": arch_fullname,
+ "caption1": item_info.caption,
+ "format_version": "1.0.1",
+ }
+
+ if os.path.exists(item_info.text_encoder_output_cache_path):
+ # load existing cache and update metadata
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
+ existing_metadata = f.metadata()
+ for key in f.keys():
+ if key not in sd: # avoid overwriting by existing cache, we keep the new one
+ sd[key] = f.get_tensor(key)
+
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
+ if existing_metadata["caption1"] != metadata["caption1"]:
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
+ # TODO verify format_version
+
+ existing_metadata.pop("caption1", None)
+ existing_metadata.pop("format_version", None)
+ metadata.update(existing_metadata) # copy existing metadata except caption and format_version
+ else:
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
+
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
+
+
+class BucketSelector:
+ RESOLUTION_STEPS_HUNYUAN = 16
+ RESOLUTION_STEPS_WAN = 16
+ RESOLUTION_STEPS_FRAMEPACK = 16
+
+ def __init__(
+ self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
+ ):
+ self.resolution = resolution
+ self.bucket_area = resolution[0] * resolution[1]
+ self.architecture = architecture
+
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
+ elif self.architecture == ARCHITECTURE_WAN:
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
+ elif self.architecture == ARCHITECTURE_FRAMEPACK:
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_FRAMEPACK
+ else:
+ raise ValueError(f"Invalid architecture: {self.architecture}")
+
+ if not enable_bucket:
+ # only define one bucket
+ self.bucket_resolutions = [resolution]
+ self.no_upscale = False
+ else:
+ # prepare bucket resolution
+ self.no_upscale = no_upscale
+ sqrt_size = int(math.sqrt(self.bucket_area))
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
+ self.bucket_resolutions = []
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
+ self.bucket_resolutions.append((w, h))
+ self.bucket_resolutions.append((h, w))
+
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
+ self.bucket_resolutions.sort()
+
+ # calculate aspect ratio to find the nearest resolution
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
+
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
+ """
+ return the bucket resolution for the given image size, (width, height)
+ """
+ area = image_size[0] * image_size[1]
+ if self.no_upscale and area <= self.bucket_area:
+ w, h = image_size
+ w = divisible_by(w, self.reso_steps)
+ h = divisible_by(h, self.reso_steps)
+ return w, h
+
+ aspect_ratio = image_size[0] / image_size[1]
+ ar_errors = self.aspect_ratios - aspect_ratio
+ bucket_id = np.abs(ar_errors).argmin()
+ return self.bucket_resolutions[bucket_id]
+
+
+def load_video(
+ video_path: str,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ bucket_reso: Optional[tuple[int, int]] = None,
+ source_fps: Optional[float] = None,
+ target_fps: Optional[float] = None,
+) -> list[np.ndarray]:
+ """
+ bucket_reso: if given, resize the video to the bucket resolution, (width, height)
+ """
+ if source_fps is None or target_fps is None:
+ if os.path.isfile(video_path):
+ container = av.open(video_path)
+ video = []
+ for i, frame in enumerate(container.decode(video=0)):
+ if start_frame is not None and i < start_frame:
+ continue
+ if end_frame is not None and i >= end_frame:
+ break
+ frame = frame.to_image()
+
+ if bucket_selector is not None and bucket_reso is None:
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
+
+ if bucket_reso is not None:
+ frame = resize_image_to_bucket(frame, bucket_reso)
+ else:
+ frame = np.array(frame)
+
+ video.append(frame)
+ container.close()
+ else:
+ # load images in the directory
+ image_files = glob_images(video_path)
+ image_files.sort()
+ video = []
+ for i in range(len(image_files)):
+ if start_frame is not None and i < start_frame:
+ continue
+ if end_frame is not None and i >= end_frame:
+ break
+
+ image_file = image_files[i]
+ image = Image.open(image_file).convert("RGB")
+
+ if bucket_selector is not None and bucket_reso is None:
+ bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
+ image = np.array(image)
+ if bucket_reso is not None:
+ image = resize_image_to_bucket(image, bucket_reso)
+
+ video.append(image)
+ else:
+ # drop frames to match the target fps TODO commonize this code with the above if this works
+ frame_index_delta = target_fps / source_fps # example: 16 / 30 = 0.5333
+ if os.path.isfile(video_path):
+ container = av.open(video_path)
+ video = []
+ frame_index_with_fraction = 0.0
+ previous_frame_index = -1
+ for i, frame in enumerate(container.decode(video=0)):
+ target_frame_index = int(frame_index_with_fraction)
+ frame_index_with_fraction += frame_index_delta
+
+ if target_frame_index == previous_frame_index: # drop this frame
+ continue
+
+ # accept this frame
+ previous_frame_index = target_frame_index
+
+ if start_frame is not None and target_frame_index < start_frame:
+ continue
+ if end_frame is not None and target_frame_index >= end_frame:
+ break
+ frame = frame.to_image()
+
+ if bucket_selector is not None and bucket_reso is None:
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
+
+ if bucket_reso is not None:
+ frame = resize_image_to_bucket(frame, bucket_reso)
+ else:
+ frame = np.array(frame)
+
+ video.append(frame)
+ container.close()
+ else:
+ # load images in the directory
+ image_files = glob_images(video_path)
+ image_files.sort()
+ video = []
+ frame_index_with_fraction = 0.0
+ previous_frame_index = -1
+ for i in range(len(image_files)):
+ target_frame_index = int(frame_index_with_fraction)
+ frame_index_with_fraction += frame_index_delta
+
+ if target_frame_index == previous_frame_index: # drop this frame
+ continue
+
+ # accept this frame
+ previous_frame_index = target_frame_index
+
+ if start_frame is not None and target_frame_index < start_frame:
+ continue
+ if end_frame is not None and target_frame_index >= end_frame:
+ break
+
+ image_file = image_files[i]
+ image = Image.open(image_file).convert("RGB")
+
+ if bucket_selector is not None and bucket_reso is None:
+ bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
+ image = np.array(image)
+ if bucket_reso is not None:
+ image = resize_image_to_bucket(image, bucket_reso)
+
+ video.append(image)
+
+ return video
+
+
+class BucketBatchManager:
+
+ def __init__(self, bucketed_item_info: dict[tuple[Any], list[ItemInfo]], batch_size: int):
+ self.batch_size = batch_size
+ self.buckets = bucketed_item_info
+ self.bucket_resos = list(self.buckets.keys())
+ self.bucket_resos.sort()
+
+ # indices for enumerating batches. each batch is reso + batch_idx. reso is (width, height) or (width, height, frames)
+ self.bucket_batch_indices: list[tuple[tuple[Any], int]] = []
+ for bucket_reso in self.bucket_resos:
+ bucket = self.buckets[bucket_reso]
+ num_batches = math.ceil(len(bucket) / self.batch_size)
+ for i in range(num_batches):
+ self.bucket_batch_indices.append((bucket_reso, i))
+
+ # do no shuffle here to avoid multiple datasets have different order
+ # self.shuffle()
+
+ def show_bucket_info(self):
+ for bucket_reso in self.bucket_resos:
+ bucket = self.buckets[bucket_reso]
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
+
+ logger.info(f"total batches: {len(self)}")
+
+ def shuffle(self):
+ # shuffle each bucket
+ for bucket in self.buckets.values():
+ random.shuffle(bucket)
+
+ # shuffle the order of batches
+ random.shuffle(self.bucket_batch_indices)
+
+ def __len__(self):
+ return len(self.bucket_batch_indices)
+
+ def __getitem__(self, idx):
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
+ bucket = self.buckets[bucket_reso]
+ start = batch_idx * self.batch_size
+ end = min(start + self.batch_size, len(bucket))
+
+ batch_tensor_data = {}
+ varlen_keys = set()
+ for item_info in bucket[start:end]:
+ sd_latent = load_file(item_info.latent_cache_path)
+ sd_te = load_file(item_info.text_encoder_output_cache_path)
+ sd = {**sd_latent, **sd_te}
+
+ # TODO refactor this
+ for key in sd.keys():
+ is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
+ content_key = key
+
+ if is_varlen_key:
+ content_key = content_key.replace("varlen_", "")
+
+ if content_key.endswith("_mask"):
+ pass
+ else:
+ content_key = content_key.rsplit("_", 1)[0] # remove dtype
+ if content_key.startswith("latents_"):
+ content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
+
+ if content_key not in batch_tensor_data:
+ batch_tensor_data[content_key] = []
+ batch_tensor_data[content_key].append(sd[key])
+
+ if is_varlen_key:
+ varlen_keys.add(content_key)
+
+ for key in batch_tensor_data.keys():
+ if key not in varlen_keys:
+ batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
+
+ return batch_tensor_data
+
+
+class ContentDatasource:
+ def __init__(self):
+ self.caption_only = False # set to True to only fetch caption for Text Encoder caching
+ self.has_control = False
+
+ def set_caption_only(self, caption_only: bool):
+ self.caption_only = caption_only
+
+ def is_indexable(self):
+ return False
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ """
+ Returns caption. May not be called if is_indexable() returns False.
+ """
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __next__(self):
+ raise NotImplementedError
+
+
+class ImageDatasource(ContentDatasource):
+ def __init__(self):
+ super().__init__()
+
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
+ """
+ Returns image data as a tuple of image path, image, and caption for the given index.
+ Key must be unique and valid as a file name.
+ May not be called if is_indexable() returns False.
+ """
+ raise NotImplementedError
+
+
+class ImageDirectoryDatasource(ImageDatasource):
+ def __init__(
+ self,
+ image_directory: str,
+ caption_extension: Optional[str] = None,
+ control_directory: Optional[str] = None,
+ control_count_per_image: int = 1,
+ ):
+ super().__init__()
+ self.image_directory = image_directory
+ self.caption_extension = caption_extension
+ self.control_directory = control_directory
+ self.control_count_per_image = control_count_per_image
+ self.current_idx = 0
+
+ # glob images
+ logger.info(f"glob images in {self.image_directory}")
+ self.image_paths = glob_images(self.image_directory)
+ logger.info(f"found {len(self.image_paths)} images")
+
+ # glob control images if specified
+ if self.control_directory is not None:
+ logger.info(f"glob control images in {self.control_directory}")
+ self.has_control = True
+ self.control_paths = {}
+ for image_path in self.image_paths:
+ image_basename = os.path.basename(image_path)
+ image_basename_no_ext = os.path.splitext(image_basename)[0]
+ potential_paths = glob.glob(os.path.join(self.control_directory, os.path.splitext(image_basename)[0] + "*.*"))
+ if potential_paths:
+ # sort by the digits (`_0000`) suffix, prefer the one without the suffix
+ def sort_key(path):
+ basename = os.path.basename(path)
+ basename_no_ext = os.path.splitext(basename)[0]
+ if image_basename_no_ext == basename_no_ext: # prefer the one without suffix
+ return 0
+ digits_suffix = basename_no_ext.rsplit("_", 1)[-1]
+ if not digits_suffix.isdigit():
+ raise ValueError(f"Invalid digits suffix in {basename_no_ext}")
+ return int(digits_suffix) + 1
+
+ potential_paths.sort(key=sort_key)
+ if len(potential_paths) < control_count_per_image:
+ logger.error(
+ f"Not enough control images for {image_path}: found {len(potential_paths)}, expected {control_count_per_image}"
+ )
+ raise ValueError(
+ f"Not enough control images for {image_path}: found {len(potential_paths)}, expected {control_count_per_image}"
+ )
+
+ # take the first `control_count_per_image` paths
+ self.control_paths[image_path] = potential_paths[:control_count_per_image]
+ logger.info(f"found {len(self.control_paths)} matching control images")
+
+ missing_controls = len(self.image_paths) - len(self.control_paths)
+ if missing_controls > 0:
+ missing_control_paths = set(self.image_paths) - set(self.control_paths.keys())
+ logger.error(f"Could not find matching control images for {missing_controls} images: {missing_control_paths}")
+ raise ValueError(f"Could not find matching control images for {missing_controls} images")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]:
+ image_path = self.image_paths[idx]
+ image = Image.open(image_path).convert("RGB")
+
+ _, caption = self.get_caption(idx)
+
+ controls = None
+ if self.has_control:
+ controls = []
+ for control_path in self.control_paths[image_path]:
+ control = Image.open(control_path)
+ if control.mode != "RGB" and control.mode != "RGBA":
+ control = control.convert("RGB")
+ controls.append(control)
+
+ return image_path, image, caption, controls
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ image_path = self.image_paths[idx]
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
+ with open(caption_path, "r", encoding="utf-8") as f:
+ caption = f.read().strip()
+ return image_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self) -> callable:
+ """
+ Returns a fetcher function that returns image data.
+ """
+ if self.current_idx >= len(self.image_paths):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+ else:
+
+ def create_image_fetcher(index):
+ return lambda: self.get_image_data(index)
+
+ fetcher = create_image_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class ImageJsonlDatasource(ImageDatasource):
+ def __init__(self, image_jsonl_file: str, control_count_per_image: int = 1):
+ super().__init__()
+ self.image_jsonl_file = image_jsonl_file
+ self.control_count_per_image = control_count_per_image
+ self.current_idx = 0
+
+ # load jsonl
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
+ self.data = []
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
+ for line in f:
+ try:
+ data = json.loads(line)
+ except json.JSONDecodeError:
+ logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
+ raise
+ self.data.append(data)
+ logger.info(f"loaded {len(self.data)} images")
+
+ # Normalize control paths
+ for item in self.data:
+ if "control_path" in item:
+ item["control_path_0"] = item.pop("control_path")
+
+ # Ensure control paths are named consistently, from control_path_0000 to control_path_0, control_path_1, etc.
+ control_path_keys = [key for key in item.keys() if key.startswith("control_path_")]
+ control_path_keys.sort(key=lambda x: int(x.split("_")[-1]))
+ for i, key in enumerate(control_path_keys):
+ if key != f"control_path_{i}":
+ item[f"control_path_{i}"] = item.pop(key)
+
+ # Check if there are control paths in the JSONL
+ self.has_control = any("control_path_0" in item for item in self.data)
+ if self.has_control:
+ missing_control_images = [
+ item["image_path"]
+ for item in self.data
+ if sum(f"control_path_{i}" not in item for i in range(self.control_count_per_image)) > 0
+ ]
+ if missing_control_images:
+ logger.error(f"Some images do not have control paths in JSONL data: {missing_control_images}")
+ raise ValueError(f"Some images do not have control paths in JSONL data: {missing_control_images}")
+ logger.info(f"found {len(self.data)} images with {self.control_count_per_image} control images per image in JSONL data")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.data)
+
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[list[Image.Image]]]:
+ data = self.data[idx]
+ image_path = data["image_path"]
+ image = Image.open(image_path).convert("RGB")
+
+ caption = data["caption"]
+
+ controls = None
+ if self.has_control:
+ controls = []
+ for i in range(self.control_count_per_image):
+ control_path = data[f"control_path_{i}"]
+ control = Image.open(control_path)
+ if control.mode != "RGB" and control.mode != "RGBA":
+ control = control.convert("RGB")
+ controls.append(control)
+
+ return image_path, image, caption, controls
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ data = self.data[idx]
+ image_path = data["image_path"]
+ caption = data["caption"]
+ return image_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self) -> callable:
+ if self.current_idx >= len(self.data):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+
+ else:
+
+ def create_fetcher(index):
+ return lambda: self.get_image_data(index)
+
+ fetcher = create_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class VideoDatasource(ContentDatasource):
+ def __init__(self):
+ super().__init__()
+
+ # None means all frames
+ self.start_frame = None
+ self.end_frame = None
+
+ self.bucket_selector = None
+
+ self.source_fps = None
+ self.target_fps = None
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def get_video_data_from_path(
+ self,
+ video_path: str,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> tuple[str, list[Image.Image], str]:
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
+
+ start_frame = start_frame if start_frame is not None else self.start_frame
+ end_frame = end_frame if end_frame is not None else self.end_frame
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
+
+ video = load_video(
+ video_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
+ )
+ return video
+
+ def get_control_data_from_path(
+ self,
+ control_path: str,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> list[Image.Image]:
+ start_frame = start_frame if start_frame is not None else self.start_frame
+ end_frame = end_frame if end_frame is not None else self.end_frame
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
+
+ control = load_video(
+ control_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
+ )
+ return control
+
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
+ self.start_frame = start_frame
+ self.end_frame = end_frame
+
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
+ self.bucket_selector = bucket_selector
+
+ def set_source_and_target_fps(self, source_fps: Optional[float], target_fps: Optional[float]):
+ self.source_fps = source_fps
+ self.target_fps = target_fps
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __next__(self):
+ raise NotImplementedError
+
+
+class VideoDirectoryDatasource(VideoDatasource):
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None):
+ super().__init__()
+ self.video_directory = video_directory
+ self.caption_extension = caption_extension
+ self.control_directory = control_directory # 新しく追加: コントロール画像ディレクトリ
+ self.current_idx = 0
+
+ # glob videos
+ logger.info(f"glob videos in {self.video_directory}")
+ self.video_paths = glob_videos(self.video_directory)
+ logger.info(f"found {len(self.video_paths)} videos")
+
+ # glob control images if specified
+ if self.control_directory is not None:
+ logger.info(f"glob control videos in {self.control_directory}")
+ self.has_control = True
+ self.control_paths = {}
+ for video_path in self.video_paths:
+ video_basename = os.path.basename(video_path)
+ # construct control path from video path
+ # for example: video_path = "vid/video.mp4" -> control_path = "control/video.mp4"
+ control_path = os.path.join(self.control_directory, video_basename)
+ if os.path.exists(control_path):
+ self.control_paths[video_path] = control_path
+ else:
+ # use the same base name for control path
+ base_name = os.path.splitext(video_basename)[0]
+
+ # directory with images. for example: video_path = "vid/video.mp4" -> control_path = "control/video"
+ potential_path = os.path.join(self.control_directory, base_name) # no extension
+ if os.path.isdir(potential_path):
+ self.control_paths[video_path] = potential_path
+ else:
+ # another extension for control path
+ # for example: video_path = "vid/video.mp4" -> control_path = "control/video.mov"
+ for ext in VIDEO_EXTENSIONS:
+ potential_path = os.path.join(self.control_directory, base_name + ext)
+ if os.path.exists(potential_path):
+ self.control_paths[video_path] = potential_path
+ break
+
+ logger.info(f"found {len(self.control_paths)} matching control videos/images")
+ # check if all videos have matching control paths, if not, raise an error
+ missing_controls = len(self.video_paths) - len(self.control_paths)
+ if missing_controls > 0:
+ # logger.warning(f"Could not find matching control videos/images for {missing_controls} videos")
+ missing_controls_videos = [video_path for video_path in self.video_paths if video_path not in self.control_paths]
+ logger.error(
+ f"Could not find matching control videos/images for {missing_controls} videos: {missing_controls_videos}"
+ )
+ raise ValueError(f"Could not find matching control videos/images for {missing_controls} videos")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.video_paths)
+
+ def get_video_data(
+ self,
+ idx: int,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
+ video_path = self.video_paths[idx]
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
+
+ _, caption = self.get_caption(idx)
+
+ control = None
+ if self.control_directory is not None and video_path in self.control_paths:
+ control_path = self.control_paths[video_path]
+ control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
+
+ return video_path, video, caption, control
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ video_path = self.video_paths[idx]
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
+ with open(caption_path, "r", encoding="utf-8") as f:
+ caption = f.read().strip()
+ return video_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= len(self.video_paths):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+
+ else:
+
+ def create_fetcher(index):
+ return lambda: self.get_video_data(index)
+
+ fetcher = create_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class VideoJsonlDatasource(VideoDatasource):
+ def __init__(self, video_jsonl_file: str):
+ super().__init__()
+ self.video_jsonl_file = video_jsonl_file
+ self.current_idx = 0
+
+ # load jsonl
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
+ self.data = []
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ self.data.append(data)
+ logger.info(f"loaded {len(self.data)} videos")
+
+ # Check if there are control paths in the JSONL
+ self.has_control = any("control_path" in item for item in self.data)
+ if self.has_control:
+ control_count = sum(1 for item in self.data if "control_path" in item)
+ if control_count < len(self.data):
+ missing_control_videos = [item["video_path"] for item in self.data if "control_path" not in item]
+ logger.error(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
+ raise ValueError(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
+ logger.info(f"found {control_count} control videos/images in JSONL data")
+
+ def is_indexable(self):
+ return True
+
+ def __len__(self):
+ return len(self.data)
+
+ def get_video_data(
+ self,
+ idx: int,
+ start_frame: Optional[int] = None,
+ end_frame: Optional[int] = None,
+ bucket_selector: Optional[BucketSelector] = None,
+ ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
+ data = self.data[idx]
+ video_path = data["video_path"]
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
+
+ caption = data["caption"]
+
+ control = None
+ if "control_path" in data and data["control_path"]:
+ control_path = data["control_path"]
+ control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
+
+ return video_path, video, caption, control
+
+ def get_caption(self, idx: int) -> tuple[str, str]:
+ data = self.data[idx]
+ video_path = data["video_path"]
+ caption = data["caption"]
+ return video_path, caption
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= len(self.data):
+ raise StopIteration
+
+ if self.caption_only:
+
+ def create_caption_fetcher(index):
+ return lambda: self.get_caption(index)
+
+ fetcher = create_caption_fetcher(self.current_idx)
+
+ else:
+
+ def create_fetcher(index):
+ return lambda: self.get_video_data(index)
+
+ fetcher = create_fetcher(self.current_idx)
+
+ self.current_idx += 1
+ return fetcher
+
+
+class BaseDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ resolution: Tuple[int, int] = (960, 544),
+ caption_extension: Optional[str] = None,
+ batch_size: int = 1,
+ num_repeats: int = 1,
+ enable_bucket: bool = False,
+ bucket_no_upscale: bool = False,
+ cache_directory: Optional[str] = None,
+ debug_dataset: bool = False,
+ architecture: str = "no_default",
+ ):
+ self.resolution = resolution
+ self.caption_extension = caption_extension
+ self.batch_size = batch_size
+ self.num_repeats = num_repeats
+ self.enable_bucket = enable_bucket
+ self.bucket_no_upscale = bucket_no_upscale
+ self.cache_directory = cache_directory
+ self.debug_dataset = debug_dataset
+ self.architecture = architecture
+ self.seed = None
+ self.current_epoch = 0
+
+ if not self.enable_bucket:
+ self.bucket_no_upscale = False
+
+ def get_metadata(self) -> dict:
+ metadata = {
+ "resolution": self.resolution,
+ "caption_extension": self.caption_extension,
+ "batch_size_per_device": self.batch_size,
+ "num_repeats": self.num_repeats,
+ "enable_bucket": bool(self.enable_bucket),
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
+ }
+ return metadata
+
+ def get_all_latent_cache_files(self):
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
+
+ def get_all_text_encoder_output_cache_files(self):
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
+
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
+ """
+ Returns the cache path for the latent tensor.
+
+ item_info: ItemInfo object
+
+ Returns:
+ str: cache path
+
+ cache_path is based on the item_key and the resolution.
+ """
+ w, h = item_info.original_size
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
+
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
+ return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
+
+ def retrieve_latent_cache_batches(self, num_workers: int):
+ raise NotImplementedError
+
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
+ raise NotImplementedError
+
+ def prepare_for_training(self):
+ pass
+
+ def set_seed(self, seed: int):
+ self.seed = seed
+
+ def set_current_epoch(self, epoch):
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
+ if epoch > self.current_epoch:
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
+ num_epochs = epoch - self.current_epoch
+ for _ in range(num_epochs):
+ self.current_epoch += 1
+ self.shuffle_buckets()
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
+ else:
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
+ self.current_epoch = epoch
+
+ def set_current_step(self, step):
+ self.current_step = step
+
+ def set_max_train_steps(self, max_train_steps):
+ self.max_train_steps = max_train_steps
+
+ def shuffle_buckets(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ return NotImplementedError
+
+ def __getitem__(self, idx):
+ raise NotImplementedError
+
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
+ datasource.set_caption_only(True)
+ executor = ThreadPoolExecutor(max_workers=num_workers)
+
+ data: list[ItemInfo] = []
+ futures = []
+
+ def aggregate_future(consume_all: bool = False):
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
+ completed_futures = [future for future in futures if future.done()]
+ if len(completed_futures) == 0:
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
+ time.sleep(0.1)
+ continue
+ else:
+ break # submit batch if possible
+
+ for future in completed_futures:
+ item_key, caption = future.result()
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
+ data.append(item_info)
+
+ futures.remove(future)
+
+ def submit_batch(flush: bool = False):
+ nonlocal data
+ if len(data) >= batch_size or (len(data) > 0 and flush):
+ batch = data[0:batch_size]
+ if len(data) > batch_size:
+ data = data[batch_size:]
+ else:
+ data = []
+ return batch
+ return None
+
+ for fetch_op in datasource:
+ future = executor.submit(fetch_op)
+ futures.append(future)
+ aggregate_future()
+ while True:
+ batch = submit_batch()
+ if batch is None:
+ break
+ yield batch
+
+ aggregate_future(consume_all=True)
+ while True:
+ batch = submit_batch(flush=True)
+ if batch is None:
+ break
+ yield batch
+
+ executor.shutdown()
+
+
+class ImageDataset(BaseDataset):
+ def __init__(
+ self,
+ resolution: Tuple[int, int],
+ caption_extension: Optional[str],
+ batch_size: int,
+ num_repeats: int,
+ enable_bucket: bool,
+ bucket_no_upscale: bool,
+ image_directory: Optional[str] = None,
+ image_jsonl_file: Optional[str] = None,
+ control_directory: Optional[str] = None,
+ cache_directory: Optional[str] = None,
+ fp_latent_window_size: Optional[int] = 9,
+ fp_1f_clean_indices: Optional[list[int]] = None,
+ fp_1f_target_index: Optional[int] = None,
+ fp_1f_no_post: Optional[bool] = False,
+ debug_dataset: bool = False,
+ architecture: str = "no_default",
+ ):
+ super(ImageDataset, self).__init__(
+ resolution,
+ caption_extension,
+ batch_size,
+ num_repeats,
+ enable_bucket,
+ bucket_no_upscale,
+ cache_directory,
+ debug_dataset,
+ architecture,
+ )
+ self.image_directory = image_directory
+ self.image_jsonl_file = image_jsonl_file
+ self.control_directory = control_directory
+ self.fp_latent_window_size = fp_latent_window_size
+ self.fp_1f_clean_indices = fp_1f_clean_indices
+ self.fp_1f_target_index = fp_1f_target_index
+ self.fp_1f_no_post = fp_1f_no_post
+
+ control_count_per_image = 1
+ if fp_1f_clean_indices is not None:
+ control_count_per_image = len(fp_1f_clean_indices)
+
+ if image_directory is not None:
+ self.datasource = ImageDirectoryDatasource(
+ image_directory, caption_extension, control_directory, control_count_per_image
+ )
+ elif image_jsonl_file is not None:
+ self.datasource = ImageJsonlDatasource(image_jsonl_file, control_count_per_image)
+ else:
+ raise ValueError("image_directory or image_jsonl_file must be specified")
+
+ if self.cache_directory is None:
+ self.cache_directory = self.image_directory
+
+ self.batch_manager = None
+ self.num_train_items = 0
+ self.has_control = self.datasource.has_control
+
+ def get_metadata(self):
+ metadata = super().get_metadata()
+ if self.image_directory is not None:
+ metadata["image_directory"] = os.path.basename(self.image_directory)
+ if self.image_jsonl_file is not None:
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
+ if self.control_directory is not None:
+ metadata["control_directory"] = os.path.basename(self.control_directory)
+ metadata["has_control"] = self.has_control
+ return metadata
+
+ def get_total_image_count(self):
+ return len(self.datasource) if self.datasource.is_indexable() else None
+
+ def retrieve_latent_cache_batches(self, num_workers: int):
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
+ executor = ThreadPoolExecutor(max_workers=num_workers)
+
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
+ futures = []
+
+ # aggregate futures and sort by bucket resolution
+ def aggregate_future(consume_all: bool = False):
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
+ completed_futures = [future for future in futures if future.done()]
+ if len(completed_futures) == 0:
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
+ time.sleep(0.1)
+ continue
+ else:
+ break # submit batch if possible
+
+ for future in completed_futures:
+ original_size, item_key, image, caption, controls = future.result()
+ bucket_height, bucket_width = image.shape[:2]
+ bucket_reso = (bucket_width, bucket_height)
+
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
+ item_info.fp_latent_window_size = self.fp_latent_window_size
+ item_info.fp_1f_clean_indices = self.fp_1f_clean_indices
+ item_info.fp_1f_target_index = self.fp_1f_target_index
+ item_info.fp_1f_no_post = self.fp_1f_no_post
+
+ if self.architecture == ARCHITECTURE_FRAMEPACK:
+ # we need to split the bucket with latent window size and optional 1f clean indices, zero post
+ bucket_reso = list(bucket_reso) + [self.fp_latent_window_size]
+ if self.fp_1f_clean_indices is not None:
+ bucket_reso.append(len(self.fp_1f_clean_indices))
+ bucket_reso.append(self.fp_1f_no_post)
+ bucket_reso = tuple(bucket_reso)
+
+ if controls is not None:
+ item_info.control_content = controls
+
+ if bucket_reso not in batches:
+ batches[bucket_reso] = []
+ batches[bucket_reso].append(item_info)
+
+ futures.remove(future)
+
+ # submit batch if some bucket has enough items
+ def submit_batch(flush: bool = False):
+ for key in batches:
+ if len(batches[key]) >= self.batch_size or flush:
+ batch = batches[key][0 : self.batch_size]
+ if len(batches[key]) > self.batch_size:
+ batches[key] = batches[key][self.batch_size :]
+ else:
+ del batches[key]
+ return key, batch
+ return None, None
+
+ for fetch_op in self.datasource:
+
+ # fetch and resize image in a separate thread
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str, Optional[Image.Image]]:
+ image_key, image, caption, controls = op()
+ image: Image.Image
+ image_size = image.size
+
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
+ image = resize_image_to_bucket(image, bucket_reso) # returns np.ndarray
+ resized_controls = None
+ if controls is not None:
+ resized_controls = []
+ for control in controls:
+ resized_control = resize_image_to_bucket(control, bucket_reso) # returns np.ndarray
+ resized_controls.append(resized_control)
+
+ return image_size, image_key, image, caption, resized_controls
+
+ future = executor.submit(fetch_and_resize, fetch_op)
+ futures.append(future)
+ aggregate_future()
+ while True:
+ key, batch = submit_batch()
+ if key is None:
+ break
+ yield key, batch
+
+ aggregate_future(consume_all=True)
+ while True:
+ key, batch = submit_batch(flush=True)
+ if key is None:
+ break
+ yield key, batch
+
+ executor.shutdown()
+
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
+
+ def prepare_for_training(self):
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
+
+ # glob cache files
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
+
+ # assign cache files to item info
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
+ for cache_file in latent_cache_files:
+ tokens = os.path.basename(cache_file).split("_")
+
+ image_size = tokens[-2] # 0000x0000
+ image_width, image_height = map(int, image_size.split("x"))
+ image_size = (image_width, image_height)
+
+ item_key = "_".join(tokens[:-2])
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
+ if not os.path.exists(text_encoder_output_cache_file):
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
+ continue
+
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
+
+ if self.architecture == ARCHITECTURE_FRAMEPACK:
+ # we need to split the bucket with latent window size and optional 1f clean indices, zero post
+ bucket_reso = list(bucket_reso) + [self.fp_latent_window_size]
+ if self.fp_1f_clean_indices is not None:
+ bucket_reso.append(len(self.fp_1f_clean_indices))
+ bucket_reso.append(self.fp_1f_no_post)
+ bucket_reso = tuple(bucket_reso)
+
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
+
+ bucket = bucketed_item_info.get(bucket_reso, [])
+ for _ in range(self.num_repeats):
+ bucket.append(item_info)
+ bucketed_item_info[bucket_reso] = bucket
+
+ # prepare batch manager
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
+ self.batch_manager.show_bucket_info()
+
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
+
+ def shuffle_buckets(self):
+ # set random seed for this epoch
+ random.seed(self.seed + self.current_epoch)
+ self.batch_manager.shuffle()
+
+ def __len__(self):
+ if self.batch_manager is None:
+ return 100 # dummy value
+ return len(self.batch_manager)
+
+ def __getitem__(self, idx):
+ return self.batch_manager[idx]
+
+
+class VideoDataset(BaseDataset):
+ TARGET_FPS_HUNYUAN = 24.0
+ TARGET_FPS_WAN = 16.0
+ TARGET_FPS_FRAMEPACK = 30.0
+
+ def __init__(
+ self,
+ resolution: Tuple[int, int],
+ caption_extension: Optional[str],
+ batch_size: int,
+ num_repeats: int,
+ enable_bucket: bool,
+ bucket_no_upscale: bool,
+ frame_extraction: Optional[str] = "head",
+ frame_stride: Optional[int] = 1,
+ frame_sample: Optional[int] = 1,
+ target_frames: Optional[list[int]] = None,
+ max_frames: Optional[int] = None,
+ source_fps: Optional[float] = None,
+ video_directory: Optional[str] = None,
+ video_jsonl_file: Optional[str] = None,
+ control_directory: Optional[str] = None,
+ cache_directory: Optional[str] = None,
+ fp_latent_window_size: Optional[int] = 9,
+ debug_dataset: bool = False,
+ architecture: str = "no_default",
+ ):
+ super(VideoDataset, self).__init__(
+ resolution,
+ caption_extension,
+ batch_size,
+ num_repeats,
+ enable_bucket,
+ bucket_no_upscale,
+ cache_directory,
+ debug_dataset,
+ architecture,
+ )
+ self.video_directory = video_directory
+ self.video_jsonl_file = video_jsonl_file
+ self.control_directory = control_directory
+ self.frame_extraction = frame_extraction
+ self.frame_stride = frame_stride
+ self.frame_sample = frame_sample
+ self.max_frames = max_frames
+ self.source_fps = source_fps
+ self.fp_latent_window_size = fp_latent_window_size
+
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
+ self.target_fps = VideoDataset.TARGET_FPS_HUNYUAN
+ elif self.architecture == ARCHITECTURE_WAN:
+ self.target_fps = VideoDataset.TARGET_FPS_WAN
+ elif self.architecture == ARCHITECTURE_FRAMEPACK:
+ self.target_fps = VideoDataset.TARGET_FPS_FRAMEPACK
+ else:
+ raise ValueError(f"Unsupported architecture: {self.architecture}")
+
+ if target_frames is not None:
+ target_frames = list(set(target_frames))
+ target_frames.sort()
+
+ # round each value to N*4+1
+ rounded_target_frames = [(f - 1) // 4 * 4 + 1 for f in target_frames]
+ rouneded_target_frames = list(set(rounded_target_frames))
+ rouneded_target_frames.sort()
+
+ # if value is changed, warn
+ if target_frames != rounded_target_frames:
+ logger.warning(f"target_frames are rounded to {rounded_target_frames}")
+
+ target_frames = tuple(rounded_target_frames)
+
+ self.target_frames = target_frames
+
+ if video_directory is not None:
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension, control_directory)
+ elif video_jsonl_file is not None:
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
+
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
+ self.frame_extraction = "head"
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
+ if self.frame_extraction == "head":
+ # head extraction. we can limit the number of frames to be extracted
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
+
+ if self.cache_directory is None:
+ self.cache_directory = self.video_directory
+
+ self.batch_manager = None
+ self.num_train_items = 0
+ self.has_control = self.datasource.has_control
+
+ def get_metadata(self):
+ metadata = super().get_metadata()
+ if self.video_directory is not None:
+ metadata["video_directory"] = os.path.basename(self.video_directory)
+ if self.video_jsonl_file is not None:
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
+ if self.control_directory is not None:
+ metadata["control_directory"] = os.path.basename(self.control_directory)
+ metadata["frame_extraction"] = self.frame_extraction
+ metadata["frame_stride"] = self.frame_stride
+ metadata["frame_sample"] = self.frame_sample
+ metadata["target_frames"] = self.target_frames
+ metadata["max_frames"] = self.max_frames
+ metadata["source_fps"] = self.source_fps
+ metadata["has_control"] = self.has_control
+ return metadata
+
+ def retrieve_latent_cache_batches(self, num_workers: int):
+ buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
+ self.datasource.set_bucket_selector(buckset_selector)
+ if self.source_fps is not None:
+ self.datasource.set_source_and_target_fps(self.source_fps, self.target_fps)
+ else:
+ self.datasource.set_source_and_target_fps(None, None) # no conversion
+
+ executor = ThreadPoolExecutor(max_workers=num_workers)
+
+ # key: (width, height, frame_count) and optional latent_window_size, value: [ItemInfo]
+ batches: dict[tuple[Any], list[ItemInfo]] = {}
+ futures = []
+
+ def aggregate_future(consume_all: bool = False):
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
+ completed_futures = [future for future in futures if future.done()]
+ if len(completed_futures) == 0:
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
+ time.sleep(0.1)
+ continue
+ else:
+ break # submit batch if possible
+
+ for future in completed_futures:
+ original_frame_size, video_key, video, caption, control = future.result()
+
+ frame_count = len(video)
+ video = np.stack(video, axis=0)
+ height, width = video.shape[1:3]
+ bucket_reso = (width, height) # already resized
+
+ # process control images if available
+ control_video = None
+ if control is not None:
+ # set frame count to the same as video
+ if len(control) > frame_count:
+ control = control[:frame_count]
+ elif len(control) < frame_count:
+ # if control is shorter than video, repeat the last frame
+ last_frame = control[-1]
+ control.extend([last_frame] * (frame_count - len(control)))
+ control_video = np.stack(control, axis=0)
+
+ crop_pos_and_frames = []
+ if self.frame_extraction == "head":
+ for target_frame in self.target_frames:
+ if frame_count >= target_frame:
+ crop_pos_and_frames.append((0, target_frame))
+ elif self.frame_extraction == "chunk":
+ # split by target_frames
+ for target_frame in self.target_frames:
+ for i in range(0, frame_count, target_frame):
+ if i + target_frame <= frame_count:
+ crop_pos_and_frames.append((i, target_frame))
+ elif self.frame_extraction == "slide":
+ # slide window
+ for target_frame in self.target_frames:
+ if frame_count >= target_frame:
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
+ crop_pos_and_frames.append((i, target_frame))
+ elif self.frame_extraction == "uniform":
+ # select N frames uniformly
+ for target_frame in self.target_frames:
+ if frame_count >= target_frame:
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
+ for i in frame_indices:
+ crop_pos_and_frames.append((i, target_frame))
+ elif self.frame_extraction == "full":
+ # select all frames
+ target_frame = min(frame_count, self.max_frames)
+ target_frame = (target_frame - 1) // 4 * 4 + 1 # round to N*4+1
+ crop_pos_and_frames.append((0, target_frame))
+ else:
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
+
+ for crop_pos, target_frame in crop_pos_and_frames:
+ cropped_video = video[crop_pos : crop_pos + target_frame]
+ body, ext = os.path.splitext(video_key)
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
+
+ if self.architecture == ARCHITECTURE_FRAMEPACK:
+ # add latent window size to bucket resolution
+ batch_key = (*batch_key, self.fp_latent_window_size)
+
+ # crop control video if available
+ cropped_control = None
+ if control_video is not None:
+ cropped_control = control_video[crop_pos : crop_pos + target_frame]
+
+ item_info = ItemInfo(
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
+ )
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
+ item_info.control_content = cropped_control # None is allowed
+ item_info.fp_latent_window_size = self.fp_latent_window_size
+
+ batch = batches.get(batch_key, [])
+ batch.append(item_info)
+ batches[batch_key] = batch
+
+ futures.remove(future)
+
+ def submit_batch(flush: bool = False):
+ for key in batches:
+ if len(batches[key]) >= self.batch_size or flush:
+ batch = batches[key][0 : self.batch_size]
+ if len(batches[key]) > self.batch_size:
+ batches[key] = batches[key][self.batch_size :]
+ else:
+ del batches[key]
+ return key, batch
+ return None, None
+
+ for operator in self.datasource:
+
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str, Optional[list[np.ndarray]]]:
+ result = op()
+
+ if len(result) == 3: # for backward compatibility TODO remove this in the future
+ video_key, video, caption = result
+ control = None
+ else:
+ video_key, video, caption, control = result
+
+ video: list[np.ndarray]
+ frame_size = (video[0].shape[1], video[0].shape[0])
+
+ # resize if necessary
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
+
+ # resize control if necessary
+ if control is not None:
+ control = [resize_image_to_bucket(frame, bucket_reso) for frame in control]
+
+ return frame_size, video_key, video, caption, control
+
+ future = executor.submit(fetch_and_resize, operator)
+ futures.append(future)
+ aggregate_future()
+ while True:
+ key, batch = submit_batch()
+ if key is None:
+ break
+ yield key, batch
+
+ aggregate_future(consume_all=True)
+ while True:
+ key, batch = submit_batch(flush=True)
+ if key is None:
+ break
+ yield key, batch
+
+ executor.shutdown()
+
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
+
+ def prepare_for_training(self):
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
+
+ # glob cache files
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
+
+ # assign cache files to item info
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
+ for cache_file in latent_cache_files:
+ tokens = os.path.basename(cache_file).split("_")
+
+ image_size = tokens[-2] # 0000x0000
+ image_width, image_height = map(int, image_size.split("x"))
+ image_size = (image_width, image_height)
+
+ frame_pos, frame_count = tokens[-3].split("-")[:2] # "00000-000", or optional section index "00000-000-00"
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
+
+ item_key = "_".join(tokens[:-3])
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
+ if not os.path.exists(text_encoder_output_cache_file):
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
+ continue
+
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
+ bucket_reso = (*bucket_reso, frame_count)
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
+
+ bucket = bucketed_item_info.get(bucket_reso, [])
+ for _ in range(self.num_repeats):
+ bucket.append(item_info)
+ bucketed_item_info[bucket_reso] = bucket
+
+ # prepare batch manager
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
+ self.batch_manager.show_bucket_info()
+
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
+
+ def shuffle_buckets(self):
+ # set random seed for this epoch
+ random.seed(self.seed + self.current_epoch)
+ self.batch_manager.shuffle()
+
+ def __len__(self):
+ if self.batch_manager is None:
+ return 100 # dummy value
+ return len(self.batch_manager)
+
+ def __getitem__(self, idx):
+ return self.batch_manager[idx]
+
+
+class DatasetGroup(torch.utils.data.ConcatDataset):
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
+ super().__init__(datasets)
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
+ self.num_train_items = 0
+ for dataset in self.datasets:
+ self.num_train_items += dataset.num_train_items
+
+ def set_current_epoch(self, epoch):
+ for dataset in self.datasets:
+ dataset.set_current_epoch(epoch)
+
+ def set_current_step(self, step):
+ for dataset in self.datasets:
+ dataset.set_current_step(step)
+
+ def set_max_train_steps(self, max_train_steps):
+ for dataset in self.datasets:
+ dataset.set_max_train_steps(max_train_steps)
diff --git a/docs/advanced_config.md b/docs/advanced_config.md
new file mode 100644
index 0000000000000000000000000000000000000000..467a75d0b40c061b9fc1b61ca19f0065aa44d4b0
--- /dev/null
+++ b/docs/advanced_config.md
@@ -0,0 +1,316 @@
+> 📝 Click on the language section to expand / 言語をクリックして展開
+
+# Advanced configuration / 高度な設定
+
+## Table of contents / 目次
+
+- [How to specify `network_args`](#how-to-specify-network_args--network_argsの指定方法)
+- [LoRA+](#lora)
+- [Select the target modules of LoRA](#select-the-target-modules-of-lora--loraの対象モジュールを選択する)
+- [Save and view logs in TensorBoard format](#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)
+- [Save and view logs in wandb](#save-and-view-logs-in-wandb--wandbでログの保存と参照)
+- [FP8 weight optimization for models](#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化)
+- [PyTorch Dynamo optimization for model training](#pytorch-dynamo-optimization-for-model-training--モデルの学習におけるpytorch-dynamoの最適化)
+
+## How to specify `network_args` / `network_args`の指定方法
+
+The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
+
+
+日本語
+`--network_args`オプションは、LoRAへの詳細な引数を指定するためのオプションです。`--network_args`には、`key=value`の形式で引数を指定します。
+
+
+### Example / 記述例
+
+If you specify it on the command line, write as follows. / コマンドラインで指定する場合は以下のように記述します。
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
+ --network_module networks.lora --network_dim 32
+ --network_args "key1=value1" "key2=value2" ...
+```
+
+If you specify it in the configuration file, write as follows. / 設定ファイルで指定する場合は以下のように記述します。
+
+```toml
+network_args = ["key1=value1", "key2=value2", ...]
+```
+
+If you specify `"verbose=True"`, detailed information of LoRA will be displayed. / `"verbose=True"`を指定するとLoRAの詳細な情報が表示されます。
+
+```bash
+--network_args "verbose=True" "key1=value1" "key2=value2" ...
+```
+
+## LoRA+
+
+LoRA+ is a method to improve the training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiplier for the learning rate. The original paper recommends 16, but adjust as needed. It seems to be good to start from around 4. For details, please refer to the [related PR of sd-scripts](https://github.com/kohya-ss/sd-scripts/pull/1233).
+
+Specify `loraplus_lr_ratio` with `--network_args`.
+
+
+日本語
+
+LoRA+は、LoRAのUP側(LoRA-B)の学習率を上げることで学習速度を向上させる手法です。学習率に対する倍率を指定します。元論文では16を推奨していますが、必要に応じて調整してください。4程度から始めるとよいようです。詳細は[sd-scriptsの関連PR]https://github.com/kohya-ss/sd-scripts/pull/1233)を参照してください。
+
+`--network_args`で`loraplus_lr_ratio`を指定します。
+
+
+### Example / 記述例
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
+ --network_module networks.lora --network_dim 32 --network_args "loraplus_lr_ratio=4" ...
+```
+
+## Select the target modules of LoRA / LoRAの対象モジュールを選択する
+
+*This feature is highly experimental and the specification may change. / この機能は特に実験的なもので、仕様は変更される可能性があります。*
+
+By specifying `exclude_patterns` and `include_patterns` with `--network_args`, you can select the target modules of LoRA.
+
+`exclude_patterns` excludes modules that match the specified pattern. `include_patterns` targets only modules that match the specified pattern.
+
+Specify the values as a list. For example, `"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`.
+
+The pattern is a regular expression for the module name. The module name is in the form of `double_blocks.0.img_mod.linear` or `single_blocks.39.modulation.linear`. The regular expression is not a partial match but a complete match.
+
+The patterns are applied in the order of `exclude_patterns`→`include_patterns`. By default, the Linear layers of `img_mod`, `txt_mod`, and `modulation` of double blocks and single blocks are excluded.
+
+(`.*(img_mod|txt_mod|modulation).*` is specified.)
+
+
+日本語
+
+`--network_args`で`exclude_patterns`と`include_patterns`を指定することで、LoRAの対象モジュールを選択することができます。
+
+`exclude_patterns`は、指定したパターンに一致するモジュールを除外します。`include_patterns`は、指定したパターンに一致するモジュールのみを対象とします。
+
+値は、リストで指定します。`"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`のようになります。
+
+パターンは、モジュール名に対する正規表現です。モジュール名は、たとえば`double_blocks.0.img_mod.linear`や`single_blocks.39.modulation.linear`のような形式です。正規表現は部分一致ではなく完全一致です。
+
+パターンは、`exclude_patterns`→`include_patterns`の順で適用されます。デフォルトは、double blocksとsingle blocksのLinear層のうち、`img_mod`、`txt_mod`、`modulation`が除外されています。
+
+(`.*(img_mod|txt_mod|modulation).*`が指定されています。)
+
+
+### Example / 記述例
+
+Only the modules of double blocks / double blocksのモジュールのみを対象とする場合:
+
+```bash
+--network_args "exclude_patterns=[r'.*single_blocks.*']"
+```
+
+Only the modules of single blocks from the 10th / single blocksの10番目以降のLinearモジュールのみを対象とする場合:
+
+```bash
+--network_args "exclude_patterns=[r'.*']" "include_patterns=[r'.*single_blocks\.\d{2}\.linear.*']"
+```
+
+## Save and view logs in TensorBoard format / TensorBoard形式のログの保存と参照
+
+Specify the folder to save the logs with the `--logging_dir` option. Logs in TensorBoard format will be saved.
+
+For example, if you specify `--logging_dir=logs`, a `logs` folder will be created in the working folder, and logs will be saved in the date folder inside it.
+
+Also, if you specify the `--log_prefix` option, the specified string will be added before the date. For example, use `--logging_dir=logs --log_prefix=lora_setting1_` for identification.
+
+To view logs in TensorBoard, open another command prompt and activate the virtual environment. Then enter the following in the working folder.
+
+```powershell
+tensorboard --logdir=logs
+```
+
+(tensorboard installation is required.)
+
+Then open a browser and access http://localhost:6006/ to display it.
+
+
+日本語
+`--logging_dir`オプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
+
+たとえば`--logging_dir=logs`と指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
+
+また`--log_prefix`オプションを指定すると、日時の前に指定した文字列が追加されます。`--logging_dir=logs --log_prefix=lora_setting1_`などとして識別用にお使いください。
+
+TensorBoardでログを確認するには、別のコマンドプロンプトを開き、仮想環境を有効にしてから、作業フォルダで以下のように入力します。
+
+```powershell
+tensorboard --logdir=logs
+```
+
+(tensorboardのインストールが必要です。)
+
+その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
+
+
+## Save and view logs in wandb / wandbでログの保存と参照
+
+`--log_with wandb` option is available to save logs in wandb format. `tensorboard` or `all` is also available. The default is `tensorboard`.
+
+Specify the project name with `--log_tracker_name` when using wandb.
+
+
+日本語
+`--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
+
+wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
+
+
+## FP8 weight optimization for models / モデルの重みのFP8への最適化
+
+The `--fp8_scaled` option is available to quantize the weights of the model to FP8 (E4M3) format with appropriate scaling. This reduces the VRAM usage while maintaining precision. Important weights are kept in FP16/BF16/FP32 format.
+
+The model weights must be in fp16 or bf16. Weights that have been pre-converted to float8_e4m3 cannot be used.
+
+Wan2.1 inference and training are supported.
+
+Specify the `--fp8_scaled` option in addition to the `--fp8` option during inference.
+
+Specify the `--fp8_scaled` option in addition to the `--fp8_base` option during training.
+
+Acknowledgments: This feature is based on the [implementation](https://github.com/Tencent/HunyuanVideo/blob/7df4a45c7e424a3f6cd7d653a7ff1f60cddc1eb1/hyvideo/modules/fp8_optimization.py) of [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). The selection of high-precision modules is based on the [implementation](https://github.com/tdrussell/diffusion-pipe/blob/407c04fdae1c9ab5e67b54d33bef62c3e0a8dbc7/models/wan.py) of [diffusion-pipe](https://github.com/tdrussell/diffusion-pipe). I would like to thank these repositories.
+
+
+日本語
+重みを単純にFP8へcastするのではなく、適切なスケーリングでFP8形式に量子化することで、精度を維持しつつVRAM使用量を削減します。また、重要な重みはFP16/BF16/FP32形式で保持します。
+
+モデルの重みは、fp16またはbf16が必要です。あらかじめfloat8_e4m3に変換された重みは使用できません。
+
+Wan2.1の推論、学習のみ対応しています。
+
+推論時は`--fp8`オプションに加えて `--fp8_scaled`オプションを指定してください。
+
+学習時は`--fp8_base`オプションに加えて `--fp8_scaled`オプションを指定してください。
+
+謝辞:この機能は、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)の[実装](https://github.com/Tencent/HunyuanVideo/blob/7df4a45c7e424a3f6cd7d653a7ff1f60cddc1eb1/hyvideo/modules/fp8_optimization.py)を参考にしました。また、高精度モジュールの選択においては[diffusion-pipe](https://github.com/tdrussell/diffusion-pipe)の[実装](https://github.com/tdrussell/diffusion-pipe/blob/407c04fdae1c9ab5e67b54d33bef62c3e0a8dbc7/models/wan.py)を参考にしました。これらのリポジトリに感謝します。
+
+
+
+### Key features and implementation details / 主な特徴と実装の詳細
+
+- Implements FP8 (E4M3) weight quantization for Linear layers
+- Reduces VRAM requirements by using 8-bit weights for storage (slightly increased compared to existing `--fp8` `--fp8_base` options)
+- Quantizes weights to FP8 format with appropriate scaling instead of simple cast to FP8
+- Maintains computational precision by dequantizing to original precision (FP16/BF16/FP32) during forward pass
+- Preserves important weights in FP16/BF16/FP32 format
+
+The implementation:
+
+1. Quantizes weights to FP8 format with appropriate scaling
+2. Replaces weights by FP8 quantized weights and stores scale factors in model state dict
+3. Applies monkey patching to Linear layers for transparent dequantization during computation
+
+
+日本語
+
+- Linear層のFP8(E4M3)重み量子化を実装
+- 8ビットの重みを使用することでVRAM使用量を削減(既存の`--fp8` `--fp8_base` オプションに比べて微増)
+- 単純なFP8へのcastではなく、適切な値でスケールして重みをFP8形式に量子化
+- forward時に元の精度(FP16/BF16/FP32)に逆量子化して計算精度を維持
+- 精度が重要な重みはFP16/BF16/FP32のまま保持
+
+実装:
+
+1. 精度を維持できる適切な倍率で重みをFP8形式に量子化
+2. 重みをFP8量子化重みに置き換え、倍率をモデルのstate dictに保存
+3. Linear層にmonkey patchingすることでモデルを変更せずに逆量子化
+
+
+ ## PyTorch Dynamo optimization for model training / モデルの学習におけるPyTorch Dynamoの最適化
+
+The PyTorch Dynamo options are now available to optimize the training process. PyTorch Dynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster by using TorchInductor, a deep learning compiler. This integration allows for potential speedups in training while maintaining model accuracy.
+
+[PR #215](https://github.com/kohya-ss/musubi-tuner/pull/215) added this feature.
+
+Specify the `--dynamo_backend` option to enable Dynamo optimization with one of the available backends from the `DynamoBackend` enum.
+
+Additional options allow for fine-tuning the Dynamo behavior:
+- `--dynamo_mode`: Controls the optimization strategy
+- `--dynamo_fullgraph`: Enables fullgraph mode for potentially better optimization
+- `--dynamo_dynamic`: Enables dynamic shape handling
+
+The `--dynamo_dynamic` option has been reported to have many problems based on the validation in PR #215.
+
+### Available options:
+
+```
+--dynamo_backend {NO, INDUCTOR, NVFUSER, CUDAGRAPHS, CUDAGRAPHS_FALLBACK, etc.}
+ Specifies the Dynamo backend to use (default is NO, which disables Dynamo)
+
+--dynamo_mode {default, reduce-overhead, max-autotune}
+ Specifies the optimization mode (default is 'default')
+ - 'default': Standard optimization
+ - 'reduce-overhead': Focuses on reducing compilation overhead
+ - 'max-autotune': Performs extensive autotuning for potentially better performance
+
+--dynamo_fullgraph
+ Flag to enable fullgraph mode, which attempts to capture and optimize the entire model graph
+
+--dynamo_dynamic
+ Flag to enable dynamic shape handling for models with variable input shapes
+```
+
+### Usage example:
+
+```bash
+python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode default
+```
+
+For more aggressive optimization:
+```bash
+python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode max-autotune --dynamo_fullgraph
+```
+
+Note: The best combination of options may depend on your specific model and hardware. Experimentation may be necessary to find the optimal configuration.
+
+
+日本語
+PyTorch Dynamoオプションが学習プロセスを最適化するために追加されました。PyTorch Dynamoは、TorchInductor(ディープラーニングコンパイラ)を使用して、変更を加えることなくPyTorchプログラムを高速化するためのPythonレベルのJITコンパイラです。この統合により、モデルの精度を維持しながら学習の高速化が期待できます。
+
+[PR #215](https://github.com/kohya-ss/musubi-tuner/pull/215) で追加されました。
+
+`--dynamo_backend`オプションを指定して、`DynamoBackend`列挙型から利用可能なバックエンドの一つを選択することで、Dynamo最適化を有効にします。
+
+追加のオプションにより、Dynamoの動作を微調整できます:
+- `--dynamo_mode`:最適化戦略を制御します
+- `--dynamo_fullgraph`:より良い最適化の可能性のためにフルグラフモードを有効にします
+- `--dynamo_dynamic`:動的形状処理を有効にします
+
+PR #215での検証によると、`--dynamo_dynamic`には問題が多いことが報告されています。
+
+__利用可能なオプション:__
+
+```
+--dynamo_backend {NO, INDUCTOR, NVFUSER, CUDAGRAPHS, CUDAGRAPHS_FALLBACK, など}
+ 使用するDynamoバックエンドを指定します(デフォルトはNOで、Dynamoを無効にします)
+
+--dynamo_mode {default, reduce-overhead, max-autotune}
+ 最適化モードを指定します(デフォルトは 'default')
+ - 'default':標準的な最適化
+ - 'reduce-overhead':コンパイルのオーバーヘッド削減に焦点を当てる
+ - 'max-autotune':より良いパフォーマンスのために広範な自動調整を実行
+
+--dynamo_fullgraph
+ フルグラフモードを有効にするフラグ。モデルグラフ全体をキャプチャして最適化しようとします
+
+--dynamo_dynamic
+ 可変入力形状を持つモデルのための動的形状処理を有効にするフラグ
+```
+
+__使用例:__
+
+```bash
+python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode default
+```
+
+より積極的な最適化の場合:
+```bash
+python train_video_model.py --dynamo_backend INDUCTOR --dynamo_mode max-autotune --dynamo_fullgraph
+```
+
+注意:最適なオプションの組み合わせは、特定のモデルとハードウェアに依存する場合があります。最適な構成を見つけるために実験が必要かもしれません。
+
diff --git a/docs/framepack.md b/docs/framepack.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e4e7df4620de3fea2529d0a878f7c03b8194626
--- /dev/null
+++ b/docs/framepack.md
@@ -0,0 +1,607 @@
+# FramePack
+
+## Overview / 概要
+
+This document describes the usage of the [FramePack](https://github.com/lllyasviel/FramePack) architecture within the Musubi Tuner framework. FramePack is a novel video generation architecture developed by lllyasviel.
+
+Key differences from HunyuanVideo:
+- FramePack only supports Image-to-Video (I2V) generation. Text-to-Video (T2V) is not supported.
+- It utilizes a different DiT model architecture and requires an additional Image Encoder. VAE is same as HunyuanVideo. Text Encoders seem to be the same as HunyuanVideo but we employ the original FramePack method to utilize them.
+- Caching and training scripts are specific to FramePack (`fpack_*.py`).
+- Due to its progressive generation nature, VRAM usage can be significantly lower, especially for longer videos, compared to other architectures.
+
+The official documentation does not provide detailed explanations on how to train the model, but it is based on the FramePack implementation and paper.
+
+This feature is experimental.
+
+For one-frame inference and training, see [here](./framepack_1f.md).
+
+
+日本語
+
+このドキュメントは、Musubi Tunerフレームワーク内での[FramePack](https://github.com/lllyasviel/FramePack) アーキテクチャの使用法について説明しています。FramePackは、lllyasviel氏にによって開発された新しいビデオ生成アーキテクチャです。
+
+HunyuanVideoとの主な違いは次のとおりです。
+- FramePackは、画像からビデオ(I2V)生成のみをサポートしています。テキストからビデオ(T2V)はサポートされていません。
+- 異なるDiTモデルアーキテクチャを使用し、追加の画像エンコーダーが必要です。VAEはHunyuanVideoと同じです。テキストエンコーダーはHunyuanVideoと同じと思われますが、FramePack公式と同じ方法で推論を行っています。
+- キャッシングと学習スクリプトはFramePack専用(`fpack_*.py`)です。
+- セクションずつ生成するため、他のアーキテクチャと比較して、特に長いビデオの場合、VRAM使用量が大幅に少なくなる可能性があります。
+
+学習方法について公式からは詳細な説明はありませんが、FramePackの実装と論文を参考にしています。
+
+この機能は実験的なものです。
+
+1フレーム推論、学習については[こちら](./framepack_1f.md)を参照してください。
+
+
+## Download the model / モデルのダウンロード
+
+You need to download the DiT, VAE, Text Encoder 1 (LLaMA), Text Encoder 2 (CLIP), and Image Encoder (SigLIP) models specifically for FramePack. Several download options are available for each component.
+
+***Note:** The weights are publicly available on the following page: [maybleMyers/framepack_h1111](https://huggingface.co/maybleMyers/framepack_h1111) (except for FramePack-F1). Thank you maybleMyers!
+
+### DiT Model
+
+Choose one of the following methods:
+
+1. **From lllyasviel's Hugging Face repo:** Download the three `.safetensors` files (starting with `diffusion_pytorch_model-00001-of-00003.safetensors`) from [lllyasviel/FramePackI2V_HY](https://huggingface.co/lllyasviel/FramePackI2V_HY). Specify the path to the first file (`...-00001-of-00003.safetensors`) as the `--dit` argument. For FramePack-F1, download from [lllyasviel/FramePack_F1_I2V_HY_20250503](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503).
+
+2. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the model might be downloaded locally. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--lllyasviel--FramePackI2V_HY/snapshots/`. FramePack-F1 is also available in the same way.
+
+3. **From Kijai's Hugging Face repo:** Download the single file `FramePackI2V_HY_bf16.safetensors` from [Kijai/HunyuanVideo_comfy](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors). Specify the path to this file as the `--dit` argument. No FramePack-F1 model is available here currently.
+
+### VAE Model
+
+Choose one of the following methods:
+
+1. **Use official HunyuanVideo VAE:** Follow the instructions in the main [README.md](../README.md#model-download).
+2. **From hunyuanvideo-community Hugging Face repo:** Download `vae/diffusion_pytorch_model.safetensors` from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo).
+3. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the VAE might be downloaded locally within the HunyuanVideo community model snapshot. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/`.
+
+### Text Encoder 1 (LLaMA) Model
+
+Choose one of the following methods:
+
+1. **From Comfy-Org Hugging Face repo:** Download `split_files/text_encoders/llava_llama3_fp16.safetensors` from [Comfy-Org/HunyuanVideo_repackaged](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged).
+2. **From hunyuanvideo-community Hugging Face repo:** Download the four `.safetensors` files (starting with `text_encoder/model-00001-of-00004.safetensors`) from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo). Specify the path to the first file (`...-00001-of-00004.safetensors`) as the `--text_encoder1` argument.
+3. **From local FramePack installation:** (Same as VAE) Specify the path to the HunyuanVideo community model snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/`.
+
+### Text Encoder 2 (CLIP) Model
+
+Choose one of the following methods:
+
+1. **From Comfy-Org Hugging Face repo:** Download `split_files/text_encoders/clip_l.safetensors` from [Comfy-Org/HunyuanVideo_repackaged](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged).
+2. **From hunyuanvideo-community Hugging Face repo:** Download `text_encoder_2/model.safetensors` from [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo).
+3. **From local FramePack installation:** (Same as VAE) Specify the path to the HunyuanVideo community model snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--hunyuanvideo-community--HunyuanVideo/snapshots/`.
+
+### Image Encoder (SigLIP) Model
+
+Choose one of the following methods:
+
+1. **From Comfy-Org Hugging Face repo:** Download `sigclip_vision_patch14_384.safetensors` from [Comfy-Org/sigclip_vision_384](https://huggingface.co/Comfy-Org/sigclip_vision_384).
+2. **From lllyasviel's Hugging Face repo:** Download `image_encoder/model.safetensors` from [lllyasviel/flux_redux_bfl](https://huggingface.co/lllyasviel/flux_redux_bfl).
+3. **From local FramePack installation:** If you have cloned and run the official FramePack repository, the model might be downloaded locally. Specify the path to the snapshot directory, e.g., `path/to/FramePack/hf_download/hub/models--lllyasviel--flux_redux_bfl/snapshots/`.
+
+
+日本語
+
+※以下のページに重みが一括で公開されています(FramePack-F1を除く)。maybleMyers 氏に感謝いたします。: https://huggingface.co/maybleMyers/framepack_h1111
+
+DiT、VAE、テキストエンコーダー1(LLaMA)、テキストエンコーダー2(CLIP)、および画像エンコーダー(SigLIP)モデルは複数の方法でダウンロードできます。英語の説明を参考にして、ダウンロードしてください。
+
+FramePack公式のリポジトリをクローンして実行した場合、モデルはローカルにダウンロードされている可能性があります。スナップショットディレクトリへのパスを指定してください。例:`path/to/FramePack/hf_download/hub/models--lllyasviel--flux_redux_bfl/snapshots/`
+
+HunyuanVideoの推論をComfyUIですでに行っている場合、いくつかのモデルはすでにダウンロードされている可能性があります。
+
+
+## Pre-caching / 事前キャッシング
+
+The default resolution for FramePack is 640x640. See [the source code](../frame_pack/bucket_tools.py) for the default resolution of each bucket.
+
+The dataset for training must be a video dataset. Image datasets are not supported. You can train on videos of any length. Specify `frame_extraction` as `full` and set `max_frames` to a sufficiently large value. However, if the video is too long, you may run out of VRAM during VAE encoding.
+
+### Latent Pre-caching / latentの事前キャッシング
+
+Latent pre-caching uses a dedicated script for FramePack. You **must** provide the Image Encoder model.
+
+```bash
+python fpack_cache_latents.py \
+ --dataset_config path/to/toml \
+ --vae path/to/vae_model.safetensors \
+ --image_encoder path/to/image_encoder_model.safetensors \
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
+```
+
+Key differences from HunyuanVideo caching:
+- Uses `fpack_cache_latents.py`.
+- Requires the `--image_encoder` argument pointing to the downloaded SigLIP model.
+- The script generates multiple cache files per video, each corresponding to a different section, with the section index appended to the filename (e.g., `..._frame_pos-0000-count_...` becomes `..._frame_pos-0000-0000-count_...`, `..._frame_pos-0000-0001-count_...`, etc.).
+- Image embeddings are calculated using the Image Encoder and stored in the cache files alongside the latents.
+
+For VRAM savings during VAE decoding, consider using `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size`. If VRAM is overflowing and using shared memory, it is recommended to set `--vae_chunk_size` to 16 or 8, and `--vae_spatial_tile_sample_min_size` to 64 or 32.
+
+Specifying `--f1` is required for FramePack-F1 training. For one-frame training, specify `--one_frame`. If you change the presence of these options, please overwrite the existing cache without specifying `--skip_existing`.
+
+`--one_frame_no_2x` and `--one_frame_no_4x` options are available for one-frame training, described in the next section.
+
+**FramePack-F1 support:**
+You can apply the FramePack-F1 sampling method by specifying `--f1` during caching. The training script also requires specifying `--f1` to change the options during sample generation.
+
+By default, the sampling method used is Inverted anti-drifting (the same as during inference with the original FramePack model, using the latent and index in reverse order), described in the paper. You can switch to FramePack-F1 sampling (Vanilla sampling, using the temporally ordered latent and index) by specifying `--f1`.
+
+
+日本語
+
+FramePackのデフォルト解像度は640x640です。各バケットのデフォルト解像度については、[ソースコード](../frame_pack/bucket_tools.py)を参照してください。
+
+画像データセットでの学習は行えません。また動画の長さによらず学習可能です。 `frame_extraction` に `full` を指定して、`max_frames` に十分に大きな値を指定してください。ただし、あまりにも長いとVAEのencodeでVRAMが不足する可能性があります。
+
+latentの事前キャッシングはFramePack専用のスクリプトを使用します。画像エンコーダーモデルを指定する必要があります。
+
+HunyuanVideoのキャッシングとの主な違いは次のとおりです。
+- `fpack_cache_latents.py`を使用します。
+- ダウンロードしたSigLIPモデルを指す`--image_encoder`引数が必要です。
+- スクリプトは、各ビデオに対して複数のキャッシュファイルを生成します。各ファイルは異なるセクションに対応し、セクションインデックスがファイル名に追加されます(例:`..._frame_pos-0000-count_...`は`..._frame_pos-0000-0000-count_...`、`..._frame_pos-0000-0001-count_...`などになります)。
+- 画像埋め込みは画像エンコーダーを使用して計算され、latentとともにキャッシュファイルに保存されます。
+
+VAEのdecode時のVRAM節約のために、`--vae_chunk_size`と`--vae_spatial_tile_sample_min_size`を使用することを検討してください。VRAMがあふれて共有メモリを使用している場合には、`--vae_chunk_size`を16、8などに、`--vae_spatial_tile_sample_min_size`を64、32などに変更することをお勧めします。
+
+FramePack-F1の学習を行う場合は`--f1`を指定してください。これらのオプションの有無を変更する場合には、`--skip_existing`を指定せずに既存のキャッシュを上書きしてください。
+
+**FramePack-F1のサポート:**
+キャッシュ時のオプションに`--f1`を指定することで、FramePack-F1のサンプリング方法を適用できます。学習スクリプトについても`--f1`を指定してサンプル生成時のオプションを変更する必要があります。
+
+デフォルトでは、論文のサンプリング方法 Inverted anti-drifting (無印のFramePackの推論時と同じ、逆順の latent と index を使用)を使用します。`--f1`を指定すると FramePack-F1 の Vanilla sampling (時間順の latent と index を使用)に変更できます。
+
+
+### Text Encoder Output Pre-caching / テキストエンコーダー出力の事前キャッシング
+
+Text encoder output pre-caching also uses a dedicated script.
+
+```bash
+python fpack_cache_text_encoder_outputs.py \
+ --dataset_config path/to/toml \
+ --text_encoder1 path/to/text_encoder1 \
+ --text_encoder2 path/to/text_encoder2 \
+ --batch_size 16
+```
+
+Key differences from HunyuanVideo caching:
+- Uses `fpack_cache_text_encoder_outputs.py`.
+- Requires both `--text_encoder1` (LLaMA) and `--text_encoder2` (CLIP) arguments.
+- Uses `--fp8_llm` option to run the LLaMA Text Encoder 1 in fp8 mode for VRAM savings (similar to `--fp8_t5` in Wan2.1).
+- Saves LLaMA embeddings, attention mask, and CLIP pooler output to the cache file.
+
+
+日本語
+
+テキストエンコーダー出力の事前キャッシングも専用のスクリプトを使用します。
+
+HunyuanVideoのキャッシングとの主な違いは次のとおりです。
+- `fpack_cache_text_encoder_outputs.py`を使用します。
+- LLaMAとCLIPの両方の引数が必要です。
+- LLaMAテキストエンコーダー1をfp8モードで実行するための`--fp8_llm`オプションを使用します(Wan2.1の`--fp8_t5`に似ています)。
+- LLaMAの埋め込み、アテンションマスク、CLIPのプーラー出力をキャッシュファイルに保存します。
+
+
+
+
+## Training / 学習
+
+### Training
+
+Training uses a dedicated script `fpack_train_network.py`. Remember FramePack only supports I2V training.
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 fpack_train_network.py \
+ --dit path/to/dit_model \
+ --vae path/to/vae_model.safetensors \
+ --text_encoder1 path/to/text_encoder1 \
+ --text_encoder2 path/to/text_encoder2 \
+ --image_encoder path/to/image_encoder_model.safetensors \
+ --dataset_config path/to/toml \
+ --sdpa --mixed_precision bf16 \
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
+ --timestep_sampling shift --weighting_scheme none --discrete_flow_shift 3.0 \
+ --max_data_loader_n_workers 2 --persistent_data_loader_workers \
+ --network_module networks.lora_framepack --network_dim 32 \
+ --max_train_epochs 16 --save_every_n_epochs 1 --seed 42 \
+ --output_dir path/to/output_dir --output_name name-of-lora
+```
+
+If you use the command prompt (Windows, not PowerShell), you may need to write them in a single line, or use `^` instead of `\` at the end of each line to continue the command.
+
+The maximum value for `--blocks_to_swap` is 36. The default resolution for FramePack is 640x640, which requires around 17GB of VRAM. If you run out of VRAM, consider lowering the dataset resolution.
+
+Key differences from HunyuanVideo training:
+- Uses `fpack_train_network.py`.
+- `--f1` option is available for FramePack-F1 model training. You need to specify the FramePack-F1 model as `--dit`. This option only changes the sample generation during training. The training process itself is the same as the original FramePack model.
+- **Requires** specifying `--vae`, `--text_encoder1`, `--text_encoder2`, and `--image_encoder`.
+- **Requires** specifying `--network_module networks.lora_framepack`.
+- Optional `--latent_window_size` argument (default 9, should match caching).
+- Memory saving options like `--fp8` (for DiT) and `--fp8_llm` (for Text Encoder 1) are available. `--fp8_scaled` is recommended when using `--fp8` for DiT.
+- `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size` options are available for the VAE to prevent out-of-memory during sampling (similar to caching).
+- `--gradient_checkpointing` is available for memory savings.
+- If you encounter an error when the batch size is greater than 1 (especially when specifying `--sdpa` or `--xformers`, it will always result in an error), please specify `--split_attn`.
+
+
+Training settings (learning rate, optimizers, etc.) are experimental. Feedback is welcome.
+
+
+日本語
+
+FramePackの学習は専用のスクリプト`fpack_train_network.py`を使用します。FramePackはI2V学習のみをサポートしています。
+
+コマンド記述例は英語版を参考にしてください。WindowsでPowerShellではなくコマンドプロンプトを使用している場合、コマンドを1行で記述するか、各行の末尾に`\`の代わりに`^`を付けてコマンドを続ける必要があります。
+
+`--blocks_to_swap`の最大値は36です。FramePackのデフォルト解像度(640x640)では、17GB程度のVRAMが必要です。VRAM容量が不足する場合は、データセットの解像度を下げてください。
+
+HunyuanVideoの学習との主な違いは次のとおりです。
+- `fpack_train_network.py`を使用します。
+- FramePack-F1モデルの学習時には`--f1`を指定してください。この場合、`--dit`にFramePack-F1モデルを指定する必要があります。このオプションは学習時のサンプル生成時のみに影響し、学習プロセス自体は元のFramePackモデルと同じです。
+- `--vae`、`--text_encoder1`、`--text_encoder2`、`--image_encoder`を指定する必要があります。
+- `--network_module networks.lora_framepack`を指定する必要があります。
+- 必要に応じて`--latent_window_size`引数(デフォルト9)を指定できます(キャッシング時と一致させる必要があります)。
+- `--fp8`(DiT用)や`--fp8_llm`(テキストエンコーダー1用)などのメモリ節約オプションが利用可能です。`--fp8_scaled`を使用することをお勧めします。
+- サンプル生成時にメモリ不足を防ぐため、VAE用の`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`オプションが利用可能です(キャッシング時と同様)。
+- メモリ節約のために`--gradient_checkpointing`が利用可能です。
+- バッチサイズが1より大きい場合にエラーが出た時には(特に`--sdpa`や`--xformers`を指定すると必ずエラーになります。)、`--split_attn`を指定してください。
+
+
+
+## Inference
+
+Inference uses a dedicated script `fpack_generate_video.py`.
+
+```bash
+python fpack_generate_video.py \
+ --dit path/to/dit_model \
+ --vae path/to/vae_model.safetensors \
+ --text_encoder1 path/to/text_encoder1 \
+ --text_encoder2 path/to/text_encoder2 \
+ --image_encoder path/to/image_encoder_model.safetensors \
+ --image_path path/to/start_image.jpg \
+ --prompt "A cat walks on the grass, realistic style." \
+ --video_size 512 768 --video_seconds 5 --fps 30 --infer_steps 25 \
+ --attn_mode sdpa --fp8_scaled \
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
+ --save_path path/to/save/dir --output_type both \
+ --seed 1234 --lora_multiplier 1.0 --lora_weight path/to/lora.safetensors
+```
+
+
+Key differences from HunyuanVideo inference:
+- Uses `fpack_generate_video.py`.
+- `--f1` option is available for FramePack-F1 model inference (forward generation). You need to specify the FramePack-F1 model as `--dit`.
+- **Requires** specifying `--vae`, `--text_encoder1`, `--text_encoder2`, and `--image_encoder`.
+- **Requires** specifying `--image_path` for the starting frame.
+- **Requires** specifying `--video_seconds` or `--video_sections`. `--video_seconds` specifies the length of the video in seconds, while `--video_sections` specifies the number of sections. If `--video_sections` is specified, `--video_seconds` is ignored.
+- `--video_size` is the size of the generated video, height and width are specified in that order.
+- `--prompt`: Prompt for generation.
+- Optional `--latent_window_size` argument (default 9, should match caching and training).
+- `--fp8_scaled` option is available for DiT to reduce memory usage. Quality may be slightly lower. `--fp8_llm` option is available to reduce memory usage of Text Encoder 1. `--fp8` alone is also an option for DiT but `--fp8_scaled` potentially offers better quality.
+- LoRA loading options (`--lora_weight`, `--lora_multiplier`, `--include_patterns`, `--exclude_patterns`) are available. `--lycoris` is also supported.
+- `--embedded_cfg_scale` (default 10.0) controls the distilled guidance scale.
+- `--guidance_scale` (default 1.0) controls the standard classifier-free guidance scale. **Changing this from 1.0 is generally not recommended for the base FramePack model.**
+- `--guidance_rescale` (default 0.0) is available but typically not needed.
+- `--bulk_decode` option can decode all frames at once, potentially faster but uses more VRAM during decoding. `--vae_chunk_size` and `--vae_spatial_tile_sample_min_size` options are recommended to prevent out-of-memory errors.
+- `--sample_solver` (default `unipc`) is available but only `unipc` is implemented.
+- `--save_merged_model` option is available to save the DiT model after merging LoRA weights. Inference is skipped if this is specified.
+- `--latent_paddings` option overrides the default padding for each section. Specify it as a comma-separated list of integers, e.g., `--latent_paddings 0,0,0,0`. This option is ignored if `--f1` is specified.
+- `--custom_system_prompt` option overrides the default system prompt for the LLaMA Text Encoder 1. Specify it as a string. See [here](../hunyuan_model/text_encoder.py#L152) for the default system prompt.
+- `--rope_scaling_timestep_threshold` option is the RoPE scaling timestep threshold, default is None (disabled). If set, RoPE scaling is applied only when the timestep exceeds the threshold. Start with around 800 and adjust as needed. This option is intended for one-frame inference and may not be suitable for other cases.
+- `--rope_scaling_factor` option is the RoPE scaling factor, default is 0.5, assuming a resolution of 2x. For 1.5x resolution, around 0.7 is recommended.
+
+Other options like `--video_size`, `--fps`, `--infer_steps`, `--save_path`, `--output_type`, `--seed`, `--attn_mode`, `--blocks_to_swap`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size` function similarly to HunyuanVideo/Wan2.1 where applicable.
+
+`--output_type` supports `latent_images` in addition to the options available in HunyuanVideo/Wan2.1. This option saves the latent and image files in the specified directory.
+
+The LoRA weights that can be specified in `--lora_weight` are not limited to the FramePack weights trained in this repository. You can also specify the HunyuanVideo LoRA weights from this repository and the HunyuanVideo LoRA weights from diffusion-pipe (automatic detection).
+
+The maximum value for `--blocks_to_swap` is 38.
+
+
+日本語
+
+FramePackの推論は専用のスクリプト`fpack_generate_video.py`を使用します。コマンド記述例は英語版を参考にしてください。
+
+HunyuanVideoの推論との主な違いは次のとおりです。
+- `fpack_generate_video.py`を使用します。
+- `--f1`を指定すると、FramePack-F1モデルの推論を行います(順方向で生成)。`--dit`にFramePack-F1モデルを指定する必要があります。
+- `--vae`、`--text_encoder1`、`--text_encoder2`、`--image_encoder`を指定する必要があります。
+- `--image_path`を指定する必要があります(開始フレーム)。
+- `--video_seconds` または `--video_sections` を指定する必要があります。`--video_seconds`は秒単位でのビデオの長さを指定し、`--video_sections`はセクション数を指定します。`--video_sections`を指定した場合、`--video_seconds`は無視されます。
+- `--video_size`は生成するビデオのサイズで、高さと幅をその順番で指定します。
+- `--prompt`: 生成用のプロンプトです。
+- 必要に応じて`--latent_window_size`引数(デフォルト9)を指定できます(キャッシング時、学習時と一致させる必要があります)。
+- DiTのメモリ使用量を削減するために、`--fp8_scaled`オプションを指定可能です。品質はやや低下する可能性があります。またText Encoder 1のメモリ使用量を削減するために、`--fp8_llm`オプションを指定可能です。DiT用に`--fp8`単独のオプションも用意されていますが、`--fp8_scaled`の方が品質が良い可能性があります。
+- LoRAの読み込みオプション(`--lora_weight`、`--lora_multiplier`、`--include_patterns`、`--exclude_patterns`)が利用可能です。LyCORISもサポートされています。
+- `--embedded_cfg_scale`(デフォルト10.0)は、蒸留されたガイダンススケールを制御します。通常は変更しないでください。
+- `--guidance_scale`(デフォルト1.0)は、標準の分類器フリーガイダンススケールを制御します。**FramePackモデルのベースモデルでは、通常1.0から変更しないことをお勧めします。**
+- `--guidance_rescale`(デフォルト0.0)も利用可能ですが、通常は必要ありません。
+- `--bulk_decode`オプションは、すべてのフレームを一度にデコードできるオプションです。高速ですが、デコード中にVRAMを多く使用します。VRAM不足エラーを防ぐために、`--vae_chunk_size`と`--vae_spatial_tile_sample_min_size`オプションを指定することをお勧めします。
+- `--sample_solver`(デフォルト`unipc`)は利用可能ですが、`unipc`のみが実装されています。
+- `--save_merged_model`オプションは、LoRAの重みをマージした後にDiTモデルを保存するためのオプションです。これを指定すると推論はスキップされます。
+- `--latent_paddings`オプションは、各セクションのデフォルトのパディングを上書きします。カンマ区切りの整数リストとして指定します。例:`--latent_paddings 0,0,0,0`。`--f1`を指定した場合は無視されます。
+- `--custom_system_prompt`オプションは、LLaMA Text Encoder 1のデフォルトのシステムプロンプトを上書きします。文字列として指定します。デフォルトのシステムプロンプトは[こちら](../hunyuan_model/text_encoder.py#L152)を参照してください。
+- `--rope_scaling_timestep_threshold`オプションはRoPEスケーリングのタイムステップ閾値で、デフォルトはNone(無効)です。設定すると、タイムステップが閾値以上の場合にのみRoPEスケーリングが適用されます。800程度から初めて調整してください。1フレーム推論時での使用を想定しており、それ以外の場合は想定していません。
+- `--rope_scaling_factor`オプションはRoPEスケーリング係数で、デフォルトは0.5で、解像度が2倍の場合を想定しています。1.5倍なら0.7程度が良いでしょう。
+
+`--video_size`、`--fps`、`--infer_steps`、`--save_path`、`--output_type`、`--seed`、`--attn_mode`、`--blocks_to_swap`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`などの他のオプションは、HunyuanVideo/Wan2.1と同様に機能します。
+
+`--lora_weight`に指定できるLoRAの重みは、当リポジトリで学習したFramePackの重み以外に、当リポジトリのHunyuanVideoのLoRA、diffusion-pipeのHunyuanVideoのLoRAが指定可能です(自動判定)。
+
+`--blocks_to_swap`の最大値は38です。
+
+
+## Batch and Interactive Modes / バッチモードとインタラクティブモード
+
+In addition to single video generation, FramePack now supports batch generation from file and interactive prompt input:
+
+### Batch Mode from File / ファイルからのバッチモード
+
+Generate multiple videos from prompts stored in a text file:
+
+```bash
+python fpack_generate_video.py --from_file prompts.txt
+--dit path/to/dit_model --vae path/to/vae_model.safetensors
+--text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
+--image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
+```
+
+The prompts file format:
+- One prompt per line
+- Empty lines and lines starting with # are ignored (comments)
+- Each line can include prompt-specific parameters using command-line style format:
+
+```
+A beautiful sunset over mountains --w 832 --h 480 --f 5 --d 42 --s 20 --i path/to/start_image.jpg
+A busy city street at night --w 480 --h 832 --i path/to/another_start.jpg
+```
+
+Supported inline parameters (if omitted, default values from the command line are used):
+- `--w`: Width
+- `--h`: Height
+- `--f`: Video seconds
+- `--d`: Seed
+- `--s`: Inference steps
+- `--g` or `--l`: Guidance scale
+- `--i`: Image path (for start image)
+- `--im`: Image mask path
+- `--n`: Negative prompt
+- `--vs`: Video sections
+- `--ei`: End image path
+- `--ci`: Control image path (explained in one-frame inference documentation)
+- `--cim`: Control image mask path (explained in one-frame inference documentation)
+- `--of`: One frame inference mode options (same as `--one_frame_inference` in the command line), options for one-frame inference
+
+In batch mode, models are loaded once and reused for all prompts, significantly improving overall generation time compared to multiple single runs.
+
+### Interactive Mode / インタラクティブモード
+
+Interactive command-line interface for entering prompts:
+
+```bash
+python fpack_generate_video.py --interactive
+--dit path/to/dit_model --vae path/to/vae_model.safetensors
+--text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
+--image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
+```
+
+In interactive mode:
+- Enter prompts directly at the command line
+- Use the same inline parameter format as batch mode
+- Use Ctrl+D (or Ctrl+Z on Windows) to exit
+- Models remain loaded between generations for efficiency
+
+
+日本語
+
+単一動画の生成に加えて、FramePackは現在、ファイルからのバッチ生成とインタラクティブなプロンプト入力をサポートしています。
+
+#### ファイルからのバッチモード
+
+テキストファイルに保存されたプロンプトから複数の動画を生成します:
+
+```bash
+python fpack_generate_video.py --from_file prompts.txt
+--dit path/to/dit_model --vae path/to/vae_model.safetensors
+--text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
+--image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
+```
+
+プロンプトファイルの形式(サンプルは英語ドキュメントを参照):
+- 1行に1つのプロンプト
+- 空行や#で始まる行は無視されます(コメント)
+- 各行にはコマンドライン形式でプロンプト固有のパラメータを含めることができます:
+
+サポートされているインラインパラメータ(省略した場合、コマンドラインのデフォルト値が使用されます)
+- `--w`: 幅
+- `--h`: 高さ
+- `--f`: 動画の秒数
+- `--d`: シード
+- `--s`: 推論ステップ
+- `--g` または `--l`: ガイダンススケール
+- `--i`: 画像パス(開始画像用)
+- `--im`: 画像マスクパス
+- `--n`: ネガティブプロンプト
+- `--vs`: 動画セクション数
+- `--ei`: 終了画像パス
+- `--ci`: 制御画像パス(1フレーム推論のドキュメントで解説)
+- `--cim`: 制御画像マスクパス(1フレーム推論のドキュメントで解説)
+- `--of`: 1フレーム推論モードオプション(コマンドラインの`--one_frame_inference`と同様、1フレーム推論のオプション)
+
+バッチモードでは、モデルは一度だけロードされ、すべてのプロンプトで再利用されるため、複数回の単一実行と比較して全体的な生成時間が大幅に改善されます。
+
+#### インタラクティブモード
+
+プロンプトを入力するためのインタラクティブなコマンドラインインターフェース:
+
+```bash
+python fpack_generate_video.py --interactive
+--dit path/to/dit_model --vae path/to/vae_model.safetensors
+--text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2
+--image_encoder path/to/image_encoder_model.safetensors --save_path output_directory
+```
+
+インタラクティブモードでは:
+- コマンドラインで直接プロンプトを入力
+- バッチモードと同じインラインパラメータ形式を使用
+- 終了するには Ctrl+D (Windowsでは Ctrl+Z) を使用
+- 効率のため、モデルは生成間で読み込まれたままになります
+
+
+## Advanced Video Control Features (Experimental) / 高度なビデオ制御機能(実験的)
+
+This section describes experimental features added to the `fpack_generate_video.py` script to provide finer control over the generated video content, particularly useful for longer videos or sequences requiring specific transitions or states. These features leverage the Inverted Anti-drifting sampling method inherent to FramePack.
+
+### **1. End Image Guidance (`--end_image_path`)**
+
+* **Functionality:** Guides the generation process to make the final frame(s) of the video resemble a specified target image.
+* **Usage:** `--end_image_path `
+* **Mechanism:** The provided image is encoded using the VAE. This latent representation is used as a target or starting point during the generation of the final video section (which is the first step in Inverted Anti-drifting).
+* **Use Cases:** Defining a clear ending for the video, such as a character striking a specific pose or a product appearing in a close-up.
+
+This option is ignored if `--f1` is specified. The end image is not used in the FramePack-F1 model.
+
+### **2. Section Start Image Guidance (`--image_path` Extended Format)**
+
+* **Functionality:** Guides specific sections within the video to start with a visual state close to a provided image.
+ * You can force the start image by setting `--latent_paddings` to `0,0,0,0` (specify the number of sections as a comma-separated list). If `latent_paddings` is set to 1 or more, the specified image will be used as a reference image (default behavior).
+* **Usage:** `--image_path "SECTION_SPEC:path/to/image.jpg;;;SECTION_SPEC:path/to/another.jpg;;;..."`
+ * `SECTION_SPEC`: Defines the target section(s). Rules:
+ * `0`: The first section of the video (generated last in Inverted Anti-drifting).
+ * `-1`: The last section of the video (generated first).
+ * `N` (non-negative integer): The N-th section (0-indexed).
+ * `-N` (negative integer): The N-th section from the end.
+ * `S-E` (range, e.g., `0-2`): Applies the same image guidance to sections S through E (inclusive).
+ * Use `;;;` as a separator between definitions.
+ * If no image is specified for a section, generation proceeds based on the prompt and preceding (future time) section context.
+* **Mechanism:** When generating a specific section, if a corresponding start image is provided, its VAE latent representation is strongly referenced as the "initial state" for that section. This guides the beginning of the section towards the specified image while attempting to maintain temporal consistency with the subsequent (already generated) section.
+* **Use Cases:** Defining clear starting points for scene changes, specifying character poses or attire at the beginning of certain sections.
+
+### **3. Section-Specific Prompts (`--prompt` Extended Format)**
+
+* **Functionality:** Allows providing different text prompts for different sections of the video, enabling more granular control over the narrative or action flow.
+* **Usage:** `--prompt "SECTION_SPEC:Prompt text for section(s);;;SECTION_SPEC:Another prompt;;;..."`
+ * `SECTION_SPEC`: Uses the same rules as `--image_path`.
+ * Use `;;;` as a separator.
+ * If a prompt for a specific section is not provided, the prompt associated with index `0` (or the closest specified applicable prompt) is typically used. Check behavior if defaults are critical.
+* **Mechanism:** During the generation of each section, the corresponding section-specific prompt is used as the primary textual guidance for the model.
+* **Prompt Content Recommendation** when using `--latent_paddings 0,0,0,0` without `--f1` (original FramePack model):
+ * Recall that FramePack uses Inverted Anti-drifting and references future context.
+ * It is recommended to describe "**the main content or state change that should occur in the current section, *and* the subsequent events or states leading towards the end of the video**" in the prompt for each section.
+ * Including the content of subsequent sections in the current section's prompt helps the model maintain context and overall coherence.
+ * Example: For section 1, the prompt might describe what happens in section 1 *and* briefly summarize section 2 (and beyond).
+ * However, based on observations (e.g., the `latent_paddings` comment), the model's ability to perfectly utilize very long-term context might be limited. Experimentation is key. Describing just the "goal for the current section" might also work. Start by trying the "section and onwards" approach.
+* Use the default prompt when `latent_paddings` is >= 1 or `--latent_paddings` is not specified, or when using `--f1` (FramePack-F1 model).
+* **Use Cases:** Describing evolving storylines, gradual changes in character actions or emotions, step-by-step processes over time.
+
+### **Combined Usage Example** (with `--f1` not specified)
+
+Generating a 3-section video of "A dog runs towards a thrown ball, catches it, and runs back":
+
+```bash
+python fpack_generate_video.py \
+ --prompt "0:A dog runs towards a thrown ball, catches it, and runs back;;;1:The dog catches the ball and then runs back towards the viewer;;;2:The dog runs back towards the viewer holding the ball" \
+ --image_path "0:./img_start_running.png;;;1:./img_catching.png;;;2:./img_running_back.png" \
+ --end_image_path ./img_returned.png \
+ --save_path ./output \
+ # ... other arguments
+```
+
+* **Generation Order:** Section 2 -> Section 1 -> Section 0
+* **Generating Section 2:**
+ * Prompt: "The dog runs back towards the viewer holding the ball"
+ * Start Image: `./img_running_back.png`
+ * End Image: `./img_returned.png` (Initial target)
+* **Generating Section 1:**
+ * Prompt: "The dog catches the ball and then runs back towards the viewer"
+ * Start Image: `./img_catching.png`
+ * Future Context: Generated Section 2 latent
+* **Generating Section 0:**
+ * Prompt: "A dog runs towards a thrown ball, catches it, and runs back"
+ * Start Image: `./img_start_running.png`
+ * Future Context: Generated Section 1 & 2 latents
+
+### **Important Considerations**
+
+* **Inverted Generation:** Always remember that generation proceeds from the end of the video towards the beginning. Section `-1` (the last section, `2` in the example) is generated first.
+* **Continuity vs. Guidance:** While start image guidance is powerful, drastically different images between sections might lead to unnatural transitions. Balance guidance strength with the need for smooth flow.
+* **Prompt Optimization:** The prompt content recommendation is a starting point. Fine-tune prompts based on observed model behavior and desired output quality.
+
+
+日本語
+
+### **高度な動画制御機能(実験的)**
+
+このセクションでは、`fpack_generate_video.py` スクリプトに追加された実験的な機能について説明します。これらの機能は、生成される動画の内容をより詳細に制御するためのもので、特に長い動画や特定の遷移・状態が必要なシーケンスに役立ちます。これらの機能は、FramePack固有のInverted Anti-driftingサンプリング方式を活用しています。
+
+#### **1. 終端画像ガイダンス (`--end_image_path`)**
+
+* **機能:** 動画の最後のフレーム(群)を指定したターゲット画像に近づけるように生成を誘導します。
+* **書式:** `--end_image_path <画像ファイルパス>`
+* **動作:** 指定された画像はVAEでエンコードされ、その潜在表現が動画の最終セクション(Inverted Anti-driftingでは最初に生成される)の生成時の目標または開始点として使用されます。
+* **用途:** キャラクターが特定のポーズで終わる、特定の商品がクローズアップで終わるなど、動画の結末を明確に定義する場合。
+
+このオプションは、`--f1`を指定した場合は無視されます。FramePack-F1モデルでは終端画像は使用されません。
+
+#### **2. セクション開始画像ガイダンス (`--image_path` 拡張書式)**
+
+* **機能:** 動画内の特定のセクションが、指定された画像に近い視覚状態から始まるように誘導します。
+ * `--latent_paddings`を`0,0,0,0`(カンマ区切りでセクション数だけ指定)に設定することで、セクションの開始画像を強制できます。`latent_paddings`が1以上の場合、指定された画像は参照画像として使用されます。
+* **書式:** `--image_path "セクション指定子:画像パス;;;セクション指定子:別の画像パス;;;..."`
+ * `セクション指定子`: 対象セクションを定義します。ルール:
+ * `0`: 動画の最初のセクション(Inverted Anti-driftingでは最後に生成)。
+ * `-1`: 動画の最後のセクション(最初に生成)。
+ * `N`(非負整数): N番目のセクション(0始まり)。
+ * `-N`(負整数): 最後からN番目のセクション。
+ * `S-E`(範囲, 例:`0-2`): セクションSからE(両端含む)に同じ画像を適用。
+ * 区切り文字は `;;;` です。
+ * セクションに画像が指定されていない場合、プロンプトと後続(未来時刻)セクションのコンテキストに基づいて生成されます。
+* **動作:** 特定セクションの生成時、対応する開始画像が指定されていれば、そのVAE潜在表現がそのセクションの「初期状態」として強く参照されます。これにより、後続(生成済み)セクションとの時間的連続性を維持しようとしつつ、セクションの始まりを指定画像に近づけます。
+* **用途:** シーン変更の起点を明確にする、特定のセクション開始時のキャラクターのポーズや服装を指定するなど。
+
+#### **3. セクション別プロンプト (`--prompt` 拡張書式)**
+
+* **機能:** 動画のセクションごとに異なるテキストプロンプトを与え、物語やアクションの流れをより細かく指示できます。
+* **書式:** `--prompt "セクション指定子:プロンプトテキスト;;;セクション指定子:別のプロンプト;;;..."`
+ * `セクション指定子`: `--image_path` と同じルールです。
+ * 区切り文字は `;;;` です。
+ * 特定セクションのプロンプトがない場合、通常はインデックス`0`に関連付けられたプロンプト(または最も近い適用可能な指定プロンプト)が使用されます。デフォルトの挙動が重要な場合は確認してください。
+* **動作:** 各セクションの生成時、対応するセクション別プロンプトがモデルへの主要なテキスト指示として使用されます。
+* `latent_paddings`に`0`を指定した場合(非F1モデル)の **プロンプト内容の推奨:**
+ * FramePackはInverted Anti-driftingを採用し、未来のコンテキストを参照することを思い出してください。
+ * 各セクションのプロンプトには、「**現在のセクションで起こるべき主要な内容や状態変化、*および*それに続く動画の終端までの内容**」を記述することを推奨します。
+ * 現在のセクションのプロンプトに後続セクションの内容を含めることで、モデルが全体的な文脈を把握し、一貫性を保つのに役立ちます。
+ * 例:セクション1のプロンプトには、セクション1の内容 *と* セクション2の簡単な要約を記述します。
+ * ただし、モデルの長期コンテキスト完全利用能力には限界がある可能性も示唆されています(例:`latent_paddings`コメント)。実験が鍵となります。「現在のセクションの目標」のみを記述するだけでも機能する場合があります。まずは「セクションと以降」アプローチを試すことをお勧めします。
+* 使用するプロンプトは、`latent_paddings`が`1`以上または指定されていない場合、または`--f1`(FramePack-F1モデル)を使用している場合は、通常のプロンプト内容を記述してください。
+* **用途:** 時間経過に伴うストーリーの変化、キャラクターの行動や感情の段階的な変化、段階的なプロセスなどを記述する場合。
+
+#### **組み合わせ使用例** (`--f1`未指定時)
+
+「投げられたボールに向かって犬が走り、それを捕まえ、走って戻ってくる」3セクション動画の生成:
+(コマンド記述例は英語版を参考にしてください)
+
+* **生成順序:** セクション2 → セクション1 → セクション0
+* **セクション2生成時:**
+ * プロンプト: "犬がボールを咥えてこちらに向かって走ってくる"
+ * 開始画像: `./img_running_back.png`
+ * 終端画像: `./img_returned.png` (初期目標)
+* **セクション1生成時:**
+ * プロンプト: "犬がボールを捕まえ、その後こちらに向かって走ってくる"
+ * 開始画像: `./img_catching.png`
+ * 未来コンテキスト: 生成済みセクション2の潜在表現
+* **セクション0生成時:**
+ * プロンプト: "犬が投げられたボールに向かって走り、それを捕まえ、走って戻ってくる"
+ * 開始画像: `./img_start_running.png`
+ * 未来コンテキスト: 生成済みセクション1 & 2の潜在表現
+
+#### **重要な考慮事項**
+
+* **逆順生成:** 生成は動画の終わりから始まりに向かって進むことを常に意識してください。セクション`-1`(最後のセクション、上の例では `2`)が最初に生成されます。
+* **連続性とガイダンスのバランス:** 開始画像ガイダンスは強力ですが、セクション間で画像が大きく異なると、遷移が不自然になる可能性があります。ガイダンスの強さとスムーズな流れの必要性のバランスを取ってください。
+* **プロンプトの最適化:** 推奨されるプロンプト内容はあくまでも参考です。モデルの観察された挙動と望ましい出力品質に基づいてプロンプトを微調整してください。
+
+
diff --git a/docs/framepack_1f.md b/docs/framepack_1f.md
new file mode 100644
index 0000000000000000000000000000000000000000..66c0fcb2e93c7ab9dba335d2d0c6b37dccc03b51
--- /dev/null
+++ b/docs/framepack_1f.md
@@ -0,0 +1,359 @@
+# FramePack One Frame (Single Frame) Inference and Training / FramePack 1フレーム推論と学習
+
+## Overview / 概要
+
+This document explains advanced inference and training methods using the FramePack model, particularly focusing on **"1-frame inference"** and its extensions. These features aim to leverage FramePack's flexibility to enable diverse image generation and editing tasks beyond simple video generation.
+
+### The Concept and Development of 1-Frame Inference
+
+While FramePack is originally a model for generating sequential video frames (or frame sections), it was discovered that by focusing on its internal structure, particularly how it handles temporal information with RoPE (Rotary Position Embedding), interesting control over single-frame generation is possible.
+
+1. **Basic 1-Frame Inference**:
+ * It takes an initial image and a prompt as input, limiting the number of generated frames to just one.
+ * In this process, by intentionally setting a large RoPE timestamp (`target_index`) for the single frame to be generated, a single static image can be obtained that reflects temporal and semantic changes from the initial image according to the prompt.
+ * This utilizes FramePack's characteristic of being highly sensitive to RoPE timestamps, as it supports bidirectional contexts like "Inverted anti-drifting." This allows for operations similar to natural language-based image editing, albeit in a limited capacity, without requiring additional training.
+
+2. **Kisekaeichi Method (Feature Merging via Post-Reference)**:
+ * This method, an extension of basic 1-frame inference, was **proposed by furusu**. In addition to the initial image, it also uses a reference image corresponding to a "next section-start image" (treated as `clean_latent_post`) as input.
+ * The RoPE timestamp (`target_index`) for the image to be generated is set to an intermediate value between the timestamps of the initial image and the section-end image.
+ * More importantly, masking (e.g., zeroing out specific regions) is applied to the latent representation of each reference image. For example, by setting masks to extract a character's face and body shape from the initial image and clothing textures from the reference image, an image can be generated that fuses the desired features of both, similar to a character "dress-up" or outfit swapping. This method can also be fundamentally achieved without additional training.
+
+3. **1f-mc (one frame multi-control) Method (Proximal Frame Blending)**:
+ * This method was **proposed by mattyamonaca**. It takes two reference images as input: an initial image (e.g., at `t=0`) and a subsequent image (e.g., at `t=1`, the first frame of a section), and generates a single image blending their features.
+ * Unlike Kisekaeichi, latent masking is typically not performed.
+ * To fully leverage this method, additional training using LoRA (Low-Rank Adaptation) is recommended. Through training, the model can better learn the relationship and blending method between the two input images to achieve specific editing effects.
+
+### Integration into a Generalized Control Framework
+
+The concepts utilized in the methods above—specifying reference images, manipulating timestamps, and applying latent masks—have been generalized to create a more flexible control framework.
+Users can arbitrarily specify the following elements for both inference and LoRA training:
+
+* **Control Images**: Any set of input images intended to influence the model.
+* **Clean Latent Index (Indices)**: Timestamps corresponding to each control image. These are treated as `clean latent index` internally by FramePack and can be set to any position on the time axis. This is specified as `control_index`.
+* **Latent Masks**: Masks applied to the latent representation of each control image, allowing selective control over which features from the control images are utilized. This is specified as `control_image_mask_path` or the alpha channel of the control image.
+* **Target Index**: The timestamp for the single frame to be generated.
+
+This generalized control framework, along with corresponding extensions to the inference and LoRA training tools, has enabled advanced applications such as:
+
+* Development of LoRAs that stabilize 1-frame inference effects (e.g., a camera orbiting effect) that were previously unstable with prompts alone.
+* Development of Kisekaeichi LoRAs that learn to perform desired feature merging under specific conditions (e.g., ignoring character information from a clothing reference image), thereby automating the masking process through learning.
+
+These features maximize FramePack's potential and open up new creative possibilities in static image generation and editing. Subsequent sections will detail the specific options for utilizing these functionalities.
+
+
+日本語
+
+このドキュメントでは、FramePackモデルを用いた高度な推論および学習手法、特に「1フレーム推論」とその拡張機能について解説します。これらの機能は、FramePackの柔軟性を活かし、動画生成に留まらない多様な画像生成・編集タスクを実現することを目的としています。
+
+### 1フレーム推論の発想と発展
+
+FramePackは本来、連続する動画フレーム(またはフレームセクション)を生成するモデルですが、その内部構造、特に時間情報を扱うRoPE (Rotary Position Embedding) の扱いに着目することで、単一フレームの生成においても興味深い制御が可能になることが発見されました。
+
+1. **基本的な1フレーム推論**:
+ * 開始画像とプロンプトを入力とし、生成するフレーム数を1フレームに限定します。
+ * この際、生成する1フレームに割り当てるRoPEのタイムスタンプ(`target_index`)を意図的に大きな値に設定することで、開始画像からプロンプトに従って時間的・意味的に変化した単一の静止画を得ることができます。
+ * これは、FramePackがInverted anti-driftingなどの双方向コンテキストに対応するため、RoPEのタイムスタンプに対して敏感に反応する特性を利用したものです。これにより、学習なしで限定的ながら自然言語による画像編集に近い操作が可能です。
+
+2. **kisekaeichi方式 (ポスト参照による特徴マージ)**:
+ * 基本的な1フレーム推論を発展させたこの方式は、**furusu氏により提案されました**。開始画像に加え、「次のセクションの開始画像」に相当する参照画像(`clean_latent_post`として扱われる)も入力として利用します。
+ * 生成する画像のRoPEタイムスタンプ(`target_index`)を、開始画像のタイムスタンプとセクション終端画像のタイムスタンプの中間的な値に設定します。
+ * さらに重要な点として、各参照画像のlatent表現に対してマスク処理(特定領域を0で埋めるなど)を施します。例えば、開始画像からはキャラクターの顔や体型を、参照画像からは服装のテクスチャを抽出するようにマスクを設定することで、キャラクターの「着せ替え」のような、両者の望ましい特徴を融合させた画像を生成できます。この手法も基本的には学習不要で実現可能です。
+
+3. **1f-mc (one frame multi-control) 方式 (近接フレームブレンド)**:
+ * この方式は、**mattyamonaca氏により提案されました**。開始画像(例: `t=0`)と、その直後の画像(例: `t=1`、セクションの最初のフレーム)の2つを参照画像として入力し、それらの特徴をブレンドした単一画像を生成します。
+ * kisekaeichiとは異なり、latentマスクは通常行いません。
+ * この方式の真価を発揮するには、LoRA (Low-Rank Adaptation) による追加学習が推奨されます。学習により、モデルは2つの入力画像間の関係性やブレンド方法をより適切に学習し、特定の編集効果を実現できます。
+
+### 汎用的な制御フレームワークへの統合
+
+上記の各手法で利用されていた「参照画像の指定」「タイムスタンプの操作」「latentマスクの適用」といった概念を一般化し、より柔軟な制御を可能にするための拡張が行われました。
+ユーザーは以下の要素を任意に指定して、推論およびLoRA学習を行うことができます。
+
+* **制御画像 (Control Images)**: モデルに影響を与えるための任意の入力画像群。
+* **Clean Latent Index (Indices)**: 各制御画像に対応するタイムスタンプ。FramePack内部の`clean latent index`として扱われ、時間軸上の任意の位置を指定可能です。`control_index`として指定します。
+* **Latentマスク (Latent Masks)**: 各制御画像のlatentに適用するマスク。これにより、制御画像から利用する特徴を選択的に制御します。`control_image_mask_path`または制御画像のアルファチャンネルとして指定します。
+* **Target Index**: 生成したい単一フレームのタイムスタンプ。
+
+この汎用的な制御フレームワークと、それに対応した推論ツールおよびLoRA学習ツールの拡張により、以下のような高度な応用が可能になりました。
+
+* プロンプトだけでは不安定だった1フレーム推論の効果(例: カメラ旋回)を安定化させるLoRAの開発。
+* マスク処理を手動で行う代わりに、特定の条件下(例: 服の参照画像からキャラクター情報を無視する)で望ましい特徴マージを行うように学習させたkisekaeichi LoRAの開発。
+
+これらの機能は、FramePackのポテンシャルを最大限に引き出し、静止画生成・編集における新たな創造の可能性を拓くものです。以降のセクションでは、これらの機能を実際に利用するための具体的なオプションについて説明します。
+
+
+
+## One Frame (Single Frame) Training / 1フレーム学習
+
+**This feature is experimental.** It trains in the same way as one frame inference.
+
+The dataset must be an image dataset. If you use caption files, you need to specify `control_directory` and place the **start images** in that directory. The `image_directory` should contain the images after the change. The filenames of both directories must match. Caption files should be placed in the `image_directory`.
+
+If you use JSONL files, specify them as `{"image_path": "/path/to/target_image1.jpg", "control_path": "/path/to/source_image1.jpg", "caption": "The object changes to red."}`. The `image_path` should point to the images after the change, and `control_path` should point to the starting images.
+
+For the dataset configuration, see [here](../dataset/dataset_config.md#sample-for-image-dataset-with-control-images) and [here](../dataset/dataset_config.md#framepack-one-frame-training). There are also examples for kisekaeichi and 1f-mc settings.
+
+For single frame training, specify `--one_frame` in `fpack_cache_latents.py` to create the cache. You can also use `--one_frame_no_2x` and `--one_frame_no_4x` options, which have the same meaning as `no_2x` and `no_4x` during inference. It is recommended to set these options to match the inference settings.
+
+If you change whether to use one frame training or these options, please overwrite the existing cache without specifying `--skip_existing`.
+
+Specify `--one_frame` in `fpack_train_network.py` to change the inference method during sample generation.
+
+The optimal training settings are currently unknown. Feedback is welcome.
+
+### Example of prompt file description for sample generation
+
+The command line options `--one_frame_inference` corresponds to `--of`, and `--control_image_path` corresponds to `--ci`.
+
+Note that `--ci` can be specified multiple times, but `--control_image_path` is specified as `--control_image_path img1.png img2.png`, while `--ci` is specified as `--ci img1.png --ci img2.png`.
+
+Normal single frame training:
+```
+The girl wears a school uniform. --i path/to/start.png --ci path/to/start.png --of no_2x,no_4x,target_index=1,control_index=0 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
+```
+
+Kisekaeichi training:
+```
+The girl wears a school uniform. --i path/to/start_with_alpha.png --ci path/to/ref_with_alpha.png --ci path/to/start_with_alpha.png --of no_post,no_2x,no_4x,target_index=5,control_index=0;10 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
+```
+
+
+日本語
+
+**この機能は実験的なものです。** 1フレーム推論と同様の方法で学習を行います。
+
+データセットは画像データセットである必要があります。キャプションファイルを用いる場合は、`control_directory`を追加で指定し、そのディレクトリに**開始画像**を格納してください。`image_directory`には変化後の画像を格納します。両者のファイル名は一致させる必要があります。キャプションファイルは`image_directory`に格納してください。
+
+JSONLファイルを用いる場合は、`{"image_path": "/path/to/target_image1.jpg", "control_path": "/path/to/source_image1.jpg", "caption": "The object changes to red"}`のように指定してください。`image_path`は変化後の画像、`control_path`は開始画像を指定します。
+
+データセットの設定については、[こちら](../dataset/dataset_config.md#sample-for-image-dataset-with-control-images)と[こちら](../dataset/dataset_config.md#framepack-one-frame-training)も参照してください。kisekaeichiと1f-mcの設定例もそちらにあります。
+
+1フレーム学習時は、`fpack_cache_latents.py`に`--one_frame`を指定してキャッシュを作成してください。また`--one_frame_no_2x`と`--one_frame_no_4x`オプションも利用可能です。推論時の`no_2x`、`no_4x`と同じ意味を持ちますので、推論時と同じ設定にすることをお勧めします。
+
+1フレーム学習か否かを変更する場合、またこれらのオプションを変更する場合は、`--skip_existing`を指定せずに既存のキャッシュを上書きしてください。
+
+また、`fpack_train_network.py`に`--one_frame`を指定してサンプル画像生成時の推論方法を変更してください。
+
+最適な学習設定は今のところ不明です。フィードバックを歓迎します。
+
+**サンプル生成のプロンプトファイル記述例**
+
+コマンドラインオプション`--one_frame_inference`に相当する `--of`と、`--control_image_path`に相当する`--ci`が用意されています。
+
+※ `--ci`は複数指定可能ですが、`--control_image_path`は`--control_image_path img1.png img2.png`のようにスペースで区切るのに対して、`--ci`は`--ci img1.png --ci img2.png`のように指定するので注意してください。
+
+通常の1フレーム学習:
+```
+The girl wears a school uniform. --i path/to/start.png --ci path/to/start.png --of no_2x,no_4x,target_index=1,control_index=0 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
+```
+
+kisekaeichi方式:
+```
+The girl wears a school uniform. --i path/to/start_with_alpha.png --ci path/to/ref_with_alpha.png --ci path/to/start_with_alpha.png --of no_post,no_2x,no_4x,target_index=5,control_index=0;10 --d 1111 --f 1 --s 10 --fs 7 --d 1234 --w 384 --h 576
+```
+
+
+
+## One (single) Frame Inference / 1フレーム推論
+
+**This feature is highly experimental** and not officially supported. It is intended for users who want to explore the potential of FramePack for one frame inference, which is not a standard feature of the model.
+
+This script also allows for one frame inference, which is not an official feature of FramePack but rather a custom implementation.
+
+Theoretically, it generates an image after a specified time from the starting image, following the prompt. This means that, although limited, it allows for natural language-based image editing.
+
+To perform one frame inference, specify some option in the `--one_frame_inference` option. Here is an example:
+
+```bash
+--video_sections 1 --output_type latent_images --one_frame_inference default
+```
+
+The `--one_frame_inference` option is recommended to be set to `default` or `no_2x,no_4x`. If you specify `--output_type` as `latent_images`, both the latent and image will be saved.
+
+You can specify the following strings in the `--one_frame_inference` option, separated by commas:
+
+- `no_2x`: Generates without passing clean latents 2x with zero vectors to the model. Slightly improves generation speed. The impact on generation results is unknown.
+- `no_4x`: Generates without passing clean latents 4x with zero vectors to the model. Slightly improves generation speed. The impact on generation results is unknown.
+- `no_post`: Generates without passing clean latents post with zero vectors to the model. Improves generation speed by about 20%, but may result in unstable generation.
+- `target_index=`: Specifies the index of the image to be generated. The default is the last frame (i.e., `latent_window_size`).
+
+For example, you can use `--one_frame_inference default` to pass clean latents 2x, clean latents 4x, and post to the model. `--one_frame_inference no_2x,no_4x` if you want to skip passing clean latents 2x and 4x to the model. `--one_frame_inference target_index=9` can be used to specify the target index for the generated image.
+
+The `--one_frame_inference` option also supports advanced inference, which is described in the next section. This option allows for more detailed control using additional parameters like `target_index` and `control_index` within this option.
+
+Normally, specify `--video_sections 1` to indicate only one section (one image).
+
+Increasing `target_index` from the default of 9 may result in larger changes. It has been confirmed that generation can be performed without breaking up to around 40.
+
+The `--end_image_path` is ignored for one frame inference.
+
+
+日本語
+
+**この機能は非常に実験的であり**、公式にはサポートされていません。FramePackを使用して1フレーム推論の可能性を試したいユーザーに向けたものです。
+
+このスクリプトでは、単一画像の推論を行うこともできます。FramePack公式の機能ではなく、独自の実装です。
+
+理論的には、開始画像から、プロンプトに従い、指定時間経過後の画像を生成します。つまり制限付きですが自然言語による画像編集を行うことができます。
+
+単一画像推論を行うには`--one_frame_inference`オプションに、何らかのオプションを指定してください。記述例は以下の通りです。
+
+```bash
+--video_sections 1 --output_type latent_images --one_frame_inference default
+```
+
+`--one_frame_inference`のオプションは、`default`または `no_2x,no_4x`を推奨します。`--output_type`に`latent_images`を指定するとlatentと画像の両方が保存されます。
+
+`--one_frame_inference`のオプションには、カンマ区切りで以下のオプションを任意個数指定できます。
+
+- `no_2x`: ゼロベクトルの clean latents 2xをモデルに渡さずに生成します。わずかに生成速度が向上します。生成結果への影響は不明です。
+- `no_4x`: ゼロベクトルの clean latents 4xをモデルに渡さずに生成します。わずかに生成速度が向上します。生成結果への影響は不明です。
+- `no_post`: ゼロベクトルの clean latents の post を渡さずに生成します。生成速度が20%程度向上しますが、生成結果が不安定になる場合があります。
+- `target_index=<整数>`: 生成する画像のindexを指定します。デフォルトは最後のフレームです(=latent_window_size)。
+
+たとえば、`--one_frame_inference default`を使用すると、clean latents 2x、clean latents 4x、postをモデルに渡します。`--one_frame_inference no_2x,no_4x`を使用すると、clean latents 2xと4xをモデルに渡すのをスキップします。`--one_frame_inference target_index=9`を使用して、生成する画像のターゲットインデックスを指定できます。
+
+後述の高度な推論では、このオプション内で `target_index`、`control_index` といった追加のパラメータを指定して、より詳細な制御が可能です。
+
+clean latents 2x、clean latents 4x、postをモデルに渡す場合でも値はゼロベクトルですが、値を渡すか否かで結果は変わります。特に`no_post`を指定すると、`latent_window_size`を大きくしたときに生成結果が不安定になる場合があります。
+
+通常は`--video_sections 1` として1セクションのみ(画像1枚)を指定してください。
+
+`target_index` をデフォルトの9から大きくすると、変化量が大きくなる可能性があります。40程度までは破綻なく生成されることを確認しています。
+
+`--end_image_path`は無視されます。
+
+
+
+## kisekaeichi method (Post Reference Options) and 1f-mc (Multi-Control) / kisekaeichi方式(ポスト参照オプション)と1f-mc(マルチコントロール)
+
+The `kisekaeichi` method was proposed by furusu. The `1f-mc` method was proposed by mattyamonaca in pull request [#304](https://github.com/kohya-ss/musubi-tuner/pull/304).
+
+In this repository, these methods have been integrated and can be specified with the `--one_frame_inference` option. This allows for specifying any number of control images as clean latents, along with indices. This means you can specify multiple starting images and multiple clean latent posts. Additionally, masks can be applied to each image.
+
+It is expected to work only with FramePack (non-F1 model) and not with F1 models.
+
+The following options have been added to `--one_frame_inference`. These can be used in conjunction with existing flags like `target_index`, `no_post`, `no_2x`, and `no_4x`.
+
+- `control_index=`: Specifies the index(es) of the clean latent for the control image(s). You must specify the same number of indices as the number of control images specified with `--control_image_path`.
+
+Additionally, the following command-line options have been added. These arguments are only valid when `--one_frame_inference` is specified.
+
+- `--control_image_path [ ...]` : Specifies the path(s) to control (reference) image(s) for one frame inference. Provide one or more paths separated by spaces. Images with an alpha channel can be specified. If an alpha channel is present, it is used as a mask for the clean latent.
+- `--control_image_mask_path [ ...]` : Specifies the path(s) to grayscale mask(s) to be applied to the control image(s). Provide one or more paths separated by spaces. Each mask is applied to the corresponding control image. The 255 areas are referenced, while the 0 areas are ignored.
+
+**Example of specifying kisekaeichi:**
+
+The kisekaeichi method works without training, but using a dedicated LoRA may yield better results.
+
+```bash
+--video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png clean_latent_post_image.png \
+--one_frame_inference target_index=1,control_index=0;10,no_post,no_2x,no_4x --control_image_mask_path ctrl_mask1.png ctrl_mask2.png
+```
+
+In this example, `start_image.png` (for `clean_latent_pre`) and `clean_latent_post_image.png` (for `clean_latent_post`) are the reference images. The `target_index` specifies the index of the generated image. The `control_index` specifies the clean latent index for each control image, so it will be `0;10`. The masks for the control images are specified with `--control_image_mask_path`.
+
+The optimal values for `target_index` and `control_index` are unknown. The `target_index` should be specified as 1 or higher. The `control_index` should be set to an appropriate value relative to `latent_window_size`. Specifying 1 for `target_index` results in less change from the starting image, but may introduce noise. Specifying 9 or 13 may reduce noise but result in larger changes from the original image.
+
+The `control_index` should be larger than `target_index`. Typically, it is set to `10`, but larger values (e.g., around `13-16`) may also work.
+
+Sample images and command lines for reproduction are as follows:
+
+```bash
+python fpack_generate_video.py --video_size 832 480 --video_sections 1 --infer_steps 25 \
+ --prompt "The girl in a school blazer in a classroom." --save_path path/to/output --output_type latent_images \
+ --dit path/to/dit --vae path/to/vae --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2 \
+ --image_encoder path/to/image_encoder --attn_mode sdpa --vae_spatial_tile_sample_min_size 128 --vae_chunk_size 32 \
+ --image_path path/to/kisekaeichi_start.png --control_image_path path/to/kisekaeichi_start.png path/to/kisekaeichi_ref.png
+ --one_frame_inference target_index=1,control_index=0;10,no_2x,no_4x,no_post
+ --control_image_mask_path path/to/kisekaeichi_start_mask.png path/to/kisekaeichi_ref_mask.png --seed 1234
+```
+
+Specify `--fp8_scaled` and `--blocks_to_swap` options according to your VRAM capacity.
+
+- [kisekaeichi_start.png](./kisekaeichi_start.png)
+- [kisekaeichi_ref.png](./kisekaeichi_ref.png)
+- [kisekaeichi_start_mask.png](./kisekaeichi_start_mask.png)
+- [kisekaeichi_ref_mask.png](./kisekaeichi_ref_mask.png)
+
+Generation result: [kisekaeichi_result.png](./kisekaeichi_result.png)
+
+
+**Example of 1f-mc (Multi-Control):**
+
+```bash
+--video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png 2nd_image.png \
+--one_frame_inference target_index=9,control_index=0;1,no_2x,no_4x
+```
+
+In this example, `start_image.png` is the starting image, and `2nd_image.png` is the reference image. The `target_index=9` specifies the index of the generated image, while `control_index=0;1` specifies the clean latent indices for each control image.
+
+1f-mc is intended to be used in combination with a trained LoRA, so adjust `target_index` and `control_index` according to the LoRA's description.
+
+
+日本語
+
+`kisekaeichi`方式はfurusu氏により提案されました。また`1f-mc`方式はmattyamonaca氏によりPR [#304](https://github.com/kohya-ss/musubi-tuner/pull/304) で提案されました。
+
+当リポジトリではこれらの方式を統合し、`--one_frame_inference`オプションで指定できるようにしました。これにより、任意の枚数の制御用画像を clean latentとして指定し、さらにインデックスを指定できます。つまり開始画像の複数枚指定やclean latent postの複数枚指定などが可能です。また、それぞれの画像にマスクを適用することもできます。
+
+なお、FramePack無印のみ動作し、F1モデルでは動作しないと思われます。
+
+`--one_frame_inference`に以下のオプションが追加されています。`target_index`、`no_post`、`no_2x`や`no_4x`など既存のフラグと併用できます。
+
+- `control_index=<整数またはセミコロン区切りの整数>`: 制御用画像のclean latentのインデックスを指定します。`--control_image_path`で指定した制御用画像の数と同じ数のインデックスを指定してください。
+
+またコマンドラインオプションに以下が追加されています。これらの引数は`--one_frame_inference`を指定した場合のみ有効です。
+
+- `--control_image_path <パス1> [<パス2> ...]` : 1フレーム推論用の制御用(参照)画像のパスを1つ以上、スペース区切りで指定します。アルファチャンネルを持つ画像が指定可能です。アルファチャンネルがある場合は、clean latentへのマスクとして利用されます。
+- `--control_image_mask_path <パス1> [<パス2> ...]` : 制御用画像に適用するグレースケールマスクのパスを1つ以上、スペース区切りで指定します。各マスクは対応する制御用画像に適用されます。255の部分が参照される部分、0の部分が無視される部分です。
+
+**kisekaeichiの指定例**:
+
+kisekaeichi方式は学習なしでも動作しますが、専用のLoRAを使用することで、より良い結果が得られる可能性があります。
+
+```bash
+--video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png clean_latent_post_image.png \
+--one_frame_inference target_index=1,control_index=0;10,no_post,no_2x,no_4x --control_image_mask_path ctrl_mask1.png ctrl_mask2.png
+```
+
+`start_image.png`(clean_latent_preに相当)と`clean_latent_post_image.png`は参照画像(clean_latent_postに相当)です。`target_index`は生成する画像のインデックスを指定します。`control_index`はそれぞれの制御用画像のclean latent indexを指定しますので、`0;10` になります。また`--control_image_mask_path`に制御用画像に適用するマスクを指定します。
+
+`target_index`、`control_index`の最適値は不明です。`target_index`は1以上を指定してください。`control_index`は`latent_window_size`に対して適切な値を指定してください。`target_index`に1を指定すると開始画像からの変化が少なくなりますが、ノイズが乗ったりすることが多いようです。9や13などを指定するとノイズは改善されるかもしれませんが、元の画像からの変化が大きくなります。
+
+`control_index`は`target_index`より大きい値を指定してください。通常は`10`ですが、これ以上大きな値、たとえば`13~16程度でも動作するようです。
+
+サンプル画像と再現のためのコマンドラインは以下のようになります。
+
+```bash
+python fpack_generate_video.py --video_size 832 480 --video_sections 1 --infer_steps 25 \
+ --prompt "The girl in a school blazer in a classroom." --save_path path/to/output --output_type latent_images \
+ --dit path/to/dit --vae path/to/vae --text_encoder1 path/to/text_encoder1 --text_encoder2 path/to/text_encoder2 \
+ --image_encoder path/to/image_encoder --attn_mode sdpa --vae_spatial_tile_sample_min_size 128 --vae_chunk_size 32 \
+ --image_path path/to/kisekaeichi_start.png --control_image_path path/to/kisekaeichi_start.png path/to/kisekaeichi_ref.png
+ --one_frame_inference target_index=1,control_index=0;10,no_2x,no_4x,no_post
+ --control_image_mask_path path/to/kisekaeichi_start_mask.png path/to/kisekaeichi_ref_mask.png --seed 1234
+```
+
+VRAM容量に応じて、`--fp8_scaled`や`--blocks_to_swap`等のオプションを調整してください。
+
+- [kisekaeichi_start.png](./kisekaeichi_start.png)
+- [kisekaeichi_ref.png](./kisekaeichi_ref.png)
+- [kisekaeichi_start_mask.png](./kisekaeichi_start_mask.png)
+- [kisekaeichi_ref_mask.png](./kisekaeichi_ref_mask.png)
+
+生成結果:
+- [kisekaeichi_result.png](./kisekaeichi_result.png)
+
+**1f-mcの指定例**:
+
+```bash
+--video_sections 1 --output_type latent_images --image_path start_image.png --control_image_path start_image.png 2nd_image.png \
+--one_frame_inference target_index=9,control_index=0;1,no_2x,no_4x
+```
+
+この例では、`start_image.png`が開始画像で、`2nd_image.png`が参照画像です。`target_index=9`は生成する画像のインデックスを指定し、`control_index=0;1`はそれぞれの制御用画像のclean latent indexを指定しています。
+
+1f-mcは学習したLoRAと組み合わせることを想定していますので、そのLoRAの説明に従って、`target_index`や`control_index`を調整してください。
+
+
\ No newline at end of file
diff --git a/docs/kisekaeichi_ref.png b/docs/kisekaeichi_ref.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3c97e632672364b360c6a26d2bc251dcf485dd8
--- /dev/null
+++ b/docs/kisekaeichi_ref.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5037f0a0cfb1a6b0a8d1f19fb462df75fb53384d0d9e654c359ca984fafa605
+size 583507
diff --git a/docs/kisekaeichi_ref_mask.png b/docs/kisekaeichi_ref_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb16684c04e48d169f4c96088dda7d622a71e75c
Binary files /dev/null and b/docs/kisekaeichi_ref_mask.png differ
diff --git a/docs/kisekaeichi_result.png b/docs/kisekaeichi_result.png
new file mode 100644
index 0000000000000000000000000000000000000000..52c13c39df20aee7b405e5faa64104ea0c4c641f
--- /dev/null
+++ b/docs/kisekaeichi_result.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:223dacb98ac834a442ee124641a6b852b1cde3bc1f11939e78192fc8be2f7b49
+size 408282
diff --git a/docs/kisekaeichi_start.png b/docs/kisekaeichi_start.png
new file mode 100644
index 0000000000000000000000000000000000000000..b31c37500db6083aa54b1c6b868cb8a56d928745
--- /dev/null
+++ b/docs/kisekaeichi_start.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beee4a910402ef2798b00aa4d193b0b7186380ed24928a4d39acc8635d2cfdaf
+size 1033975
diff --git a/docs/kisekaeichi_start_mask.png b/docs/kisekaeichi_start_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..c568f034c34492e71393a5d4fad7031c7e686492
Binary files /dev/null and b/docs/kisekaeichi_start_mask.png differ
diff --git a/docs/sampling_during_training.md b/docs/sampling_during_training.md
new file mode 100644
index 0000000000000000000000000000000000000000..e466331eb9e29a8ff64183fa037c632dcb671ac6
--- /dev/null
+++ b/docs/sampling_during_training.md
@@ -0,0 +1,116 @@
+> 📝 Click on the language section to expand / 言語をクリックして展開
+
+# Sampling during training / 学習中のサンプル画像生成
+
+By preparing a prompt file, you can generate sample images during training.
+
+Please be aware that it consumes a considerable amount of VRAM, so be careful when generating sample images for videos with a large number of frames. Also, since it takes time to generate, adjust the frequency of sample image generation as needed.
+
+
+日本語
+
+プロンプトファイルを用意することで、学習中にサンプル画像を生成することができます。
+
+VRAMをそれなりに消費しますので、特にフレーム数が多い動画を生成する場合は注意してください。また生成には時間がかかりますので、サンプル画像生成の頻度は適宜調整してください。
+
+
+## How to use / 使い方
+
+### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
+
+Example of command line options for training with sampling / 記述例:
+
+```bash
+--vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
+--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
+--text_encoder1 path/to/ckpts/text_encoder
+--text_encoder2 path/to/ckpts/text_encoder_2
+--sample_prompts /path/to/prompt_file.txt
+--sample_every_n_epochs 1 --sample_every_n_steps 1000 --sample_at_first
+```
+
+`--vae`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size`, `--text_encoder1`, `--text_encoder2` are the same as when generating images, so please refer to [here](/README.md#inference) for details. `--fp8_llm` can also be specified.
+
+`--sample_prompts` specifies the path to the prompt file used for sample image generation. Details are described below.
+
+`--sample_every_n_epochs` specifies how often to generate sample images in epochs, and `--sample_every_n_steps` specifies how often to generate sample images in steps.
+
+`--sample_at_first` is specified when generating sample images at the beginning of training.
+
+Sample images and videos are saved in the `sample` directory in the directory specified by `--output_dir`. They are saved as `.png` for still images and `.mp4` for videos.
+
+
+日本語
+
+`--vae`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`、`--text_encoder1`、`--text_encoder2`は、画像生成時と同様ですので、詳細は[こちら](/README.ja.md#推論)を参照してください。`--fp8_llm`も指定可能です。
+
+`--sample_prompts`は、サンプル画像生成に使用するプロンプトファイルのパスを指定します。詳細は後述します。
+
+`--sample_every_n_epochs`は、何エポックごとにサンプル画像を生成するかを、`--sample_every_n_steps`は、何ステップごとにサンプル画像を生成するかを指定します。
+
+`--sample_at_first`は、学習開始時にサンプル画像を生成する場合に指定します。
+
+サンプル画像、動画は、`--output_dir`で指定したディレクトリ内の、`sample`ディレクトリに保存されます。静止画の場合は`.png`、動画の場合は`.mp4`で保存されます。
+
+
+### Prompt file / プロンプトファイル
+
+The prompt file is a text file that contains the prompts for generating sample images. The example is as follows. / プロンプトファイルは、サンプル画像生成のためのプロンプトを記述したテキストファイルです。例は以下の通りです。
+
+```
+# prompt 1: for generating a cat video
+A cat walks on the grass, realistic style. --w 640 --h 480 --f 25 --d 1 --s 20
+
+# prompt 2: for generating a dog image
+A dog runs on the beach, realistic style. --w 960 --h 544 --f 1 --d 2 --s 20
+```
+
+A line starting with `#` is a comment.
+
+* `--w` specifies the width of the generated image or video. The default is 256.
+* `--h` specifies the height. The default is 256.
+* `--f` specifies the number of frames. The default is 1, which generates a still image.
+* `--d` specifies the seed. The default is random.
+* `--s` specifies the number of steps in generation. The default is 20.
+* `--g` specifies the embedded guidance scale (not CFG scale). The default is 6.0 for HunyuanVideo, 10.0 for FramePack, which is the default value during inference of each architecture. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
+* `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10). Ignore this option for FramePack models (it uses 10.0).
+
+If you train I2V models, you must add the following option.
+
+* `--i path/to/image.png`: the image path for image2video inference.
+
+If you train Wan2.1-Fun-Control models, you must add the following option.
+
+* `--cn path/to/control_video_or_dir_of_images`: the path to the video or directory containing multiple images for control.
+
+If you train the model with classifier free guidance (such as Wan2.1), you can use the additional options below.
+
+*`--n negative prompt...`: the negative prompt for the classifier free guidance. The default prompt for each model is used if omitted.
+*`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
+
+
+日本語
+
+`#` で始まる行はコメントです。
+
+* `--w` 生成画像、動画の幅を指定します。省略時は256です。
+* `--h` 高さを指定します。省略時は256です。
+* `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
+* `--d` シードを指定します。省略時はランダムです。
+* `--s` 生成におけるステップ数を指定します。省略時は20です。
+* `--g` embedded guidance scaleを指定します(CFG scaleではありません)。省略時はHunyuanVideoは6.0、FramePackは10.0で、各アーキテクチャの推論時のデフォルト値です。SkyReels V1モデルの場合は1.0を指定してください。Wan2.1モデルの場合はこのオプションは無視されます。
+* `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。FramePackモデルはこのオプションは無視され、10.0が使用されます。
+
+I2Vモデルを学習する場合、以下のオプションを追加してください。
+
+* `--i path/to/image.png`: image2video推論用の画像パス。
+
+Wan2.1-Fun-Controlモデルを学習する場合、以下のオプションを追加してください。
+
+* `--cn path/to/control_video_or_dir_of_images`: control用の動画または複数枚の画像を含むディレクトリのパス。
+
+classifier free guidance(ネガティブプロンプト)を必要とするモデル(Wan2.1など)を学習する場合、以下の追加オプションを使用できます。
+
+*`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。省略時はモデルごとのデフォルトプロンプトが使用されます。
+*`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
+
diff --git a/docs/wan.md b/docs/wan.md
new file mode 100644
index 0000000000000000000000000000000000000000..27a457a3977cbfe90641000c5b41f39d011979d8
--- /dev/null
+++ b/docs/wan.md
@@ -0,0 +1,531 @@
+> 📝 Click on the language section to expand / 言語をクリックして展開
+
+# Wan 2.1
+
+## Overview / 概要
+
+This is an unofficial training and inference script for [Wan2.1](https://github.com/Wan-Video/Wan2.1). The features are as follows.
+
+- fp8 support and memory reduction by block swap: Inference of a 720x1280x81frames videos with 24GB VRAM, training with 720x1280 images with 24GB VRAM
+- Inference without installing Flash attention (using PyTorch's scaled dot product attention)
+- Supports xformers and Sage attention
+
+This feature is experimental.
+
+
+日本語
+[Wan2.1](https://github.com/Wan-Video/Wan2.1) の非公式の学習および推論スクリプトです。
+
+以下の特徴があります。
+
+- fp8対応およびblock swapによる省メモリ化:720x1280x81framesの動画を24GB VRAMで推論可能、720x1280の画像での学習が24GB VRAMで可能
+- Flash attentionのインストールなしでの実行(PyTorchのscaled dot product attentionを使用)
+- xformersおよびSage attention対応
+
+この機能は実験的なものです。
+
+
+## Download the model / モデルのダウンロード
+
+Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main
+
+Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
+
+Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
+
+Wan2.1 Fun Control model weights can be downloaded from [here](https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control). Navigate to each weight page and download. The Fun Control model seems to support not only T2V but also I2V tasks.
+
+Please select the appropriate weights according to T2V, I2V, resolution, model size, etc.
+
+`fp16` and `bf16` models can be used, and `fp8_e4m3fn` models can be used if `--fp8` (or `--fp8_base`) is specified without specifying `--fp8_scaled`. **Please note that `fp8_scaled` models are not supported even with `--fp8_scaled`.**
+
+(Thanks to Comfy-Org for providing the repackaged weights.)
+
+### Model support matrix / モデルサポートマトリックス
+
+* columns: training dtype (行:学習時のデータ型)
+* rows: model dtype (列:モデルのデータ型)
+
+| model \ training |bf16|fp16|--fp8_base|--fp8base & --fp8_scaled|
+|--|--|--|--|--|
+|bf16|✓|--|✓|✓|
+|fp16|--|✓|✓|✓|
+|fp8_e4m3fn|--|--|✓|--|
+|fp8_scaled|--|--|--|--|
+
+
+日本語
+T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main
+
+VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
+
+DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
+
+Wan2.1 Fun Controlモデルの重みは、[こちら](https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control)から、それぞれの重みのページに遷移し、ダウンロードしてください。Fun ControlモデルはT2VだけでなくI2Vタスクにも対応しているようです。
+
+T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。
+
+`fp16` および `bf16` モデルを使用できます。また、`--fp8` (または`--fp8_base`)を指定し`--fp8_scaled`を指定をしないときには `fp8_e4m3fn` モデルを使用できます。**`fp8_scaled` モデルはいずれの場合もサポートされていませんのでご注意ください。**
+
+(repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
+
+
+## Pre-caching / 事前キャッシュ
+
+### Latent Pre-caching
+
+Latent pre-caching is almost the same as in HunyuanVideo. Create the cache using the following command:
+
+```bash
+python wan_cache_latents.py --dataset_config path/to/toml --vae path/to/wan_2.1_vae.safetensors
+```
+
+If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. If not specified, the training will raise an error.
+
+If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
+
+The control video settings are required for training the Fun-Control model. Please refer to [Dataset Settings](/dataset/dataset_config.md#sample-for-video-dataset-with-control-images) for details.
+
+
+日本語
+latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
+
+I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
+
+VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
+
+Fun-Controlモデルを学習する場合は、制御用動画の設定が必要です。[データセット設定](/dataset/dataset_config.md#sample-for-video-dataset-with-control-images)を参照してください。
+
+
+### Text Encoder Output Pre-caching
+
+Text encoder output pre-caching is also almost the same as in HunyuanVideo. Create the cache using the following command:
+
+```bash
+python wan_cache_text_encoder_outputs.py --dataset_config path/to/toml --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --batch_size 16
+```
+
+Adjust `--batch_size` according to your available VRAM.
+
+For systems with limited VRAM (less than ~16GB), use `--fp8_t5` to run the T5 in fp8 mode.
+
+
+日本語
+テキストエンコーダ出力の事前キャッシングもHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
+
+使用可能なVRAMに合わせて `--batch_size` を調整してください。
+
+VRAMが限られているシステム(約16GB未満)の場合は、T5をfp8モードで実行するために `--fp8_t5` を使用してください。
+
+
+## Training / 学習
+
+### Training
+
+Start training using the following command (input as a single line):
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py
+ --task t2v-1.3B
+ --dit path/to/wan2.1_xxx_bf16.safetensors
+ --dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
+ --max_data_loader_n_workers 2 --persistent_data_loader_workers
+ --network_module networks.lora_wan --network_dim 32
+ --timestep_sampling shift --discrete_flow_shift 3.0
+ --max_train_epochs 16 --save_every_n_epochs 1 --seed 42
+ --output_dir path/to/output_dir --output_name name-of-lora
+```
+The above is an example. The appropriate values for `timestep_sampling` and `discrete_flow_shift` need to be determined by experimentation.
+
+For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
+
+`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (for Wan2.1 official models), `t2v-1.3B-FC`, `t2v-14B-FC`, and `i2v-14B-FC` (for Wan2.1 Fun Control model). Specify the DiT weights for the task with `--dit`.
+
+Don't forget to specify `--network_module networks.lora_wan`.
+
+Other options are mostly the same as `hv_train_network.py`.
+
+Use `convert_lora.py` for converting the LoRA weights after training, as in HunyuanVideo.
+
+
+日本語
+`timestep_sampling`や`discrete_flow_shift`は一例です。どのような値が適切かは実験が必要です。
+
+その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
+
+`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (これらはWan2.1公式モデル)、`t2v-1.3B-FC`, `t2v-14B-FC`, `i2v-14B-FC`(Wan2.1-Fun Controlモデル)を指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
+
+ `--network_module` に `networks.lora_wan` を指定することを忘れないでください。
+
+その他のオプションは、ほぼ`hv_train_network.py`と同様です。
+
+学習後のLoRAの重みの変換は、HunyuanVideoと同様に`convert_lora.py`を使用してください。
+
+
+### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
+
+Example of command line options for training with sampling / 記述例:
+
+```bash
+--vae path/to/wan_2.1_vae.safetensors
+--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
+--sample_prompts /path/to/prompt_file.txt
+--sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first
+```
+Each option is the same as when generating images or as HunyuanVideo. Please refer to [here](/docs/sampling_during_training.md) for details.
+
+If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
+
+You can specify the initial image, the negative prompt and the control video (for Wan2.1-Fun-Control) in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
+
+
+日本語
+各オプションは推論時、およびHunyuanVideoの場合と同様です。[こちら](/docs/sampling_during_training.md)を参照してください。
+
+I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
+
+プロンプトファイルで、初期画像やネガティブプロンプト、制御動画(Wan2.1-Fun-Control用)等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
+
+
+
+## Inference / 推論
+
+### Inference Options Comparison / 推論オプション比較
+
+#### Speed Comparison (Faster → Slower) / 速度比較(速い→遅い)
+*Note: Results may vary depending on GPU type*
+
+fp8_fast > bf16/fp16 (no block swap) > fp8 > fp8_scaled > bf16/fp16 (block swap)
+
+#### Quality Comparison (Higher → Lower) / 品質比較(高→低)
+
+bf16/fp16 > fp8_scaled > fp8 >> fp8_fast
+
+### T2V Inference / T2V推論
+
+The following is an example of T2V inference (input as a single line):
+
+```bash
+python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video_length 81 --infer_steps 20
+--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
+--dit path/to/wan2.1_t2v_1.3B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
+--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
+--attn_mode torch
+```
+
+`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (these are Wan2.1 official models), `t2v-1.3B-FC`, `t2v-14B-FC` and `i2v-14B-FC` (for Wan2.1-Fun Control model).
+
+`--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
+
+Specifying `--fp8` runs DiT in fp8 mode. fp8 can significantly reduce memory consumption but may impact output quality.
+
+`--fp8_scaled` can be specified in addition to `--fp8` to run the model in fp8 weights optimization. This increases memory consumption and speed slightly but improves output quality. See [here](advanced_config.md#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化) for details.
+
+`--fp8_fast` option is also available for faster inference on RTX 40x0 GPUs. This option requires `--fp8_scaled` option. **This option seems to degrade the output quality.**
+
+`--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
+
+`--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
+
+`--flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
+
+`--guidance_scale` can be used to specify the guidance scale for classifier free guidance (default 5.0).
+
+`--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
+
+`--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
+
+`--compile` enables torch.compile. See [here](/README.md#inference) for details.
+
+`--trim_tail_frames` can be used to trim the tail frames when saving. The default is 0.
+
+`--cfg_skip_mode` specifies the mode for skipping CFG in different steps. The default is `none` (all steps).`--cfg_apply_ratio` specifies the ratio of steps where CFG is applied. See below for details.
+
+`--include_patterns` and `--exclude_patterns` can be used to specify which LoRA modules to apply or exclude during training. If not specified, all modules are applied by default. These options accept regular expressions.
+
+`--include_patterns` specifies the modules to be applied, and `--exclude_patterns` specifies the modules to be excluded. The regular expression is matched against the LoRA key name, and include takes precedence.
+
+The key name to be searched is in sd-scripts format (`lora_unet_`). For example, `lora_unet_blocks_9_cross_attn_k`.
+
+For example, if you specify `--exclude_patterns "blocks_[23]\d_"`, it will exclude modules containing `blocks_20` to `blocks_39`. If you specify `--include_patterns "cross_attn" --exclude_patterns "blocks_(0|1|2|3|4)_"`, it will apply LoRA to modules containing `cross_attn` and not containing `blocks_0` to `blocks_4`.
+
+If you specify multiple LoRA weights, please specify them with multiple arguments. For example: `--include_patterns "cross_attn" ".*" --exclude_patterns "dummy_do_not_exclude" "blocks_(0|1|2|3|4)"`. `".*"` is a regex that matches everything. `dummy_do_not_exclude` is a dummy regex that does not match anything.
+
+`--cpu_noise` generates initial noise on the CPU. This may result in the same results as ComfyUI with the same seed (depending on other settings).
+
+If you are using the Fun Control model, specify the control video with `--control_path`. You can specify a video file or a folder containing multiple image files. The number of frames in the video file (or the number of images) should be at least the number specified in `--video_length` (plus 1 frame if you specify `--end_image_path`).
+
+Please try to match the aspect ratio of the control video with the aspect ratio specified in `--video_size` (there may be some deviation from the initial image of I2V due to the use of bucketing processing).
+
+Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
+
+
+日本語
+`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` (これらはWan2.1公式モデル)、`t2v-1.3B-FC`, `t2v-14B-FC`, `i2v-14B-FC`(Wan2.1-Fun Controlモデル)を指定します。
+
+`--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
+
+`--fp8` を指定するとDiTモデルをfp8形式で実行します。fp8はメモリ消費を大幅に削減できますが、出力品質に影響を与える可能性があります。
+
+`--fp8_scaled` を `--fp8` と併用すると、fp8への重み量子化を行います。メモリ消費と速度はわずかに悪化しますが、出力品質が向上します。詳しくは[こちら](advanced_config.md#fp8-weight-optimization-for-models--モデルの重みのfp8への最適化)を参照してください。
+
+`--fp8_fast` オプションはRTX 40x0 GPUでの高速推論に使用されるオプションです。このオプションは `--fp8_scaled` オプションが必要です。**出力品質が劣化するようです。**
+
+`--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
+
+`--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
+
+`--flow_shift` でflow shiftを指定できます(480pのI2Vの場合はデフォルト3.0、それ以外は5.0)。
+
+`--guidance_scale` でclassifier free guianceのガイダンススケールを指定できます(デフォルト5.0)。
+
+`--blocks_to_swap` は推論時のblock swapの数です。デフォルト値はNone(block swapなし)です。最大値は14Bモデルの場合39、1.3Bモデルの場合29です。
+
+`--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
+
+`--compile`でtorch.compileを有効にします。詳細については[こちら](/README.md#inference)を参照してください。
+
+`--trim_tail_frames` で保存時に末尾のフレームをトリミングできます。デフォルトは0です。
+
+`--cfg_skip_mode` は異なるステップでCFGをスキップするモードを指定します。デフォルトは `none`(全ステップ)。`--cfg_apply_ratio` はCFGが適用されるステップの割合を指定します。詳細は後述します。
+
+LoRAのどのモジュールを適用するかを、`--include_patterns`と`--exclude_patterns`で指定できます(未指定時・デフォルトは全モジュール適用されます
+)。これらのオプションには、正規表現を指定します。`--include_patterns`は適用するモジュール、`--exclude_patterns`は適用しないモジュールを指定します。正規表現がLoRAのキー名に含まれるかどうかで判断され、includeが優先されます。
+
+検索対象となるキー名は sd-scripts 形式(`lora_unet_<モジュール名のドットを_に置換したもの>`)です。例:`lora_unet_blocks_9_cross_attn_k`
+
+たとえば `--exclude_patterns "blocks_[23]\d_"`のみを指定すると、`blocks_20`から`blocks_39`を含むモジュールが除外されます。`--include_patterns "cross_attn" --exclude_patterns "blocks_(0|1|2|3|4)_"`のようにincludeとexcludeを指定すると、`cross_attn`を含むモジュールで、かつ`blocks_0`から`blocks_4`を含まないモジュールにLoRAが適用されます。
+
+複数のLoRAの重みを指定する場合は、複数個の引数で指定してください。例:`--include_patterns "cross_attn" ".*" --exclude_patterns "dummy_do_not_exclude" "blocks_(0|1|2|3|4)"` `".*"`は全てにマッチする正規表現です。`dummy_do_not_exclude`は何にもマッチしないダミーの正規表現です。
+
+`--cpu_noise`を指定すると初期ノイズをCPUで生成します。これにより同一seed時の結果がComfyUIと同じになる可能性があります(他の設定にもよります)。
+
+Fun Controlモデルを使用する場合は、`--control_path`で制御用の映像を指定します。動画ファイル、または複数枚の画像ファイルを含んだフォルダを指定できます。動画ファイルのフレーム数(または画像の枚数)は、`--video_length`で指定したフレーム数以上にしてください(後述の`--end_image_path`を指定した場合は、さらに+1フレーム)。
+
+制御用の映像のアスペクト比は、`--video_size`で指定したアスペクト比とできるかぎり合わせてください(bucketingの処理を流用しているためI2Vの初期画像とズレる場合があります)。
+
+その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
+
+
+#### CFG Skip Mode / CFGスキップモード
+
+ These options allow you to balance generation speed against prompt accuracy. More skipped steps results in faster generation with potential quality degradation.
+
+Setting `--cfg_apply_ratio` to 0.5 speeds up the denoising loop by up to 25%.
+
+`--cfg_skip_mode` specified one of the following modes:
+
+- `early`: Skips CFG in early steps for faster generation, applying guidance mainly in later refinement steps
+- `late`: Skips CFG in later steps, applying guidance during initial structure formation
+- `middle`: Skips CFG in middle steps, applying guidance in both early and later steps
+- `early_late`: Skips CFG in both early and late steps, applying only in middle steps
+- `alternate`: Applies CFG in alternate steps based on the specified ratio
+- `none`: Applies CFG at all steps (default)
+
+`--cfg_apply_ratio` specifies a value from 0.0 to 1.0 controlling the proportion of steps where CFG is applied. For example, setting 0.5 means CFG will be applied in only 50% of the steps.
+
+If num_steps is 10, the following table shows the steps where CFG is applied based on the `--cfg_skip_mode` option (A means CFG is applied, S means it is skipped, `--cfg_apply_ratio` is 0.6):
+
+| skip mode | CFG apply pattern |
+|---|---|
+| early | SSSSAAAAAA |
+| late | AAAAAASSSS |
+| middle | AAASSSSAAA |
+| early_late | SSAAAAAASS |
+| alternate | SASASAASAS |
+
+The appropriate settings are unknown, but you may want to try `late` or `early_late` mode with a ratio of around 0.3 to 0.5.
+
+日本語
+これらのオプションは、生成速度とプロンプトの精度のバランスを取ることができます。スキップされるステップが多いほど、生成速度が速くなりますが、品質が低下する可能性があります。
+
+ratioに0.5を指定することで、デノイジングのループが最大25%程度、高速化されます。
+
+`--cfg_skip_mode` は次のモードのいずれかを指定します:
+
+- `early`:初期のステップでCFGをスキップして、主に終盤の精細化のステップで適用します
+- `late`:終盤のステップでCFGをスキップし、初期の構造が決まる段階で適用します
+- `middle`:中間のステップでCFGをスキップし、初期と終盤のステップの両方で適用します
+- `early_late`:初期と終盤のステップの両方でCFGをスキップし、中間のステップのみ適用します
+- `alternate`:指定された割合に基づいてCFGを適用します
+
+`--cfg_apply_ratio` は、CFGが適用されるステップの割合を0.0から1.0の値で指定します。たとえば、0.5に設定すると、CFGはステップの50%のみで適用されます。
+
+具体的なパターンは上のテーブルを参照してください。
+
+適切な設定は不明ですが、モードは`late`または`early_late`、ratioは0.3~0.5程度から試してみると良いかもしれません。
+
+
+#### Skip Layer Guidance
+
+Skip Layer Guidance is a feature that uses the output of a model with some blocks skipped as the unconditional output of classifier free guidance. It was originally proposed in [SD 3.5](https://github.com/comfyanonymous/ComfyUI/pull/5404) and first applied in Wan2GP in [this PR](https://github.com/deepbeepmeep/Wan2GP/pull/61). It may improve the quality of generated videos.
+
+The implementation of SD 3.5 is [here](https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py), and the implementation of Wan2GP (the PR mentioned above) has some different specifications. This inference script allows you to choose between the two methods.
+
+*The SD3.5 method applies slg output in addition to cond and uncond (slows down the speed). The Wan2GP method uses only cond and slg output.*
+
+The following arguments are available:
+
+- `--slg_mode`: Specifies the SLG mode. `original` for SD 3.5 method, `uncond` for Wan2GP method. Default is None (no SLG).
+- `--slg_layers`: Specifies the indices of the blocks (layers) to skip in SLG, separated by commas. Example: `--slg_layers 4,5,6`. Default is empty (no skip). If this option is not specified, `--slg_mode` is ignored.
+- `--slg_scale`: Specifies the scale of SLG when `original`. Default is 3.0.
+- `--slg_start`: Specifies the start step of SLG application in inference steps from 0.0 to 1.0. Default is 0.0 (applied from the beginning).
+- `--slg_end`: Specifies the end step of SLG application in inference steps from 0.0 to 1.0. Default is 0.3 (applied up to 30% from the beginning).
+
+Appropriate settings are unknown, but you may want to try `original` mode with a scale of around 3.0 and a start ratio of 0.0 and an end ratio of 0.5, with layers 4, 5, and 6 skipped.
+
+
+日本語
+Skip Layer Guidanceは、一部のblockをスキップしたモデル出力をclassifier free guidanceのunconditional出力に使用する機能です。元々は[SD 3.5](https://github.com/comfyanonymous/ComfyUI/pull/5404)で提案されたもので、Wan2.1には[Wan2GPのこちらのPR](https://github.com/deepbeepmeep/Wan2GP/pull/61)で初めて適用されました。生成動画の品質が向上する可能性があります。
+
+SD 3.5の実装は[こちら](https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py)で、Wan2GPの実装(前述のPR)は一部仕様が異なります。この推論スクリプトでは両者の方式を選択できるようになっています。
+
+※SD3.5方式はcondとuncondに加えてslg outputを適用します(速度が低下します)。Wan2GP方式はcondとslg outputのみを使用します。
+
+以下の引数があります。
+
+- `--slg_mode`:SLGのモードを指定します。`original`でSD 3.5の方式、`uncond`でWan2GPの方式です。デフォルトはNoneで、SLGを使用しません。
+- `--slg_layers`:SLGでスキップするblock (layer)のインデクスをカンマ区切りで指定します。例:`--slg_layers 4,5,6`。デフォルトは空(スキップしない)です。このオプションを指定しないと`--slg_mode`は無視されます。
+- `--slg_scale`:`original`のときのSLGのスケールを指定します。デフォルトは3.0です。
+- `--slg_start`:推論ステップのSLG適用開始ステップを0.0から1.0の割合で指定します。デフォルトは0.0です(最初から適用)。
+- `--slg_end`:推論ステップのSLG適用終了ステップを0.0から1.0の割合で指定します。デフォルトは0.3です(最初から30%まで適用)。
+
+適切な設定は不明ですが、`original`モードでスケールを3.0程度、開始割合を0.0、終了割合を0.5程度に設定し、4, 5, 6のlayerをスキップする設定から始めると良いかもしれません。
+
+
+### I2V Inference / I2V推論
+
+The following is an example of I2V inference (input as a single line):
+
+```bash
+python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_length 81 --infer_steps 20
+--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
+--dit path/to/wan2.1_i2v_480p_14B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
+--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
+--attn_mode torch --image_path path/to/image.jpg
+```
+
+Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
+
+`--end_image_path` can be used to specify the end image. This option is experimental. When this option is specified, the saved video will be slightly longer than the specified number of frames and will have noise, so it is recommended to specify `--trim_tail_frames 3` to trim the tail frames.
+
+You can also use the Fun Control model for I2V inference. Specify the control video with `--control_path`.
+
+Other options are same as T2V inference.
+
+
+日本語
+`--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
+
+`--end_image_path` で終了画像を指定できます。このオプションは実験的なものです。このオプションを指定すると、保存される動画が指定フレーム数よりもやや多くなり、かつノイズが乗るため、`--trim_tail_frames 3` などを指定して末尾のフレームをトリミングすることをお勧めします。
+
+I2V推論でもFun Controlモデルが使用できます。`--control_path` で制御用の映像を指定します。
+
+その他のオプションはT2V推論と同じです。
+
+
+### New Batch and Interactive Modes / 新しいバッチモードとインタラクティブモード
+
+In addition to single video generation, Wan 2.1 now supports batch generation from file and interactive prompt input:
+
+#### Batch Mode from File / ファイルからのバッチモード
+
+Generate multiple videos from prompts stored in a text file:
+
+```bash
+python wan_generate_video.py --from_file prompts.txt --task t2v-14B
+--dit path/to/model.safetensors --vae path/to/vae.safetensors
+--t5 path/to/t5_model.pth --save_path output_directory
+```
+
+The prompts file format:
+- One prompt per line
+- Empty lines and lines starting with # are ignored (comments)
+- Each line can include prompt-specific parameters using command-line style format:
+
+```
+A beautiful sunset over mountains --w 832 --h 480 --f 81 --d 42 --s 20
+A busy city street at night --w 480 --h 832 --g 7.5 --n low quality, blurry
+```
+
+Supported inline parameters (if ommitted, default values from the command line are used):
+- `--w`: Width
+- `--h`: Height
+- `--f`: Frame count
+- `--d`: Seed
+- `--s`: Inference steps
+- `--g` or `--l`: Guidance scale
+- `--fs`: Flow shift
+- `--i`: Image path (for I2V)
+- `--cn`: Control path (for Fun Control)
+- `--n`: Negative prompt
+
+In batch mode, models are loaded once and reused for all prompts, significantly improving overall generation time compared to multiple single runs.
+
+#### Interactive Mode / インタラクティブモード
+
+Interactive command-line interface for entering prompts:
+
+```bash
+python wan_generate_video.py --interactive --task t2v-14B
+--dit path/to/model.safetensors --vae path/to/vae.safetensors
+--t5 path/to/t5_model.pth --save_path output_directory
+```
+
+In interactive mode:
+- Enter prompts directly at the command line
+- Use the same inline parameter format as batch mode
+- Use Ctrl+D (or Ctrl+Z on Windows) to exit
+- Models remain loaded between generations for efficiency
+
+
+日本語
+単一動画の生成に加えて、Wan 2.1は現在、ファイルからのバッチ生成とインタラクティブなプロンプト入力をサポートしています。
+
+#### ファイルからのバッチモード
+
+テキストファイルに保存されたプロンプトから複数の動画を生成します:
+
+```bash
+python wan_generate_video.py --from_file prompts.txt --task t2v-14B
+--dit path/to/model.safetensors --vae path/to/vae.safetensors
+--t5 path/to/t5_model.pth --save_path output_directory
+```
+
+プロンプトファイルの形式:
+- 1行に1つのプロンプト
+- 空行や#で始まる行は無視されます(コメント)
+- 各行にはコマンドライン形式でプロンプト固有のパラメータを含めることができます:
+
+サポートされているインラインパラメータ(省略した場合、コマンドラインのデフォルト値が使用されます)
+- `--w`: 幅
+- `--h`: 高さ
+- `--f`: フレーム数
+- `--d`: シード
+- `--s`: 推論ステップ
+- `--g` または `--l`: ガイダンススケール
+- `--fs`: フローシフト
+- `--i`: 画像パス(I2V用)
+- `--cn`: コントロールパス(Fun Control用)
+- `--n`: ネガティブプロンプト
+
+バッチモードでは、モデルは一度だけロードされ、すべてのプロンプトで再利用されるため、複数回の単一実行と比較して全体的な生成時間が大幅に改善されます。
+
+#### インタラクティブモード
+
+プロンプトを入力するためのインタラクティブなコマンドラインインターフェース:
+
+```bash
+python wan_generate_video.py --interactive --task t2v-14B
+--dit path/to/model.safetensors --vae path/to/vae.safetensors
+--t5 path/to/t5_model.pth --save_path output_directory
+```
+
+インタラクティブモードでは:
+- コマンドラインで直接プロンプトを入力
+- バッチモードと同じインラインパラメータ形式を使用
+- 終了するには Ctrl+D (Windowsでは Ctrl+Z) を使用
+- 効率のため、モデルは生成間で読み込まれたままになります
+
+
diff --git a/fpack_cache_latents.py b/fpack_cache_latents.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d09ac066e17b312bd7ea3bb317be0c73b19841a
--- /dev/null
+++ b/fpack_cache_latents.py
@@ -0,0 +1,524 @@
+import argparse
+import logging
+import math
+import os
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+from transformers import SiglipImageProcessor, SiglipVisionModel
+from PIL import Image
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache_framepack, ARCHITECTURE_FRAMEPACK
+from frame_pack import hunyuan
+from frame_pack.framepack_utils import load_image_encoders, load_vae
+from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from frame_pack.clip_vision import hf_clip_vision_encode
+import cache_latents
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def encode_and_save_batch(
+ vae: AutoencoderKLCausal3D,
+ feature_extractor: SiglipImageProcessor,
+ image_encoder: SiglipVisionModel,
+ batch: List[ItemInfo],
+ vanilla_sampling: bool = False,
+ one_frame: bool = False,
+ one_frame_no_2x: bool = False,
+ one_frame_no_4x: bool = False,
+):
+ """Encode a batch of original RGB videos and save FramePack section caches."""
+ if one_frame:
+ encode_and_save_batch_one_frame(
+ vae, feature_extractor, image_encoder, batch, vanilla_sampling, one_frame_no_2x, one_frame_no_4x
+ )
+ return
+
+ latent_window_size = batch[0].fp_latent_window_size # all items should have the same window size
+
+ # Stack batch into tensor (B,C,F,H,W) in RGB order
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
+ if len(contents.shape) == 4:
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
+
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
+ contents = contents.to(vae.device, dtype=vae.dtype)
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
+
+ height, width = contents.shape[3], contents.shape[4]
+ if height < 8 or width < 8:
+ item = batch[0] # other items should have the same size
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
+
+ # calculate latent frame count from original frame count (4n+1)
+ latent_f = (batch[0].frame_count - 1) // 4 + 1
+
+ # calculate the total number of sections (excluding the first frame, divided by window size)
+ total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
+ if total_latent_sections < 1:
+ min_frames_needed = latent_window_size * 4 + 1
+ raise ValueError(
+ f"Not enough frames for FramePack: {batch[0].frame_count} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size+1} latent frames)"
+ )
+
+ # actual latent frame count (aligned to section boundaries)
+ latent_f_aligned = total_latent_sections * latent_window_size + 1 if not one_frame else 1
+
+ # actual video frame count
+ frame_count_aligned = (latent_f_aligned - 1) * 4 + 1
+ if frame_count_aligned != batch[0].frame_count:
+ logger.info(
+ f"Frame count mismatch: required={frame_count_aligned} != actual={batch[0].frame_count}, trimming to {frame_count_aligned}"
+ )
+ contents = contents[:, :, :frame_count_aligned, :, :]
+
+ latent_f = latent_f_aligned # Update to the aligned value
+
+ # VAE encode (list of tensor -> stack)
+ latents = hunyuan.vae_encode(contents, vae) # include scaling factor
+ latents = latents.to("cpu") # (B, C, latent_f, H/8, W/8)
+
+ # Vision encoding per‑item (once)
+ images = np.stack([item.content[0] for item in batch], axis=0) # B, H, W, C
+
+ # encode image with image encoder
+ image_embeddings = []
+ with torch.no_grad():
+ for image in images:
+ image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
+ image_embeddings.append(image_encoder_output.last_hidden_state)
+ image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
+ image_embeddings = image_embeddings.to("cpu") # Save memory
+
+ if not vanilla_sampling:
+ # padding is reversed for inference (future to past)
+ latent_paddings = list(reversed(range(total_latent_sections)))
+ # Note: The padding trick for inference. See the paper for details.
+ if total_latent_sections > 4:
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
+
+ for b, item in enumerate(batch):
+ original_latent_cache_path = item.latent_cache_path
+ video_lat = latents[b : b + 1] # keep batch dim, 1, C, F, H, W
+
+ # emulate inference step (history latents)
+ # Note: In inference, history_latents stores *generated* future latents.
+ # Here, for caching, we just need its shape and type for clean_* tensors.
+ # The actual content doesn't matter much as clean_* will be overwritten.
+ history_latents = torch.zeros(
+ (1, video_lat.shape[1], 1 + 2 + 16, video_lat.shape[3], video_lat.shape[4]), dtype=video_lat.dtype
+ ) # C=16 for HY
+
+ latent_f_index = latent_f - latent_window_size # Start from the last section
+ section_index = total_latent_sections - 1
+
+ for latent_padding in latent_paddings:
+ is_last_section = section_index == 0 # the last section in inference order == the first section in time
+ latent_padding_size = latent_padding * latent_window_size
+ if is_last_section:
+ assert latent_f_index == 1, "Last section should be starting from frame 1"
+
+ # indices generation (same as inference)
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
+ (
+ clean_latent_indices_pre, # Index for start_latent
+ blank_indices, # Indices for padding (future context in inference)
+ latent_indices, # Indices for the target latents to predict
+ clean_latent_indices_post, # Index for the most recent history frame
+ clean_latent_2x_indices, # Indices for the next 2 history frames
+ clean_latent_4x_indices, # Indices for the next 16 history frames
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
+
+ # Indices for clean_latents (start + recent history)
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
+
+ # clean latents preparation (emulating inference)
+ clean_latents_pre = video_lat[:, :, 0:1, :, :] # Always the first frame (start_latent)
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
+ [1, 2, 16], dim=2
+ )
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
+
+ # Target latents for this section (ground truth)
+ target_latents = video_lat[:, :, latent_f_index : latent_f_index + latent_window_size, :, :]
+
+ # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
+ item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
+ save_latent_cache_framepack(
+ item_info=item,
+ latent=target_latents.squeeze(0), # Ground truth for this section
+ latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
+ clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
+ clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
+ clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
+ image_embeddings=image_embeddings[b],
+ )
+
+ if is_last_section: # If this was the first section generated in inference (time=0)
+ # History gets the start frame + the generated first section
+ generated_latents_for_history = video_lat[:, :, : latent_window_size + 1, :, :]
+ else:
+ # History gets the generated current section
+ generated_latents_for_history = target_latents # Use true latents as stand-in for generated
+
+ history_latents = torch.cat([generated_latents_for_history, history_latents], dim=2)
+
+ section_index -= 1
+ latent_f_index -= latent_window_size
+
+ else:
+ # Vanilla Sampling Logic
+ for b, item in enumerate(batch):
+ original_latent_cache_path = item.latent_cache_path
+ video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
+ img_emb = image_embeddings[b] # LEN, 1152
+
+ for section_index in range(total_latent_sections):
+ target_start_f = section_index * latent_window_size + 1
+ target_end_f = target_start_f + latent_window_size
+ target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
+ start_latent = video_lat[:, :, 0:1, :, :]
+
+ # Clean latents preparation (Vanilla)
+ clean_latents_total_count = 1 + 2 + 16
+ history_latents = torch.zeros(
+ size=(1, 16, clean_latents_total_count, video_lat.shape[-2], video_lat.shape[-1]),
+ device=video_lat.device,
+ dtype=video_lat.dtype,
+ )
+
+ history_start_f = 0
+ video_start_f = target_start_f - clean_latents_total_count
+ copy_count = clean_latents_total_count
+ if video_start_f < 0:
+ history_start_f = -video_start_f
+ copy_count = clean_latents_total_count - history_start_f
+ video_start_f = 0
+ if copy_count > 0:
+ history_latents[:, :, history_start_f:] = video_lat[:, :, video_start_f : video_start_f + copy_count, :, :]
+
+ # indices generation (Vanilla): copy from FramePack-F1
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
+ (
+ clean_latent_indices_start,
+ clean_latent_4x_indices,
+ clean_latent_2x_indices,
+ clean_latent_1x_indices,
+ latent_indices,
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
+
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2)
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
+
+ # Save cache
+ item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
+ save_latent_cache_framepack(
+ item_info=item,
+ latent=target_latents.squeeze(0),
+ latent_indices=latent_indices.squeeze(0), # Indices for target section i
+ clean_latents=clean_latents.squeeze(0), # Past clean frames
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for clean_latents_pre/post
+ clean_latents_2x=clean_latents_2x.squeeze(0), # Past clean frames (2x)
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for clean_latents_2x
+ clean_latents_4x=clean_latents_4x.squeeze(0), # Past clean frames (4x)
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for clean_latents_4x
+ image_embeddings=img_emb,
+ # Note: We don't explicitly save past_offset_indices,
+ # but its size influences the absolute values in other indices.
+ )
+
+
+def encode_and_save_batch_one_frame(
+ vae: AutoencoderKLCausal3D,
+ feature_extractor: SiglipImageProcessor,
+ image_encoder: SiglipVisionModel,
+ batch: List[ItemInfo],
+ vanilla_sampling: bool = False,
+ one_frame_no_2x: bool = False,
+ one_frame_no_4x: bool = False,
+):
+ # item.content: target image (H, W, C)
+ # item.control_content: list of images (H, W, C)
+
+ # Stack batch into tensor (B,F,H,W,C) in RGB order. The numbers of control content for each item are the same.
+ contents = []
+ content_masks: list[list[Optional[torch.Tensor]]] = []
+ for item in batch:
+ item_contents = item.control_content + [item.content]
+
+ item_masks = []
+ for i, c in enumerate(item_contents):
+ if c.shape[-1] == 4: # RGBA
+ item_contents[i] = c[..., :3] # remove alpha channel from content
+
+ alpha = c[..., 3] # extract alpha channel
+ mask_image = Image.fromarray(alpha, mode="L")
+ width, height = mask_image.size
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
+ mask_image = mask_image.to(torch.float32)
+ content_mask = mask_image
+ else:
+ content_mask = None
+
+ item_masks.append(content_mask)
+
+ item_contents = [torch.from_numpy(c) for c in item_contents]
+ contents.append(torch.stack(item_contents, dim=0)) # list of [F, H, W, C]
+ content_masks.append(item_masks)
+
+ contents = torch.stack(contents, dim=0) # B, F, H, W, C. F is control frames + target frame
+
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
+ contents = contents.to(vae.device, dtype=vae.dtype)
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
+
+ height, width = contents.shape[-2], contents.shape[-1]
+ if height < 8 or width < 8:
+ item = batch[0] # other items should have the same size
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
+
+ # VAE encode: we need to encode one frame at a time because VAE encoder has stride=4 for the time dimension except for the first frame.
+ latents = [hunyuan.vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])]
+ latents = torch.cat(latents, dim=2) # B, C, F, H/8, W/8
+
+ # apply alphas to latents
+ for b, item in enumerate(batch):
+ for i, content_mask in enumerate(content_masks[b]):
+ if content_mask is not None:
+ # apply mask to the latents
+ # print(f"Applying content mask for item {item.item_key}, frame {i}")
+ latents[b : b + 1, :, i : i + 1] *= content_mask
+
+ # Vision encoding per‑item (once): use control content because it is the start image
+ images = [item.control_content[0] for item in batch] # list of [H, W, C]
+
+ # encode image with image encoder
+ image_embeddings = []
+ with torch.no_grad():
+ for image in images:
+ image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
+ image_embeddings.append(image_encoder_output.last_hidden_state)
+ image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
+ image_embeddings = image_embeddings.to("cpu") # Save memory
+
+ # save cache for each item in the batch
+ for b, item in enumerate(batch):
+ # indices generation (same as inference): each item may have different clean_latent_indices, so we generate them per item
+ clean_latent_indices = item.fp_1f_clean_indices # list of indices for clean latents
+ if clean_latent_indices is None or len(clean_latent_indices) == 0:
+ logger.warning(
+ f"Item {item.item_key} has no clean_latent_indices defined, using default indices for one frame training."
+ )
+ clean_latent_indices = [0]
+
+ if not item.fp_1f_no_post:
+ clean_latent_indices = clean_latent_indices + [1 + item.fp_latent_window_size]
+ clean_latent_indices = torch.Tensor(clean_latent_indices).long() # N
+
+ latent_index = torch.Tensor([item.fp_1f_target_index]).long() # 1
+
+ # zero values is not needed to cache even if one_frame_no_2x or 4x is False
+ clean_latents_2x = None
+ clean_latents_4x = None
+
+ if one_frame_no_2x:
+ clean_latent_2x_indices = None
+ else:
+ index = 1 + item.fp_latent_window_size + 1
+ clean_latent_2x_indices = torch.arange(index, index + 2) # 2
+
+ if one_frame_no_4x:
+ clean_latent_4x_indices = None
+ else:
+ index = 1 + item.fp_latent_window_size + 1 + 2
+ clean_latent_4x_indices = torch.arange(index, index + 16) # 16
+
+ # clean latents preparation (emulating inference)
+ clean_latents = latents[b, :, :-1] # C, F, H, W
+ if not item.fp_1f_no_post:
+ # If zero post is enabled, we need to add a zero frame at the end
+ clean_latents = F.pad(clean_latents, (0, 0, 0, 0, 0, 1), value=0.0) # C, F+1, H, W
+
+ # Target latents for this section (ground truth)
+ target_latents = latents[b, :, -1:] # C, 1, H, W
+
+ print(f"Saving cache for item {item.item_key} at {item.latent_cache_path}. no_post: {item.fp_1f_no_post}")
+ print(f" Clean latent indices: {clean_latent_indices}, latent index: {latent_index}")
+ print(f" Clean latents: {clean_latents.shape}, target latents: {target_latents.shape}")
+ print(f" Clean latents 2x indices: {clean_latent_2x_indices}, clean latents 4x indices: {clean_latent_4x_indices}")
+ print(
+ f" Clean latents 2x: {clean_latents_2x.shape if clean_latents_2x is not None else 'None'}, "
+ f"Clean latents 4x: {clean_latents_4x.shape if clean_latents_4x is not None else 'None'}"
+ )
+ print(f" Image embeddings: {image_embeddings[b].shape}")
+
+ # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
+ save_latent_cache_framepack(
+ item_info=item,
+ latent=target_latents, # Ground truth for this section
+ latent_indices=latent_index, # Indices for the ground truth section
+ clean_latents=clean_latents, # Start frame + history placeholder
+ clean_latent_indices=clean_latent_indices, # Indices for start frame + history placeholder
+ clean_latents_2x=clean_latents_2x, # History placeholder
+ clean_latent_2x_indices=clean_latent_2x_indices, # Indices for history placeholder
+ clean_latents_4x=clean_latents_4x, # History placeholder
+ clean_latent_4x_indices=clean_latent_4x_indices, # Indices for history placeholder
+ image_embeddings=image_embeddings[b],
+ )
+
+
+def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
+ parser.add_argument(
+ "--f1",
+ action="store_true",
+ help="Generate cache for F1 model (vanilla (autoregressive) sampling) instead of Inverted anti-drifting (plain FramePack)",
+ )
+ parser.add_argument(
+ "--one_frame",
+ action="store_true",
+ help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
+ )
+ parser.add_argument(
+ "--one_frame_no_2x",
+ action="store_true",
+ help="Do not use clean_latents_2x and clean_latent_2x_indices for one frame training.",
+ )
+ parser.add_argument(
+ "--one_frame_no_4x",
+ action="store_true",
+ help="Do not use clean_latents_4x and clean_latent_4x_indices for one frame training.",
+ )
+ return parser
+
+
+def main(args: argparse.Namespace):
+ device = args.device if hasattr(args, "device") and args.device else ("cuda" if torch.cuda.is_available() else "cpu")
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ if args.debug_mode is not None:
+ cache_latents.show_datasets(
+ datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
+ )
+ return
+
+ assert args.vae is not None, "vae checkpoint is required"
+
+ logger.info(f"Loading VAE model from {args.vae}")
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device=device)
+ vae.to(device)
+
+ logger.info(f"Loading image encoder from {args.image_encoder}")
+ feature_extractor, image_encoder = load_image_encoders(args)
+ image_encoder.eval()
+ image_encoder.to(device)
+
+ logger.info(f"Cache generation mode: {'Vanilla Sampling' if args.f1 else 'Inference Emulation'}")
+
+ # encoding closure
+ def encode(batch: List[ItemInfo]):
+ encode_and_save_batch(
+ vae, feature_extractor, image_encoder, batch, args.f1, args.one_frame, args.one_frame_no_2x, args.one_frame_no_4x
+ )
+
+ # reuse core loop from cache_latents with no change
+ encode_datasets_framepack(datasets, encode, args)
+
+
+def append_section_idx_to_latent_cache_path(latent_cache_path: str, section_idx: int) -> str:
+ tokens = latent_cache_path.split("_")
+ tokens[-3] = f"{tokens[-3]}-{section_idx:04d}" # append section index to "frame_pos-count"
+ return "_".join(tokens)
+
+
+def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
+ for i, dataset in enumerate(datasets):
+ logger.info(f"Encoding dataset [{i}]")
+ all_latent_cache_paths = []
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
+ batch: list[ItemInfo] = batch # type: ignore
+
+ # latent_cache_path is "{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
+ # For video dataset,we expand it to "{basename}_{section_idx:04d}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
+ filtered_batch = []
+ for item in batch:
+ if item.frame_count is None:
+ # image dataset
+ all_latent_cache_paths.append(item.latent_cache_path)
+ all_existing = os.path.exists(item.latent_cache_path)
+ else:
+ latent_f = (item.frame_count - 1) // 4 + 1
+ num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size)) # min 1 section
+ all_existing = True
+ for sec in range(num_sections):
+ p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
+ all_latent_cache_paths.append(p)
+ all_existing = all_existing and os.path.exists(p)
+
+ if not all_existing: # if any section cache is missing
+ filtered_batch.append(item)
+
+ if args.skip_existing:
+ if len(filtered_batch) == 0: # all sections exist
+ logger.info(f"All sections exist for {batch[0].item_key}, skipping")
+ continue
+ batch = filtered_batch # update batch to only missing sections
+
+ bs = args.batch_size if args.batch_size is not None else len(batch)
+ for i in range(0, len(batch), bs):
+ encode(batch[i : i + bs])
+
+ # normalize paths
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
+ all_latent_cache_paths = set(all_latent_cache_paths)
+
+ # remove old cache files not in the dataset
+ all_cache_files = dataset.get_all_latent_cache_files()
+ for cache_file in all_cache_files:
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
+ if args.keep_cache:
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
+ else:
+ os.remove(cache_file)
+ logger.info(f"Removed old cache file: {cache_file}")
+
+
+if __name__ == "__main__":
+ parser = cache_latents.setup_parser_common()
+ parser = cache_latents.hv_setup_parser(parser) # VAE
+ parser = framepack_setup_parser(parser)
+
+ args = parser.parse_args()
+
+ if args.vae_dtype is not None:
+ raise ValueError("VAE dtype is not supported in FramePack")
+ # if args.batch_size != 1:
+ # args.batch_size = 1
+ # logger.info("Batch size is set to 1 for FramePack.")
+
+ main(args)
diff --git a/fpack_cache_text_encoder_outputs.py b/fpack_cache_text_encoder_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a1ed9fbd943f0927bb64aea92409e6db7b28a79
--- /dev/null
+++ b/fpack_cache_text_encoder_outputs.py
@@ -0,0 +1,110 @@
+import argparse
+import os
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from transformers import LlamaTokenizerFast, LlamaModel, CLIPTokenizer, CLIPTextModel
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ItemInfo, save_text_encoder_output_cache_framepack
+import cache_text_encoder_outputs
+from frame_pack import hunyuan
+from frame_pack.framepack_utils import load_text_encoder1, load_text_encoder2
+
+import logging
+
+from frame_pack.utils import crop_or_pad_yield_mask
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def encode_and_save_batch(
+ tokenizer1: LlamaTokenizerFast,
+ text_encoder1: LlamaModel,
+ tokenizer2: CLIPTokenizer,
+ text_encoder2: CLIPTextModel,
+ batch: list[ItemInfo],
+ device: torch.device,
+):
+ prompts = [item.caption for item in batch]
+
+ # encode prompt
+ # FramePack's encode_prompt_conds only supports single prompt, so we need to encode each prompt separately
+ list_of_llama_vec = []
+ list_of_llama_attention_mask = []
+ list_of_clip_l_pooler = []
+ for prompt in prompts:
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
+ # llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompts, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
+
+ list_of_llama_vec.append(llama_vec.squeeze(0))
+ list_of_llama_attention_mask.append(llama_attention_mask.squeeze(0))
+ list_of_clip_l_pooler.append(clip_l_pooler.squeeze(0))
+
+ # save prompt cache
+ for item, llama_vec, llama_attention_mask, clip_l_pooler in zip(
+ batch, list_of_llama_vec, list_of_llama_attention_mask, list_of_clip_l_pooler
+ ):
+ # save llama_vec and clip_l_pooler to cache
+ save_text_encoder_output_cache_framepack(item, llama_vec, llama_attention_mask, clip_l_pooler)
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
+
+ # load text encoder
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
+ text_encoder2.to(device)
+
+ # Encode with Text Encoders
+ logger.info("Encoding with Text Encoders")
+
+ def encode_for_text_encoder(batch: list[ItemInfo]):
+ encode_and_save_batch(tokenizer1, text_encoder1, tokenizer2, text_encoder2, batch, device)
+
+ cache_text_encoder_outputs.process_text_encoder_batches(
+ args.num_workers,
+ args.skip_existing,
+ args.batch_size,
+ datasets,
+ all_cache_files_for_dataset,
+ all_cache_paths_for_dataset,
+ encode_for_text_encoder,
+ )
+
+ # remove cache files not in dataset
+ cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
+
+
+def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = cache_text_encoder_outputs.setup_parser_common()
+ parser = framepack_setup_parser(parser)
+
+ args = parser.parse_args()
+ main(args)
diff --git a/fpack_generate_video.py b/fpack_generate_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1cfe8b46805d7111827b525e7e9f0a660f890a0
--- /dev/null
+++ b/fpack_generate_video.py
@@ -0,0 +1,1832 @@
+import argparse
+from datetime import datetime
+import gc
+import json
+import random
+import os
+import re
+import time
+import math
+import copy
+from typing import Tuple, Optional, List, Union, Any, Dict
+
+import torch
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from PIL import Image
+import cv2
+import numpy as np
+import torchvision.transforms.functional as TF
+from transformers import LlamaModel
+from tqdm import tqdm
+
+from networks import lora_framepack
+from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from frame_pack import hunyuan
+from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
+from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw
+from frame_pack.bucket_tools import find_nearest_bucket
+from frame_pack.clip_vision import hf_clip_vision_encode
+from frame_pack.k_diffusion_hunyuan import sample_hunyuan
+from dataset import image_video_dataset
+
+try:
+ from lycoris.kohya import create_network_from_weights
+except:
+ pass
+
+from utils.device_utils import clean_memory_on_device
+from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
+from wan_generate_video import merge_lora_weights
+from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders
+from dataset.image_video_dataset import load_video
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+class GenerationSettings:
+ def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None):
+ self.device = device
+ self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized
+
+
+def parse_args() -> argparse.Namespace:
+ """parse command line arguments"""
+ parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
+
+ # WAN arguments
+ # parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
+ parser.add_argument(
+ "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
+ )
+
+ parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
+ parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path")
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path")
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path")
+ parser.add_argument("--f1", action="store_true", help="Use F1 sampling method")
+
+ # LoRA
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
+ parser.add_argument(
+ "--save_merged_model",
+ type=str,
+ default=None,
+ help="Save merged model to path. If specified, no inference will be performed.",
+ )
+
+ # inference
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default=None,
+ help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or "
+ "`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).",
+ )
+ parser.add_argument(
+ "--negative_prompt",
+ type=str,
+ default=None,
+ help="negative prompt for generation, default is empty string. should not change.",
+ )
+ parser.add_argument(
+ "--custom_system_prompt",
+ type=str,
+ default=None,
+ help="Custom system prompt for LLM. If specified, it will override the default system prompt. See hunyuan_model/text_encoder.py for the default system prompt.",
+ )
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
+ parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, default is 5.0 seconds")
+ parser.add_argument(
+ "--video_sections",
+ type=int,
+ default=None,
+ help="number of video sections, Default is None (auto calculate from video seconds)",
+ )
+ parser.add_argument(
+ "--one_frame_inference",
+ type=str,
+ default=None,
+ help="one frame inference, default is None, comma separated values from 'no_2x', 'no_4x', 'no_post', 'control_indices' and 'target_index'.",
+ )
+ parser.add_argument(
+ "--control_image_path", type=str, default=None, nargs="*", help="path to control (reference) image for one frame inference."
+ )
+ parser.add_argument(
+ "--control_image_mask_path",
+ type=str,
+ default=None,
+ nargs="*",
+ help="path to control (reference) image mask for one frame inference.",
+ )
+ parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
+ parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
+ # parser.add_argument(
+ # "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
+ # )
+ parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.")
+ parser.add_argument(
+ "--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0"
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=1.0,
+ help="Guidance scale for classifier free guidance. Default is 1.0 (no guidance), should not change.",
+ )
+ parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.")
+ # parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
+ parser.add_argument(
+ "--image_path",
+ type=str,
+ default=None,
+ help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
+ )
+ parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
+ parser.add_argument(
+ "--latent_paddings",
+ type=str,
+ default=None,
+ help="latent paddings for each section, comma separated values. default is None (FramePack default paddings)",
+ )
+ # parser.add_argument(
+ # "--control_path",
+ # type=str,
+ # default=None,
+ # help="path to control video for inference with controlnet. video file or directory with images",
+ # )
+ # parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
+
+ # # Flow Matching
+ # parser.add_argument(
+ # "--flow_shift",
+ # type=float,
+ # default=None,
+ # help="Shift factor for flow matching schedulers. Default depends on task.",
+ # )
+
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
+ # parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
+ parser.add_argument(
+ "--rope_scaling_factor", type=float, default=0.5, help="RoPE scaling factor for high resolution (H/W), default is 0.5"
+ )
+ parser.add_argument(
+ "--rope_scaling_timestep_threshold",
+ type=int,
+ default=None,
+ help="RoPE scaling timestep threshold, default is None (disable), if set, RoPE scaling will be applied only for timesteps >= threshold, around 800 is good starting point",
+ )
+
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
+ parser.add_argument(
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
+ )
+ parser.add_argument(
+ "--attn_mode",
+ type=str,
+ default="torch",
+ choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
+ help="attention mode",
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once")
+ parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
+ parser.add_argument(
+ "--output_type",
+ type=str,
+ default="video",
+ choices=["video", "images", "latent", "both", "latent_images"],
+ help="output type",
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
+ # parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
+ # parser.add_argument(
+ # "--compile_args",
+ # nargs=4,
+ # metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
+ # default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
+ # help="Torch.compile settings",
+ # )
+
+ # New arguments for batch and interactive modes
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
+
+ args = parser.parse_args()
+
+ # Validate arguments
+ if args.from_file and args.interactive:
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
+
+ if args.latent_path is None or len(args.latent_path) == 0:
+ if args.prompt is None and not args.from_file and not args.interactive:
+ raise ValueError("Either --prompt, --from_file or --interactive must be specified")
+
+ return args
+
+
+def parse_prompt_line(line: str) -> Dict[str, Any]:
+ """Parse a prompt line into a dictionary of argument overrides
+
+ Args:
+ line: Prompt line with options
+
+ Returns:
+ Dict[str, Any]: Dictionary of argument overrides
+ """
+ # TODO common function with hv_train_network.line_to_prompt_dict
+ parts = line.split(" --")
+ prompt = parts[0].strip()
+
+ # Create dictionary of overrides
+ overrides = {"prompt": prompt}
+ # Initialize control_image_path and control_image_mask_path as a list to accommodate multiple paths
+ overrides["control_image_path"] = []
+ overrides["control_image_mask_path"] = []
+
+ for part in parts[1:]:
+ if not part.strip():
+ continue
+ option_parts = part.split(" ", 1)
+ option = option_parts[0].strip()
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
+
+ # Map options to argument names
+ if option == "w":
+ overrides["video_size_width"] = int(value)
+ elif option == "h":
+ overrides["video_size_height"] = int(value)
+ elif option == "f":
+ overrides["video_seconds"] = float(value)
+ elif option == "d":
+ overrides["seed"] = int(value)
+ elif option == "s":
+ overrides["infer_steps"] = int(value)
+ elif option == "g" or option == "l":
+ overrides["guidance_scale"] = float(value)
+ # elif option == "fs":
+ # overrides["flow_shift"] = float(value)
+ elif option == "i":
+ overrides["image_path"] = value
+ # elif option == "im":
+ # overrides["image_mask_path"] = value
+ # elif option == "cn":
+ # overrides["control_path"] = value
+ elif option == "n":
+ overrides["negative_prompt"] = value
+ elif option == "vs": # video_sections
+ overrides["video_sections"] = int(value)
+ elif option == "ei": # end_image_path
+ overrides["end_image_path"] = value
+ elif option == "ci": # control_image_path
+ overrides["control_image_path"].append(value)
+ elif option == "cim": # control_image_mask_path
+ overrides["control_image_mask_path"].append(value)
+ elif option == "of": # one_frame_inference
+ overrides["one_frame_inference"] = value
+
+ # If no control_image_path was provided, remove the empty list
+ if not overrides["control_image_path"]:
+ del overrides["control_image_path"]
+ if not overrides["control_image_mask_path"]:
+ del overrides["control_image_mask_path"]
+
+ return overrides
+
+
+def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
+ """Apply overrides to args
+
+ Args:
+ args: Original arguments
+ overrides: Dictionary of overrides
+
+ Returns:
+ argparse.Namespace: New arguments with overrides applied
+ """
+ args_copy = copy.deepcopy(args)
+
+ for key, value in overrides.items():
+ if key == "video_size_width":
+ args_copy.video_size[1] = value
+ elif key == "video_size_height":
+ args_copy.video_size[0] = value
+ else:
+ setattr(args_copy, key, value)
+
+ return args_copy
+
+
+def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
+ """Validate video size and length
+
+ Args:
+ args: command line arguments
+
+ Returns:
+ Tuple[int, int, float]: (height, width, video_seconds)
+ """
+ height = args.video_size[0]
+ width = args.video_size[1]
+
+ video_seconds = args.video_seconds
+ if args.video_sections is not None:
+ video_seconds = (args.video_sections * (args.latent_window_size * 4) + 1) / args.fps
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ return height, width, video_seconds
+
+
+# region DiT model
+
+
+def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked:
+ """load DiT model
+
+ Args:
+ args: command line arguments
+ device: device to use
+ dit_dtype: data type for the model
+ dit_weight_dtype: data type for the model weights. None for as-is
+
+ Returns:
+ HunyuanVideoTransformer3DModelPacked: DiT model
+ """
+ loading_device = "cpu"
+ if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None:
+ loading_device = device
+
+ # do not fp8 optimize because we will merge LoRA weights
+ model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
+
+ # apply RoPE scaling factor
+ if args.rope_scaling_timestep_threshold is not None:
+ logger.info(
+ f"Applying RoPE scaling factor {args.rope_scaling_factor} for timesteps >= {args.rope_scaling_timestep_threshold}"
+ )
+ model.enable_rope_scaling(args.rope_scaling_timestep_threshold, args.rope_scaling_factor)
+ return model
+
+
+def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None:
+ """optimize the model (FP8 conversion, device move etc.)
+
+ Args:
+ model: dit model
+ args: command line arguments
+ device: device to use
+ """
+ if args.fp8_scaled:
+ # load state dict as-is and optimize to fp8
+ state_dict = model.state_dict()
+
+ # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
+ move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
+ state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
+
+ info = model.load_state_dict(state_dict, strict=True, assign=True)
+ logger.info(f"Loaded FP8 optimized weights: {info}")
+
+ if args.blocks_to_swap == 0:
+ model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
+ else:
+ # simple cast to dit_dtype
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
+ target_device = None
+
+ if args.fp8:
+ target_dtype = torch.float8e4m3fn
+
+ if args.blocks_to_swap == 0:
+ logger.info(f"Move model to device: {device}")
+ target_device = device
+
+ if target_device is not None and target_dtype is not None:
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
+
+ # if args.compile:
+ # compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
+ # logger.info(
+ # f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
+ # )
+ # torch._dynamo.config.cache_size_limit = 32
+ # for i in range(len(model.blocks)):
+ # model.blocks[i] = torch.compile(
+ # model.blocks[i],
+ # backend=compile_backend,
+ # mode=compile_mode,
+ # dynamic=compile_dynamic.lower() in "true",
+ # fullgraph=compile_fullgraph.lower() in "true",
+ # )
+
+ if args.blocks_to_swap > 0:
+ logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
+ model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
+ model.move_to_device_except_swap_blocks(device)
+ model.prepare_block_swap_before_forward()
+ else:
+ # make sure the model is on the right device
+ model.to(device)
+
+ model.eval().requires_grad_(False)
+ clean_memory_on_device(device)
+
+
+# endregion
+
+
+def decode_latent(
+ latent_window_size: int,
+ total_latent_sections: int,
+ bulk_decode: bool,
+ vae: AutoencoderKLCausal3D,
+ latent: torch.Tensor,
+ device: torch.device,
+ one_frame_inference_mode: bool = False,
+) -> torch.Tensor:
+ logger.info(f"Decoding video...")
+ if latent.ndim == 4:
+ latent = latent.unsqueeze(0) # add batch dimension
+
+ vae.to(device)
+ if not bulk_decode and not one_frame_inference_mode:
+ latent_window_size = latent_window_size # default is 9
+ # total_latent_sections = (args.video_seconds * 30) / (latent_window_size * 4)
+ # total_latent_sections = int(max(round(total_latent_sections), 1))
+ num_frames = latent_window_size * 4 - 3
+
+ latents_to_decode = []
+ latent_frame_index = 0
+ for i in range(total_latent_sections - 1, -1, -1):
+ is_last_section = i == total_latent_sections - 1
+ generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0)
+ section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
+
+ section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :]
+ if section_latent.shape[2] > 0:
+ latents_to_decode.append(section_latent)
+
+ latent_frame_index += generated_latent_frames
+
+ latents_to_decode = latents_to_decode[::-1] # reverse the order of latents to decode
+
+ history_pixels = None
+ for latent in tqdm(latents_to_decode):
+ if history_pixels is None:
+ history_pixels = hunyuan.vae_decode(latent, vae).cpu()
+ else:
+ overlapped_frames = latent_window_size * 4 - 3
+ current_pixels = hunyuan.vae_decode(latent, vae).cpu()
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
+ clean_memory_on_device(device)
+ else:
+ # bulk decode
+ logger.info(f"Bulk decoding or one frame inference")
+ if not one_frame_inference_mode:
+ history_pixels = hunyuan.vae_decode(latent, vae).cpu() # normal
+ else:
+ # one frame inference
+ history_pixels = [hunyuan.vae_decode(latent[:, :, i : i + 1, :, :], vae).cpu() for i in range(latent.shape[2])]
+ history_pixels = torch.cat(history_pixels, dim=2)
+
+ vae.to("cpu")
+
+ logger.info(f"Decoded. Pixel shape {history_pixels.shape}")
+ return history_pixels[0] # remove batch dimension
+
+
+def prepare_i2v_inputs(
+ args: argparse.Namespace,
+ device: torch.device,
+ vae: AutoencoderKLCausal3D,
+ shared_models: Optional[Dict] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
+ """Prepare inputs for I2V
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ device: device to use
+ vae: VAE model, used for image encoding
+ shared_models: dictionary containing pre-loaded models
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
+ (noise, context, context_null, y, (arg_c, arg_null))
+ """
+
+ height, width, video_seconds = check_inputs(args)
+
+ # define parsing function
+ def parse_section_strings(input_string: str) -> dict[int, str]:
+ section_strings = {}
+ if ";;;" in input_string:
+ split_section_strings = input_string.split(";;;")
+ for section_str in split_section_strings:
+ if ":" not in section_str:
+ start = end = 0
+ section_str = section_str.strip()
+ else:
+ index_str, section_str = section_str.split(":", 1)
+ index_str = index_str.strip()
+ section_str = section_str.strip()
+
+ m = re.match(r"^(-?\d+)(-\d+)?$", index_str)
+ if m:
+ start = int(m.group(1))
+ end = int(m.group(2)[1:]) if m.group(2) is not None else start
+ else:
+ start = end = 0
+ section_str = section_str.strip()
+ for i in range(start, end + 1):
+ section_strings[i] = section_str
+ else:
+ section_strings[0] = input_string
+
+ # assert 0 in section_prompts, "Section prompts must contain section 0"
+ if 0 not in section_strings:
+ # use smallest section index. prefer positive index over negative index
+ # if all section indices are negative, use the smallest negative index
+ indices = list(section_strings.keys())
+ if all(i < 0 for i in indices):
+ section_index = min(indices)
+ else:
+ section_index = min(i for i in indices if i >= 0)
+ section_strings[0] = section_strings[section_index]
+ return section_strings
+
+ # prepare image
+ def preprocess_image(image_path: str):
+ image = Image.open(image_path)
+ if image.mode == "RGBA":
+ alpha = image.split()[-1]
+ else:
+ alpha = None
+ image = image.convert("RGB")
+
+ image_np = np.array(image) # PIL to numpy, HWC
+
+ image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
+ image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
+ image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
+ return image_tensor, image_np, alpha
+
+ section_image_paths = parse_section_strings(args.image_path)
+
+ section_images = {}
+ for index, image_path in section_image_paths.items():
+ img_tensor, img_np, _ = preprocess_image(image_path)
+ section_images[index] = (img_tensor, img_np)
+
+ # check end image
+ if args.end_image_path is not None:
+ end_image_tensor, _, _ = preprocess_image(args.end_image_path)
+ else:
+ end_image_tensor = None
+
+ # check end images
+ if args.control_image_path is not None and len(args.control_image_path) > 0:
+ control_image_tensors = []
+ control_mask_images = []
+ for ctrl_image_path in args.control_image_path:
+ control_image_tensor, _, control_mask = preprocess_image(ctrl_image_path)
+ control_image_tensors.append(control_image_tensor)
+ control_mask_images.append(control_mask)
+ else:
+ control_image_tensors = None
+ control_mask_images = None
+
+ # configure negative prompt
+ n_prompt = args.negative_prompt if args.negative_prompt else ""
+
+ # parse section prompts
+ section_prompts = parse_section_strings(args.prompt)
+
+ # load text encoder
+ if shared_models is not None:
+ tokenizer1, text_encoder1 = shared_models["tokenizer1"], shared_models["text_encoder1"]
+ tokenizer2, text_encoder2 = shared_models["tokenizer2"], shared_models["text_encoder2"]
+ text_encoder1.to(device)
+ else:
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
+ text_encoder2.to(device)
+
+ logger.info(f"Encoding prompt")
+ llama_vecs = {}
+ llama_attention_masks = {}
+ clip_l_poolers = {}
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
+ for index, prompt in section_prompts.items():
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(
+ prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
+ )
+ llama_vec = llama_vec.cpu()
+ clip_l_pooler = clip_l_pooler.cpu()
+
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
+
+ llama_vecs[index] = llama_vec
+ llama_attention_masks[index] = llama_attention_mask
+ clip_l_poolers[index] = clip_l_pooler
+
+ if args.guidance_scale == 1.0:
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vecs[0]), torch.zeros_like(clip_l_poolers[0])
+ else:
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
+ llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(
+ n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
+ )
+ llama_vec_n = llama_vec_n.cpu()
+ clip_l_pooler_n = clip_l_pooler_n.cpu()
+
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
+
+ # free text encoder and clean memory
+ if shared_models is not None: # if shared models are used, do not free them but move to CPU
+ text_encoder1.to("cpu")
+ text_encoder2.to("cpu")
+ del tokenizer1, text_encoder1, tokenizer2, text_encoder2 # do not free shared models
+ clean_memory_on_device(device)
+
+ # load image encoder
+ if shared_models is not None:
+ feature_extractor, image_encoder = shared_models["feature_extractor"], shared_models["image_encoder"]
+ else:
+ feature_extractor, image_encoder = load_image_encoders(args)
+ image_encoder.to(device)
+
+ # encode image with image encoder
+
+ section_image_encoder_last_hidden_states = {}
+ for index, (img_tensor, img_np) in section_images.items():
+ with torch.no_grad():
+ image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu()
+ section_image_encoder_last_hidden_states[index] = image_encoder_last_hidden_state
+
+ # free image encoder and clean memory
+ if shared_models is not None:
+ image_encoder.to("cpu")
+ del image_encoder, feature_extractor
+ clean_memory_on_device(device)
+
+ # VAE encoding
+ logger.info(f"Encoding image to latent space")
+ vae.to(device)
+
+ section_start_latents = {}
+ for index, (img_tensor, img_np) in section_images.items():
+ start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
+ section_start_latents[index] = start_latent
+
+ end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
+
+ control_latents = None
+ if control_image_tensors is not None:
+ control_latents = []
+ for ctrl_image_tensor in control_image_tensors:
+ control_latent = hunyuan.vae_encode(ctrl_image_tensor, vae).cpu()
+ control_latents.append(control_latent)
+
+ vae.to("cpu") # move VAE to CPU to save memory
+ clean_memory_on_device(device)
+
+ # prepare model input arguments
+ arg_c = {}
+ arg_null = {}
+ for index in llama_vecs.keys():
+ llama_vec = llama_vecs[index]
+ llama_attention_mask = llama_attention_masks[index]
+ clip_l_pooler = clip_l_poolers[index]
+ arg_c_i = {
+ "llama_vec": llama_vec,
+ "llama_attention_mask": llama_attention_mask,
+ "clip_l_pooler": clip_l_pooler,
+ "prompt": section_prompts[index], # for debugging
+ }
+ arg_c[index] = arg_c_i
+
+ arg_null = {
+ "llama_vec": llama_vec_n,
+ "llama_attention_mask": llama_attention_mask_n,
+ "clip_l_pooler": clip_l_pooler_n,
+ }
+
+ arg_c_img = {}
+ for index in section_images.keys():
+ image_encoder_last_hidden_state = section_image_encoder_last_hidden_states[index]
+ start_latent = section_start_latents[index]
+ arg_c_img_i = {
+ "image_encoder_last_hidden_state": image_encoder_last_hidden_state,
+ "start_latent": start_latent,
+ "image_path": section_image_paths[index],
+ }
+ arg_c_img[index] = arg_c_img_i
+
+ return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latent, control_latents, control_mask_images
+
+
+# def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
+# """setup scheduler for sampling
+
+# Args:
+# args: command line arguments
+# config: model configuration
+# device: device to use
+
+# Returns:
+# Tuple[Any, torch.Tensor]: (scheduler, timesteps)
+# """
+# if args.sample_solver == "unipc":
+# scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
+# scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
+# timesteps = scheduler.timesteps
+# elif args.sample_solver == "dpm++":
+# scheduler = FlowDPMSolverMultistepScheduler(
+# num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
+# )
+# sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
+# timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
+# elif args.sample_solver == "vanilla":
+# scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
+# scheduler.set_timesteps(args.infer_steps, device=device)
+# timesteps = scheduler.timesteps
+
+# # FlowMatchDiscreteScheduler does not support generator argument in step method
+# org_step = scheduler.step
+
+# def step_wrapper(
+# model_output: torch.Tensor,
+# timestep: Union[int, torch.Tensor],
+# sample: torch.Tensor,
+# return_dict: bool = True,
+# generator=None,
+# ):
+# return org_step(model_output, timestep, sample, return_dict=return_dict)
+
+# scheduler.step = step_wrapper
+# else:
+# raise NotImplementedError("Unsupported solver.")
+
+# return scheduler, timesteps
+
+
+def convert_lora_for_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ # Check the format of the LoRA file
+ keys = list(lora_sd.keys())
+ if keys[0].startswith("lora_unet_"):
+ # logging.info(f"Musubi Tuner LoRA detected")
+ pass
+
+ else:
+ transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
+ lora_suffix = None
+ prefix = None
+ for key in keys:
+ if lora_suffix is None and "lora_A" in key:
+ lora_suffix = "lora_A"
+ if prefix is None:
+ pfx = key.split(".")[0]
+ if pfx in transformer_prefixes:
+ prefix = pfx
+ if lora_suffix is not None and prefix is not None:
+ break
+
+ if lora_suffix == "lora_A" and prefix is not None:
+ logging.info(f"Diffusion-pipe (?) LoRA detected, converting to the default LoRA format")
+ lora_sd = convert_lora_from_diffusion_pipe_or_something(lora_sd, "lora_unet_")
+
+ else:
+ logging.info(f"LoRA file format not recognized. Using it as-is.")
+
+ # Check LoRA is for FramePack or for HunyuanVideo
+ is_hunyuan = False
+ for key in lora_sd.keys():
+ if "double_blocks" in key or "single_blocks" in key:
+ is_hunyuan = True
+ break
+ if is_hunyuan:
+ logging.info("HunyuanVideo LoRA detected, converting to FramePack format")
+ lora_sd = convert_hunyuan_to_framepack(lora_sd)
+
+ return lora_sd
+
+
+def convert_lora_from_diffusion_pipe_or_something(lora_sd: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
+ """
+ Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
+ Copy from Musubi Tuner repo.
+ """
+ # convert from diffusers(?) to default LoRA
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
+
+ # note: Diffusers has no alpha, so alpha is set to rank
+ new_weights_sd = {}
+ lora_dims = {}
+ for key, weight in lora_sd.items():
+ diffusers_prefix, key_body = key.split(".", 1)
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
+ print(f"unexpected key: {key} in diffusers format")
+ continue
+
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
+ new_weights_sd[new_key] = weight
+
+ lora_name = new_key.split(".")[0] # before first dot
+ if lora_name not in lora_dims and "lora_down" in new_key:
+ lora_dims[lora_name] = weight.shape[0]
+
+ # add alpha with rank
+ for lora_name, dim in lora_dims.items():
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
+
+ return new_weights_sd
+
+
+def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Convert HunyuanVideo LoRA weights to FramePack format.
+ """
+ new_lora_sd = {}
+ for key, weight in lora_sd.items():
+ if "double_blocks" in key:
+ key = key.replace("double_blocks", "transformer_blocks")
+ key = key.replace("img_mod_linear", "norm1_linear")
+ key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
+ key = key.replace("img_attn_proj", "attn_to_out_0")
+ key = key.replace("img_mlp_fc1", "ff_net_0_proj")
+ key = key.replace("img_mlp_fc2", "ff_net_2")
+ key = key.replace("txt_mod_linear", "norm1_context_linear")
+ key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
+ key = key.replace("txt_attn_proj", "attn_to_add_out")
+ key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
+ key = key.replace("txt_mlp_fc2", "ff_context_net_2")
+ elif "single_blocks" in key:
+ key = key.replace("single_blocks", "single_transformer_blocks")
+ key = key.replace("linear1", "attn_to_QKVM") # split later
+ key = key.replace("linear2", "proj_out")
+ key = key.replace("modulation_linear", "norm_linear")
+ else:
+ print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
+ continue
+
+ if "QKVM" in key:
+ # split QKVM into Q, K, V, M
+ key_q = key.replace("QKVM", "q")
+ key_k = key.replace("QKVM", "k")
+ key_v = key.replace("QKVM", "v")
+ key_m = key.replace("attn_to_QKVM", "proj_mlp")
+ if "_down" in key or "alpha" in key:
+ # copy QKVM weight or alpha to Q, K, V, M
+ assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
+ new_lora_sd[key_q] = weight
+ new_lora_sd[key_k] = weight
+ new_lora_sd[key_v] = weight
+ new_lora_sd[key_m] = weight
+ elif "_up" in key:
+ # split QKVM weight into Q, K, V, M
+ assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
+ new_lora_sd[key_q] = weight[:3072]
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
+ new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
+ new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
+ else:
+ print(f"Unsupported module name: {key}")
+ continue
+ elif "QKV" in key:
+ # split QKV into Q, K, V
+ key_q = key.replace("QKV", "q")
+ key_k = key.replace("QKV", "k")
+ key_v = key.replace("QKV", "v")
+ if "_down" in key or "alpha" in key:
+ # copy QKV weight or alpha to Q, K, V
+ assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
+ new_lora_sd[key_q] = weight
+ new_lora_sd[key_k] = weight
+ new_lora_sd[key_v] = weight
+ elif "_up" in key:
+ # split QKV weight into Q, K, V
+ assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
+ new_lora_sd[key_q] = weight[:3072]
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
+ new_lora_sd[key_v] = weight[3072 * 2 :]
+ else:
+ print(f"Unsupported module name: {key}")
+ continue
+ else:
+ # no split needed
+ new_lora_sd[key] = weight
+
+ return new_lora_sd
+
+
+def generate(
+ args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None
+) -> tuple[AutoencoderKLCausal3D, torch.Tensor]:
+ """main function for generation
+
+ Args:
+ args: command line arguments
+ shared_models: dictionary containing pre-loaded models
+
+ Returns:
+ tuple: (AutoencoderKLCausal3D model (vae), torch.Tensor generated latent)
+ """
+ device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype)
+
+ # prepare seed
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+ args.seed = seed # set seed to args for saving
+
+ # Check if we have shared models
+ if shared_models is not None:
+ # Use shared models and encoded data
+ vae = shared_models.get("vae")
+ height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
+ prepare_i2v_inputs(args, device, vae, shared_models)
+ )
+ else:
+ # prepare inputs without shared models
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
+ height, width, video_seconds, context, context_null, context_img, end_latent, control_latents, control_mask_images = (
+ prepare_i2v_inputs(args, device, vae)
+ )
+
+ if shared_models is None or "model" not in shared_models:
+ # load DiT model
+ model = load_dit_model(args, device)
+
+ # merge LoRA weights
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ # ugly hack to common merge_lora_weights function
+ merge_lora_weights(lora_framepack, model, args, device, convert_lora_for_framepack)
+
+ # if we only want to save the model, we can skip the rest
+ if args.save_merged_model:
+ return None, None
+
+ # optimize model: fp8 conversion, block swap etc.
+ optimize_model(model, args, device)
+
+ if shared_models is not None:
+ shared_models["model"] = model
+ else:
+ # use shared model
+ model: HunyuanVideoTransformer3DModelPacked = shared_models["model"]
+ model.move_to_device_except_swap_blocks(device)
+ model.prepare_block_swap_before_forward()
+
+ # sampling
+ latent_window_size = args.latent_window_size # default is 9
+ # ex: (5s * 30fps) / (9 * 4) = 4.16 -> 4 sections, 60s -> 1800 / 36 = 50 sections
+ total_latent_sections = (video_seconds * 30) / (latent_window_size * 4)
+ total_latent_sections = int(max(round(total_latent_sections), 1))
+
+ # set random generator
+ seed_g = torch.Generator(device="cpu")
+ seed_g.manual_seed(seed)
+ num_frames = latent_window_size * 4 - 3
+
+ logger.info(
+ f"Video size: {height}x{width}@{video_seconds} (HxW@seconds), fps: {args.fps}, num sections: {total_latent_sections}, "
+ f"infer_steps: {args.infer_steps}, frames per generation: {num_frames}"
+ )
+
+ # video generation ######
+ f1_mode = args.f1
+ one_frame_inference = None
+ if args.one_frame_inference is not None:
+ one_frame_inference = set()
+ for mode in args.one_frame_inference.split(","):
+ one_frame_inference.add(mode.strip())
+
+ if one_frame_inference is not None:
+ real_history_latents = generate_with_one_frame_inference(
+ args,
+ model,
+ context,
+ context_null,
+ context_img,
+ control_latents,
+ control_mask_images,
+ latent_window_size,
+ height,
+ width,
+ device,
+ seed_g,
+ one_frame_inference,
+ )
+ else:
+ # prepare history latents
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
+ if end_latent is not None and not f1_mode:
+ logger.info(f"Use end image(s): {args.end_image_path}")
+ history_latents[:, :, :1] = end_latent.to(history_latents)
+
+ # prepare clean latents and indices
+ if not f1_mode:
+ # Inverted Anti-drifting
+ total_generated_latent_frames = 0
+ latent_paddings = reversed(range(total_latent_sections))
+
+ if total_latent_sections > 4 and one_frame_inference is None:
+ # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
+ # items looks better than expanding it when total_latent_sections > 4
+ # One can try to remove below trick and just
+ # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
+ # 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
+
+ if args.latent_paddings is not None:
+ # parse user defined latent paddings
+ user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
+ if len(user_latent_paddings) < total_latent_sections:
+ print(
+ f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
+ )
+ print(f"Use default paddings instead for unspecified sections.")
+ latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
+ elif len(user_latent_paddings) > total_latent_sections:
+ print(
+ f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
+ )
+ print(f"Use only first {total_latent_sections} paddings instead.")
+ latent_paddings = user_latent_paddings[:total_latent_sections]
+ else:
+ latent_paddings = user_latent_paddings
+ else:
+ start_latent = context_img[0]["start_latent"]
+ history_latents = torch.cat([history_latents, start_latent], dim=2)
+ total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
+ latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
+
+ latent_paddings = list(latent_paddings) # make sure it's a list
+ for loop_index in range(total_latent_sections):
+ latent_padding = latent_paddings[loop_index]
+
+ if not f1_mode:
+ # Inverted Anti-drifting
+ section_index_reverse = loop_index # 0, 1, 2, 3
+ section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
+ section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
+
+ is_last_section = section_index == 0
+ is_first_section = section_index_reverse == 0
+ latent_padding_size = latent_padding * latent_window_size
+
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
+ else:
+ section_index = loop_index # 0, 1, 2, 3
+ section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
+ is_last_section = loop_index == total_latent_sections - 1
+ is_first_section = loop_index == 0
+ latent_padding_size = 0 # dummy padding for F1 mode
+
+ # select start latent
+ if section_index_from_last in context_img:
+ image_index = section_index_from_last
+ elif section_index in context_img:
+ image_index = section_index
+ else:
+ image_index = 0
+
+ start_latent = context_img[image_index]["start_latent"]
+ image_path = context_img[image_index]["image_path"]
+ if image_index != 0: # use section image other than section 0
+ logger.info(
+ f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}"
+ )
+
+ if not f1_mode:
+ # Inverted Anti-drifting
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
+ (
+ clean_latent_indices_pre,
+ blank_indices,
+ latent_indices,
+ clean_latent_indices_post,
+ clean_latent_2x_indices,
+ clean_latent_4x_indices,
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
+
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
+
+ clean_latents_pre = start_latent.to(history_latents)
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
+ [1, 2, 16], dim=2
+ )
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
+
+ else:
+ # F1 mode
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
+ (
+ clean_latent_indices_start,
+ clean_latent_4x_indices,
+ clean_latent_2x_indices,
+ clean_latent_1x_indices,
+ latent_indices,
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
+
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
+ [16, 2, 1], dim=2
+ )
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
+
+ # if use_teacache:
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
+ # else:
+ # transformer.initialize_teacache(enable_teacache=False)
+
+ # prepare conditioning inputs
+ if section_index_from_last in context:
+ prompt_index = section_index_from_last
+ elif section_index in context:
+ prompt_index = section_index
+ else:
+ prompt_index = 0
+
+ context_for_index = context[prompt_index]
+ # if args.section_prompts is not None:
+ logger.info(f"Section {section_index}: {context_for_index['prompt']}")
+
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
+
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
+ device, dtype=torch.bfloat16
+ )
+
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
+
+ generated_latents = sample_hunyuan(
+ transformer=model,
+ sampler=args.sample_solver,
+ width=width,
+ height=height,
+ frames=num_frames,
+ real_guidance_scale=args.guidance_scale,
+ distilled_guidance_scale=args.embedded_cfg_scale,
+ guidance_rescale=args.guidance_rescale,
+ # shift=3.0,
+ num_inference_steps=args.infer_steps,
+ generator=seed_g,
+ prompt_embeds=llama_vec,
+ prompt_embeds_mask=llama_attention_mask,
+ prompt_poolers=clip_l_pooler,
+ negative_prompt_embeds=llama_vec_n,
+ negative_prompt_embeds_mask=llama_attention_mask_n,
+ negative_prompt_poolers=clip_l_pooler_n,
+ device=device,
+ dtype=torch.bfloat16,
+ image_embeddings=image_encoder_last_hidden_state,
+ latent_indices=latent_indices,
+ clean_latents=clean_latents,
+ clean_latent_indices=clean_latent_indices,
+ clean_latents_2x=clean_latents_2x,
+ clean_latent_2x_indices=clean_latent_2x_indices,
+ clean_latents_4x=clean_latents_4x,
+ clean_latent_4x_indices=clean_latent_4x_indices,
+ )
+
+ # concatenate generated latents
+ total_generated_latent_frames += int(generated_latents.shape[2])
+ if not f1_mode:
+ # Inverted Anti-drifting: prepend generated latents to history latents
+ if is_last_section:
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
+ total_generated_latent_frames += 1
+
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
+ else:
+ # F1 mode: append generated latents to history latents
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
+
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
+
+ # # TODO support saving intermediate video
+ # clean_memory_on_device(device)
+ # vae.to(device)
+ # if history_pixels is None:
+ # history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
+ # else:
+ # section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
+ # overlapped_frames = latent_window_size * 4 - 3
+ # current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
+ # history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
+ # vae.to("cpu")
+ # # if not is_last_section:
+ # # # save intermediate video
+ # # save_video(history_pixels[0], args, total_generated_latent_frames)
+ # print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
+
+ # Only clean up shared models if they were created within this function
+ if shared_models is None:
+ del model # free memory
+ synchronize_device(device)
+ else:
+ # move model to CPU to save memory
+ model.to("cpu")
+
+ # wait for 5 seconds until block swap is done
+ if args.blocks_to_swap > 0:
+ logger.info("Waiting for 5 seconds to finish block swap")
+ time.sleep(5)
+
+ gc.collect()
+ clean_memory_on_device(device)
+
+ return vae, real_history_latents
+
+
+def generate_with_one_frame_inference(
+ args: argparse.Namespace,
+ model: HunyuanVideoTransformer3DModelPacked,
+ context: Dict[int, Dict[str, torch.Tensor]],
+ context_null: Dict[str, torch.Tensor],
+ context_img: Dict[int, Dict[str, torch.Tensor]],
+ control_latents: Optional[List[torch.Tensor]],
+ control_mask_images: Optional[List[Optional[Image.Image]]],
+ latent_window_size: int,
+ height: int,
+ width: int,
+ device: torch.device,
+ seed_g: torch.Generator,
+ one_frame_inference: set[str],
+) -> torch.Tensor:
+ # one frame inference
+ sample_num_frames = 1
+ latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
+ latent_indices[:, 0] = latent_window_size # last of latent_window
+
+ def get_latent_mask(mask_image: Image.Image) -> torch.Tensor:
+ if mask_image.mode != "L":
+ mask_image = mask_image.convert("L")
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (BCFHW)
+ mask_image = mask_image.to(torch.float32)
+ return mask_image
+
+ if control_latents is None or len(control_latents) == 0:
+ logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
+ control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
+
+ if "no_post" not in one_frame_inference:
+ # add zero latents as clean latents post
+ control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
+ logger.info(f"Add zero latents as clean latents post for one frame inference.")
+
+ # kisekaeichi and 1f-mc: both are using control images, but indices are different
+ clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
+ clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
+ if "no_post" not in one_frame_inference:
+ clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
+
+ for i in range(len(control_latents)):
+ mask_image = None
+ if args.control_image_mask_path is not None and i < len(args.control_image_mask_path):
+ mask_image = get_latent_mask(Image.open(args.control_image_mask_path[i]))
+ logger.info(
+ f"Apply mask for clean latents 1x for {i + 1}: {args.control_image_mask_path[i]}, shape: {mask_image.shape}"
+ )
+ elif control_mask_images is not None and i < len(control_mask_images) and control_mask_images[i] is not None:
+ mask_image = get_latent_mask(control_mask_images[i])
+ logger.info(f"Apply mask for clean latents 1x for {i + 1} with alpha channel: {mask_image.shape}")
+ if mask_image is not None:
+ clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * mask_image
+
+ for one_frame_param in one_frame_inference:
+ if one_frame_param.startswith("target_index="):
+ target_index = int(one_frame_param.split("=")[1])
+ latent_indices[:, 0] = target_index
+ logger.info(f"Set index for target: {target_index}")
+ elif one_frame_param.startswith("control_index="):
+ control_indices = one_frame_param.split("=")[1].split(";")
+ i = 0
+ while i < len(control_indices) and i < clean_latent_indices.shape[1]:
+ control_index = int(control_indices[i])
+ clean_latent_indices[:, i] = control_index
+ i += 1
+ logger.info(f"Set index for clean latent 1x: {control_indices}")
+
+ # "default" option does nothing, so we can skip it
+ if "default" in one_frame_inference:
+ pass
+
+ if "no_2x" in one_frame_inference:
+ clean_latents_2x = None
+ clean_latent_2x_indices = None
+ logger.info(f"No clean_latents_2x")
+ else:
+ clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
+ index = 1 + latent_window_size + 1
+ clean_latent_2x_indices = torch.arange(index, index + 2).unsqueeze(0) # 2
+
+ if "no_4x" in one_frame_inference:
+ clean_latents_4x = None
+ clean_latent_4x_indices = None
+ logger.info(f"No clean_latents_4x")
+ else:
+ clean_latents_4x = torch.zeros((1, 16, 16, height // 8, width // 8), dtype=torch.float32)
+ index = 1 + latent_window_size + 1 + 2
+ clean_latent_4x_indices = torch.arange(index, index + 16).unsqueeze(0) # 16
+
+ logger.info(
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
+ )
+
+ # prepare conditioning inputs
+ prompt_index = 0
+ image_index = 0
+
+ context_for_index = context[prompt_index]
+ logger.info(f"Prompt: {context_for_index['prompt']}")
+
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
+
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(device, dtype=torch.bfloat16)
+
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
+
+ generated_latents = sample_hunyuan(
+ transformer=model,
+ sampler=args.sample_solver,
+ width=width,
+ height=height,
+ frames=1,
+ real_guidance_scale=args.guidance_scale,
+ distilled_guidance_scale=args.embedded_cfg_scale,
+ guidance_rescale=args.guidance_rescale,
+ # shift=3.0,
+ num_inference_steps=args.infer_steps,
+ generator=seed_g,
+ prompt_embeds=llama_vec,
+ prompt_embeds_mask=llama_attention_mask,
+ prompt_poolers=clip_l_pooler,
+ negative_prompt_embeds=llama_vec_n,
+ negative_prompt_embeds_mask=llama_attention_mask_n,
+ negative_prompt_poolers=clip_l_pooler_n,
+ device=device,
+ dtype=torch.bfloat16,
+ image_embeddings=image_encoder_last_hidden_state,
+ latent_indices=latent_indices,
+ clean_latents=clean_latents,
+ clean_latent_indices=clean_latent_indices,
+ clean_latents_2x=clean_latents_2x,
+ clean_latent_2x_indices=clean_latent_2x_indices,
+ clean_latents_4x=clean_latents_4x,
+ clean_latent_4x_indices=clean_latent_4x_indices,
+ )
+
+ real_history_latents = generated_latents.to(clean_latents)
+ return real_history_latents
+
+
+def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
+ """Save latent to file
+
+ Args:
+ latent: Latent tensor
+ args: command line arguments
+ height: height of frame
+ width: width of frame
+
+ Returns:
+ str: Path to saved latent file
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ seed = args.seed
+ video_seconds = args.video_seconds
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
+
+ if args.no_metadata:
+ metadata = None
+ else:
+ metadata = {
+ "seeds": f"{seed}",
+ "prompt": f"{args.prompt}",
+ "height": f"{height}",
+ "width": f"{width}",
+ "video_seconds": f"{video_seconds}",
+ "infer_steps": f"{args.infer_steps}",
+ "guidance_scale": f"{args.guidance_scale}",
+ "latent_window_size": f"{args.latent_window_size}",
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
+ "guidance_rescale": f"{args.guidance_rescale}",
+ "sample_solver": f"{args.sample_solver}",
+ "latent_window_size": f"{args.latent_window_size}",
+ "fps": f"{args.fps}",
+ }
+ if args.negative_prompt is not None:
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
+
+ sd = {"latent": latent.contiguous()}
+ save_file(sd, latent_path, metadata=metadata)
+ logger.info(f"Latent saved to: {latent_path}")
+
+ return latent_path
+
+
+def save_video(
+ video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None
+) -> str:
+ """Save video to file
+
+ Args:
+ video: Video tensor
+ args: command line arguments
+ original_base_name: Original base name (if latents are loaded from files)
+
+ Returns:
+ str: Path to saved video file
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ seed = args.seed
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
+ latent_frames = "" if latent_frames is None else f"_{latent_frames}"
+ video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4"
+
+ video = video.unsqueeze(0)
+ save_videos_grid(video, video_path, fps=args.fps, rescale=True)
+ logger.info(f"Video saved to: {video_path}")
+
+ return video_path
+
+
+def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
+ """Save images to directory
+
+ Args:
+ sample: Video tensor
+ args: command line arguments
+ original_base_name: Original base name (if latents are loaded from files)
+
+ Returns:
+ str: Path to saved images directory
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ seed = args.seed
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
+ image_name = f"{time_flag}_{seed}{original_name}"
+ sample = sample.unsqueeze(0)
+ one_frame_mode = args.one_frame_inference is not None
+ save_images_grid(sample, save_path, image_name, rescale=True, create_subdir=not one_frame_mode)
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
+
+ return f"{save_path}/{image_name}"
+
+
+def save_output(
+ args: argparse.Namespace,
+ vae: AutoencoderKLCausal3D,
+ latent: torch.Tensor,
+ device: torch.device,
+ original_base_names: Optional[List[str]] = None,
+) -> None:
+ """save output
+
+ Args:
+ args: command line arguments
+ vae: VAE model
+ latent: latent tensor
+ device: device to use
+ original_base_names: original base names (if latents are loaded from files)
+ """
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
+ height *= 8
+ width *= 8
+ # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}")
+ if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
+ # save latent
+ save_latent(latent, args, height, width)
+ if args.output_type == "latent":
+ return
+
+ total_latent_sections = (args.video_seconds * 30) / (args.latent_window_size * 4)
+ total_latent_sections = int(max(round(total_latent_sections), 1))
+ video = decode_latent(
+ args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent, device, args.one_frame_inference is not None
+ )
+
+ if args.output_type == "video" or args.output_type == "both":
+ # save video
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
+ save_video(video, args, original_name)
+
+ elif args.output_type == "images" or args.output_type == "latent_images":
+ # save images
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
+ save_images(video, args, original_name)
+
+
+def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
+ """Process multiple prompts for batch mode
+
+ Args:
+ prompt_lines: List of prompt lines
+ base_args: Base command line arguments
+
+ Returns:
+ List[Dict]: List of prompt data dictionaries
+ """
+ prompts_data = []
+
+ for line in prompt_lines:
+ line = line.strip()
+ if not line or line.startswith("#"): # Skip empty lines and comments
+ continue
+
+ # Parse prompt line and create override dictionary
+ prompt_data = parse_prompt_line(line)
+ logger.info(f"Parsed prompt data: {prompt_data}")
+ prompts_data.append(prompt_data)
+
+ return prompts_data
+
+
+def load_shared_models(args: argparse.Namespace) -> Dict:
+ """Load shared models for batch processing or interactive mode.
+ Models are loaded to CPU to save memory.
+
+ Args:
+ args: Base command line arguments
+
+ Returns:
+ Dict: Dictionary of shared models
+ """
+ shared_models = {}
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, "cpu")
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
+ feature_extractor, image_encoder = load_image_encoders(args)
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
+ shared_models["tokenizer1"] = tokenizer1
+ shared_models["text_encoder1"] = text_encoder1
+ shared_models["tokenizer2"] = tokenizer2
+ shared_models["text_encoder2"] = text_encoder2
+ shared_models["feature_extractor"] = feature_extractor
+ shared_models["image_encoder"] = image_encoder
+ shared_models["vae"] = vae
+
+ return shared_models
+
+
+def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
+ """Process multiple prompts with model reuse
+
+ Args:
+ prompts_data: List of prompt data dictionaries
+ args: Base command line arguments
+ """
+ if not prompts_data:
+ logger.warning("No valid prompts found")
+ return
+
+ # 1. Load configuration
+ gen_settings = get_generation_settings(args)
+ device = gen_settings.device
+
+ # 2. Load models to CPU in advance except for VAE and DiT
+ shared_models = load_shared_models(args)
+
+ # 3. Generate for each prompt
+ all_latents = []
+ all_prompt_args = []
+
+ with torch.no_grad():
+ for prompt_data in prompts_data:
+ prompt = prompt_data["prompt"]
+ prompt_args = apply_overrides(args, prompt_data)
+ logger.info(f"Processing prompt: {prompt}")
+
+ try:
+ vae, latent = generate(prompt_args, gen_settings, shared_models)
+
+ # Save latent if needed
+ if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
+ height *= 8
+ width *= 8
+ save_latent(latent, prompt_args, height, width)
+
+ all_latents.append(latent)
+ all_prompt_args.append(prompt_args)
+ except Exception as e:
+ logger.error(f"Error processing prompt: {prompt}. Error: {e}")
+ continue
+
+ # 4. Free models
+ if "model" in shared_models:
+ del shared_models["model"]
+ del shared_models["tokenizer1"]
+ del shared_models["text_encoder1"]
+ del shared_models["tokenizer2"]
+ del shared_models["text_encoder2"]
+ del shared_models["feature_extractor"]
+ del shared_models["image_encoder"]
+
+ clean_memory_on_device(device)
+ synchronize_device(device)
+
+ # 5. Decode latents if needed
+ if args.output_type != "latent":
+ logger.info("Decoding latents to videos/images")
+ vae.to(device)
+
+ for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
+ logger.info(f"Decoding output {i+1}/{len(all_latents)}")
+
+ # avoid saving latents again (ugly hack)
+ if prompt_args.output_type == "both":
+ prompt_args.output_type = "video"
+ elif prompt_args.output_type == "latent_images":
+ prompt_args.output_type = "images"
+
+ save_output(prompt_args, vae, latent[0], device)
+
+
+def process_interactive(args: argparse.Namespace) -> None:
+ """Process prompts in interactive mode
+
+ Args:
+ args: Base command line arguments
+ """
+ gen_settings = get_generation_settings(args)
+ device = gen_settings.device
+ shared_models = load_shared_models(args)
+
+ print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
+
+ try:
+ while True:
+ try:
+ line = input("> ")
+ if not line.strip():
+ continue
+
+ # Parse prompt
+ prompt_data = parse_prompt_line(line)
+ prompt_args = apply_overrides(args, prompt_data)
+
+ # Generate latent
+ vae, latent = generate(prompt_args, gen_settings, shared_models)
+
+ # Save latent and video
+ save_output(prompt_args, vae, latent[0], device)
+
+ except KeyboardInterrupt:
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
+ continue
+
+ except EOFError:
+ print("\nExiting interactive mode")
+
+
+def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
+ device = torch.device(args.device)
+
+ dit_weight_dtype = None # default
+ if args.fp8_scaled:
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
+ elif args.fp8:
+ dit_weight_dtype = torch.float8_e4m3fn
+
+ logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}")
+
+ gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype)
+ return gen_settings
+
+
+def main():
+ # Parse arguments
+ args = parse_args()
+
+ # Check if latents are provided
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
+
+ # Set device
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ logger.info(f"Using device: {device}")
+ args.device = device
+
+ if latents_mode:
+ # Original latent decode mode
+ original_base_names = []
+ latents_list = []
+ seeds = []
+
+ # assert len(args.latent_path) == 1, "Only one latent path is supported for now"
+
+ for latent_path in args.latent_path:
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
+ seed = 0
+
+ if os.path.splitext(latent_path)[1] != ".safetensors":
+ latents = torch.load(latent_path, map_location="cpu")
+ else:
+ latents = load_file(latent_path)["latent"]
+ with safe_open(latent_path, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ logger.info(f"Loaded metadata: {metadata}")
+
+ if "seeds" in metadata:
+ seed = int(metadata["seeds"])
+ if "height" in metadata and "width" in metadata:
+ height = int(metadata["height"])
+ width = int(metadata["width"])
+ args.video_size = [height, width]
+ if "video_seconds" in metadata:
+ args.video_seconds = float(metadata["video_seconds"])
+
+ seeds.append(seed)
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
+
+ if latents.ndim == 5: # [BCTHW]
+ latents = latents.squeeze(0) # [CTHW]
+
+ latents_list.append(latents)
+
+ # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
+
+ for i, latent in enumerate(latents_list):
+ args.seed = seeds[i]
+
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
+ save_output(args, vae, latent, device, original_base_names)
+
+ elif args.from_file:
+ # Batch mode from file
+
+ # Read prompts from file
+ with open(args.from_file, "r", encoding="utf-8") as f:
+ prompt_lines = f.readlines()
+
+ # Process prompts
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
+ process_batch_prompts(prompts_data, args)
+
+ elif args.interactive:
+ # Interactive mode
+ process_interactive(args)
+
+ else:
+ # Single prompt mode (original behavior)
+
+ # Generate latent
+ gen_settings = get_generation_settings(args)
+ vae, latent = generate(args, gen_settings)
+ # print(f"Generated latent shape: {latent.shape}")
+ if args.save_merged_model:
+ return
+
+ # Save latent and video
+ save_output(args, vae, latent[0], device)
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fpack_train_network.py b/fpack_train_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7bee4defcff1418ca49e124d022ec12f08074d2
--- /dev/null
+++ b/fpack_train_network.py
@@ -0,0 +1,617 @@
+import argparse
+import gc
+import math
+import time
+from typing import Optional
+from PIL import Image
+
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from tqdm import tqdm
+from accelerate import Accelerator, init_empty_weights
+
+from dataset import image_video_dataset
+from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ARCHITECTURE_FRAMEPACK_FULL, load_video
+from fpack_generate_video import decode_latent
+from frame_pack import hunyuan
+from frame_pack.clip_vision import hf_clip_vision_encode
+from frame_pack.framepack_utils import load_image_encoders, load_text_encoder1, load_text_encoder2
+from frame_pack.framepack_utils import load_vae as load_framepack_vae
+from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
+from frame_pack.k_diffusion_hunyuan import sample_hunyuan
+from frame_pack.utils import crop_or_pad_yield_mask
+from dataset.image_video_dataset import resize_image_to_bucket
+from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+from utils import model_utils
+from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
+
+
+class FramePackNetworkTrainer(NetworkTrainer):
+ def __init__(self):
+ super().__init__()
+
+ # region model specific
+
+ @property
+ def architecture(self) -> str:
+ return ARCHITECTURE_FRAMEPACK
+
+ @property
+ def architecture_full_name(self) -> str:
+ return ARCHITECTURE_FRAMEPACK_FULL
+
+ def handle_model_specific_args(self, args):
+ self._i2v_training = True
+ self._control_training = False
+ self.default_guidance_scale = 10.0 # embeded guidance scale
+
+ def process_sample_prompts(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ sample_prompts: str,
+ ):
+ device = accelerator.device
+
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
+ prompts = load_prompts(sample_prompts)
+
+ # load text encoder
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
+ text_encoder2.to(device)
+
+ sample_prompts_te_outputs = {} # (prompt) -> (t1 embeds, t1 mask, t2 embeds)
+ for prompt_dict in prompts:
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
+ if p is None or p in sample_prompts_te_outputs:
+ continue
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
+ with torch.amp.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(p, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
+
+ llama_vec = llama_vec.to("cpu")
+ llama_attention_mask = llama_attention_mask.to("cpu")
+ clip_l_pooler = clip_l_pooler.to("cpu")
+ sample_prompts_te_outputs[p] = (llama_vec, llama_attention_mask, clip_l_pooler)
+ del text_encoder1, text_encoder2
+ clean_memory_on_device(device)
+
+ # image embedding for I2V training
+ feature_extractor, image_encoder = load_image_encoders(args)
+ image_encoder.to(device)
+
+ # encode image with image encoder
+ sample_prompts_image_embs = {}
+ for prompt_dict in prompts:
+ image_path = prompt_dict.get("image_path", None)
+ assert image_path is not None, "image_path should be set for I2V training"
+ if image_path in sample_prompts_image_embs:
+ continue
+
+ logger.info(f"Encoding image to image encoder context: {image_path}")
+
+ height = prompt_dict.get("height", 256)
+ width = prompt_dict.get("width", 256)
+
+ img = Image.open(image_path).convert("RGB")
+ img_np = np.array(img) # PIL to numpy, HWC
+ img_np = image_video_dataset.resize_image_to_bucket(img_np, (width, height)) # returns a numpy array
+
+ with torch.no_grad():
+ image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
+
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to("cpu")
+ sample_prompts_image_embs[image_path] = image_encoder_last_hidden_state
+
+ del image_encoder
+ clean_memory_on_device(device)
+
+ # prepare sample parameters
+ sample_parameters = []
+ for prompt_dict in prompts:
+ prompt_dict_copy = prompt_dict.copy()
+
+ p = prompt_dict.get("prompt", "")
+ llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
+ prompt_dict_copy["llama_vec"] = llama_vec
+ prompt_dict_copy["llama_attention_mask"] = llama_attention_mask
+ prompt_dict_copy["clip_l_pooler"] = clip_l_pooler
+
+ p = prompt_dict.get("negative_prompt", "")
+ llama_vec, llama_attention_mask, clip_l_pooler = sample_prompts_te_outputs[p]
+ prompt_dict_copy["negative_llama_vec"] = llama_vec
+ prompt_dict_copy["negative_llama_attention_mask"] = llama_attention_mask
+ prompt_dict_copy["negative_clip_l_pooler"] = clip_l_pooler
+
+ p = prompt_dict.get("image_path", None)
+ prompt_dict_copy["image_encoder_last_hidden_state"] = sample_prompts_image_embs[p]
+
+ sample_parameters.append(prompt_dict_copy)
+
+ clean_memory_on_device(accelerator.device)
+ return sample_parameters
+
+ def do_inference(
+ self,
+ accelerator,
+ args,
+ sample_parameter,
+ vae,
+ dit_dtype,
+ transformer,
+ discrete_flow_shift,
+ sample_steps,
+ width,
+ height,
+ frame_count,
+ generator,
+ do_classifier_free_guidance,
+ guidance_scale,
+ cfg_scale,
+ image_path=None,
+ control_video_path=None,
+ ):
+ """architecture dependent inference"""
+ model: HunyuanVideoTransformer3DModelPacked = transformer
+ device = accelerator.device
+ if cfg_scale is None:
+ cfg_scale = 1.0
+ do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
+
+ # prepare parameters
+ one_frame_mode = args.one_frame
+ if one_frame_mode:
+ one_frame_inference = set()
+ for mode in sample_parameter["one_frame"].split(","):
+ one_frame_inference.add(mode.strip())
+ else:
+ one_frame_inference = None
+
+ latent_window_size = args.latent_window_size # default is 9
+ latent_f = (frame_count - 1) // 4 + 1
+ total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
+ if total_latent_sections < 1 and not one_frame_mode:
+ logger.warning(f"Not enough frames for FramePack: {latent_f}, minimum: {latent_window_size*4+1}")
+ return None
+
+ latent_f = total_latent_sections * latent_window_size + 1
+ actual_frame_count = (latent_f - 1) * 4 + 1
+ if actual_frame_count != frame_count:
+ logger.info(f"Frame count mismatch: {actual_frame_count} != {frame_count}, trimming to {actual_frame_count}")
+ frame_count = actual_frame_count
+ num_frames = latent_window_size * 4 - 3
+
+ # prepare start and control latent
+ def encode_image(path):
+ image = Image.open(path)
+ if image.mode == "RGBA":
+ alpha = image.split()[-1]
+ image = image.convert("RGB")
+ else:
+ alpha = None
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).unsqueeze(0).float() # 1, C, 1, H, W
+ image = image / 127.5 - 1 # -1 to 1
+ return hunyuan.vae_encode(image, vae).to("cpu"), alpha
+
+ # VAE encoding
+ logger.info(f"Encoding image to latent space")
+ vae.to(device)
+
+ start_latent, _ = (
+ encode_image(image_path) if image_path else torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32)
+ )
+
+ if one_frame_mode:
+ control_latents = []
+ control_alphas = []
+ if "control_image_path" in sample_parameter:
+ for control_image_path in sample_parameter["control_image_path"]:
+ control_latent, control_alpha = encode_image(control_image_path)
+ control_latents.append(control_latent)
+ control_alphas.append(control_alpha)
+ else:
+ control_latents = None
+ control_alphas = None
+
+ vae.to("cpu") # move VAE to CPU to save memory
+ clean_memory_on_device(device)
+
+ # sampilng
+ if not one_frame_mode:
+ f1_mode = args.f1
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
+
+ if not f1_mode:
+ total_generated_latent_frames = 0
+ latent_paddings = reversed(range(total_latent_sections))
+ else:
+ total_generated_latent_frames = 1
+ history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
+ latent_paddings = [0] * total_latent_sections
+
+ if total_latent_sections > 4:
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
+
+ latent_paddings = list(latent_paddings)
+ for loop_index in range(total_latent_sections):
+ latent_padding = latent_paddings[loop_index]
+
+ if not f1_mode:
+ is_last_section = latent_padding == 0
+ latent_padding_size = latent_padding * latent_window_size
+
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
+
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
+ (
+ clean_latent_indices_pre,
+ blank_indices,
+ latent_indices,
+ clean_latent_indices_post,
+ clean_latent_2x_indices,
+ clean_latent_4x_indices,
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
+
+ clean_latents_pre = start_latent.to(history_latents)
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
+ [1, 2, 16], dim=2
+ )
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
+ else:
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
+ (
+ clean_latent_indices_start,
+ clean_latent_4x_indices,
+ clean_latent_2x_indices,
+ clean_latent_1x_indices,
+ latent_indices,
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
+
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
+ [16, 2, 1], dim=2
+ )
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
+
+ # if use_teacache:
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
+ # else:
+ # transformer.initialize_teacache(enable_teacache=False)
+
+ llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
+ clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
+ if cfg_scale == 1.0:
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
+ else:
+ llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
+ clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
+ image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
+ device, dtype=torch.bfloat16
+ )
+
+ generated_latents = sample_hunyuan(
+ transformer=model,
+ sampler=args.sample_solver,
+ width=width,
+ height=height,
+ frames=num_frames,
+ real_guidance_scale=cfg_scale,
+ distilled_guidance_scale=guidance_scale,
+ guidance_rescale=0.0,
+ # shift=3.0,
+ num_inference_steps=sample_steps,
+ generator=generator,
+ prompt_embeds=llama_vec,
+ prompt_embeds_mask=llama_attention_mask,
+ prompt_poolers=clip_l_pooler,
+ negative_prompt_embeds=llama_vec_n,
+ negative_prompt_embeds_mask=llama_attention_mask_n,
+ negative_prompt_poolers=clip_l_pooler_n,
+ device=device,
+ dtype=torch.bfloat16,
+ image_embeddings=image_encoder_last_hidden_state,
+ latent_indices=latent_indices,
+ clean_latents=clean_latents,
+ clean_latent_indices=clean_latent_indices,
+ clean_latents_2x=clean_latents_2x,
+ clean_latent_2x_indices=clean_latent_2x_indices,
+ clean_latents_4x=clean_latents_4x,
+ clean_latent_4x_indices=clean_latent_4x_indices,
+ )
+
+ total_generated_latent_frames += int(generated_latents.shape[2])
+ if not f1_mode:
+ if is_last_section:
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
+ total_generated_latent_frames += 1
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
+ else:
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
+
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
+ else:
+ # one frame mode
+ sample_num_frames = 1
+ latent_indices = torch.zeros((1, 1), dtype=torch.int64) # 1x1 latent index for target image
+ latent_indices[:, 0] = latent_window_size # last of latent_window
+
+ def get_latent_mask(mask_image: Image.Image):
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) # HW -> 111HW (B, C, F, H, W)
+ mask_image = mask_image.to(torch.float32)
+ return mask_image
+
+ if control_latents is None or len(control_latents) == 0:
+ logger.info(f"No control images provided for one frame inference. Use zero latents for control images.")
+ control_latents = [torch.zeros(1, 16, 1, height // 8, width // 8, dtype=torch.float32)]
+
+ if "no_post" not in one_frame_inference:
+ # add zero latents as clean latents post
+ control_latents.append(torch.zeros((1, 16, 1, height // 8, width // 8), dtype=torch.float32))
+ logger.info(f"Add zero latents as clean latents post for one frame inference.")
+
+ # kisekaeichi and 1f-mc: both are using control images, but indices are different
+ clean_latents = torch.cat(control_latents, dim=2) # (1, 16, num_control_images, H//8, W//8)
+ clean_latent_indices = torch.zeros((1, len(control_latents)), dtype=torch.int64)
+ if "no_post" not in one_frame_inference:
+ clean_latent_indices[:, -1] = 1 + latent_window_size # default index for clean latents post
+
+ # apply mask for control latents (clean latents)
+ for i in range(len(control_alphas)):
+ control_alpha = control_alphas[i]
+ if control_alpha is not None:
+ latent_mask = get_latent_mask(control_alpha)
+ logger.info(
+ f"Apply mask for clean latents 1x for {i+1}: shape: {latent_mask.shape}"
+ )
+ clean_latents[:, :, i : i + 1, :, :] = clean_latents[:, :, i : i + 1, :, :] * latent_mask
+
+ for one_frame_param in one_frame_inference:
+ if one_frame_param.startswith("target_index="):
+ target_index = int(one_frame_param.split("=")[1])
+ latent_indices[:, 0] = target_index
+ logger.info(f"Set index for target: {target_index}")
+ elif one_frame_param.startswith("control_index="):
+ control_indices = one_frame_param.split("=")[1].split(";")
+ i = 0
+ while i < len(control_indices) and i < clean_latent_indices.shape[1]:
+ control_index = int(control_indices[i])
+ clean_latent_indices[:, i] = control_index
+ i += 1
+ logger.info(f"Set index for clean latent 1x: {control_indices}")
+
+ if "no_2x" in one_frame_inference:
+ clean_latents_2x = None
+ clean_latent_2x_indices = None
+ logger.info(f"No clean_latents_2x")
+ else:
+ clean_latents_2x = torch.zeros((1, 16, 2, height // 8, width // 8), dtype=torch.float32)
+ index = 1 + latent_window_size + 1
+ clean_latent_2x_indices = torch.arange(index, index + 2) # 2
+
+ if "no_4x" in one_frame_inference:
+ clean_latents_4x = None
+ clean_latent_4x_indices = None
+ logger.info(f"No clean_latents_4x")
+ else:
+ index = 1 + latent_window_size + 1 + 2
+ clean_latent_4x_indices = torch.arange(index, index + 16) # 16
+
+ logger.info(
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
+ )
+
+ # prepare conditioning inputs
+ llama_vec = sample_parameter["llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask = sample_parameter["llama_attention_mask"].to(device)
+ clip_l_pooler = sample_parameter["clip_l_pooler"].to(device, dtype=torch.bfloat16)
+ if cfg_scale == 1.0:
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
+ else:
+ llama_vec_n = sample_parameter["negative_llama_vec"].to(device, dtype=torch.bfloat16)
+ llama_attention_mask_n = sample_parameter["negative_llama_attention_mask"].to(device)
+ clip_l_pooler_n = sample_parameter["negative_clip_l_pooler"].to(device, dtype=torch.bfloat16)
+ image_encoder_last_hidden_state = sample_parameter["image_encoder_last_hidden_state"].to(
+ device, dtype=torch.bfloat16
+ )
+
+ generated_latents = sample_hunyuan(
+ transformer=model,
+ sampler=args.sample_solver,
+ width=width,
+ height=height,
+ frames=1,
+ real_guidance_scale=cfg_scale,
+ distilled_guidance_scale=guidance_scale,
+ guidance_rescale=0.0,
+ # shift=3.0,
+ num_inference_steps=sample_steps,
+ generator=generator,
+ prompt_embeds=llama_vec,
+ prompt_embeds_mask=llama_attention_mask,
+ prompt_poolers=clip_l_pooler,
+ negative_prompt_embeds=llama_vec_n,
+ negative_prompt_embeds_mask=llama_attention_mask_n,
+ negative_prompt_poolers=clip_l_pooler_n,
+ device=device,
+ dtype=torch.bfloat16,
+ image_embeddings=image_encoder_last_hidden_state,
+ latent_indices=latent_indices,
+ clean_latents=clean_latents,
+ clean_latent_indices=clean_latent_indices,
+ clean_latents_2x=clean_latents_2x,
+ clean_latent_2x_indices=clean_latent_2x_indices,
+ clean_latents_4x=clean_latents_4x,
+ clean_latent_4x_indices=clean_latent_4x_indices,
+ )
+
+ real_history_latents = generated_latents.to(clean_latents)
+
+ # wait for 5 seconds until block swap is done
+ logger.info("Waiting for 5 seconds to finish block swap")
+ time.sleep(5)
+
+ gc.collect()
+ clean_memory_on_device(device)
+
+ video = decode_latent(
+ latent_window_size, total_latent_sections, args.bulk_decode, vae, real_history_latents, device, one_frame_mode
+ )
+ video = video.to("cpu", dtype=torch.float32).unsqueeze(0) # add batch dimension
+ video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
+ clean_memory_on_device(device)
+
+ return video
+
+ def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
+ vae_path = args.vae
+ logger.info(f"Loading VAE model from {vae_path}")
+ vae = load_framepack_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
+ return vae
+
+ def load_transformer(
+ self,
+ accelerator: Accelerator,
+ args: argparse.Namespace,
+ dit_path: str,
+ attn_mode: str,
+ split_attn: bool,
+ loading_device: str,
+ dit_weight_dtype: Optional[torch.dtype],
+ ):
+ logger.info(f"Loading DiT model from {dit_path}")
+ device = accelerator.device
+ model = load_packed_model(device, dit_path, attn_mode, loading_device, args.fp8_scaled, split_attn)
+ return model
+
+ def scale_shift_latents(self, latents):
+ # FramePack VAE includes scaling
+ return latents
+
+ def call_dit(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ transformer,
+ latents: torch.Tensor,
+ batch: dict[str, torch.Tensor],
+ noise: torch.Tensor,
+ noisy_model_input: torch.Tensor,
+ timesteps: torch.Tensor,
+ network_dtype: torch.dtype,
+ ):
+ model: HunyuanVideoTransformer3DModelPacked = transformer
+ device = accelerator.device
+ batch_size = latents.shape[0]
+
+ # maybe model.dtype is better than network_dtype...
+ distilled_guidance = torch.tensor([args.guidance_scale * 1000.0] * batch_size).to(device=device, dtype=network_dtype)
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
+ # for k, v in batch.items():
+ # if isinstance(v, torch.Tensor):
+ # print(f"{k}: {v.shape} {v.dtype} {v.device}")
+ with accelerator.autocast():
+ clean_latent_2x_indices = batch["clean_latent_2x_indices"] if "clean_latent_2x_indices" in batch else None
+ if clean_latent_2x_indices is not None:
+ clean_latent_2x = batch["latents_clean_2x"] if "latents_clean_2x" in batch else None
+ if clean_latent_2x is None:
+ clean_latent_2x = torch.zeros(
+ (batch_size, 16, 2, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
+ )
+ else:
+ clean_latent_2x = None
+
+ clean_latent_4x_indices = batch["clean_latent_4x_indices"] if "clean_latent_4x_indices" in batch else None
+ if clean_latent_4x_indices is not None:
+ clean_latent_4x = batch["latents_clean_4x"] if "latents_clean_4x" in batch else None
+ if clean_latent_4x is None:
+ clean_latent_4x = torch.zeros(
+ (batch_size, 16, 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device
+ )
+ else:
+ clean_latent_4x = None
+
+ model_pred = model(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=batch["llama_vec"],
+ encoder_attention_mask=batch["llama_attention_mask"],
+ pooled_projections=batch["clip_l_pooler"],
+ guidance=distilled_guidance,
+ latent_indices=batch["latent_indices"],
+ clean_latents=batch["latents_clean"],
+ clean_latent_indices=batch["clean_latent_indices"],
+ clean_latents_2x=clean_latent_2x,
+ clean_latent_2x_indices=clean_latent_2x_indices,
+ clean_latents_4x=clean_latent_4x,
+ clean_latent_4x_indices=clean_latent_4x_indices,
+ image_embeddings=batch["image_embeddings"],
+ return_dict=False,
+ )
+ model_pred = model_pred[0] # returns tuple (model_pred, )
+
+ # flow matching loss
+ target = noise - latents
+
+ return model_pred, target
+
+ # endregion model specific
+
+
+def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """FramePack specific parser setup"""
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
+ parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
+ parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once in sample generation")
+ parser.add_argument("--f1", action="store_true", help="Use F1 sampling method for sample generation")
+ parser.add_argument("--one_frame", action="store_true", help="Use one frame sampling method for sample generation")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser_common()
+ parser = framepack_setup_parser(parser)
+
+ args = parser.parse_args()
+ args = read_config_from_file(args, parser)
+
+ assert (
+ args.vae_dtype is None or args.vae_dtype == "float16"
+ ), "VAE dtype must be float16 / VAEのdtypeはfloat16でなければなりません"
+ args.vae_dtype = "float16" # fixed
+ args.dit_dtype = "bfloat16" # fixed
+ args.sample_solver = "unipc" # for sample generation, fixed to unipc
+
+ trainer = FramePackNetworkTrainer()
+ trainer.train(args)
diff --git a/frame_pack/__init__.py b/frame_pack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/frame_pack/bucket_tools.py b/frame_pack/bucket_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc13fdeb11f9ac87c64dda049a06b968360e7c3f
--- /dev/null
+++ b/frame_pack/bucket_tools.py
@@ -0,0 +1,30 @@
+bucket_options = {
+ 640: [
+ (416, 960),
+ (448, 864),
+ (480, 832),
+ (512, 768),
+ (544, 704),
+ (576, 672),
+ (608, 640),
+ (640, 608),
+ (672, 576),
+ (704, 544),
+ (768, 512),
+ (832, 480),
+ (864, 448),
+ (960, 416),
+ ],
+}
+
+
+def find_nearest_bucket(h, w, resolution=640):
+ min_metric = float('inf')
+ best_bucket = None
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
+ metric = abs(h * bucket_w - w * bucket_h)
+ if metric <= min_metric:
+ min_metric = metric
+ best_bucket = (bucket_h, bucket_w)
+ return best_bucket
+
diff --git a/frame_pack/clip_vision.py b/frame_pack/clip_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c919296b23084ac00e3e4440657d368df1ee86e
--- /dev/null
+++ b/frame_pack/clip_vision.py
@@ -0,0 +1,14 @@
+import numpy as np
+
+
+def hf_clip_vision_encode(image, feature_extractor, image_encoder):
+ assert isinstance(image, np.ndarray)
+ assert image.ndim == 3 and image.shape[2] == 3
+ assert image.dtype == np.uint8
+
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(
+ device=image_encoder.device, dtype=image_encoder.dtype
+ )
+ image_encoder_output = image_encoder(**preprocessed)
+
+ return image_encoder_output
diff --git a/frame_pack/framepack_utils.py b/frame_pack/framepack_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a57364aa6bb0d67c0491b532646daa2fa6d36fc1
--- /dev/null
+++ b/frame_pack/framepack_utils.py
@@ -0,0 +1,273 @@
+import os
+import logging
+from types import SimpleNamespace
+from typing import Optional, Union
+
+import accelerate
+from accelerate import Accelerator, init_empty_weights
+import torch
+from safetensors.torch import load_file
+from transformers import (
+ LlamaTokenizerFast,
+ LlamaConfig,
+ LlamaModel,
+ CLIPTokenizer,
+ CLIPTextModel,
+ CLIPConfig,
+ SiglipImageProcessor,
+ SiglipVisionModel,
+ SiglipVisionConfig,
+)
+
+from utils.safetensors_utils import load_split_weights
+from hunyuan_model.vae import load_vae as hunyuan_load_vae
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def load_vae(
+ vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device]
+):
+ # single file and directory (contains 'vae') support
+ if os.path.isdir(vae_path):
+ vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors")
+ else:
+ vae_path = vae_path
+
+ vae_dtype = torch.float16 # if vae_dtype is None else str_to_dtype(vae_dtype)
+ vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path)
+ vae.eval()
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
+
+ # set chunk_size to CausalConv3d recursively
+ chunk_size = vae_chunk_size
+ if chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
+
+ if vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8
+ logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}")
+ # elif vae_tiling:
+ else:
+ vae.enable_spatial_tiling(True)
+
+ return vae
+
+
+# region Text Encoders
+
+# Text Encoder configs are copied from HunyuanVideo repo
+
+LLAMA_CONFIG = {
+ "architectures": ["LlamaModel"],
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": 128001,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 14336,
+ "max_position_embeddings": 8192,
+ "mlp_bias": False,
+ "model_type": "llama",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": None,
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": False,
+ "torch_dtype": "float16",
+ "transformers_version": "4.46.3",
+ "use_cache": True,
+ "vocab_size": 128320,
+}
+
+CLIP_CONFIG = {
+ # "_name_or_path": "/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2",
+ "architectures": ["CLIPTextModel"],
+ "attention_dropout": 0.0,
+ "bos_token_id": 0,
+ "dropout": 0.0,
+ "eos_token_id": 2,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 768,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-05,
+ "max_position_embeddings": 77,
+ "model_type": "clip_text_model",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 1,
+ "projection_dim": 768,
+ "torch_dtype": "float16",
+ "transformers_version": "4.48.0.dev0",
+ "vocab_size": 49408,
+}
+
+
+def load_text_encoder1(
+ args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None
+) -> tuple[LlamaTokenizerFast, LlamaModel]:
+ # single file, split file and directory (contains 'text_encoder') support
+ logger.info(f"Loading text encoder 1 tokenizer")
+ tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer")
+
+ logger.info(f"Loading text encoder 1 from {args.text_encoder1}")
+ if os.path.isdir(args.text_encoder1):
+ # load from directory, configs are in the directory
+ text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16)
+ else:
+ # load from file, we create the model with the appropriate config
+ config = LlamaConfig(**LLAMA_CONFIG)
+ with init_empty_weights():
+ text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16)
+
+ state_dict = load_split_weights(args.text_encoder1)
+
+ # support weights from ComfyUI
+ if "model.embed_tokens.weight" in state_dict:
+ for key in list(state_dict.keys()):
+ if key.startswith("model."):
+ new_key = key.replace("model.", "")
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+ if "tokenizer" in state_dict:
+ state_dict.pop("tokenizer")
+ if "lm_head.weight" in state_dict:
+ state_dict.pop("lm_head.weight")
+
+ # # support weights from ComfyUI
+ # if "tokenizer" in state_dict:
+ # state_dict.pop("tokenizer")
+
+ text_encoder1.load_state_dict(state_dict, strict=True, assign=True)
+
+ if fp8_llm:
+ org_dtype = text_encoder1.dtype
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
+ text_encoder1.to(device=device, dtype=torch.float8_e4m3fn)
+
+ # prepare LLM for fp8
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
+ def forward_hook(module):
+ def forward(hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
+
+ return forward
+
+ for module in llama_model.modules():
+ if module.__class__.__name__ in ["Embedding"]:
+ # print("set", module.__class__.__name__, "to", target_dtype)
+ module.to(target_dtype)
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
+ # print("set", module.__class__.__name__, "hooks")
+ module.forward = forward_hook(module)
+
+ prepare_fp8(text_encoder1, org_dtype)
+ else:
+ text_encoder1.to(device)
+
+ text_encoder1.eval()
+ return tokenizer1, text_encoder1
+
+
+def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]:
+ # single file and directory (contains 'text_encoder_2') support
+ logger.info(f"Loading text encoder 2 tokenizer")
+ tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2")
+
+ logger.info(f"Loading text encoder 2 from {args.text_encoder2}")
+ if os.path.isdir(args.text_encoder2):
+ # load from directory, configs are in the directory
+ text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16)
+ else:
+ # we only have one file, so we can load it directly
+ config = CLIPConfig(**CLIP_CONFIG)
+ with init_empty_weights():
+ text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16)
+
+ state_dict = load_file(args.text_encoder2)
+
+ text_encoder2.load_state_dict(state_dict, strict=True, assign=True)
+
+ text_encoder2.eval()
+ return tokenizer2, text_encoder2
+
+
+# endregion
+
+# region image encoder
+
+# Siglip configs are copied from FramePack repo
+FEATURE_EXTRACTOR_CONFIG = {
+ "do_convert_rgb": None,
+ "do_normalize": True,
+ "do_rescale": True,
+ "do_resize": True,
+ "image_mean": [0.5, 0.5, 0.5],
+ "image_processor_type": "SiglipImageProcessor",
+ "image_std": [0.5, 0.5, 0.5],
+ "processor_class": "SiglipProcessor",
+ "resample": 3,
+ "rescale_factor": 0.00392156862745098,
+ "size": {"height": 384, "width": 384},
+}
+IMAGE_ENCODER_CONFIG = {
+ "_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder",
+ "architectures": ["SiglipVisionModel"],
+ "attention_dropout": 0.0,
+ "hidden_act": "gelu_pytorch_tanh",
+ "hidden_size": 1152,
+ "image_size": 384,
+ "intermediate_size": 4304,
+ "layer_norm_eps": 1e-06,
+ "model_type": "siglip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 27,
+ "patch_size": 14,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.46.2",
+}
+
+
+def load_image_encoders(args):
+ logger.info(f"Loading image encoder feature extractor")
+ feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG)
+
+ # single file, split file and directory (contains 'image_encoder') support
+ logger.info(f"Loading image encoder from {args.image_encoder}")
+ if os.path.isdir(args.image_encoder):
+ # load from directory, configs are in the directory
+ image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16)
+ else:
+ # load from file, we create the model with the appropriate config
+ config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG)
+ with init_empty_weights():
+ image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16)
+
+ state_dict = load_file(args.image_encoder)
+
+ image_encoder.load_state_dict(state_dict, strict=True, assign=True)
+
+ image_encoder.eval()
+ return feature_extractor, image_encoder
+
+
+# endregion
diff --git a/frame_pack/hunyuan.py b/frame_pack/hunyuan.py
new file mode 100644
index 0000000000000000000000000000000000000000..2349d8cf91198373a1f1299915a6d85e2f5b0b44
--- /dev/null
+++ b/frame_pack/hunyuan.py
@@ -0,0 +1,134 @@
+# original code: https://github.com/lllyasviel/FramePack
+# original license: Apache-2.0
+
+import torch
+
+# from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
+# from diffusers_helper.utils import crop_or_pad_yield_mask
+from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from hunyuan_model.text_encoder import PROMPT_TEMPLATE
+
+
+@torch.no_grad()
+def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256, custom_system_prompt=None):
+ assert isinstance(prompt, str)
+
+ prompt = [prompt]
+
+ # LLAMA
+
+ # We can verify crop_start by checking the token count of the prompt:
+ # custom_system_prompt = (
+ # "Describe the video by detailing the following aspects: "
+ # "1. The main content and theme of the video."
+ # "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ # "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ # "4. background environment, light, style and atmosphere."
+ # "5. camera angles, movements, and transitions used in the video:"
+ # )
+ if custom_system_prompt is None:
+ prompt_llama = [PROMPT_TEMPLATE["dit-llm-encode-video"]["template"].format(p) for p in prompt]
+ crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"]["crop_start"]
+ else:
+ # count tokens for custom_system_prompt
+ full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{custom_system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
+ print(f"Custom system prompt: {full_prompt}")
+ system_prompt_tokens = tokenizer(full_prompt, return_tensors="pt", truncation=True).input_ids[0].shape[0]
+ print(f"Custom system prompt token count: {system_prompt_tokens}")
+ prompt_llama = [full_prompt + p + "<|eot_id|>" for p in prompt]
+ crop_start = system_prompt_tokens
+
+ llama_inputs = tokenizer(
+ prompt_llama,
+ padding="max_length",
+ max_length=max_length + crop_start,
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
+ llama_attention_length = int(llama_attention_mask.sum())
+
+ llama_outputs = text_encoder(
+ input_ids=llama_input_ids,
+ attention_mask=llama_attention_mask,
+ output_hidden_states=True,
+ )
+
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
+
+ assert torch.all(llama_attention_mask.bool())
+
+ # CLIP
+
+ clip_l_input_ids = tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ ).input_ids
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
+
+ return llama_vec, clip_l_pooler
+
+
+@torch.no_grad()
+def vae_decode_fake(latents):
+ latent_rgb_factors = [
+ [-0.0395, -0.0331, 0.0445],
+ [0.0696, 0.0795, 0.0518],
+ [0.0135, -0.0945, -0.0282],
+ [0.0108, -0.0250, -0.0765],
+ [-0.0209, 0.0032, 0.0224],
+ [-0.0804, -0.0254, -0.0639],
+ [-0.0991, 0.0271, -0.0669],
+ [-0.0646, -0.0422, -0.0400],
+ [-0.0696, -0.0595, -0.0894],
+ [-0.0799, -0.0208, -0.0375],
+ [0.1166, 0.1627, 0.0962],
+ [0.1165, 0.0432, 0.0407],
+ [-0.2315, -0.1920, -0.1355],
+ [-0.0270, 0.0401, -0.0821],
+ [-0.0616, -0.0997, -0.0727],
+ [0.0249, -0.0469, -0.1703],
+ ] # From comfyui
+
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
+
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
+
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
+ images = images.clamp(0.0, 1.0)
+
+ return images
+
+
+@torch.no_grad()
+def vae_decode(latents, vae, image_mode=False) -> torch.Tensor:
+ latents = latents / vae.config.scaling_factor
+
+ if not image_mode:
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
+ else:
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
+ image = torch.cat(image, dim=2)
+
+ return image
+
+
+@torch.no_grad()
+def vae_encode(image, vae: AutoencoderKLCausal3D) -> torch.Tensor:
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ return latents
diff --git a/frame_pack/hunyuan_video_packed.py b/frame_pack/hunyuan_video_packed.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4cee4a7d8c3325dccb980e7f89a3d921e6fca8c
--- /dev/null
+++ b/frame_pack/hunyuan_video_packed.py
@@ -0,0 +1,2038 @@
+# original code: https://github.com/lllyasviel/FramePack
+# original license: Apache-2.0
+
+import glob
+import math
+import numbers
+import os
+from types import SimpleNamespace
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import einops
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from modules.custom_offloading_utils import ModelOffloader
+from utils.safetensors_utils import load_split_weights
+from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8
+from accelerate import init_empty_weights
+
+try:
+ # raise NotImplementedError
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
+
+ print("Xformers is installed!")
+except:
+ print("Xformers is not installed!")
+ xformers_attn_func = None
+
+try:
+ # raise NotImplementedError
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
+
+ print("Flash Attn is installed!")
+except:
+ print("Flash Attn is not installed!")
+ flash_attn_varlen_func = None
+ flash_attn_func = None
+
+try:
+ # raise NotImplementedError
+ from sageattention import sageattn_varlen, sageattn
+
+ print("Sage Attn is installed!")
+except:
+ print("Sage Attn is not installed!")
+ sageattn_varlen = None
+ sageattn = None
+
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+# region diffusers
+
+# copied from diffusers with some modifications to minimize dependencies
+# original code: https://github.com/huggingface/diffusers/
+# original license: Apache-2.0
+
+ACT2CLS = {
+ "swish": nn.SiLU,
+ "silu": nn.SiLU,
+ "mish": nn.Mish,
+ "gelu": nn.GELU,
+ "relu": nn.ReLU,
+}
+
+
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
+
+ act_fn = act_fn.lower()
+ if act_fn in ACT2CLS:
+ return ACT2CLS[act_fn]()
+ else:
+ raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ Args
+ timesteps (torch.Tensor):
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ embedding_dim (int):
+ the dimension of the output.
+ flip_sin_to_cos (bool):
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
+ downscale_freq_shift (float):
+ Controls the delta between frequencies between dimensions
+ scale (float):
+ Scaling factor applied to the embeddings.
+ max_period (int):
+ Controls the maximum frequency of the embeddings
+ Returns
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ sample_proj_bias=True,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ )
+ return t_emb
+
+
+class FP32SiLU(nn.Module):
+ r"""
+ SiLU activation function with input upcasted to torch.float32.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ # if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
+ # # fp16 gelu not supported on mps before torch 2.0
+ # return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+ return F.gelu(gate, approximate=self.approximate)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class PixArtAlphaTextProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
+ super().__init__()
+ if out_features is None:
+ out_features = hidden_size
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ if act_fn == "gelu_tanh":
+ self.act_1 = nn.GELU(approximate="tanh")
+ elif act_fn == "silu":
+ self.act_1 = nn.SiLU()
+ elif act_fn == "silu_fp32":
+ self.act_1 = FP32SiLU()
+ else:
+ raise ValueError(f"Unknown activation function: {act_fn}")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class LayerNormFramePack(nn.LayerNorm):
+ # casting to dtype of input tensor is added
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
+
+
+class FP32LayerNormFramePack(nn.LayerNorm):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ origin_dtype = x.dtype
+ return torch.nn.functional.layer_norm(
+ x.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ ).to(origin_dtype)
+
+
+class RMSNormFramePack(nn.Module):
+ r"""
+ RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
+
+ Args:
+ dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
+ eps (`float`): Small value to use when calculating the reciprocal of the square-root.
+ elementwise_affine (`bool`, defaults to `True`):
+ Boolean flag to denote if affine transformation should be applied.
+ bias (`bool`, defaults to False): If also training the `bias` param.
+ """
+
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
+ super().__init__()
+
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+
+ if isinstance(dim, numbers.Integral):
+ dim = (dim,)
+
+ self.dim = torch.Size(dim)
+
+ self.weight = None
+ self.bias = None
+
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(dim))
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ if self.weight is None:
+ return hidden_states.to(input_dtype)
+
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
+
+
+class AdaLayerNormContinuousFramePack(nn.Module):
+ r"""
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
+
+ Args:
+ embedding_dim (`int`): Embedding dimension to use during projection.
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
+ elementwise_affine (`bool`, defaults to `True`):
+ Boolean flag to denote if affine transformation should be applied.
+ eps (`float`, defaults to 1e-5): Epsilon factor.
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
+ norm_type (`str`, defaults to `"layer_norm"`):
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ ):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNormFramePack(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x, conditioning_embedding):
+ emb = self.linear(self.silu(conditioning_embedding))
+ scale, shift = emb.chunk(2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+class LinearActivation(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
+ super().__init__()
+
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.activation = get_activation(activation)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ return self.activation(hidden_states)
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ # if activation_fn == "gelu":
+ # act_fn = GELU(dim, inner_dim, bias=bias)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ # elif activation_fn == "geglu":
+ # act_fn = GEGLU(dim, inner_dim, bias=bias)
+ # elif activation_fn == "geglu-approximate":
+ # act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+ # elif activation_fn == "swiglu":
+ # act_fn = SwiGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "linear-silu":
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
+ else:
+ raise ValueError(f"Unknown activation function: {activation_fn}")
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ # deprecate("scale", "1.0.0", deprecation_message)
+ raise ValueError("scale is not supported in this version. Please remove it.")
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+# @maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ Minimal copy of Attention class from diffusers.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ bias: bool = False,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ eps: float = 1e-5,
+ processor: Optional[any] = None,
+ out_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ ):
+ super().__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim # if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+
+ self.scale = dim_head**-0.5
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNormFramePack(dim_head, eps=eps)
+ self.norm_k = RMSNormFramePack(dim_head, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+
+ self.added_proj_bias = True # added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ else:
+ self.add_q_proj = None
+ self.add_k_proj = None
+ self.add_v_proj = None
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=True))
+ # self.to_out.append(nn.Dropout(dropout))
+ self.to_out.append(nn.Identity()) # dropout=0.0
+ else:
+ self.to_out = None
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=True)
+ else:
+ self.to_add_out = None
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "rms_norm":
+ self.norm_added_q = RMSNormFramePack(dim_head, eps=eps)
+ self.norm_added_k = RMSNormFramePack(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`")
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ if processor is None:
+ processor = AttnProcessor2_0()
+ self.set_processor(processor)
+
+ def set_processor(self, processor: any) -> None:
+ self.processor = processor
+
+ def get_processor(self) -> any:
+ return self.processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0, output_size=attention_mask.shape[0] * head_size)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1, output_size=attention_mask.shape[1] * head_size)
+
+ return attention_mask
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ query_dtype = query.dtype # store dtype before potentially deleting query
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
+ del query, key, value, attention_mask # free memory
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query_dtype) # use stored dtype
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states
+
+
+# endregion diffusers
+
+
+def pad_for_3d_conv(x, kernel_size):
+ b, c, t, h, w = x.shape
+ pt, ph, pw = kernel_size
+ pad_t = (pt - (t % pt)) % pt
+ pad_h = (ph - (h % ph)) % ph
+ pad_w = (pw - (w % pw)) % pw
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
+
+
+def center_down_sample_3d(x, kernel_size):
+ # pt, ph, pw = kernel_size
+ # cp = (pt * ph * pw) // 2
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
+ # xc = xp[cp]
+ # return xc
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
+
+
+def get_cu_seqlens(text_mask, img_len):
+ batch_size = text_mask.shape[0]
+ text_len = text_mask.sum(dim=1)
+ max_len = text_mask.shape[1] + img_len
+
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) # ensure device match
+
+ for i in range(batch_size):
+ s = text_len[i] + img_len
+ s1 = i * max_len + s
+ s2 = (i + 1) * max_len
+ cu_seqlens[2 * i + 1] = s1
+ cu_seqlens[2 * i + 2] = s2
+
+ return cu_seqlens
+
+
+def apply_rotary_emb_transposed(x, freqs_cis):
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
+ del freqs_cis
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ del x_real, x_imag
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+
+def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=None, split_attn=False):
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
+ if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
+ x = sageattn(q, k, v, tensor_layout="NHD")
+ return x
+
+ if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
+ x = flash_attn_func(q, k, v)
+ return x
+
+ if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
+ x = xformers_attn_func(q, k, v)
+ return x
+
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(
+ 1, 2
+ )
+ return x
+ if split_attn:
+ if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = sageattn(q[i : i + 1], k[i : i + 1], v[i : i + 1], tensor_layout="NHD")
+ return x
+
+ if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = flash_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
+ return x
+
+ if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = xformers_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
+ return x
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(q[i : i + 1], k[i : i + 1], v[i : i + 1])
+ x = x.transpose(1, 2)
+ return x
+
+ batch_size = q.shape[0]
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
+ if attn_mode == "sageattn" or attn_mode is None and sageattn_varlen is not None:
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+ del q, k, v # free memory
+ elif attn_mode == "flash" or attn_mode is None and flash_attn_varlen_func is not None:
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+ del q, k, v # free memory
+ else:
+ raise NotImplementedError("No Attn Installed or batch_size > 1 is not supported in this configuration. Try `--split_attn`.")
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
+ return x
+
+
+class HunyuanAttnProcessorFlashAttnDouble:
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ image_rotary_emb,
+ attn_mode: Optional[str] = None,
+ split_attn: Optional[bool] = False,
+ ):
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
+
+ # Project image latents
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ del hidden_states # free memory
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
+ del image_rotary_emb # free memory
+
+ # Project context (text/encoder) embeddings
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+ txt_length = encoder_hidden_states.shape[1] # store length before deleting
+ del encoder_hidden_states # free memory
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ # Concatenate image and context q, k, v
+ query = torch.cat([query, encoder_query], dim=1)
+ key = torch.cat([key, encoder_key], dim=1)
+ value = torch.cat([value, encoder_value], dim=1)
+ del encoder_query, encoder_key, encoder_value # free memory
+
+ hidden_states_attn = attn_varlen_func(
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
+ )
+ del query, key, value # free memory
+ hidden_states_attn = hidden_states_attn.flatten(-2)
+
+ hidden_states, encoder_hidden_states = hidden_states_attn[:, :-txt_length], hidden_states_attn[:, -txt_length:]
+ del hidden_states_attn # free memory
+
+ # Apply output projections
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states) # Dropout/Identity
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanAttnProcessorFlashAttnSingle:
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ image_rotary_emb,
+ attn_mode: Optional[str] = None,
+ split_attn: Optional[bool] = False,
+ ):
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
+ txt_length = encoder_hidden_states.shape[1] # Store text length
+
+ # Concatenate image and context inputs
+ hidden_states_cat = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+ del hidden_states, encoder_hidden_states # free memory
+
+ # Project concatenated inputs
+ query = attn.to_q(hidden_states_cat)
+ key = attn.to_k(hidden_states_cat)
+ value = attn.to_v(hidden_states_cat)
+ del hidden_states_cat # free memory
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
+ del image_rotary_emb # free memory
+
+ hidden_states = attn_varlen_func(
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
+ )
+ del query, key, value # free memory
+ hidden_states = hidden_states.flatten(-2)
+
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, guidance, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = time_guidance_emb + pooled_projections
+
+ return conditioning
+
+
+class CombinedTimestepTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
+
+ pooled_projections = self.text_embedder(pooled_projection)
+
+ conditioning = timesteps_emb + pooled_projections
+
+ return conditioning
+
+
+class HunyuanVideoAdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ )
+
+ self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
+
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # Self-attention
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ )
+ del norm_hidden_states # free memory
+
+ gate_msa, gate_mlp = self.norm_out(temb)
+ hidden_states = hidden_states + attn_output * gate_msa
+ del attn_output, gate_msa # free memory
+
+ ff_output = self.ff(self.norm2(hidden_states))
+ hidden_states = hidden_states + ff_output * gate_mlp
+ del ff_output, gate_mlp # free memory
+
+ return hidden_states
+
+
+class HunyuanVideoIndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.refiner_blocks = nn.ModuleList(
+ [
+ HunyuanVideoIndividualTokenRefinerBlock(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ self_attn_mask = None
+ if attention_mask is not None:
+ batch_size = attention_mask.shape[0]
+ seq_len = attention_mask.shape[1]
+ attention_mask = attention_mask.to(hidden_states.device).bool()
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+ self_attn_mask[:, :, :, 0] = True
+
+ for block in self.refiner_blocks:
+ hidden_states = block(hidden_states, temb, self_attn_mask)
+
+ return hidden_states
+
+
+class HunyuanVideoTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(embedding_dim=hidden_size, pooled_projection_dim=in_channels)
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_layers=num_layers,
+ mlp_width_ratio=mlp_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ if attention_mask is None:
+ pooled_projections = hidden_states.mean(dim=1)
+ else:
+ original_dtype = hidden_states.dtype
+ mask_float = attention_mask.float().unsqueeze(-1)
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ pooled_projections = pooled_projections.to(original_dtype)
+
+ temb = self.time_text_embed(timestep, pooled_projections)
+ del pooled_projections # free memory
+
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
+ del temb, attention_mask # free memory
+
+ return hidden_states
+
+
+class HunyuanVideoRotaryPosEmbed(nn.Module):
+ def __init__(self, rope_dim, theta):
+ super().__init__()
+ self.DT, self.DY, self.DX = rope_dim
+ self.theta = theta
+ self.h_w_scaling_factor = 1.0
+
+ @torch.no_grad()
+ def get_frequency(self, dim, pos):
+ T, H, W = pos.shape
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
+ return freqs.cos(), freqs.sin()
+
+ @torch.no_grad()
+ def forward_inner(self, frame_indices, height, width, device):
+ GT, GY, GX = torch.meshgrid(
+ frame_indices.to(device=device, dtype=torch.float32),
+ torch.arange(0, height, device=device, dtype=torch.float32) * self.h_w_scaling_factor,
+ torch.arange(0, width, device=device, dtype=torch.float32) * self.h_w_scaling_factor,
+ indexing="ij",
+ )
+
+ FCT, FST = self.get_frequency(self.DT, GT)
+ del GT # free memory
+ FCY, FSY = self.get_frequency(self.DY, GY)
+ del GY # free memory
+ FCX, FSX = self.get_frequency(self.DX, GX)
+ del GX # free memory
+
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
+ del FCT, FCY, FCX, FST, FSY, FSX # free memory
+
+ # Return result already on the correct device
+ return result # Shape (2 * total_dim / 2, T, H, W) -> (total_dim, T, H, W)
+
+ @torch.no_grad()
+ def forward(self, frame_indices, height, width, device):
+ frame_indices = frame_indices.unbind(0)
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
+ results = torch.stack(results, dim=0)
+ return results
+
+
+class AdaLayerNormZero(nn.Module):
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(
+ self, x: torch.Tensor, emb: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = emb.unsqueeze(-2)
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaLayerNormZeroSingle(nn.Module):
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ emb = emb.unsqueeze(-2)
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
+ return x, gate_msa
+
+
+class AdaLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ ):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ emb = emb.unsqueeze(-2)
+ emb = self.linear(self.silu(emb))
+ scale, shift = emb.chunk(2, dim=-1)
+ del emb # free memory
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class HunyuanVideoSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ attn_mode: Optional[str] = None,
+ split_attn: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+
+ # Attention layer (pre_only=True means no output projection in Attention module itself)
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ bias=True,
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ pre_only=True, # Crucial: Attn processor will return raw attention output
+ )
+
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+ del encoder_hidden_states # free memory
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-text_seq_length, :],
+ norm_hidden_states[:, -text_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ attn_mode=self.attn_mode,
+ split_attn=self.split_attn,
+ )
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+ del norm_hidden_states, norm_encoder_hidden_states, context_attn_output # free memory
+ del image_rotary_emb
+
+ # 3. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ del attn_output, mlp_hidden_states # free memory
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-text_seq_length, :],
+ hidden_states[:, -text_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideoTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ attn_mode: Optional[str] = None,
+ split_attn: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ self.norm2_context = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=freqs_cis,
+ attn_mode=self.attn_mode,
+ split_attn=self.split_attn,
+ )
+ del norm_hidden_states, norm_encoder_hidden_states, freqs_cis # free memory
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa
+ del attn_output, gate_msa # free memory
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
+ del context_attn_output, c_gate_msa # free memory
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+ del shift_mlp, scale_mlp # free memory
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
+ del c_shift_mlp, c_scale_mlp # free memory
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ del norm_hidden_states # free memory
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ del norm_encoder_hidden_states # free memory
+
+ hidden_states = hidden_states + gate_mlp * ff_output
+ del ff_output, gate_mlp # free memory
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
+ del context_ff_output, c_gate_mlp # free memory
+
+ return hidden_states, encoder_hidden_states
+
+
+class ClipVisionProjection(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.up = nn.Linear(in_channels, out_channels * 3)
+ self.down = nn.Linear(out_channels * 3, out_channels)
+
+ def forward(self, x):
+ projected_x = self.down(nn.functional.silu(self.up(x)))
+ return projected_x
+
+
+class HunyuanVideoPatchEmbed(nn.Module):
+ def __init__(self, patch_size, in_chans, embed_dim):
+ super().__init__()
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+
+class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
+ def __init__(self, inner_dim):
+ super().__init__()
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
+
+ @torch.no_grad()
+ def initialize_weight_from_another_conv3d(self, another_layer):
+ weight = another_layer.weight.detach().clone()
+ bias = another_layer.bias.detach().clone()
+
+ sd = {
+ "proj.weight": weight.clone(),
+ "proj.bias": bias.clone(),
+ "proj_2x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=2, hk=2, wk=2) / 8.0,
+ "proj_2x.bias": bias.clone(),
+ "proj_4x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=4, hk=4, wk=4) / 64.0,
+ "proj_4x.bias": bias.clone(),
+ }
+
+ sd = {k: v.clone() for k, v in sd.items()}
+
+ self.load_state_dict(sd)
+ return
+
+
+class HunyuanVideoTransformer3DModelPacked(nn.Module): # (PreTrainedModelMixin, GenerationMixin,
+ # ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ # @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_refiner_layers: int = 2,
+ mlp_ratio: float = 4.0,
+ patch_size: int = 2,
+ patch_size_t: int = 1,
+ qk_norm: str = "rms_norm",
+ guidance_embeds: bool = True,
+ text_embed_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
+ has_image_proj=False,
+ image_proj_dim=1152,
+ has_clean_x_embedder=False,
+ attn_mode: Optional[str] = None,
+ split_attn: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+ self.config_patch_size = patch_size
+ self.config_patch_size_t = patch_size_t
+
+ # 1. Latent and condition embedders
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+ self.context_embedder = HunyuanVideoTokenRefiner(
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
+ )
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
+
+ self.clean_x_embedder = None
+ self.image_projection = None
+
+ # 2. RoPE
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
+
+ # 3. Dual stream transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoTransformerBlock(
+ num_attention_heads,
+ attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ attn_mode=attn_mode,
+ split_attn=split_attn,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoSingleTransformerBlock(
+ num_attention_heads,
+ attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ attn_mode=attn_mode,
+ split_attn=split_attn,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
+
+ self.inner_dim = inner_dim
+ self.use_gradient_checkpointing = False
+ self.enable_teacache = False
+
+ # if has_image_proj:
+ # self.install_image_projection(image_proj_dim)
+ self.image_projection = ClipVisionProjection(in_channels=image_proj_dim, out_channels=self.inner_dim)
+ # self.config["has_image_proj"] = True
+ # self.config["image_proj_dim"] = in_channels
+
+ # if has_clean_x_embedder:
+ # self.install_clean_x_embedder()
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
+ # self.config["has_clean_x_embedder"] = True
+
+ self.high_quality_fp32_output_for_inference = True # False # change default to True
+
+ # Block swapping attributes (initialized to None)
+ self.blocks_to_swap = None
+ self.offloader_double = None
+ self.offloader_single = None
+
+ # RoPE scaling
+ self.rope_scaling_timestep_threshold: Optional[int] = None # scale RoPE above this timestep
+ self.rope_scaling_factor: float = 1.0 # RoPE scaling factor
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def enable_gradient_checkpointing(self):
+ self.use_gradient_checkpointing = True
+ print("Gradient checkpointing enabled for HunyuanVideoTransformer3DModelPacked.") # Logging
+
+ def disable_gradient_checkpointing(self):
+ self.use_gradient_checkpointing = False
+ print("Gradient checkpointing disabled for HunyuanVideoTransformer3DModelPacked.") # Logging
+
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
+ self.enable_teacache = enable_teacache
+ self.cnt = 0
+ self.num_steps = num_steps
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.previous_residual = None
+ self.teacache_rescale_func = np.poly1d([7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02])
+ if enable_teacache:
+ print(f"TeaCache enabled: num_steps={num_steps}, rel_l1_thresh={rel_l1_thresh}")
+ else:
+ print("TeaCache disabled.")
+
+ def gradient_checkpointing_method(self, block, *args):
+ if self.use_gradient_checkpointing:
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
+ else:
+ result = block(*args)
+ return result
+
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
+ self.blocks_to_swap = num_blocks
+ self.num_double_blocks = len(self.transformer_blocks)
+ self.num_single_blocks = len(self.single_transformer_blocks)
+ double_blocks_to_swap = num_blocks // 2
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
+
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
+ )
+
+ self.offloader_double = ModelOffloader(
+ "double",
+ self.transformer_blocks,
+ self.num_double_blocks,
+ double_blocks_to_swap,
+ supports_backward,
+ device,
+ # debug=True # Optional debugging
+ )
+ self.offloader_single = ModelOffloader(
+ "single",
+ self.single_transformer_blocks,
+ self.num_single_blocks,
+ single_blocks_to_swap,
+ supports_backward,
+ device, # , debug=True
+ )
+ print(
+ f"HunyuanVideoTransformer3DModelPacked: Block swap enabled. Swapping {num_blocks} blocks, "
+ + f"double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}, supports_backward: {supports_backward}."
+ )
+
+ def switch_block_swap_for_inference(self):
+ if self.blocks_to_swap and self.blocks_to_swap > 0:
+ self.offloader_double.set_forward_only(True)
+ self.offloader_single.set_forward_only(True)
+ self.prepare_block_swap_before_forward()
+ print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward only.")
+
+ def switch_block_swap_for_training(self):
+ if self.blocks_to_swap and self.blocks_to_swap > 0:
+ self.offloader_double.set_forward_only(False)
+ self.offloader_single.set_forward_only(False)
+ self.prepare_block_swap_before_forward()
+ print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward and backward.")
+
+ def move_to_device_except_swap_blocks(self, device: torch.device):
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
+ if self.blocks_to_swap:
+ saved_double_blocks = self.transformer_blocks
+ saved_single_blocks = self.single_transformer_blocks
+ self.transformer_blocks = None
+ self.single_transformer_blocks = None
+
+ self.to(device)
+
+ if self.blocks_to_swap:
+ self.transformer_blocks = saved_double_blocks
+ self.single_transformer_blocks = saved_single_blocks
+
+ def prepare_block_swap_before_forward(self):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self.offloader_double.prepare_block_devices_before_forward(self.transformer_blocks)
+ self.offloader_single.prepare_block_devices_before_forward(self.single_transformer_blocks)
+
+ def enable_rope_scaling(self, timestep_threshold: Optional[int], rope_scaling_factor: float = 1.0):
+ if timestep_threshold is not None and rope_scaling_factor > 0:
+ self.rope_scaling_timestep_threshold = timestep_threshold
+ self.rope_scaling_factor = rope_scaling_factor
+ logger.info(f"RoPE scaling enabled: threshold={timestep_threshold}, scaling_factor={rope_scaling_factor}.")
+ else:
+ self.rope_scaling_timestep_threshold = None
+ self.rope_scaling_factor = 1.0
+ self.rope.h_w_scaling_factor = 1.0 # reset to default
+ logger.info("RoPE scaling disabled.")
+
+ def process_input_hidden_states(
+ self,
+ latents,
+ latent_indices=None,
+ clean_latents=None,
+ clean_latent_indices=None,
+ clean_latents_2x=None,
+ clean_latent_2x_indices=None,
+ clean_latents_4x=None,
+ clean_latent_4x_indices=None,
+ ):
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
+ B, C, T, H, W = hidden_states.shape
+
+ if latent_indices is None:
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
+
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
+
+ if clean_latents is not None and clean_latent_indices is not None:
+ clean_latents = clean_latents.to(hidden_states)
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
+
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
+
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
+
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
+
+ clean_latent_2x_rope_freqs = self.rope(
+ frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device
+ )
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
+
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
+
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
+
+ clean_latent_4x_rope_freqs = self.rope(
+ frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device
+ )
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
+
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
+
+ return hidden_states, rope_freqs
+
+ def forward(
+ self,
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ pooled_projections,
+ guidance,
+ latent_indices=None,
+ clean_latents=None,
+ clean_latent_indices=None,
+ clean_latents_2x=None,
+ clean_latent_2x_indices=None,
+ clean_latents_4x=None,
+ clean_latent_4x_indices=None,
+ image_embeddings=None,
+ attention_kwargs=None,
+ return_dict=True,
+ ):
+
+ if attention_kwargs is None:
+ attention_kwargs = {}
+
+ # RoPE scaling: must be done before processing hidden states
+ if self.rope_scaling_timestep_threshold is not None:
+ if timestep >= self.rope_scaling_timestep_threshold:
+ self.rope.h_w_scaling_factor = self.rope_scaling_factor
+ else:
+ self.rope.h_w_scaling_factor = 1.0
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p, p_t = self.config_patch_size, self.config_patch_size_t
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
+
+ hidden_states, rope_freqs = self.process_input_hidden_states(
+ hidden_states,
+ latent_indices,
+ clean_latents,
+ clean_latent_indices,
+ clean_latents_2x,
+ clean_latent_2x_indices,
+ clean_latents_4x,
+ clean_latent_4x_indices,
+ )
+ del (
+ latent_indices,
+ clean_latents,
+ clean_latent_indices,
+ clean_latents_2x,
+ clean_latent_2x_indices,
+ clean_latents_4x,
+ clean_latent_4x_indices,
+ ) # free memory
+
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
+ encoder_hidden_states = self.gradient_checkpointing_method(
+ self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask
+ )
+
+ if self.image_projection is not None:
+ assert image_embeddings is not None, "You must use image embeddings!"
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
+ extra_attention_mask = torch.ones(
+ (batch_size, extra_encoder_hidden_states.shape[1]),
+ dtype=encoder_attention_mask.dtype,
+ device=encoder_attention_mask.device,
+ )
+
+ # must cat before (not after) encoder_hidden_states, due to attn masking
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
+ del extra_encoder_hidden_states, extra_attention_mask # free memory
+
+ with torch.no_grad():
+ if batch_size == 1:
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
+ text_len = encoder_attention_mask.sum().item()
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
+ attention_mask = None, None, None, None
+ else:
+ img_seq_len = hidden_states.shape[1]
+ txt_seq_len = encoder_hidden_states.shape[1]
+
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
+ cu_seqlens_kv = cu_seqlens_q
+ max_seqlen_q = img_seq_len + txt_seq_len
+ max_seqlen_kv = max_seqlen_q
+
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
+ del cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv # free memory
+ del encoder_attention_mask # free memory
+
+ if self.enable_teacache:
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
+
+ if self.cnt == 0 or self.cnt == self.num_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ curr_rel_l1 = (
+ ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())
+ .cpu()
+ .item()
+ )
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
+
+ if should_calc:
+ self.accumulated_rel_l1_distance = 0
+
+ self.previous_modulated_input = modulated_inp
+ self.cnt += 1
+
+ if self.cnt == self.num_steps:
+ self.cnt = 0
+
+ if not should_calc:
+ hidden_states = hidden_states + self.previous_residual
+ else:
+ ori_hidden_states = hidden_states.clone()
+
+ for block_id, block in enumerate(self.transformer_blocks):
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
+ )
+
+ for block_id, block in enumerate(self.single_transformer_blocks):
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
+ )
+
+ self.previous_residual = hidden_states - ori_hidden_states
+ del ori_hidden_states # free memory
+ else:
+ for block_id, block in enumerate(self.transformer_blocks):
+ if self.blocks_to_swap:
+ self.offloader_double.wait_for_block(block_id)
+
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
+ )
+
+ if self.blocks_to_swap:
+ self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id)
+
+ for block_id, block in enumerate(self.single_transformer_blocks):
+ if self.blocks_to_swap:
+ self.offloader_single.wait_for_block(block_id)
+
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
+ )
+
+ if self.blocks_to_swap:
+ self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id)
+
+ del attention_mask, rope_freqs # free memory
+ del encoder_hidden_states # free memory
+
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
+
+ hidden_states = hidden_states[:, -original_context_length:, :]
+
+ if self.high_quality_fp32_output_for_inference:
+ hidden_states = hidden_states.to(dtype=torch.float32)
+ if self.proj_out.weight.dtype != torch.float32:
+ self.proj_out.to(dtype=torch.float32)
+
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
+
+ hidden_states = einops.rearrange(
+ hidden_states,
+ "b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)",
+ t=post_patch_num_frames,
+ h=post_patch_height,
+ w=post_patch_width,
+ pt=p_t,
+ ph=p,
+ pw=p,
+ )
+
+ if return_dict:
+ # return Transformer2DModelOutput(sample=hidden_states)
+ return SimpleNamespace(sample=hidden_states)
+
+ return (hidden_states,)
+
+ def fp8_optimization(
+ self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False
+ ) -> dict[str, torch.Tensor]: # Return type hint added
+ """
+ Optimize the model state_dict with fp8.
+
+ Args:
+ state_dict (dict[str, torch.Tensor]):
+ The state_dict of the model.
+ device (torch.device):
+ The device to calculate the weight.
+ move_to_device (bool):
+ Whether to move the weight to the device after optimization.
+ use_scaled_mm (bool):
+ Whether to use scaled matrix multiplication for FP8.
+ """
+ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
+ EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
+
+ # inplace optimization
+ state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device)
+
+ # apply monkey patching
+ apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm)
+
+ return state_dict
+
+
+def load_packed_model(
+ device: Union[str, torch.device],
+ dit_path: str,
+ attn_mode: str,
+ loading_device: Union[str, torch.device],
+ fp8_scaled: bool = False,
+ split_attn: bool = False,
+) -> HunyuanVideoTransformer3DModelPacked:
+ # TODO support split_attn
+ device = torch.device(device)
+ loading_device = torch.device(loading_device)
+
+ if os.path.isdir(dit_path):
+ # we don't support from_pretrained for now, so loading safetensors directly
+ safetensor_files = glob.glob(os.path.join(dit_path, "*.safetensors"))
+ if len(safetensor_files) == 0:
+ raise ValueError(f"Cannot find safetensors file in {dit_path}")
+ # sort by name and take the first one
+ safetensor_files.sort()
+ dit_path = safetensor_files[0]
+
+ with init_empty_weights():
+ logger.info(f"Creating HunyuanVideoTransformer3DModelPacked")
+ model = HunyuanVideoTransformer3DModelPacked(
+ attention_head_dim=128,
+ guidance_embeds=True,
+ has_clean_x_embedder=True,
+ has_image_proj=True,
+ image_proj_dim=1152,
+ in_channels=16,
+ mlp_ratio=4.0,
+ num_attention_heads=24,
+ num_layers=20,
+ num_refiner_layers=2,
+ num_single_layers=40,
+ out_channels=16,
+ patch_size=2,
+ patch_size_t=1,
+ pooled_projection_dim=768,
+ qk_norm="rms_norm",
+ rope_axes_dim=(16, 56, 56),
+ rope_theta=256.0,
+ text_embed_dim=4096,
+ attn_mode=attn_mode,
+ split_attn=split_attn,
+ )
+
+ # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others)
+ dit_loading_device = torch.device("cpu") if fp8_scaled else loading_device
+ logger.info(f"Loading DiT model from {dit_path}, device={dit_loading_device}")
+
+ # load model weights with the specified dtype or as is
+ sd = load_split_weights(dit_path, device=dit_loading_device, disable_mmap=True)
+
+ if fp8_scaled:
+ # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap)
+ logger.info(f"Optimizing model weights to fp8. This may take a while.")
+ sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu")
+
+ if loading_device.type != "cpu":
+ # make sure all the model weights are on the loading_device
+ logger.info(f"Moving weights to {loading_device}")
+ for key in sd.keys():
+ sd[key] = sd[key].to(loading_device)
+
+ info = model.load_state_dict(sd, strict=True, assign=True)
+ logger.info(f"Loaded DiT model from {dit_path}, info={info}")
+
+ return model
diff --git a/frame_pack/k_diffusion_hunyuan.py b/frame_pack/k_diffusion_hunyuan.py
new file mode 100644
index 0000000000000000000000000000000000000000..60524eae0d6c9571ee90164ce520ba41cd8d3d20
--- /dev/null
+++ b/frame_pack/k_diffusion_hunyuan.py
@@ -0,0 +1,128 @@
+# original code: https://github.com/lllyasviel/FramePack
+# original license: Apache-2.0
+
+import torch
+import math
+
+# from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
+# from diffusers_helper.k_diffusion.wrapper import fm_wrapper
+# from diffusers_helper.utils import repeat_to_batch_size
+from frame_pack.uni_pc_fm import sample_unipc
+from frame_pack.wrapper import fm_wrapper
+from frame_pack.utils import repeat_to_batch_size
+
+
+def flux_time_shift(t, mu=1.15, sigma=1.0):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+
+def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
+ k = (y2 - y1) / (x2 - x1)
+ b = y1 - k * x1
+ mu = k * context_length + b
+ mu = min(mu, math.log(exp_max))
+ return mu
+
+
+def get_flux_sigmas_from_mu(n, mu):
+ sigmas = torch.linspace(1, 0, steps=n + 1)
+ sigmas = flux_time_shift(sigmas, mu=mu)
+ return sigmas
+
+
+# @torch.inference_mode()
+def sample_hunyuan(
+ transformer,
+ sampler="unipc",
+ initial_latent=None,
+ concat_latent=None,
+ strength=1.0,
+ width=512,
+ height=512,
+ frames=16,
+ real_guidance_scale=1.0,
+ distilled_guidance_scale=6.0,
+ guidance_rescale=0.0,
+ shift=None,
+ num_inference_steps=25,
+ batch_size=None,
+ generator=None,
+ prompt_embeds=None,
+ prompt_embeds_mask=None,
+ prompt_poolers=None,
+ negative_prompt_embeds=None,
+ negative_prompt_embeds_mask=None,
+ negative_prompt_poolers=None,
+ dtype=torch.bfloat16,
+ device=None,
+ negative_kwargs=None,
+ callback=None,
+ **kwargs,
+):
+ device = device or transformer.device
+
+ if batch_size is None:
+ batch_size = int(prompt_embeds.shape[0])
+
+ latents = torch.randn(
+ (batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device
+ ).to(device=device, dtype=torch.float32)
+
+ B, C, T, H, W = latents.shape
+ seq_length = T * H * W // 4 # 9*80*80//4 = 14400
+
+ if shift is None:
+ mu = calculate_flux_mu(seq_length, exp_max=7.0) # 1.9459... if seq_len is large, mu is clipped.
+ else:
+ mu = math.log(shift)
+
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
+
+ k_model = fm_wrapper(transformer)
+
+ if initial_latent is not None:
+ sigmas = sigmas * strength
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
+
+ if concat_latent is not None:
+ concat_latent = concat_latent.to(latents)
+
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
+
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
+
+ sampler_kwargs = dict(
+ dtype=dtype,
+ cfg_scale=real_guidance_scale,
+ cfg_rescale=guidance_rescale,
+ concat_latent=concat_latent,
+ positive=dict(
+ pooled_projections=prompt_poolers,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_embeds_mask,
+ guidance=distilled_guidance,
+ **kwargs,
+ ),
+ negative=dict(
+ pooled_projections=negative_prompt_poolers,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_embeds_mask,
+ guidance=distilled_guidance,
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
+ ),
+ )
+
+ if sampler == "unipc":
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
+ else:
+ raise NotImplementedError(f"Sampler {sampler} is not supported.")
+
+ return results
diff --git a/frame_pack/uni_pc_fm.py b/frame_pack/uni_pc_fm.py
new file mode 100644
index 0000000000000000000000000000000000000000..43a198f9f1c408b8c84b47a675c871aaf71bc418
--- /dev/null
+++ b/frame_pack/uni_pc_fm.py
@@ -0,0 +1,142 @@
+# Better Flow Matching UniPC by Lvmin Zhang
+# (c) 2025
+# CC BY-SA 4.0
+# Attribution-ShareAlike 4.0 International Licence
+
+
+import torch
+
+from tqdm.auto import trange
+
+
+def expand_dims(v, dims):
+ return v[(...,) + (None,) * (dims - 1)]
+
+
+class FlowMatchUniPC:
+ def __init__(self, model, extra_args, variant='bh1'):
+ self.model = model
+ self.variant = variant
+ self.extra_args = extra_args
+
+ def model_fn(self, x, t):
+ return self.model(x, t, **self.extra_args)
+
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
+ assert order <= len(model_prev_list)
+ dims = x.dim()
+
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = - torch.log(t_prev_0)
+ lambda_t = - torch.log(t)
+ model_prev_0 = model_prev_list[-1]
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = - torch.log(t_prev_i)
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.)
+ rks = torch.tensor(rks, device=x.device)
+
+ R = []
+ b = []
+
+ hh = -h[0]
+ h_phi_1 = torch.expm1(hh)
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.variant == 'bh1':
+ B_h = hh
+ elif self.variant == 'bh2':
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError('Bad variant!')
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= (i + 1)
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=x.device)
+
+ use_predictor = len(D1s) > 0
+
+ if use_predictor:
+ D1s = torch.stack(D1s, dim=1)
+ if order == 2:
+ rhos_p = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
+ else:
+ D1s = None
+ rhos_p = None
+
+ if order == 1:
+ rhos_c = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_c = torch.linalg.solve(R, b)
+
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
+
+ if use_predictor:
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
+ else:
+ pred_res = 0
+
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
+ model_t = self.model_fn(x_t, t)
+
+ if D1s is not None:
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
+ else:
+ corr_res = 0
+
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
+
+ return x_t, model_t
+
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
+ order = min(3, len(sigmas) - 2)
+ model_prev_list, t_prev_list = [], []
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
+ vec_t = sigmas[i].expand(x.shape[0])
+
+ with torch.no_grad():
+ if i == 0:
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ elif i < order:
+ init_order = i
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+ else:
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+
+ model_prev_list = model_prev_list[-order:]
+ t_prev_list = t_prev_list[-order:]
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
+
+ return model_prev_list[-1]
+
+
+def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
+ assert variant in ['bh1', 'bh2']
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
diff --git a/frame_pack/utils.py b/frame_pack/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69bd5fea25b2e8a6c80a774c0bd8eeb5926d0a7
--- /dev/null
+++ b/frame_pack/utils.py
@@ -0,0 +1,617 @@
+import os
+import cv2
+import json
+import random
+import glob
+import torch
+import einops
+import numpy as np
+import datetime
+import torchvision
+
+import safetensors.torch as sf
+from PIL import Image
+
+
+def min_resize(x, m):
+ if x.shape[0] < x.shape[1]:
+ s0 = m
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
+ else:
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
+ s1 = m
+ new_max = max(s1, s0)
+ raw_max = max(x.shape[0], x.shape[1])
+ if new_max < raw_max:
+ interpolation = cv2.INTER_AREA
+ else:
+ interpolation = cv2.INTER_LANCZOS4
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
+ return y
+
+
+def d_resize(x, y):
+ H, W, C = y.shape
+ new_min = min(H, W)
+ raw_min = min(x.shape[0], x.shape[1])
+ if new_min < raw_min:
+ interpolation = cv2.INTER_AREA
+ else:
+ interpolation = cv2.INTER_LANCZOS4
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
+ return y
+
+
+def resize_and_center_crop(image, target_width, target_height):
+ if target_height == image.shape[0] and target_width == image.shape[1]:
+ return image
+
+ pil_image = Image.fromarray(image)
+ original_width, original_height = pil_image.size
+ scale_factor = max(target_width / original_width, target_height / original_height)
+ resized_width = int(round(original_width * scale_factor))
+ resized_height = int(round(original_height * scale_factor))
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
+ left = (resized_width - target_width) / 2
+ top = (resized_height - target_height) / 2
+ right = (resized_width + target_width) / 2
+ bottom = (resized_height + target_height) / 2
+ cropped_image = resized_image.crop((left, top, right, bottom))
+ return np.array(cropped_image)
+
+
+def resize_and_center_crop_pytorch(image, target_width, target_height):
+ B, C, H, W = image.shape
+
+ if H == target_height and W == target_width:
+ return image
+
+ scale_factor = max(target_width / W, target_height / H)
+ resized_width = int(round(W * scale_factor))
+ resized_height = int(round(H * scale_factor))
+
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
+
+ top = (resized_height - target_height) // 2
+ left = (resized_width - target_width) // 2
+ cropped = resized[:, :, top : top + target_height, left : left + target_width]
+
+ return cropped
+
+
+def resize_without_crop(image, target_width, target_height):
+ if target_height == image.shape[0] and target_width == image.shape[1]:
+ return image
+
+ pil_image = Image.fromarray(image)
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
+ return np.array(resized_image)
+
+
+def just_crop(image, w, h):
+ if h == image.shape[0] and w == image.shape[1]:
+ return image
+
+ original_height, original_width = image.shape[:2]
+ k = min(original_height / h, original_width / w)
+ new_width = int(round(w * k))
+ new_height = int(round(h * k))
+ x_start = (original_width - new_width) // 2
+ y_start = (original_height - new_height) // 2
+ cropped_image = image[y_start : y_start + new_height, x_start : x_start + new_width]
+ return cropped_image
+
+
+def write_to_json(data, file_path):
+ temp_file_path = file_path + ".tmp"
+ with open(temp_file_path, "wt", encoding="utf-8") as temp_file:
+ json.dump(data, temp_file, indent=4)
+ os.replace(temp_file_path, file_path)
+ return
+
+
+def read_from_json(file_path):
+ with open(file_path, "rt", encoding="utf-8") as file:
+ data = json.load(file)
+ return data
+
+
+def get_active_parameters(m):
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
+
+
+def cast_training_params(m, dtype=torch.float32):
+ result = {}
+ for n, param in m.named_parameters():
+ if param.requires_grad:
+ param.data = param.to(dtype)
+ result[n] = param
+ return result
+
+
+def separate_lora_AB(parameters, B_patterns=None):
+ parameters_normal = {}
+ parameters_B = {}
+
+ if B_patterns is None:
+ B_patterns = [".lora_B.", "__zero__"]
+
+ for k, v in parameters.items():
+ if any(B_pattern in k for B_pattern in B_patterns):
+ parameters_B[k] = v
+ else:
+ parameters_normal[k] = v
+
+ return parameters_normal, parameters_B
+
+
+def set_attr_recursive(obj, attr, value):
+ attrs = attr.split(".")
+ for name in attrs[:-1]:
+ obj = getattr(obj, name)
+ setattr(obj, attrs[-1], value)
+ return
+
+
+def print_tensor_list_size(tensors):
+ total_size = 0
+ total_elements = 0
+
+ if isinstance(tensors, dict):
+ tensors = tensors.values()
+
+ for tensor in tensors:
+ total_size += tensor.nelement() * tensor.element_size()
+ total_elements += tensor.nelement()
+
+ total_size_MB = total_size / (1024**2)
+ total_elements_B = total_elements / 1e9
+
+ print(f"Total number of tensors: {len(tensors)}")
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
+ return
+
+
+@torch.no_grad()
+def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
+ batch_size = a.size(0)
+
+ if b is None:
+ b = torch.zeros_like(a)
+
+ if mask_a is None:
+ mask_a = torch.rand(batch_size) < probability_a
+
+ mask_a = mask_a.to(a.device)
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
+ result = torch.where(mask_a, a, b)
+ return result
+
+
+@torch.no_grad()
+def zero_module(module):
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+@torch.no_grad()
+def supress_lower_channels(m, k, alpha=0.01):
+ data = m.weight.data.clone()
+
+ assert int(data.shape[1]) >= k
+
+ data[:, :k] = data[:, :k] * alpha
+ m.weight.data = data.contiguous().clone()
+ return m
+
+
+def freeze_module(m):
+ if not hasattr(m, "_forward_inside_frozen_module"):
+ m._forward_inside_frozen_module = m.forward
+ m.requires_grad_(False)
+ m.forward = torch.no_grad()(m.forward)
+ return m
+
+
+def get_latest_safetensors(folder_path):
+ safetensors_files = glob.glob(os.path.join(folder_path, "*.safetensors"))
+
+ if not safetensors_files:
+ raise ValueError("No file to resume!")
+
+ latest_file = max(safetensors_files, key=os.path.getmtime)
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
+ return latest_file
+
+
+def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
+ tags = tags_str.split(", ")
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
+ prompt = ", ".join(tags)
+ return prompt
+
+
+def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
+ if round_to_int:
+ numbers = np.round(numbers).astype(int)
+ return numbers.tolist()
+
+
+def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
+ edges = np.linspace(0, 1, n + 1)
+ points = np.random.uniform(edges[:-1], edges[1:])
+ numbers = inclusive + (exclusive - inclusive) * points
+ if round_to_int:
+ numbers = np.round(numbers).astype(int)
+ return numbers.tolist()
+
+
+def soft_append_bcthw(history, current, overlap=0):
+ if overlap <= 0:
+ return torch.cat([history, current], dim=2)
+
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
+
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
+
+ return output.to(history)
+
+
+def save_bcthw_as_mp4(x, output_filename, fps=10):
+ b, c, t, h, w = x.shape
+
+ per_row = b
+ for p in [6, 5, 4, 3, 2]:
+ if b % p == 0:
+ per_row = p
+ break
+
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
+ x = x.detach().cpu().to(torch.uint8)
+ x = einops.rearrange(x, "(m n) c t h w -> t (m h) (n w) c", n=per_row)
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec="libx264", options={"crf": "0"})
+
+ # write tensor as .pt file
+ torch.save(x, output_filename.replace(".mp4", ".pt"))
+
+ return x
+
+
+def save_bcthw_as_png(x, output_filename):
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
+ x = x.detach().cpu().to(torch.uint8)
+ x = einops.rearrange(x, "b c t h w -> c (b h) (t w)")
+ torchvision.io.write_png(x, output_filename)
+ return output_filename
+
+
+def save_bchw_as_png(x, output_filename):
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
+ x = x.detach().cpu().to(torch.uint8)
+ x = einops.rearrange(x, "b c h w -> c h (b w)")
+ torchvision.io.write_png(x, output_filename)
+ return output_filename
+
+
+def add_tensors_with_padding(tensor1, tensor2):
+ if tensor1.shape == tensor2.shape:
+ return tensor1 + tensor2
+
+ shape1 = tensor1.shape
+ shape2 = tensor2.shape
+
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
+
+ padded_tensor1 = torch.zeros(new_shape)
+ padded_tensor2 = torch.zeros(new_shape)
+
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
+
+ result = padded_tensor1 + padded_tensor2
+ return result
+
+
+def print_free_mem():
+ torch.cuda.empty_cache()
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
+ free_mem_mb = free_mem / (1024**2)
+ total_mem_mb = total_mem / (1024**2)
+ print(f"Free memory: {free_mem_mb:.2f} MB")
+ print(f"Total memory: {total_mem_mb:.2f} MB")
+ return
+
+
+def print_gpu_parameters(device, state_dict, log_count=1):
+ summary = {"device": device, "keys_count": len(state_dict)}
+
+ logged_params = {}
+ for i, (key, tensor) in enumerate(state_dict.items()):
+ if i >= log_count:
+ break
+ logged_params[key] = tensor.flatten()[:3].tolist()
+
+ summary["params"] = logged_params
+
+ print(str(summary))
+ return
+
+
+def visualize_txt_as_img(width, height, text, font_path="font/DejaVuSans.ttf", size=18):
+ from PIL import Image, ImageDraw, ImageFont
+
+ txt = Image.new("RGB", (width, height), color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype(font_path, size=size)
+
+ if text == "":
+ return np.array(txt)
+
+ # Split text into lines that fit within the image width
+ lines = []
+ words = text.split()
+ current_line = words[0]
+
+ for word in words[1:]:
+ line_with_word = f"{current_line} {word}"
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
+ current_line = line_with_word
+ else:
+ lines.append(current_line)
+ current_line = word
+
+ lines.append(current_line)
+
+ # Draw the text line by line
+ y = 0
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
+
+ for line in lines:
+ if y + line_height > height:
+ break # stop drawing if the next line will be outside the image
+ draw.text((0, y), line, fill="black", font=font)
+ y += line_height
+
+ return np.array(txt)
+
+
+def blue_mark(x):
+ x = x.copy()
+ c = x[:, :, 2]
+ b = cv2.blur(c, (9, 9))
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
+ return x
+
+
+def green_mark(x):
+ x = x.copy()
+ x[:, :, 2] = -1
+ x[:, :, 0] = -1
+ return x
+
+
+def frame_mark(x):
+ x = x.copy()
+ x[:64] = -1
+ x[-64:] = -1
+ x[:, :8] = 1
+ x[:, -8:] = 1
+ return x
+
+
+@torch.inference_mode()
+def pytorch2numpy(imgs):
+ results = []
+ for x in imgs:
+ y = x.movedim(0, -1)
+ y = y * 127.5 + 127.5
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
+ results.append(y)
+ return results
+
+
+@torch.inference_mode()
+def numpy2pytorch(imgs):
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
+ h = h.movedim(-1, 1)
+ return h
+
+
+@torch.no_grad()
+def duplicate_prefix_to_suffix(x, count, zero_out=False):
+ if zero_out:
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
+ else:
+ return torch.cat([x, x[:count]], dim=0)
+
+
+def weighted_mse(a, b, weight):
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
+
+
+def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
+ x = (x - x_min) / (x_max - x_min)
+ x = max(0.0, min(x, 1.0))
+ x = x**sigma
+ return y_min + x * (y_max - y_min)
+
+
+def expand_to_dims(x, target_dims):
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
+
+
+def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
+ if tensor is None:
+ return None
+
+ first_dim = tensor.shape[0]
+
+ if first_dim == batch_size:
+ return tensor
+
+ if batch_size % first_dim != 0:
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
+
+ repeat_times = batch_size // first_dim
+
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
+
+
+def dim5(x):
+ return expand_to_dims(x, 5)
+
+
+def dim4(x):
+ return expand_to_dims(x, 4)
+
+
+def dim3(x):
+ return expand_to_dims(x, 3)
+
+
+def crop_or_pad_yield_mask(x, length):
+ B, F, C = x.shape
+ device = x.device
+ dtype = x.dtype
+
+ if F < length:
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
+ y[:, :F, :] = x
+ mask[:, :F] = True
+ return y, mask
+
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
+
+
+def extend_dim(x, dim, minimal_length, zero_pad=False):
+ original_length = int(x.shape[dim])
+
+ if original_length >= minimal_length:
+ return x
+
+ if zero_pad:
+ padding_shape = list(x.shape)
+ padding_shape[dim] = minimal_length - original_length
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
+ else:
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
+ last_element = x[idx]
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
+
+ return torch.cat([x, padding], dim=dim)
+
+
+def lazy_positional_encoding(t, repeats=None):
+ if not isinstance(t, list):
+ t = [t]
+
+ from diffusers.models.embeddings import get_timestep_embedding
+
+ te = torch.tensor(t)
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
+
+ if repeats is None:
+ return te
+
+ te = te[:, None, :].expand(-1, repeats, -1)
+
+ return te
+
+
+def state_dict_offset_merge(A, B, C=None):
+ result = {}
+ keys = A.keys()
+
+ for key in keys:
+ A_value = A[key]
+ B_value = B[key].to(A_value)
+
+ if C is None:
+ result[key] = A_value + B_value
+ else:
+ C_value = C[key].to(A_value)
+ result[key] = A_value + B_value - C_value
+
+ return result
+
+
+def state_dict_weighted_merge(state_dicts, weights):
+ if len(state_dicts) != len(weights):
+ raise ValueError("Number of state dictionaries must match number of weights")
+
+ if not state_dicts:
+ return {}
+
+ total_weight = sum(weights)
+
+ if total_weight == 0:
+ raise ValueError("Sum of weights cannot be zero")
+
+ normalized_weights = [w / total_weight for w in weights]
+
+ keys = state_dicts[0].keys()
+ result = {}
+
+ for key in keys:
+ result[key] = state_dicts[0][key] * normalized_weights[0]
+
+ for i in range(1, len(state_dicts)):
+ state_dict_value = state_dicts[i][key].to(result[key])
+ result[key] += state_dict_value * normalized_weights[i]
+
+ return result
+
+
+def group_files_by_folder(all_files):
+ grouped_files = {}
+
+ for file in all_files:
+ folder_name = os.path.basename(os.path.dirname(file))
+ if folder_name not in grouped_files:
+ grouped_files[folder_name] = []
+ grouped_files[folder_name].append(file)
+
+ list_of_lists = list(grouped_files.values())
+ return list_of_lists
+
+
+def generate_timestamp():
+ now = datetime.datetime.now()
+ timestamp = now.strftime("%y%m%d_%H%M%S")
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
+ random_number = random.randint(0, 9999)
+ return f"{timestamp}_{milliseconds}_{random_number}"
+
+
+def write_PIL_image_with_png_info(image, metadata, path):
+ from PIL.PngImagePlugin import PngInfo
+
+ png_info = PngInfo()
+ for key, value in metadata.items():
+ png_info.add_text(key, value)
+
+ image.save(path, "PNG", pnginfo=png_info)
+ return image
+
+
+def torch_safe_save(content, path):
+ torch.save(content, path + "_tmp")
+ os.replace(path + "_tmp", path)
+ return path
+
+
+def move_optimizer_to_device(optimizer, device):
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(device)
diff --git a/frame_pack/wrapper.py b/frame_pack/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc420da4db1134deca30648077923021b35f82d1
--- /dev/null
+++ b/frame_pack/wrapper.py
@@ -0,0 +1,51 @@
+import torch
+
+
+def append_dims(x, target_dims):
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
+ if guidance_rescale == 0:
+ return noise_cfg
+
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def fm_wrapper(transformer, t_scale=1000.0):
+ def k_model(x, sigma, **extra_args):
+ dtype = extra_args['dtype']
+ cfg_scale = extra_args['cfg_scale']
+ cfg_rescale = extra_args['cfg_rescale']
+ concat_latent = extra_args['concat_latent']
+
+ original_dtype = x.dtype
+ sigma = sigma.float()
+
+ x = x.to(dtype)
+ timestep = (sigma * t_scale).to(dtype)
+
+ if concat_latent is None:
+ hidden_states = x
+ else:
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
+
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
+
+ if cfg_scale == 1.0:
+ pred_negative = torch.zeros_like(pred_positive)
+ else:
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
+
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
+
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
+
+ return x0.to(dtype=original_dtype)
+
+ return k_model
diff --git a/hunyuan_model/__init__.py b/hunyuan_model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hunyuan_model/activation_layers.py b/hunyuan_model/activation_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8774c26ceef6081482ca0dbbf930b207d4ac03b
--- /dev/null
+++ b/hunyuan_model/activation_layers.py
@@ -0,0 +1,23 @@
+import torch.nn as nn
+
+
+def get_activation_layer(act_type):
+ """get activation layer
+
+ Args:
+ act_type (str): the activation type
+
+ Returns:
+ torch.nn.functional: the activation layer
+ """
+ if act_type == "gelu":
+ return lambda: nn.GELU()
+ elif act_type == "gelu_tanh":
+ # Approximate `tanh` requires torch >= 1.13
+ return lambda: nn.GELU(approximate="tanh")
+ elif act_type == "relu":
+ return nn.ReLU
+ elif act_type == "silu":
+ return nn.SiLU
+ else:
+ raise ValueError(f"Unknown activation type: {act_type}")
diff --git a/hunyuan_model/attention.py b/hunyuan_model/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..e94253df0aceb11e4f5812b728df75b9d38bf8c2
--- /dev/null
+++ b/hunyuan_model/attention.py
@@ -0,0 +1,295 @@
+import importlib.metadata
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ import flash_attn
+ from flash_attn.flash_attn_interface import _flash_attn_forward
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
+ from flash_attn.flash_attn_interface import flash_attn_func
+except ImportError:
+ flash_attn = None
+ flash_attn_varlen_func = None
+ _flash_attn_forward = None
+ flash_attn_func = None
+
+try:
+ print(f"Trying to import sageattention")
+ from sageattention import sageattn_varlen, sageattn
+
+ print("Successfully imported sageattention")
+except ImportError:
+ print(f"Failed to import sageattention")
+ sageattn_varlen = None
+ sageattn = None
+
+try:
+ import xformers.ops as xops
+except ImportError:
+ xops = None
+
+MEMORY_LAYOUT = {
+ "flash": (
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
+ lambda x: x,
+ ),
+ "flash_fixlen": (
+ lambda x: x,
+ lambda x: x,
+ ),
+ "sageattn": (
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
+ lambda x: x,
+ ),
+ "sageattn_fixlen": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+ "torch": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+ "xformers": (
+ lambda x: x,
+ lambda x: x,
+ ),
+ "vanilla": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+}
+
+
+def get_cu_seqlens(text_mask, img_len):
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
+
+ Args:
+ text_mask (torch.Tensor): the mask of text
+ img_len (int): the length of image
+
+ Returns:
+ torch.Tensor: the calculated cu_seqlens for flash attention
+ """
+ batch_size = text_mask.shape[0]
+ text_len = text_mask.sum(dim=1)
+ max_len = text_mask.shape[1] + img_len
+
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
+
+ for i in range(batch_size):
+ s = text_len[i] + img_len
+ s1 = i * max_len + s
+ s2 = (i + 1) * max_len
+ cu_seqlens[2 * i + 1] = s1
+ cu_seqlens[2 * i + 2] = s2
+
+ return cu_seqlens
+
+
+def attention(
+ q_or_qkv_list,
+ k=None,
+ v=None,
+ mode="flash",
+ drop_rate=0,
+ attn_mask=None,
+ total_len=None,
+ causal=False,
+ cu_seqlens_q=None,
+ cu_seqlens_kv=None,
+ max_seqlen_q=None,
+ max_seqlen_kv=None,
+ batch_size=1,
+):
+ """
+ Perform QKV self attention.
+
+ Args:
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
+ drop_rate (float): Dropout rate in attention map. (default: 0)
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
+ (default: None)
+ causal (bool): Whether to use causal attention. (default: False)
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into q.
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into kv.
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
+
+ Returns:
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
+ """
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
+ if type(q_or_qkv_list) == list:
+ q_or_qkv_list.clear()
+ split_attn = total_len is not None
+ if split_attn and mode == "sageattn":
+ mode = "sageattn_fixlen"
+ elif split_attn and mode == "flash":
+ mode = "flash_fixlen"
+ # print(f"Attention mode: {mode}, split_attn: {split_attn}")
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
+
+ # trim the sequence length to the actual length instead of attn_mask
+ if split_attn:
+ trimmed_len = q.shape[1] - total_len
+ q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
+ k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
+ v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
+ q = [pre_attn_layout(q_i) for q_i in q]
+ k = [pre_attn_layout(k_i) for k_i in k]
+ v = [pre_attn_layout(v_i) for v_i in v]
+ # print(
+ # f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
+ # )
+ else:
+ q = pre_attn_layout(q)
+ k = pre_attn_layout(k)
+ v = pre_attn_layout(v)
+
+ if mode == "torch":
+ if split_attn:
+ x = []
+ for i in range(len(q)):
+ x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
+ q[i], k[i], v[i] = None, None, None
+ x.append(x_i)
+ del q, k, v
+ else:
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(q.dtype)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
+ del q, k, v
+ del attn_mask
+
+ elif mode == "xformers":
+ # B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
+ # currently only support batch_size = 1
+ assert split_attn, "Xformers only supports splitting"
+ x = []
+ for i in range(len(q)):
+ x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
+ q[i], k[i], v[i] = None, None, None
+ x.append(x_i)
+ del q, k, v
+
+ elif mode == "flash":
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+ del q, k, v
+ # x with shape [(bxs), a, d]
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
+ elif mode == "flash_fixlen":
+ x = []
+ for i in range(len(q)):
+ # q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
+ x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
+ q[i], k[i], v[i] = None, None, None
+ x.append(x_i)
+ del q, k, v
+ elif mode == "sageattn":
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+ del q, k, v
+ # x with shape [(bxs), a, d]
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
+ elif mode == "sageattn_fixlen":
+ x = []
+ for i in range(len(q)):
+ # HND seems to cause an error
+ x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
+ q[i], k[i], v[i] = None, None, None
+ x.append(x_i)
+ del q, k, v
+ elif mode == "vanilla":
+ assert not split_attn, "Vanilla attention does not support trimming"
+ scale_factor = 1 / math.sqrt(q.size(-1))
+
+ b, a, s, _ = q.shape
+ s1 = k.size(2)
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
+ if causal:
+ # Only applied to self attention
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(q.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
+ attn += attn_bias
+ attn = attn.softmax(dim=-1)
+ attn = torch.dropout(attn, p=drop_rate, train=True)
+ x = attn @ v
+ else:
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
+
+ if split_attn:
+ x = [post_attn_layout(x_i) for x_i in x]
+ for i in range(len(x)):
+ x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
+ x = torch.cat(x, dim=0)
+ else:
+ x = post_attn_layout(x)
+
+ b, s, a, d = x.shape
+ out = x.reshape(b, s, -1)
+ return out
+
+
+def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
+ attn1 = hybrid_seq_parallel_attn(
+ None,
+ q[:, :img_q_len, :, :],
+ k[:, :img_kv_len, :, :],
+ v[:, :img_kv_len, :, :],
+ dropout_p=0.0,
+ causal=False,
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
+ joint_strategy="rear",
+ )
+ if flash_attn.__version__ >= "2.7.0":
+ attn2, *_ = _flash_attn_forward(
+ q[:, cu_seqlens_q[1] :],
+ k[:, cu_seqlens_kv[1] :],
+ v[:, cu_seqlens_kv[1] :],
+ dropout_p=0.0,
+ softmax_scale=q.shape[-1] ** (-0.5),
+ causal=False,
+ window_size_left=-1,
+ window_size_right=-1,
+ softcap=0.0,
+ alibi_slopes=None,
+ return_softmax=False,
+ )
+ else:
+ attn2, *_ = _flash_attn_forward(
+ q[:, cu_seqlens_q[1] :],
+ k[:, cu_seqlens_kv[1] :],
+ v[:, cu_seqlens_kv[1] :],
+ dropout_p=0.0,
+ softmax_scale=q.shape[-1] ** (-0.5),
+ causal=False,
+ window_size=(-1, -1),
+ softcap=0.0,
+ alibi_slopes=None,
+ return_softmax=False,
+ )
+ attn = torch.cat([attn1, attn2], dim=1)
+ b, s, a, d = attn.shape
+ attn = attn.reshape(b, s, -1)
+
+ return attn
diff --git a/hunyuan_model/autoencoder_kl_causal_3d.py b/hunyuan_model/autoencoder_kl_causal_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e70737325a50e1ee1fbbee96b4a0aafbdcd241
--- /dev/null
+++ b/hunyuan_model/autoencoder_kl_causal_3d.py
@@ -0,0 +1,609 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+from typing import Dict, Optional, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+
+# try:
+# # This diffusers is modified and packed in the mirror.
+# from diffusers.loaders import FromOriginalVAEMixin
+# except ImportError:
+# # Use this to be compatible with the original diffusers.
+# from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
+
+
+@dataclass
+class DecoderOutput2(BaseOutput):
+ sample: torch.FloatTensor
+ posterior: Optional[DiagonalGaussianDistribution] = None
+
+
+class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ sample_tsize: int = 64,
+ scaling_factor: float = 0.18215,
+ force_upcast: float = True,
+ spatial_compression_ratio: int = 8,
+ time_compression_ratio: int = 4,
+ mid_block_add_attention: bool = True,
+ ):
+ super().__init__()
+
+ self.time_compression_ratio = time_compression_ratio
+
+ self.encoder = EncoderCausal3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ time_compression_ratio=time_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.decoder = DecoderCausal3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ time_compression_ratio=time_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
+
+ self.use_slicing = False
+ self.use_spatial_tiling = False
+ self.use_temporal_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_tsize = sample_tsize
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
+
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
+ module.gradient_checkpointing = value
+
+ def enable_temporal_tiling(self, use_tiling: bool = True):
+ self.use_temporal_tiling = use_tiling
+
+ def disable_temporal_tiling(self):
+ self.enable_temporal_tiling(False)
+
+ def enable_spatial_tiling(self, use_tiling: bool = True):
+ self.use_spatial_tiling = use_tiling
+
+ def disable_spatial_tiling(self):
+ self.enable_spatial_tiling(False)
+
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger videos.
+ """
+ self.enable_spatial_tiling(use_tiling)
+ self.enable_temporal_tiling(use_tiling)
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.disable_spatial_tiling()
+ self.disable_temporal_tiling()
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
+ # set chunk_size to CausalConv3d recursively
+ def set_chunk_size(module):
+ if hasattr(module, "chunk_size"):
+ module.chunk_size = chunk_size
+
+ self.apply(set_chunk_size)
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images/videos into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images/videos.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
+
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
+
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
+
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
+
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images/videos.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
+ return b
+
+ def spatial_tiled_encode(
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
+ ) -> AutoencoderKLOutput:
+ r"""Encode a batch of images/videos using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images/videos.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split video into tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[-2], overlap_size):
+ row = []
+ for j in range(0, x.shape[-1], overlap_size):
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ moments = torch.cat(result_rows, dim=-2)
+ if return_moments:
+ return moments
+
+ posterior = DiagonalGaussianDistribution(moments)
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Decode a batch of images/videos using a tiled decoder.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[-2], overlap_size):
+ row = []
+ for j in range(0, z.shape[-1], overlap_size):
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=-2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+
+ B, C, T, H, W = x.shape
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
+ t_limit = self.tile_latent_min_tsize - blend_extent
+
+ # Split the video into tiles and encode them separately.
+ row = []
+ for i in range(0, T, overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
+ if self.use_spatial_tiling and (
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
+ ):
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
+ else:
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ if i > 0:
+ tile = tile[:, :, 1:, :, :]
+ row.append(tile)
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :t_limit, :, :])
+ else:
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
+
+ moments = torch.cat(result_row, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ # Split z into overlapping tiles and decode them separately.
+
+ B, C, T, H, W = z.shape
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
+ t_limit = self.tile_sample_min_tsize - blend_extent
+
+ row = []
+ for i in range(0, T, overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
+ if self.use_spatial_tiling and (
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
+ ):
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
+ else:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ if i > 0:
+ decoded = decoded[:, :, 1:, :, :]
+ row.append(decoded)
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :t_limit, :, :])
+ else:
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
+
+ dec = torch.cat(result_row, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ return_posterior: bool = False,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ if return_posterior:
+ return (dec, posterior)
+ else:
+ return (dec,)
+ if return_posterior:
+ return DecoderOutput2(sample=dec, posterior=posterior)
+ else:
+ return DecoderOutput2(sample=dec)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
diff --git a/hunyuan_model/embed_layers.py b/hunyuan_model/embed_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e31ba9cc58d1aa05e0f17b919762f69bd693b5c0
--- /dev/null
+++ b/hunyuan_model/embed_layers.py
@@ -0,0 +1,132 @@
+import collections
+import math
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from .helpers import to_2tuple
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding
+
+ Image to Patch Embedding using Conv2d
+
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
+
+ Based on the impl in https://github.com/google-research/vision_transformer
+
+ Hacked together by / Copyright 2020 Ross Wightman
+
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
+ """
+
+ def __init__(
+ self,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ bias=True,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+ self.flatten = flatten
+
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
+ if bias:
+ nn.init.zeros_(self.proj.bias)
+
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
+
+
+class TextProjection(nn.Module):
+ """
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
+ self.act_1 = act_layer()
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ Args:
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ dim (int): the dimension of the output.
+ max_period (int): controls the minimum frequency of the embeddings.
+
+ Returns:
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
+
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ """
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ act_layer,
+ frequency_embedding_size=256,
+ max_period=10000,
+ out_size=None,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.frequency_embedding_size = frequency_embedding_size
+ self.max_period = max_period
+ if out_size is None:
+ out_size = hidden_size
+
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
+ act_layer(),
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
+ )
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
diff --git a/hunyuan_model/fp8_optimization.py b/hunyuan_model/fp8_optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b978baca8cd9a3401b8b66a6575c0c3c29c991
--- /dev/null
+++ b/hunyuan_model/fp8_optimization.py
@@ -0,0 +1,39 @@
+#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
+#further borrowed from HunyuanVideoWrapper for Musubi Tuner
+import torch
+import torch.nn as nn
+
+def fp8_linear_forward(cls, original_dtype, input):
+ weight_dtype = cls.weight.dtype
+ if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
+ if len(input.shape) == 3:
+ target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn
+ inn = input.reshape(-1, input.shape[2]).to(target_dtype)
+ w = cls.weight.t()
+
+ scale = torch.ones((1), device=input.device, dtype=torch.float32)
+ bias = cls.bias.to(original_dtype) if cls.bias is not None else None
+
+ if bias is not None:
+ o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale)
+ else:
+ o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale)
+
+ if isinstance(o, tuple):
+ o = o[0]
+
+ return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
+ else:
+ return cls.original_forward(input.to(original_dtype))
+ else:
+ return cls.original_forward(input)
+
+def convert_fp8_linear(module, original_dtype, params_to_keep={}):
+ setattr(module, "fp8_matmul_enabled", True)
+
+ for name, module in module.named_modules():
+ if not any(keyword in name for keyword in params_to_keep):
+ if isinstance(module, nn.Linear):
+ original_forward = module.forward
+ setattr(module, "original_forward", original_forward)
+ setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
diff --git a/hunyuan_model/helpers.py b/hunyuan_model/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..72ab8cb1feba4ce7782f1ea841fd42c71be7b0d1
--- /dev/null
+++ b/hunyuan_model/helpers.py
@@ -0,0 +1,40 @@
+import collections.abc
+
+from itertools import repeat
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ x = tuple(x)
+ if len(x) == 1:
+ x = tuple(repeat(x[0], n))
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+
+
+def as_tuple(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ if x is None or isinstance(x, (int, float, str)):
+ return (x,)
+ else:
+ raise ValueError(f"Unknown type {type(x)}")
+
+
+def as_list_of_2tuple(x):
+ x = as_tuple(x)
+ if len(x) == 1:
+ x = (x[0], x[0])
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
+ lst = []
+ for i in range(0, len(x), 2):
+ lst.append((x[i], x[i + 1]))
+ return lst
diff --git a/hunyuan_model/mlp_layers.py b/hunyuan_model/mlp_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc9547a6a0ba80ab19a472a9ea7aef525f46613
--- /dev/null
+++ b/hunyuan_model/mlp_layers.py
@@ -0,0 +1,118 @@
+# Modified from timm library:
+# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from .modulate_layers import modulate
+from .helpers import to_2tuple
+
+
+class MLP(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ device=None,
+ dtype=None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ out_features = out_features or in_channels
+ hidden_channels = hidden_channels or in_channels
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
+ )
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = (
+ norm_layer(hidden_channels, **factory_kwargs)
+ if norm_layer is not None
+ else nn.Identity()
+ )
+ self.fc2 = linear_layer(
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
+ )
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+#
+class MLPEmbedder(nn.Module):
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class FinalLayer(nn.Module):
+ """The final layer of DiT."""
+
+ def __init__(
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ # Just use LayerNorm for the final layer
+ self.norm_final = nn.LayerNorm(
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
+ )
+ if isinstance(patch_size, int):
+ self.linear = nn.Linear(
+ hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True,
+ **factory_kwargs
+ )
+ else:
+ self.linear = nn.Linear(
+ hidden_size,
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
+ bias=True,
+ )
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ # Here we don't distinguish between the modulate types. Just use the simple one.
+ self.adaLN_modulation = nn.Sequential(
+ act_layer(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
+ x = self.linear(x)
+ return x
diff --git a/hunyuan_model/models.py b/hunyuan_model/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..20921f4beb93f35d241020a4f14474e29bcf485a
--- /dev/null
+++ b/hunyuan_model/models.py
@@ -0,0 +1,1044 @@
+import os
+from typing import Any, List, Tuple, Optional, Union, Dict
+import accelerate
+from einops import rearrange
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .activation_layers import get_activation_layer
+from .norm_layers import get_norm_layer
+from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
+from .attention import attention, parallel_attention, get_cu_seqlens
+from .posemb_layers import apply_rotary_emb
+from .mlp_layers import MLP, MLPEmbedder, FinalLayer
+from .modulate_layers import ModulateDiT, modulate, apply_gate
+from .token_refiner import SingleTokenRefiner
+from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
+from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
+
+from utils.safetensors_utils import MemoryEfficientSafeOpen
+
+
+class MMDoubleStreamBlock(nn.Module):
+ """
+ A multimodal dit block with seperate modulation for
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
+ (Flux.1): https://github.com/black-forest-labs/flux
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ mlp_width_ratio: float,
+ mlp_act_type: str = "gelu_tanh",
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qkv_bias: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ attn_mode: str = "flash",
+ split_attn: bool = False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+
+ self.deterministic = False
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+
+ self.img_mod = ModulateDiT(
+ hidden_size,
+ factor=6,
+ act_layer=get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.img_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.img_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.img_mlp = MLP(
+ hidden_size,
+ mlp_hidden_dim,
+ act_layer=get_activation_layer(mlp_act_type),
+ bias=True,
+ **factory_kwargs,
+ )
+
+ self.txt_mod = ModulateDiT(
+ hidden_size,
+ factor=6,
+ act_layer=get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ self.txt_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.txt_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.txt_mlp = MLP(
+ hidden_size,
+ mlp_hidden_dim,
+ act_layer=get_activation_layer(mlp_act_type),
+ bias=True,
+ **factory_kwargs,
+ )
+ self.hybrid_seq_parallel_attn = None
+
+ self.gradient_checkpointing = False
+
+ def enable_deterministic(self):
+ self.deterministic = True
+
+ def disable_deterministic(self):
+ self.deterministic = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+
+ def _forward(
+ self,
+ img: torch.Tensor,
+ txt: torch.Tensor,
+ vec: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ total_len: Optional[torch.Tensor] = None,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_kv: Optional[int] = None,
+ freqs_cis: tuple = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
+ 6, dim=-1
+ )
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
+ 6, dim=-1
+ )
+
+ # Prepare image for attention.
+ img_modulated = self.img_norm1(img)
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
+ img_qkv = self.img_attn_qkv(img_modulated)
+ img_modulated = None
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ img_qkv = None
+ # Apply QK-Norm if needed
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
+
+ # Apply RoPE if needed.
+ if freqs_cis is not None:
+ img_q_shape = img_q.shape
+ img_k_shape = img_k.shape
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
+ assert (
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
+ # img_q, img_k = img_qq, img_kk
+
+ # Prepare txt for attention.
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
+ txt_modulated = None
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ txt_qkv = None
+ # Apply QK-Norm if needed.
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
+
+ # Run actual attention.
+ img_q_len = img_q.shape[1]
+ img_kv_len = img_k.shape[1]
+ batch_size = img_k.shape[0]
+ q = torch.cat((img_q, txt_q), dim=1)
+ img_q = txt_q = None
+ k = torch.cat((img_k, txt_k), dim=1)
+ img_k = txt_k = None
+ v = torch.cat((img_v, txt_v), dim=1)
+ img_v = txt_v = None
+
+ assert (
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
+
+ # attention computation start
+ if not self.hybrid_seq_parallel_attn:
+ l = [q, k, v]
+ q = k = v = None
+ attn = attention(
+ l,
+ mode=self.attn_mode,
+ attn_mask=attn_mask,
+ total_len=total_len,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv,
+ batch_size=batch_size,
+ )
+ else:
+ attn = parallel_attention(
+ self.hybrid_seq_parallel_attn,
+ q,
+ k,
+ v,
+ img_q_len=img_q_len,
+ img_kv_len=img_kv_len,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ )
+
+ # attention computation end
+
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
+ attn = None
+
+ # Calculate the img bloks.
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
+ img_attn = None
+ img = img + apply_gate(
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
+ gate=img_mod2_gate,
+ )
+
+ # Calculate the txt bloks.
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
+ txt_attn = None
+ txt = txt + apply_gate(
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
+ gate=txt_mod2_gate,
+ )
+
+ return img, txt
+
+ # def forward(
+ # self,
+ # img: torch.Tensor,
+ # txt: torch.Tensor,
+ # vec: torch.Tensor,
+ # attn_mask: Optional[torch.Tensor] = None,
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
+ # max_seqlen_q: Optional[int] = None,
+ # max_seqlen_kv: Optional[int] = None,
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+class MMSingleStreamBlock(nn.Module):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
+ (Flux.1): https://github.com/black-forest-labs/flux
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_act_type: str = "gelu_tanh",
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qk_scale: float = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ attn_mode: str = "flash",
+ split_attn: bool = False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+
+ self.deterministic = False
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+ self.mlp_hidden_dim = mlp_hidden_dim
+ self.scale = qk_scale or head_dim**-0.5
+
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
+
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.mlp_act = get_activation_layer(mlp_act_type)()
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
+ self.hybrid_seq_parallel_attn = None
+
+ self.gradient_checkpointing = False
+
+ def enable_deterministic(self):
+ self.deterministic = True
+
+ def disable_deterministic(self):
+ self.deterministic = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ vec: torch.Tensor,
+ txt_len: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ total_len: Optional[torch.Tensor] = None,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_kv: Optional[int] = None,
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ ) -> torch.Tensor:
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+ x_mod = None
+ # mlp = mlp.to("cpu", non_blocking=True)
+ # clean_memory_on_device(x.device)
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ qkv = None
+
+ # Apply QK-Norm if needed.
+ q = self.q_norm(q).to(v)
+ k = self.k_norm(k).to(v)
+
+ # Apply RoPE if needed.
+ if freqs_cis is not None:
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
+ q = k = None
+ img_q_shape = img_q.shape
+ img_k_shape = img_k.shape
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
+ assert (
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
+ # img_q, img_k = img_qq, img_kk
+ # del img_qq, img_kk
+ q = torch.cat((img_q, txt_q), dim=1)
+ k = torch.cat((img_k, txt_k), dim=1)
+ del img_q, txt_q, img_k, txt_k
+
+ # Compute attention.
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
+
+ # attention computation start
+ if not self.hybrid_seq_parallel_attn:
+ l = [q, k, v]
+ q = k = v = None
+ attn = attention(
+ l,
+ mode=self.attn_mode,
+ attn_mask=attn_mask,
+ total_len=total_len,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_kv=max_seqlen_kv,
+ batch_size=x.shape[0],
+ )
+ else:
+ attn = parallel_attention(
+ self.hybrid_seq_parallel_attn,
+ q,
+ k,
+ v,
+ img_q_len=img_q.shape[1],
+ img_kv_len=img_k.shape[1],
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ )
+ # attention computation end
+
+ # Compute activation in mlp stream, cat again and run second linear layer.
+ # mlp = mlp.to(x.device)
+ mlp = self.mlp_act(mlp)
+ attn_mlp = torch.cat((attn, mlp), 2)
+ attn = None
+ mlp = None
+ output = self.linear2(attn_mlp)
+ attn_mlp = None
+ return x + apply_gate(output, gate=mod_gate)
+
+ # def forward(
+ # self,
+ # x: torch.Tensor,
+ # vec: torch.Tensor,
+ # txt_len: int,
+ # attn_mask: Optional[torch.Tensor] = None,
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
+ # max_seqlen_q: Optional[int] = None,
+ # max_seqlen_kv: Optional[int] = None,
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ # ) -> torch.Tensor:
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
+ """
+ HunyuanVideo Transformer backbone
+
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
+
+ Reference:
+ [1] Flux.1: https://github.com/black-forest-labs/flux
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
+
+ Parameters
+ ----------
+ args: argparse.Namespace
+ The arguments parsed by argparse.
+ patch_size: list
+ The size of the patch.
+ in_channels: int
+ The number of input channels.
+ out_channels: int
+ The number of output channels.
+ hidden_size: int
+ The hidden size of the transformer backbone.
+ heads_num: int
+ The number of attention heads.
+ mlp_width_ratio: float
+ The ratio of the hidden size of the MLP in the transformer block.
+ mlp_act_type: str
+ The activation function of the MLP in the transformer block.
+ depth_double_blocks: int
+ The number of transformer blocks in the double blocks.
+ depth_single_blocks: int
+ The number of transformer blocks in the single blocks.
+ rope_dim_list: list
+ The dimension of the rotary embedding for t, h, w.
+ qkv_bias: bool
+ Whether to use bias in the qkv linear layer.
+ qk_norm: bool
+ Whether to use qk norm.
+ qk_norm_type: str
+ The type of qk norm.
+ guidance_embed: bool
+ Whether to use guidance embedding for distillation.
+ text_projection: str
+ The type of the text projection, default is single_refiner.
+ use_attention_mask: bool
+ Whether to use attention mask for text encoder.
+ dtype: torch.dtype
+ The dtype of the model.
+ device: torch.device
+ The device of the model.
+ attn_mode: str
+ The mode of the attention, default is flash.
+ split_attn: bool
+ Whether to use split attention (make attention as batch size 1).
+ """
+
+ # @register_to_config
+ def __init__(
+ self,
+ text_states_dim: int,
+ text_states_dim_2: int,
+ patch_size: list = [1, 2, 2],
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
+ out_channels: int = None,
+ hidden_size: int = 3072,
+ heads_num: int = 24,
+ mlp_width_ratio: float = 4.0,
+ mlp_act_type: str = "gelu_tanh",
+ mm_double_blocks_depth: int = 20,
+ mm_single_blocks_depth: int = 40,
+ rope_dim_list: List[int] = [16, 56, 56],
+ qkv_bias: bool = True,
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ guidance_embed: bool = False, # For modulation.
+ text_projection: str = "single_refiner",
+ use_attention_mask: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ attn_mode: str = "flash",
+ split_attn: bool = False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.unpatchify_channels = self.out_channels
+ self.guidance_embed = guidance_embed
+ self.rope_dim_list = rope_dim_list
+
+ # Text projection. Default to linear projection.
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
+ self.use_attention_mask = use_attention_mask
+ self.text_projection = text_projection
+
+ self.text_states_dim = text_states_dim
+ self.text_states_dim_2 = text_states_dim_2
+
+ if hidden_size % heads_num != 0:
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
+ pe_dim = hidden_size // heads_num
+ if sum(rope_dim_list) != pe_dim:
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+ print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
+
+ # image projection
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
+
+ # text projection
+ if self.text_projection == "linear":
+ self.txt_in = TextProjection(
+ self.text_states_dim,
+ self.hidden_size,
+ get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ elif self.text_projection == "single_refiner":
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
+ else:
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
+
+ # time modulation
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
+
+ # text modulation
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
+
+ # guidance modulation
+ self.guidance_in = (
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
+ )
+
+ # double blocks
+ self.double_blocks = nn.ModuleList(
+ [
+ MMDoubleStreamBlock(
+ self.hidden_size,
+ self.heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_act_type=mlp_act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ qkv_bias=qkv_bias,
+ attn_mode=attn_mode,
+ split_attn=split_attn,
+ **factory_kwargs,
+ )
+ for _ in range(mm_double_blocks_depth)
+ ]
+ )
+
+ # single blocks
+ self.single_blocks = nn.ModuleList(
+ [
+ MMSingleStreamBlock(
+ self.hidden_size,
+ self.heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_act_type=mlp_act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ attn_mode=attn_mode,
+ split_attn=split_attn,
+ **factory_kwargs,
+ )
+ for _ in range(mm_single_blocks_depth)
+ ]
+ )
+
+ self.final_layer = FinalLayer(
+ self.hidden_size,
+ self.patch_size,
+ self.out_channels,
+ get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+
+ self.gradient_checkpointing = False
+ self.blocks_to_swap = None
+ self.offloader_double = None
+ self.offloader_single = None
+ self._enable_img_in_txt_in_offloading = False
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ self.txt_in.enable_gradient_checkpointing()
+
+ for block in self.double_blocks + self.single_blocks:
+ block.enable_gradient_checkpointing()
+
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+
+ self.txt_in.disable_gradient_checkpointing()
+
+ for block in self.double_blocks + self.single_blocks:
+ block.disable_gradient_checkpointing()
+
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
+
+ def enable_img_in_txt_in_offloading(self):
+ self._enable_img_in_txt_in_offloading = True
+
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
+ self.blocks_to_swap = num_blocks
+ self.num_double_blocks = len(self.double_blocks)
+ self.num_single_blocks = len(self.single_blocks)
+ double_blocks_to_swap = num_blocks // 2
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
+
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
+ )
+
+ self.offloader_double = ModelOffloader(
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
+ )
+ self.offloader_single = ModelOffloader(
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
+ )
+ print(
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
+ )
+
+ def switch_block_swap_for_inference(self):
+ if self.blocks_to_swap:
+ self.offloader_double.set_forward_only(True)
+ self.offloader_single.set_forward_only(True)
+ self.prepare_block_swap_before_forward()
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
+
+ def switch_block_swap_for_training(self):
+ if self.blocks_to_swap:
+ self.offloader_double.set_forward_only(False)
+ self.offloader_single.set_forward_only(False)
+ self.prepare_block_swap_before_forward()
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
+
+ def move_to_device_except_swap_blocks(self, device: torch.device):
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
+ if self.blocks_to_swap:
+ save_double_blocks = self.double_blocks
+ save_single_blocks = self.single_blocks
+ self.double_blocks = None
+ self.single_blocks = None
+
+ self.to(device)
+
+ if self.blocks_to_swap:
+ self.double_blocks = save_double_blocks
+ self.single_blocks = save_single_blocks
+
+ def prepare_block_swap_before_forward(self):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
+
+ def enable_deterministic(self):
+ for block in self.double_blocks:
+ block.enable_deterministic()
+ for block in self.single_blocks:
+ block.enable_deterministic()
+
+ def disable_deterministic(self):
+ for block in self.double_blocks:
+ block.disable_deterministic()
+ for block in self.single_blocks:
+ block.disable_deterministic()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor, # Should be in range(0, 1000).
+ text_states: torch.Tensor = None,
+ text_mask: torch.Tensor = None, # Now we don't use it.
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
+ freqs_cos: Optional[torch.Tensor] = None,
+ freqs_sin: Optional[torch.Tensor] = None,
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ out = {}
+ img = x
+ txt = text_states
+ _, _, ot, oh, ow = x.shape
+ tt, th, tw = (
+ ot // self.patch_size[0],
+ oh // self.patch_size[1],
+ ow // self.patch_size[2],
+ )
+
+ # Prepare modulation vectors.
+ vec = self.time_in(t)
+
+ # text modulation
+ vec = vec + self.vector_in(text_states_2)
+
+ # guidance modulation
+ if self.guidance_embed:
+ if guidance is None:
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
+
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
+ vec = vec + self.guidance_in(guidance)
+
+ # Embed image and text.
+ if self._enable_img_in_txt_in_offloading:
+ self.img_in.to(x.device, non_blocking=True)
+ self.txt_in.to(x.device, non_blocking=True)
+ synchronize_device(x.device)
+
+ img = self.img_in(img)
+ if self.text_projection == "linear":
+ txt = self.txt_in(txt)
+ elif self.text_projection == "single_refiner":
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
+ else:
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
+
+ if self._enable_img_in_txt_in_offloading:
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
+ synchronize_device(x.device)
+ clean_memory_on_device(x.device)
+
+ txt_seq_len = txt.shape[1]
+ img_seq_len = img.shape[1]
+
+ # Compute cu_squlens and max_seqlen for flash attention
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
+ cu_seqlens_kv = cu_seqlens_q
+ max_seqlen_q = img_seq_len + txt_seq_len
+ max_seqlen_kv = max_seqlen_q
+
+ attn_mask = total_len = None
+ if self.split_attn or self.attn_mode == "torch":
+ # calculate text length and total length
+ text_len = text_mask.sum(dim=1) # (bs, )
+ total_len = img_seq_len + text_len # (bs, )
+ if self.attn_mode == "torch" and not self.split_attn:
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
+ bs = img.shape[0]
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
+
+ # set attention mask with total_len
+ for i in range(bs):
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
+ total_len = None # means we don't use split_attn
+
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
+ # --------------------- Pass through DiT blocks ------------------------
+ for block_idx, block in enumerate(self.double_blocks):
+ double_block_args = [
+ img,
+ txt,
+ vec,
+ attn_mask,
+ total_len,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ freqs_cis,
+ ]
+
+ if self.blocks_to_swap:
+ self.offloader_double.wait_for_block(block_idx)
+
+ img, txt = block(*double_block_args)
+
+ if self.blocks_to_swap:
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
+
+ # Merge txt and img to pass through single stream blocks.
+ x = torch.cat((img, txt), 1)
+ if self.blocks_to_swap:
+ # delete img, txt to reduce memory usage
+ del img, txt
+ clean_memory_on_device(x.device)
+
+ if len(self.single_blocks) > 0:
+ for block_idx, block in enumerate(self.single_blocks):
+ single_block_args = [
+ x,
+ vec,
+ txt_seq_len,
+ attn_mask,
+ total_len,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ freqs_cis,
+ ]
+ if self.blocks_to_swap:
+ self.offloader_single.wait_for_block(block_idx)
+
+ x = block(*single_block_args)
+
+ if self.blocks_to_swap:
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
+
+ img = x[:, :img_seq_len, ...]
+ x = None
+
+ # ---------------------------- Final layer ------------------------------
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+
+ img = self.unpatchify(img, tt, th, tw)
+ if return_dict:
+ out["x"] = img
+ return out
+ return img
+
+ def unpatchify(self, x, t, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.unpatchify_channels
+ pt, ph, pw = self.patch_size
+ assert t * h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
+ x = torch.einsum("nthwcopq->nctohpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
+
+ return imgs
+
+ def params_count(self):
+ counts = {
+ "double": sum(
+ [
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
+ + sum(p.numel() for p in block.img_mlp.parameters())
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
+ + sum(p.numel() for p in block.txt_mlp.parameters())
+ for block in self.double_blocks
+ ]
+ ),
+ "single": sum(
+ [
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
+ for block in self.single_blocks
+ ]
+ ),
+ "total": sum(p.numel() for p in self.parameters()),
+ }
+ counts["attn+mlp"] = counts["double"] + counts["single"]
+ return counts
+
+
+#################################################################################
+# HunyuanVideo Configs #
+#################################################################################
+
+HUNYUAN_VIDEO_CONFIG = {
+ "HYVideo-T/2": {
+ "mm_double_blocks_depth": 20,
+ "mm_single_blocks_depth": 40,
+ "rope_dim_list": [16, 56, 56],
+ "hidden_size": 3072,
+ "heads_num": 24,
+ "mlp_width_ratio": 4,
+ },
+ "HYVideo-T/2-cfgdistill": {
+ "mm_double_blocks_depth": 20,
+ "mm_single_blocks_depth": 40,
+ "rope_dim_list": [16, 56, 56],
+ "hidden_size": 3072,
+ "heads_num": 24,
+ "mlp_width_ratio": 4,
+ "guidance_embed": True,
+ },
+}
+
+
+def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
+ """load hunyuan video model
+
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
+
+ Args:
+ text_state_dim (int): text state dimension
+ text_state_dim_2 (int): text state dimension 2
+ in_channels (int): input channels number
+ out_channels (int): output channels number
+ factor_kwargs (dict): factor kwargs
+
+ Returns:
+ model (nn.Module): The hunyuan video model
+ """
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
+ model = HYVideoDiffusionTransformer(
+ text_states_dim=text_states_dim,
+ text_states_dim_2=text_states_dim_2,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
+ **factor_kwargs,
+ )
+ return model
+ # else:
+ # raise NotImplementedError()
+
+
+def load_state_dict(model, model_path):
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
+
+ load_key = "module"
+ if load_key in state_dict:
+ state_dict = state_dict[load_key]
+ else:
+ raise KeyError(
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
+ f"are: {list(state_dict.keys())}."
+ )
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ return model
+
+
+def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
+ # =========================== Build main model ===========================
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
+ latent_channels = 16
+ out_channels = latent_channels
+
+ with accelerate.init_empty_weights():
+ transformer = load_dit_model(
+ text_states_dim=4096,
+ text_states_dim_2=768,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ factor_kwargs=factor_kwargs,
+ )
+
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
+ # loading safetensors: may be already fp8
+ with MemoryEfficientSafeOpen(dit_path) as f:
+ state_dict = {}
+ for k in f.keys():
+ tensor = f.get_tensor(k)
+ tensor = tensor.to(device=device, dtype=dtype)
+ # TODO support comfy model
+ # if k.startswith("model.model."):
+ # k = convert_comfy_model_key(k)
+ state_dict[k] = tensor
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
+ else:
+ transformer = load_state_dict(transformer, dit_path)
+
+ return transformer
+
+
+def get_rotary_pos_embed_by_shape(model, latents_size):
+ target_ndim = 3
+ ndim = 5 - 2
+
+ if isinstance(model.patch_size, int):
+ assert all(s % model.patch_size == 0 for s in latents_size), (
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
+ f"but got {latents_size}."
+ )
+ rope_sizes = [s // model.patch_size for s in latents_size]
+ elif isinstance(model.patch_size, list):
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
+ f"but got {latents_size}."
+ )
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
+
+ if len(rope_sizes) != target_ndim:
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
+ head_dim = model.hidden_size // model.heads_num
+ rope_dim_list = model.rope_dim_list
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+
+ rope_theta = 256
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
+ )
+ return freqs_cos, freqs_sin
+
+
+def get_rotary_pos_embed(vae_name, model, video_length, height, width):
+ # 884
+ if "884" in vae_name:
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
+ elif "888" in vae_name:
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
+ else:
+ latents_size = [video_length, height // 8, width // 8]
+
+ return get_rotary_pos_embed_by_shape(model, latents_size)
diff --git a/hunyuan_model/modulate_layers.py b/hunyuan_model/modulate_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a57c6d2fdc0fca9bf44aeee6996bf1d8a05901
--- /dev/null
+++ b/hunyuan_model/modulate_layers.py
@@ -0,0 +1,76 @@
+from typing import Callable
+
+import torch
+import torch.nn as nn
+
+
+class ModulateDiT(nn.Module):
+ """Modulation layer for DiT."""
+ def __init__(
+ self,
+ hidden_size: int,
+ factor: int,
+ act_layer: Callable,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.act = act_layer()
+ self.linear = nn.Linear(
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear(self.act(x))
+
+
+def modulate(x, shift=None, scale=None):
+ """modulate by shift and scale
+
+ Args:
+ x (torch.Tensor): input tensor.
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
+
+ Returns:
+ torch.Tensor: the output tensor after modulate.
+ """
+ if scale is None and shift is None:
+ return x
+ elif shift is None:
+ return x * (1 + scale.unsqueeze(1))
+ elif scale is None:
+ return x + shift.unsqueeze(1)
+ else:
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+def apply_gate(x, gate=None, tanh=False):
+ """AI is creating summary for apply_gate
+
+ Args:
+ x (torch.Tensor): input tensor.
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
+
+ Returns:
+ torch.Tensor: the output tensor after apply gate.
+ """
+ if gate is None:
+ return x
+ if tanh:
+ return x * gate.unsqueeze(1).tanh()
+ else:
+ return x * gate.unsqueeze(1)
+
+
+def ckpt_wrapper(module):
+ def ckpt_forward(*inputs):
+ outputs = module(*inputs)
+ return outputs
+
+ return ckpt_forward
diff --git a/hunyuan_model/norm_layers.py b/hunyuan_model/norm_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53d167436b6971d3aabf5cfe51c0b9d6dfc022f
--- /dev/null
+++ b/hunyuan_model/norm_layers.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ elementwise_affine=True,
+ eps: float = 1e-6,
+ device=None,
+ dtype=None,
+ ):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.eps = eps
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ if hasattr(self, "weight"):
+ # output = output * self.weight
+ # support fp8
+ output = output * self.weight.to(output.dtype)
+ return output
+
+
+def get_norm_layer(norm_layer):
+ """
+ Get the normalization layer.
+
+ Args:
+ norm_layer (str): The type of normalization layer.
+
+ Returns:
+ norm_layer (nn.Module): The normalization layer.
+ """
+ if norm_layer == "layer":
+ return nn.LayerNorm
+ elif norm_layer == "rms":
+ return RMSNorm
+ else:
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
diff --git a/hunyuan_model/pipeline_hunyuan_video.py b/hunyuan_model/pipeline_hunyuan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1293161e13a47ae7dcedfef2c55e3baefc655f4
--- /dev/null
+++ b/hunyuan_model/pipeline_hunyuan_video.py
@@ -0,0 +1,1100 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+import torch
+import torch.distributed as dist
+import numpy as np
+from dataclasses import dataclass
+from packaging import version
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import BaseOutput
+
+from ...constants import PRECISION_TO_TYPE
+from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+from ...text_encoder import TextEncoder
+from ...modules import HYVideoDiffusionTransformer
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """"""
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
+ )
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = (
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ )
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+@dataclass
+class HunyuanVideoPipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+
+class HunyuanVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`TextEncoder`]):
+ Frozen text-encoder.
+ text_encoder_2 ([`TextEncoder`]):
+ Frozen text-encoder_2.
+ transformer ([`HYVideoDiffusionTransformer`]):
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = ["text_encoder_2"]
+ _exclude_from_cpu_offload = ["transformer"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: TextEncoder,
+ transformer: HYVideoDiffusionTransformer,
+ scheduler: KarrasDiffusionSchedulers,
+ text_encoder_2: Optional[TextEncoder] = None,
+ progress_bar_config: Dict[str, Any] = None,
+ args=None,
+ ):
+ super().__init__()
+
+ # ==========================================================================================
+ if progress_bar_config is None:
+ progress_bar_config = {}
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ self._progress_bar_config.update(progress_bar_config)
+
+ self.args = args
+ # ==========================================================================================
+
+ if (
+ hasattr(scheduler.config, "steps_offset")
+ and scheduler.config.steps_offset != 1
+ ):
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate(
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if (
+ hasattr(scheduler.config, "clip_sample")
+ and scheduler.config.clip_sample is True
+ ):
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_attention_mask: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ text_encoder: Optional[TextEncoder] = None,
+ data_type: Optional[str] = "image",
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of videos that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ attention_mask (`torch.Tensor`, *optional*):
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_attention_mask (`torch.Tensor`, *optional*):
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ text_encoder (TextEncoder, *optional*):
+ data_type (`str`, *optional*):
+ """
+ if text_encoder is None:
+ text_encoder = self.text_encoder
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
+ else:
+ scale_lora_layers(text_encoder.model, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
+
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
+
+ if clip_skip is None:
+ prompt_outputs = text_encoder.encode(
+ text_inputs, data_type=data_type, device=device
+ )
+ prompt_embeds = prompt_outputs.hidden_state
+ else:
+ prompt_outputs = text_encoder.encode(
+ text_inputs,
+ output_hidden_states=True,
+ data_type=data_type,
+ device=device,
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
+ prompt_embeds
+ )
+
+ attention_mask = prompt_outputs.attention_mask
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(device)
+ bs_embed, seq_len = attention_mask.shape
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
+ attention_mask = attention_mask.view(
+ bs_embed * num_videos_per_prompt, seq_len
+ )
+
+ if text_encoder is not None:
+ prompt_embeds_dtype = text_encoder.dtype
+ elif self.transformer is not None:
+ prompt_embeds_dtype = self.transformer.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ if prompt_embeds.ndim == 2:
+ bs_embed, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
+ else:
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(
+ bs_embed * num_videos_per_prompt, seq_len, -1
+ )
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(
+ uncond_tokens, text_encoder.tokenizer
+ )
+
+ # max_length = prompt_embeds.shape[1]
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
+
+ negative_prompt_outputs = text_encoder.encode(
+ uncond_input, data_type=data_type, device=device
+ )
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
+
+ negative_attention_mask = negative_prompt_outputs.attention_mask
+ if negative_attention_mask is not None:
+ negative_attention_mask = negative_attention_mask.to(device)
+ _, seq_len = negative_attention_mask.shape
+ negative_attention_mask = negative_attention_mask.repeat(
+ 1, num_videos_per_prompt
+ )
+ negative_attention_mask = negative_attention_mask.view(
+ batch_size * num_videos_per_prompt, seq_len
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=prompt_embeds_dtype, device=device
+ )
+
+ if negative_prompt_embeds.ndim == 2:
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_videos_per_prompt
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_videos_per_prompt, -1
+ )
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_videos_per_prompt, 1
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_videos_per_prompt, seq_len, -1
+ )
+
+ if text_encoder is not None:
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(text_encoder.model, lora_scale)
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ attention_mask,
+ negative_attention_mask,
+ )
+
+ def decode_latents(self, latents, enable_tiling=True):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ if enable_tiling:
+ self.vae.enable_tiling()
+ image = self.vae.decode(latents, return_dict=False)[0]
+ else:
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ if image.ndim == 4:
+ image = image.cpu().permute(0, 2, 3, 1).float()
+ else:
+ image = image.cpu().float()
+ return image
+
+ def prepare_extra_func_kwargs(self, func, kwargs):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ extra_step_kwargs = {}
+
+ for k, v in kwargs.items():
+ accepts = k in set(inspect.signature(func).parameters.keys())
+ if accepts:
+ extra_step_kwargs[k] = v
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ video_length,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ vae_ver="88-4c-sd",
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ )
+
+ if video_length is not None:
+ if "884" in vae_ver:
+ if video_length != 1 and (video_length - 1) % 4 != 0:
+ raise ValueError(
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
+ )
+ elif "888" in vae_ver:
+ if video_length != 1 and (video_length - 1) % 8 != 0:
+ raise ValueError(
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
+ )
+
+ if callback_steps is not None and (
+ not isinstance(callback_steps, int) or callback_steps <= 0
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs
+ for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (
+ not isinstance(prompt, str) and not isinstance(prompt, list)
+ ):
+ raise ValueError(
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ video_length,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(
+ shape, generator=generator, device=device, dtype=dtype
+ )
+ else:
+ latents = latents.to(device)
+
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self,
+ w: torch.Tensor,
+ embedding_dim: int = 512,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ video_length: int,
+ data_type: str = "video",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[
+ Callable[[int, int, Dict], None],
+ PipelineCallback,
+ MultiPipelineCallbacks,
+ ]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
+ vae_ver: str = "88-4c-sd",
+ enable_tiling: bool = False,
+ n_tokens: Optional[int] = None,
+ embedded_guidance_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`):
+ The height in pixels of the generated image.
+ width (`int`):
+ The width in pixels of the generated image.
+ video_length (`int`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. Default height and width to unet
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ video_length,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ vae_ver=vae_ver,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None)
+ if self.cross_attention_kwargs is not None
+ else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_mask,
+ negative_prompt_mask,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ attention_mask=attention_mask,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_attention_mask=negative_attention_mask,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ data_type=data_type,
+ )
+ if self.text_encoder_2 is not None:
+ (
+ prompt_embeds_2,
+ negative_prompt_embeds_2,
+ prompt_mask_2,
+ negative_prompt_mask_2,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=None,
+ attention_mask=None,
+ negative_prompt_embeds=None,
+ negative_attention_mask=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ text_encoder=self.text_encoder_2,
+ data_type=data_type,
+ )
+ else:
+ prompt_embeds_2 = None
+ negative_prompt_embeds_2 = None
+ prompt_mask_2 = None
+ negative_prompt_mask_2 = None
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if prompt_mask is not None:
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
+ if prompt_embeds_2 is not None:
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
+ if prompt_mask_2 is not None:
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
+
+
+ # 4. Prepare timesteps
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ **extra_set_timesteps_kwargs,
+ )
+
+ if "884" in vae_ver:
+ video_length = (video_length - 1) // 4 + 1
+ elif "888" in vae_ver:
+ video_length = (video_length - 1) // 8 + 1
+ else:
+ video_length = video_length
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
+ self.scheduler.step,
+ {"generator": generator, "eta": eta},
+ )
+
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
+ autocast_enabled = (
+ target_dtype != torch.float32
+ ) and not self.args.disable_autocast
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
+ vae_autocast_enabled = (
+ vae_dtype != torch.float32
+ ) and not self.args.disable_autocast
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ # if is_progress_bar:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * 2)
+ if self.do_classifier_free_guidance
+ else latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+
+ t_expand = t.repeat(latent_model_input.shape[0])
+ guidance_expand = (
+ torch.tensor(
+ [embedded_guidance_scale] * latent_model_input.shape[0],
+ dtype=torch.float32,
+ device=device,
+ ).to(target_dtype)
+ * 1000.0
+ if embedded_guidance_scale is not None
+ else None
+ )
+
+ # predict the noise residual
+ with torch.autocast(
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
+ ):
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
+ latent_model_input, # [2, 16, 33, 24, 42]
+ t_expand, # [2]
+ text_states=prompt_embeds, # [2, 256, 4096]
+ text_mask=prompt_mask, # [2, 256]
+ text_states_2=prompt_embeds_2, # [2, 768]
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
+ guidance=guidance_expand,
+ return_dict=True,
+ )[
+ "x"
+ ]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred,
+ noise_pred_text,
+ guidance_rescale=self.guidance_rescale,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop(
+ "negative_prompt_embeds", negative_prompt_embeds
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ if progress_bar is not None:
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ expand_temporal_dim = False
+ if len(latents.shape) == 4:
+ if isinstance(self.vae, AutoencoderKLCausal3D):
+ latents = latents.unsqueeze(2)
+ expand_temporal_dim = True
+ elif len(latents.shape) == 5:
+ pass
+ else:
+ raise ValueError(
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
+ )
+
+ if (
+ hasattr(self.vae.config, "shift_factor")
+ and self.vae.config.shift_factor
+ ):
+ latents = (
+ latents / self.vae.config.scaling_factor
+ + self.vae.config.shift_factor
+ )
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ with torch.autocast(
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
+ ):
+ if enable_tiling:
+ self.vae.enable_tiling()
+ image = self.vae.decode(
+ latents, return_dict=False, generator=generator
+ )[0]
+ else:
+ image = self.vae.decode(
+ latents, return_dict=False, generator=generator
+ )[0]
+
+ if expand_temporal_dim or image.shape[2] == 1:
+ image = image.squeeze(2)
+
+ else:
+ image = latents
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().float()
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return image
+
+ return HunyuanVideoPipelineOutput(videos=image)
diff --git a/hunyuan_model/posemb_layers.py b/hunyuan_model/posemb_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfce82c690540d17a55a51b7997ee7ceb0bdbf44
--- /dev/null
+++ b/hunyuan_model/posemb_layers.py
@@ -0,0 +1,310 @@
+import torch
+from typing import Union, Tuple, List
+
+
+def _to_tuple(x, dim=2):
+ if isinstance(x, int):
+ return (x,) * dim
+ elif len(x) == dim:
+ return x
+ else:
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
+
+
+def get_meshgrid_nd(start, *args, dim=2):
+ """
+ Get n-D meshgrid with start, stop and num.
+
+ Args:
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
+ n-tuples.
+ *args: See above.
+ dim (int): Dimension of the meshgrid. Defaults to 2.
+
+ Returns:
+ grid (np.ndarray): [dim, ...]
+ """
+ if len(args) == 0:
+ # start is grid_size
+ num = _to_tuple(start, dim=dim)
+ start = (0,) * dim
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = _to_tuple(start, dim=dim)
+ stop = _to_tuple(args[0], dim=dim)
+ num = [stop[i] - start[i] for i in range(dim)]
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
+ axis_grid = []
+ for i in range(dim):
+ a, b, n = start[i], stop[i], num[i]
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
+ axis_grid.append(g)
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
+
+ return grid
+
+
+#################################################################################
+# Rotary Positional Embedding Functions #
+#################################################################################
+# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
+
+
+def reshape_for_broadcast(
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ x: torch.Tensor,
+ head_first=False,
+):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Notes:
+ When using FlashMHAModified, head_first should be False.
+ When using Attention, head_first should be True.
+
+ Args:
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+
+ Raises:
+ AssertionError: If the frequency tensor doesn't match the expected shape.
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+ """
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+
+ if isinstance(freqs_cis, tuple):
+ # freqs_cis: (cos, sin) in real space
+ if head_first:
+ assert freqs_cis[0].shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [
+ d if i == ndim - 2 or i == ndim - 1 else 1
+ for i, d in enumerate(x.shape)
+ ]
+ else:
+ assert freqs_cis[0].shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
+ else:
+ # freqs_cis: values in complex space
+ if head_first:
+ assert freqs_cis.shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [
+ d if i == ndim - 2 or i == ndim - 1 else 1
+ for i, d in enumerate(x.shape)
+ ]
+ else:
+ assert freqs_cis.shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def rotate_half(x):
+ x_real, x_imag = (
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
+ ) # [B, S, H, D//2]
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ head_first: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+ returned as real tensors.
+
+ Args:
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+ """
+ xk_out = None
+ if isinstance(freqs_cis, tuple):
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
+ # real * cos - imag * sin
+ # imag * cos + real * sin
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
+ else:
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
+ xq_ = torch.view_as_complex(
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
+ ) # [B, S, H, D//2]
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
+ xq.device
+ ) # [S, D//2] --> [1, S, 1, D//2]
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
+ xk_ = torch.view_as_complex(
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
+ ) # [B, S, H, D//2]
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
+
+ return xq_out, xk_out
+
+
+def get_nd_rotary_pos_embed(
+ rope_dim_list,
+ start,
+ *args,
+ theta=10000.0,
+ use_real=False,
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
+ interpolation_factor: Union[float, List[float]] = 1.0,
+):
+ """
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
+
+ Args:
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
+ sum(rope_dim_list) should equal to head_dim of attention layer.
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
+ *args: See above.
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
+ part and an imaginary part separately.
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
+
+ Returns:
+ pos_embed (torch.Tensor): [HW, D/2]
+ """
+
+ grid = get_meshgrid_nd(
+ start, *args, dim=len(rope_dim_list)
+ ) # [3, W, H, D] / [2, W, H]
+
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
+ assert len(theta_rescale_factor) == len(
+ rope_dim_list
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
+
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
+ assert len(interpolation_factor) == len(
+ rope_dim_list
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
+
+ # use 1/ndim of dimensions to encode grid_axis
+ embs = []
+ for i in range(len(rope_dim_list)):
+ emb = get_1d_rotary_pos_embed(
+ rope_dim_list[i],
+ grid[i].reshape(-1),
+ theta,
+ use_real=use_real,
+ theta_rescale_factor=theta_rescale_factor[i],
+ interpolation_factor=interpolation_factor[i],
+ ) # 2 x [WHD, rope_dim_list[i]]
+ embs.append(emb)
+
+ if use_real:
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[torch.FloatTensor, int],
+ theta: float = 10000.0,
+ use_real: bool = False,
+ theta_rescale_factor: float = 1.0,
+ interpolation_factor: float = 1.0,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
+
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool, optional): If True, return real part and imaginary part separately.
+ Otherwise, return complex numbers.
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
+
+ Returns:
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
+ """
+ if isinstance(pos, int):
+ pos = torch.arange(pos).float()
+
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+ # has some connection to NTK literature
+ if theta_rescale_factor != 1.0:
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
+
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ ) # [D/2]
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(
+ torch.ones_like(freqs), freqs
+ ) # complex64 # [S, D/2]
+ return freqs_cis
diff --git a/hunyuan_model/text_encoder.py b/hunyuan_model/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b424c880c07d4344548b2e8bb01c397ad4d448a
--- /dev/null
+++ b/hunyuan_model/text_encoder.py
@@ -0,0 +1,710 @@
+from dataclasses import dataclass
+import json
+import os
+from typing import Optional, Tuple, Union
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ AutoTokenizer,
+ AutoModel,
+ CLIPConfig,
+ LlamaForCausalLM,
+ LlamaConfig,
+)
+from transformers.utils import ModelOutput
+from transformers.models.llama import LlamaModel
+from safetensors.torch import load_file
+from accelerate import init_empty_weights
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+CLIP_L_HUGGINGFACE_MODEL_ID = "openai/clip-vit-large-patch14"
+LLAVA_HUGGINGFACE_MODEL_ID = "xtuner/llava-llama-3-8b-v1_1-transformers"
+
+CLIP_CONFIG = {
+ "_name_or_path": "clip-vit-large-patch14/",
+ "architectures": ["CLIPModel"],
+ "initializer_factor": 1.0,
+ "logit_scale_init_value": 2.6592,
+ "model_type": "clip",
+ "projection_dim": 768,
+ # "text_config": {
+ "_name_or_path": "",
+ "add_cross_attention": False,
+ "architectures": None,
+ "attention_dropout": 0.0,
+ "bad_words_ids": None,
+ "bos_token_id": 0,
+ "chunk_size_feed_forward": 0,
+ "cross_attention_hidden_size": None,
+ "decoder_start_token_id": None,
+ "diversity_penalty": 0.0,
+ "do_sample": False,
+ "dropout": 0.0,
+ "early_stopping": False,
+ "encoder_no_repeat_ngram_size": 0,
+ "eos_token_id": 2,
+ "finetuning_task": None,
+ "forced_bos_token_id": None,
+ "forced_eos_token_id": None,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 768,
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "is_decoder": False,
+ "is_encoder_decoder": False,
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
+ "layer_norm_eps": 1e-05,
+ "length_penalty": 1.0,
+ "max_length": 20,
+ "max_position_embeddings": 77,
+ "min_length": 0,
+ "model_type": "clip_text_model",
+ "no_repeat_ngram_size": 0,
+ "num_attention_heads": 12,
+ "num_beam_groups": 1,
+ "num_beams": 1,
+ "num_hidden_layers": 12,
+ "num_return_sequences": 1,
+ "output_attentions": False,
+ "output_hidden_states": False,
+ "output_scores": False,
+ "pad_token_id": 1,
+ "prefix": None,
+ "problem_type": None,
+ "projection_dim": 768,
+ "pruned_heads": {},
+ "remove_invalid_values": False,
+ "repetition_penalty": 1.0,
+ "return_dict": True,
+ "return_dict_in_generate": False,
+ "sep_token_id": None,
+ "task_specific_params": None,
+ "temperature": 1.0,
+ "tie_encoder_decoder": False,
+ "tie_word_embeddings": True,
+ "tokenizer_class": None,
+ "top_k": 50,
+ "top_p": 1.0,
+ "torch_dtype": None,
+ "torchscript": False,
+ "transformers_version": "4.16.0.dev0",
+ "use_bfloat16": False,
+ "vocab_size": 49408,
+ # },
+ # "text_config_dict": {
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "projection_dim": 768,
+ # },
+ # "torch_dtype": "float32",
+ # "transformers_version": null
+}
+
+LLAMA_CONFIG = {
+ "architectures": ["LlamaForCausalLM"],
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": 128001,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 14336,
+ "max_position_embeddings": 8192,
+ "mlp_bias": False,
+ "model_type": "llama",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": None,
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": False,
+ "torch_dtype": "float16",
+ "transformers_version": "4.46.3",
+ "use_cache": True,
+ "vocab_size": 128320,
+}
+
+# When using decoder-only models, we must provide a prompt template to instruct the text encoder
+# on how to generate the text.
+# --------------------------------------------------------------------
+PROMPT_TEMPLATE_ENCODE = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+)
+PROMPT_TEMPLATE_ENCODE_VIDEO = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+)
+
+NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
+
+PROMPT_TEMPLATE = {
+ "dit-llm-encode": {
+ "template": PROMPT_TEMPLATE_ENCODE,
+ "crop_start": 36,
+ },
+ "dit-llm-encode-video": {
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
+ "crop_start": 95,
+ },
+}
+
+
+def use_default(value, default):
+ return value if value is not None else default
+
+
+def load_clip_l(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
+ if os.path.isdir(text_encoder_path):
+ # load from directory, configs are in the directory
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
+ else:
+ # load from file, we create the model with the appropriate config
+ config = CLIPConfig(**CLIP_CONFIG)
+ with init_empty_weights():
+ text_encoder = CLIPTextModel._from_config(config, torch_dtype=dtype)
+
+ state_dict = load_file(text_encoder_path)
+
+ text_encoder.load_state_dict(state_dict, strict=True, assign=True)
+ # if dtype is not None:
+ # text_encoder.to(dtype=dtype)
+
+ return text_encoder
+
+
+def load_clip_l_tokenizer(tokenizer_path: str):
+ if os.path.isdir(tokenizer_path):
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
+ else:
+ # load from Hugging Face
+ logger.info(f"Loading tokenizer from Hugging Face: {CLIP_L_HUGGINGFACE_MODEL_ID}")
+ tokenizer = CLIPTokenizer.from_pretrained(CLIP_L_HUGGINGFACE_MODEL_ID, max_length=77)
+
+ return tokenizer
+
+
+def load_llm(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
+ if os.path.isdir(text_encoder_path):
+ # load from directory, configs are in the directory
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
+ else:
+ # load from file, we create the model with the appropriate config
+ config = LlamaConfig(**LLAMA_CONFIG)
+ with init_empty_weights():
+ text_encoder = LlamaForCausalLM._from_config(config, torch_dtype=dtype)
+
+ state_dict = load_file(text_encoder_path)
+
+ # support weights from ComfyUI
+ if "tokenizer" in state_dict:
+ state_dict.pop("tokenizer")
+
+ text_encoder.load_state_dict(state_dict, strict=True, assign=True)
+
+ return text_encoder
+
+
+def load_llm_tokenizer(tokenizer_path: str, padding_side="right"):
+ if os.path.isdir(tokenizer_path):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ else:
+ # load from Hugging Face
+ logger.info(f"Loading tokenizer from Hugging Face: {LLAVA_HUGGINGFACE_MODEL_ID}")
+ tokenizer = AutoTokenizer.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID, padding_side=padding_side)
+
+ return tokenizer
+
+
+def load_text_encoder(
+ text_encoder_type: str,
+ text_encoder_path: str,
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
+):
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
+
+ # reduce peak memory usage by specifying the dtype of the model
+ dtype = text_encoder_dtype
+ if text_encoder_type == "clipL":
+ text_encoder = load_clip_l(text_encoder_path, dtype=dtype)
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
+ elif text_encoder_type == "llm":
+ text_encoder = load_llm(text_encoder_path, dtype=dtype)
+ if hasattr(text_encoder, "norm"):
+ text_encoder.final_layer_norm = text_encoder.norm # by from_pretrained
+ else:
+ text_encoder.final_layer_norm = text_encoder.model.norm # by _from_config
+ else:
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
+ # from_pretrained will ensure that the model is in eval mode.
+
+ if dtype is not None:
+ text_encoder = text_encoder.to(dtype=dtype)
+
+ text_encoder.requires_grad_(False)
+
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
+ return text_encoder, text_encoder_path
+
+
+def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
+
+ if tokenizer_type == "clipL":
+ tokenizer = load_clip_l_tokenizer(tokenizer_path)
+ elif tokenizer_type == "llm":
+ tokenizer = load_llm_tokenizer(tokenizer_path, padding_side=padding_side)
+ else:
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
+
+ return tokenizer, tokenizer_path
+
+
+@dataclass
+class TextEncoderModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
+ List of decoded texts.
+ """
+
+ hidden_state: torch.FloatTensor = None
+ attention_mask: Optional[torch.LongTensor] = None
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
+ text_outputs: Optional[list] = None
+
+
+class TextEncoder(nn.Module):
+ def __init__(
+ self,
+ text_encoder_type: str,
+ max_length: int,
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
+ text_encoder_path: Optional[str] = None,
+ tokenizer_type: Optional[str] = None,
+ tokenizer_path: Optional[str] = None,
+ output_key: Optional[str] = None,
+ use_attention_mask: bool = True,
+ input_max_length: Optional[int] = None,
+ prompt_template: Optional[dict] = None,
+ prompt_template_video: Optional[dict] = None,
+ hidden_state_skip_layer: Optional[int] = None,
+ apply_final_norm: bool = False,
+ reproduce: bool = False,
+ ):
+ super().__init__()
+ self.text_encoder_type = text_encoder_type
+ self.max_length = max_length
+ # self.precision = text_encoder_precision
+ self.model_path = text_encoder_path
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
+ self.use_attention_mask = use_attention_mask
+ if prompt_template_video is not None:
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
+ self.prompt_template = prompt_template
+ self.prompt_template_video = prompt_template_video
+ self.hidden_state_skip_layer = hidden_state_skip_layer
+ self.apply_final_norm = apply_final_norm
+ self.reproduce = reproduce
+
+ self.use_template = self.prompt_template is not None
+ if self.use_template:
+ assert (
+ isinstance(self.prompt_template, dict) and "template" in self.prompt_template
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
+ assert "{}" in str(self.prompt_template["template"]), (
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
+ f"got {self.prompt_template['template']}"
+ )
+
+ self.use_video_template = self.prompt_template_video is not None
+ if self.use_video_template:
+ if self.prompt_template_video is not None:
+ assert (
+ isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
+ assert "{}" in str(self.prompt_template_video["template"]), (
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
+ f"got {self.prompt_template_video['template']}"
+ )
+
+ if "t5" in text_encoder_type:
+ self.output_key = output_key or "last_hidden_state"
+ elif "clip" in text_encoder_type:
+ self.output_key = output_key or "pooler_output"
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
+ self.output_key = output_key or "last_hidden_state"
+ else:
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
+
+ self.model, self.model_path = load_text_encoder(
+ text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
+ )
+ self.dtype = self.model.dtype
+
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
+ tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
+ )
+
+ def __repr__(self):
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
+
+ @property
+ def device(self):
+ return self.model.device
+
+ @staticmethod
+ def apply_text_to_template(text, template, prevent_empty_text=True):
+ """
+ Apply text to template.
+
+ Args:
+ text (str): Input text.
+ template (str or list): Template string or list of chat conversation.
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
+ by adding a space. Defaults to True.
+ """
+ if isinstance(template, str):
+ # Will send string to tokenizer. Used for llm
+ return template.format(text)
+ else:
+ raise TypeError(f"Unsupported template type: {type(template)}")
+
+ def text2tokens(self, text, data_type="image"):
+ """
+ Tokenize the input text.
+
+ Args:
+ text (str or list): Input text.
+ """
+ tokenize_input_type = "str"
+ if self.use_template:
+ if data_type == "image":
+ prompt_template = self.prompt_template["template"]
+ elif data_type == "video":
+ prompt_template = self.prompt_template_video["template"]
+ else:
+ raise ValueError(f"Unsupported data type: {data_type}")
+ if isinstance(text, (list, tuple)):
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
+ if isinstance(text[0], list):
+ tokenize_input_type = "list"
+ elif isinstance(text, str):
+ text = self.apply_text_to_template(text, prompt_template)
+ if isinstance(text, list):
+ tokenize_input_type = "list"
+ else:
+ raise TypeError(f"Unsupported text type: {type(text)}")
+
+ kwargs = dict(
+ truncation=True,
+ max_length=self.max_length,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ if tokenize_input_type == "str":
+ return self.tokenizer(
+ text,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ **kwargs,
+ )
+ elif tokenize_input_type == "list":
+ return self.tokenizer.apply_chat_template(
+ text,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ **kwargs,
+ )
+ else:
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
+
+ def encode(
+ self,
+ batch_encoding,
+ use_attention_mask=None,
+ output_hidden_states=False,
+ do_sample=None,
+ hidden_state_skip_layer=None,
+ return_texts=False,
+ data_type="image",
+ device=None,
+ ):
+ """
+ Args:
+ batch_encoding (dict): Batch encoding from tokenizer.
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
+ Defaults to None.
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
+ output_hidden_states will be set True. Defaults to False.
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
+ When self.produce is False, do_sample is set to True by default.
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
+ If None, self.output_key will be used. Defaults to None.
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
+ """
+ device = self.model.device if device is None else device
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
+ do_sample = use_default(do_sample, not self.reproduce)
+ attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
+ outputs = self.model(
+ input_ids=batch_encoding["input_ids"].to(device),
+ attention_mask=attention_mask,
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
+ )
+ if hidden_state_skip_layer is not None:
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
+ # Real last hidden state already has layer norm applied. So here we only apply it
+ # for intermediate layers.
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
+ else:
+ last_hidden_state = outputs[self.output_key]
+
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
+ if self.use_template:
+ if data_type == "image":
+ crop_start = self.prompt_template.get("crop_start", -1)
+ elif data_type == "video":
+ crop_start = self.prompt_template_video.get("crop_start", -1)
+ else:
+ raise ValueError(f"Unsupported data type: {data_type}")
+ if crop_start > 0:
+ last_hidden_state = last_hidden_state[:, crop_start:]
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
+
+ if output_hidden_states:
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
+
+ def forward(
+ self,
+ text,
+ use_attention_mask=None,
+ output_hidden_states=False,
+ do_sample=False,
+ hidden_state_skip_layer=None,
+ return_texts=False,
+ ):
+ batch_encoding = self.text2tokens(text)
+ return self.encode(
+ batch_encoding,
+ use_attention_mask=use_attention_mask,
+ output_hidden_states=output_hidden_states,
+ do_sample=do_sample,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ return_texts=return_texts,
+ )
+
+
+# region HunyanVideo architecture
+
+
+def load_text_encoder_1(
+ text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
+) -> TextEncoder:
+ text_encoder_dtype = dtype or torch.float16
+ text_encoder_type = "llm"
+ text_len = 256
+ hidden_state_skip_layer = 2
+ apply_final_norm = False
+ reproduce = False
+
+ prompt_template = "dit-llm-encode"
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
+ prompt_template_video = "dit-llm-encode-video"
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
+
+ crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
+ max_length = text_len + crop_start
+
+ text_encoder_1 = TextEncoder(
+ text_encoder_type=text_encoder_type,
+ max_length=max_length,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=text_encoder_dir,
+ tokenizer_type=text_encoder_type,
+ prompt_template=prompt_template,
+ prompt_template_video=prompt_template_video,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ apply_final_norm=apply_final_norm,
+ reproduce=reproduce,
+ )
+ text_encoder_1.eval()
+
+ if fp8_llm:
+ org_dtype = text_encoder_1.dtype
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
+ text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
+
+ # prepare LLM for fp8
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
+ def forward_hook(module):
+ def forward(hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
+
+ return forward
+
+ for module in llama_model.modules():
+ if module.__class__.__name__ in ["Embedding"]:
+ # print("set", module.__class__.__name__, "to", target_dtype)
+ module.to(target_dtype)
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
+ # print("set", module.__class__.__name__, "hooks")
+ module.forward = forward_hook(module)
+
+ prepare_fp8(text_encoder_1.model, org_dtype)
+ else:
+ text_encoder_1.to(device=device)
+
+ return text_encoder_1
+
+
+def load_text_encoder_2(
+ text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
+) -> TextEncoder:
+ text_encoder_dtype = dtype or torch.float16
+ reproduce = False
+
+ text_encoder_2_type = "clipL"
+ text_len_2 = 77
+
+ text_encoder_2 = TextEncoder(
+ text_encoder_type=text_encoder_2_type,
+ max_length=text_len_2,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=text_encoder_dir,
+ tokenizer_type=text_encoder_2_type,
+ reproduce=reproduce,
+ )
+ text_encoder_2.eval()
+
+ text_encoder_2.to(device=device)
+
+ return text_encoder_2
+
+
+# endregion
+
+
+if __name__ == "__main__":
+ import argparse
+ from utils.model_utils import str_to_dtype
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("type", type=str, help="Text Encoder type")
+ parser.add_argument("path1", type=str, help="Text Encoder directory or file 1")
+ parser.add_argument("path2", type=str, help="Text Encoder directory or file 2")
+ parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder")
+ args = parser.parse_args()
+
+ dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16
+
+ """
+ if args.type == "clipL":
+ text_encoder_1st = load_clip_l(args.path1, dtype=dtype)
+ tokenizer_1st = load_clip_l_tokenizer(args.path1)
+ text_encoder_2nd = load_clip_l(args.path2, dtype=dtype)
+ tokenizer_2nd = load_clip_l_tokenizer(args.path2)
+ elif args.type == "llm":
+ text_encoder_1st = load_llm(args.path1, dtype=dtype)
+ tokenizer_1st = load_llm_tokenizer(args.path1)
+ text_encoder_2nd = load_llm(args.path2, dtype=dtype)
+ tokenizer_2nd = load_llm_tokenizer(args.path2)
+
+ print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
+ print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
+
+ text_encoder_1st.to(device=device)
+ text_encoder_2nd.to(device=device)
+
+ test_text = "A cat sitting on a table"
+ token_ids_1st = tokenizer_1st(test_text, return_tensors="pt")["input_ids"]
+ token_ids_2nd = tokenizer_2nd(test_text, return_tensors="pt")["input_ids"]
+ assert torch.allclose(token_ids_1st, token_ids_2nd)
+ print(f"Token IDs are the same: {token_ids_1st}")
+
+ with torch.no_grad():
+ text_encoder_1st_output = text_encoder_1st(token_ids_1st.to(device), output_hidden_states=True)
+ text_encoder_2nd_output = text_encoder_2nd(token_ids_2nd.to(device), output_hidden_states=True)
+ print(f"1st Text Encoder output keys: {text_encoder_1st_output.keys()}")
+ print(f"2nd Text Encoder output keys: {text_encoder_2nd_output.keys()}")
+ for key in text_encoder_1st_output:
+ print(f"Checking output: {key}")
+ assert key in text_encoder_2nd_output, f"Key {key} not in 2nd Text Encoder output"
+ assert torch.allclose(text_encoder_1st_output[key], text_encoder_2nd_output[key])
+ print(f"Outputs are the same: {key}")
+ print("All outputs are the same.")
+ """
+
+ if args.type == "clipL":
+ text_encoder_1st = load_text_encoder_2(args.path1, device, dtype)
+ text_encoder_2nd = load_text_encoder_2(args.path2, device, dtype)
+ elif args.type == "llm":
+ text_encoder_1st = load_text_encoder_1(args.path1, device, False, dtype)
+ text_encoder_2nd = load_text_encoder_1(args.path2, device, False, dtype)
+ print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
+ print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
+
+ prompt = "A cat sitting on a table"
+ data_type = "video" # video only, image is not supported
+ text_inputs_1st = text_encoder_1st.text2tokens(prompt, data_type=data_type)
+ text_inputs_2nd = text_encoder_2nd.text2tokens(prompt, data_type=data_type)
+ print(text_inputs_1st)
+ assert torch.allclose(text_inputs_1st["input_ids"], text_inputs_2nd["input_ids"])
+
+ with torch.no_grad():
+ prompt_outputs_1st = text_encoder_1st.encode(text_inputs_1st, data_type=data_type)
+ prompt_outputs_2nd = text_encoder_2nd.encode(text_inputs_1st, data_type=data_type)
+
+ # prompt_outputs.hidden_state, prompt_outputs.attention_mask
+ assert torch.allclose(prompt_outputs_1st.hidden_state, prompt_outputs_2nd.hidden_state)
+ print("Hidden states are the same.")
+ assert torch.allclose(prompt_outputs_1st.attention_mask, prompt_outputs_2nd.attention_mask)
+ print("Attention masks are the same.")
+ print("All outputs are the same.")
diff --git a/hunyuan_model/token_refiner.py b/hunyuan_model/token_refiner.py
new file mode 100644
index 0000000000000000000000000000000000000000..378bbab7d5b5483f552bc37699650506dc6f790c
--- /dev/null
+++ b/hunyuan_model/token_refiner.py
@@ -0,0 +1,245 @@
+from typing import Optional
+
+from einops import rearrange
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .activation_layers import get_activation_layer
+from .attention import attention
+from .norm_layers import get_norm_layer
+from .embed_layers import TimestepEmbedder, TextProjection
+from .mlp_layers import MLP
+from .modulate_layers import modulate, apply_gate
+
+
+class IndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ heads_num,
+ mlp_width_ratio: str = 4.0,
+ mlp_drop_rate: float = 0.0,
+ act_type: str = "silu",
+ qk_norm: bool = False,
+ qk_norm_type: str = "layer",
+ qkv_bias: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.self_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.self_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
+ act_layer = get_activation_layer(act_type)
+ self.mlp = MLP(
+ in_channels=hidden_size,
+ hidden_channels=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=mlp_drop_rate,
+ **factory_kwargs,
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ act_layer(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ self.gradient_checkpointing = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
+ attn_mask: torch.Tensor = None,
+ ):
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
+
+ norm_x = self.norm1(x)
+ qkv = self.self_attn_qkv(norm_x)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ # Apply QK-Norm if needed
+ q = self.self_attn_q_norm(q).to(v)
+ k = self.self_attn_k_norm(k).to(v)
+
+ # Self-Attention
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
+
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
+
+ # FFN Layer
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
+
+ return x
+
+ def forward(self, *args, **kwargs):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
+ else:
+ return self._forward(*args, **kwargs)
+
+
+class IndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ heads_num,
+ depth,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ act_type: str = "silu",
+ qk_norm: bool = False,
+ qk_norm_type: str = "layer",
+ qkv_bias: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.blocks = nn.ModuleList(
+ [
+ IndividualTokenRefinerBlock(
+ hidden_size=hidden_size,
+ heads_num=heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ act_type=act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ qkv_bias=qkv_bias,
+ **factory_kwargs,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ def enable_gradient_checkpointing(self):
+ for block in self.blocks:
+ block.enable_gradient_checkpointing()
+
+ def disable_gradient_checkpointing(self):
+ for block in self.blocks:
+ block.disable_gradient_checkpointing()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ c: torch.LongTensor,
+ mask: Optional[torch.Tensor] = None,
+ ):
+ self_attn_mask = None
+ if mask is not None:
+ batch_size = mask.shape[0]
+ seq_len = mask.shape[1]
+ mask = mask.to(x.device)
+ # batch_size x 1 x seq_len x seq_len
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ # batch_size x 1 x seq_len x seq_len
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+ # avoids self-attention weight being NaN for padding tokens
+ self_attn_mask[:, :, :, 0] = True
+
+ for block in self.blocks:
+ x = block(x, c, self_attn_mask)
+ return x
+
+
+class SingleTokenRefiner(nn.Module):
+ """
+ A single token refiner block for llm text embedding refine.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ heads_num,
+ depth,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ act_type: str = "silu",
+ qk_norm: bool = False,
+ qk_norm_type: str = "layer",
+ qkv_bias: bool = True,
+ attn_mode: str = "torch",
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.attn_mode = attn_mode
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
+
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
+
+ act_layer = get_activation_layer(act_type)
+ # Build timestep embedding layer
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
+ # Build context embedding layer
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
+
+ self.individual_token_refiner = IndividualTokenRefiner(
+ hidden_size=hidden_size,
+ heads_num=heads_num,
+ depth=depth,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ act_type=act_type,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ qkv_bias=qkv_bias,
+ **factory_kwargs,
+ )
+
+ def enable_gradient_checkpointing(self):
+ self.individual_token_refiner.enable_gradient_checkpointing()
+
+ def disable_gradient_checkpointing(self):
+ self.individual_token_refiner.disable_gradient_checkpointing()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.LongTensor,
+ mask: Optional[torch.LongTensor] = None,
+ ):
+ timestep_aware_representations = self.t_embedder(t)
+
+ if mask is None:
+ context_aware_representations = x.mean(dim=1)
+ else:
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ context_aware_representations = self.c_embedder(context_aware_representations)
+ c = timestep_aware_representations + context_aware_representations
+
+ x = self.input_embedder(x)
+
+ x = self.individual_token_refiner(x, c, mask)
+
+ return x
diff --git a/hunyuan_model/vae.py b/hunyuan_model/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae718a5634e98e53a0c0dec85254228229a01c3
--- /dev/null
+++ b/hunyuan_model/vae.py
@@ -0,0 +1,446 @@
+from dataclasses import dataclass
+import json
+from typing import Optional, Tuple, Union
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from diffusers.utils import BaseOutput, is_torch_version
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.models.attention_processor import SpatialNorm
+from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+SCALING_FACTOR = 0.476986
+VAE_VER = "884-16c-hy" # We don't support other versions currently
+
+
+def load_vae(
+ vae_type: str = "884-16c-hy",
+ vae_dtype: Optional[Union[str, torch.dtype]] = None,
+ sample_size: tuple = None,
+ vae_path: str = None,
+ device=None,
+):
+ """the fucntion to load the 3D VAE model
+
+ Args:
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
+ sample_size (tuple, optional): the tiling size. Defaults to None.
+ vae_path (str, optional): the path to vae. Defaults to None.
+ logger (_type_, optional): logger. Defaults to None.
+ device (_type_, optional): device to load vae. Defaults to None.
+ """
+ if vae_path is None:
+ vae_path = VAE_PATH[vae_type]
+
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
+
+ # use fixed config for Hunyuan's VAE
+ CONFIG_JSON = """{
+ "_class_name": "AutoencoderKLCausal3D",
+ "_diffusers_version": "0.4.2",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlockCausal3D",
+ "DownEncoderBlockCausal3D",
+ "DownEncoderBlockCausal3D",
+ "DownEncoderBlockCausal3D"
+ ],
+ "in_channels": 3,
+ "latent_channels": 16,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "sample_tsize": 64,
+ "up_block_types": [
+ "UpDecoderBlockCausal3D",
+ "UpDecoderBlockCausal3D",
+ "UpDecoderBlockCausal3D",
+ "UpDecoderBlockCausal3D"
+ ],
+ "scaling_factor": 0.476986,
+ "time_compression_ratio": 4,
+ "mid_block_add_attention": true
+ }"""
+
+ # config = AutoencoderKLCausal3D.load_config(vae_path)
+ config = json.loads(CONFIG_JSON)
+
+ # import here to avoid circular import
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
+
+ if sample_size:
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
+ else:
+ vae = AutoencoderKLCausal3D.from_config(config)
+
+ # vae_ckpt = Path(vae_path) / "pytorch_model.pt"
+ # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
+
+ if vae_path.endswith(".safetensors"):
+ from safetensors.torch import load_file
+ ckpt = load_file(vae_path)
+ else:
+ ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ if any(k.startswith("vae.") for k in ckpt.keys()):
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
+ vae.load_state_dict(ckpt)
+
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
+ time_compression_ratio = vae.config.time_compression_ratio
+
+ if vae_dtype is not None:
+ vae = vae.to(vae_dtype)
+
+ vae.requires_grad_(False)
+
+ logger.info(f"VAE to dtype: {vae.dtype}")
+
+ if device is not None:
+ vae = vae.to(device)
+
+ vae.eval()
+
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ r"""
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The decoded output sample from the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class EncoderCausal3D(nn.Module):
+ r"""
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ mid_block_add_attention=True,
+ time_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
+
+ if time_compression_ratio == 4:
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
+ else:
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
+
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
+ down_block = get_down_block3d(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
+ downsample_stride=downsample_stride,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockCausal3D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ add_attention=mid_block_add_attention,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
+
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ r"""The forward method of the `EncoderCausal3D` class."""
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
+
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class DecoderCausal3D(nn.Module):
+ r"""
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ norm_type: str = "group", # group, spatial
+ mid_block_add_attention=True,
+ time_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ temb_channels = in_channels if norm_type == "spatial" else None
+
+ # mid
+ self.mid_block = UNetMidBlockCausal3D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ add_attention=mid_block_add_attention,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
+
+ if time_compression_ratio == 4:
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
+ else:
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
+
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
+ up_block = get_up_block3d(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
+ upsample_scale_factor=upsample_scale_factor,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=temb_channels,
+ resnet_time_scale_shift=norm_type,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_type == "spatial":
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ latent_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""The forward method of the `DecoderCausal3D` class."""
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
+
+ sample = self.conv_in(sample)
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ else:
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ else:
+ # middle
+ sample = self.mid_block(sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample, latent_embeds)
+
+ # post-process
+ if latent_embeds is None:
+ sample = self.conv_norm_out(sample)
+ else:
+ sample = self.conv_norm_out(sample, latent_embeds)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ if parameters.ndim == 3:
+ dim = 2 # (B, L, C)
+ elif parameters.ndim == 5 or parameters.ndim == 4:
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
+ else:
+ raise NotImplementedError
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = randn_tensor(
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ reduce_dim = list(range(1, self.mean.ndim))
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=reduce_dim,
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=reduce_dim,
+ )
+
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self) -> torch.Tensor:
+ return self.mean
diff --git a/hv_generate_video.py b/hv_generate_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b57335f3000d8a8292cb4ca619bfa4fcc92bba3
--- /dev/null
+++ b/hv_generate_video.py
@@ -0,0 +1,936 @@
+import argparse
+from datetime import datetime
+from pathlib import Path
+import random
+import sys
+import os
+import time
+from typing import Optional, Union
+
+import numpy as np
+import torch
+import torchvision
+import accelerate
+from diffusers.utils.torch_utils import randn_tensor
+from transformers.models.llama import LlamaModel
+from tqdm import tqdm
+import av
+from einops import rearrange
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from PIL import Image
+
+from hunyuan_model import vae
+from hunyuan_model.text_encoder import TextEncoder
+from hunyuan_model.text_encoder import PROMPT_TEMPLATE
+from hunyuan_model.vae import load_vae
+from hunyuan_model.models import load_transformer, get_rotary_pos_embed
+from hunyuan_model.fp8_optimization import convert_fp8_linear
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+from networks import lora
+
+try:
+ from lycoris.kohya import create_network_from_weights
+except:
+ pass
+
+from utils.model_utils import str_to_dtype
+from utils.safetensors_utils import mem_eff_save_file
+from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def clean_memory_on_device(device):
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ elif device.type == "cpu":
+ pass
+ elif device.type == "mps": # not tested
+ torch.mps.empty_cache()
+
+
+def synchronize_device(device: torch.device):
+ if device.type == "cuda":
+ torch.cuda.synchronize()
+ elif device.type == "xpu":
+ torch.xpu.synchronize()
+ elif device.type == "mps":
+ torch.mps.synchronize()
+
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
+ """save videos by video tensor
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
+
+ Args:
+ videos (torch.Tensor): video tensor predicted by the model
+ path (str): path to save video
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
+ n_rows (int, optional): Defaults to 1.
+ fps (int, optional): video save fps. Defaults to 8.
+ """
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = torch.clamp(x, 0, 1)
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ # # save video with av
+ # container = av.open(path, "w")
+ # stream = container.add_stream("libx264", rate=fps)
+ # for x in outputs:
+ # frame = av.VideoFrame.from_ndarray(x, format="rgb24")
+ # packet = stream.encode(frame)
+ # container.mux(packet)
+ # packet = stream.encode(None)
+ # container.mux(packet)
+ # container.close()
+
+ height, width, _ = outputs[0].shape
+
+ # create output container
+ container = av.open(path, mode="w")
+
+ # create video stream
+ codec = "libx264"
+ pixel_format = "yuv420p"
+ stream = container.add_stream(codec, rate=fps)
+ stream.width = width
+ stream.height = height
+ stream.pix_fmt = pixel_format
+ stream.bit_rate = 4000000 # 4Mbit/s
+
+ for frame_array in outputs:
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
+ packets = stream.encode(frame)
+ for packet in packets:
+ container.mux(packet)
+
+ for packet in stream.encode():
+ container.mux(packet)
+
+ container.close()
+
+
+def save_images_grid(
+ videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
+):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = torch.clamp(x, 0, 1)
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ if create_subdir:
+ output_dir = os.path.join(parent_dir, image_name)
+ else:
+ output_dir = parent_dir
+
+ os.makedirs(output_dir, exist_ok=True)
+ for i, x in enumerate(outputs):
+ image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
+ image = Image.fromarray(x)
+ image.save(image_path)
+
+
+# region Encoding prompt
+
+
+def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of videos that should be generated per prompt
+ text_encoder (TextEncoder):
+ text encoder to be used for encoding the prompt
+ """
+ # LoRA and Textual Inversion are not supported in this script
+ # negative prompt and prompt embedding are not supported in this script
+ # clip_skip is not supported in this script because it is not used in the original script
+ data_type = "video" # video only, image is not supported
+
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
+
+ with torch.no_grad():
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
+ prompt_embeds = prompt_outputs.hidden_state
+
+ attention_mask = prompt_outputs.attention_mask
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(device)
+ bs_embed, seq_len = attention_mask.shape
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
+
+ prompt_embeds_dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ if prompt_embeds.ndim == 2:
+ bs_embed, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
+ else:
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, attention_mask
+
+
+def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
+ # constants
+ prompt_template_video = "dit-llm-encode-video"
+ prompt_template = "dit-llm-encode"
+ text_encoder_dtype = torch.float16
+ text_encoder_type = "llm"
+ text_len = 256
+ hidden_state_skip_layer = 2
+ apply_final_norm = False
+ reproduce = False
+
+ text_encoder_2_type = "clipL"
+ text_len_2 = 77
+
+ num_videos = 1
+
+ # if args.prompt_template_video is not None:
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
+ # elif args.prompt_template is not None:
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
+ # else:
+ # crop_start = 0
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
+ max_length = text_len + crop_start
+
+ # prompt_template
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
+
+ # prompt_template_video
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
+
+ # load text encoders
+ logger.info(f"loading text encoder: {args.text_encoder1}")
+ text_encoder = TextEncoder(
+ text_encoder_type=text_encoder_type,
+ max_length=max_length,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=args.text_encoder1,
+ tokenizer_type=text_encoder_type,
+ prompt_template=prompt_template,
+ prompt_template_video=prompt_template_video,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ apply_final_norm=apply_final_norm,
+ reproduce=reproduce,
+ )
+ text_encoder.eval()
+ if fp8_llm:
+ org_dtype = text_encoder.dtype
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
+ text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
+
+ # prepare LLM for fp8
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
+ def forward_hook(module):
+ def forward(hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
+
+ return forward
+
+ for module in llama_model.modules():
+ if module.__class__.__name__ in ["Embedding"]:
+ # print("set", module.__class__.__name__, "to", target_dtype)
+ module.to(target_dtype)
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
+ # print("set", module.__class__.__name__, "hooks")
+ module.forward = forward_hook(module)
+
+ prepare_fp8(text_encoder.model, org_dtype)
+
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
+ text_encoder_2 = TextEncoder(
+ text_encoder_type=text_encoder_2_type,
+ max_length=text_len_2,
+ text_encoder_dtype=text_encoder_dtype,
+ text_encoder_path=args.text_encoder2,
+ tokenizer_type=text_encoder_2_type,
+ reproduce=reproduce,
+ )
+ text_encoder_2.eval()
+
+ # encode prompt
+ logger.info(f"Encoding prompt with text encoder 1")
+ text_encoder.to(device=device)
+ if fp8_llm:
+ with accelerator.autocast():
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
+ else:
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
+ text_encoder = None
+ clean_memory_on_device(device)
+
+ logger.info(f"Encoding prompt with text encoder 2")
+ text_encoder_2.to(device=device)
+ prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
+
+ prompt_embeds = prompt_embeds.to("cpu")
+ prompt_mask = prompt_mask.to("cpu")
+ prompt_embeds_2 = prompt_embeds_2.to("cpu")
+ prompt_mask_2 = prompt_mask_2.to("cpu")
+
+ text_encoder_2 = None
+ clean_memory_on_device(device)
+
+ return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
+
+
+# endregion
+
+
+def prepare_vae(args, device):
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
+ vae.eval()
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
+
+ # set chunk_size to CausalConv3d recursively
+ chunk_size = args.vae_chunk_size
+ if chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
+
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ # elif args.vae_tiling:
+ else:
+ vae.enable_spatial_tiling(True)
+
+ return vae, vae_dtype
+
+
+def encode_to_latents(args, video, device):
+ vae, vae_dtype = prepare_vae(args, device)
+
+ video = video.to(device=device, dtype=vae_dtype)
+ video = video * 2 - 1 # 0, 1 -> -1, 1
+ with torch.no_grad():
+ latents = vae.encode(video).latent_dist.sample()
+
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
+ latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
+ else:
+ latents = latents * vae.config.scaling_factor
+
+ return latents
+
+
+def decode_latents(args, latents, device):
+ vae, vae_dtype = prepare_vae(args, device)
+
+ expand_temporal_dim = False
+ if len(latents.shape) == 4:
+ latents = latents.unsqueeze(2)
+ expand_temporal_dim = True
+ elif len(latents.shape) == 5:
+ pass
+ else:
+ raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
+
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
+ else:
+ latents = latents / vae.config.scaling_factor
+
+ latents = latents.to(device=device, dtype=vae_dtype)
+ with torch.no_grad():
+ image = vae.decode(latents, return_dict=False)[0]
+
+ if expand_temporal_dim:
+ image = image.squeeze(2)
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().float()
+
+ return image
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
+
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
+ parser.add_argument(
+ "--dit_in_channels",
+ type=int,
+ default=None,
+ help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
+ )
+ parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
+
+ # LoRA
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
+ parser.add_argument(
+ "--save_merged_model",
+ type=str,
+ default=None,
+ help="Save merged model to path. If specified, no inference will be performed.",
+ )
+ parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
+
+ # inference
+ parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
+ parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
+ parser.add_argument("--video_length", type=int, default=129, help="video length")
+ parser.add_argument("--fps", type=int, default=24, help="video fps")
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=1.0,
+ help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
+ )
+ parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
+ parser.add_argument(
+ "--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
+ )
+ parser.add_argument(
+ "--split_uncond",
+ action="store_true",
+ help="split unconditional call for classifier free guidance, slower but less memory usage",
+ )
+ parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
+
+ # Flow Matching
+ parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
+
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
+ parser.add_argument(
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
+ )
+ parser.add_argument(
+ "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
+ )
+ parser.add_argument(
+ "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
+ parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
+ parser.add_argument(
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)")
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
+ parser.add_argument(
+ "--compile_args",
+ nargs=4,
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
+ help="Torch.compile settings",
+ )
+
+ args = parser.parse_args()
+
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
+ args.output_type == "images" or args.output_type == "video"
+ ), "latent_path is only supported for images or video output"
+
+ # update dit_weight based on model_base if not exists
+
+ if args.fp8_fast and not args.fp8:
+ raise ValueError("--fp8_fast requires --fp8")
+
+ return args
+
+
+def check_inputs(args):
+ height = args.video_size[0]
+ width = args.video_size[1]
+ video_length = args.video_length
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ return height, width, video_length
+
+
+def main():
+ args = parse_args()
+
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ dit_dtype = torch.bfloat16
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
+ logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
+
+ original_base_names = None
+ if args.latent_path is not None and len(args.latent_path) > 0:
+ original_base_names = []
+ latents_list = []
+ seeds = []
+ for latent_path in args.latent_path:
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
+ seed = 0
+
+ if os.path.splitext(latent_path)[1] != ".safetensors":
+ latents = torch.load(latent_path, map_location="cpu")
+ else:
+ latents = load_file(latent_path)["latent"]
+ with safe_open(latent_path, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ logger.info(f"Loaded metadata: {metadata}")
+
+ if "seeds" in metadata:
+ seed = int(metadata["seeds"])
+
+ seeds.append(seed)
+ latents_list.append(latents)
+
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
+ latents = torch.stack(latents_list, dim=0)
+ else:
+ # prepare accelerator
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
+
+ # load prompt
+ prompt = args.prompt # TODO load prompts from file
+ assert prompt is not None, "prompt is required"
+
+ # check inputs: may be height, width, video_length etc will be changed for each generation in future
+ height, width, video_length = check_inputs(args)
+
+ # encode prompt with LLM and Text Encoder
+ logger.info(f"Encoding prompt: {prompt}")
+
+ do_classifier_free_guidance = args.guidance_scale != 1.0
+ if do_classifier_free_guidance:
+ negative_prompt = args.negative_prompt
+ if negative_prompt is None:
+ logger.info("Negative prompt is not provided, using empty prompt")
+ negative_prompt = ""
+ logger.info(f"Encoding negative prompt: {negative_prompt}")
+ prompt = [negative_prompt, prompt]
+ else:
+ if args.negative_prompt is not None:
+ logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
+
+ prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
+ prompt, args, device, args.fp8_llm, accelerator
+ )
+
+ # encode latents for video2video inference
+ video_latents = None
+ if args.video_path is not None:
+ # v2v inference
+ logger.info(f"Video2Video inference: {args.video_path}")
+ video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
+ if len(video) < video_length:
+ raise ValueError(f"Video length is less than {video_length}")
+ video = np.stack(video, axis=0) # F, H, W, C
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
+ video = video / 255.0
+
+ logger.info(f"Encoding video to latents")
+ video_latents = encode_to_latents(args, video, device)
+ video_latents = video_latents.to(device=device, dtype=dit_dtype)
+
+ clean_memory_on_device(device)
+
+ # encode latents for image2video inference
+ image_latents = None
+ if args.image_path is not None:
+ # i2v inference
+ logger.info(f"Image2Video inference: {args.image_path}")
+
+ image = Image.open(args.image_path)
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
+ image = image / 255.0
+
+ logger.info(f"Encoding image to latents")
+ image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
+ image_latents = image_latents.to(device=device, dtype=dit_dtype)
+
+ clean_memory_on_device(device)
+
+ # load DiT model
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
+ loading_device = "cpu" # if blocks_to_swap > 0 else device
+
+ logger.info(f"Loading DiT model from {args.dit}")
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch"
+
+ # if image_latents is given, the model should be I2V model, so the in_channels should be 32
+ dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
+
+ # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
+ # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
+ # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
+ transformer = load_transformer(
+ args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
+ )
+ transformer.eval()
+
+ # load LoRA weights
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ for i, lora_weight in enumerate(args.lora_weight):
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
+ lora_multiplier = args.lora_multiplier[i]
+ else:
+ lora_multiplier = 1.0
+
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
+ weights_sd = load_file(lora_weight)
+
+ # Filter to exclude keys that are part of single_blocks
+ if args.exclude_single_blocks:
+ filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
+ weights_sd = filtered_weights
+
+ if args.lycoris:
+ lycoris_net, _ = create_network_from_weights(
+ multiplier=lora_multiplier,
+ file=None,
+ weights_sd=weights_sd,
+ unet=transformer,
+ text_encoder=None,
+ vae=None,
+ for_inference=True,
+ )
+ else:
+ network = lora.create_arch_network_from_weights(
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
+ )
+ logger.info("Merging LoRA weights to DiT model")
+
+ # try:
+ # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
+ # info = network.load_state_dict(weights_sd, strict=True)
+ # logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
+ # network.eval()
+ # network.to(device)
+ # except Exception as e:
+ if args.lycoris:
+ lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
+ else:
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
+
+ synchronize_device(device)
+
+ logger.info("LoRA weights loaded")
+
+ # save model here before casting to dit_weight_dtype
+ if args.save_merged_model:
+ logger.info(f"Saving merged model to {args.save_merged_model}")
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
+ logger.info("Merged model saved")
+ return
+
+ logger.info(f"Casting model to {dit_weight_dtype}")
+ transformer.to(dtype=dit_weight_dtype)
+
+ if args.fp8_fast:
+ logger.info("Enabling FP8 acceleration")
+ params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"}
+ for name, param in transformer.named_parameters():
+ dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype
+ param.to(dtype=dtype_to_use)
+ convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep)
+
+ if args.compile:
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
+ logger.info(
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
+ )
+ torch._dynamo.config.cache_size_limit = 32
+ for i, block in enumerate(transformer.single_blocks):
+ compiled_block = torch.compile(
+ block,
+ backend=compile_backend,
+ mode=compile_mode,
+ dynamic=compile_dynamic.lower() in "true",
+ fullgraph=compile_fullgraph.lower() in "true",
+ )
+ transformer.single_blocks[i] = compiled_block
+ for i, block in enumerate(transformer.double_blocks):
+ compiled_block = torch.compile(
+ block,
+ backend=compile_backend,
+ mode=compile_mode,
+ dynamic=compile_dynamic.lower() in "true",
+ fullgraph=compile_fullgraph.lower() in "true",
+ )
+ transformer.double_blocks[i] = compiled_block
+
+ if blocks_to_swap > 0:
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
+ transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
+ transformer.move_to_device_except_swap_blocks(device)
+ transformer.prepare_block_swap_before_forward()
+ else:
+ logger.info(f"Moving model to {device}")
+ transformer.to(device=device)
+ if args.img_in_txt_in_offloading:
+ logger.info("Enable offloading img_in and txt_in to CPU")
+ transformer.enable_img_in_txt_in_offloading()
+
+ # load scheduler
+ logger.info(f"Loading scheduler")
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
+
+ # Prepare timesteps
+ num_inference_steps = args.infer_steps
+ scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
+ timesteps = scheduler.timesteps
+
+ # Prepare generator
+ num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
+ seed = args.seed
+ if seed is None:
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
+ elif isinstance(seed, int):
+ seeds = [seed + i for i in range(num_videos_per_prompt)]
+ else:
+ raise ValueError(f"Seed must be an integer or None, got {seed}.")
+ generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
+
+ # Prepare noisy latents
+ num_channels_latents = 16 # transformer.config.in_channels
+ vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
+
+ vae_ver = vae.VAE_VER
+ if "884" in vae_ver:
+ latent_video_length = (video_length - 1) // 4 + 1
+ elif "888" in vae_ver:
+ latent_video_length = (video_length - 1) // 8 + 1
+ else:
+ latent_video_length = video_length
+
+ # shape = (
+ # num_videos_per_prompt,
+ # num_channels_latents,
+ # latent_video_length,
+ # height // vae_scale_factor,
+ # width // vae_scale_factor,
+ # )
+ # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
+
+ # make first N frames to be the same if the given seed is same
+ shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
+ latents = []
+ for i in range(latent_video_length):
+ latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
+ latents = torch.cat(latents, dim=2)
+
+ # pad image_latents to match the length of video_latents
+ if image_latents is not None:
+ zero_latents = torch.zeros_like(latents)
+ zero_latents[:, :, :1, :, :] = image_latents
+ image_latents = zero_latents
+
+ if args.video_path is not None:
+ # v2v inference
+ noise = latents
+ assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
+
+ num_inference_steps = int(num_inference_steps * args.strength)
+ timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
+ t = timestep_start / 1000.0
+ latents = noise * t + video_latents * (1 - t)
+
+ timesteps = timesteps[-num_inference_steps:]
+
+ logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
+
+ # FlowMatchDiscreteScheduler does not have init_noise_sigma
+
+ # Denoising loop
+ embedded_guidance_scale = args.embedded_cfg_scale
+ if embedded_guidance_scale is not None:
+ guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
+ guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
+ if do_classifier_free_guidance:
+ guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
+ else:
+ guidance_expand = None
+ freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
+ # n_tokens = freqs_cos.shape[0]
+
+ # move and cast all inputs to the correct device and dtype
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
+ prompt_mask = prompt_mask.to(device=device)
+ prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
+ prompt_mask_2 = prompt_mask_2.to(device=device)
+
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
+
+ # assert split_uncond and split_attn
+ if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
+ logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
+ args.split_uncond = True
+
+ # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
+ with tqdm(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latents = scheduler.scale_model_input(latents, t)
+
+ # predict the noise residual
+ with torch.no_grad(), accelerator.autocast():
+ latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
+ if image_latents is not None:
+ latents_image_input = (
+ image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
+ )
+ latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
+
+ batch_size = 1 if args.split_uncond else latents_input.shape[0]
+
+ noise_pred_list = []
+ for j in range(0, latents_input.shape[0], batch_size):
+ noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
+ latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
+ t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
+ text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
+ text_mask=prompt_mask[j : j + batch_size], # [1, 256]
+ text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
+ freqs_cos=freqs_cos, # [seqlen, head_dim]
+ freqs_sin=freqs_sin, # [seqlen, head_dim]
+ guidance=guidance_expand[j : j + batch_size], # [1]
+ return_dict=True,
+ )["x"]
+ noise_pred_list.append(noise_pred)
+ noise_pred = torch.cat(noise_pred_list, dim=0)
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # # SkyReels' rescale noise config is omitted for now
+ # if guidance_rescale > 0.0:
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # noise_pred = rescale_noise_cfg(
+ # noise_pred,
+ # noise_pred_cond,
+ # guidance_rescale=self.guidance_rescale,
+ # )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
+ if progress_bar is not None:
+ progress_bar.update()
+
+ # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
+ # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
+
+ latents = latents.detach().cpu()
+ transformer = None
+ clean_memory_on_device(device)
+
+ # Save samples
+ output_type = args.output_type
+ save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ if output_type == "latent" or output_type == "both":
+ # save latent
+ for i, latent in enumerate(latents):
+ latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
+
+ if args.no_metadata:
+ metadata = None
+ else:
+ metadata = {
+ "seeds": f"{seeds[i]}",
+ "prompt": f"{args.prompt}",
+ "height": f"{height}",
+ "width": f"{width}",
+ "video_length": f"{video_length}",
+ "infer_steps": f"{num_inference_steps}",
+ "guidance_scale": f"{args.guidance_scale}",
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
+ }
+ if args.negative_prompt is not None:
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
+ sd = {"latent": latent}
+ save_file(sd, latent_path, metadata=metadata)
+
+ logger.info(f"Latent save to: {latent_path}")
+ if output_type == "video" or output_type == "both":
+ # save video
+ videos = decode_latents(args, latents, device)
+ for i, sample in enumerate(videos):
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
+ sample = sample.unsqueeze(0)
+ video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
+ save_videos_grid(sample, video_path, fps=args.fps)
+ logger.info(f"Sample save to: {video_path}")
+ elif output_type == "images":
+ # save images
+ videos = decode_latents(args, latents, device)
+ for i, sample in enumerate(videos):
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
+ sample = sample.unsqueeze(0)
+ image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
+ save_images_grid(sample, save_path, image_name)
+ logger.info(f"Sample images save to: {save_path}/{image_name}")
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/hv_train.py b/hv_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..501f1648b3bd0df086f0ffc395ea58b146adc8ea
--- /dev/null
+++ b/hv_train.py
@@ -0,0 +1,1721 @@
+import ast
+import asyncio
+from datetime import timedelta
+import gc
+import importlib
+import argparse
+import math
+import os
+import pathlib
+import re
+import sys
+import random
+import time
+import json
+from multiprocessing import Value
+from typing import Any, Dict, List, Optional
+import accelerate
+import numpy as np
+from packaging.version import Version
+
+import huggingface_hub
+import toml
+
+import torch
+from tqdm import tqdm
+from accelerate.utils import set_seed
+from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
+from safetensors.torch import load_file, save_file
+import transformers
+from diffusers.optimization import (
+ SchedulerType as DiffusersSchedulerType,
+ TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
+)
+from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
+
+from dataset import config_utils
+from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
+import hunyuan_model.text_encoder as text_encoder_module
+from hunyuan_model.vae import load_vae
+import hunyuan_model.vae as vae_module
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+import networks.lora as lora_module
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO
+
+import logging
+
+from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
+
+# TODO make separate file for some functions to commonize with other scripts
+
+
+def clean_memory_on_device(device: torch.device):
+ r"""
+ Clean memory on the specified device, will be called from training scripts.
+ """
+ gc.collect()
+
+ # device may "cuda" or "cuda:0", so we need to check the type of device
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ if device.type == "xpu":
+ torch.xpu.empty_cache()
+ if device.type == "mps":
+ torch.mps.empty_cache()
+
+
+# for collate_fn: epoch and step is multiprocessing.Value
+class collator_class:
+ def __init__(self, epoch, step, dataset):
+ self.current_epoch = epoch
+ self.current_step = step
+ self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
+
+ def __call__(self, examples):
+ worker_info = torch.utils.data.get_worker_info()
+ # worker_info is None in the main process
+ if worker_info is not None:
+ dataset = worker_info.dataset
+ else:
+ dataset = self.dataset
+
+ # set epoch and step
+ dataset.set_current_epoch(self.current_epoch.value)
+ dataset.set_current_step(self.current_step.value)
+ return examples[0]
+
+
+def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
+ """
+ DeepSpeed is not supported in this script currently.
+ """
+ if args.logging_dir is None:
+ logging_dir = None
+ else:
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
+
+ if args.log_with is None:
+ if logging_dir is not None:
+ log_with = "tensorboard"
+ else:
+ log_with = None
+ else:
+ log_with = args.log_with
+ if log_with in ["tensorboard", "all"]:
+ if logging_dir is None:
+ raise ValueError(
+ "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
+ )
+ if log_with in ["wandb", "all"]:
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError("No wandb / wandb がインストールされていないようです")
+ if logging_dir is not None:
+ os.makedirs(logging_dir, exist_ok=True)
+ os.environ["WANDB_DIR"] = logging_dir
+ if args.wandb_api_key is not None:
+ wandb.login(key=args.wandb_api_key)
+
+ kwargs_handlers = [
+ (
+ InitProcessGroupKwargs(
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
+ init_method=(
+ "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
+ ),
+ timeout=timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
+ )
+ if torch.cuda.device_count() > 1
+ else None
+ ),
+ (
+ DistributedDataParallelKwargs(
+ gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
+ )
+ if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
+ else None
+ ),
+ ]
+ kwargs_handlers = [i for i in kwargs_handlers if i is not None]
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=log_with,
+ project_dir=logging_dir,
+ kwargs_handlers=kwargs_handlers,
+ )
+ print("accelerator device:", accelerator.device)
+ return accelerator
+
+
+def line_to_prompt_dict(line: str) -> dict:
+ # subset of gen_img_diffusers
+ prompt_args = line.split(" --")
+ prompt_dict = {}
+ prompt_dict["prompt"] = prompt_args[0]
+
+ for parg in prompt_args:
+ try:
+ m = re.match(r"w (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["width"] = int(m.group(1))
+ continue
+
+ m = re.match(r"h (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["height"] = int(m.group(1))
+ continue
+
+ m = re.match(r"f (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["frame_count"] = int(m.group(1))
+ continue
+
+ m = re.match(r"d (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["seed"] = int(m.group(1))
+ continue
+
+ m = re.match(r"s (\d+)", parg, re.IGNORECASE)
+ if m: # steps
+ prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
+ continue
+
+ # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
+ # if m: # scale
+ # prompt_dict["scale"] = float(m.group(1))
+ # continue
+ # m = re.match(r"n (.+)", parg, re.IGNORECASE)
+ # if m: # negative prompt
+ # prompt_dict["negative_prompt"] = m.group(1)
+ # continue
+
+ except ValueError as ex:
+ logger.error(f"Exception in parsing / 解析エラー: {parg}")
+ logger.error(ex)
+
+ return prompt_dict
+
+
+def load_prompts(prompt_file: str) -> list[Dict]:
+ # read prompts
+ if prompt_file.endswith(".txt"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+ prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
+ elif prompt_file.endswith(".toml"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ data = toml.load(f)
+ prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
+ elif prompt_file.endswith(".json"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ prompts = json.load(f)
+
+ # preprocess prompts
+ for i in range(len(prompts)):
+ prompt_dict = prompts[i]
+ if isinstance(prompt_dict, str):
+ prompt_dict = line_to_prompt_dict(prompt_dict)
+ prompts[i] = prompt_dict
+ assert isinstance(prompt_dict, dict)
+
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
+ prompt_dict["enum"] = i
+ prompt_dict.pop("subset", None)
+
+ return prompts
+
+
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
+):
+ """Compute the density for sampling the timesteps when doing SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device="cpu")
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device="cpu")
+ return u
+
+
+def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+
+ # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
+ if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
+ # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
+ # round to nearest timestep
+ logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
+ step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
+ else:
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+
+def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
+ """Computes loss weighting scheme for SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
+ if weighting_scheme == "sigma_sqrt":
+ weighting = (sigmas**-2.0).float()
+ else:
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
+ weighting = 2 / (math.pi * bot)
+ else:
+ weighting = None # torch.ones_like(sigmas)
+ return weighting
+
+
+class FineTuningTrainer:
+ def __init__(self):
+ pass
+
+ def process_sample_prompts(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ sample_prompts: str,
+ text_encoder1: str,
+ text_encoder2: str,
+ fp8_llm: bool,
+ ):
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
+ prompts = load_prompts(sample_prompts)
+
+ def encode_for_text_encoder(text_encoder, is_llm=True):
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
+ with accelerator.autocast(), torch.no_grad():
+ for prompt_dict in prompts:
+ for p in [prompt_dict.get("prompt", "")]:
+ if p not in sample_prompts_te_outputs:
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
+
+ data_type = "video"
+ text_inputs = text_encoder.text2tokens(p, data_type=data_type)
+
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
+ sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
+
+ return sample_prompts_te_outputs
+
+ # Load Text Encoder 1 and encode
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
+ logger.info(f"loading text encoder 1: {text_encoder1}")
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
+
+ logger.info("encoding with Text Encoder 1")
+ te_outputs_1 = encode_for_text_encoder(text_encoder_1)
+ del text_encoder_1
+
+ # Load Text Encoder 2 and encode
+ logger.info(f"loading text encoder 2: {text_encoder2}")
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
+
+ logger.info("encoding with Text Encoder 2")
+ te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
+ del text_encoder_2
+
+ # prepare sample parameters
+ sample_parameters = []
+ for prompt_dict in prompts:
+ prompt_dict_copy = prompt_dict.copy()
+ p = prompt_dict.get("prompt", "")
+ prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
+ prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
+ prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
+ prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
+ sample_parameters.append(prompt_dict_copy)
+
+ clean_memory_on_device(accelerator.device)
+
+ return sample_parameters
+
+ def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
+ # adamw, adamw8bit, adafactor
+
+ optimizer_type = args.optimizer_type.lower()
+
+ # split optimizer_type and optimizer_args
+ optimizer_kwargs = {}
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
+ for arg in args.optimizer_args:
+ key, value = arg.split("=")
+ value = ast.literal_eval(value)
+ optimizer_kwargs[key] = value
+
+ lr = args.learning_rate
+ optimizer = None
+ optimizer_class = None
+
+ if optimizer_type.endswith("8bit".lower()):
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
+
+ if optimizer_type == "AdamW8bit".lower():
+ logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
+ optimizer_class = bnb.optim.AdamW8bit
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ elif optimizer_type == "Adafactor".lower():
+ # Adafactor: check relative_step and warmup_init
+ if "relative_step" not in optimizer_kwargs:
+ optimizer_kwargs["relative_step"] = True # default
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
+ logger.info(
+ f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
+ )
+ optimizer_kwargs["relative_step"] = True
+ logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
+
+ if optimizer_kwargs["relative_step"]:
+ logger.info(f"relative_step is true / relative_stepがtrueです")
+ if lr != 0.0:
+ logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
+ args.learning_rate = None
+
+ if args.lr_scheduler != "adafactor":
+ logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
+
+ lr = None
+ else:
+ if args.max_grad_norm != 0.0:
+ logger.warning(
+ f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
+ )
+ if args.lr_scheduler != "constant_with_warmup":
+ logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
+ logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
+
+ optimizer_class = transformers.optimization.Adafactor
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ elif optimizer_type == "AdamW".lower():
+ logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
+ optimizer_class = torch.optim.AdamW
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ if optimizer is None:
+ # 任意のoptimizerを使う
+ case_sensitive_optimizer_type = args.optimizer_type # not lower
+ logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
+
+ if "." not in case_sensitive_optimizer_type: # from torch.optim
+ optimizer_module = torch.optim
+ else: # from other library
+ values = case_sensitive_optimizer_type.split(".")
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
+ case_sensitive_optimizer_type = values[-1]
+
+ optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ # for logging
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
+
+ # get train and eval functions
+ if hasattr(optimizer, "train") and callable(optimizer.train):
+ train_fn = optimizer.train
+ eval_fn = optimizer.eval
+ else:
+ train_fn = lambda: None
+ eval_fn = lambda: None
+
+ return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
+
+ def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
+ return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
+
+ def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
+ # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
+ # this scheduler is used for logging only.
+ # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
+ class DummyScheduler:
+ def __init__(self, optimizer: torch.optim.Optimizer):
+ self.optimizer = optimizer
+
+ def step(self):
+ pass
+
+ def get_last_lr(self):
+ return [group["lr"] for group in self.optimizer.param_groups]
+
+ return DummyScheduler(optimizer)
+
+ def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
+ """
+ Unified API to get any scheduler from its name.
+ """
+ # if schedulefree optimizer, return dummy scheduler
+ if self.is_schedulefree_optimizer(optimizer, args):
+ return self.get_dummy_scheduler(optimizer)
+
+ name = args.lr_scheduler
+ num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
+ num_warmup_steps: Optional[int] = (
+ int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
+ )
+ num_decay_steps: Optional[int] = (
+ int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
+ )
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
+ num_cycles = args.lr_scheduler_num_cycles
+ power = args.lr_scheduler_power
+ timescale = args.lr_scheduler_timescale
+ min_lr_ratio = args.lr_scheduler_min_lr_ratio
+
+ lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
+ if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
+ for arg in args.lr_scheduler_args:
+ key, value = arg.split("=")
+ value = ast.literal_eval(value)
+ lr_scheduler_kwargs[key] = value
+
+ def wrap_check_needless_num_warmup_steps(return_vals):
+ if num_warmup_steps is not None and num_warmup_steps != 0:
+ raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
+ return return_vals
+
+ # using any lr_scheduler from other library
+ if args.lr_scheduler_type:
+ lr_scheduler_type = args.lr_scheduler_type
+ logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
+ if "." not in lr_scheduler_type: # default to use torch.optim
+ lr_scheduler_module = torch.optim.lr_scheduler
+ else:
+ values = lr_scheduler_type.split(".")
+ lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
+ lr_scheduler_type = values[-1]
+ lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
+ lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
+ return lr_scheduler
+
+ if name.startswith("adafactor"):
+ assert (
+ type(optimizer) == transformers.optimization.Adafactor
+ ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
+ initial_lr = float(name.split(":")[1])
+ # logger.info(f"adafactor scheduler init lr {initial_lr}")
+ return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
+
+ if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
+ name = DiffusersSchedulerType(name)
+ schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
+ return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
+
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+ if name == SchedulerType.CONSTANT:
+ return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
+
+ if name == SchedulerType.INVERSE_SQRT:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ **lr_scheduler_kwargs,
+ )
+
+ if name == SchedulerType.POLYNOMIAL:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ power=power,
+ **lr_scheduler_kwargs,
+ )
+
+ if name == SchedulerType.COSINE_WITH_MIN_LR:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles / 2,
+ min_lr_rate=min_lr_ratio,
+ **lr_scheduler_kwargs,
+ )
+
+ # these schedulers do not require `num_decay_steps`
+ if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **lr_scheduler_kwargs,
+ )
+
+ # All other schedulers require `num_decay_steps`
+ if num_decay_steps is None:
+ raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_stable_steps=num_stable_steps,
+ num_decay_steps=num_decay_steps,
+ num_cycles=num_cycles / 2,
+ min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
+ **lr_scheduler_kwargs,
+ )
+
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_decay_steps=num_decay_steps,
+ **lr_scheduler_kwargs,
+ )
+
+ def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
+ if not args.resume:
+ return False
+
+ if not args.resume_from_huggingface:
+ logger.info(f"resume training from local state: {args.resume}")
+ accelerator.load_state(args.resume)
+ return True
+
+ logger.info(f"resume training from huggingface state: {args.resume}")
+ repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
+ path_in_repo = "/".join(args.resume.split("/")[2:])
+ revision = None
+ repo_type = None
+ if ":" in path_in_repo:
+ divided = path_in_repo.split(":")
+ if len(divided) == 2:
+ path_in_repo, revision = divided
+ repo_type = "model"
+ else:
+ path_in_repo, revision, repo_type = divided
+ logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
+
+ list_files = huggingface_utils.list_dir(
+ repo_id=repo_id,
+ subfolder=path_in_repo,
+ revision=revision,
+ token=args.huggingface_token,
+ repo_type=repo_type,
+ )
+
+ async def download(filename) -> str:
+ def task():
+ return huggingface_hub.hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ revision=revision,
+ repo_type=repo_type,
+ token=args.huggingface_token,
+ )
+
+ return await asyncio.get_event_loop().run_in_executor(None, task)
+
+ loop = asyncio.get_event_loop()
+ results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
+ if len(results) == 0:
+ raise ValueError(
+ "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
+ )
+ dirname = os.path.dirname(results[0])
+ accelerator.load_state(dirname)
+
+ return True
+
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
+ pass
+
+ def get_noisy_model_input_and_timesteps(
+ self,
+ args: argparse.Namespace,
+ noise: torch.Tensor,
+ latents: torch.Tensor,
+ noise_scheduler: FlowMatchDiscreteScheduler,
+ device: torch.device,
+ dtype: torch.dtype,
+ ):
+ batch_size = noise.shape[0]
+
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
+ # Simple random t-based noise sampling
+ if args.timestep_sampling == "sigmoid":
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
+ else:
+ t = torch.rand((batch_size,), device=device)
+
+ elif args.timestep_sampling == "shift":
+ shift = args.discrete_flow_shift
+ logits_norm = torch.randn(batch_size, device=device)
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
+ t = logits_norm.sigmoid()
+ t = (t * shift) / (1 + (shift - 1) * t)
+
+ t_min = args.min_timestep if args.min_timestep is not None else 0
+ t_max = args.max_timestep if args.max_timestep is not None else 1000.0
+ t_min /= 1000.0
+ t_max /= 1000.0
+ t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
+
+ timesteps = t * 1000.0
+ t = t.view(-1, 1, 1, 1, 1)
+ noisy_model_input = (1 - t) * latents + t * noise
+
+ timesteps += 1 # 1 to 1000
+ else:
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=batch_size,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ # indices = (u * noise_scheduler.config.num_train_timesteps).long()
+ t_min = args.min_timestep if args.min_timestep is not None else 0
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
+ indices = (u * (t_max - t_min) + t_min).long()
+
+ timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
+
+ return noisy_model_input, timesteps
+
+ def train(self, args):
+ if args.seed is None:
+ args.seed = random.randint(0, 2**32)
+ set_seed(args.seed)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
+
+ current_epoch = Value("i", 0)
+ current_step = Value("i", 0)
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
+ collator = collator_class(current_epoch, current_step, ds_for_collator)
+
+ # prepare accelerator
+ logger.info("preparing accelerator")
+ accelerator = prepare_accelerator(args)
+ is_main_process = accelerator.is_main_process
+
+ # prepare dtype
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # HunyuanVideo specific
+ vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
+
+ # get embedding for sampling images
+ sample_parameters = vae = None
+ if args.sample_prompts:
+ sample_parameters = self.process_sample_prompts(
+ args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
+ )
+
+ # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
+ vae.requires_grad_(False)
+ vae.eval()
+
+ if args.vae_chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ elif args.vae_tiling:
+ vae.enable_spatial_tiling(True)
+
+ # load DiT model
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
+ loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
+
+ logger.info(f"Loading DiT model from {args.dit}")
+ if args.sdpa:
+ attn_mode = "torch"
+ elif args.flash_attn:
+ attn_mode = "flash"
+ elif args.sage_attn:
+ attn_mode = "sageattn"
+ elif args.xformers:
+ attn_mode = "xformers"
+ else:
+ raise ValueError(
+ f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください"
+ )
+ transformer = load_transformer(
+ args.dit, attn_mode, args.split_attn, loading_device, None, in_channels=args.dit_in_channels
+ ) # load as is
+
+ if blocks_to_swap > 0:
+ logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
+ transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
+ transformer.move_to_device_except_swap_blocks(accelerator.device)
+ if args.img_in_txt_in_offloading:
+ logger.info("Enable offloading img_in and txt_in to CPU")
+ transformer.enable_img_in_txt_in_offloading()
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ # prepare optimizer, data loader etc.
+ accelerator.print("prepare optimizer, data loader etc.")
+
+ transformer.requires_grad_(False)
+ if accelerator.is_main_process:
+ accelerator.print(f"Trainable modules '{args.trainable_modules}'.")
+ for name, param in transformer.named_parameters():
+ for trainable_module_name in args.trainable_modules:
+ if trainable_module_name in name:
+ param.requires_grad = True
+ break
+
+ total_params = list(transformer.parameters())
+ trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ logger.info(
+ f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M"
+ )
+ optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
+ args, trainable_params
+ )
+
+ # prepare dataloader
+
+ # num workers for data loader: if 0, persistent_workers is not available
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset_group,
+ batch_size=1,
+ shuffle=True,
+ collate_fn=collator,
+ num_workers=n_workers,
+ persistent_workers=args.persistent_data_loader_workers,
+ )
+
+ # calculate max_train_steps
+ if args.max_train_epochs is not None:
+ args.max_train_steps = args.max_train_epochs * math.ceil(
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
+ )
+ accelerator.print(
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
+ )
+
+ # send max_train_steps to train_dataset_group
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
+
+ # prepare lr_scheduler
+ lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
+
+ # prepare training model. accelerator does some magic here
+
+ # experimental feature: train the model with gradients in fp16/bf16
+ dit_dtype = torch.float32
+ if args.full_fp16:
+ assert (
+ args.mixed_precision == "fp16"
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
+ accelerator.print("enable full fp16 training.")
+ dit_weight_dtype = torch.float16
+ elif args.full_bf16:
+ assert (
+ args.mixed_precision == "bf16"
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
+ accelerator.print("enable full bf16 training.")
+ dit_weight_dtype = torch.bfloat16
+ else:
+ dit_weight_dtype = torch.float32
+
+ # TODO add fused optimizer and stochastic rounding
+
+ # cast model to dit_weight_dtype
+ # if dit_dtype != dit_weight_dtype:
+ logger.info(f"casting model to {dit_weight_dtype}")
+ transformer.to(dit_weight_dtype)
+
+ if blocks_to_swap > 0:
+ transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
+ accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
+ accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
+ else:
+ transformer = accelerator.prepare(transformer)
+
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
+
+ transformer.train()
+
+ if args.full_fp16:
+ # patch accelerator for fp16 training
+ # def patch_accelerator_for_fp16_training(accelerator):
+ org_unscale_grads = accelerator.scaler._unscale_grads_
+
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
+
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
+
+ # resume from local or huggingface. accelerator.step is set
+ self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
+
+ # epoch数を計算する
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # 学習する
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ accelerator.print("running training / 学習開始")
+ accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
+ accelerator.print(
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
+ )
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
+
+ if accelerator.is_main_process:
+ init_kwargs = {}
+ if args.wandb_run_name:
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
+ if args.log_tracker_config is not None:
+ init_kwargs = toml.load(args.log_tracker_config)
+ accelerator.init_trackers(
+ "hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name,
+ config=train_utils.get_sanitized_config_or_none(args),
+ init_kwargs=init_kwargs,
+ )
+
+ # TODO skip until initial step
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
+
+ epoch_to_start = 0
+ global_step = 0
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
+
+ loss_recorder = train_utils.LossRecorder()
+ del train_dataset_group
+
+ # function for saving/removing
+ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
+ os.makedirs(args.output_dir, exist_ok=True)
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
+
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
+
+ title = args.metadata_title if args.metadata_title is not None else args.output_name
+ if args.min_timestep is not None or args.max_timestep is not None:
+ min_time_step = args.min_timestep if args.min_timestep is not None else 0
+ max_time_step = args.max_timestep if args.max_timestep is not None else 1000
+ md_timesteps = (min_time_step, max_time_step)
+ else:
+ md_timesteps = None
+
+ sai_metadata = sai_model_spec.build_metadata(
+ None,
+ ARCHITECTURE_HUNYUAN_VIDEO,
+ time.time(),
+ title,
+ None,
+ args.metadata_author,
+ args.metadata_description,
+ args.metadata_license,
+ args.metadata_tags,
+ timesteps=md_timesteps,
+ is_lora=False,
+ )
+
+ save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata)
+ if args.huggingface_repo_id is not None:
+ huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
+
+ def remove_model(old_ckpt_name):
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
+ if os.path.exists(old_ckpt_file):
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
+ os.remove(old_ckpt_file)
+
+ # For --sample_at_first
+ optimizer_eval_fn()
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
+ optimizer_train_fn()
+ if len(accelerator.trackers) > 0:
+ # log empty object to commit the sample images to wandb
+ accelerator.log({}, step=0)
+
+ # training loop
+
+ # log device and dtype for each model
+ logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
+
+ clean_memory_on_device(accelerator.device)
+
+ pos_embed_cache = {}
+
+ for epoch in range(epoch_to_start, num_train_epochs):
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
+ current_epoch.value = epoch + 1
+
+ for step, batch in enumerate(train_dataloader):
+ latents, llm_embeds, llm_mask, clip_embeds = batch
+ bsz = latents.shape[0]
+ current_step.value = global_step
+
+ with accelerator.accumulate(transformer):
+ latents = latents * vae_module.SCALING_FACTOR
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+
+ # calculate model input and timesteps
+ noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
+ args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
+ )
+
+ weighting = compute_loss_weighting_for_sd3(
+ args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
+ )
+
+ # ensure guidance_scale in args is float
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
+
+ # ensure the hidden state will require grad
+ if args.gradient_checkpointing:
+ noisy_model_input.requires_grad_(True)
+ guidance_vec.requires_grad_(True)
+
+ pos_emb_shape = latents.shape[1:]
+ if pos_emb_shape not in pos_embed_cache:
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(
+ accelerator.unwrap_model(transformer), latents.shape[2:]
+ )
+ # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
+ # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
+ pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
+ else:
+ freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
+
+ # call DiT
+ latents = latents.to(device=accelerator.device, dtype=dit_dtype)
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype)
+ # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
+ # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
+ # llm_mask = llm_mask.to(device=accelerator.device)
+ # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
+ with accelerator.autocast():
+ model_pred = transformer(
+ noisy_model_input,
+ timesteps,
+ text_states=llm_embeds,
+ text_mask=llm_mask,
+ text_states_2=clip_embeds,
+ freqs_cos=freqs_cos,
+ freqs_sin=freqs_sin,
+ guidance=guidance_vec,
+ return_dict=False,
+ )
+
+ # flow matching loss
+ target = noise - latents
+
+ loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none")
+
+ if weighting is not None:
+ loss = loss * weighting
+ # loss = loss.mean([1, 2, 3])
+ # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
+ # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
+
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ # self.all_reduce_network(accelerator, network) # sync DDP grad manually
+ state = accelerate.PartialState()
+ if state.distributed_type != accelerate.DistributedType.NO:
+ for param in transformer.parameters():
+ if param.grad is not None:
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
+
+ if args.max_grad_norm != 0.0:
+ params_to_clip = accelerator.unwrap_model(transformer).parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ optimizer_eval_fn()
+ self.sample_images(
+ accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
+ )
+
+ # 指定ステップごとにモデルを保存
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
+ save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch)
+
+ if args.save_state:
+ train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
+
+ remove_step_no = train_utils.get_remove_step_no(args, global_step)
+ if remove_step_no is not None:
+ remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
+ remove_model(remove_ckpt_name)
+ optimizer_train_fn()
+
+ current_loss = loss.detach().item()
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
+ avr_loss: float = loss_recorder.moving_average
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if len(accelerator.trackers) > 0:
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if len(accelerator.trackers) > 0:
+ logs = {"loss/epoch": loss_recorder.moving_average}
+ accelerator.log(logs, step=epoch + 1)
+
+ accelerator.wait_for_everyone()
+
+ # 指定エポックごとにモデルを保存
+ optimizer_eval_fn()
+ if args.save_every_n_epochs is not None:
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
+ if is_main_process and saving:
+ ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
+ save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1)
+
+ remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
+ if remove_epoch_no is not None:
+ remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
+ remove_model(remove_ckpt_name)
+
+ if args.save_state:
+ train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
+
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
+ optimizer_train_fn()
+
+ # end of epoch
+
+ if is_main_process:
+ transformer = accelerator.unwrap_model(transformer)
+
+ accelerator.end_training()
+ optimizer_eval_fn()
+
+ if args.save_state or args.save_state_on_train_end:
+ train_utils.save_state_on_train_end(args, accelerator)
+
+ if is_main_process:
+ ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
+ save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True)
+
+ logger.info("model saved.")
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ def int_or_float(value):
+ if value.endswith("%"):
+ try:
+ return float(value[:-1]) / 100.0
+ except ValueError:
+ raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
+ try:
+ float_value = float(value)
+ if float_value >= 1 and float_value.is_integer():
+ return int(value)
+ return float(value)
+ except ValueError:
+ raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
+
+ parser = argparse.ArgumentParser()
+
+ # general settings
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ default=None,
+ help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
+ )
+ parser.add_argument(
+ "--dataset_config",
+ type=pathlib.Path,
+ default=None,
+ required=True,
+ help="config file for dataset / データセットの設定ファイル",
+ )
+
+ # training settings
+ parser.add_argument(
+ "--sdpa",
+ action="store_true",
+ help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
+ )
+ parser.add_argument(
+ "--flash_attn",
+ action="store_true",
+ help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
+ )
+ parser.add_argument(
+ "--sage_attn",
+ action="store_true",
+ help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
+ )
+ parser.add_argument(
+ "--xformers",
+ action="store_true",
+ help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要",
+ )
+ parser.add_argument(
+ "--split_attn",
+ action="store_true",
+ help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)"
+ " / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)",
+ )
+
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
+ parser.add_argument(
+ "--max_train_epochs",
+ type=int,
+ default=None,
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
+ )
+ parser.add_argument(
+ "--max_data_loader_n_workers",
+ type=int,
+ default=8,
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
+ )
+ parser.add_argument(
+ "--persistent_data_loader_workers",
+ action="store_true",
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
+ parser.add_argument(
+ "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help="use mixed precision / 混合精度を使う場合、その精度",
+ )
+ parser.add_argument("--trainable_modules", nargs="+", default=".", help="Enter a list of trainable modules")
+
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default=None,
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
+ )
+ parser.add_argument(
+ "--log_with",
+ type=str,
+ default=None,
+ choices=["tensorboard", "wandb", "all"],
+ help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
+ )
+ parser.add_argument(
+ "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
+ )
+ parser.add_argument(
+ "--log_tracker_name",
+ type=str,
+ default=None,
+ help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
+ )
+ parser.add_argument(
+ "--wandb_run_name",
+ type=str,
+ default=None,
+ help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
+ )
+ parser.add_argument(
+ "--log_tracker_config",
+ type=str,
+ default=None,
+ help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
+ )
+ parser.add_argument(
+ "--wandb_api_key",
+ type=str,
+ default=None,
+ help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
+ )
+ parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
+
+ parser.add_argument(
+ "--ddp_timeout",
+ type=int,
+ default=None,
+ help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
+ )
+ parser.add_argument(
+ "--ddp_gradient_as_bucket_view",
+ action="store_true",
+ help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
+ )
+ parser.add_argument(
+ "--ddp_static_graph",
+ action="store_true",
+ help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
+ )
+
+ parser.add_argument(
+ "--sample_every_n_steps",
+ type=int,
+ default=None,
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
+ )
+ parser.add_argument(
+ "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
+ )
+ parser.add_argument(
+ "--sample_every_n_epochs",
+ type=int,
+ default=None,
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
+ )
+ parser.add_argument(
+ "--sample_prompts",
+ type=str,
+ default=None,
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
+ )
+
+ # optimizer and lr scheduler settings
+ parser.add_argument(
+ "--optimizer_type",
+ type=str,
+ default="",
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
+ "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
+ )
+ parser.add_argument(
+ "--optimizer_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
+ )
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
+ parser.add_argument(
+ "--max_grad_norm",
+ default=1.0,
+ type=float,
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
+ )
+
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps",
+ type=int_or_float,
+ default=0,
+ help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
+ " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
+ )
+ parser.add_argument(
+ "--lr_decay_steps",
+ type=int_or_float,
+ default=0,
+ help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
+ " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
+ )
+ parser.add_argument(
+ "--lr_scheduler_num_cycles",
+ type=int,
+ default=1,
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
+ )
+ parser.add_argument(
+ "--lr_scheduler_power",
+ type=float,
+ default=1,
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
+ )
+ parser.add_argument(
+ "--lr_scheduler_timescale",
+ type=int,
+ default=None,
+ help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
+ + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
+ )
+ parser.add_argument(
+ "--lr_scheduler_min_lr_ratio",
+ type=float,
+ default=None,
+ help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
+ + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
+ )
+ parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
+ parser.add_argument(
+ "--lr_scheduler_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
+ )
+
+ # model settings
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
+ parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
+ parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
+ parser.add_argument(
+ "--vae_tiling",
+ action="store_true",
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
+ " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
+ parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
+
+ parser.add_argument(
+ "--blocks_to_swap",
+ type=int,
+ default=None,
+ help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
+ )
+ parser.add_argument(
+ "--img_in_txt_in_offloading",
+ action="store_true",
+ help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
+ )
+
+ # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
+ parser.add_argument(
+ "--timestep_sampling",
+ choices=["sigma", "uniform", "sigmoid", "shift"],
+ default="sigma",
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
+ )
+ parser.add_argument(
+ "--discrete_flow_shift",
+ type=float,
+ default=1.0,
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
+ )
+ parser.add_argument(
+ "--sigmoid_scale",
+ type=float,
+ default=1.0,
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
+ help="weighting scheme for timestep distribution. Default is none"
+ " / タイムステップ分布の重み付けスキーム、デフォルトはnone",
+ )
+ parser.add_argument(
+ "--logit_mean",
+ type=float,
+ default=0.0,
+ help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
+ )
+ parser.add_argument(
+ "--logit_std",
+ type=float,
+ default=1.0,
+ help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
+ )
+ parser.add_argument(
+ "--min_timestep",
+ type=int,
+ default=None,
+ help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
+ )
+ parser.add_argument(
+ "--max_timestep",
+ type=int,
+ default=None,
+ help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
+ )
+
+ # save and load settings
+ parser.add_argument(
+ "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
+ )
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default=None,
+ required=True,
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
+ )
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
+
+ parser.add_argument(
+ "--save_every_n_epochs",
+ type=int,
+ default=None,
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
+ )
+ parser.add_argument(
+ "--save_every_n_steps",
+ type=int,
+ default=None,
+ help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
+ )
+ parser.add_argument(
+ "--save_last_n_epochs",
+ type=int,
+ default=None,
+ help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
+ )
+ parser.add_argument(
+ "--save_last_n_epochs_state",
+ type=int,
+ default=None,
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
+ )
+ parser.add_argument(
+ "--save_last_n_steps",
+ type=int,
+ default=None,
+ help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
+ )
+ parser.add_argument(
+ "--save_last_n_steps_state",
+ type=int,
+ default=None,
+ help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
+ )
+ parser.add_argument(
+ "--save_state",
+ action="store_true",
+ help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
+ )
+ parser.add_argument(
+ "--save_state_on_train_end",
+ action="store_true",
+ help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
+ " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
+ )
+
+ # SAI Model spec
+ parser.add_argument(
+ "--metadata_title",
+ type=str,
+ default=None,
+ help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
+ )
+ parser.add_argument(
+ "--metadata_author",
+ type=str,
+ default=None,
+ help="author name for model metadata / メタデータに書き込まれるモデル作者名",
+ )
+ parser.add_argument(
+ "--metadata_description",
+ type=str,
+ default=None,
+ help="description for model metadata / メタデータに書き込まれるモデル説明",
+ )
+ parser.add_argument(
+ "--metadata_license",
+ type=str,
+ default=None,
+ help="license for model metadata / メタデータに書き込まれるモデルライセンス",
+ )
+ parser.add_argument(
+ "--metadata_tags",
+ type=str,
+ default=None,
+ help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
+ )
+
+ # huggingface settings
+ parser.add_argument(
+ "--huggingface_repo_id",
+ type=str,
+ default=None,
+ help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
+ )
+ parser.add_argument(
+ "--huggingface_repo_type",
+ type=str,
+ default=None,
+ help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
+ )
+ parser.add_argument(
+ "--huggingface_path_in_repo",
+ type=str,
+ default=None,
+ help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
+ )
+ parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
+ parser.add_argument(
+ "--huggingface_repo_visibility",
+ type=str,
+ default=None,
+ help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
+ )
+ parser.add_argument(
+ "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
+ )
+ parser.add_argument(
+ "--resume_from_huggingface",
+ action="store_true",
+ help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
+ )
+ parser.add_argument(
+ "--async_upload",
+ action="store_true",
+ help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
+ )
+
+ return parser
+
+
+def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
+ if not args.config_file:
+ return args
+
+ config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
+
+ if not os.path.exists(config_path):
+ logger.info(f"{config_path} not found.")
+ exit(1)
+
+ logger.info(f"Loading settings from {config_path}...")
+ with open(config_path, "r", encoding="utf-8") as f:
+ config_dict = toml.load(f)
+
+ # combine all sections into one
+ ignore_nesting_dict = {}
+ for section_name, section_dict in config_dict.items():
+ # if value is not dict, save key and value as is
+ if not isinstance(section_dict, dict):
+ ignore_nesting_dict[section_name] = section_dict
+ continue
+
+ # if value is dict, save all key and value into one dict
+ for key, value in section_dict.items():
+ ignore_nesting_dict[key] = value
+
+ config_args = argparse.Namespace(**ignore_nesting_dict)
+ args = parser.parse_args(namespace=config_args)
+ args.config_file = os.path.splitext(args.config_file)[0]
+ logger.info(args.config_file)
+
+ return args
+
+
+if __name__ == "__main__":
+ parser = setup_parser()
+
+ args = parser.parse_args()
+ args = read_config_from_file(args, parser)
+
+ trainer = FineTuningTrainer()
+ trainer.train(args)
diff --git a/hv_train_network.py b/hv_train_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..43615f087a1e8460e38690d2fdabe16a3def19ed
--- /dev/null
+++ b/hv_train_network.py
@@ -0,0 +1,2693 @@
+import ast
+import asyncio
+from datetime import timedelta
+import gc
+import importlib
+import argparse
+import math
+import os
+import pathlib
+import re
+import sys
+import random
+import time
+import json
+from multiprocessing import Value
+from typing import Any, Dict, List, Optional
+import accelerate
+import numpy as np
+from packaging.version import Version
+from PIL import Image
+
+import huggingface_hub
+import toml
+
+import torch
+from tqdm import tqdm
+from accelerate.utils import TorchDynamoPlugin, set_seed, DynamoBackend
+from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
+from safetensors.torch import load_file
+import transformers
+from diffusers.optimization import (
+ SchedulerType as DiffusersSchedulerType,
+ TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
+)
+from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
+
+from dataset import config_utils
+from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape, HYVideoDiffusionTransformer
+import hunyuan_model.text_encoder as text_encoder_module
+from hunyuan_model.vae import load_vae, VAE_VER
+import hunyuan_model.vae as vae_module
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+import networks.lora as lora_module
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, ARCHITECTURE_HUNYUAN_VIDEO_FULL
+from hv_generate_video import save_images_grid, save_videos_grid, resize_image_to_bucket, encode_to_latents
+
+import logging
+
+from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
+SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
+SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
+SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
+SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
+
+SS_METADATA_MINIMUM_KEYS = [
+ SS_METADATA_KEY_BASE_MODEL_VERSION,
+ SS_METADATA_KEY_NETWORK_MODULE,
+ SS_METADATA_KEY_NETWORK_DIM,
+ SS_METADATA_KEY_NETWORK_ALPHA,
+ SS_METADATA_KEY_NETWORK_ARGS,
+]
+
+
+def clean_memory_on_device(device: torch.device):
+ r"""
+ Clean memory on the specified device, will be called from training scripts.
+ """
+ gc.collect()
+
+ # device may "cuda" or "cuda:0", so we need to check the type of device
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ if device.type == "xpu":
+ torch.xpu.empty_cache()
+ if device.type == "mps":
+ torch.mps.empty_cache()
+
+
+# for collate_fn: epoch and step is multiprocessing.Value
+class collator_class:
+ def __init__(self, epoch, step, dataset):
+ self.current_epoch = epoch
+ self.current_step = step
+ self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
+
+ def __call__(self, examples):
+ worker_info = torch.utils.data.get_worker_info()
+ # worker_info is None in the main process
+ if worker_info is not None:
+ dataset = worker_info.dataset
+ else:
+ dataset = self.dataset
+
+ # set epoch and step
+ dataset.set_current_epoch(self.current_epoch.value)
+ dataset.set_current_step(self.current_step.value)
+ return examples[0]
+
+
+def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
+ """
+ DeepSpeed is not supported in this script currently.
+ """
+ if args.logging_dir is None:
+ logging_dir = None
+ else:
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
+
+ if args.log_with is None:
+ if logging_dir is not None:
+ log_with = "tensorboard"
+ else:
+ log_with = None
+ else:
+ log_with = args.log_with
+ if log_with in ["tensorboard", "all"]:
+ if logging_dir is None:
+ raise ValueError(
+ "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
+ )
+ if log_with in ["wandb", "all"]:
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError("No wandb / wandb がインストールされていないようです")
+ if logging_dir is not None:
+ os.makedirs(logging_dir, exist_ok=True)
+ os.environ["WANDB_DIR"] = logging_dir
+ if args.wandb_api_key is not None:
+ wandb.login(key=args.wandb_api_key)
+
+ kwargs_handlers = [
+ (
+ InitProcessGroupKwargs(
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
+ init_method=(
+ "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
+ ),
+ timeout=timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
+ )
+ if torch.cuda.device_count() > 1
+ else None
+ ),
+ (
+ DistributedDataParallelKwargs(
+ gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
+ )
+ if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
+ else None
+ ),
+ ]
+ kwargs_handlers = [i for i in kwargs_handlers if i is not None]
+
+ dynamo_plugin = None
+ if args.dynamo_backend.upper() != "NO":
+ dynamo_plugin = TorchDynamoPlugin(
+ backend=DynamoBackend(args.dynamo_backend.upper()),
+ mode=args.dynamo_mode,
+ fullgraph=args.dynamo_fullgraph,
+ dynamic=args.dynamo_dynamic,
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=log_with,
+ project_dir=logging_dir,
+ dynamo_plugin=dynamo_plugin,
+ kwargs_handlers=kwargs_handlers,
+ )
+ print("accelerator device:", accelerator.device)
+ return accelerator
+
+
+def line_to_prompt_dict(line: str) -> dict:
+ # subset of gen_img_diffusers
+ prompt_args = line.split(" --")
+ prompt_dict = {}
+ prompt_dict["prompt"] = prompt_args[0]
+
+ for parg in prompt_args:
+ try:
+ m = re.match(r"w (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["width"] = int(m.group(1))
+ continue
+
+ m = re.match(r"h (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["height"] = int(m.group(1))
+ continue
+
+ m = re.match(r"f (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["frame_count"] = int(m.group(1))
+ continue
+
+ m = re.match(r"d (\d+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["seed"] = int(m.group(1))
+ continue
+
+ m = re.match(r"s (\d+)", parg, re.IGNORECASE)
+ if m: # steps
+ prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
+ continue
+
+ m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
+ if m: # scale
+ prompt_dict["guidance_scale"] = float(m.group(1))
+ continue
+
+ m = re.match(r"fs ([\d\.]+)", parg, re.IGNORECASE)
+ if m: # scale
+ prompt_dict["discrete_flow_shift"] = float(m.group(1))
+ continue
+
+ m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
+ if m: # scale
+ prompt_dict["cfg_scale"] = float(m.group(1))
+ continue
+
+ m = re.match(r"n (.+)", parg, re.IGNORECASE)
+ if m: # negative prompt
+ prompt_dict["negative_prompt"] = m.group(1)
+ continue
+
+ m = re.match(r"i (.+)", parg, re.IGNORECASE)
+ if m: # negative prompt
+ prompt_dict["image_path"] = m.group(1)
+ continue
+
+ m = re.match(r"cn (.+)", parg, re.IGNORECASE)
+ if m:
+ prompt_dict["control_video_path"] = m.group(1)
+ continue
+
+ m = re.match(r"ci (.+)", parg, re.IGNORECASE)
+ if m:
+ # can be multiple control images
+ control_image_path = m.group(1)
+ if "control_image_path" not in prompt_dict:
+ prompt_dict["control_image_path"] = []
+ prompt_dict["control_image_path"].append(control_image_path)
+ continue
+
+ m = re.match(r"of (.+)", parg, re.IGNORECASE)
+ if m: # output folder
+ prompt_dict["one_frame"] = m.group(1)
+ continue
+
+ except ValueError as ex:
+ logger.error(f"Exception in parsing / 解析エラー: {parg}")
+ logger.error(ex)
+
+ return prompt_dict
+
+
+def load_prompts(prompt_file: str) -> list[Dict]:
+ # read prompts
+ if prompt_file.endswith(".txt"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+ prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
+ elif prompt_file.endswith(".toml"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ data = toml.load(f)
+ prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
+ elif prompt_file.endswith(".json"):
+ with open(prompt_file, "r", encoding="utf-8") as f:
+ prompts = json.load(f)
+
+ # preprocess prompts
+ for i in range(len(prompts)):
+ prompt_dict = prompts[i]
+ if isinstance(prompt_dict, str):
+ prompt_dict = line_to_prompt_dict(prompt_dict)
+ prompts[i] = prompt_dict
+ assert isinstance(prompt_dict, dict)
+
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
+ prompt_dict["enum"] = i
+ prompt_dict.pop("subset", None)
+
+ return prompts
+
+
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
+):
+ """Compute the density for sampling the timesteps when doing SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device="cpu")
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device="cpu")
+ return u
+
+
+def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+
+ # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
+ if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
+ # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
+ # round to nearest timestep
+ logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
+ step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
+ else:
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+
+def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
+ """Computes loss weighting scheme for SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
+ if weighting_scheme == "sigma_sqrt":
+ weighting = (sigmas**-2.0).float()
+ else:
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
+ weighting = 2 / (math.pi * bot)
+ else:
+ weighting = None # torch.ones_like(sigmas)
+ return weighting
+
+
+def should_sample_images(args, steps, epoch=None):
+ if steps == 0:
+ if not args.sample_at_first:
+ return False
+ else:
+ should_sample_by_steps = args.sample_every_n_steps is not None and steps % args.sample_every_n_steps == 0
+ should_sample_by_epochs = (
+ args.sample_every_n_epochs is not None and epoch is not None and epoch % args.sample_every_n_epochs == 0
+ )
+ if not should_sample_by_steps and not should_sample_by_epochs:
+ return False
+ return True
+
+
+class NetworkTrainer:
+ def __init__(self):
+ self.blocks_to_swap = None
+
+ # TODO 他のスクリプトと共通化する
+ def generate_step_logs(
+ self,
+ args: argparse.Namespace,
+ current_loss,
+ avr_loss,
+ lr_scheduler,
+ lr_descriptions,
+ optimizer=None,
+ keys_scaled=None,
+ mean_norm=None,
+ maximum_norm=None,
+ ):
+ network_train_unet_only = True
+ logs = {"loss/current": current_loss, "loss/average": avr_loss}
+
+ if keys_scaled is not None:
+ logs["max_norm/keys_scaled"] = keys_scaled
+ logs["max_norm/average_key_norm"] = mean_norm
+ logs["max_norm/max_key_norm"] = maximum_norm
+
+ lrs = lr_scheduler.get_last_lr()
+ for i, lr in enumerate(lrs):
+ if lr_descriptions is not None:
+ lr_desc = lr_descriptions[i]
+ else:
+ idx = i - (0 if network_train_unet_only else -1)
+ if idx == -1:
+ lr_desc = "textencoder"
+ else:
+ if len(lrs) > 2:
+ lr_desc = f"group{idx}"
+ else:
+ lr_desc = "unet"
+
+ logs[f"lr/{lr_desc}"] = lr
+
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
+ # tracking d*lr value
+ logs[f"lr/d*lr/{lr_desc}"] = (
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
+ )
+ if (
+ args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
+ ): # tracking d*lr value of unet.
+ logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
+ else:
+ idx = 0
+ if not network_train_unet_only:
+ logs["lr/textencoder"] = float(lrs[0])
+ idx = 1
+
+ for i in range(idx, len(lrs)):
+ logs[f"lr/group{i}"] = float(lrs[i])
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
+ logs[f"lr/d*lr/group{i}"] = (
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
+ )
+ if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
+ logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
+
+ return logs
+
+ def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
+ # adamw, adamw8bit, adafactor
+
+ optimizer_type = args.optimizer_type.lower()
+
+ # split optimizer_type and optimizer_args
+ optimizer_kwargs = {}
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
+ for arg in args.optimizer_args:
+ key, value = arg.split("=")
+ value = ast.literal_eval(value)
+ optimizer_kwargs[key] = value
+
+ lr = args.learning_rate
+ optimizer = None
+ optimizer_class = None
+
+ if optimizer_type.endswith("8bit".lower()):
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
+
+ if optimizer_type == "AdamW8bit".lower():
+ logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
+ optimizer_class = bnb.optim.AdamW8bit
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ elif optimizer_type == "Adafactor".lower():
+ # Adafactor: check relative_step and warmup_init
+ if "relative_step" not in optimizer_kwargs:
+ optimizer_kwargs["relative_step"] = True # default
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
+ logger.info(
+ f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
+ )
+ optimizer_kwargs["relative_step"] = True
+ logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
+
+ if optimizer_kwargs["relative_step"]:
+ logger.info(f"relative_step is true / relative_stepがtrueです")
+ if lr != 0.0:
+ logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
+ args.learning_rate = None
+
+ if args.lr_scheduler != "adafactor":
+ logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
+
+ lr = None
+ else:
+ if args.max_grad_norm != 0.0:
+ logger.warning(
+ f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
+ )
+ if args.lr_scheduler != "constant_with_warmup":
+ logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
+ logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
+
+ optimizer_class = transformers.optimization.Adafactor
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ elif optimizer_type == "AdamW".lower():
+ logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
+ optimizer_class = torch.optim.AdamW
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ if optimizer is None:
+ # 任意のoptimizerを使う
+ case_sensitive_optimizer_type = args.optimizer_type # not lower
+ logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
+
+ if "." not in case_sensitive_optimizer_type: # from torch.optim
+ optimizer_module = torch.optim
+ else: # from other library
+ values = case_sensitive_optimizer_type.split(".")
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
+ case_sensitive_optimizer_type = values[-1]
+
+ optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
+
+ # for logging
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
+
+ # get train and eval functions
+ if hasattr(optimizer, "train") and callable(optimizer.train):
+ train_fn = optimizer.train
+ eval_fn = optimizer.eval
+ else:
+ train_fn = lambda: None
+ eval_fn = lambda: None
+
+ return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
+
+ def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
+ return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
+
+ def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
+ # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
+ # this scheduler is used for logging only.
+ # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
+ class DummyScheduler:
+ def __init__(self, optimizer: torch.optim.Optimizer):
+ self.optimizer = optimizer
+
+ def step(self):
+ pass
+
+ def get_last_lr(self):
+ return [group["lr"] for group in self.optimizer.param_groups]
+
+ return DummyScheduler(optimizer)
+
+ def get_lr_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
+ """
+ Unified API to get any scheduler from its name.
+ """
+ # if schedulefree optimizer, return dummy scheduler
+ if self.is_schedulefree_optimizer(optimizer, args):
+ return self.get_dummy_scheduler(optimizer)
+
+ name = args.lr_scheduler
+ num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
+ num_warmup_steps: Optional[int] = (
+ int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
+ )
+ num_decay_steps: Optional[int] = (
+ int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
+ )
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
+ num_cycles = args.lr_scheduler_num_cycles
+ power = args.lr_scheduler_power
+ timescale = args.lr_scheduler_timescale
+ min_lr_ratio = args.lr_scheduler_min_lr_ratio
+
+ lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
+ if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
+ for arg in args.lr_scheduler_args:
+ key, value = arg.split("=")
+ value = ast.literal_eval(value)
+ lr_scheduler_kwargs[key] = value
+
+ def wrap_check_needless_num_warmup_steps(return_vals):
+ if num_warmup_steps is not None and num_warmup_steps != 0:
+ raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
+ return return_vals
+
+ # using any lr_scheduler from other library
+ if args.lr_scheduler_type:
+ lr_scheduler_type = args.lr_scheduler_type
+ logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
+ if "." not in lr_scheduler_type: # default to use torch.optim
+ lr_scheduler_module = torch.optim.lr_scheduler
+ else:
+ values = lr_scheduler_type.split(".")
+ lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
+ lr_scheduler_type = values[-1]
+ lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
+ lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
+ return lr_scheduler
+
+ if name.startswith("adafactor"):
+ assert (
+ type(optimizer) == transformers.optimization.Adafactor
+ ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
+ initial_lr = float(name.split(":")[1])
+ # logger.info(f"adafactor scheduler init lr {initial_lr}")
+ return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
+
+ if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
+ name = DiffusersSchedulerType(name)
+ schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
+ return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
+
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+ if name == SchedulerType.CONSTANT:
+ return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
+
+ if name == SchedulerType.INVERSE_SQRT:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ **lr_scheduler_kwargs,
+ )
+
+ if name == SchedulerType.POLYNOMIAL:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ power=power,
+ **lr_scheduler_kwargs,
+ )
+
+ if name == SchedulerType.COSINE_WITH_MIN_LR:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles / 2,
+ min_lr_rate=min_lr_ratio,
+ **lr_scheduler_kwargs,
+ )
+
+ # these schedulers do not require `num_decay_steps`
+ if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **lr_scheduler_kwargs,
+ )
+
+ # All other schedulers require `num_decay_steps`
+ if num_decay_steps is None:
+ raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_stable_steps=num_stable_steps,
+ num_decay_steps=num_decay_steps,
+ num_cycles=num_cycles / 2,
+ min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
+ **lr_scheduler_kwargs,
+ )
+
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_decay_steps=num_decay_steps,
+ **lr_scheduler_kwargs,
+ )
+
+ def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
+ if not args.resume:
+ return False
+
+ if not args.resume_from_huggingface:
+ logger.info(f"resume training from local state: {args.resume}")
+ accelerator.load_state(args.resume)
+ return True
+
+ logger.info(f"resume training from huggingface state: {args.resume}")
+ repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
+ path_in_repo = "/".join(args.resume.split("/")[2:])
+ revision = None
+ repo_type = None
+ if ":" in path_in_repo:
+ divided = path_in_repo.split(":")
+ if len(divided) == 2:
+ path_in_repo, revision = divided
+ repo_type = "model"
+ else:
+ path_in_repo, revision, repo_type = divided
+ logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
+
+ list_files = huggingface_utils.list_dir(
+ repo_id=repo_id,
+ subfolder=path_in_repo,
+ revision=revision,
+ token=args.huggingface_token,
+ repo_type=repo_type,
+ )
+
+ async def download(filename) -> str:
+ def task():
+ return huggingface_hub.hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ revision=revision,
+ repo_type=repo_type,
+ token=args.huggingface_token,
+ )
+
+ return await asyncio.get_event_loop().run_in_executor(None, task)
+
+ loop = asyncio.get_event_loop()
+ results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
+ if len(results) == 0:
+ raise ValueError(
+ "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
+ )
+ dirname = os.path.dirname(results[0])
+ accelerator.load_state(dirname)
+
+ return True
+
+ def get_noisy_model_input_and_timesteps(
+ self,
+ args: argparse.Namespace,
+ noise: torch.Tensor,
+ latents: torch.Tensor,
+ noise_scheduler: FlowMatchDiscreteScheduler,
+ device: torch.device,
+ dtype: torch.dtype,
+ ):
+ batch_size = noise.shape[0]
+
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
+ # Simple random t-based noise sampling
+ if args.timestep_sampling == "sigmoid":
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
+ else:
+ t = torch.rand((batch_size,), device=device)
+
+ elif args.timestep_sampling == "shift":
+ shift = args.discrete_flow_shift
+ logits_norm = torch.randn(batch_size, device=device)
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
+ t = logits_norm.sigmoid()
+ t = (t * shift) / (1 + (shift - 1) * t)
+
+ t_min = args.min_timestep if args.min_timestep is not None else 0
+ t_max = args.max_timestep if args.max_timestep is not None else 1000.0
+ t_min /= 1000.0
+ t_max /= 1000.0
+ t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
+
+ timesteps = t * 1000.0
+ t = t.view(-1, 1, 1, 1, 1)
+ noisy_model_input = (1 - t) * latents + t * noise
+
+ timesteps += 1 # 1 to 1000
+ else:
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=batch_size,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ # indices = (u * noise_scheduler.config.num_train_timesteps).long()
+ t_min = args.min_timestep if args.min_timestep is not None else 0
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
+ indices = (u * (t_max - t_min) + t_min).long()
+
+ timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
+
+ return noisy_model_input, timesteps
+
+ def show_timesteps(self, args: argparse.Namespace):
+ N_TRY = 100000
+ BATCH_SIZE = 1000
+ CONSOLE_WIDTH = 64
+ N_TIMESTEPS_PER_LINE = 25
+
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
+ # print(f"Noise scheduler timesteps: {noise_scheduler.timesteps}")
+
+ latents = torch.zeros(BATCH_SIZE, 1, 1, 1, 1, dtype=torch.float16)
+ noise = torch.ones_like(latents)
+
+ # sample timesteps
+ sampled_timesteps = [0] * noise_scheduler.config.num_train_timesteps
+ for i in tqdm(range(N_TRY // BATCH_SIZE)):
+ # we use noise=1, so retured noisy_model_input is same as timestep, because `noisy_model_input = (1 - t) * latents + t * noise`
+ actual_timesteps, _ = self.get_noisy_model_input_and_timesteps(
+ args, noise, latents, noise_scheduler, "cpu", torch.float16
+ )
+ actual_timesteps = actual_timesteps[:, 0, 0, 0, 0] * 1000
+ for t in actual_timesteps:
+ t = int(t.item())
+ sampled_timesteps[t] += 1
+
+ # sample weighting
+ sampled_weighting = [0] * noise_scheduler.config.num_train_timesteps
+ for i in tqdm(range(len(sampled_weighting))):
+ timesteps = torch.tensor([i + 1], device="cpu")
+ weighting = compute_loss_weighting_for_sd3(args.weighting_scheme, noise_scheduler, timesteps, "cpu", torch.float16)
+ if weighting is None:
+ weighting = torch.tensor(1.0, device="cpu")
+ elif torch.isinf(weighting).any():
+ weighting = torch.tensor(1.0, device="cpu")
+ sampled_weighting[i] = weighting.item()
+
+ # show results
+ if args.show_timesteps == "image":
+ # show timesteps with matplotlib
+ import matplotlib.pyplot as plt
+
+ plt.figure(figsize=(10, 5))
+ plt.subplot(1, 2, 1)
+ plt.bar(range(len(sampled_timesteps)), sampled_timesteps, width=1.0)
+ plt.title("Sampled timesteps")
+ plt.xlabel("Timestep")
+ plt.ylabel("Count")
+
+ plt.subplot(1, 2, 2)
+ plt.bar(range(len(sampled_weighting)), sampled_weighting, width=1.0)
+ plt.title("Sampled loss weighting")
+ plt.xlabel("Timestep")
+ plt.ylabel("Weighting")
+
+ plt.tight_layout()
+ plt.show()
+
+ else:
+ sampled_timesteps = np.array(sampled_timesteps)
+ sampled_weighting = np.array(sampled_weighting)
+
+ # average per line
+ sampled_timesteps = sampled_timesteps.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1)
+ sampled_weighting = sampled_weighting.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1)
+
+ max_count = max(sampled_timesteps)
+ print(f"Sampled timesteps: max count={max_count}")
+ for i, t in enumerate(sampled_timesteps):
+ line = f"{(i)*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: "
+ line += "#" * int(t / max_count * CONSOLE_WIDTH)
+ print(line)
+
+ max_weighting = max(sampled_weighting)
+ print(f"Sampled loss weighting: max weighting={max_weighting}")
+ for i, w in enumerate(sampled_weighting):
+ line = f"{i*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: {w:8.2f} "
+ line += "#" * int(w / max_weighting * CONSOLE_WIDTH)
+ print(line)
+
+ def sample_images(self, accelerator, args, epoch, steps, vae, transformer, sample_parameters, dit_dtype):
+ """architecture independent sample images"""
+ if not should_sample_images(args, steps, epoch):
+ return
+
+ logger.info("")
+ logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
+ if sample_parameters is None:
+ logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
+ return
+
+ distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
+
+ # Use the unwrapped model
+ transformer = accelerator.unwrap_model(transformer)
+ transformer.switch_block_swap_for_inference()
+
+ # Create a directory to save the samples
+ save_dir = os.path.join(args.output_dir, "sample")
+ os.makedirs(save_dir, exist_ok=True)
+
+ # save random state to restore later
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = None
+ try:
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
+ except Exception:
+ pass
+
+ if distributed_state.num_processes <= 1:
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
+ with torch.no_grad(), accelerator.autocast():
+ for sample_parameter in sample_parameters:
+ self.sample_image_inference(
+ accelerator, args, transformer, dit_dtype, vae, save_dir, sample_parameter, epoch, steps
+ )
+ clean_memory_on_device(accelerator.device)
+ else:
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
+ per_process_params = [] # list of lists
+ for i in range(distributed_state.num_processes):
+ per_process_params.append(sample_parameters[i :: distributed_state.num_processes])
+
+ with torch.no_grad():
+ with distributed_state.split_between_processes(per_process_params) as sample_parameter_lists:
+ for sample_parameter in sample_parameter_lists[0]:
+ self.sample_image_inference(
+ accelerator, args, transformer, dit_dtype, vae, save_dir, sample_parameter, epoch, steps
+ )
+ clean_memory_on_device(accelerator.device)
+
+ torch.set_rng_state(rng_state)
+ if cuda_rng_state is not None:
+ torch.cuda.set_rng_state(cuda_rng_state)
+
+ transformer.switch_block_swap_for_training()
+ clean_memory_on_device(accelerator.device)
+
+ def sample_image_inference(self, accelerator, args, transformer, dit_dtype, vae, save_dir, sample_parameter, epoch, steps):
+ """architecture independent sample images"""
+ sample_steps = sample_parameter.get("sample_steps", 20)
+ width = sample_parameter.get("width", 256) # make smaller for faster and memory saving inference
+ height = sample_parameter.get("height", 256)
+ frame_count = sample_parameter.get("frame_count", 1)
+ guidance_scale = sample_parameter.get("guidance_scale", self.default_guidance_scale)
+ discrete_flow_shift = sample_parameter.get("discrete_flow_shift", 14.5)
+ seed = sample_parameter.get("seed")
+ prompt: str = sample_parameter.get("prompt", "")
+ cfg_scale = sample_parameter.get("cfg_scale", None) # None for architecture default
+ negative_prompt = sample_parameter.get("negative_prompt", None)
+
+ frame_count = (frame_count - 1) // 4 * 4 + 1 # 1, 5, 9, 13, ... For HunyuanVideo and Wan2.1
+
+ if self.i2v_training:
+ image_path = sample_parameter.get("image_path", None)
+ if image_path is None:
+ logger.error("No image_path for i2v model / i2vモデルのサンプル画像生成にはimage_pathが必要です")
+ return
+ else:
+ image_path = None
+
+ if self.control_training:
+ control_video_path = sample_parameter.get("control_video_path", None)
+ if control_video_path is None:
+ logger.error(
+ "No control_video_path for control model / controlモデルのサンプル画像生成にはcontrol_video_pathが必要です"
+ )
+ return
+ else:
+ control_video_path = None
+
+ device = accelerator.device
+ if seed is not None:
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ generator = torch.Generator(device=device).manual_seed(seed)
+ else:
+ # True random sample image generation
+ torch.seed()
+ torch.cuda.seed()
+ generator = torch.Generator(device=device).manual_seed(torch.initial_seed())
+
+ logger.info(f"prompt: {prompt}")
+ logger.info(f"height: {height}")
+ logger.info(f"width: {width}")
+ logger.info(f"frame count: {frame_count}")
+ logger.info(f"sample steps: {sample_steps}")
+ logger.info(f"guidance scale: {guidance_scale}")
+ logger.info(f"discrete flow shift: {discrete_flow_shift}")
+ if seed is not None:
+ logger.info(f"seed: {seed}")
+
+ do_classifier_free_guidance = False
+ if negative_prompt is not None:
+ do_classifier_free_guidance = True
+ logger.info(f"negative prompt: {negative_prompt}")
+ logger.info(f"cfg scale: {cfg_scale}")
+
+ if self.i2v_training:
+ logger.info(f"image path: {image_path}")
+ if self.control_training:
+ logger.info(f"control video path: {control_video_path}")
+
+ # inference: architecture dependent
+ video = self.do_inference(
+ accelerator,
+ args,
+ sample_parameter,
+ vae,
+ dit_dtype,
+ transformer,
+ discrete_flow_shift,
+ sample_steps,
+ width,
+ height,
+ frame_count,
+ generator,
+ do_classifier_free_guidance,
+ guidance_scale,
+ cfg_scale,
+ image_path=image_path,
+ control_video_path=control_video_path,
+ )
+
+ # Save video
+ if video is None:
+ logger.error("No video generated / 生成された動画がありません")
+ return
+
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
+ seed_suffix = "" if seed is None else f"_{seed}"
+ prompt_idx = sample_parameter.get("enum", 0)
+ save_path = (
+ f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{prompt_idx:02d}_{ts_str}{seed_suffix}"
+ )
+ if video.shape[2] == 1:
+ save_images_grid(video, save_dir, save_path, create_subdir=False)
+ else:
+ save_videos_grid(video, os.path.join(save_dir, save_path) + ".mp4")
+
+ # Move models back to initial state
+ vae.to("cpu")
+ clean_memory_on_device(device)
+
+ # region model specific
+
+ @property
+ def architecture(self) -> str:
+ return ARCHITECTURE_HUNYUAN_VIDEO
+
+ @property
+ def architecture_full_name(self) -> str:
+ return ARCHITECTURE_HUNYUAN_VIDEO_FULL
+
+ def handle_model_specific_args(self, args: argparse.Namespace):
+ self.pos_embed_cache = {}
+
+ self._i2v_training = args.dit_in_channels == 32 # may be changed in the future
+ if self._i2v_training:
+ logger.info("I2V training mode")
+
+ self._control_training = False # HunyuanVideo does not support control training yet
+
+ self.default_guidance_scale = 6.0
+
+ @property
+ def i2v_training(self) -> bool:
+ return self._i2v_training
+
+ @property
+ def control_training(self) -> bool:
+ return self._control_training
+
+ def process_sample_prompts(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ sample_prompts: str,
+ ):
+ text_encoder1, text_encoder2, fp8_llm = args.text_encoder1, args.text_encoder2, args.fp8_llm
+
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
+ prompts = load_prompts(sample_prompts)
+
+ def encode_for_text_encoder(text_encoder, is_llm=True):
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
+ with accelerator.autocast(), torch.no_grad():
+ for prompt_dict in prompts:
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", None)]:
+ if p is None:
+ continue
+ if p not in sample_prompts_te_outputs:
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
+
+ data_type = "video"
+ text_inputs = text_encoder.text2tokens(p, data_type=data_type)
+
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
+ sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
+
+ return sample_prompts_te_outputs
+
+ # Load Text Encoder 1 and encode
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
+ logger.info(f"loading text encoder 1: {text_encoder1}")
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
+
+ logger.info("encoding with Text Encoder 1")
+ te_outputs_1 = encode_for_text_encoder(text_encoder_1)
+ del text_encoder_1
+
+ # Load Text Encoder 2 and encode
+ logger.info(f"loading text encoder 2: {text_encoder2}")
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
+
+ logger.info("encoding with Text Encoder 2")
+ te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
+ del text_encoder_2
+
+ # prepare sample parameters
+ sample_parameters = []
+ for prompt_dict in prompts:
+ prompt_dict_copy = prompt_dict.copy()
+
+ p = prompt_dict.get("prompt", "")
+ prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
+ prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
+ prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
+ prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
+
+ p = prompt_dict.get("negative_prompt", None)
+ if p is not None:
+ prompt_dict_copy["negative_llm_embeds"] = te_outputs_1[p][0]
+ prompt_dict_copy["negative_llm_mask"] = te_outputs_1[p][1]
+ prompt_dict_copy["negative_clipL_embeds"] = te_outputs_2[p][0]
+ prompt_dict_copy["negative_clipL_mask"] = te_outputs_2[p][1]
+
+ sample_parameters.append(prompt_dict_copy)
+
+ clean_memory_on_device(accelerator.device)
+
+ return sample_parameters
+
+ def do_inference(
+ self,
+ accelerator,
+ args,
+ sample_parameter,
+ vae,
+ dit_dtype,
+ transformer,
+ discrete_flow_shift,
+ sample_steps,
+ width,
+ height,
+ frame_count,
+ generator,
+ do_classifier_free_guidance,
+ guidance_scale,
+ cfg_scale,
+ image_path=None,
+ control_video_path=None,
+ ):
+ """architecture dependent inference"""
+ device = accelerator.device
+ if cfg_scale is None:
+ cfg_scale = 1.0
+ do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
+
+ # Prepare scheduler for each prompt
+ scheduler = FlowMatchDiscreteScheduler(shift=discrete_flow_shift, reverse=True, solver="euler")
+
+ # Number of inference steps for sampling
+ scheduler.set_timesteps(sample_steps, device=device)
+ timesteps = scheduler.timesteps
+
+ # Calculate latent video length based on VAE version
+ if "884" in VAE_VER:
+ latent_video_length = (frame_count - 1) // 4 + 1
+ elif "888" in VAE_VER:
+ latent_video_length = (frame_count - 1) // 8 + 1
+ else:
+ latent_video_length = frame_count
+
+ # Get embeddings
+ prompt_embeds = sample_parameter["llm_embeds"].to(device=device, dtype=dit_dtype)
+ prompt_mask = sample_parameter["llm_mask"].to(device=device)
+ prompt_embeds_2 = sample_parameter["clipL_embeds"].to(device=device, dtype=dit_dtype)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = sample_parameter["negative_llm_embeds"].to(device=device, dtype=dit_dtype)
+ negative_prompt_mask = sample_parameter["negative_llm_mask"].to(device=device)
+ negative_prompt_embeds_2 = sample_parameter["negative_clipL_embeds"].to(device=device, dtype=dit_dtype)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask], dim=0)
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2], dim=0)
+
+ num_channels_latents = 16 # transformer.config.in_channels
+ vae_scale_factor = 2 ** (4 - 1) # Assuming 4 VAE blocks
+
+ # Initialize latents
+ shape_or_frame = (
+ 1,
+ num_channels_latents,
+ 1,
+ height // vae_scale_factor,
+ width // vae_scale_factor,
+ )
+ latents = []
+ for _ in range(latent_video_length):
+ latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=dit_dtype))
+ latents = torch.cat(latents, dim=2)
+
+ if self.i2v_training:
+ # Move VAE to the appropriate device for sampling
+ vae.to(device)
+ vae.eval()
+
+ image = Image.open(image_path)
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
+ image = image / 255.0
+
+ logger.info(f"Encoding image to latents")
+ image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
+ image_latents = image_latents.to(device=device, dtype=dit_dtype)
+
+ vae.to("cpu")
+ clean_memory_on_device(device)
+
+ zero_latents = torch.zeros_like(latents)
+ zero_latents[:, :, :1, :, :] = image_latents
+ image_latents = zero_latents
+ else:
+ image_latents = None
+
+ # Guidance scale
+ guidance_expand = torch.tensor([guidance_scale * 1000.0], dtype=torch.float32, device=device).to(dit_dtype)
+
+ # Get rotary positional embeddings
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:])
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
+
+ # Wrap the inner loop with tqdm to track progress over timesteps
+ prompt_idx = sample_parameter.get("enum", 0)
+ with torch.no_grad():
+ for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")):
+ latents_input = scheduler.scale_model_input(latents, t)
+
+ if do_classifier_free_guidance:
+ latents_input = torch.cat([latents_input, latents_input], dim=0) # 2, C, F, H, W
+
+ if image_latents is not None:
+ latents_image_input = (
+ image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
+ )
+ latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
+
+ noise_pred = transformer(
+ latents_input,
+ t.repeat(latents.shape[0]).to(device=device, dtype=dit_dtype),
+ text_states=prompt_embeds,
+ text_mask=prompt_mask,
+ text_states_2=prompt_embeds_2,
+ freqs_cos=freqs_cos,
+ freqs_sin=freqs_sin,
+ guidance=guidance_expand,
+ return_dict=True,
+ )["x"]
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # Compute the previous noisy sample x_t -> x_t-1
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Move VAE to the appropriate device for sampling
+ vae.to(device)
+ vae.eval()
+
+ # Decode latents to video
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
+ else:
+ latents = latents / vae.config.scaling_factor
+
+ latents = latents.to(device=device, dtype=vae.dtype)
+ with torch.no_grad():
+ video = vae.decode(latents, return_dict=False)[0]
+ video = (video / 2 + 0.5).clamp(0, 1)
+ video = video.cpu().float()
+
+ return video
+
+ def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=vae_path)
+
+ if args.vae_chunk_size is not None:
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
+ if args.vae_spatial_tile_sample_min_size is not None:
+ vae.enable_spatial_tiling(True)
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
+ elif args.vae_tiling:
+ vae.enable_spatial_tiling(True)
+
+ return vae
+
+ def load_transformer(
+ self,
+ accelerator: Accelerator,
+ args: argparse.Namespace,
+ dit_path: str,
+ attn_mode: str,
+ split_attn: bool,
+ loading_device: str,
+ dit_weight_dtype: Optional[torch.dtype],
+ ):
+ transformer = load_transformer(dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.dit_in_channels)
+
+ if args.img_in_txt_in_offloading:
+ logger.info("Enable offloading img_in and txt_in to CPU")
+ transformer.enable_img_in_txt_in_offloading()
+
+ return transformer
+
+ def scale_shift_latents(self, latents):
+ latents = latents * vae_module.SCALING_FACTOR
+ return latents
+
+ def call_dit(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ transformer_arg,
+ latents: torch.Tensor,
+ batch: dict[str, torch.Tensor],
+ noise: torch.Tensor,
+ noisy_model_input: torch.Tensor,
+ timesteps: torch.Tensor,
+ network_dtype: torch.dtype,
+ ):
+ transformer: HYVideoDiffusionTransformer = transformer_arg
+ bsz = latents.shape[0]
+
+ # I2V training
+ if self.i2v_training:
+ image_latents = torch.zeros_like(latents)
+ image_latents[:, :, :1, :, :] = latents[:, :, :1, :, :]
+ noisy_model_input = torch.cat([noisy_model_input, image_latents], dim=1) # concat along channel dim
+
+ # ensure guidance_scale in args is float
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
+
+ # ensure the hidden state will require grad
+ if args.gradient_checkpointing:
+ noisy_model_input.requires_grad_(True)
+ guidance_vec.requires_grad_(True)
+
+ pos_emb_shape = latents.shape[1:]
+ if pos_emb_shape not in self.pos_embed_cache:
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:])
+ # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
+ # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
+ self.pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
+ else:
+ freqs_cos, freqs_sin = self.pos_embed_cache[pos_emb_shape]
+
+ # call DiT
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
+ with accelerator.autocast():
+ model_pred = transformer(
+ noisy_model_input,
+ timesteps,
+ text_states=batch["llm"],
+ text_mask=batch["llm_mask"],
+ text_states_2=batch["clipL"],
+ freqs_cos=freqs_cos,
+ freqs_sin=freqs_sin,
+ guidance=guidance_vec,
+ return_dict=False,
+ )
+
+ # flow matching loss
+ target = noise - latents
+
+ return model_pred, target
+
+ # endregion model specific
+
+ def train(self, args):
+ # check required arguments
+ if args.dataset_config is None:
+ raise ValueError("dataset_config is required / dataset_configが必要です")
+ if args.dit is None:
+ raise ValueError("path to DiT model is required / DiTモデルのパスが必要です")
+ assert not args.fp8_scaled or args.fp8_base, "fp8_scaled requires fp8_base / fp8_scaledはfp8_baseが必要です"
+
+ if args.sage_attn:
+ raise ValueError(
+ "SageAttention doesn't support training currently. Please use `--sdpa` or `--xformers` etc. instead."
+ " / SageAttentionは現在学習をサポートしていないようです。`--sdpa`や`--xformers`などの他のオプションを使ってください"
+ )
+
+ # check model specific arguments
+ self.handle_model_specific_args(args)
+
+ # show timesteps for debugging
+ if args.show_timesteps:
+ self.show_timesteps(args)
+ return
+
+ session_id = random.randint(0, 2**32)
+ training_started_at = time.time()
+ # setup_logging(args, reset=True)
+
+ if args.seed is None:
+ args.seed = random.randint(0, 2**32)
+ set_seed(args.seed)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=self.architecture)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
+
+ current_epoch = Value("i", 0)
+ current_step = Value("i", 0)
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
+ collator = collator_class(current_epoch, current_step, ds_for_collator)
+
+ # prepare accelerator
+ logger.info("preparing accelerator")
+ accelerator = prepare_accelerator(args)
+ is_main_process = accelerator.is_main_process
+
+ # prepare dtype
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # HunyuanVideo: bfloat16 or float16, Wan2.1: bfloat16
+ dit_dtype = torch.bfloat16 if args.dit_dtype is None else model_utils.str_to_dtype(args.dit_dtype)
+ dit_weight_dtype = (None if args.fp8_scaled else torch.float8_e4m3fn) if args.fp8_base else dit_dtype
+ logger.info(f"DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
+
+ # get embedding for sampling images
+ vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
+ sample_parameters = None
+ vae = None
+ if args.sample_prompts:
+ sample_parameters = self.process_sample_prompts(args, accelerator, args.sample_prompts)
+
+ # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
+ vae = self.load_vae(args, vae_dtype=vae_dtype, vae_path=args.vae)
+ vae.requires_grad_(False)
+ vae.eval()
+
+ # load DiT model
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
+ self.blocks_to_swap = blocks_to_swap
+ loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
+
+ logger.info(f"Loading DiT model from {args.dit}")
+ if args.sdpa:
+ attn_mode = "torch"
+ elif args.flash_attn:
+ attn_mode = "flash"
+ elif args.sage_attn:
+ attn_mode = "sageattn"
+ elif args.xformers:
+ attn_mode = "xformers"
+ elif args.flash3:
+ attn_mode = "flash3"
+ else:
+ raise ValueError(
+ f"either --sdpa, --flash-attn, --flash3, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --flash3, --sage-attn, --xformersのいずれかを指定してください"
+ )
+ transformer = self.load_transformer(
+ accelerator, args, args.dit, attn_mode, args.split_attn, loading_device, dit_weight_dtype
+ )
+ transformer.eval()
+ transformer.requires_grad_(False)
+
+ if blocks_to_swap > 0:
+ logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
+ transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
+ transformer.move_to_device_except_swap_blocks(accelerator.device)
+
+ # load network model for differential training
+ sys.path.append(os.path.dirname(__file__))
+ accelerator.print("import network module:", args.network_module)
+ network_module: lora_module = importlib.import_module(args.network_module) # actual module may be different
+
+ if args.base_weights is not None:
+ # if base_weights is specified, merge the weights to DiT model
+ for i, weight_path in enumerate(args.base_weights):
+ if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
+ multiplier = 1.0
+ else:
+ multiplier = args.base_weights_multiplier[i]
+
+ accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
+
+ weights_sd = load_file(weight_path)
+ module = network_module.create_arch_network_from_weights(
+ multiplier, weights_sd, unet=transformer, for_inference=True
+ )
+ module.merge_to(None, transformer, weights_sd, weight_dtype, "cpu")
+
+ accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
+
+ # prepare network
+ net_kwargs = {}
+ if args.network_args is not None:
+ for net_arg in args.network_args:
+ key, value = net_arg.split("=")
+ net_kwargs[key] = value
+
+ if args.dim_from_weights:
+ logger.info(f"Loading network from weights: {args.dim_from_weights}")
+ weights_sd = load_file(args.dim_from_weights)
+ network, _ = network_module.create_arch_network_from_weights(1, weights_sd, unet=transformer)
+ else:
+ # We use the name create_arch_network for compatibility with LyCORIS
+ if hasattr(network_module, "create_arch_network"):
+ network = network_module.create_arch_network(
+ 1.0,
+ args.network_dim,
+ args.network_alpha,
+ vae,
+ None,
+ transformer,
+ neuron_dropout=args.network_dropout,
+ **net_kwargs,
+ )
+ else:
+ # LyCORIS compatibility
+ network = network_module.create_network(
+ 1.0,
+ args.network_dim,
+ args.network_alpha,
+ vae,
+ None,
+ transformer,
+ **net_kwargs,
+ )
+ if network is None:
+ return
+
+ if hasattr(network_module, "prepare_network"):
+ network.prepare_network(args)
+
+ # apply network to DiT
+ network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
+
+ if args.network_weights is not None:
+ # FIXME consider alpha of weights: this assumes that the alpha is not changed
+ info = network.load_weights(args.network_weights)
+ accelerator.print(f"load network weights from {args.network_weights}: {info}")
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ network.enable_gradient_checkpointing() # may have no effect
+
+ # prepare optimizer, data loader etc.
+ accelerator.print("prepare optimizer, data loader etc.")
+
+ trainable_params, lr_descriptions = network.prepare_optimizer_params(unet_lr=args.learning_rate)
+ optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
+ args, trainable_params
+ )
+
+ # prepare dataloader
+
+ # num workers for data loader: if 0, persistent_workers is not available
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset_group,
+ batch_size=1,
+ shuffle=True,
+ collate_fn=collator,
+ num_workers=n_workers,
+ persistent_workers=args.persistent_data_loader_workers,
+ )
+
+ # calculate max_train_steps
+ if args.max_train_epochs is not None:
+ args.max_train_steps = args.max_train_epochs * math.ceil(
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
+ )
+ accelerator.print(
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
+ )
+
+ # send max_train_steps to train_dataset_group
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
+
+ # prepare lr_scheduler
+ lr_scheduler = self.get_lr_scheduler(args, optimizer, accelerator.num_processes)
+
+ # prepare training model. accelerator does some magic here
+
+ # experimental feature: train the model with gradients in fp16/bf16
+ network_dtype = torch.float32
+ args.full_fp16 = args.full_bf16 = False # temporary disabled because stochastic rounding is not supported yet
+ if args.full_fp16:
+ assert (
+ args.mixed_precision == "fp16"
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
+ accelerator.print("enable full fp16 training.")
+ network_dtype = weight_dtype
+ network.to(network_dtype)
+ elif args.full_bf16:
+ assert (
+ args.mixed_precision == "bf16"
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
+ accelerator.print("enable full bf16 training.")
+ network_dtype = weight_dtype
+ network.to(network_dtype)
+
+ if dit_weight_dtype != dit_dtype and dit_weight_dtype is not None:
+ logger.info(f"casting model to {dit_weight_dtype}")
+ transformer.to(dit_weight_dtype)
+
+ if blocks_to_swap > 0:
+ transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
+ accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
+ accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
+ else:
+ transformer = accelerator.prepare(transformer)
+
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
+ training_model = network
+
+ if args.gradient_checkpointing:
+ transformer.train()
+ else:
+ transformer.eval()
+
+ accelerator.unwrap_model(network).prepare_grad_etc(transformer)
+
+ if args.full_fp16:
+ # patch accelerator for fp16 training
+ # def patch_accelerator_for_fp16_training(accelerator):
+ org_unscale_grads = accelerator.scaler._unscale_grads_
+
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
+
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
+
+ # before resuming make hook for saving/loading to save/load the network weights only
+ def save_model_hook(models, weights, output_dir):
+ # pop weights of other models than network to save only network weights
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
+ if accelerator.is_main_process: # or args.deepspeed:
+ remove_indices = []
+ for i, model in enumerate(models):
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
+ remove_indices.append(i)
+ for i in reversed(remove_indices):
+ if len(weights) > i:
+ weights.pop(i)
+ # print(f"save model hook: {len(weights)} weights will be saved")
+
+ def load_model_hook(models, input_dir):
+ # remove models except network
+ remove_indices = []
+ for i, model in enumerate(models):
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
+ remove_indices.append(i)
+ for i in reversed(remove_indices):
+ models.pop(i)
+ # print(f"load model hook: {len(models)} models will be loaded")
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # resume from local or huggingface. accelerator.step is set
+ self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
+
+ # epoch数を計算する
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # 学習する
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ accelerator.print("running training / 学習開始")
+ accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
+ accelerator.print(
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
+ )
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
+
+ # TODO refactor metadata creation and move to util
+ metadata = {
+ "ss_" "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
+ "ss_training_started_at": training_started_at, # unix timestamp
+ "ss_output_name": args.output_name,
+ "ss_learning_rate": args.learning_rate,
+ "ss_num_train_items": train_dataset_group.num_train_items,
+ "ss_num_batches_per_epoch": len(train_dataloader),
+ "ss_num_epochs": num_train_epochs,
+ "ss_gradient_checkpointing": args.gradient_checkpointing,
+ "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
+ "ss_max_train_steps": args.max_train_steps,
+ "ss_lr_warmup_steps": args.lr_warmup_steps,
+ "ss_lr_scheduler": args.lr_scheduler,
+ SS_METADATA_KEY_BASE_MODEL_VERSION: self.architecture_full_name,
+ # "ss_network_module": args.network_module,
+ # "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
+ # "ss_network_alpha": args.network_alpha, # some networks may not have alpha
+ SS_METADATA_KEY_NETWORK_MODULE: args.network_module,
+ SS_METADATA_KEY_NETWORK_DIM: args.network_dim,
+ SS_METADATA_KEY_NETWORK_ALPHA: args.network_alpha,
+ "ss_network_dropout": args.network_dropout, # some networks may not have dropout
+ "ss_mixed_precision": args.mixed_precision,
+ "ss_seed": args.seed,
+ "ss_training_comment": args.training_comment, # will not be updated after training
+ # "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
+ "ss_max_grad_norm": args.max_grad_norm,
+ "ss_fp8_base": bool(args.fp8_base),
+ # "ss_fp8_llm": bool(args.fp8_llm), # remove this because this is only for HuanyuanVideo TODO set architecure dependent metadata
+ "ss_full_fp16": bool(args.full_fp16),
+ "ss_full_bf16": bool(args.full_bf16),
+ "ss_weighting_scheme": args.weighting_scheme,
+ "ss_logit_mean": args.logit_mean,
+ "ss_logit_std": args.logit_std,
+ "ss_mode_scale": args.mode_scale,
+ "ss_guidance_scale": args.guidance_scale,
+ "ss_timestep_sampling": args.timestep_sampling,
+ "ss_sigmoid_scale": args.sigmoid_scale,
+ "ss_discrete_flow_shift": args.discrete_flow_shift,
+ }
+
+ datasets_metadata = []
+ # tag_frequency = {} # merge tag frequency for metadata editor # TODO support tag frequency
+ for dataset in train_dataset_group.datasets:
+ dataset_metadata = dataset.get_metadata()
+ datasets_metadata.append(dataset_metadata)
+
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
+
+ # add extra args
+ if args.network_args:
+ # metadata["ss_network_args"] = json.dumps(net_kwargs)
+ metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(net_kwargs)
+
+ # model name and hash
+ # calculate hash takes time, so we omit it for now
+ if args.dit is not None:
+ # logger.info(f"calculate hash for DiT model: {args.dit}")
+ logger.info(f"set DiT model name for metadata: {args.dit}")
+ sd_model_name = args.dit
+ if os.path.exists(sd_model_name):
+ # metadata["ss_sd_model_hash"] = model_utils.model_hash(sd_model_name)
+ # metadata["ss_new_sd_model_hash"] = model_utils.calculate_sha256(sd_model_name)
+ sd_model_name = os.path.basename(sd_model_name)
+ metadata["ss_sd_model_name"] = sd_model_name
+
+ if args.vae is not None:
+ # logger.info(f"calculate hash for VAE model: {args.vae}")
+ logger.info(f"set VAE model name for metadata: {args.vae}")
+ vae_name = args.vae
+ if os.path.exists(vae_name):
+ # metadata["ss_vae_hash"] = model_utils.model_hash(vae_name)
+ # metadata["ss_new_vae_hash"] = model_utils.calculate_sha256(vae_name)
+ vae_name = os.path.basename(vae_name)
+ metadata["ss_vae_name"] = vae_name
+
+ metadata = {k: str(v) for k, v in metadata.items()}
+
+ # make minimum metadata for filtering
+ minimum_metadata = {}
+ for key in SS_METADATA_MINIMUM_KEYS:
+ if key in metadata:
+ minimum_metadata[key] = metadata[key]
+
+ if accelerator.is_main_process:
+ init_kwargs = {}
+ if args.wandb_run_name:
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
+ if args.log_tracker_config is not None:
+ init_kwargs = toml.load(args.log_tracker_config)
+ accelerator.init_trackers(
+ "network_train" if args.log_tracker_name is None else args.log_tracker_name,
+ config=train_utils.get_sanitized_config_or_none(args),
+ init_kwargs=init_kwargs,
+ )
+
+ # TODO skip until initial step
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
+
+ epoch_to_start = 0
+ global_step = 0
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
+
+ loss_recorder = train_utils.LossRecorder()
+ del train_dataset_group
+
+ # function for saving/removing
+ save_dtype = dit_dtype
+
+ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
+ os.makedirs(args.output_dir, exist_ok=True)
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
+
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
+ metadata["ss_training_finished_at"] = str(time.time())
+ metadata["ss_steps"] = str(steps)
+ metadata["ss_epoch"] = str(epoch_no)
+
+ metadata_to_save = minimum_metadata if args.no_metadata else metadata
+
+ title = args.metadata_title if args.metadata_title is not None else args.output_name
+ if args.min_timestep is not None or args.max_timestep is not None:
+ min_time_step = args.min_timestep if args.min_timestep is not None else 0
+ max_time_step = args.max_timestep if args.max_timestep is not None else 1000
+ md_timesteps = (min_time_step, max_time_step)
+ else:
+ md_timesteps = None
+
+ sai_metadata = sai_model_spec.build_metadata(
+ None,
+ self.architecture,
+ time.time(),
+ title,
+ None,
+ args.metadata_author,
+ args.metadata_description,
+ args.metadata_license,
+ args.metadata_tags,
+ timesteps=md_timesteps,
+ )
+
+ metadata_to_save.update(sai_metadata)
+
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
+ if args.huggingface_repo_id is not None:
+ huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
+
+ def remove_model(old_ckpt_name):
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
+ if os.path.exists(old_ckpt_file):
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
+ os.remove(old_ckpt_file)
+
+ # For --sample_at_first
+ if should_sample_images(args, global_step, epoch=0):
+ optimizer_eval_fn()
+ self.sample_images(accelerator, args, 0, global_step, vae, transformer, sample_parameters, dit_dtype)
+ optimizer_train_fn()
+ if len(accelerator.trackers) > 0:
+ # log empty object to commit the sample images to wandb
+ accelerator.log({}, step=0)
+
+ # training loop
+
+ # log device and dtype for each model
+ logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
+
+ clean_memory_on_device(accelerator.device)
+
+ for epoch in range(epoch_to_start, num_train_epochs):
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
+ current_epoch.value = epoch + 1
+
+ metadata["ss_epoch"] = str(epoch + 1)
+
+ accelerator.unwrap_model(network).on_epoch_start(transformer)
+
+ for step, batch in enumerate(train_dataloader):
+ latents = batch["latents"]
+ bsz = latents.shape[0]
+ current_step.value = global_step
+
+ with accelerator.accumulate(training_model):
+ accelerator.unwrap_model(network).on_step_start()
+
+ latents = self.scale_shift_latents(latents)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+
+ # calculate model input and timesteps
+ noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
+ args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
+ )
+
+ weighting = compute_loss_weighting_for_sd3(
+ args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
+ )
+
+ model_pred, target = self.call_dit(
+ args, accelerator, transformer, latents, batch, noise, noisy_model_input, timesteps, network_dtype
+ )
+ loss = torch.nn.functional.mse_loss(model_pred.to(network_dtype), target, reduction="none")
+
+ if weighting is not None:
+ loss = loss * weighting
+ # loss = loss.mean([1, 2, 3])
+ # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
+ # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
+
+ loss = loss.mean() # mean loss over all elements in batch
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ # self.all_reduce_network(accelerator, network) # sync DDP grad manually
+ state = accelerate.PartialState()
+ if state.distributed_type != accelerate.DistributedType.NO:
+ for param in network.parameters():
+ if param.grad is not None:
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
+
+ if args.max_grad_norm != 0.0:
+ params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ if args.scale_weight_norms:
+ keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
+ args.scale_weight_norms, accelerator.device
+ )
+ max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
+ else:
+ keys_scaled, mean_norm, maximum_norm = None, None, None
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ # to avoid calling optimizer_eval_fn() too frequently, we call it only when we need to sample images or save the model
+ should_sampling = should_sample_images(args, global_step, epoch=None)
+ should_saving = args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0
+
+ if should_sampling or should_saving:
+ optimizer_eval_fn()
+ if should_sampling:
+ self.sample_images(accelerator, args, None, global_step, vae, transformer, sample_parameters, dit_dtype)
+
+ if should_saving:
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
+
+ if args.save_state:
+ train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
+
+ remove_step_no = train_utils.get_remove_step_no(args, global_step)
+ if remove_step_no is not None:
+ remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
+ remove_model(remove_ckpt_name)
+ optimizer_train_fn()
+
+ current_loss = loss.detach().item()
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
+ avr_loss: float = loss_recorder.moving_average
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if args.scale_weight_norms:
+ progress_bar.set_postfix(**{**max_mean_logs, **logs})
+
+ if len(accelerator.trackers) > 0:
+ logs = self.generate_step_logs(
+ args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
+ )
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if len(accelerator.trackers) > 0:
+ logs = {"loss/epoch": loss_recorder.moving_average}
+ accelerator.log(logs, step=epoch + 1)
+
+ accelerator.wait_for_everyone()
+
+ # save model at the end of epoch if needed
+ optimizer_eval_fn()
+ if args.save_every_n_epochs is not None:
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
+ if is_main_process and saving:
+ ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
+
+ remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
+ if remove_epoch_no is not None:
+ remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
+ remove_model(remove_ckpt_name)
+
+ if args.save_state:
+ train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
+
+ self.sample_images(accelerator, args, epoch + 1, global_step, vae, transformer, sample_parameters, dit_dtype)
+ optimizer_train_fn()
+
+ # end of epoch
+
+ # metadata["ss_epoch"] = str(num_train_epochs)
+ metadata["ss_training_finished_at"] = str(time.time())
+
+ if is_main_process:
+ network = accelerator.unwrap_model(network)
+
+ accelerator.end_training()
+ optimizer_eval_fn()
+
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
+ train_utils.save_state_on_train_end(args, accelerator)
+
+ if is_main_process:
+ ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
+
+ logger.info("model saved.")
+
+
+def setup_parser_common() -> argparse.ArgumentParser:
+ def int_or_float(value):
+ if value.endswith("%"):
+ try:
+ return float(value[:-1]) / 100.0
+ except ValueError:
+ raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
+ try:
+ float_value = float(value)
+ if float_value >= 1 and float_value.is_integer():
+ return int(value)
+ return float(value)
+ except ValueError:
+ raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
+
+ parser = argparse.ArgumentParser()
+
+ # general settings
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ default=None,
+ help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
+ )
+ parser.add_argument(
+ "--dataset_config",
+ type=pathlib.Path,
+ default=None,
+ help="config file for dataset / データセットの設定ファイル",
+ )
+
+ # training settings
+ parser.add_argument(
+ "--sdpa",
+ action="store_true",
+ help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
+ )
+ parser.add_argument(
+ "--flash_attn",
+ action="store_true",
+ help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
+ )
+ parser.add_argument(
+ "--sage_attn",
+ action="store_true",
+ help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
+ )
+ parser.add_argument(
+ "--xformers",
+ action="store_true",
+ help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要",
+ )
+ parser.add_argument(
+ "--flash3",
+ action="store_true",
+ help="use FlashAttention 3 for CrossAttention, requires FlashAttention 3, HunyuanVideo does not support this yet"
+ " / CrossAttentionにFlashAttention 3を使う、FlashAttention 3が必要。HunyuanVideoは未対応。",
+ )
+ parser.add_argument(
+ "--split_attn",
+ action="store_true",
+ help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)"
+ " / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)",
+ )
+
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
+ parser.add_argument(
+ "--max_train_epochs",
+ type=int,
+ default=None,
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
+ )
+ parser.add_argument(
+ "--max_data_loader_n_workers",
+ type=int,
+ default=8,
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
+ )
+ parser.add_argument(
+ "--persistent_data_loader_workers",
+ action="store_true",
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
+ parser.add_argument(
+ "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help="use mixed precision / 混合精度を使う場合、その精度",
+ )
+
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default=None,
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
+ )
+ parser.add_argument(
+ "--log_with",
+ type=str,
+ default=None,
+ choices=["tensorboard", "wandb", "all"],
+ help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
+ )
+ parser.add_argument(
+ "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
+ )
+ parser.add_argument(
+ "--log_tracker_name",
+ type=str,
+ default=None,
+ help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
+ )
+ parser.add_argument(
+ "--wandb_run_name",
+ type=str,
+ default=None,
+ help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
+ )
+ parser.add_argument(
+ "--log_tracker_config",
+ type=str,
+ default=None,
+ help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
+ )
+ parser.add_argument(
+ "--wandb_api_key",
+ type=str,
+ default=None,
+ help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
+ )
+ parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
+
+ parser.add_argument(
+ "--ddp_timeout",
+ type=int,
+ default=None,
+ help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
+ )
+ parser.add_argument(
+ "--ddp_gradient_as_bucket_view",
+ action="store_true",
+ help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
+ )
+ parser.add_argument(
+ "--ddp_static_graph",
+ action="store_true",
+ help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
+ )
+
+ parser.add_argument(
+ "--sample_every_n_steps",
+ type=int,
+ default=None,
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
+ )
+ parser.add_argument(
+ "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
+ )
+ parser.add_argument(
+ "--sample_every_n_epochs",
+ type=int,
+ default=None,
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
+ )
+ parser.add_argument(
+ "--sample_prompts",
+ type=str,
+ default=None,
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
+ )
+
+ # optimizer and lr scheduler settings
+ parser.add_argument(
+ "--optimizer_type",
+ type=str,
+ default="",
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
+ "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
+ )
+ parser.add_argument(
+ "--optimizer_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
+ )
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
+ parser.add_argument(
+ "--max_grad_norm",
+ default=1.0,
+ type=float,
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
+ )
+
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps",
+ type=int_or_float,
+ default=0,
+ help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
+ " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
+ )
+ parser.add_argument(
+ "--lr_decay_steps",
+ type=int_or_float,
+ default=0,
+ help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
+ " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
+ )
+ parser.add_argument(
+ "--lr_scheduler_num_cycles",
+ type=int,
+ default=1,
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
+ )
+ parser.add_argument(
+ "--lr_scheduler_power",
+ type=float,
+ default=1,
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
+ )
+ parser.add_argument(
+ "--lr_scheduler_timescale",
+ type=int,
+ default=None,
+ help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
+ + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
+ )
+ parser.add_argument(
+ "--lr_scheduler_min_lr_ratio",
+ type=float,
+ default=None,
+ help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
+ + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
+ )
+ parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
+ parser.add_argument(
+ "--lr_scheduler_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
+ )
+
+ parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
+ # parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
+ # parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
+
+ parser.add_argument(
+ "--dynamo_backend",
+ type=str,
+ default="NO",
+ choices=[e.value for e in DynamoBackend],
+ help="dynamo backend type (default is None) / dynamoのbackendの種類(デフォルトは None)",
+ )
+
+ parser.add_argument(
+ "--dynamo_mode",
+ type=str,
+ default=None,
+ choices=["default", "reduce-overhead", "max-autotune"],
+ help="dynamo mode (default is default) / dynamoのモード(デフォルトは default)",
+ )
+
+ parser.add_argument(
+ "--dynamo_fullgraph",
+ action="store_true",
+ help="use fullgraph mode for dynamo / dynamoのfullgraphモードを使う",
+ )
+
+ parser.add_argument(
+ "--dynamo_dynamic",
+ action="store_true",
+ help="use dynamic mode for dynamo / dynamoのdynamicモードを使う",
+ )
+
+ parser.add_argument(
+ "--blocks_to_swap",
+ type=int,
+ default=None,
+ help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
+ )
+ parser.add_argument(
+ "--img_in_txt_in_offloading",
+ action="store_true",
+ help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
+ )
+
+ # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
+ parser.add_argument(
+ "--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale (HunyuanVideo only)."
+ )
+ parser.add_argument(
+ "--timestep_sampling",
+ choices=["sigma", "uniform", "sigmoid", "shift"],
+ default="sigma",
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
+ )
+ parser.add_argument(
+ "--discrete_flow_shift",
+ type=float,
+ default=1.0,
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
+ )
+ parser.add_argument(
+ "--sigmoid_scale",
+ type=float,
+ default=1.0,
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
+ help="weighting scheme for timestep distribution. Default is none"
+ " / タイムステップ分布の重み付けスキーム、デフォルトはnone",
+ )
+ parser.add_argument(
+ "--logit_mean",
+ type=float,
+ default=0.0,
+ help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
+ )
+ parser.add_argument(
+ "--logit_std",
+ type=float,
+ default=1.0,
+ help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
+ )
+ parser.add_argument(
+ "--min_timestep",
+ type=int,
+ default=None,
+ help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
+ )
+ parser.add_argument(
+ "--max_timestep",
+ type=int,
+ default=None,
+ help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
+ )
+
+ parser.add_argument(
+ "--show_timesteps",
+ type=str,
+ default=None,
+ choices=["image", "console"],
+ help="show timesteps in image or console, and return to console / タイムステップを画像またはコンソールに表示し、コンソールに戻る",
+ )
+
+ # network settings
+ parser.add_argument(
+ "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
+ )
+ parser.add_argument(
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
+ )
+ parser.add_argument(
+ "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
+ )
+ parser.add_argument(
+ "--network_dim",
+ type=int,
+ default=None,
+ help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
+ )
+ parser.add_argument(
+ "--network_alpha",
+ type=float,
+ default=1,
+ help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
+ )
+ parser.add_argument(
+ "--network_dropout",
+ type=float,
+ default=None,
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
+ )
+ parser.add_argument(
+ "--network_args",
+ type=str,
+ default=None,
+ nargs="*",
+ help="additional arguments for network (key=value) / ネットワークへの追加の引数",
+ )
+ parser.add_argument(
+ "--training_comment",
+ type=str,
+ default=None,
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
+ )
+ parser.add_argument(
+ "--dim_from_weights",
+ action="store_true",
+ help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
+ )
+ parser.add_argument(
+ "--scale_weight_norms",
+ type=float,
+ default=None,
+ help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
+ )
+ parser.add_argument(
+ "--base_weights",
+ type=str,
+ default=None,
+ nargs="*",
+ help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
+ )
+ parser.add_argument(
+ "--base_weights_multiplier",
+ type=float,
+ default=None,
+ nargs="*",
+ help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
+ )
+
+ # save and load settings
+ parser.add_argument(
+ "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
+ )
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default=None,
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
+ )
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
+
+ parser.add_argument(
+ "--save_every_n_epochs",
+ type=int,
+ default=None,
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
+ )
+ parser.add_argument(
+ "--save_every_n_steps",
+ type=int,
+ default=None,
+ help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
+ )
+ parser.add_argument(
+ "--save_last_n_epochs",
+ type=int,
+ default=None,
+ help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
+ )
+ parser.add_argument(
+ "--save_last_n_epochs_state",
+ type=int,
+ default=None,
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
+ )
+ parser.add_argument(
+ "--save_last_n_steps",
+ type=int,
+ default=None,
+ help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
+ )
+ parser.add_argument(
+ "--save_last_n_steps_state",
+ type=int,
+ default=None,
+ help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
+ )
+ parser.add_argument(
+ "--save_state",
+ action="store_true",
+ help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
+ )
+ parser.add_argument(
+ "--save_state_on_train_end",
+ action="store_true",
+ help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
+ " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
+ )
+
+ # SAI Model spec
+ parser.add_argument(
+ "--metadata_title",
+ type=str,
+ default=None,
+ help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
+ )
+ parser.add_argument(
+ "--metadata_author",
+ type=str,
+ default=None,
+ help="author name for model metadata / メタデータに書き込まれるモデル作者名",
+ )
+ parser.add_argument(
+ "--metadata_description",
+ type=str,
+ default=None,
+ help="description for model metadata / メタデータに書き込まれるモデル説明",
+ )
+ parser.add_argument(
+ "--metadata_license",
+ type=str,
+ default=None,
+ help="license for model metadata / メタデータに書き込まれるモデルライセンス",
+ )
+ parser.add_argument(
+ "--metadata_tags",
+ type=str,
+ default=None,
+ help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
+ )
+
+ # huggingface settings
+ parser.add_argument(
+ "--huggingface_repo_id",
+ type=str,
+ default=None,
+ help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
+ )
+ parser.add_argument(
+ "--huggingface_repo_type",
+ type=str,
+ default=None,
+ help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
+ )
+ parser.add_argument(
+ "--huggingface_path_in_repo",
+ type=str,
+ default=None,
+ help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
+ )
+ parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
+ parser.add_argument(
+ "--huggingface_repo_visibility",
+ type=str,
+ default=None,
+ help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
+ )
+ parser.add_argument(
+ "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
+ )
+ parser.add_argument(
+ "--resume_from_huggingface",
+ action="store_true",
+ help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
+ )
+ parser.add_argument(
+ "--async_upload",
+ action="store_true",
+ help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
+ )
+
+ parser.add_argument("--dit", type=str, help="DiT checkpoint path / DiTのチェックポイントのパス")
+ parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
+
+ return parser
+
+
+def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
+ if not args.config_file:
+ return args
+
+ config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
+
+ if not os.path.exists(config_path):
+ logger.info(f"{config_path} not found.")
+ exit(1)
+
+ logger.info(f"Loading settings from {config_path}...")
+ with open(config_path, "r", encoding="utf-8") as f:
+ config_dict = toml.load(f)
+
+ # combine all sections into one
+ ignore_nesting_dict = {}
+ for section_name, section_dict in config_dict.items():
+ # if value is not dict, save key and value as is
+ if not isinstance(section_dict, dict):
+ ignore_nesting_dict[section_name] = section_dict
+ continue
+
+ # if value is dict, save all key and value into one dict
+ for key, value in section_dict.items():
+ ignore_nesting_dict[key] = value
+
+ config_args = argparse.Namespace(**ignore_nesting_dict)
+ args = parser.parse_args(namespace=config_args)
+ args.config_file = os.path.splitext(args.config_file)[0]
+ logger.info(args.config_file)
+
+ return args
+
+
+def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """HunyuanVideo specific parser setup"""
+ # model settings
+ parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
+ parser.add_argument(
+ "--vae_tiling",
+ action="store_true",
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
+ " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
+ )
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
+ parser.add_argument(
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser_common()
+ parser = hv_setup_parser(parser)
+
+ args = parser.parse_args()
+ args = read_config_from_file(args, parser)
+
+ args.fp8_scaled = False # HunyuanVideo does not support this yet
+
+ trainer = NetworkTrainer()
+ trainer.train(args)
diff --git a/merge_lora.py b/merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..8528774eb748a37884cbc423b35ff62db90c43cb
--- /dev/null
+++ b/merge_lora.py
@@ -0,0 +1,63 @@
+import argparse
+import logging
+import torch
+from safetensors.torch import load_file
+from networks import lora
+from utils.safetensors_utils import mem_eff_save_file
+from hunyuan_model.models import load_transformer
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="HunyuanVideo model merger script")
+
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier (can specify multiple values)")
+ parser.add_argument("--save_merged_model", type=str, required=True, help="Path to save the merged model")
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for merging")
+
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ device = torch.device(args.device)
+ logger.info(f"Using device: {device}")
+
+ # Load DiT model
+ logger.info(f"Loading DiT model from {args.dit}")
+ transformer = load_transformer(args.dit, "torch", False, "cpu", torch.bfloat16, in_channels=args.dit_in_channels)
+ transformer.eval()
+
+ # Load LoRA weights and merge
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ for i, lora_weight in enumerate(args.lora_weight):
+ # Use the corresponding lora_multiplier or default to 1.0
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
+ lora_multiplier = args.lora_multiplier[i]
+ else:
+ lora_multiplier = 1.0
+
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
+ weights_sd = load_file(lora_weight)
+ network = lora.create_arch_network_from_weights(
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
+ )
+ logger.info("Merging LoRA weights to DiT model")
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
+
+ logger.info("LoRA weights loaded")
+
+ # Save the merged model
+ logger.info(f"Saving merged model to {args.save_merged_model}")
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model)
+ logger.info("Merged model saved")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/custom_offloading_utils.py b/modules/custom_offloading_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d813575af2ce4fcccf4a305c1002bf618844e591
--- /dev/null
+++ b/modules/custom_offloading_utils.py
@@ -0,0 +1,266 @@
+from concurrent.futures import ThreadPoolExecutor
+import gc
+import time
+from typing import Optional
+import torch
+import torch.nn as nn
+
+
+def clean_memory_on_device(device: torch.device):
+ r"""
+ Clean memory on the specified device, will be called from training scripts.
+ """
+ gc.collect()
+
+ # device may "cuda" or "cuda:0", so we need to check the type of device
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ if device.type == "xpu":
+ torch.xpu.empty_cache()
+ if device.type == "mps":
+ torch.mps.empty_cache()
+
+
+def synchronize_device(device: torch.device):
+ if device.type == "cuda":
+ torch.cuda.synchronize()
+ elif device.type == "xpu":
+ torch.xpu.synchronize()
+ elif device.type == "mps":
+ torch.mps.synchronize()
+
+
+def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
+
+ weight_swap_jobs = []
+
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+ else:
+ if module_to_cuda.weight.data.device.type != device.type:
+ # print(
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
+ # )
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
+
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
+
+ stream = torch.cuda.Stream()
+ with torch.cuda.stream(stream):
+ # cuda to cpu
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ cuda_data_view.record_stream(stream)
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
+
+ stream.synchronize()
+
+ # cpu to cuda
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
+ module_to_cuda.weight.data = cuda_data_view
+
+ stream.synchronize()
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
+
+
+def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
+ """
+ not tested
+ """
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
+
+ weight_swap_jobs = []
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
+
+ # device to cpu
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
+
+ synchronize_device()
+
+ # cpu to device
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
+ module_to_cuda.weight.data = cuda_data_view
+
+ synchronize_device()
+
+
+def weighs_to_device(layer: nn.Module, device: torch.device):
+ for module in layer.modules():
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
+
+
+class Offloader:
+ """
+ common offloading class
+ """
+
+ def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
+ self.block_type = block_type
+ self.num_blocks = num_blocks
+ self.blocks_to_swap = blocks_to_swap
+ self.device = device
+ self.debug = debug
+
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
+ self.futures = {}
+ self.cuda_available = device.type == "cuda"
+
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
+ if self.cuda_available:
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
+ else:
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
+
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
+ if self.debug:
+ start_time = time.perf_counter()
+ print(
+ f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
+ )
+
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
+
+ if self.debug:
+ print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
+ return bidx_to_cpu, bidx_to_cuda # , event
+
+ block_to_cpu = blocks[block_idx_to_cpu]
+ block_to_cuda = blocks[block_idx_to_cuda]
+
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
+ )
+
+ def _wait_blocks_move(self, block_idx):
+ if block_idx not in self.futures:
+ return
+
+ if self.debug:
+ print(f"[{self.block_type}] Wait for block {block_idx}")
+ start_time = time.perf_counter()
+
+ future = self.futures.pop(block_idx)
+ _, bidx_to_cuda = future.result()
+
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
+
+ if self.debug:
+ print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
+
+
+class ModelOffloader(Offloader):
+ """
+ supports forward offloading
+ """
+
+ def __init__(
+ self,
+ block_type: str,
+ blocks: list[nn.Module],
+ num_blocks: int,
+ blocks_to_swap: int,
+ supports_backward: bool,
+ device: torch.device,
+ debug: bool = False,
+ ):
+ super().__init__(block_type, num_blocks, blocks_to_swap, device, debug)
+
+ self.supports_backward = supports_backward
+ self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
+
+ if self.supports_backward:
+ # register backward hooks
+ self.remove_handles = []
+ for i, block in enumerate(blocks):
+ hook = self.create_backward_hook(blocks, i)
+ if hook is not None:
+ handle = block.register_full_backward_hook(hook)
+ self.remove_handles.append(handle)
+
+ def set_forward_only(self, forward_only: bool):
+ self.forward_only = forward_only
+
+ def __del__(self):
+ if self.supports_backward:
+ for handle in self.remove_handles:
+ handle.remove()
+
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
+ # -1 for 0-based index
+ num_blocks_propagated = self.num_blocks - block_index - 1
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
+
+ if not swapping and not waiting:
+ return None
+
+ # create hook
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
+ block_idx_to_wait = block_index - 1
+
+ def backward_hook(module, grad_input, grad_output):
+ if self.debug:
+ print(f"Backward hook for block {block_index}")
+
+ if swapping:
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
+ if waiting:
+ self._wait_blocks_move(block_idx_to_wait)
+ return None
+
+ return backward_hook
+
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+
+ if self.debug:
+ print(f"[{self.block_type}] Prepare block devices before forward")
+
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
+ b.to(self.device)
+ weighs_to_device(b, self.device) # make sure weights are on device
+
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
+ b.to(self.device) # move block to device first
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
+
+ synchronize_device(self.device)
+ clean_memory_on_device(self.device)
+
+ def wait_for_block(self, block_idx: int):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self._wait_blocks_move(block_idx)
+
+ def submit_move_blocks_forward(self, blocks: list[nn.Module], block_idx: int):
+ # check if blocks_to_swap is enabled
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+
+ # if supports_backward and backward is enabled, we swap blocks more than blocks_to_swap in backward pass
+ if not self.forward_only and block_idx >= self.blocks_to_swap:
+ return
+
+ block_idx_to_cpu = block_idx
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
+ block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
diff --git a/modules/fp8_optimization_utils.py b/modules/fp8_optimization_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4ac4f8d11200cb715ebeb5b9ff55e26aae76ff
--- /dev/null
+++ b/modules/fp8_optimization_utils.py
@@ -0,0 +1,356 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import logging
+
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+from utils.device_utils import clean_memory_on_device
+
+
+def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
+ """
+ Calculate the maximum representable value in FP8 format.
+ Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign).
+
+ Args:
+ exp_bits (int): Number of exponent bits
+ mantissa_bits (int): Number of mantissa bits
+ sign_bits (int): Number of sign bits (0 or 1)
+
+ Returns:
+ float: Maximum value representable in FP8 format
+ """
+ assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
+
+ # Calculate exponent bias
+ bias = 2 ** (exp_bits - 1) - 1
+
+ # Calculate maximum mantissa value
+ mantissa_max = 1.0
+ for i in range(mantissa_bits - 1):
+ mantissa_max += 2 ** -(i + 1)
+
+ # Calculate maximum value
+ max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
+
+ return max_value
+
+
+def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None):
+ """
+ Quantize a tensor to FP8 format.
+
+ Args:
+ tensor (torch.Tensor): Tensor to quantize
+ scale (float or torch.Tensor): Scale factor
+ exp_bits (int): Number of exponent bits
+ mantissa_bits (int): Number of mantissa bits
+ sign_bits (int): Number of sign bits
+
+ Returns:
+ tuple: (quantized_tensor, scale_factor)
+ """
+ # Create scaled tensor
+ scaled_tensor = tensor / scale
+
+ # Calculate FP8 parameters
+ bias = 2 ** (exp_bits - 1) - 1
+
+ if max_value is None:
+ # Calculate max and min values
+ max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits)
+ min_value = -max_value if sign_bits > 0 else 0.0
+
+ # Clamp tensor to range
+ clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value)
+
+ # Quantization process
+ abs_values = torch.abs(clamped_tensor)
+ nonzero_mask = abs_values > 0
+
+ # Calculate log scales (only for non-zero elements)
+ log_scales = torch.zeros_like(clamped_tensor)
+ if nonzero_mask.any():
+ log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach()
+
+ # Limit log scales and calculate quantization factor
+ log_scales = torch.clamp(log_scales, min=1.0)
+ quant_factor = 2.0 ** (log_scales - mantissa_bits - bias)
+
+ # Quantize and dequantize
+ quantized = torch.round(clamped_tensor / quant_factor) * quant_factor
+
+ return quantized, scale
+
+
+def optimize_state_dict_with_fp8(
+ state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False
+):
+ """
+ Optimize Linear layer weights in a model's state dict to FP8 format.
+
+ Args:
+ state_dict (dict): State dict to optimize, replaced in-place
+ calc_device (str): Device to quantize tensors on
+ target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
+ exclude_layer_keys (list, optional): Layer key patterns to exclude
+ exp_bits (int): Number of exponent bits
+ mantissa_bits (int): Number of mantissa bits
+ move_to_device (bool): Move optimized tensors to the calculating device
+
+ Returns:
+ dict: FP8 optimized state dict
+ """
+ if exp_bits == 4 and mantissa_bits == 3:
+ fp8_dtype = torch.float8_e4m3fn
+ elif exp_bits == 5 and mantissa_bits == 2:
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
+
+ # Calculate FP8 max value
+ max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
+ min_value = -max_value # this function supports only signed FP8
+
+ # Create optimized state dict
+ optimized_count = 0
+
+ # Enumerate tarket keys
+ target_state_dict_keys = []
+ for key in state_dict.keys():
+ # Check if it's a weight key and matches target patterns
+ is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
+ is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
+ is_target = is_target and not is_excluded
+
+ if is_target and isinstance(state_dict[key], torch.Tensor):
+ target_state_dict_keys.append(key)
+
+ # Process each key
+ for key in tqdm(target_state_dict_keys):
+ value = state_dict[key]
+
+ # Save original device and dtype
+ original_device = value.device
+ original_dtype = value.dtype
+
+ # Move to calculation device
+ if calc_device is not None:
+ value = value.to(calc_device)
+
+ # Calculate scale factor
+ scale = torch.max(torch.abs(value.flatten())) / max_value
+ # print(f"Optimizing {key} with scale: {scale}")
+
+ # Quantize weight to FP8
+ quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
+
+ # Add to state dict using original key for weight and new key for scale
+ fp8_key = key # Maintain original key
+ scale_key = key.replace(".weight", ".scale_weight")
+
+ quantized_weight = quantized_weight.to(fp8_dtype)
+
+ if not move_to_device:
+ quantized_weight = quantized_weight.to(original_device)
+
+ scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
+
+ state_dict[fp8_key] = quantized_weight
+ state_dict[scale_key] = scale_tensor
+
+ optimized_count += 1
+
+ if calc_device is not None: # optimized_count % 10 == 0 and
+ # free memory on calculation device
+ clean_memory_on_device(calc_device)
+
+ logger.info(f"Number of optimized Linear layers: {optimized_count}")
+ return state_dict
+
+
+def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
+ """
+ Patched forward method for Linear layers with FP8 weights.
+
+ Args:
+ self: Linear layer instance
+ x (torch.Tensor): Input tensor
+ use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
+ max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
+
+ Returns:
+ torch.Tensor: Result of linear transformation
+ """
+ if use_scaled_mm:
+ input_dtype = x.dtype
+ original_weight_dtype = self.scale_weight.dtype
+ weight_dtype = self.weight.dtype
+ target_dtype = torch.float8_e5m2
+ assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported"
+ assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
+
+ if max_value is None:
+ # no input quantization
+ scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
+ else:
+ # calculate scale factor for input tensor
+ scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
+
+ # quantize input tensor to FP8: this seems to consume a lot of memory
+ x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value)
+
+ original_shape = x.shape
+ x = x.reshape(-1, x.shape[2]).to(target_dtype)
+
+ weight = self.weight.t()
+ scale_weight = self.scale_weight.to(torch.float32)
+
+ if self.bias is not None:
+ # float32 is not supported with bias in scaled_mm
+ o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
+ else:
+ o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
+
+ return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype)
+
+ else:
+ # Dequantize the weight
+ original_dtype = self.scale_weight.dtype
+ dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
+
+ # Perform linear transformation
+ if self.bias is not None:
+ output = F.linear(x, dequantized_weight, self.bias)
+ else:
+ output = F.linear(x, dequantized_weight)
+
+ return output
+
+
+def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
+ """
+ Apply monkey patching to a model using FP8 optimized state dict.
+
+ Args:
+ model (nn.Module): Model instance to patch
+ optimized_state_dict (dict): FP8 optimized state dict
+ use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
+
+ Returns:
+ nn.Module: The patched model (same instance, modified in-place)
+ """
+ # # Calculate FP8 float8_e5m2 max value
+ # max_value = calculate_fp8_maxval(5, 2)
+ max_value = None # do not quantize input tensor
+
+ # Find all scale keys to identify FP8-optimized layers
+ scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
+
+ # Enumerate patched layers
+ patched_module_paths = set()
+ for scale_key in scale_keys:
+ # Extract module path from scale key (remove .scale_weight)
+ module_path = scale_key.rsplit(".scale_weight", 1)[0]
+ patched_module_paths.add(module_path)
+
+ patched_count = 0
+
+ # Apply monkey patch to each layer with FP8 weights
+ for name, module in model.named_modules():
+ # Check if this module has a corresponding scale_weight
+ has_scale = name in patched_module_paths
+
+ # Apply patch if it's a Linear layer with FP8 scale
+ if isinstance(module, nn.Linear) and has_scale:
+ # register the scale_weight as a buffer to load the state_dict
+ module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
+
+ # Create a new forward method with the patched version.
+ def new_forward(self, x):
+ return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
+
+ # Bind method to module
+ module.forward = new_forward.__get__(module, type(module))
+
+ patched_count += 1
+
+ logger.info(f"Number of monkey-patched Linear layers: {patched_count}")
+ return model
+
+
+# Example usage
+def example_usage():
+ # Small test model
+ class TestModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ fc1 = nn.Linear(768, 3072)
+ act1 = nn.GELU()
+ fc2 = nn.Linear(3072, 768)
+ act2 = nn.GELU()
+ fc3 = nn.Linear(768, 768)
+
+ # Set layer names for testing
+ self.single_blocks = nn.ModuleList([fc1, act1, fc2, act2, fc3])
+
+ self.fc4 = nn.Linear(768, 128)
+
+ def forward(self, x):
+ for layer in self.single_blocks:
+ x = layer(x)
+ x = self.fc4(x)
+ return x
+
+ # Instantiate model
+ test_model = TestModel()
+ test_model.to(torch.float16) # convert to FP16 for testing
+
+ # Test input tensor
+ test_input = torch.randn(1, 768, dtype=torch.float16)
+
+ # Calculate output before optimization
+ with torch.no_grad():
+ original_output = test_model(test_input)
+ print("original output", original_output[0, :5])
+
+ # Get state dict
+ state_dict = test_model.state_dict()
+
+ # Apply FP8 optimization to state dict
+ cuda_device = torch.device("cuda")
+ optimized_state_dict = optimize_state_dict_with_fp8(state_dict, cuda_device, ["single_blocks"], ["2"])
+
+ # Apply monkey patching to the model
+ optimized_model = TestModel() # re-instantiate model
+ optimized_model.to(torch.float16) # convert to FP16 for testing
+ apply_fp8_monkey_patch(optimized_model, optimized_state_dict)
+
+ # Load optimized state dict
+ optimized_model.load_state_dict(optimized_state_dict, strict=True, assign=True) # assign=True to load buffer
+
+ # Calculate output after optimization
+ with torch.no_grad():
+ optimized_output = optimized_model(test_input)
+ print("optimized output", optimized_output[0, :5])
+
+ # Compare accuracy
+ error = torch.mean(torch.abs(original_output - optimized_output))
+ print(f"Mean absolute error: {error.item()}")
+
+ # Check memory usage
+ original_params = sum(p.nelement() * p.element_size() for p in test_model.parameters()) / (1024 * 1024)
+ print(f"Model parameter memory: {original_params:.2f} MB")
+ optimized_params = sum(p.nelement() * p.element_size() for p in optimized_model.parameters()) / (1024 * 1024)
+ print(f"Optimized model parameter memory: {optimized_params:.2f} MB")
+
+ return test_model
+
+
+if __name__ == "__main__":
+ example_usage()
diff --git a/modules/scheduling_flow_match_discrete.py b/modules/scheduling_flow_match_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..c507ec4eb050463188e250c20aec8d1fde2c4a5d
--- /dev/null
+++ b/modules/scheduling_flow_match_discrete.py
@@ -0,0 +1,257 @@
+# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FlowMatchDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ reverse (`bool`, defaults to `True`):
+ Whether to reverse the timestep schedule.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ reverse: bool = True,
+ solver: str = "euler",
+ n_tokens: Optional[int] = None,
+ ):
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
+
+ if not reverse:
+ sigmas = sigmas.flip(0)
+
+ self.sigmas = sigmas
+ # the value fed to model
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
+
+ self._step_index = None
+ self._begin_index = None
+
+ self.supported_solver = ["euler"]
+ if solver not in self.supported_solver:
+ raise ValueError(
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
+ )
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ n_tokens: int = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ n_tokens (`int`, *optional*):
+ Number of tokens in the input sequence.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
+ sigmas = self.sd3_time_shift(sigmas)
+
+ if not self.config.reverse:
+ sigmas = 1 - sigmas
+
+ self.sigmas = sigmas
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
+ dtype=torch.float32, device=device
+ )
+
+ # Reset step index
+ self._step_index = None
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def scale_model_input(
+ self, sample: torch.Tensor, timestep: Optional[int] = None
+ ) -> torch.Tensor:
+ return sample
+
+ def sd3_time_shift(self, t: torch.Tensor):
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ n_tokens (`int`, *optional*):
+ Number of tokens in the input sequence.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
+
+ if self.config.solver == "euler":
+ prev_sample = sample + model_output.to(torch.float32) * dt
+ else:
+ raise ValueError(
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
+ )
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/modules/unet_causal_3d_blocks.py b/modules/unet_causal_3d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..27d544170ece6a370cdacfe9e31367b884c2e516
--- /dev/null
+++ b/modules/unet_causal_3d_blocks.py
@@ -0,0 +1,818 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Modified from diffusers==0.29.2
+#
+# ==============================================================================
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from einops import rearrange
+
+from diffusers.utils import logging
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import SpatialNorm
+from diffusers.models.attention_processor import Attention
+from diffusers.models.normalization import AdaGroupNorm
+from diffusers.models.normalization import RMSNorm
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
+ seq_len = n_frame * n_hw
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
+ for i in range(seq_len):
+ i_frame = i // n_hw
+ mask[i, : (i_frame + 1) * n_hw] = 0
+ if batch_size is not None:
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
+ return mask
+
+
+class CausalConv3d(nn.Module):
+ """
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
+ This maintains temporal causality in video generation tasks.
+ """
+
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ pad_mode="replicate",
+ chunk_size=0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
+ self.time_causal_padding = padding
+ self.chunk_size = chunk_size
+
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def original_forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+ def forward(self, x):
+ if self.chunk_size == 0:
+ return self.original_forward(x)
+
+ # if not large, call original forward
+ if x.shape[4] < self.chunk_size * 1.5:
+ return self.original_forward(x)
+
+ # # debug: verify the original forward is the same as chunked forward
+ # orig_forwarded_value = None
+ # if x.shape[4] < self.chunk_size * 4:
+ # orig_forwarded_value = self.original_forward(x)
+
+ # get the kernel size
+ kernel_size = self.conv.kernel_size[0] # assume cubic kernel
+ assert kernel_size == self.conv.kernel_size[1] == self.conv.kernel_size[2], "Only cubic kernels are supported"
+ padding_size = kernel_size // 2 # 1 for kernel_size=3, 0 for kernel_size=1
+
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+
+ B, C, D, H, W = orig_shape = x.shape
+ chunk_size = self.chunk_size
+ chunk_size -= chunk_size % self.conv.stride[2] # make sure the chunk size is divisible by stride
+ # print(f"chunked forward: {x.shape}, chunk_size: {chunk_size}")
+
+ # calculate the indices for chunking with overlap and padding by kernel size and stride
+ indices = []
+ i = 0
+ while i < W - padding_size:
+ start_idx = i - padding_size
+ end_idx = min(i + chunk_size + padding_size, W)
+ if i == 0:
+ start_idx = 0
+ end_idx += padding_size # to make sure the first chunk is divisible by stride
+ if W - end_idx < chunk_size // 2: # small chunk at the end
+ end_idx = W
+ indices.append((start_idx, end_idx))
+ i = end_idx - padding_size
+ # print(f"chunked forward: {x.shape}, chunked indices: {indices}")
+
+ chunks = []
+ for start_idx, end_idx in indices:
+ chunk = x[:, :, :, :, start_idx:end_idx]
+ chunk_output = self.conv(chunk)
+ # print(chunk.shape, chunk_output.shape)
+ chunks.append(chunk_output)
+
+ # concatenate the chunks
+ x = torch.cat(chunks, dim=4)
+
+ assert (
+ x.shape[2] == ((D - padding_size * 2) + self.conv.stride[0] - 1) // self.conv.stride[0]
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
+ assert (
+ x.shape[3] == ((H - padding_size * 2) + self.conv.stride[1] - 1) // self.conv.stride[1]
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
+ assert (
+ x.shape[4] == ((W - padding_size * 2) + self.conv.stride[2] - 1) // self.conv.stride[2]
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
+
+ # # debug: verify the original forward is the same as chunked forward
+ # if orig_forwarded_value is not None:
+ # assert torch.allclose(
+ # orig_forwarded_value, x, rtol=1e-4, atol=1e-2
+ # ), f"Chunked forward is different from original forward. {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}, {self.conv.kernel_size}"
+
+ return x
+
+
+class UpsampleCausal3D(nn.Module):
+ """
+ A 3D upsampling layer with an optional convolution.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ use_conv_transpose: bool = False,
+ out_channels: Optional[int] = None,
+ name: str = "conv",
+ kernel_size: Optional[int] = None,
+ padding=1,
+ norm_type=None,
+ eps=None,
+ elementwise_affine=None,
+ bias=True,
+ interpolate=True,
+ upsample_factor=(2, 2, 2),
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+ self.interpolate = interpolate
+ self.upsample_factor = upsample_factor
+
+ if norm_type == "ln_norm":
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
+ elif norm_type is None:
+ self.norm = None
+ else:
+ raise ValueError(f"unknown norm_type: {norm_type}")
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ if kernel_size is None:
+ kernel_size = 3
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ output_size: Optional[int] = None,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ assert hidden_states.shape[1] == self.channels
+
+ if self.norm is not None:
+ raise NotImplementedError
+
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if self.interpolate:
+ B, C, T, H, W = hidden_states.shape
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
+ if output_size is None:
+ if T > 1:
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
+
+ first_h = first_h.squeeze(2)
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
+ first_h = first_h.unsqueeze(2)
+ else:
+ raise NotImplementedError
+
+ if T > 1:
+ hidden_states = torch.cat((first_h, other_h), dim=2)
+ else:
+ hidden_states = first_h
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class DownsampleCausal3D(nn.Module):
+ """
+ A 3D downsampling layer with an optional convolution.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ name: str = "conv",
+ kernel_size=3,
+ norm_type=None,
+ eps=None,
+ elementwise_affine=None,
+ bias=True,
+ stride=2,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = stride
+ self.name = name
+
+ if norm_type == "ln_norm":
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
+ elif norm_type is None:
+ self.norm = None
+ else:
+ raise ValueError(f"unknown norm_type: {norm_type}")
+
+ if use_conv:
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ assert hidden_states.shape[1] == self.channels
+
+ if self.norm is not None:
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ assert hidden_states.shape[1] == self.channels
+
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlockCausal3D(nn.Module):
+ r"""
+ A Resnet block.
+ """
+
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ conv_shortcut: bool = False,
+ dropout: float = 0.0,
+ temb_channels: int = 512,
+ groups: int = 32,
+ groups_out: Optional[int] = None,
+ pre_norm: bool = True,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ skip_time_act: bool = False,
+ # default, scale_shift, ada_group, spatial
+ time_embedding_norm: str = "default",
+ kernel: Optional[torch.FloatTensor] = None,
+ output_scale_factor: float = 1.0,
+ use_in_shortcut: Optional[bool] = None,
+ up: bool = False,
+ down: bool = False,
+ conv_shortcut_bias: bool = True,
+ conv_3d_out_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+ self.time_embedding_norm = time_embedding_norm
+ self.skip_time_act = skip_time_act
+
+ linear_cls = nn.Linear
+
+ if groups_out is None:
+ groups_out = groups
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
+ elif self.time_embedding_norm == "spatial":
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
+ else:
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
+ elif self.time_embedding_norm == "scale_shift":
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ self.time_emb_proj = None
+ else:
+ raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
+ else:
+ self.time_emb_proj = None
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
+ elif self.time_embedding_norm == "spatial":
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
+ else:
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
+ self.dropout = torch.nn.Dropout(dropout)
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.upsample = self.downsample = None
+ if self.up:
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
+ elif self.down:
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
+
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = CausalConv3d(
+ in_channels,
+ conv_3d_out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=conv_shortcut_bias,
+ )
+
+ def forward(
+ self,
+ input_tensor: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ hidden_states = input_tensor
+
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ hidden_states = self.norm1(hidden_states, temb)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ input_tensor = input_tensor.contiguous()
+ hidden_states = hidden_states.contiguous()
+ input_tensor = self.upsample(input_tensor, scale=scale)
+ hidden_states = self.upsample(hidden_states, scale=scale)
+ elif self.downsample is not None:
+ input_tensor = self.downsample(input_tensor, scale=scale)
+ hidden_states = self.downsample(hidden_states, scale=scale)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if self.time_emb_proj is not None:
+ if not self.skip_time_act:
+ temb = self.nonlinearity(temb)
+ temb = self.time_emb_proj(temb, scale)[:, :, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ hidden_states = self.norm2(hidden_states, temb)
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+def get_down_block3d(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ add_downsample: bool,
+ downsample_stride: int,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ downsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownEncoderBlockCausal3D":
+ return DownEncoderBlockCausal3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ downsample_stride=downsample_stride,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block3d(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ add_upsample: bool,
+ upsample_scale_factor: Tuple,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resolution_idx: Optional[int] = None,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ upsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+) -> nn.Module:
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpDecoderBlockCausal3D":
+ return UpDecoderBlockCausal3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ upsample_scale_factor=upsample_scale_factor,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlockCausal3D(nn.Module):
+ """
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ attn_groups: Optional[int] = None,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ if attn_groups is None:
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+ )
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=attn_groups,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ B, C, T, H, W = hidden_states.shape
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
+ attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class DownEncoderBlockCausal3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_stride: int = 2,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ DownsampleCausal3D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ stride=downsample_stride,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale)
+
+ return hidden_states
+
+
+class UpDecoderBlockCausal3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ upsample_scale_factor=(2, 2, 2),
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ UpsampleCausal3D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ upsample_factor=upsample_scale_factor,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
diff --git a/networks/__init__.py b/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/networks/lora.py b/networks/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..971828958abc6d7a47e4bf3103f75abaf7299700
--- /dev/null
+++ b/networks/lora.py
@@ -0,0 +1,913 @@
+# LoRA network module: currently conv2d is not fully supported
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+
+import ast
+import math
+import os
+import re
+from typing import Dict, List, Optional, Type, Union
+from transformers import CLIPTextModel
+import numpy as np
+import torch
+import torch.nn as nn
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"]
+
+
+class LoRAModule(torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ split_dims: Optional[List[int]] = None,
+ ):
+ """
+ if alpha == 0 or None, alpha is rank (no scaling).
+
+ split_dims is used to mimic the split qkv of multi-head attention.
+ """
+ super().__init__()
+ self.lora_name = lora_name
+
+ if org_module.__class__.__name__ == "Conv2d":
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ self.lora_dim = lora_dim
+ self.split_dims = split_dims
+
+ if split_dims is None:
+ if org_module.__class__.__name__ == "Conv2d":
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+ else:
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.zeros_(self.lora_up.weight)
+ else:
+ # conv2d not supported
+ assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
+ assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
+ # print(f"split_dims: {split_dims}")
+ self.lora_down = torch.nn.ModuleList(
+ [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
+ )
+ self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
+ for lora_down in self.lora_down:
+ torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
+ for lora_up in self.lora_up:
+ torch.nn.init.zeros_(lora_up.weight)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha)) # for save/load
+
+ # same as microsoft's
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def forward(self, x):
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ if self.split_dims is None:
+ lx = self.lora_down(x)
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if len(lx.size()) == 3:
+ mask = mask.unsqueeze(1) # for Text Encoder
+ elif len(lx.size()) == 4:
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ return org_forwarded + lx * self.multiplier * scale
+ else:
+ lxs = [lora_down(x) for lora_down in self.lora_down]
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
+ for i in range(len(lxs)):
+ if len(lx.size()) == 3:
+ masks[i] = masks[i].unsqueeze(1)
+ elif len(lx.size()) == 4:
+ masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
+ lxs[i] = lxs[i] * masks[i]
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
+
+ return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
+
+
+class LoRAInfModule(LoRAModule):
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ **kwargs,
+ ):
+ # no dropout for inference
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
+
+ self.org_module_ref = [org_module] # for reference
+ self.enabled = True
+ self.network: LoRANetwork = None
+
+ def set_network(self, network):
+ self.network = network
+
+ # merge weight to org_module
+ # def merge_to(self, sd, dtype, device, non_blocking=False):
+ # if torch.cuda.is_available():
+ # stream = torch.cuda.Stream(device=device)
+ # with torch.cuda.stream(stream):
+ # print(f"merge_to {self.lora_name}")
+ # self._merge_to(sd, dtype, device, non_blocking)
+ # torch.cuda.synchronize(device=device)
+ # print(f"merge_to {self.lora_name} done")
+ # torch.cuda.empty_cache()
+ # else:
+ # self._merge_to(sd, dtype, device, non_blocking)
+
+ def merge_to(self, sd, dtype, device, non_blocking=False):
+ # extract weight from org_module
+ org_sd = self.org_module.state_dict()
+ weight = org_sd["weight"]
+ org_dtype = weight.dtype
+ org_device = weight.device
+ weight = weight.to(device, dtype=torch.float, non_blocking=non_blocking) # for calculation
+
+ if dtype is None:
+ dtype = org_dtype
+ if device is None:
+ device = org_device
+
+ if self.split_dims is None:
+ # get up/down weight
+ down_weight = sd["lora_down.weight"].to(device, dtype=torch.float, non_blocking=non_blocking)
+ up_weight = sd["lora_up.weight"].to(device, dtype=torch.float, non_blocking=non_blocking)
+
+ # merge weight
+ if len(weight.size()) == 2:
+ # linear
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ weight
+ + self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
+ weight = weight + self.multiplier * conved * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(org_device, dtype=dtype) # back to CPU without non_blocking
+ self.org_module.load_state_dict(org_sd)
+ else:
+ # split_dims
+ total_dims = sum(self.split_dims)
+ for i in range(len(self.split_dims)):
+ # get up/down weight
+ down_weight = sd[f"lora_down.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (rank, in_dim)
+ up_weight = sd[f"lora_up.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (split dim, rank)
+
+ # pad up_weight -> (total_dims, rank)
+ padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
+ padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
+
+ # merge weight
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(org_device, dtype) # back to CPU without non_blocking
+ self.org_module.load_state_dict(org_sd)
+
+ # return weight for merge
+ def get_weight(self, multiplier=None):
+ if multiplier is None:
+ multiplier = self.multiplier
+
+ # get up/down weight from module
+ up_weight = self.lora_up.weight.to(torch.float)
+ down_weight = self.lora_down.weight.to(torch.float)
+
+ # pre-calculated weight
+ if len(down_weight.size()) == 2:
+ # linear
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ weight = self.multiplier * conved * self.scale
+
+ return weight
+
+ def default_forward(self, x):
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
+ if self.split_dims is None:
+ lx = self.lora_down(x)
+ lx = self.lora_up(lx)
+ return self.org_forward(x) + lx * self.multiplier * self.scale
+ else:
+ lxs = [lora_down(x) for lora_down in self.lora_down]
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
+ return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
+
+ def forward(self, x):
+ if not self.enabled:
+ return self.org_forward(x)
+ return self.default_forward(x)
+
+
+def create_arch_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae: nn.Module,
+ text_encoders: List[nn.Module],
+ unet: nn.Module,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ # add default exclude patterns
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is None:
+ exclude_patterns = []
+ else:
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+
+ # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
+ exclude_patterns.append(r".*(img_mod|txt_mod|modulation).*")
+
+ kwargs["exclude_patterns"] = exclude_patterns
+
+ return create_network(
+ HUNYUAN_TARGET_REPLACE_MODULES,
+ "lora_unet",
+ multiplier,
+ network_dim,
+ network_alpha,
+ vae,
+ text_encoders,
+ unet,
+ neuron_dropout=neuron_dropout,
+ **kwargs,
+ )
+
+
+def create_network(
+ target_replace_modules: List[str],
+ prefix: str,
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae: nn.Module,
+ text_encoders: List[nn.Module],
+ unet: nn.Module,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ """ architecture independent network creation """
+ if network_dim is None:
+ network_dim = 4 # default
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ # extract dim/alpha for conv2d, and block dim
+ conv_dim = kwargs.get("conv_dim", None)
+ conv_alpha = kwargs.get("conv_alpha", None)
+ if conv_dim is not None:
+ conv_dim = int(conv_dim)
+ if conv_alpha is None:
+ conv_alpha = 1.0
+ else:
+ conv_alpha = float(conv_alpha)
+
+ # TODO generic rank/dim setting with regular expression
+
+ # rank/module dropout
+ rank_dropout = kwargs.get("rank_dropout", None)
+ if rank_dropout is not None:
+ rank_dropout = float(rank_dropout)
+ module_dropout = kwargs.get("module_dropout", None)
+ if module_dropout is not None:
+ module_dropout = float(module_dropout)
+
+ # verbose
+ verbose = kwargs.get("verbose", False)
+ if verbose is not None:
+ verbose = True if verbose == "True" else False
+
+ # regular expression for module selection: exclude and include
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is not None and isinstance(exclude_patterns, str):
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+ include_patterns = kwargs.get("include_patterns", None)
+ if include_patterns is not None and isinstance(include_patterns, str):
+ include_patterns = ast.literal_eval(include_patterns)
+
+ # too many arguments ( ^ω^)・・・
+ network = LoRANetwork(
+ target_replace_modules,
+ prefix,
+ text_encoders,
+ unet,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ conv_lora_dim=conv_dim,
+ conv_alpha=conv_alpha,
+ exclude_patterns=exclude_patterns,
+ include_patterns=include_patterns,
+ verbose=verbose,
+ )
+
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
+ # loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
+ # loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
+ # loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
+ # loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
+ if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
+
+ return network
+
+
+class LoRANetwork(torch.nn.Module):
+ # only supports U-Net (DiT), Text Encoders are not supported
+
+ def __init__(
+ self,
+ target_replace_modules: List[str],
+ prefix: str,
+ text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
+ unet: nn.Module,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ conv_lora_dim: Optional[int] = None,
+ conv_alpha: Optional[float] = None,
+ module_class: Type[object] = LoRAModule,
+ modules_dim: Optional[Dict[str, int]] = None,
+ modules_alpha: Optional[Dict[str, int]] = None,
+ exclude_patterns: Optional[List[str]] = None,
+ include_patterns: Optional[List[str]] = None,
+ verbose: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+ self.multiplier = multiplier
+
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.conv_lora_dim = conv_lora_dim
+ self.conv_alpha = conv_alpha
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.target_replace_modules = target_replace_modules
+ self.prefix = prefix
+
+ self.loraplus_lr_ratio = None
+ # self.loraplus_unet_lr_ratio = None
+ # self.loraplus_text_encoder_lr_ratio = None
+
+ if modules_dim is not None:
+ logger.info(f"create LoRA network from weights")
+ else:
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ logger.info(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
+ )
+ # if self.conv_lora_dim is not None:
+ # logger.info(
+ # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
+ # )
+ # if train_t5xxl:
+ # logger.info(f"train T5XXL as well")
+
+ # compile regular expression if specified
+ exclude_re_patterns = []
+ if exclude_patterns is not None:
+ for pattern in exclude_patterns:
+ try:
+ re_pattern = re.compile(pattern)
+ except re.error as e:
+ logger.error(f"Invalid exclude pattern '{pattern}': {e}")
+ continue
+ exclude_re_patterns.append(re_pattern)
+
+ include_re_patterns = []
+ if include_patterns is not None:
+ for pattern in include_patterns:
+ try:
+ re_pattern = re.compile(pattern)
+ except re.error as e:
+ logger.error(f"Invalid include pattern '{pattern}': {e}")
+ continue
+ include_re_patterns.append(re_pattern)
+
+ # create module instances
+ def create_modules(
+ is_unet: bool,
+ pfx: str,
+ root_module: torch.nn.Module,
+ target_replace_mods: Optional[List[str]] = None,
+ filter: Optional[str] = None,
+ default_dim: Optional[int] = None,
+ ) -> List[LoRAModule]:
+ loras = []
+ skipped = []
+ for name, module in root_module.named_modules():
+ if target_replace_mods is None or module.__class__.__name__ in target_replace_mods:
+ if target_replace_mods is None: # dirty hack for all modules
+ module = root_module # search all modules
+
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ == "Linear"
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+ if is_linear or is_conv2d:
+ original_name = (name + "." if name else "") + child_name
+ lora_name = f"{pfx}.{original_name}".replace(".", "_")
+
+ # exclude/include filter
+ excluded = False
+ for pattern in exclude_re_patterns:
+ if pattern.match(original_name):
+ excluded = True
+ break
+ included = False
+ for pattern in include_re_patterns:
+ if pattern.match(original_name):
+ included = True
+ break
+ if excluded and not included:
+ if verbose:
+ logger.info(f"exclude: {original_name}")
+ continue
+
+ # filter by name (not used in the current implementation)
+ if filter is not None and not filter in lora_name:
+ continue
+
+ dim = None
+ alpha = None
+
+ if modules_dim is not None:
+ # モジュール指定あり
+ if lora_name in modules_dim:
+ dim = modules_dim[lora_name]
+ alpha = modules_alpha[lora_name]
+ else:
+ # 通常、すべて対象とする
+ if is_linear or is_conv2d_1x1:
+ dim = default_dim if default_dim is not None else self.lora_dim
+ alpha = self.alpha
+ elif self.conv_lora_dim is not None:
+ dim = self.conv_lora_dim
+ alpha = self.conv_alpha
+
+ if dim is None or dim == 0:
+ # skipした情報を出力
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha,
+ dropout=dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ )
+ loras.append(lora)
+
+ if target_replace_mods is None:
+ break # all modules are searched
+ return loras, skipped
+
+ # # create LoRA for text encoder
+ # # it is redundant to create LoRA modules even if they are not used
+
+ self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
+ # skipped_te = []
+ # for i, text_encoder in enumerate(text_encoders):
+ # index = i
+ # if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
+ # break
+ # logger.info(f"create LoRA for Text Encoder {index+1}:")
+ # text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+ # logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
+ # self.text_encoder_loras.extend(text_encoder_loras)
+ # skipped_te += skipped
+
+ # create LoRA for U-Net
+ self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
+ self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules)
+
+ logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.")
+ if verbose:
+ for lora in self.unet_loras:
+ logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
+
+ skipped = skipped_un
+ if verbose and len(skipped) > 0:
+ logger.warning(
+ f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
+ )
+ for name in skipped:
+ logger.info(f"\t{name}")
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ def prepare_network(self, args):
+ """
+ called after the network is created
+ """
+ pass
+
+ def set_multiplier(self, multiplier):
+ self.multiplier = multiplier
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.multiplier = self.multiplier
+
+ def set_enabled(self, is_enabled):
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.enabled = is_enabled
+
+ def load_weights(self, file):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ info = self.load_state_dict(weights_sd, False)
+ return info
+
+ def apply_to(
+ self,
+ text_encoders: Optional[nn.Module],
+ unet: Optional[nn.Module],
+ apply_text_encoder: bool = True,
+ apply_unet: bool = True,
+ ):
+ if apply_text_encoder:
+ logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.apply_to()
+ self.add_module(lora.lora_name, lora)
+
+ # マージできるかどうかを返す
+ def is_mergeable(self):
+ return True
+
+ # TODO refactor to common function with apply_to
+ def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None, non_blocking=False):
+ from concurrent.futures import ThreadPoolExecutor
+
+ with ThreadPoolExecutor(max_workers=2) as executor: # 2 workers is enough
+ futures = []
+ for lora in self.text_encoder_loras + self.unet_loras:
+ sd_for_lora = {}
+ for key in weights_sd.keys():
+ if key.startswith(lora.lora_name):
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+ if len(sd_for_lora) == 0:
+ logger.info(f"no weight for {lora.lora_name}")
+ continue
+
+ # lora.merge_to(sd_for_lora, dtype, device)
+ futures.append(executor.submit(lora.merge_to, sd_for_lora, dtype, device, non_blocking))
+
+ for future in futures:
+ future.result()
+
+ logger.info(f"weights are merged")
+
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
+ self.loraplus_lr_ratio = loraplus_lr_ratio
+
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}")
+ # logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
+
+ def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs):
+ self.requires_grad_(True)
+
+ all_params = []
+ lr_descriptions = []
+
+ def assemble_params(loras, lr, loraplus_ratio):
+ param_groups = {"lora": {}, "plus": {}}
+ for lora in loras:
+ for name, param in lora.named_parameters():
+ if loraplus_ratio is not None and "lora_up" in name:
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
+ else:
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
+
+ params = []
+ descriptions = []
+ for key in param_groups.keys():
+ param_data = {"params": param_groups[key].values()}
+
+ if len(param_data["params"]) == 0:
+ continue
+
+ if lr is not None:
+ if key == "plus":
+ param_data["lr"] = lr * loraplus_ratio
+ else:
+ param_data["lr"] = lr
+
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
+ logger.info("NO LR skipping!")
+ continue
+
+ params.append(param_data)
+ descriptions.append("plus" if key == "plus" else "")
+
+ return params, descriptions
+
+ if self.unet_loras:
+ params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio)
+ all_params.extend(params)
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
+
+ return all_params, lr_descriptions
+
+ def enable_gradient_checkpointing(self):
+ # not supported
+ pass
+
+ def prepare_grad_etc(self, unet):
+ self.requires_grad_(True)
+
+ def on_epoch_start(self, unet):
+ self.train()
+
+ def on_step_start(self):
+ pass
+
+ def get_trainable_params(self):
+ return self.parameters()
+
+ def save_weights(self, file, dtype, metadata):
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+
+ if dtype is not None:
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ state_dict[key] = v
+
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+ from utils import model_utils
+
+ # Precalculate model hashes to save time on indexing
+ if metadata is None:
+ metadata = {}
+ model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ save_file(state_dict, file, metadata)
+ else:
+ torch.save(state_dict, file)
+
+ def backup_weights(self):
+ # 重みのバックアップを行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not hasattr(org_module, "_lora_org_weight"):
+ sd = org_module.state_dict()
+ org_module._lora_org_weight = sd["weight"].detach().clone()
+ org_module._lora_restored = True
+
+ def restore_weights(self):
+ # 重みのリストアを行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not org_module._lora_restored:
+ sd = org_module.state_dict()
+ sd["weight"] = org_module._lora_org_weight
+ org_module.load_state_dict(sd)
+ org_module._lora_restored = True
+
+ def pre_calculation(self):
+ # 事前計算を行う
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ sd = org_module.state_dict()
+
+ org_weight = sd["weight"]
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+ sd["weight"] = org_weight + lora_weight
+ assert sd["weight"].shape == org_weight.shape
+ org_module.load_state_dict(sd)
+
+ org_module._lora_restored = False
+ lora.enabled = False
+
+ def apply_max_norm_regularization(self, max_norm_value, device):
+ downkeys = []
+ upkeys = []
+ alphakeys = []
+ norms = []
+ keys_scaled = 0
+
+ state_dict = self.state_dict()
+ for key in state_dict.keys():
+ if "lora_down" in key and "weight" in key:
+ downkeys.append(key)
+ upkeys.append(key.replace("lora_down", "lora_up"))
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
+
+ for i in range(len(downkeys)):
+ down = state_dict[downkeys[i]].to(device)
+ up = state_dict[upkeys[i]].to(device)
+ alpha = state_dict[alphakeys[i]].to(device)
+ dim = down.shape[0]
+ scale = alpha / dim
+
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
+ else:
+ updown = up @ down
+
+ updown *= scale
+
+ norm = updown.norm().clamp(min=max_norm_value / 2)
+ desired = torch.clamp(norm, max=max_norm_value)
+ ratio = desired.cpu() / norm.cpu()
+ sqrt_ratio = ratio**0.5
+ if ratio != 1:
+ keys_scaled += 1
+ state_dict[upkeys[i]] *= sqrt_ratio
+ state_dict[downkeys[i]] *= sqrt_ratio
+ scalednorm = updown.norm() * ratio
+ norms.append(scalednorm.item())
+
+ return keys_scaled, sum(norms) / len(norms), max(norms)
+
+
+def create_arch_network_from_weights(
+ multiplier: float,
+ weights_sd: Dict[str, torch.Tensor],
+ text_encoders: Optional[List[nn.Module]] = None,
+ unet: Optional[nn.Module] = None,
+ for_inference: bool = False,
+ **kwargs,
+) -> LoRANetwork:
+ return create_network_from_weights(
+ HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
+ )
+
+
+# Create network from weights for inference, weights are not loaded here (because can be merged)
+def create_network_from_weights(
+ target_replace_modules: List[str],
+ multiplier: float,
+ weights_sd: Dict[str, torch.Tensor],
+ text_encoders: Optional[List[nn.Module]] = None,
+ unet: Optional[nn.Module] = None,
+ for_inference: bool = False,
+ **kwargs,
+) -> LoRANetwork:
+ # get dim/alpha mapping
+ modules_dim = {}
+ modules_alpha = {}
+ for key, value in weights_sd.items():
+ if "." not in key:
+ continue
+
+ lora_name = key.split(".")[0]
+ if "alpha" in key:
+ modules_alpha[lora_name] = value
+ elif "lora_down" in key:
+ dim = value.shape[0]
+ modules_dim[lora_name] = dim
+ # logger.info(lora_name, value.size(), dim)
+
+ module_class = LoRAInfModule if for_inference else LoRAModule
+
+ network = LoRANetwork(
+ target_replace_modules,
+ "lora_unet",
+ text_encoders,
+ unet,
+ multiplier=multiplier,
+ modules_dim=modules_dim,
+ modules_alpha=modules_alpha,
+ module_class=module_class,
+ )
+ return network
diff --git a/networks/lora_framepack.py b/networks/lora_framepack.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b627d4d5188257f5ceca9e467e7c0964e4dd5e8
--- /dev/null
+++ b/networks/lora_framepack.py
@@ -0,0 +1,65 @@
+# LoRA module for FramePack
+
+import ast
+from typing import Dict, List, Optional
+import torch
+import torch.nn as nn
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+import networks.lora as lora
+
+
+FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"]
+
+
+def create_arch_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae: nn.Module,
+ text_encoders: List[nn.Module],
+ unet: nn.Module,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ # add default exclude patterns
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is None:
+ exclude_patterns = []
+ else:
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+
+ # exclude if 'norm' in the name of the module
+ exclude_patterns.append(r".*(norm).*")
+
+ kwargs["exclude_patterns"] = exclude_patterns
+
+ return lora.create_network(
+ FRAMEPACK_TARGET_REPLACE_MODULES,
+ "lora_unet",
+ multiplier,
+ network_dim,
+ network_alpha,
+ vae,
+ text_encoders,
+ unet,
+ neuron_dropout=neuron_dropout,
+ **kwargs,
+ )
+
+
+def create_arch_network_from_weights(
+ multiplier: float,
+ weights_sd: Dict[str, torch.Tensor],
+ text_encoders: Optional[List[nn.Module]] = None,
+ unet: Optional[nn.Module] = None,
+ for_inference: bool = False,
+ **kwargs,
+) -> lora.LoRANetwork:
+ return lora.create_network_from_weights(
+ FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
+ )
diff --git a/networks/lora_wan.py b/networks/lora_wan.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b171a741d317a551f17d1f45046e7eed6b161e
--- /dev/null
+++ b/networks/lora_wan.py
@@ -0,0 +1,65 @@
+# LoRA module for Wan2.1
+
+import ast
+from typing import Dict, List, Optional
+import torch
+import torch.nn as nn
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+import networks.lora as lora
+
+
+WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"]
+
+
+def create_arch_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae: nn.Module,
+ text_encoders: List[nn.Module],
+ unet: nn.Module,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ # add default exclude patterns
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is None:
+ exclude_patterns = []
+ else:
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+
+ # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
+ exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*")
+
+ kwargs["exclude_patterns"] = exclude_patterns
+
+ return lora.create_network(
+ WAN_TARGET_REPLACE_MODULES,
+ "lora_unet",
+ multiplier,
+ network_dim,
+ network_alpha,
+ vae,
+ text_encoders,
+ unet,
+ neuron_dropout=neuron_dropout,
+ **kwargs,
+ )
+
+
+def create_arch_network_from_weights(
+ multiplier: float,
+ weights_sd: Dict[str, torch.Tensor],
+ text_encoders: Optional[List[nn.Module]] = None,
+ unet: Optional[nn.Module] = None,
+ for_inference: bool = False,
+ **kwargs,
+) -> lora.LoRANetwork:
+ return lora.create_network_from_weights(
+ WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
+ )
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..479e30c6558a1ec72dacca7c98d5931c85728b69
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,42 @@
+[project]
+name = "musubi-tuner"
+version = "0.1.0"
+description = "Musubi Tuner by kohya_ss"
+readme = "README.md"
+requires-python = ">=3.10, <3.11"
+dependencies = [
+ "accelerate>=1.6.0",
+ "ascii-magic==2.3.0",
+ "av==14.0.1",
+ "bitsandbytes>=0.45.0",
+ "diffusers>=0.32.1",
+ "easydict==1.13",
+ "einops>=0.7.0",
+ "ftfy==6.3.1",
+ "huggingface-hub>=0.30.0",
+ "matplotlib>=3.10.0",
+ "opencv-python>=4.10.0.84",
+ "pillow>=10.2.0",
+ "safetensors>=0.4.5",
+ "sageattention>=1.0.6",
+ "tensorboard>=2.18.0",
+ "toml>=0.10.2",
+ "torch>=2.5.1",
+ "torchvision>=0.20.1",
+ "tqdm>=4.66.5",
+ "transformers>=4.46.3",
+ "voluptuous>=0.15.2",
+]
+
+[tool.uv.sources]
+torch = [
+ { index = "pytorch-cu124" },
+]
+torchvision = [
+ { index = "pytorch-cu124" },
+]
+
+[[tool.uv.index]]
+name = "pytorch-cu124"
+url = "https://download.pytorch.org/whl/cu124"
+explicit = true
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1f2e932e7e019521a5e657088dcd78666224009e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+accelerate==1.6.0
+av==14.0.1
+bitsandbytes==0.45.4
+diffusers==0.32.1
+einops==0.7.0
+huggingface-hub==0.30.0
+opencv-python==4.10.0.84
+pillow
+safetensors==0.4.5
+toml==0.10.2
+tqdm==4.67.1
+transformers==4.46.3
+voluptuous==0.15.2
+
+# Wan2.1
+ftfy==6.3.1
+easydict==1.13
+
+# optional dependencies
+# ascii-magic==2.3.0
+# matplotlib==3.10.0
+# tensorboard
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/device_utils.py b/utils/device_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14803e499d7b92acebf8d8bddc3426d178695c4
--- /dev/null
+++ b/utils/device_utils.py
@@ -0,0 +1,19 @@
+import torch
+
+
+def clean_memory_on_device(device):
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ elif device.type == "cpu":
+ pass
+ elif device.type == "mps": # not tested
+ torch.mps.empty_cache()
+
+
+def synchronize_device(device: torch.device):
+ if device.type == "cuda":
+ torch.cuda.synchronize()
+ elif device.type == "xpu":
+ torch.xpu.synchronize()
+ elif device.type == "mps":
+ torch.mps.synchronize()
diff --git a/utils/huggingface_utils.py b/utils/huggingface_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc7bd7dbb2ef70e0b6244b9db686aae00f46408
--- /dev/null
+++ b/utils/huggingface_utils.py
@@ -0,0 +1,89 @@
+import threading
+from typing import Union, BinaryIO
+from huggingface_hub import HfApi
+from pathlib import Path
+import argparse
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def fire_in_thread(f, *args, **kwargs):
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
+
+
+def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
+ api = HfApi(
+ token=token,
+ )
+ try:
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
+ return True
+ except:
+ return False
+
+
+def upload(
+ args: argparse.Namespace,
+ src: Union[str, Path, bytes, BinaryIO],
+ dest_suffix: str = "",
+ force_sync_upload: bool = False,
+):
+ repo_id = args.huggingface_repo_id
+ repo_type = args.huggingface_repo_type
+ token = args.huggingface_token
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
+ api = HfApi(token=token)
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
+ try:
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
+ except Exception as e: # RepositoryNotFoundError or something else
+ logger.error("===========================================")
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
+ logger.error("===========================================")
+
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
+
+ def uploader():
+ try:
+ if is_folder:
+ api.upload_folder(
+ repo_id=repo_id,
+ repo_type=repo_type,
+ folder_path=src,
+ path_in_repo=path_in_repo,
+ )
+ else:
+ api.upload_file(
+ repo_id=repo_id,
+ repo_type=repo_type,
+ path_or_fileobj=src,
+ path_in_repo=path_in_repo,
+ )
+ except Exception as e: # RuntimeError or something else
+ logger.error("===========================================")
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
+ logger.error("===========================================")
+
+ if args.async_upload and not force_sync_upload:
+ fire_in_thread(uploader)
+ else:
+ uploader()
+
+
+def list_dir(
+ repo_id: str,
+ subfolder: str,
+ repo_type: str,
+ revision: str = "main",
+ token: str = None,
+):
+ api = HfApi(
+ token=token,
+ )
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
+ return file_list
diff --git a/utils/model_utils.py b/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5beed8ec4e09f433ba2e84556a6c8f342a2903f5
--- /dev/null
+++ b/utils/model_utils.py
@@ -0,0 +1,151 @@
+import hashlib
+from io import BytesIO
+from typing import Optional
+
+import safetensors.torch
+import torch
+
+
+def model_hash(filename):
+ """Old model hash used by stable-diffusion-webui"""
+ try:
+ with open(filename, "rb") as file:
+ m = hashlib.sha256()
+
+ file.seek(0x100000)
+ m.update(file.read(0x10000))
+ return m.hexdigest()[0:8]
+ except FileNotFoundError:
+ return "NOFILE"
+ except IsADirectoryError: # Linux?
+ return "IsADirectory"
+ except PermissionError: # Windows
+ return "IsADirectory"
+
+
+def calculate_sha256(filename):
+ """New model hash used by stable-diffusion-webui"""
+ try:
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+ except FileNotFoundError:
+ return "NOFILE"
+ except IsADirectoryError: # Linux?
+ return "IsADirectory"
+ except PermissionError: # Windows
+ return "IsADirectory"
+
+
+def addnet_hash_legacy(b):
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+ m = hashlib.sha256()
+
+ b.seek(0x100000)
+ m.update(b.read(0x10000))
+ return m.hexdigest()[0:8]
+
+
+def addnet_hash_safetensors(b):
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def precalculate_safetensors_hashes(tensors, metadata):
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
+ save time on indexing the model later."""
+
+ # Because writing user metadata to the file can change the result of
+ # sd_models.model_hash(), only retain the training metadata for purposes of
+ # calculating the hash, as they are meant to be immutable
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
+
+ bytes = safetensors.torch.save(tensors, metadata)
+ b = BytesIO(bytes)
+
+ model_hash = addnet_hash_safetensors(b)
+ legacy_hash = addnet_hash_legacy(b)
+ return model_hash, legacy_hash
+
+
+def dtype_to_str(dtype: torch.dtype) -> str:
+ # get name of the dtype
+ dtype_name = str(dtype).split(".")[-1]
+ return dtype_name
+
+
+def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
+ """
+ Convert a string to a torch.dtype
+
+ Args:
+ s: string representation of the dtype
+ default_dtype: default dtype to return if s is None
+
+ Returns:
+ torch.dtype: the corresponding torch.dtype
+
+ Raises:
+ ValueError: if the dtype is not supported
+
+ Examples:
+ >>> str_to_dtype("float32")
+ torch.float32
+ >>> str_to_dtype("fp32")
+ torch.float32
+ >>> str_to_dtype("float16")
+ torch.float16
+ >>> str_to_dtype("fp16")
+ torch.float16
+ >>> str_to_dtype("bfloat16")
+ torch.bfloat16
+ >>> str_to_dtype("bf16")
+ torch.bfloat16
+ >>> str_to_dtype("fp8")
+ torch.float8_e4m3fn
+ >>> str_to_dtype("fp8_e4m3fn")
+ torch.float8_e4m3fn
+ >>> str_to_dtype("fp8_e4m3fnuz")
+ torch.float8_e4m3fnuz
+ >>> str_to_dtype("fp8_e5m2")
+ torch.float8_e5m2
+ >>> str_to_dtype("fp8_e5m2fnuz")
+ torch.float8_e5m2fnuz
+ """
+ if s is None:
+ return default_dtype
+ if s in ["bf16", "bfloat16"]:
+ return torch.bfloat16
+ elif s in ["fp16", "float16"]:
+ return torch.float16
+ elif s in ["fp32", "float32", "float"]:
+ return torch.float32
+ elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
+ return torch.float8_e4m3fn
+ elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
+ return torch.float8_e4m3fnuz
+ elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
+ return torch.float8_e5m2
+ elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
+ return torch.float8_e5m2fnuz
+ elif s in ["fp8", "float8"]:
+ return torch.float8_e4m3fn # default fp8
+ else:
+ raise ValueError(f"Unsupported dtype: {s}")
diff --git a/utils/safetensors_utils.py b/utils/safetensors_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d33b3c25ddee119212de80332b609f9dbd6b251d
--- /dev/null
+++ b/utils/safetensors_utils.py
@@ -0,0 +1,221 @@
+import os
+import re
+import torch
+import json
+import struct
+from typing import Dict, Any, Union, Optional
+
+from safetensors.torch import load_file
+
+
+def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
+ """
+ memory efficient save file
+ """
+
+ _TYPES = {
+ torch.float64: "F64",
+ torch.float32: "F32",
+ torch.float16: "F16",
+ torch.bfloat16: "BF16",
+ torch.int64: "I64",
+ torch.int32: "I32",
+ torch.int16: "I16",
+ torch.int8: "I8",
+ torch.uint8: "U8",
+ torch.bool: "BOOL",
+ getattr(torch, "float8_e5m2", None): "F8_E5M2",
+ getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
+ }
+ _ALIGN = 256
+
+ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
+ validated = {}
+ for key, value in metadata.items():
+ if not isinstance(key, str):
+ raise ValueError(f"Metadata key must be a string, got {type(key)}")
+ if not isinstance(value, str):
+ print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
+ validated[key] = str(value)
+ else:
+ validated[key] = value
+ return validated
+
+ # print(f"Using memory efficient save file: {filename}")
+
+ header = {}
+ offset = 0
+ if metadata:
+ header["__metadata__"] = validate_metadata(metadata)
+ for k, v in tensors.items():
+ if v.numel() == 0: # empty tensor
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
+ else:
+ size = v.numel() * v.element_size()
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
+ offset += size
+
+ hjson = json.dumps(header).encode("utf-8")
+ hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
+
+ with open(filename, "wb") as f:
+ f.write(struct.pack(" Dict[str, str]:
+ return self.header.get("__metadata__", {})
+
+ def get_tensor(self, key):
+ if key not in self.header:
+ raise KeyError(f"Tensor '{key}' not found in the file")
+
+ metadata = self.header[key]
+ offset_start, offset_end = metadata["data_offsets"]
+
+ if offset_start == offset_end:
+ tensor_bytes = None
+ else:
+ # adjust offset by header size
+ self.file.seek(self.header_size + 8 + offset_start)
+ tensor_bytes = self.file.read(offset_end - offset_start)
+
+ return self._deserialize_tensor(tensor_bytes, metadata)
+
+ def _read_header(self):
+ header_size = struct.unpack(" dict[str, torch.Tensor]:
+ if disable_mmap:
+ # return safetensors.torch.load(open(path, "rb").read())
+ # use experimental loader
+ # logger.info(f"Loading without mmap (experimental)")
+ state_dict = {}
+ with MemoryEfficientSafeOpen(path) as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
+ return state_dict
+ else:
+ try:
+ state_dict = load_file(path, device=device)
+ except:
+ state_dict = load_file(path) # prevent device invalid Error
+ if dtype is not None:
+ for key in state_dict.keys():
+ state_dict[key] = state_dict[key].to(dtype=dtype)
+ return state_dict
+
+
+def load_split_weights(
+ file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False
+) -> Dict[str, torch.Tensor]:
+ """
+ Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix.
+ dtype is as is, no conversion is done.
+ """
+ device = torch.device(device)
+
+ # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
+ basename = os.path.basename(file_path)
+ match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
+ if match:
+ prefix = basename[: match.start(2)]
+ count = int(match.group(3))
+ state_dict = {}
+ for i in range(count):
+ filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
+ filepath = os.path.join(os.path.dirname(file_path), filename)
+ if os.path.exists(filepath):
+ state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap))
+ else:
+ raise FileNotFoundError(f"File {filepath} not found")
+ else:
+ state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap)
+ return state_dict
diff --git a/utils/sai_model_spec.py b/utils/sai_model_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..264340cf532166922849db9cf520a23c133cca99
--- /dev/null
+++ b/utils/sai_model_spec.py
@@ -0,0 +1,286 @@
+# based on https://github.com/Stability-AI/ModelSpec
+import datetime
+import hashlib
+from io import BytesIO
+import os
+from typing import List, Optional, Tuple, Union
+import safetensors
+import logging
+
+from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, ARCHITECTURE_WAN, ARCHITECTURE_FRAMEPACK
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+r"""
+# Metadata Example
+metadata = {
+ # === Must ===
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
+ "modelspec.implementation": "sgm",
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
+ # === Should ===
+ "modelspec.author": "Example Corp", # Your name or company name
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
+ # === Can ===
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
+}
+"""
+
+BASE_METADATA = {
+ # === Must ===
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
+ "modelspec.architecture": None,
+ "modelspec.implementation": None,
+ "modelspec.title": None,
+ "modelspec.resolution": None,
+ # === Should ===
+ "modelspec.description": None,
+ "modelspec.author": None,
+ "modelspec.date": None,
+ # === Can ===
+ "modelspec.license": None,
+ "modelspec.tags": None,
+ "modelspec.merged_from": None,
+ "modelspec.prediction_type": None,
+ "modelspec.timestep_range": None,
+ "modelspec.encoder_layer": None,
+}
+
+# 別に使うやつだけ定義
+MODELSPEC_TITLE = "modelspec.title"
+
+ARCH_HUNYUAN_VIDEO = "hunyuan-video"
+
+# Official Wan2.1 weights does not have sai_model_spec, so we use this as an architecture name
+ARCH_WAN = "wan2.1"
+
+ARCH_FRAMEPACK = "framepack"
+
+ADAPTER_LORA = "lora"
+
+IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo"
+IMPL_WAN = "https://github.com/Wan-Video/Wan2.1"
+IMPL_FRAMEPACK = "https://github.com/lllyasviel/FramePack"
+
+PRED_TYPE_EPSILON = "epsilon"
+# PRED_TYPE_V = "v"
+
+
+def load_bytes_in_safetensors(tensors):
+ bytes = safetensors.torch.save(tensors)
+ b = BytesIO(bytes)
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+
+ return b.read()
+
+
+def precalculate_safetensors_hashes(state_dict):
+ # calculate each tensor one by one to reduce memory usage
+ hash_sha256 = hashlib.sha256()
+ for tensor in state_dict.values():
+ single_tensor_sd = {"tensor": tensor}
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
+ hash_sha256.update(bytes_for_tensor)
+
+ return f"0x{hash_sha256.hexdigest()}"
+
+
+def update_hash_sha256(metadata: dict, state_dict: dict):
+ raise NotImplementedError
+
+
+def build_metadata(
+ state_dict: Optional[dict],
+ architecture: str,
+ timestamp: float,
+ title: Optional[str] = None,
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
+ author: Optional[str] = None,
+ description: Optional[str] = None,
+ license: Optional[str] = None,
+ tags: Optional[str] = None,
+ merged_from: Optional[str] = None,
+ timesteps: Optional[Tuple[int, int]] = None,
+ is_lora: bool = True,
+):
+ metadata = {}
+ metadata.update(BASE_METADATA)
+
+ # TODO implement if we can calculate hash without loading all tensors
+ # if state_dict is not None:
+ # hash = precalculate_safetensors_hashes(state_dict)
+ # metadata["modelspec.hash_sha256"] = hash
+
+ # arch = ARCH_HUNYUAN_VIDEO
+ if architecture == ARCHITECTURE_HUNYUAN_VIDEO:
+ arch = ARCH_HUNYUAN_VIDEO
+ impl = IMPL_HUNYUAN_VIDEO
+ elif architecture == ARCHITECTURE_WAN:
+ arch = ARCH_WAN
+ impl = IMPL_WAN
+ elif architecture == ARCHITECTURE_FRAMEPACK:
+ arch = ARCH_FRAMEPACK
+ impl = IMPL_FRAMEPACK
+ else:
+ raise ValueError(f"Unknown architecture: {architecture}")
+
+ if is_lora:
+ arch += f"/{ADAPTER_LORA}"
+ metadata["modelspec.architecture"] = arch
+
+ metadata["modelspec.implementation"] = impl
+
+ if title is None:
+ title = "LoRA" if is_lora else "Hunyuan-Video"
+ title += f"@{timestamp}"
+ metadata[MODELSPEC_TITLE] = title
+
+ if author is not None:
+ metadata["modelspec.author"] = author
+ else:
+ del metadata["modelspec.author"]
+
+ if description is not None:
+ metadata["modelspec.description"] = description
+ else:
+ del metadata["modelspec.description"]
+
+ if merged_from is not None:
+ metadata["modelspec.merged_from"] = merged_from
+ else:
+ del metadata["modelspec.merged_from"]
+
+ if license is not None:
+ metadata["modelspec.license"] = license
+ else:
+ del metadata["modelspec.license"]
+
+ if tags is not None:
+ metadata["modelspec.tags"] = tags
+ else:
+ del metadata["modelspec.tags"]
+
+ # remove microsecond from time
+ int_ts = int(timestamp)
+
+ # time to iso-8601 compliant date
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
+ metadata["modelspec.date"] = date
+
+ if reso is not None:
+ # comma separated to tuple
+ if isinstance(reso, str):
+ reso = tuple(map(int, reso.split(",")))
+ if len(reso) == 1:
+ reso = (reso[0], reso[0])
+ else:
+ # resolution is defined in dataset, so use default
+ reso = (1280, 720)
+ if isinstance(reso, int):
+ reso = (reso, reso)
+
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
+
+ # metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
+ del metadata["modelspec.prediction_type"]
+
+ if timesteps is not None:
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
+ timesteps = (timesteps, timesteps)
+ if len(timesteps) == 1:
+ timesteps = (timesteps[0], timesteps[0])
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
+ else:
+ del metadata["modelspec.timestep_range"]
+
+ # if clip_skip is not None:
+ # metadata["modelspec.encoder_layer"] = f"{clip_skip}"
+ # else:
+ del metadata["modelspec.encoder_layer"]
+
+ # # assert all values are filled
+ # assert all([v is not None for v in metadata.values()]), metadata
+ if not all([v is not None for v in metadata.values()]):
+ logger.error(f"Internal error: some metadata values are None: {metadata}")
+
+ return metadata
+
+
+# region utils
+
+
+def get_title(metadata: dict) -> Optional[str]:
+ return metadata.get(MODELSPEC_TITLE, None)
+
+
+def load_metadata_from_safetensors(model: str) -> dict:
+ if not model.endswith(".safetensors"):
+ return {}
+
+ with safetensors.safe_open(model, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ return metadata
+
+
+def build_merged_from(models: List[str]) -> str:
+ def get_title(model: str):
+ metadata = load_metadata_from_safetensors(model)
+ title = metadata.get(MODELSPEC_TITLE, None)
+ if title is None:
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
+ return title
+
+ titles = [get_title(model) for model in models]
+ return ", ".join(titles)
+
+
+# endregion
+
+
+r"""
+if __name__ == "__main__":
+ import argparse
+ import torch
+ from safetensors.torch import load_file
+ from library import train_util
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt", type=str, required=True)
+ args = parser.parse_args()
+
+ print(f"Loading {args.ckpt}")
+ state_dict = load_file(args.ckpt)
+
+ print(f"Calculating metadata")
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
+ print(metadata)
+ del state_dict
+
+ # by reference implementation
+ with open(args.ckpt, mode="rb") as file_data:
+ file_hash = hashlib.sha256()
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
+ content = (
+ file_data.read()
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
+ file_hash.update(content)
+ # ===== Update the hash for modelspec =====
+ by_ref = f"0x{file_hash.hexdigest()}"
+ print(by_ref)
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
+
+"""
diff --git a/utils/train_utils.py b/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2176e999b2309af1935654d8e02894f64b0d80e
--- /dev/null
+++ b/utils/train_utils.py
@@ -0,0 +1,178 @@
+import argparse
+import logging
+import os
+import shutil
+
+import accelerate
+import torch
+
+from utils import huggingface_utils
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+# checkpointファイル名
+EPOCH_STATE_NAME = "{}-{:06d}-state"
+EPOCH_FILE_NAME = "{}-{:06d}"
+EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
+LAST_STATE_NAME = "{}-state"
+STEP_STATE_NAME = "{}-step{:08d}-state"
+STEP_FILE_NAME = "{}-step{:08d}"
+STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
+
+
+def get_sanitized_config_or_none(args: argparse.Namespace):
+ # if `--log_config` is enabled, return args for logging. if not, return None.
+ # when `--log_config is enabled, filter out sensitive values from args
+ # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe
+
+ if not args.log_config:
+ return None
+
+ sensitive_args = ["wandb_api_key", "huggingface_token"]
+ sensitive_path_args = [
+ "dit",
+ "vae",
+ "text_encoder1",
+ "text_encoder2",
+ "image_encoder",
+ "base_weights",
+ "network_weights",
+ "output_dir",
+ "logging_dir",
+ ]
+ filtered_args = {}
+ for k, v in vars(args).items():
+ # filter out sensitive values and convert to string if necessary
+ if k not in sensitive_args + sensitive_path_args:
+ # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
+ if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
+ filtered_args[k] = v
+ # accelerate does not support lists
+ elif isinstance(v, list):
+ filtered_args[k] = f"{v}"
+ # accelerate does not support objects
+ elif isinstance(v, object):
+ filtered_args[k] = f"{v}"
+
+ return filtered_args
+
+
+class LossRecorder:
+ def __init__(self):
+ self.loss_list: list[float] = []
+ self.loss_total: float = 0.0
+
+ def add(self, *, epoch: int, step: int, loss: float) -> None:
+ if epoch == 0:
+ self.loss_list.append(loss)
+ else:
+ while len(self.loss_list) <= step:
+ self.loss_list.append(0.0)
+ self.loss_total -= self.loss_list[step]
+ self.loss_list[step] = loss
+ self.loss_total += loss
+
+ @property
+ def moving_average(self) -> float:
+ return self.loss_total / len(self.loss_list)
+
+
+def get_epoch_ckpt_name(model_name, epoch_no: int):
+ return EPOCH_FILE_NAME.format(model_name, epoch_no) + ".safetensors"
+
+
+def get_step_ckpt_name(model_name, step_no: int):
+ return STEP_FILE_NAME.format(model_name, step_no) + ".safetensors"
+
+
+def get_last_ckpt_name(model_name):
+ return model_name + ".safetensors"
+
+
+def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
+ if args.save_last_n_epochs is None:
+ return None
+
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
+ if remove_epoch_no < 0:
+ return None
+ return remove_epoch_no
+
+
+def get_remove_step_no(args: argparse.Namespace, step_no: int):
+ if args.save_last_n_steps is None:
+ return None
+
+ # calculate the step number to remove from the last_n_steps and save_every_n_steps
+ # e.g. if save_every_n_steps=10, save_last_n_steps=30, at step 50, keep 30 steps and remove step 10
+ remove_step_no = step_no - args.save_last_n_steps - 1
+ remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
+ if remove_step_no < 0:
+ return None
+ return remove_step_no
+
+
+def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator: accelerate.Accelerator, epoch_no: int):
+ model_name = args.output_name
+
+ logger.info("")
+ logger.info(f"saving state at epoch {epoch_no}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
+ accelerator.save_state(state_dir)
+ if args.save_state_to_huggingface:
+ logger.info("uploading state to huggingface.")
+ huggingface_utils.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
+
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
+ if last_n_epochs is not None:
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
+ if os.path.exists(state_dir_old):
+ logger.info(f"removing old state: {state_dir_old}")
+ shutil.rmtree(state_dir_old)
+
+
+def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator: accelerate.Accelerator, step_no: int):
+ model_name = args.output_name
+
+ logger.info("")
+ logger.info(f"saving state at step {step_no}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
+ accelerator.save_state(state_dir)
+ if args.save_state_to_huggingface:
+ logger.info("uploading state to huggingface.")
+ huggingface_utils.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
+
+ last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
+ if last_n_steps is not None:
+ # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
+ remove_step_no = step_no - last_n_steps - 1
+ remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
+
+ if remove_step_no > 0:
+ state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
+ if os.path.exists(state_dir_old):
+ logger.info(f"removing old state: {state_dir_old}")
+ shutil.rmtree(state_dir_old)
+
+
+def save_state_on_train_end(args: argparse.Namespace, accelerator: accelerate.Accelerator):
+ model_name = args.output_name
+
+ logger.info("")
+ logger.info("saving last state.")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
+ accelerator.save_state(state_dir)
+
+ if args.save_state_to_huggingface:
+ logger.info("uploading last state to huggingface.")
+ huggingface_utils.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
+
diff --git a/wan/__init__.py b/wan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f7ed4df10cae220744639f079b4e11d985f9d05
--- /dev/null
+++ b/wan/__init__.py
@@ -0,0 +1 @@
+# from . import configs, distributed, modules
diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e257028bf7fc31eb99d95f4fb584344bc33b43a
--- /dev/null
+++ b/wan/configs/__init__.py
@@ -0,0 +1,69 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import copy
+import os
+import torch
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+from .wan_i2v_14B import i2v_14B
+from .wan_t2v_1_3B import t2v_1_3B
+from .wan_t2v_14B import t2v_14B
+
+# the config of t2i_14B is the same as t2v_14B
+t2i_14B = copy.deepcopy(t2v_14B)
+t2i_14B.__name__ = "Config: Wan T2I 14B"
+
+# support Fun models: deepcopy and change some configs. FC denotes Fun Control
+t2v_1_3B_FC = copy.deepcopy(t2v_1_3B)
+t2v_1_3B_FC.__name__ = "Config: Wan-Fun-Control T2V 1.3B"
+t2v_1_3B_FC.i2v = True # this is strange, but Fun-Control model needs this because it has img cross-attention
+t2v_1_3B_FC.in_dim = 48
+t2v_1_3B_FC.is_fun_control = True
+
+t2v_14B_FC = copy.deepcopy(t2v_14B)
+t2v_14B_FC.__name__ = "Config: Wan-Fun-Control T2V 14B"
+t2v_14B_FC.i2v = True # this is strange, but Fun-Control model needs this because it has img cross-attention
+t2v_14B_FC.in_dim = 48 # same as i2v_14B, use zeros for image latents
+t2v_14B_FC.is_fun_control = True
+
+i2v_14B_FC = copy.deepcopy(i2v_14B)
+i2v_14B_FC.__name__ = "Config: Wan-Fun-Control I2V 14B"
+i2v_14B_FC.in_dim = 48
+i2v_14B_FC.is_fun_control = True
+
+WAN_CONFIGS = {
+ "t2v-14B": t2v_14B,
+ "t2v-1.3B": t2v_1_3B,
+ "i2v-14B": i2v_14B,
+ "t2i-14B": t2i_14B,
+ # Fun Control models
+ "t2v-1.3B-FC": t2v_1_3B_FC,
+ "t2v-14B-FC": t2v_14B_FC,
+ "i2v-14B-FC": i2v_14B_FC,
+}
+
+SIZE_CONFIGS = {
+ "720*1280": (720, 1280),
+ "1280*720": (1280, 720),
+ "480*832": (480, 832),
+ "832*480": (832, 480),
+ "1024*1024": (1024, 1024),
+}
+
+MAX_AREA_CONFIGS = {
+ "720*1280": 720 * 1280,
+ "1280*720": 1280 * 720,
+ "480*832": 480 * 832,
+ "832*480": 832 * 480,
+}
+
+SUPPORTED_SIZES = {
+ "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"),
+ "t2v-1.3B": ("480*832", "832*480"),
+ "i2v-14B": ("720*1280", "1280*720", "480*832", "832*480"),
+ "t2i-14B": tuple(SIZE_CONFIGS.keys()),
+ # Fun Control models
+ "t2v-1.3B-FC": ("480*832", "832*480"),
+ "t2v-14B-FC": ("720*1280", "1280*720", "480*832", "832*480"),
+ "i2v-14B-FC": ("720*1280", "1280*720", "480*832", "832*480"),
+}
diff --git a/wan/configs/shared_config.py b/wan/configs/shared_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff603d52244336acc864835c2cd30c1c6110e39b
--- /dev/null
+++ b/wan/configs/shared_config.py
@@ -0,0 +1,20 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+from easydict import EasyDict
+
+#------------------------ Wan shared config ------------------------#
+wan_shared_cfg = EasyDict()
+
+# t5
+wan_shared_cfg.t5_model = 'umt5_xxl'
+wan_shared_cfg.t5_dtype = torch.bfloat16
+wan_shared_cfg.text_len = 512
+
+# transformer
+wan_shared_cfg.param_dtype = torch.bfloat16
+wan_shared_cfg.out_dim = 16
+
+# inference
+wan_shared_cfg.num_train_timesteps = 1000
+wan_shared_cfg.sample_fps = 16
+wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py
new file mode 100644
index 0000000000000000000000000000000000000000..434f59c3d1dd75c9cdc816c5f976afed0ef08631
--- /dev/null
+++ b/wan/configs/wan_i2v_14B.py
@@ -0,0 +1,38 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+from easydict import EasyDict
+
+from .shared_config import wan_shared_cfg
+
+# ------------------------ Wan I2V 14B ------------------------#
+
+i2v_14B = EasyDict(__name__="Config: Wan I2V 14B")
+i2v_14B.update(wan_shared_cfg)
+i2v_14B.i2v = True
+i2v_14B.is_fun_control = False
+
+i2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
+i2v_14B.t5_tokenizer = "google/umt5-xxl"
+
+# clip
+i2v_14B.clip_model = "clip_xlm_roberta_vit_h_14"
+i2v_14B.clip_dtype = torch.float16
+i2v_14B.clip_checkpoint = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
+i2v_14B.clip_tokenizer = "xlm-roberta-large"
+
+# vae
+i2v_14B.vae_checkpoint = "Wan2.1_VAE.pth"
+i2v_14B.vae_stride = (4, 8, 8)
+
+# transformer
+i2v_14B.patch_size = (1, 2, 2)
+i2v_14B.dim = 5120
+i2v_14B.ffn_dim = 13824
+i2v_14B.freq_dim = 256
+i2v_14B.in_dim = 36
+i2v_14B.num_heads = 40
+i2v_14B.num_layers = 40
+i2v_14B.window_size = (-1, -1)
+i2v_14B.qk_norm = True
+i2v_14B.cross_attn_norm = True
+i2v_14B.eps = 1e-6
diff --git a/wan/configs/wan_t2v_14B.py b/wan/configs/wan_t2v_14B.py
new file mode 100644
index 0000000000000000000000000000000000000000..76433f058b159d74ce41a539e69a1bcd8bb9901e
--- /dev/null
+++ b/wan/configs/wan_t2v_14B.py
@@ -0,0 +1,32 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+from easydict import EasyDict
+
+from .shared_config import wan_shared_cfg
+
+# ------------------------ Wan T2V 14B ------------------------#
+
+t2v_14B = EasyDict(__name__="Config: Wan T2V 14B")
+t2v_14B.update(wan_shared_cfg)
+t2v_14B.i2v = False
+t2v_14B.is_fun_control = False
+
+# t5
+t2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
+t2v_14B.t5_tokenizer = "google/umt5-xxl"
+
+# vae
+t2v_14B.vae_checkpoint = "Wan2.1_VAE.pth"
+t2v_14B.vae_stride = (4, 8, 8)
+
+# transformer
+t2v_14B.patch_size = (1, 2, 2)
+t2v_14B.dim = 5120
+t2v_14B.ffn_dim = 13824
+t2v_14B.freq_dim = 256
+t2v_14B.in_dim = 16
+t2v_14B.num_heads = 40
+t2v_14B.num_layers = 40
+t2v_14B.window_size = (-1, -1)
+t2v_14B.qk_norm = True
+t2v_14B.cross_attn_norm = True
+t2v_14B.eps = 1e-6
diff --git a/wan/configs/wan_t2v_1_3B.py b/wan/configs/wan_t2v_1_3B.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb9e10ef41cf249004e1e46d22591471e284882
--- /dev/null
+++ b/wan/configs/wan_t2v_1_3B.py
@@ -0,0 +1,32 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+from easydict import EasyDict
+
+from .shared_config import wan_shared_cfg
+
+# ------------------------ Wan T2V 1.3B ------------------------#
+
+t2v_1_3B = EasyDict(__name__="Config: Wan T2V 1.3B")
+t2v_1_3B.update(wan_shared_cfg)
+t2v_1_3B.i2v = False
+t2v_1_3B.is_fun_control = False
+
+# t5
+t2v_1_3B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
+t2v_1_3B.t5_tokenizer = "google/umt5-xxl"
+
+# vae
+t2v_1_3B.vae_checkpoint = "Wan2.1_VAE.pth"
+t2v_1_3B.vae_stride = (4, 8, 8)
+
+# transformer
+t2v_1_3B.patch_size = (1, 2, 2)
+t2v_1_3B.dim = 1536
+t2v_1_3B.ffn_dim = 8960
+t2v_1_3B.freq_dim = 256
+t2v_1_3B.in_dim = 16
+t2v_1_3B.num_heads = 12
+t2v_1_3B.num_layers = 30
+t2v_1_3B.window_size = (-1, -1)
+t2v_1_3B.qk_norm = True
+t2v_1_3B.cross_attn_norm = True
+t2v_1_3B.eps = 1e-6
diff --git a/wan/image2video.py b/wan/image2video.py
new file mode 100644
index 0000000000000000000000000000000000000000..711d10ac5037325e0b855c9418595469202c7ca2
--- /dev/null
+++ b/wan/image2video.py
@@ -0,0 +1,410 @@
+# Modified from official implementation
+
+# Original source:
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import logging
+import os
+import random
+import sys
+from typing import Optional, Union
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from tqdm import tqdm
+from accelerate import Accelerator, init_empty_weights
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+from utils.safetensors_utils import load_safetensors
+
+# from .distributed.fsdp import shard_model
+from .modules.clip import CLIPModel
+from .modules.model import WanModel
+from .modules.t5 import T5EncoderModel
+from .modules.vae import WanVAE
+from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
+from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+
+from utils.device_utils import clean_memory_on_device, synchronize_device
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+class WanI2V:
+
+ def __init__(
+ self,
+ config,
+ checkpoint_dir,
+ device_id=0,
+ rank=0,
+ t5_fsdp=False,
+ dit_fsdp=False,
+ use_usp=False,
+ t5_cpu=False,
+ init_on_cpu=True,
+ device=None,
+ dtype=None,
+ dit_path=None,
+ dit_attn_mode=None,
+ t5_path=None,
+ clip_path=None,
+ t5_fp8=False,
+ ):
+ r"""
+ Initializes the image-to-video generation model components.
+
+ Args:
+ config (EasyDict):
+ Object containing model parameters initialized from config.py
+ checkpoint_dir (`str`):
+ Path to directory containing model checkpoints
+ device_id (`int`, *optional*, defaults to 0) **IGNORED**:
+ Id of target GPU device
+ rank (`int`, *optional*, defaults to 0) **IGNORED**:
+ Process rank for distributed training
+ t5_fsdp (`bool`, *optional*, defaults to False) **IGNORED**:
+ Enable FSDP sharding for T5 model
+ dit_fsdp (`bool`, *optional*, defaults to False) **IGNORED**:
+ Enable FSDP sharding for DiT model
+ use_usp (`bool`, *optional*, defaults to False) **IGNORED**:
+ Enable distribution strategy of USP.
+ t5_cpu (`bool`, *optional*, defaults to False) **IGNORED**:
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
+ init_on_cpu (`bool`, *optional*, defaults to True) **IGNORED**:
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
+
+ device (`torch.device`, *optional*, defaults to None):
+ Device to place the model on. If None, use the default device (cuda)
+ dtype (`torch.dtype`, *optional*, defaults to None):
+ Data type for DiT model parameters. If None, use the default parameter data type from config
+ dit_path (`str`, *optional*, defaults to None):
+ Path to DiT model checkpoint. checkpoint_dir is used if None.
+ dit_attn_mode (`str`, *optional*, defaults to None):
+ Attention mode for DiT model. If None, use "torch" attention mode.
+ t5_path (`str`, *optional*, defaults to None):
+ Path to T5 model checkpoint. checkpoint_dir is used if None.
+ clip_path (`str`, *optional*, defaults to None):
+ Path to CLIP model checkpoint. checkpoint_dir is used if None.
+ t5_fp8 (`bool`, *optional*, defaults to False):
+ Enable FP8 quantization for T5 model
+ """
+ self.device = device if device is not None else torch.device("cuda")
+ self.config = config
+ self.rank = rank
+ self.t5_cpu = t5_cpu
+ self.t5_fp8 = t5_fp8
+
+ self.num_train_timesteps = config.num_train_timesteps
+ self.param_dtype = config.param_dtype
+
+ # shard_fn = partial(shard_model, device_id=device_id)
+ checkpoint_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_checkpoint)
+ tokenizer_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_tokenizer)
+ self.text_encoder = T5EncoderModel(
+ text_len=config.text_len,
+ dtype=config.t5_dtype,
+ device=device,
+ checkpoint_path=checkpoint_path,
+ tokenizer_path=tokenizer_path,
+ weight_path=t5_path,
+ fp8=t5_fp8,
+ # shard_fn=shard_fn if t5_fsdp else None,
+ )
+
+ self.vae_stride = config.vae_stride
+ self.patch_size = config.patch_size
+
+ self.checkpoint_dir = checkpoint_dir
+ self.dit_path = dit_path
+ self.dit_dtype = dtype if dtype is not None else config.param_dtype
+ self.dit_attn_mode = dit_attn_mode
+ self.clip_path = clip_path
+
+ self.sample_neg_prompt = config.sample_neg_prompt
+
+ def generate(
+ self,
+ accelerator: Accelerator,
+ merge_lora: Optional[callable],
+ dit_loading_dtype: Optional[torch.dtype],
+ input_prompt,
+ img,
+ size=(1280, 720),
+ frame_num=81,
+ shift=5.0,
+ sample_solver="unipc",
+ sampling_steps=40,
+ guide_scale=5.0,
+ n_prompt="",
+ seed=-1,
+ blocks_to_swap=0,
+ vae: WanVAE = None,
+ ):
+ r"""
+ Generates video frames from input image and text prompt using diffusion process.
+
+ Args:
+ input_prompt (`str`):
+ Text prompt for content generation.
+ img (PIL.Image.Image):
+ Input image tensor. Shape: [3, H, W]
+ max_area (`int`, *optional*, defaults to 720*1280):
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
+ frame_num (`int`, *optional*, defaults to 81):
+ How many frames to sample from a video. The number should be 4n+1
+ shift (`float`, *optional*, defaults to 5.0):
+ Noise schedule shift parameter. Affects temporal dynamics
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
+ Solver used to sample the video.
+ sampling_steps (`int`, *optional*, defaults to 40):
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
+ guide_scale (`float`, *optional*, defaults 5.0):
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
+ n_prompt (`str`, *optional*, defaults to ""):
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
+ seed (`int`, *optional*, defaults to -1):
+ Random seed for noise generation. If -1, use random seed.
+ blocks_to_swap (`int`, *optional*, defaults to 0):
+ Number of blocks to swap (offload) to CPU. If 0, no blocks are offloaded.
+
+ Returns:
+ torch.Tensor:
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
+ - C: Color channels (3 for RGB)
+ - N: Number of frames (81)
+ - H: Frame height (from size)
+ - W: Frame width from size)
+ """
+ max_area = size[0] * size[1]
+
+ # save original image as numpy array
+ img_cv2 = np.array(img) # PIL to numpy
+ img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
+
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) # -1 to 1
+
+ F = frame_num # number of frames
+ h, w = img.shape[1:]
+ aspect_ratio = h / w
+ lat_h = round(np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1])
+ lat_w = round(np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2])
+ h = lat_h * self.vae_stride[1]
+ w = lat_w * self.vae_stride[2]
+ lat_f = (F - 1) // self.vae_stride[0] + 1 # size of latent frames
+ max_seq_len = lat_f * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2])
+
+ # set seed
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=self.device)
+ seed_g.manual_seed(seed)
+
+ # Generate noise for the required number of frames only
+ noise = torch.randn(16, lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
+
+ if n_prompt == "":
+ n_prompt = self.sample_neg_prompt
+
+ # preprocess
+ self.text_encoder.model.to(self.device)
+ with torch.no_grad():
+ if self.t5_fp8:
+ with accelerator.autocast():
+ context = self.text_encoder([input_prompt], self.device)
+ context_null = self.text_encoder([n_prompt], self.device)
+ else:
+ context = self.text_encoder([input_prompt], self.device)
+ context_null = self.text_encoder([n_prompt], self.device)
+
+ del self.text_encoder
+ clean_memory_on_device(self.device)
+
+ # load CLIP model
+ checkpoint_path = None if self.checkpoint_dir is None else os.path.join(self.checkpoint_dir, self.config.clip_checkpoint)
+ tokenizer_path = None if self.checkpoint_dir is None else os.path.join(self.checkpoint_dir, self.config.clip_tokenizer)
+ clip = CLIPModel(
+ dtype=self.config.clip_dtype,
+ device=self.device,
+ checkpoint_path=checkpoint_path,
+ tokenizer_path=tokenizer_path,
+ weight_path=self.clip_path,
+ )
+
+ clip.model.to(self.device)
+ logger.info(f"Encoding image to CLIP context")
+ # use torch.amp.autocast istead of accelerator.autocast, becuase CLIP dtype is not bfloat16
+ with torch.amp.autocast(device_type=self.device.type, dtype=torch.float16), torch.no_grad():
+ clip_context = clip.visual([img[:, None, :, :]])
+ logger.info(f"Encoding complete")
+
+ del clip
+ clean_memory_on_device(self.device)
+
+ # y should be encoded with 81 frames, and trim to lat_f frames? encoding F frames causes invalid results?
+ logger.info(f"Encoding image to latent space")
+ vae.to_device(self.device)
+
+ # resize image for the first frame. INTER_AREA is the best for downsampling
+ interpolation = cv2.INTER_AREA if h < img_cv2.shape[0] else cv2.INTER_CUBIC
+ img_resized = cv2.resize(img_cv2, (w, h), interpolation=interpolation)
+ img_resized = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
+ img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(self.device) # -1 to 1, CHW
+ img_resized = img_resized.unsqueeze(1) # CFHW
+
+ # Create mask for the required number of frames
+ msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
+ msk[:, 1:] = 0
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
+ msk = msk.transpose(1, 2)[0]
+
+ with accelerator.autocast(), torch.no_grad():
+ # Zero padding for the required number of frames only
+ padding_frames = F - 1 # The first frame is the input image
+ img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, h, w, device=self.device)], dim=1)
+ y = vae.encode([img_resized])[0]
+
+ y = y[:, :lat_f] # may be not needed
+ y = torch.concat([msk, y])
+ logger.info(f"Encoding complete")
+
+ vae.to_device("cpu")
+ clean_memory_on_device(self.device)
+
+ # load DiT model
+ dit_loading_dtype = dit_loading_dtype if dit_loading_dtype is not None else self.dit_dtype
+ with init_empty_weights():
+ # if self.checkpoint_dir is not None:
+ # logger.info(f"Creating WanModel from {self.checkpoint_dir}")
+ # self.model = WanModel.from_pretrained(self.checkpoint_dir)
+ # self.model = WanModel.from_config(config)
+ # else:
+ logger.info(f"Creating WanModel")
+ self.model = WanModel(
+ model_type="i2v",
+ dim=self.config.dim,
+ eps=self.config.eps,
+ ffn_dim=self.config.ffn_dim,
+ freq_dim=self.config.freq_dim,
+ in_dim=36,
+ num_heads=self.config.num_heads,
+ num_layers=self.config.num_layers,
+ out_dim=16,
+ text_len=512,
+ attn_mode=self.dit_attn_mode,
+ )
+ self.model.to(dit_loading_dtype)
+
+ # if LoRA is enabled, load the model on CPU with bfloat16
+ loading_device = self.device if (blocks_to_swap == 0 and merge_lora is None) else "cpu"
+ logger.info(f"Loading DiT model from {self.dit_path}, device={loading_device}, dtype={dit_loading_dtype }")
+ sd = load_safetensors(self.dit_path, loading_device, disable_mmap=True, dtype=dit_loading_dtype)
+ info = self.model.load_state_dict(sd, strict=True, assign=True)
+ logger.info(f"Loaded DiT model from {self.dit_path}, info={info}")
+
+ if merge_lora is not None:
+ # merge LoRA to the model, cast and move to the device
+ merge_lora(self.model)
+ if blocks_to_swap == 0:
+ self.model.to(self.device, self.dit_dtype)
+ else:
+ self.model.to(self.dit_dtype)
+
+ if blocks_to_swap > 0:
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {self.device}")
+ self.model.enable_block_swap(blocks_to_swap, self.device, supports_backward=False)
+ self.model.move_to_device_except_swap_blocks(self.device)
+ self.model.prepare_block_swap_before_forward()
+ else:
+ # make sure the model is on the right device
+ self.model.to(self.device)
+
+ self.model.eval().requires_grad_(False)
+ clean_memory_on_device(self.device)
+
+ # evaluation mode
+ with torch.no_grad():
+
+ if sample_solver == "unipc":
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False
+ )
+ sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
+ timesteps = sample_scheduler.timesteps
+ elif sample_solver == "dpm++":
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False
+ )
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
+ timesteps, _ = retrieve_timesteps(sample_scheduler, device=self.device, sigmas=sampling_sigmas)
+ elif sample_solver == "vanilla":
+ sample_scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=self.num_train_timesteps, shift=shift)
+ sample_scheduler.set_timesteps(sampling_steps, device=self.device)
+ timesteps = sample_scheduler.timesteps
+
+ org_step = sample_scheduler.step
+
+ def step_wrapper(
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ return_dict: bool = True,
+ generator=None,
+ ):
+ return org_step(model_output, timestep, sample, return_dict=return_dict)
+
+ sample_scheduler.step = step_wrapper
+ else:
+ raise NotImplementedError("Unsupported solver.")
+
+ # sample videos
+ latent = noise # on device
+ del noise
+
+ arg_c = {
+ "context": [context[0]],
+ "clip_fea": clip_context,
+ "seq_len": max_seq_len,
+ "y": [y],
+ }
+
+ arg_null = {
+ "context": context_null,
+ "clip_fea": clip_context,
+ "seq_len": max_seq_len,
+ "y": [y],
+ }
+
+ # self.model.to(self.device)
+ for _, t in enumerate(tqdm(timesteps)):
+ latent_model_input = [latent.to(self.device)]
+ latent = latent.to("cpu")
+ timestep = [t]
+
+ timestep = torch.stack(timestep).to(self.device)
+
+ with accelerator.autocast():
+ noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0].to("cpu")
+ noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0].to("cpu")
+
+ latent_model_input = None
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=seed_g
+ )[0]
+ latent = temp_x0.squeeze(0)
+
+ # x0 = [latent.to(self.device)]
+ del latent_model_input, timestep
+
+ del sample_scheduler
+ del self.model
+ synchronize_device(self.device)
+ clean_memory_on_device(self.device)
+ return latent
diff --git a/wan/modules/__init__.py b/wan/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8935bbb45ab4e3f349d203b673102f7cfc07553
--- /dev/null
+++ b/wan/modules/__init__.py
@@ -0,0 +1,16 @@
+from .attention import flash_attention
+from .model import WanModel
+from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
+from .tokenizers import HuggingfaceTokenizer
+from .vae import WanVAE
+
+__all__ = [
+ 'WanVAE',
+ 'WanModel',
+ 'T5Model',
+ 'T5Encoder',
+ 'T5Decoder',
+ 'T5EncoderModel',
+ 'HuggingfaceTokenizer',
+ 'flash_attention',
+]
diff --git a/wan/modules/attention.py b/wan/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..7653f7c7c1ceee172f6fd32686fa038dff3472dc
--- /dev/null
+++ b/wan/modules/attention.py
@@ -0,0 +1,312 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+from typing import Optional
+import torch
+
+try:
+ import flash_attn_interface
+
+ FLASH_ATTN_3_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+
+try:
+ import flash_attn
+
+ FLASH_ATTN_2_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_2_AVAILABLE = False
+
+try:
+ import sageattention
+
+ SAGE_ATTN_AVAILABLE = True
+except ModuleNotFoundError:
+ SAGE_ATTN_AVAILABLE = False
+
+try:
+ import xformers.ops as xops
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ XFORMERS_AVAILABLE = False
+
+
+import warnings
+
+__all__ = [
+ "flash_attention",
+ "attention",
+]
+
+
+def flash_attention(
+ qkv,
+ q_lens=None,
+ k_lens=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ q_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ deterministic=False,
+ dtype=torch.bfloat16,
+ version=None,
+ attn_mode: Optional[str] = "torch",
+ split_attn: bool = False,
+):
+ """
+ q: [B, Lq, Nq, C1].
+ k: [B, Lk, Nk, C1].
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
+ q_lens: [B].
+ k_lens: [B].
+ dropout_p: float. Dropout probability.
+ softmax_scale: float. The scaling of QK^T before applying softmax.
+ causal: bool. Whether to apply causal attention mask.
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
+ deterministic: bool. If True, slightly slower and uses more memory.
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
+ """
+ q, k, v = qkv
+ qkv.clear()
+
+ half_dtypes = (torch.float16, torch.bfloat16)
+ assert dtype in half_dtypes
+ # assert q.device.type == "cuda" and q.size(-1) <= 256
+
+ # params
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
+
+ def half(x):
+ return x if x.dtype in half_dtypes else x.to(dtype)
+
+ # We cannot test Flash attention 3 in musubi tuner, so keep the original code.
+ # Customized code (except for flash attention 3) is not supported q_lens and k_lens.
+ if attn_mode != "flash3" and attn_mode != "sageattn":
+ assert q_lens is None, "q_lens is not supported except for flash attention 3."
+ assert k_lens is None or (
+ min(k_lens) == max(k_lens) and k_lens[0] == lk
+ ), "k_lens is not supported except for flash attention 3."
+
+ # SDPA
+ if attn_mode == "torch" or attn_mode == "sdpa":
+ assert not deterministic, "deterministic is not supported in scaled_dot_product_attention."
+ if q_scale is not None:
+ q = q * q_scale
+ q = half(q.transpose(1, 2))
+ k = half(k.transpose(1, 2))
+ v = half(v.transpose(1, 2))
+
+ if not split_attn:
+ q = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, is_causal=causal, dropout_p=dropout_p, scale=softmax_scale
+ )
+ x = q
+ else:
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(
+ q[i : i + 1], k[i : i + 1], v[i : i + 1], is_causal=causal, dropout_p=dropout_p, scale=softmax_scale
+ )
+
+ del q, k, v
+ x = x.transpose(1, 2).contiguous()
+ return x.type(out_dtype)
+
+ # flash attention 2
+ if attn_mode == "flash" or attn_mode == "flash2":
+ if q_scale is not None:
+ q = q * q_scale
+ q = half(q)
+ k = half(k)
+ v = half(v)
+
+ if not split_attn:
+ q = flash_attn.flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size, deterministic=deterministic)
+ x = q
+ else:
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = flash_attn.flash_attn_func(
+ q[i : i + 1],
+ k[i : i + 1],
+ v[i : i + 1],
+ dropout_p,
+ softmax_scale,
+ causal,
+ window_size,
+ deterministic=deterministic,
+ )
+ del q, k, v
+ return x.type(out_dtype)
+
+ # xformers
+ if attn_mode == "xformers":
+ assert not deterministic, "deterministic is not supported in xformers."
+ assert not causal, "causal is not supported in xformers."
+ if q_scale is not None:
+ q = q * q_scale
+ q = half(q)
+ k = half(k)
+ v = half(v)
+
+ if not split_attn:
+ q = xops.memory_efficient_attention(q, k, v, p=dropout_p, scale=softmax_scale)
+ x = q
+ else:
+ x = torch.empty_like(q)
+ for i in range(q.size(0)):
+ x[i : i + 1] = xops.memory_efficient_attention(
+ q[i : i + 1], k[i : i + 1], v[i : i + 1], p=dropout_p, scale=softmax_scale
+ )
+
+ del q, k, v
+ return x.type(out_dtype)
+
+ # sage attention with fixed length seems to cause NaN in I2V inference.
+ # # sage attention
+ # if attn_mode == "sageattn":
+ # print("Using sage attention")
+ # assert not deterministic, "deterministic is not supported in sage attention."
+ # if q_scale is not None:
+ # q = q * q_scale
+ # q, k, v = half(q), half(k), half(v)
+ # x = sageattention.sageattn(q, k, v, "NHD", is_causal=causal, sm_scale=softmax_scale)
+ # del q, k, v
+ # return x.type(out_dtype)
+
+ assert not split_attn, "split_attn is not supported in flash attention 3 or sage attention."
+
+ # preprocess query: in Wan 2.1, q_lens is always None.
+ if q_lens is None:
+ q = half(q.flatten(0, 1))
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
+ else:
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
+
+ # preprocess key, value
+ if k_lens is None:
+ k = half(k.flatten(0, 1))
+ v = half(v.flatten(0, 1))
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
+ else:
+ # Note: in Wan 2.1, all k_lens are same if we have same image size in the batch.
+ if min(k_lens) == max(k_lens) and k.shape[1] == k_lens[0]:
+ # B, L, N, C -> BN, L, C
+ k = half(k.flatten(0, 1))
+ v = half(v.flatten(0, 1))
+ else:
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
+
+ q = q.to(v.dtype)
+ k = k.to(v.dtype)
+
+ if q_scale is not None:
+ q = q * q_scale
+
+ # if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
+ # warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
+
+ # apply attention
+ # if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
+ if attn_mode == "flash3":
+ # Not tested yet in musubi tuner.
+ # Note: dropout_p, window_size are not supported in FA3 now.
+ x = flash_attn_interface.flash_attn_varlen_func(
+ q=q,
+ k=k,
+ v=v,
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
+ seqused_q=None,
+ seqused_k=None,
+ max_seqlen_q=lq,
+ max_seqlen_k=lk,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ deterministic=deterministic,
+ )[0].unflatten(0, (b, lq))
+ # elif (version is None or version == 2) and FLASH_ATTN_2_AVAILABLE:
+ # # assert FLASH_ATTN_2_AVAILABLE
+ # x = flash_attn.flash_attn_varlen_func(
+ # q=q,
+ # k=k,
+ # v=v,
+ # cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
+ # cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
+ # max_seqlen_q=lq,
+ # max_seqlen_k=lk,
+ # dropout_p=dropout_p,
+ # softmax_scale=softmax_scale,
+ # causal=causal,
+ # window_size=window_size,
+ # deterministic=deterministic,
+ # ).unflatten(0, (b, lq))
+ # elif version is None and SAGE_ATTN_AVAILABLE:
+ elif attn_mode == "sageattn":
+ # print("Using sage attention")
+ assert not causal, "SAGE attention does not support causal attention."
+ x = sageattention.sageattn_varlen(
+ q=q,
+ k=k,
+ v=v,
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
+ max_seqlen_q=lq,
+ max_seqlen_k=lk,
+ sm_scale=softmax_scale,
+ ).unflatten(0, (b, lq))
+ else:
+ raise ValueError(f"Unknown attention mode: {attn_mode}")
+
+ # output
+ return x.type(out_dtype)
+
+
+def attention(
+ q,
+ k,
+ v,
+ q_lens=None,
+ k_lens=None,
+ dropout_p=0.0,
+ softmax_scale=None,
+ q_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ deterministic=False,
+ dtype=torch.bfloat16,
+ fa_version=None,
+):
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
+ return flash_attention(
+ q=q,
+ k=k,
+ v=v,
+ q_lens=q_lens,
+ k_lens=k_lens,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ q_scale=q_scale,
+ causal=causal,
+ window_size=window_size,
+ deterministic=deterministic,
+ dtype=dtype,
+ version=fa_version,
+ )
+ else:
+ if q_lens is not None or k_lens is not None:
+ warnings.warn(
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
+ )
+ attn_mask = None
+
+ q = q.transpose(1, 2).to(dtype)
+ k = k.transpose(1, 2).to(dtype)
+ v = v.transpose(1, 2).to(dtype)
+
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
+
+ out = out.transpose(1, 2).contiguous()
+ return out
diff --git a/wan/modules/clip.py b/wan/modules/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fbd867678e1d75d402583c91ea97bba74194c52
--- /dev/null
+++ b/wan/modules/clip.py
@@ -0,0 +1,546 @@
+# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import logging
+import math
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as T
+from accelerate import init_empty_weights
+
+from .attention import flash_attention
+from .tokenizers import HuggingfaceTokenizer
+from .xlm_roberta import XLMRoberta
+
+from utils.safetensors_utils import load_safetensors
+
+__all__ = [
+ "XLMRobertaCLIP",
+ "clip_xlm_roberta_vit_h_14",
+ "CLIPModel",
+]
+
+
+def pos_interpolate(pos, seq_len):
+ if pos.size(1) == seq_len:
+ return pos
+ else:
+ src_grid = int(math.sqrt(pos.size(1)))
+ tar_grid = int(math.sqrt(seq_len))
+ n = pos.size(1) - src_grid * src_grid
+ return torch.cat(
+ [
+ pos[:, :n],
+ F.interpolate(
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
+ size=(tar_grid, tar_grid),
+ mode="bicubic",
+ align_corners=False,
+ )
+ .flatten(2)
+ .transpose(1, 2),
+ ],
+ dim=1,
+ )
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerNorm(nn.LayerNorm):
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.causal = causal
+ self.attn_dropout = attn_dropout
+ self.proj_dropout = proj_dropout
+
+ # layers
+ self.to_qkv = nn.Linear(dim, dim * 3)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
+
+ # compute attention
+ p = self.attn_dropout if self.training else 0.0
+ # x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
+ # print(q.shape, k.shape, v.shape)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=p, is_causal=self.causal)
+ # print(x.shape)
+ x = x.transpose(1, 2).contiguous()
+ x = x.reshape(b, s, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+ return x
+
+
+class SwiGLU(nn.Module):
+
+ def __init__(self, dim, mid_dim):
+ super().__init__()
+ self.dim = dim
+ self.mid_dim = mid_dim
+
+ # layers
+ self.fc1 = nn.Linear(dim, mid_dim)
+ self.fc2 = nn.Linear(dim, mid_dim)
+ self.fc3 = nn.Linear(mid_dim, dim)
+
+ def forward(self, x):
+ x = F.silu(self.fc1(x)) * self.fc2(x)
+ x = self.fc3(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ post_norm=False,
+ causal=False,
+ activation="quick_gelu",
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ norm_eps=1e-5,
+ ):
+ assert activation in ["quick_gelu", "gelu", "swi_glu"]
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.causal = causal
+ self.norm_eps = norm_eps
+
+ # layers
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
+ if activation == "swi_glu":
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim),
+ nn.Dropout(proj_dropout),
+ )
+
+ def forward(self, x):
+ if self.post_norm:
+ x = x + self.norm1(self.attn(x))
+ x = x + self.norm2(self.mlp(x))
+ else:
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class AttentionPool(nn.Module):
+
+ def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.proj_dropout = proj_dropout
+ self.norm_eps = norm_eps
+
+ # layers
+ gain = 1.0 / math.sqrt(dim)
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.to_q = nn.Linear(dim, dim)
+ self.to_kv = nn.Linear(dim, dim * 2)
+ self.proj = nn.Linear(dim, dim)
+ self.norm = LayerNorm(dim, eps=norm_eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim),
+ nn.Dropout(proj_dropout),
+ )
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
+
+ # compute attention
+ # this line is never used because pool_type="token" in Wan2.1
+ x = flash_attention(q, k, v, version=2)
+ x = x.reshape(b, 1, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+
+ # mlp
+ x = x + self.mlp(self.norm(x))
+ return x[:, 0]
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=16,
+ dim=768,
+ mlp_ratio=4,
+ out_dim=512,
+ num_heads=12,
+ num_layers=12,
+ pool_type="token",
+ pre_norm=True,
+ post_norm=False,
+ activation="quick_gelu",
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5,
+ ):
+ if image_size % patch_size != 0:
+ print("[WARNING] image_size is not divisible by patch_size", flush=True)
+ assert pool_type in ("token", "token_fc", "attn_pool")
+ out_dim = out_dim or dim
+ super().__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = (image_size // patch_size) ** 2
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.pool_type = pool_type
+ self.post_norm = post_norm
+ self.norm_eps = norm_eps
+
+ # embeddings
+ gain = 1.0 / math.sqrt(dim)
+ self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
+ if pool_type in ("token", "token_fc"):
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.pos_embedding = nn.Parameter(
+ gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
+ )
+ self.dropout = nn.Dropout(embedding_dropout)
+
+ # transformer
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
+ self.transformer = nn.Sequential(
+ *[
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps)
+ for _ in range(num_layers)
+ ]
+ )
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
+
+ # head
+ if pool_type == "token":
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
+ elif pool_type == "token_fc":
+ self.head = nn.Linear(dim, out_dim)
+ elif pool_type == "attn_pool":
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
+
+ def forward(self, x, interpolation=False, use_31_block=False):
+ b = x.size(0)
+
+ # embeddings
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
+ if self.pool_type in ("token", "token_fc"):
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
+ if interpolation:
+ e = pos_interpolate(self.pos_embedding, x.size(1))
+ else:
+ e = self.pos_embedding
+ x = self.dropout(x + e)
+ if self.pre_norm is not None:
+ x = self.pre_norm(x)
+
+ # transformer
+ if use_31_block:
+ x = self.transformer[:-1](x)
+ return x
+ else:
+ x = self.transformer(x)
+ return x
+
+
+class XLMRobertaWithHead(XLMRoberta):
+
+ def __init__(self, **kwargs):
+ self.out_dim = kwargs.pop("out_dim")
+ super().__init__(**kwargs)
+
+ # head
+ mid_dim = (self.dim + self.out_dim) // 2
+ self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False))
+
+ def forward(self, ids):
+ # xlm-roberta
+ x = super().forward(ids)
+
+ # average pooling
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
+
+ # head
+ x = self.head(x)
+ return x
+
+
+class XLMRobertaCLIP(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool="token",
+ vision_pre_norm=True,
+ vision_post_norm=False,
+ activation="gelu",
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.vision_dim = vision_dim
+ self.vision_mlp_ratio = vision_mlp_ratio
+ self.vision_heads = vision_heads
+ self.vision_layers = vision_layers
+ self.vision_pre_norm = vision_pre_norm
+ self.vision_post_norm = vision_post_norm
+ self.activation = activation
+ self.vocab_size = vocab_size
+ self.max_text_len = max_text_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.text_dim = text_dim
+ self.text_heads = text_heads
+ self.text_layers = text_layers
+ self.text_post_norm = text_post_norm
+ self.norm_eps = norm_eps
+
+ # models
+ self.visual = VisionTransformer(
+ image_size=image_size,
+ patch_size=patch_size,
+ dim=vision_dim,
+ mlp_ratio=vision_mlp_ratio,
+ out_dim=embed_dim,
+ num_heads=vision_heads,
+ num_layers=vision_layers,
+ pool_type=vision_pool,
+ pre_norm=vision_pre_norm,
+ post_norm=vision_post_norm,
+ activation=activation,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout,
+ norm_eps=norm_eps,
+ )
+ self.textual = XLMRobertaWithHead(
+ vocab_size=vocab_size,
+ max_seq_len=max_text_len,
+ type_size=type_size,
+ pad_id=pad_id,
+ dim=text_dim,
+ out_dim=embed_dim,
+ num_heads=text_heads,
+ num_layers=text_layers,
+ post_norm=text_post_norm,
+ dropout=text_dropout,
+ )
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
+
+ def forward(self, imgs, txt_ids):
+ """
+ imgs: [B, 3, H, W] of torch.float32.
+ - mean: [0.48145466, 0.4578275, 0.40821073]
+ - std: [0.26862954, 0.26130258, 0.27577711]
+ txt_ids: [B, L] of torch.long.
+ Encoded by data.CLIPTokenizer.
+ """
+ xi = self.visual(imgs)
+ xt = self.textual(txt_ids)
+ return xi, xt
+
+ def param_groups(self):
+ groups = [
+ {"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], "weight_decay": 0.0},
+ {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
+ ]
+ return groups
+
+
+def _clip(
+ pretrained=False,
+ pretrained_name=None,
+ model_cls=XLMRobertaCLIP,
+ return_transforms=False,
+ return_tokenizer=False,
+ tokenizer_padding="eos",
+ dtype=torch.float32,
+ device="cpu",
+ **kwargs,
+):
+ # # init a model on device
+ # with torch.device(device):
+ model = model_cls(**kwargs)
+
+ # # set device
+ # model = model.to(dtype=dtype, device=device)
+ output = (model,)
+
+ # init transforms
+ if return_transforms:
+ # mean and std
+ if "siglip" in pretrained_name.lower():
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
+ else:
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+
+ # transforms
+ transforms = T.Compose(
+ [
+ T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=mean, std=std),
+ ]
+ )
+ output += (transforms,)
+ return output[0] if len(output) == 1 else output
+
+
+def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
+ cfg = dict(
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool="token",
+ activation="gelu",
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ )
+ cfg.update(**kwargs)
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
+
+
+class CLIPModel:
+
+ def __init__(self, dtype, device, checkpoint_path=None, tokenizer_path=None, weight_path=None):
+ self.dtype = dtype
+ self.device = device
+ self.checkpoint_path = checkpoint_path
+ self.tokenizer_path = tokenizer_path
+ self.weight_path = weight_path
+
+ # init model
+ with init_empty_weights():
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
+ pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device
+ )
+ self.model = self.model.eval().requires_grad_(False)
+
+ logging.info(f"loading {weight_path}")
+ if os.path.splitext(weight_path)[-1] == ".safetensors":
+ sd = load_safetensors(weight_path, device=device, disable_mmap=True, dtype=dtype)
+ else:
+ sd = torch.load(weight_path, map_location=device, weights_only=True)
+ info = self.model.load_state_dict(sd, strict=True, assign=True)
+ self.model = self.model.to(dtype=dtype, device=device)
+ logging.info(f"weights loaded from {weight_path}: {info}")
+
+ # init tokenizer
+ if tokenizer_path is None:
+ tokenizer_path = "Wan-AI/Wan2.1-I2V-14B-720P"
+ subfolder = "xlm-roberta-large"
+ else:
+ subfolder = None
+
+ self.tokenizer = HuggingfaceTokenizer(
+ name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace", subfolder=subfolder
+ )
+
+ def visual(self, videos):
+ # preprocess
+ size = (self.model.image_size,) * 2
+ videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
+
+ # forward
+ # with torch.cuda.amp.autocast(dtype=self.dtype):
+ out = self.model.visual(videos, use_31_block=True)
+ return out
diff --git a/wan/modules/model.py b/wan/modules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c49561aead4fbd4af77dc3e8fc15ad86e62c2d1c
--- /dev/null
+++ b/wan/modules/model.py
@@ -0,0 +1,933 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from accelerate import init_empty_weights
+
+import logging
+
+from utils.safetensors_utils import MemoryEfficientSafeOpen, load_safetensors
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+from utils.device_utils import clean_memory_on_device
+
+from .attention import flash_attention
+from utils.device_utils import clean_memory_on_device
+from modules.custom_offloading_utils import ModelOffloader
+from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8
+
+__all__ = ["WanModel"]
+
+
+def sinusoidal_embedding_1d(dim, position):
+ # preprocess
+ assert dim % 2 == 0
+ half = dim // 2
+ position = position.type(torch.float64)
+
+ # calculation
+ sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x
+
+
+# @amp.autocast(enabled=False)
+# no autocast is needed for rope_apply, because it is already in float64
+def rope_params(max_seq_len, dim, theta=10000):
+ assert dim % 2 == 0
+ freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+
+# @amp.autocast(enabled=False)
+def rope_apply(x, grid_sizes, freqs):
+ device_type = x.device.type
+ with torch.amp.autocast(device_type=device_type, enabled=False):
+ n, c = x.size(2), x.size(3) // 2
+
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
+ freqs_i = torch.cat(
+ [
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
+ ],
+ dim=-1,
+ ).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ x_i = torch.cat([x_i, x[i, seq_len:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).float()
+
+
+def calculate_freqs_i(fhw, c, freqs):
+ f, h, w = fhw
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+ freqs_i = torch.cat(
+ [
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
+ ],
+ dim=-1,
+ ).reshape(f * h * w, 1, -1)
+ return freqs_i
+
+
+# inplace version of rope_apply
+def rope_apply_inplace_cached(x, grid_sizes, freqs_list):
+ # with torch.amp.autocast(device_type=device_type, enabled=False):
+ rope_dtype = torch.float64 # float32 does not reduce memory usage significantly
+
+ n, c = x.size(2), x.size(3) // 2
+
+ # loop over samples
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :seq_len].to(rope_dtype).reshape(seq_len, n, -1, 2))
+ freqs_i = freqs_list[i]
+
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ # x_i = torch.cat([x_i, x[i, seq_len:]])
+
+ # inplace update
+ x[i, :seq_len] = x_i.to(x.dtype)
+
+ return x
+
+
+class WanRMSNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ # return self._norm(x.float()).type_as(x) * self.weight
+ # support fp8
+ return self._norm(x.float()).type_as(x) * self.weight.to(x.dtype)
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+ # def forward(self, x):
+ # r"""
+ # Args:
+ # x(Tensor): Shape [B, L, C]
+ # """
+ # # inplace version, also supports fp8 -> does not have significant performance improvement
+ # original_dtype = x.dtype
+ # x = x.float()
+ # y = x.pow(2).mean(dim=-1, keepdim=True)
+ # y.add_(self.eps)
+ # y.rsqrt_()
+ # x *= y
+ # x = x.to(original_dtype)
+ # x *= self.weight.to(original_dtype)
+ # return x
+
+
+class WanLayerNorm(nn.LayerNorm):
+
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class WanSelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, attn_mode="torch", split_attn=False):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, seq_lens, grid_sizes, freqs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
+ seq_lens(Tensor): Shape [B]
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # # query, key, value function
+ # def qkv_fn(x):
+ # q = self.norm_q(self.q(x)).view(b, s, n, d)
+ # k = self.norm_k(self.k(x)).view(b, s, n, d)
+ # v = self.v(x).view(b, s, n, d)
+ # return q, k, v
+ # q, k, v = qkv_fn(x)
+ # del x
+ # query, key, value function
+
+ q = self.q(x)
+ k = self.k(x)
+ v = self.v(x)
+ del x
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+ q = q.view(b, s, n, d)
+ k = k.view(b, s, n, d)
+ v = v.view(b, s, n, d)
+
+ rope_apply_inplace_cached(q, grid_sizes, freqs)
+ rope_apply_inplace_cached(k, grid_sizes, freqs)
+ qkv = [q, k, v]
+ del q, k, v
+ x = flash_attention(
+ qkv, k_lens=seq_lens, window_size=self.window_size, attn_mode=self.attn_mode, split_attn=self.split_attn
+ )
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanT2VCrossAttention(WanSelfAttention):
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ # q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ # k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ # v = self.v(context).view(b, -1, n, d)
+ q = self.q(x)
+ del x
+ k = self.k(context)
+ v = self.v(context)
+ del context
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+ q = q.view(b, -1, n, d)
+ k = k.view(b, -1, n, d)
+ v = v.view(b, -1, n, d)
+
+ # compute attention
+ qkv = [q, k, v]
+ del q, k, v
+ x = flash_attention(qkv, k_lens=context_lens, attn_mode=self.attn_mode, split_attn=self.split_attn)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanI2VCrossAttention(WanSelfAttention):
+
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, attn_mode="torch", split_attn=False):
+ super().__init__(dim, num_heads, window_size, qk_norm, eps, attn_mode, split_attn)
+
+ self.k_img = nn.Linear(dim, dim)
+ self.v_img = nn.Linear(dim, dim)
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ context_img = context[:, :257]
+ context = context[:, 257:]
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x)
+ del x
+ q = self.norm_q(q)
+ q = q.view(b, -1, n, d)
+ k = self.k(context)
+ k = self.norm_k(k).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+ del context
+
+ # compute attention
+ qkv = [q, k, v]
+ del k, v
+ x = flash_attention(qkv, k_lens=context_lens, attn_mode=self.attn_mode, split_attn=self.split_attn)
+
+ # compute query, key, value
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
+ v_img = self.v_img(context_img).view(b, -1, n, d)
+ del context_img
+
+ # compute attention
+ qkv = [q, k_img, v_img]
+ del q, k_img, v_img
+ img_x = flash_attention(qkv, k_lens=None, attn_mode=self.attn_mode, split_attn=self.split_attn)
+
+ # output
+ x = x.flatten(2)
+ img_x = img_x.flatten(2)
+ if self.training:
+ x = x + img_x # avoid inplace
+ else:
+ x += img_x
+ del img_x
+
+ x = self.o(x)
+ return x
+
+
+WAN_CROSSATTENTION_CLASSES = {
+ "t2v_cross_attn": WanT2VCrossAttention,
+ "i2v_cross_attn": WanI2VCrossAttention,
+}
+
+
+class WanAttentionBlock(nn.Module):
+
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ attn_mode="torch",
+ split_attn=False,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # layers
+ self.norm1 = WanLayerNorm(dim, eps)
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps, attn_mode, split_attn)
+ self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps, attn_mode, split_attn)
+ self.norm2 = WanLayerNorm(dim, eps)
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+
+ def _forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ assert e.dtype == torch.float32
+ # with amp.autocast(dtype=torch.float32):
+ # e = (self.modulation + e).chunk(6, dim=1)
+ # support fp8
+ e = self.modulation.to(torch.float32) + e
+ e = e.chunk(6, dim=1)
+ assert e[0].dtype == torch.float32
+
+ # self-attention
+ y = self.self_attn(self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs)
+ # with amp.autocast(dtype=torch.float32):
+ # x = x + y * e[2]
+ x = x + y.to(torch.float32) * e[2]
+ del y
+
+ # cross-attention & ffn function
+ # def cross_attn_ffn(x, context, context_lens, e):
+ # x += self.cross_attn(self.norm3(x), context, context_lens)
+ # y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
+ # # with amp.autocast(dtype=torch.float32):
+ # # x = x + y * e[5]
+ # x += y.to(torch.float32) * e[5]
+ # return x
+ # x = cross_attn_ffn(x, context, context_lens, e)
+
+ # x += self.cross_attn(self.norm3(x), context, context_lens) # backward error
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
+ del context
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
+ x = x + y.to(torch.float32) * e[5]
+ del y
+ return x
+
+ def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens):
+ if self.training and self.gradient_checkpointing:
+ return checkpoint(self._forward, x, e, seq_lens, grid_sizes, freqs, context, context_lens, use_reentrant=False)
+ return self._forward(x, e, seq_lens, grid_sizes, freqs, context, context_lens)
+
+
+class Head(nn.Module):
+
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.eps = eps
+
+ # layers
+ out_dim = math.prod(patch_size) * out_dim
+ self.norm = WanLayerNorm(dim, eps)
+ self.head = nn.Linear(dim, out_dim)
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, e):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ e(Tensor): Shape [B, C]
+ """
+ assert e.dtype == torch.float32
+ # with amp.autocast(dtype=torch.float32):
+ # e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
+ # x = self.head(self.norm(x) * (1 + e[1]) + e[0])
+ # support fp8
+ e = (self.modulation.to(torch.float32) + e.unsqueeze(1)).chunk(2, dim=1)
+ x = self.head(self.norm(x) * (1 + e[1]) + e[0])
+ return x
+
+
+class MLPProj(torch.nn.Module):
+
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.LayerNorm(in_dim),
+ torch.nn.Linear(in_dim, in_dim),
+ torch.nn.GELU(),
+ torch.nn.Linear(in_dim, out_dim),
+ torch.nn.LayerNorm(out_dim),
+ )
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class WanModel(nn.Module): # ModelMixin, ConfigMixin):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
+ _no_split_modules = ["WanAttentionBlock"]
+
+ # @register_to_config
+ def __init__(
+ self,
+ model_type="t2v",
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ attn_mode=None,
+ split_attn=False,
+ ):
+ r"""
+ Initialize the diffusion model backbone.
+
+ Args:
+ model_type (`str`, *optional*, defaults to 't2v'):
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
+ text_len (`int`, *optional*, defaults to 512):
+ Fixed length for text embeddings
+ in_dim (`int`, *optional*, defaults to 16):
+ Input video channels (C_in)
+ dim (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the transformer
+ ffn_dim (`int`, *optional*, defaults to 8192):
+ Intermediate dimension in feed-forward network
+ freq_dim (`int`, *optional*, defaults to 256):
+ Dimension for sinusoidal time embeddings
+ text_dim (`int`, *optional*, defaults to 4096):
+ Input dimension for text embeddings
+ out_dim (`int`, *optional*, defaults to 16):
+ Output video channels (C_out)
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads
+ num_layers (`int`, *optional*, defaults to 32):
+ Number of transformer blocks
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
+ Window size for local attention (-1 indicates global attention)
+ qk_norm (`bool`, *optional*, defaults to True):
+ Enable query/key normalization
+ cross_attn_norm (`bool`, *optional*, defaults to False):
+ Enable cross-attention normalization
+ eps (`float`, *optional*, defaults to 1e-6):
+ Epsilon value for normalization layers
+ """
+
+ super().__init__()
+
+ assert model_type in ["t2v", "i2v"]
+ self.model_type = model_type
+
+ self.patch_size = patch_size
+ self.text_len = text_len
+ self.in_dim = in_dim
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.freq_dim = freq_dim
+ self.text_dim = text_dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+ self.attn_mode = attn_mode if attn_mode is not None else "torch"
+ self.split_attn = split_attn
+
+ # embeddings
+ self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
+
+ self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
+
+ # blocks
+ cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
+ self.blocks = nn.ModuleList(
+ [
+ WanAttentionBlock(
+ cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, attn_mode, split_attn
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # head
+ self.head = Head(dim, out_dim, patch_size, eps)
+
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
+ d = dim // num_heads
+ self.freqs = torch.cat(
+ [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1
+ )
+ self.freqs_fhw = {}
+
+ if model_type == "i2v":
+ self.img_emb = MLPProj(1280, dim)
+
+ # initialize weights
+ self.init_weights()
+
+ self.gradient_checkpointing = False
+
+ # offloading
+ self.blocks_to_swap = None
+ self.offloader = None
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def fp8_optimization(
+ self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False
+ ) -> int:
+ """
+ Optimize the model state_dict with fp8.
+
+ Args:
+ state_dict (dict[str, torch.Tensor]):
+ The state_dict of the model.
+ device (torch.device):
+ The device to calculate the weight.
+ move_to_device (bool):
+ Whether to move the weight to the device after optimization.
+ """
+ TARGET_KEYS = ["blocks"]
+ EXCLUDE_KEYS = [
+ "norm",
+ "patch_embedding",
+ "text_embedding",
+ "time_embedding",
+ "time_projection",
+ "head",
+ "modulation",
+ "img_emb",
+ ]
+
+ # inplace optimization
+ state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device)
+
+ # apply monkey patching
+ apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm)
+
+ return state_dict
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing = True
+
+ for block in self.blocks:
+ block.enable_gradient_checkpointing()
+
+ print(f"WanModel: Gradient checkpointing enabled.")
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing = False
+
+ for block in self.blocks:
+ block.disable_gradient_checkpointing()
+
+ print(f"WanModel: Gradient checkpointing disabled.")
+
+ def enable_block_swap(self, blocks_to_swap: int, device: torch.device, supports_backward: bool):
+ self.blocks_to_swap = blocks_to_swap
+ self.num_blocks = len(self.blocks)
+
+ assert (
+ self.blocks_to_swap <= self.num_blocks - 1
+ ), f"Cannot swap more than {self.num_blocks - 1} blocks. Requested {self.blocks_to_swap} blocks to swap."
+
+ self.offloader = ModelOffloader(
+ "wan_attn_block", self.blocks, self.num_blocks, self.blocks_to_swap, supports_backward, device # , debug=True
+ )
+ print(
+ f"WanModel: Block swap enabled. Swapping {self.blocks_to_swap} blocks out of {self.num_blocks} blocks. Supports backward: {supports_backward}"
+ )
+
+ def switch_block_swap_for_inference(self):
+ if self.blocks_to_swap:
+ self.offloader.set_forward_only(True)
+ self.prepare_block_swap_before_forward()
+ print(f"WanModel: Block swap set to forward only.")
+
+ def switch_block_swap_for_training(self):
+ if self.blocks_to_swap:
+ self.offloader.set_forward_only(False)
+ self.prepare_block_swap_before_forward()
+ print(f"WanModel: Block swap set to forward and backward.")
+
+ def move_to_device_except_swap_blocks(self, device: torch.device):
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
+ if self.blocks_to_swap:
+ save_blocks = self.blocks
+ self.blocks = None
+
+ self.to(device)
+
+ if self.blocks_to_swap:
+ self.blocks = save_blocks
+
+ def prepare_block_swap_before_forward(self):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self.offloader.prepare_block_devices_before_forward(self.blocks)
+
+ def forward(self, x, t, context, seq_len, clip_fea=None, y=None, skip_block_indices=None):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (List[Tensor]):
+ List of input video tensors, each with shape [C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ # remove assertions to work with Fun-Control T2V
+ # if self.model_type == "i2v":
+ # assert clip_fea is not None and y is not None
+ # params
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+
+ if y is not None:
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+ y = None
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+
+ freqs_list = []
+ for fhw in grid_sizes:
+ fhw = tuple(fhw.tolist())
+ if fhw not in self.freqs_fhw:
+ c = self.dim // self.num_heads // 2
+ self.freqs_fhw[fhw] = calculate_freqs_i(fhw, c, self.freqs)
+ freqs_list.append(self.freqs_fhw[fhw])
+
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ assert seq_lens.max() <= seq_len, f"Sequence length exceeds maximum allowed length {seq_len}. Got {seq_lens.max()}"
+ x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
+
+ # time embeddings
+ # with amp.autocast(dtype=torch.float32):
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float32):
+ e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ if type(context) is list:
+ context = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
+ context = self.text_embedding(context)
+
+ if clip_fea is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+ clip_fea = None
+ context_clip = None
+
+ # arguments
+ kwargs = dict(e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs_list, context=context, context_lens=context_lens)
+
+ if self.blocks_to_swap:
+ clean_memory_on_device(device)
+
+ # print(f"x: {x.shape}, e: {e0.shape}, context: {context.shape}, seq_lens: {seq_lens}")
+ for block_idx, block in enumerate(self.blocks):
+ is_block_skipped = skip_block_indices is not None and block_idx in skip_block_indices
+
+ if self.blocks_to_swap and not is_block_skipped:
+ self.offloader.wait_for_block(block_idx)
+
+ if not is_block_skipped:
+ x = block(x, **kwargs)
+
+ if self.blocks_to_swap:
+ self.offloader.submit_move_blocks_forward(self.blocks, block_idx)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
+
+ def unpatchify(self, x, grid_sizes):
+ r"""
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x (List[Tensor]):
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
+ grid_sizes (Tensor):
+ Original spatial-temporal grid dimensions before patching,
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
+
+ Returns:
+ List[Tensor]:
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
+ """
+
+ c = self.out_dim
+ out = []
+ for u, v in zip(x, grid_sizes.tolist()):
+ u = u[: math.prod(v)].view(*v, *self.patch_size, c)
+ u = torch.einsum("fhwpqrc->cfphqwr", u)
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
+ out.append(u)
+ return out
+
+ def init_weights(self):
+ r"""
+ Initialize model parameters using Xavier initialization.
+ """
+
+ # basic init
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ # init embeddings
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
+ for m in self.text_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=0.02)
+ for m in self.time_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=0.02)
+
+ # init output layer
+ nn.init.zeros_(self.head.head.weight)
+
+
+def detect_wan_sd_dtype(path: str) -> torch.dtype:
+ # get dtype from model weights
+ with MemoryEfficientSafeOpen(path) as f:
+ keys = set(f.keys())
+ key1 = "model.diffusion_model.blocks.0.cross_attn.k.weight" # 1.3B
+ key2 = "blocks.0.cross_attn.k.weight" # 14B
+ if key1 in keys:
+ dit_dtype = f.get_tensor(key1).dtype
+ elif key2 in keys:
+ dit_dtype = f.get_tensor(key2).dtype
+ else:
+ raise ValueError(f"Could not find the dtype in the model weights: {path}")
+ logger.info(f"Detected DiT dtype: {dit_dtype}")
+ return dit_dtype
+
+
+def load_wan_model(
+ config: any,
+ device: Union[str, torch.device],
+ dit_path: str,
+ attn_mode: str,
+ split_attn: bool,
+ loading_device: Union[str, torch.device],
+ dit_weight_dtype: Optional[torch.dtype],
+ fp8_scaled: bool = False,
+) -> WanModel:
+ # dit_weight_dtype is None for fp8_scaled
+ assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
+
+ device = torch.device(device)
+ loading_device = torch.device(loading_device)
+
+ with init_empty_weights():
+ logger.info(f"Creating WanModel")
+ model = WanModel(
+ model_type="i2v" if config.i2v else "t2v",
+ dim=config.dim,
+ eps=config.eps,
+ ffn_dim=config.ffn_dim,
+ freq_dim=config.freq_dim,
+ in_dim=config.in_dim,
+ num_heads=config.num_heads,
+ num_layers=config.num_layers,
+ out_dim=config.out_dim,
+ text_len=config.text_len,
+ attn_mode=attn_mode,
+ split_attn=split_attn,
+ )
+ if dit_weight_dtype is not None:
+ model.to(dit_weight_dtype)
+
+ # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others)
+ wan_loading_device = torch.device("cpu") if fp8_scaled else loading_device
+ logger.info(f"Loading DiT model from {dit_path}, device={wan_loading_device}, dtype={dit_weight_dtype}")
+
+ # load model weights with the specified dtype or as is
+ sd = load_safetensors(dit_path, wan_loading_device, disable_mmap=True, dtype=dit_weight_dtype)
+
+ # remove "model.diffusion_model." prefix: 1.3B model has this prefix
+ for key in list(sd.keys()):
+ if key.startswith("model.diffusion_model."):
+ sd[key[22:]] = sd.pop(key)
+
+ if fp8_scaled:
+ # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap)
+ logger.info(f"Optimizing model weights to fp8. This may take a while.")
+ sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu")
+
+ if loading_device.type != "cpu":
+ # make sure all the model weights are on the loading_device
+ logger.info(f"Moving weights to {loading_device}")
+ for key in sd.keys():
+ sd[key] = sd[key].to(loading_device)
+
+ info = model.load_state_dict(sd, strict=True, assign=True)
+ logger.info(f"Loaded DiT model from {dit_path}, info={info}")
+
+ return model
diff --git a/wan/modules/t5.py b/wan/modules/t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbc89c8342ae9c799fc4674e51fa5661131e38b4
--- /dev/null
+++ b/wan/modules/t5.py
@@ -0,0 +1,514 @@
+# Modified from transformers.models.t5.modeling_t5
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+# import logging
+import math
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .tokenizers import HuggingfaceTokenizer
+from accelerate import init_empty_weights
+from safetensors.torch import load_file
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+__all__ = [
+ "T5Model",
+ "T5Encoder",
+ "T5Decoder",
+ "T5EncoderModel",
+]
+
+
+def fp16_clamp(x):
+ if x.dtype == torch.float16 and torch.isinf(x).any():
+ clamp = torch.finfo(x.dtype).max - 1000
+ x = torch.clamp(x, min=-clamp, max=clamp)
+ return x
+
+
+def init_weights(m):
+ if isinstance(m, T5LayerNorm):
+ nn.init.ones_(m.weight)
+ elif isinstance(m, T5Model):
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
+ elif isinstance(m, T5FeedForward):
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
+ elif isinstance(m, T5Attention):
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
+ elif isinstance(m, T5RelativeEmbedding):
+ nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
+
+
+class GELU(nn.Module):
+
+ def forward(self, x):
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5LayerNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-6):
+ super(T5LayerNorm, self).__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ x = x.type_as(self.weight)
+ return self.weight * x
+
+
+class T5Attention(nn.Module):
+
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
+ assert dim_attn % num_heads == 0
+ super(T5Attention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.num_heads = num_heads
+ self.head_dim = dim_attn // num_heads
+
+ # layers
+ self.q = nn.Linear(dim, dim_attn, bias=False)
+ self.k = nn.Linear(dim, dim_attn, bias=False)
+ self.v = nn.Linear(dim, dim_attn, bias=False)
+ self.o = nn.Linear(dim_attn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, context=None, mask=None, pos_bias=None):
+ """
+ x: [B, L1, C].
+ context: [B, L2, C] or None.
+ mask: [B, L2] or [B, L1, L2] or None.
+ """
+ # check inputs
+ context = x if context is None else context
+ b, n, c = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).view(b, -1, n, c)
+ k = self.k(context).view(b, -1, n, c)
+ v = self.v(context).view(b, -1, n, c)
+
+ # attention bias
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
+ if pos_bias is not None:
+ attn_bias += pos_bias
+ if mask is not None:
+ assert mask.ndim in [2, 3]
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
+
+ # compute attention (T5 does not use scaling)
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
+
+ # output
+ x = x.reshape(b, -1, n * c)
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_ffn, dropout=0.1):
+ super(T5FeedForward, self).__init__()
+ self.dim = dim
+ self.dim_ffn = dim_ffn
+
+ # layers
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.fc1(x) * self.gate(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5SelfAttention(nn.Module):
+
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
+ super(T5SelfAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
+
+ def forward(self, x, mask=None, pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
+ return x
+
+
+class T5CrossAttention(nn.Module):
+
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
+ super(T5CrossAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm3 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
+
+ def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
+ return x
+
+
+class T5RelativeEmbedding(nn.Module):
+
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
+ super(T5RelativeEmbedding, self).__init__()
+ self.num_buckets = num_buckets
+ self.num_heads = num_heads
+ self.bidirectional = bidirectional
+ self.max_dist = max_dist
+
+ # layers
+ self.embedding = nn.Embedding(num_buckets, num_heads)
+
+ def forward(self, lq, lk):
+ device = self.embedding.weight.device
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
+ # torch.arange(lq).unsqueeze(1).to(device)
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
+ rel_pos = self._relative_position_bucket(rel_pos)
+ rel_pos_embeds = self.embedding(rel_pos)
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
+ return rel_pos_embeds.contiguous()
+
+ def _relative_position_bucket(self, rel_pos):
+ # preprocess
+ if self.bidirectional:
+ num_buckets = self.num_buckets // 2
+ rel_buckets = (rel_pos > 0).long() * num_buckets
+ rel_pos = torch.abs(rel_pos)
+ else:
+ num_buckets = self.num_buckets
+ rel_buckets = 0
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
+
+ # embeddings for small and large positions
+ max_exact = num_buckets // 2
+ rel_pos_large = (
+ max_exact
+ + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
+ )
+ rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
+ return rel_buckets
+
+
+class T5Encoder(nn.Module):
+
+ def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
+ super(T5Encoder, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
+ self.dropout = nn.Dropout(dropout)
+ self.blocks = nn.ModuleList(
+ [T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)]
+ )
+ self.norm = T5LayerNorm(dim)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def prepare_fp8(self, target_dtype=torch.bfloat16):
+ def forward_hook(module):
+ def forward(hidden_states):
+ hidden_gelu = module.act(module.wi_0(hidden_states))
+ hidden_linear = module.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = module.dropout(hidden_states)
+
+ hidden_states = module.wo(hidden_states)
+ return hidden_states
+
+ return forward
+
+ for module in self.modules():
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
+ # print("set", module.__class__.__name__, "to", target_dtype)
+ module.to(target_dtype)
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
+ # print("set", module.__class__.__name__, "hooks")
+ module.forward = forward_hook(module)
+
+ def forward(self, ids, mask=None):
+ x = self.token_embedding(ids)
+ x = self.dropout(x)
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
+ for block in self.blocks:
+ x = block(x, mask, pos_bias=e)
+ x = self.norm(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5Decoder(nn.Module):
+
+ def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
+ super(T5Decoder, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
+ self.dropout = nn.Dropout(dropout)
+ self.blocks = nn.ModuleList(
+ [T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)]
+ )
+ self.norm = T5LayerNorm(dim)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
+ b, s = ids.size()
+
+ # causal mask
+ if mask is None:
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
+ elif mask.ndim == 2:
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
+
+ # layers
+ x = self.token_embedding(ids)
+ x = self.dropout(x)
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
+ for block in self.blocks:
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
+ x = self.norm(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5Model(nn.Module):
+
+ def __init__(
+ self,
+ vocab_size,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ encoder_layers,
+ decoder_layers,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1,
+ ):
+ super(T5Model, self).__init__()
+ self.vocab_size = vocab_size
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.encoder_layers = encoder_layers
+ self.decoder_layers = decoder_layers
+ self.num_buckets = num_buckets
+
+ # layers
+ self.token_embedding = nn.Embedding(vocab_size, dim)
+ self.encoder = T5Encoder(
+ self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
+ )
+ self.decoder = T5Decoder(
+ self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
+ )
+ self.head = nn.Linear(dim, vocab_size, bias=False)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
+ x = self.encoder(encoder_ids, encoder_mask)
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
+ x = self.head(x)
+ return x
+
+
+def _t5(
+ name,
+ encoder_only=False,
+ decoder_only=False,
+ return_tokenizer=False,
+ tokenizer_kwargs={},
+ **kwargs,
+):
+ # dtype=torch.float32,
+ # device="cpu",
+ # sanity check
+ assert not (encoder_only and decoder_only)
+
+ # params
+ if encoder_only:
+ model_cls = T5Encoder
+ kwargs["vocab"] = kwargs.pop("vocab_size")
+ kwargs["num_layers"] = kwargs.pop("encoder_layers")
+ _ = kwargs.pop("decoder_layers")
+ elif decoder_only:
+ model_cls = T5Decoder
+ kwargs["vocab"] = kwargs.pop("vocab_size")
+ kwargs["num_layers"] = kwargs.pop("decoder_layers")
+ _ = kwargs.pop("encoder_layers")
+ else:
+ model_cls = T5Model
+
+ # # init model
+ # with torch.device(device):
+ model = model_cls(**kwargs)
+
+ # # set device
+ # model = model.to(dtype=dtype, device=device)
+
+ # init tokenizer
+ if return_tokenizer:
+ from .tokenizers import HuggingfaceTokenizer
+
+ tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
+ return model, tokenizer
+ else:
+ return model
+
+
+def umt5_xxl(**kwargs):
+ cfg = dict(
+ vocab_size=256384,
+ dim=4096,
+ dim_attn=4096,
+ dim_ffn=10240,
+ num_heads=64,
+ encoder_layers=24,
+ decoder_layers=24,
+ num_buckets=32,
+ shared_pos=False,
+ dropout=0.1,
+ )
+ cfg.update(**kwargs)
+ return _t5("umt5-xxl", **cfg)
+
+
+class T5EncoderModel:
+
+ def __init__(
+ self,
+ text_len,
+ dtype=torch.bfloat16,
+ device=torch.cuda.current_device(),
+ checkpoint_path=None,
+ tokenizer_path=None,
+ shard_fn=None,
+ weight_path=None,
+ fp8=False,
+ ):
+ self.text_len = text_len
+ self.dtype = dtype if not fp8 else torch.float8_e4m3fn
+ self.device = device
+ self.checkpoint_path = checkpoint_path
+ self.tokenizer_path = tokenizer_path
+
+ # init model
+ with init_empty_weights():
+ model = umt5_xxl(encoder_only=True, return_tokenizer=False)
+
+ model = model.eval().requires_grad_(False)
+ if checkpoint_path is not None:
+ logger.info(f"loading {checkpoint_path}")
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
+ else:
+ logger.info(f"loading weights from {weight_path}")
+ if os.path.splitext(weight_path)[1] == ".safetensors":
+ sd = load_file(weight_path)
+ else:
+ sd = torch.load(weight_path, map_location="cpu", weights_only=True)
+ # remove prefix "encoder." from the state dict
+ sd = {k.replace("encoder.", ""): v for k, v in sd.items()}
+ model.load_state_dict(sd, strict=True, assign=True)
+
+ logger.info(f"moving model to {device} and casting to {self.dtype}")
+ model = model.to(device, dtype=self.dtype)
+
+ if fp8:
+ logger.info("preparing model for fp8")
+ model.prepare_fp8(dtype)
+
+ self.model = model
+ # if shard_fn is not None:
+ # self.model = shard_fn(self.model, sync_module_states=False)
+ # else:
+ # self.model.to(self.device)
+ # init tokenizer
+ if tokenizer_path is None:
+ tokenizer_path = "Wan-AI/Wan2.1-T2V-14B"
+ subfolder = "google/umt5-xxl"
+ else:
+ subfolder = None
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace", subfolder=subfolder)
+
+ def __call__(self, texts, device):
+ ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
+ ids = ids.to(device)
+ mask = mask.to(device)
+ seq_lens = mask.gt(0).sum(dim=1).long()
+ context = self.model(ids, mask)
+ return [u[:v] for u, v in zip(context, seq_lens)]
diff --git a/wan/modules/tokenizers.py b/wan/modules/tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..121e591c48f82f82daa51a6ce38ae9a27beea8d2
--- /dev/null
+++ b/wan/modules/tokenizers.py
@@ -0,0 +1,82 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import html
+import string
+
+import ftfy
+import regex as re
+from transformers import AutoTokenizer
+
+__all__ = ['HuggingfaceTokenizer']
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+def canonicalize(text, keep_punctuation_exact_string=None):
+ text = text.replace('_', ' ')
+ if keep_punctuation_exact_string:
+ text = keep_punctuation_exact_string.join(
+ part.translate(str.maketrans('', '', string.punctuation))
+ for part in text.split(keep_punctuation_exact_string))
+ else:
+ text = text.translate(str.maketrans('', '', string.punctuation))
+ text = text.lower()
+ text = re.sub(r'\s+', ' ', text)
+ return text.strip()
+
+
+class HuggingfaceTokenizer:
+
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
+ self.name = name
+ self.seq_len = seq_len
+ self.clean = clean
+
+ # init tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
+ self.vocab_size = self.tokenizer.vocab_size
+
+ def __call__(self, sequence, **kwargs):
+ return_mask = kwargs.pop('return_mask', False)
+
+ # arguments
+ _kwargs = {'return_tensors': 'pt'}
+ if self.seq_len is not None:
+ _kwargs.update({
+ 'padding': 'max_length',
+ 'truncation': True,
+ 'max_length': self.seq_len
+ })
+ _kwargs.update(**kwargs)
+
+ # tokenization
+ if isinstance(sequence, str):
+ sequence = [sequence]
+ if self.clean:
+ sequence = [self._clean(u) for u in sequence]
+ ids = self.tokenizer(sequence, **_kwargs)
+
+ # output
+ if return_mask:
+ return ids.input_ids, ids.attention_mask
+ else:
+ return ids.input_ids
+
+ def _clean(self, text):
+ if self.clean == 'whitespace':
+ text = whitespace_clean(basic_clean(text))
+ elif self.clean == 'lower':
+ text = whitespace_clean(basic_clean(text)).lower()
+ elif self.clean == 'canonicalize':
+ text = canonicalize(basic_clean(text))
+ return text
diff --git a/wan/modules/vae.py b/wan/modules/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..02f5254a3d4bc0352098c0f0dfd533ed93354163
--- /dev/null
+++ b/wan/modules/vae.py
@@ -0,0 +1,760 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import logging
+import os
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from safetensors.torch import load_file
+
+__all__ = [
+ "WanVAE",
+]
+
+CACHE_T = 2
+
+
+class CausalConv3d(nn.Conv3d):
+ """
+ Causal 3d convolusion.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+
+ return super().forward(x)
+
+
+class RMS_norm(nn.Module):
+
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class Upsample(nn.Upsample):
+
+ def forward(self, x):
+ """
+ Fix bfloat16 support for nearest neighbor interpolation.
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ )
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ self.cache_device = None
+
+ def set_cache_device(self, device):
+ self.cache_device = device
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ cache_device = self.cache_device if self.cache_device is not None else x.device
+
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.resample(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone().to(cache_device)
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -1:, :, :].clone().to(cache_device)
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
+ # # cache last frame of last two chunk
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :].to(x.device), x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+ def init_weight(self, conv):
+ conv_weight = conv.weight
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ one_matrix = torch.eye(c1, c2)
+ init_matrix = one_matrix
+ nn.init.zeros_(conv_weight)
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def init_weight2(self, conv):
+ conv_weight = conv.weight.data
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ init_matrix = torch.eye(c1 // 2, c2)
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False),
+ nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
+ )
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
+
+ self.cache_device = None
+
+ def set_cache_device(self, device):
+ self.cache_device = device
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ cache_device = self.cache_device if self.cache_device is not None else x.device
+
+ h = self.shortcut(x)
+ for layer in self.residual:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = layer(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ Causal self-attention with a single head.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = RMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x):
+ identity = x
+ b, c, t, h, w = x.size()
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.norm(x)
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ )
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
+
+ # output
+ x = self.proj(x)
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
+ )
+
+ # output blocks
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
+
+ self.cache_device = None
+
+ def set_cache_device(self, device):
+ self.cache_device = device
+
+ # set cache device for all layers
+ for layer in self.downsamples + self.middle + self.head:
+ if isinstance(layer, Resample) or isinstance(layer, ResidualBlock):
+ layer.set_cache_device(device)
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ cache_device = self.cache_device if self.cache_device is not None else x.device
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv1(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = layer(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
+ )
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
+
+ self.cache_device = None
+
+ def set_cache_device(self, device):
+ self.cache_device = device
+
+ # set cache device for all layers
+ for layer in self.middle + self.upsamples + self.head:
+ if isinstance(layer, Resample) or isinstance(layer, ResidualBlock):
+ layer.set_cache_device(device)
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ cache_device = self.cache_device if self.cache_device is not None else x.device
+
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv1(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone().to(cache_device)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = layer(x, feat_cache[idx].to(x.device) if feat_cache[idx] is not None else None)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class WanVAE_(nn.Module):
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
+
+ self.cache_device = None
+
+ @property
+ def dtype(self):
+ return self.conv1.weight.dtype
+
+ @property
+ def device(self):
+ return self.conv1.weight.device
+
+ def set_cache_device(self, device):
+ # set cache device
+ self.cache_device = device
+ self.encoder.set_cache_device(device)
+ self.decoder.set_cache_device(device)
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x, scale):
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ # ## 对encode输入的x,按时间拆分为1、4、4、4....
+
+ # if self.cache_device is None:
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
+ )
+ out = torch.cat([out, out_], 2)
+ # else:
+ # # VRAM optimization
+ # device = x.device
+ # clean_memory_on_device(device)
+ # outs = []
+ # for i in range(iter_):
+ # self._enc_conv_idx = [0]
+ # if i == 0:
+ # out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ # else:
+ # out = self.encoder(
+ # x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
+ # )
+ # outs.append(out.to(self.cache_device))
+ # out = torch.cat(outs, 2).to(device)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ if isinstance(scale[0], torch.Tensor):
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
+ else:
+ mu = (mu - scale[0]) * scale[1]
+ self.clear_cache()
+ return mu
+
+ def decode(self, z, scale):
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+ if isinstance(scale[0], torch.Tensor):
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
+ else:
+ z = z / scale[1] + scale[0]
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+
+ # if self.cache_device is None:
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+ # else:
+ # # VRAM optimization
+ # device = z.device
+ # x = x.to("cpu")
+ # clean_memory_on_device(device)
+ # outs = []
+ # for i in range(iter_):
+ # self._conv_idx = [0]
+ # out = self.decoder(x[:, :, i : i + 1, :, :].to(device), feat_cache=self._feat_map, feat_idx=self._conv_idx).to(
+ # self.cache_device
+ # )
+ # outs.append(out)
+ # out = torch.cat(outs, 2) # on cache_device
+ self.clear_cache()
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+
+def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
+ """
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
+ """
+ # params
+ cfg = dict(
+ dim=96,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[False, True, True],
+ dropout=0.0,
+ )
+ cfg.update(**kwargs)
+
+ # init model
+ with torch.device("meta"):
+ model = WanVAE_(**cfg)
+
+ # load checkpoint
+ logging.info(f"loading {pretrained_path}")
+ if os.path.splitext(pretrained_path)[-1] == ".safetensors":
+ sd = load_file(pretrained_path)
+ model.load_state_dict(sd, strict=False, assign=True)
+ else:
+ model.load_state_dict(torch.load(pretrained_path, map_location=device, weights_only=True), assign=True)
+
+ return model
+
+
+class WanVAE:
+
+ def __init__(self, z_dim=16, vae_path="cache/vae_step_411000.pth", dtype=torch.float, device="cuda", cache_device=None):
+ self.dtype = dtype
+ self.device = device
+
+ mean = [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ]
+ std = [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ]
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
+ self.std = torch.tensor(std, dtype=dtype, device=device)
+ self.scale = [self.mean, 1.0 / self.std]
+
+ # init model
+ self.model = (
+ _video_vae(
+ pretrained_path=vae_path,
+ z_dim=z_dim,
+ )
+ .eval()
+ .requires_grad_(False)
+ .to(device, dtype=dtype)
+ )
+ if cache_device is not None:
+ self.model.set_cache_device(torch.device(cache_device))
+
+ def to_device(self, device):
+ self.device = device
+ self.model.to(device)
+ self.mean = self.mean.to(device)
+ self.std = self.std.to(device)
+ self.scale = [t.to(device) for t in self.scale]
+
+ def to_dtype(self, dtype):
+ self.dtype = dtype
+ self.model.to(dtype=dtype)
+ self.mean = self.mean.to(dtype)
+ self.std = self.std.to(dtype)
+ self.scale = [t.to(dtype) for t in self.scale]
+
+ def eval(self):
+ self.model.eval()
+
+ def train(self, mode: bool = True):
+ self.model.train(mode)
+
+ def requires_grad_(self, requires_grad: bool = True):
+ self.model.requires_grad_(requires_grad)
+
+ def to(self, device_or_dtype: Union[torch.device, torch.dtype, str], dtype: Optional[torch.dtype] = None):
+ """
+ Add nn.Module.to() support for device and dtype.
+ """
+ if isinstance(device_or_dtype, str) or isinstance(device_or_dtype, torch.device):
+ self.to_device(device_or_dtype)
+ else:
+ self.to_dtype(device_or_dtype)
+
+ if dtype is not None:
+ self.to_dtype(dtype)
+
+ def encode(self, videos):
+ """
+ videos: A list of videos each with shape [C, T, H, W].
+ """
+ # with amp.autocast(dtype=self.dtype):
+ return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
+
+ def decode(self, zs):
+ # with amp.autocast(dtype=self.dtype):
+ return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs]
diff --git a/wan/modules/xlm_roberta.py b/wan/modules/xlm_roberta.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bd38c1016fdaec90b77a6222d75d01c38c1291c
--- /dev/null
+++ b/wan/modules/xlm_roberta.py
@@ -0,0 +1,170 @@
+# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['XLMRoberta', 'xlm_roberta_large']
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, mask):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+
+ # compute attention
+ p = self.dropout.p if self.training else 0.0
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
+
+ # output
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # layers
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
+ nn.Dropout(dropout))
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, x, mask):
+ if self.post_norm:
+ x = self.norm1(x + self.attn(x, mask))
+ x = self.norm2(x + self.ffn(x))
+ else:
+ x = x + self.attn(self.norm1(x), mask)
+ x = x + self.ffn(self.norm2(x))
+ return x
+
+
+class XLMRoberta(nn.Module):
+ """
+ XLMRobertaModel with no pooler and no LM head.
+ """
+
+ def __init__(self,
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.max_seq_len = max_seq_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.dim = dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # embeddings
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
+ self.type_embedding = nn.Embedding(type_size, dim)
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
+ self.dropout = nn.Dropout(dropout)
+
+ # blocks
+ self.blocks = nn.ModuleList([
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
+ for _ in range(num_layers)
+ ])
+
+ # norm layer
+ self.norm = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, ids):
+ """
+ ids: [B, L] of torch.LongTensor.
+ """
+ b, s = ids.shape
+ mask = ids.ne(self.pad_id).long()
+
+ # embeddings
+ x = self.token_embedding(ids) + \
+ self.type_embedding(torch.zeros_like(ids)) + \
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
+ if self.post_norm:
+ x = self.norm(x)
+ x = self.dropout(x)
+
+ # blocks
+ mask = torch.where(
+ mask.view(b, 1, 1, s).gt(0), 0.0,
+ torch.finfo(x.dtype).min)
+ for block in self.blocks:
+ x = block(x, mask)
+
+ # output
+ if not self.post_norm:
+ x = self.norm(x)
+ return x
+
+
+def xlm_roberta_large(pretrained=False,
+ return_tokenizer=False,
+ device='cpu',
+ **kwargs):
+ """
+ XLMRobertaLarge adapted from Huggingface.
+ """
+ # params
+ cfg = dict(
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5)
+ cfg.update(**kwargs)
+
+ # init a model on device
+ with torch.device(device):
+ model = XLMRoberta(**cfg)
+ return model
diff --git a/wan/text2video.py b/wan/text2video.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b3dd26b32e721e96279dae44b0f94a9a912a03
--- /dev/null
+++ b/wan/text2video.py
@@ -0,0 +1,331 @@
+# Modified from official implementation
+
+# Original source:
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import logging
+import math
+import os
+import random
+import sys
+from typing import Optional, Union
+
+import torch
+from tqdm import tqdm
+from accelerate import Accelerator, init_empty_weights
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+from utils.safetensors_utils import load_safetensors
+
+# from .distributed.fsdp import shard_model
+from .modules.model import WanModel
+from .modules.t5 import T5EncoderModel
+from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
+from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+
+from utils.device_utils import clean_memory_on_device, synchronize_device
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+class WanT2V:
+
+ def __init__(
+ self,
+ config,
+ checkpoint_dir,
+ device_id=0,
+ rank=0,
+ t5_fsdp=False,
+ dit_fsdp=False,
+ use_usp=False,
+ t5_cpu=False,
+ device=None,
+ dtype=None,
+ dit_path=None,
+ dit_attn_mode=None,
+ t5_path=None,
+ t5_fp8=False,
+ ):
+ r"""
+ Initializes the Wan text-to-video generation model components.
+
+ Args:
+ config (EasyDict):
+ Object containing model parameters initialized from config.py
+ checkpoint_dir (`str`):
+ Path to directory containing model checkpoints
+ device_id (`int`, *optional*, defaults to 0) **IGNORED**:
+ Id of target GPU device
+ rank (`int`, *optional*, defaults to 0) **IGNORED**:
+ Process rank for distributed training
+ t5_fsdp (`bool`, *optional*, defaults to False) **IGNORED**:
+ Enable FSDP sharding for T5 model
+ dit_fsdp (`bool`, *optional*, defaults to False) **IGNORED**:
+ Enable FSDP sharding for DiT model
+ use_usp (`bool`, *optional*, defaults to False) **IGNORED**:
+ Enable distribution strategy of USP.
+ t5_cpu (`bool`, *optional*, defaults to False) **IGNORED**:
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
+ device (`torch.device`, *optional*, defaults to None):
+ Device to place the model on. If None, use the default device (cuda)
+ dtype (`torch.dtype`, *optional*, defaults to None):
+ Data type for DiT model parameters. If None, use the default parameter data type from config
+ dit_path (`str`, *optional*, defaults to None):
+ Path to DiT model checkpoint. checkpoint_dir is used if None.
+ dit_attn_mode (`str`, *optional*, defaults to None):
+ Attention mode for DiT model. If None, use "torch" attention mode.
+ t5_path (`str`, *optional*, defaults to None):
+ Path to T5 model checkpoint. checkpoint_dir is used if None.
+ t5_fp8 (`bool`, *optional*, defaults to False):
+ Enable FP8 quantization for T5 model
+ """
+ self.device = device if device is not None else torch.device("cuda")
+ self.config = config
+ self.rank = rank
+ self.t5_cpu = t5_cpu
+ self.t5_fp8 = t5_fp8
+
+ self.num_train_timesteps = config.num_train_timesteps
+ self.param_dtype = config.param_dtype
+
+ # shard_fn = partial(shard_model, device_id=device_id)
+ checkpoint_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_checkpoint)
+ tokenizer_path = None if checkpoint_dir is None else os.path.join(checkpoint_dir, config.t5_tokenizer)
+ self.text_encoder = T5EncoderModel(
+ text_len=config.text_len,
+ dtype=config.t5_dtype,
+ device=device,
+ checkpoint_path=checkpoint_path,
+ tokenizer_path=tokenizer_path,
+ weight_path=t5_path,
+ fp8=t5_fp8,
+ # shard_fn=shard_fn if t5_fsdp else None,
+ )
+
+ self.vae_stride = config.vae_stride
+ self.patch_size = config.patch_size
+
+ self.checkpoint_dir = checkpoint_dir
+ self.dit_path = dit_path
+ self.dit_dtype = dtype if dtype is not None else config.param_dtype
+ self.dit_attn_mode = dit_attn_mode
+
+ self.sample_neg_prompt = config.sample_neg_prompt
+
+ def generate(
+ self,
+ accelerator: Accelerator,
+ merge_lora: Optional[callable],
+ dit_loading_dtype: Optional[torch.dtype],
+ input_prompt,
+ size=(1280, 720),
+ frame_num=81,
+ shift=5.0,
+ sample_solver="unipc",
+ sampling_steps=50,
+ guide_scale=5.0,
+ n_prompt="",
+ seed=-1,
+ blocks_to_swap=0,
+ ):
+ r"""
+ Generates video frames from text prompt using diffusion process.
+
+ Args:
+ input_prompt (`str`):
+ Text prompt for content generation
+ size (tupele[`int`], *optional*, defaults to (1280,720)):
+ Controls video resolution, (width,height).
+ frame_num (`int`, *optional*, defaults to 81):
+ How many frames to sample from a video. The number should be 4n+1
+ shift (`float`, *optional*, defaults to 5.0):
+ Noise schedule shift parameter. Affects temporal dynamics
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
+ Solver used to sample the video.
+ sampling_steps (`int`, *optional*, defaults to 40):
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
+ guide_scale (`float`, *optional*, defaults 5.0):
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
+ n_prompt (`str`, *optional*, defaults to ""):
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
+ seed (`int`, *optional*, defaults to -1):
+ Random seed for noise generation. If -1, use random seed.
+ blocks_to_swap (`int`, *optional*, defaults to 0):
+ Number of blocks to swap (offload) to CPU. If 0, no blocks are offloaded.
+
+ Returns:
+ torch.Tensor:
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
+ - C: Color channels (3 for RGB)
+ - N: Number of frames (81)
+ - H: Frame height (from size)
+ - W: Frame width from size)
+ """
+ # preprocess
+ F = frame_num
+ # self.vae.model.z_dim == 16
+ target_shape = (16, (F - 1) // self.vae_stride[0] + 1, size[1] // self.vae_stride[1], size[0] // self.vae_stride[2])
+
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.patch_size[1] * self.patch_size[2]) * target_shape[1])
+
+ if n_prompt == "":
+ n_prompt = self.sample_neg_prompt
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=self.device)
+ seed_g.manual_seed(seed)
+
+ self.text_encoder.model.to(self.device)
+ with torch.no_grad():
+ if self.t5_fp8:
+ with accelerator.autocast():
+ context = self.text_encoder([input_prompt], self.device)
+ context_null = self.text_encoder([n_prompt], self.device)
+ else:
+ context = self.text_encoder([input_prompt], self.device)
+ context_null = self.text_encoder([n_prompt], self.device)
+
+ del self.text_encoder
+ clean_memory_on_device(self.device)
+
+ # load DiT model
+ dit_loading_dtype = dit_loading_dtype if dit_loading_dtype is not None else self.dit_dtype
+ with init_empty_weights():
+ # if self.checkpoint_dir is not None:
+ # logger.info(f"Creating WanModel from {self.checkpoint_dir}")
+ # self.model = WanModel.from_pretrained(self.checkpoint_dir)
+ # self.model = WanModel.from_config(config)
+ # else:
+ logger.info(f"Creating WanModel")
+ self.model = WanModel(
+ dim=self.config.dim,
+ eps=self.config.eps,
+ ffn_dim=self.config.ffn_dim,
+ freq_dim=self.config.freq_dim,
+ in_dim=16,
+ num_heads=self.config.num_heads,
+ num_layers=self.config.num_layers,
+ out_dim=16,
+ text_len=512,
+ attn_mode=self.dit_attn_mode,
+ )
+ self.model.to(dit_loading_dtype)
+
+ # if LoRA is enabled, load the model on CPU with bfloat16
+ loading_device = self.device if (blocks_to_swap == 0 and merge_lora is None) else "cpu"
+ logger.info(f"Loading DiT model from {self.dit_path}, device={loading_device}, dtype={dit_loading_dtype}")
+ sd = load_safetensors(self.dit_path, loading_device, disable_mmap=True, dtype=dit_loading_dtype)
+
+ # remove "model.diffusion_model." prefix: 1.3B model has this prefix
+ for key in list(sd.keys()):
+ if key.startswith("model.diffusion_model."):
+ sd[key[22:]] = sd.pop(key)
+
+ info = self.model.load_state_dict(sd, strict=True, assign=True)
+ logger.info(f"Loaded DiT model from {self.dit_path}, info={info}")
+
+ if merge_lora is not None:
+ # merge LoRA to the model, cast and move to the device
+ merge_lora(self.model)
+ if blocks_to_swap == 0:
+ self.model.to(self.device, self.dit_dtype)
+ else:
+ self.model.to(self.dit_dtype)
+
+ if blocks_to_swap > 0:
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {self.device}")
+ self.model.enable_block_swap(blocks_to_swap, self.device, supports_backward=False)
+ self.model.move_to_device_except_swap_blocks(self.device)
+ self.model.prepare_block_swap_before_forward()
+ else:
+ # make sure the model is on the right device
+ self.model.to(self.device)
+
+ self.model.eval().requires_grad_(False)
+ clean_memory_on_device(self.device)
+
+ noise = [
+ torch.randn(
+ target_shape[0],
+ target_shape[1],
+ target_shape[2],
+ target_shape[3],
+ dtype=torch.float32,
+ device=self.device,
+ generator=seed_g,
+ )
+ ]
+
+ # evaluation mode
+ # with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
+ with accelerator.autocast(), torch.no_grad():
+ if sample_solver == "unipc":
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False
+ )
+ sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
+ timesteps = sample_scheduler.timesteps
+ elif sample_solver == "dpm++":
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False
+ )
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
+ timesteps, _ = retrieve_timesteps(sample_scheduler, device=self.device, sigmas=sampling_sigmas)
+ elif sample_solver == "vanilla":
+ sample_scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=self.num_train_timesteps, shift=shift)
+ sample_scheduler.set_timesteps(sampling_steps, device=self.device)
+ timesteps = sample_scheduler.timesteps
+
+ org_step = sample_scheduler.step
+
+ def step_wrapper(
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ return_dict: bool = True,
+ generator=None,
+ ):
+ return org_step(model_output, timestep, sample, return_dict=return_dict)
+
+ sample_scheduler.step = step_wrapper
+ else:
+ raise NotImplementedError("Unsupported solver.")
+
+ # sample videos
+ latents = noise
+ del noise
+
+ arg_c = {"context": context, "seq_len": seq_len}
+ arg_null = {"context": context_null, "seq_len": seq_len}
+
+ for _, t in enumerate(tqdm(timesteps)):
+ latent_model_input = latents
+ timestep = [t]
+
+ timestep = torch.stack(timestep)
+
+ noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0]
+ noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0]
+
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
+ del noise_pred_cond, noise_pred_uncond
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=seed_g
+ )[0]
+ del noise_pred
+ latents = [temp_x0.squeeze(0)]
+ del temp_x0
+
+ x0 = latents
+
+ del latents
+ del sample_scheduler
+ del self.model
+ synchronize_device(self.device)
+ clean_memory_on_device(self.device)
+
+ # return latents
+ return x0[0]
diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e9a339e69fd55dd226d3ce242613c19bd690522
--- /dev/null
+++ b/wan/utils/__init__.py
@@ -0,0 +1,8 @@
+from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
+ retrieve_timesteps)
+from .fm_solvers_unipc import FlowUniPCMultistepScheduler
+
+__all__ = [
+ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
+ 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
+]
diff --git a/wan/utils/fm_solvers.py b/wan/utils/fm_solvers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c908969e24849ce1381a8df9d5eb401dccf66524
--- /dev/null
+++ b/wan/utils/fm_solvers.py
@@ -0,0 +1,857 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+# Convert dpm solver for flow matching
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import inspect
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
+ SchedulerMixin,
+ SchedulerOutput)
+from diffusers.utils import deprecate, is_scipy_available
+from diffusers.utils.torch_utils import randn_tensor
+
+if is_scipy_available():
+ pass
+
+
+def get_sampling_sigmas(sampling_steps, shift):
+ sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
+ sigma = (shift * sigma / (1 + (shift - 1) * sigma))
+
+ return sigma
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps=None,
+ device=None,
+ timesteps=None,
+ sigmas=None,
+ **kwargs,
+):
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
+ solver_order (`int`, defaults to 2):
+ The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
+ and used in multistep updates.
+ prediction_type (`str`, defaults to "flow_prediction"):
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
+ the flow of the diffusion process.
+ shift (`float`, *optional*, defaults to 1.0):
+ A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
+ process.
+ use_dynamic_shifting (`bool`, defaults to `False`):
+ Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
+ applied on the fly.
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
+ saturation and improve photorealism.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++"`.
+ algorithm_type (`str`, defaults to `dpmsolver++`):
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
+ paper, and the `dpmsolver++` type implements the algorithms in the
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ solver_type (`str`, defaults to `midpoint`):
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ lower_order_final (`bool`, defaults to `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ euler_at_final (`bool`, defaults to `False`):
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
+ steps, but sometimes may result in blurring.
+ final_sigmas_type (`str`, *optional*, defaults to "zero"):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ lambda_min_clipped (`float`, defaults to `-inf`):
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
+ cosine (`squaredcos_cap_v2`) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
+ contains the predicted Gaussian variance.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "flow_prediction",
+ shift: Optional[float] = 1.0,
+ use_dynamic_shifting=False,
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ euler_at_final: bool = False,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ lambda_min_clipped: float = -float("inf"),
+ variance_type: Optional[str] = None,
+ invert_sigmas: bool = False,
+ ):
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
+ deprecation_message)
+
+ # settings for DPM-Solver
+ if algorithm_type not in [
+ "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
+ ]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(
+ f"{algorithm_type} is not implemented for {self.__class__}")
+
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(
+ f"{solver_type} is not implemented for {self.__class__}")
+
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
+ ] and final_sigmas_type == "zero":
+ raise ValueError(
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
+ )
+
+ # setable values
+ self.num_inference_steps = None
+ alphas = np.linspace(1, 1 / num_train_timesteps,
+ num_train_timesteps)[::-1].copy()
+ sigmas = 1.0 - alphas
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ self.sigmas = sigmas
+ self.timesteps = sigmas * num_train_timesteps
+
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self._step_index = None
+ self._begin_index = None
+
+ # self.sigmas = self.sigmas.to(
+ # "cpu") # to avoid too much CPU/GPU communication
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
+ def set_timesteps(
+ self,
+ num_inference_steps: Union[int, None] = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[Union[float, None]] = None,
+ shift: Optional[Union[float, None]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ Total number of the spacing of the time steps.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
+ )
+
+ if sigmas is None:
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
+ num_inference_steps +
+ 1).copy()[:-1] # pyright: ignore
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
+ else:
+ if shift is None:
+ shift = self.config.shift
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
+ self.alphas_cumprod[0])**0.5
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = np.concatenate([sigmas, [sigma_last]
+ ]).astype(np.float32) # pyright: ignore
+
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps).to(
+ device=device, dtype=torch.int64)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ self._step_index = None
+ self._begin_index = None
+ # self.sigmas = self.sigmas.to(
+ # "cpu") # to avoid too much CPU/GPU communication
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float(
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(
+ sample, -s, s
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ return 1 - sigma, sigma
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
+ integral of the data prediction model.
+
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
+ prediction and data prediction models.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError(
+ "missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ epsilon = sample - (1 - sigma_t) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = model_output + x0_pred
+
+ return epsilon
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the first-order DPMSolver (equivalent to DDIM).
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(
+ " missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
+ self.step_index] # pyright: ignore
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
+
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t /
+ sigma_s) * sample - (alpha_t *
+ (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t /
+ alpha_s) * sample - (sigma_t *
+ (torch.exp(h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ x_t = ((alpha_t / alpha_s) * sample - 2.0 *
+ (sigma_t * (torch.exp(h) - 1.0)) * model_output +
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
+ return x_t # pyright: ignore
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the second-order multistep DPMSolver.
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
+ "timestep_list", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(
+ " missing `sample` as a required keyward argument")
+ if timestep_list is not None:
+ deprecate(
+ "timestep_list",
+ "1.0.0",
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s0, sigma_s1 = (
+ self.sigmas[self.step_index + 1], # pyright: ignore
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1], # pyright: ignore
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = ((sigma_t / sigma_s0) * sample -
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
+ (alpha_t * (torch.exp(-h) - 1.0)) * D1)
+ elif self.config.solver_type == "heun":
+ x_t = ((sigma_t / sigma_s0) * sample -
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = ((alpha_t / alpha_s0) * sample -
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
+ (sigma_t * (torch.exp(h) - 1.0)) * D1)
+ elif self.config.solver_type == "heun":
+ x_t = ((alpha_t / alpha_s0) * sample -
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
+ elif self.config.solver_type == "heun":
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
+ (-2.0 * h) + 1.0)) * D1 +
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
+ (sigma_t * (torch.exp(h) - 1.0)) * D1 +
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
+ elif self.config.solver_type == "heun":
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
+ return x_t # pyright: ignore
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the third-order multistep DPMSolver.
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
+ "timestep_list", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(
+ " missing`sample` as a required keyward argument")
+ if timestep_list is not None:
+ deprecate(
+ "timestep_list",
+ "1.0.0",
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
+ self.sigmas[self.step_index + 1], # pyright: ignore
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1], # pyright: ignore
+ self.sigmas[self.step_index - 2], # pyright: ignore
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
+
+ m0, m1, m2 = model_output_list[-1], model_output_list[
+ -2], model_output_list[-3]
+
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = ((sigma_t / sigma_s0) * sample -
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
+ (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
+ (torch.exp(h) - 1.0)) * D0 -
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
+ (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
+ return x_t # pyright: ignore
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep DPMSolver.
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`LEdits++`].
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Improve numerical stability for small number of steps
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
+ self.config.euler_at_final or
+ (self.config.lower_order_final and len(self.timesteps) < 15) or
+ self.config.final_sigmas_type == "zero")
+ lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
+ self.config.lower_order_final and
+ len(self.timesteps) < 15)
+
+ model_output = self.convert_model_output(model_output, sample=sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
+ ] and variance_noise is None:
+ noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=torch.float32)
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
+ noise = variance_noise.to(
+ device=model_output.device,
+ dtype=torch.float32) # pyright: ignore
+ else:
+ noise = None
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(
+ model_output, sample=sample, noise=noise)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ prev_sample = self.multistep_dpm_solver_second_order_update(
+ self.model_outputs, sample=sample, noise=noise)
+ else:
+ prev_sample = self.multistep_dpm_solver_third_order_update(
+ self.model_outputs, sample=sample)
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # Cast sample back to expected dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1 # pyright: ignore
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.Tensor, *args,
+ **kwargs) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(
+ device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(
+ timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [
+ self.index_for_timestep(t, schedule_timesteps)
+ for t in timesteps
+ ]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/wan/utils/fm_solvers_unipc.py b/wan/utils/fm_solvers_unipc.py
new file mode 100644
index 0000000000000000000000000000000000000000..57321baa35359782b33143321cd31c8d934a7b29
--- /dev/null
+++ b/wan/utils/fm_solvers_unipc.py
@@ -0,0 +1,800 @@
+# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
+# Convert unipc for flow matching
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
+ SchedulerMixin,
+ SchedulerOutput)
+from diffusers.utils import deprecate, is_scipy_available
+
+if is_scipy_available():
+ import scipy.stats
+
+
+class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ solver_order (`int`, default `2`):
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
+ unconditional sampling.
+ prediction_type (`str`, defaults to "flow_prediction"):
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
+ the flow of the diffusion process.
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
+ predict_x0 (`bool`, defaults to `True`):
+ Whether to use the updating algorithm on the predicted x0.
+ solver_type (`str`, default `bh2`):
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
+ otherwise.
+ lower_order_final (`bool`, default `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ disable_corrector (`list`, default `[]`):
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
+ usually disabled during the first few steps.
+ solver_p (`SchedulerMixin`, default `None`):
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
+ the sigmas are determined according to a sequence of noise levels {σi}.
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "flow_prediction",
+ shift: Optional[float] = 1.0,
+ use_dynamic_shifting=False,
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ predict_x0: bool = True,
+ solver_type: str = "bh2",
+ lower_order_final: bool = True,
+ disable_corrector: List[int] = [],
+ solver_p: SchedulerMixin = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ ):
+
+ if solver_type not in ["bh1", "bh2"]:
+ if solver_type in ["midpoint", "heun", "logrho"]:
+ self.register_to_config(solver_type="bh2")
+ else:
+ raise NotImplementedError(
+ f"{solver_type} is not implemented for {self.__class__}")
+
+ self.predict_x0 = predict_x0
+ # setable values
+ self.num_inference_steps = None
+ alphas = np.linspace(1, 1 / num_train_timesteps,
+ num_train_timesteps)[::-1].copy()
+ sigmas = 1.0 - alphas
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ self.sigmas = sigmas
+ self.timesteps = sigmas * num_train_timesteps
+
+ self.model_outputs = [None] * solver_order
+ self.timestep_list = [None] * solver_order
+ self.lower_order_nums = 0
+ self.disable_corrector = disable_corrector
+ self.solver_p = solver_p
+ self.last_sample = None
+ self._step_index = None
+ self._begin_index = None
+
+ self.sigmas = self.sigmas.to(
+ "cpu") # to avoid too much CPU/GPU communication
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
+ def set_timesteps(
+ self,
+ num_inference_steps: Union[int, None] = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[Union[float, None]] = None,
+ shift: Optional[Union[float, None]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ Total number of the spacing of the time steps.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
+ )
+
+ if sigmas is None:
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
+ num_inference_steps +
+ 1).copy()[:-1] # pyright: ignore
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
+ else:
+ if shift is None:
+ shift = self.config.shift
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
+ self.alphas_cumprod[0])**0.5
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = np.concatenate([sigmas, [sigma_last]
+ ]).astype(np.float32) # pyright: ignore
+
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps).to(
+ device=device, dtype=torch.int64)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+ self.last_sample = None
+ if self.solver_p:
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
+
+ # add an index counter for schedulers that allow duplicated timesteps
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to(
+ "cpu") # to avoid too much CPU/GPU communication
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float(
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(
+ sample, -s, s
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ return 1 - sigma, sigma
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ r"""
+ Convert the model output to the corresponding type the UniPC algorithm needs.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError(
+ "missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+
+ if self.predict_x0:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+ else:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ epsilon = sample - (1 - sigma_t) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = model_output + x0_pred
+
+ return epsilon
+
+ def multistep_uni_p_bh_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ order: int = None, # pyright: ignore
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model at the current timestep.
+ prev_timestep (`int`):
+ The previous discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ order (`int`):
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError(
+ " missing `sample` as a required keyward argument")
+ if order is None:
+ if len(args) > 2:
+ order = args[2]
+ else:
+ raise ValueError(
+ " missing `order` as a required keyward argument")
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+ model_output_list = self.model_outputs
+
+ s0 = self.timestep_list[-1]
+ m0 = model_output_list[-1]
+ x = sample
+
+ if self.solver_p:
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
+ return x_t
+
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
+ self.step_index] # pyright: ignore
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+ device = sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = self.step_index - i # pyright: ignore
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk) # pyright: ignore
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1],
+ b[:-1]).to(device).to(x.dtype)
+ else:
+ D1s = None
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
+ D1s) # pyright: ignore
+ else:
+ pred_res = 0
+ x_t = x_t_ - alpha_t * B_h * pred_res
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
+ D1s) # pyright: ignore
+ else:
+ pred_res = 0
+ x_t = x_t_ - sigma_t * B_h * pred_res
+
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def multistep_uni_c_bh_update(
+ self,
+ this_model_output: torch.Tensor,
+ *args,
+ last_sample: torch.Tensor = None,
+ this_sample: torch.Tensor = None,
+ order: int = None, # pyright: ignore
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniC (B(h) version).
+
+ Args:
+ this_model_output (`torch.Tensor`):
+ The model outputs at `x_t`.
+ this_timestep (`int`):
+ The current timestep `t`.
+ last_sample (`torch.Tensor`):
+ The generated sample before the last predictor `x_{t-1}`.
+ this_sample (`torch.Tensor`):
+ The generated sample after the last predictor `x_{t}`.
+ order (`int`):
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
+
+ Returns:
+ `torch.Tensor`:
+ The corrected sample tensor at the current timestep.
+ """
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop(
+ "this_timestep", None)
+ if last_sample is None:
+ if len(args) > 1:
+ last_sample = args[1]
+ else:
+ raise ValueError(
+ " missing`last_sample` as a required keyward argument")
+ if this_sample is None:
+ if len(args) > 2:
+ this_sample = args[2]
+ else:
+ raise ValueError(
+ " missing`this_sample` as a required keyward argument")
+ if order is None:
+ if len(args) > 3:
+ order = args[3]
+ else:
+ raise ValueError(
+ " missing`order` as a required keyward argument")
+ if this_timestep is not None:
+ deprecate(
+ "this_timestep",
+ "1.0.0",
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ model_output_list = self.model_outputs
+
+ m0 = model_output_list[-1]
+ x = last_sample
+ x_t = this_sample
+ model_t = this_model_output
+
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
+ self.step_index - 1] # pyright: ignore
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+ device = this_sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = self.step_index - (i + 1) # pyright: ignore
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk) # pyright: ignore
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1)
+ else:
+ D1s = None
+
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ return_dict: bool = True,
+ generator=None) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep UniPC.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ use_corrector = (
+ self.step_index > 0 and
+ self.step_index - 1 not in self.disable_corrector and
+ self.last_sample is not None # pyright: ignore
+ )
+
+ model_output_convert = self.convert_model_output(
+ model_output, sample=sample)
+ if use_corrector:
+ sample = self.multistep_uni_c_bh_update(
+ this_model_output=model_output_convert,
+ last_sample=self.last_sample,
+ this_sample=sample,
+ order=self.this_order,
+ )
+
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.timestep_list[i] = self.timestep_list[i + 1]
+
+ self.model_outputs[-1] = model_output_convert
+ self.timestep_list[-1] = timestep # pyright: ignore
+
+ if self.config.lower_order_final:
+ this_order = min(self.config.solver_order,
+ len(self.timesteps) -
+ self.step_index) # pyright: ignore
+ else:
+ this_order = self.config.solver_order
+
+ self.this_order = min(this_order,
+ self.lower_order_nums + 1) # warmup for multistep
+ assert self.this_order > 0
+
+ self.last_sample = sample
+ prev_sample = self.multistep_uni_p_bh_update(
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
+ sample=sample,
+ order=self.this_order,
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # upon completion increase step index by one
+ self._step_index += 1 # pyright: ignore
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.Tensor, *args,
+ **kwargs) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(
+ device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(
+ timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [
+ self.index_for_timestep(t, schedule_timesteps)
+ for t in timesteps
+ ]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/wan/utils/utils.py b/wan/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d72599967f0a5a491e722e7d7a942efe5137b210
--- /dev/null
+++ b/wan/utils/utils.py
@@ -0,0 +1,118 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import argparse
+import binascii
+import os
+import os.path as osp
+
+import imageio
+import torch
+import torchvision
+
+__all__ = ['cache_video', 'cache_image', 'str2bool']
+
+
+def rand_name(length=8, suffix=''):
+ name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
+ if suffix:
+ if not suffix.startswith('.'):
+ suffix = '.' + suffix
+ name += suffix
+ return name
+
+
+def cache_video(tensor,
+ save_file=None,
+ fps=30,
+ suffix='.mp4',
+ nrow=8,
+ normalize=True,
+ value_range=(-1, 1),
+ retry=5):
+ # cache file
+ cache_file = osp.join('/tmp', rand_name(
+ suffix=suffix)) if save_file is None else save_file
+
+ # save to cache
+ error = None
+ for _ in range(retry):
+ try:
+ # preprocess
+ tensor = tensor.clamp(min(value_range), max(value_range))
+ tensor = torch.stack([
+ torchvision.utils.make_grid(
+ u, nrow=nrow, normalize=normalize, value_range=value_range)
+ for u in tensor.unbind(2)
+ ],
+ dim=1).permute(1, 2, 3, 0)
+ tensor = (tensor * 255).type(torch.uint8).cpu()
+
+ # write video
+ writer = imageio.get_writer(
+ cache_file, fps=fps, codec='libx264', quality=8)
+ for frame in tensor.numpy():
+ writer.append_data(frame)
+ writer.close()
+ return cache_file
+ except Exception as e:
+ error = e
+ continue
+ else:
+ print(f'cache_video failed, error: {error}', flush=True)
+ return None
+
+
+def cache_image(tensor,
+ save_file,
+ nrow=8,
+ normalize=True,
+ value_range=(-1, 1),
+ retry=5):
+ # cache file
+ suffix = osp.splitext(save_file)[1]
+ if suffix.lower() not in [
+ '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
+ ]:
+ suffix = '.png'
+
+ # save to cache
+ error = None
+ for _ in range(retry):
+ try:
+ tensor = tensor.clamp(min(value_range), max(value_range))
+ torchvision.utils.save_image(
+ tensor,
+ save_file,
+ nrow=nrow,
+ normalize=normalize,
+ value_range=value_range)
+ return save_file
+ except Exception as e:
+ error = e
+ continue
+
+
+def str2bool(v):
+ """
+ Convert a string to a boolean.
+
+ Supported true values: 'yes', 'true', 't', 'y', '1'
+ Supported false values: 'no', 'false', 'f', 'n', '0'
+
+ Args:
+ v (str): String to convert.
+
+ Returns:
+ bool: Converted boolean value.
+
+ Raises:
+ argparse.ArgumentTypeError: If the value cannot be converted to boolean.
+ """
+ if isinstance(v, bool):
+ return v
+ v_lower = v.lower()
+ if v_lower in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v_lower in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
diff --git a/wan_cache_latents.py b/wan_cache_latents.py
new file mode 100644
index 0000000000000000000000000000000000000000..15b09e444e396b279aabb6a3e1d24bc947c29aca
--- /dev/null
+++ b/wan_cache_latents.py
@@ -0,0 +1,177 @@
+import argparse
+import os
+import glob
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+from PIL import Image
+
+import logging
+
+from dataset.image_video_dataset import ItemInfo, save_latent_cache_wan, ARCHITECTURE_WAN
+from utils.model_utils import str_to_dtype
+from wan.configs import wan_i2v_14B
+from wan.modules.vae import WanVAE
+from wan.modules.clip import CLIPModel
+import cache_latents
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def encode_and_save_batch(vae: WanVAE, clip: Optional[CLIPModel], batch: list[ItemInfo]):
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
+ if len(contents.shape) == 4:
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
+
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
+ contents = contents.to(vae.device, dtype=vae.dtype)
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
+
+ h, w = contents.shape[3], contents.shape[4]
+ if h < 8 or w < 8:
+ item = batch[0] # other items should have the same size
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
+
+ # print(f"encode batch: {contents.shape}")
+ with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
+ latent = vae.encode(contents) # list of Tensor[C, F, H, W]
+ latent = torch.stack(latent, dim=0) # B, C, F, H, W
+ latent = latent.to(vae.dtype) # convert to bfloat16, we are not sure if this is correct
+
+ if clip is not None:
+ # extract first frame of contents
+ images = contents[:, :, 0:1, :, :] # B, C, F, H, W, non contiguous view is fine
+
+ with torch.amp.autocast(device_type=clip.device.type, dtype=torch.float16), torch.no_grad():
+ clip_context = clip.visual(images)
+ clip_context = clip_context.to(torch.float16) # convert to fp16
+
+ # encode image latent for I2V
+ B, _, _, lat_h, lat_w = latent.shape
+ F = contents.shape[2]
+
+ # Create mask for the required number of frames
+ msk = torch.ones(1, F, lat_h, lat_w, dtype=vae.dtype, device=vae.device)
+ msk[:, 1:] = 0
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
+ msk = msk.transpose(1, 2) # 1, F, 4, H, W -> 1, 4, F, H, W
+ msk = msk.repeat(B, 1, 1, 1, 1) # B, 4, F, H, W
+
+ # Zero padding for the required number of frames only
+ padding_frames = F - 1 # The first frame is the input image
+ images_resized = torch.concat([images, torch.zeros(B, 3, padding_frames, h, w, device=vae.device)], dim=2)
+ with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
+ y = vae.encode(images_resized)
+ y = torch.stack(y, dim=0) # B, C, F, H, W
+
+ y = y[:, :, :F] # may be not needed
+ y = y.to(vae.dtype) # convert to bfloat16
+ y = torch.concat([msk, y], dim=1) # B, 4 + C, F, H, W
+
+ else:
+ clip_context = None
+ y = None
+
+ # control videos
+ if batch[0].control_content is not None:
+ control_contents = torch.stack([torch.from_numpy(item.control_content) for item in batch])
+ if len(control_contents.shape) == 4:
+ control_contents = control_contents.unsqueeze(1)
+ control_contents = control_contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
+ control_contents = control_contents.to(vae.device, dtype=vae.dtype)
+ control_contents = control_contents / 127.5 - 1.0 # normalize to [-1, 1]
+ with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
+ control_latent = vae.encode(control_contents) # list of Tensor[C, F, H, W]
+ control_latent = torch.stack(control_latent, dim=0) # B, C, F, H, W
+ control_latent = control_latent.to(vae.dtype) # convert to bfloat16
+ else:
+ control_latent = None
+
+ # # debug: decode and save
+ # with torch.no_grad():
+ # latent_to_decode = latent / vae.config.scaling_factor
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
+ # images = (images / 2 + 0.5).clamp(0, 1)
+ # images = images.cpu().float().numpy()
+ # images = (images * 255).astype(np.uint8)
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
+ # for b in range(images.shape[0]):
+ # for f in range(images.shape[1]):
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
+ # img = Image.fromarray(images[b, f])
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
+
+ for i, item in enumerate(batch):
+ l = latent[i]
+ cctx = clip_context[i] if clip is not None else None
+ y_i = y[i] if clip is not None else None
+ control_latent_i = control_latent[i] if control_latent is not None else None
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
+ save_latent_cache_wan(item, l, cctx, y_i, control_latent_i)
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ if args.debug_mode is not None:
+ cache_latents.show_datasets(
+ datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
+ )
+ return
+
+ assert args.vae is not None, "vae checkpoint is required"
+
+ vae_path = args.vae
+
+ logger.info(f"Loading VAE model from {vae_path}")
+ vae_dtype = torch.bfloat16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
+ vae = WanVAE(vae_path=vae_path, device=device, dtype=vae_dtype, cache_device=cache_device)
+
+ if args.clip is not None:
+ clip_dtype = wan_i2v_14B.i2v_14B["clip_dtype"]
+ clip = CLIPModel(dtype=clip_dtype, device=device, weight_path=args.clip)
+ else:
+ clip = None
+
+ # Encode images
+ def encode(one_batch: list[ItemInfo]):
+ encode_and_save_batch(vae, clip, one_batch)
+
+ cache_latents.encode_datasets(datasets, encode, args)
+
+
+def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
+ parser.add_argument(
+ "--clip",
+ type=str,
+ default=None,
+ help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = cache_latents.setup_parser_common()
+ parser = wan_setup_parser(parser)
+
+ args = parser.parse_args()
+ main(args)
diff --git a/wan_cache_text_encoder_outputs.py b/wan_cache_text_encoder_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f10871eec315038c6c11f6efe3e55e33301c4ec4
--- /dev/null
+++ b/wan_cache_text_encoder_outputs.py
@@ -0,0 +1,107 @@
+import argparse
+import os
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from dataset import config_utils
+from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
+import accelerate
+
+from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan
+
+# for t5 config: all Wan2.1 models have the same config for t5
+from wan.configs import wan_t2v_14B
+
+import cache_text_encoder_outputs
+import logging
+
+from utils.model_utils import str_to_dtype
+from wan.modules.t5 import T5EncoderModel
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def encode_and_save_batch(
+ text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator]
+):
+ prompts = [item.caption for item in batch]
+ # print(prompts)
+
+ # encode prompt
+ with torch.no_grad():
+ if accelerator is not None:
+ with accelerator.autocast():
+ context = text_encoder(prompts, device)
+ else:
+ context = text_encoder(prompts, device)
+
+ # save prompt cache
+ for item, ctx in zip(batch, context):
+ save_text_encoder_output_cache_wan(item, ctx)
+
+
+def main(args):
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Load dataset config
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
+ logger.info(f"Load dataset config from {args.dataset_config}")
+ user_config = config_utils.load_user_config(args.dataset_config)
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
+
+ datasets = train_dataset_group.datasets
+
+ # define accelerator for fp8 inference
+ config = wan_t2v_14B.t2v_14B # all Wan2.1 models have the same config for t5
+ accelerator = None
+ if args.fp8_t5:
+ accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16")
+
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
+
+ # Load T5
+ logger.info(f"Loading T5: {args.t5}")
+ text_encoder = T5EncoderModel(
+ text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5
+ )
+
+ # Encode with T5
+ logger.info("Encoding with T5")
+
+ def encode_for_text_encoder(batch: list[ItemInfo]):
+ encode_and_save_batch(text_encoder, batch, device, accelerator)
+
+ cache_text_encoder_outputs.process_text_encoder_batches(
+ args.num_workers,
+ args.skip_existing,
+ args.batch_size,
+ datasets,
+ all_cache_files_for_dataset,
+ all_cache_paths_for_dataset,
+ encode_for_text_encoder,
+ )
+ del text_encoder
+
+ # remove cache files not in dataset
+ cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
+
+
+def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path")
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = cache_text_encoder_outputs.setup_parser_common()
+ parser = wan_setup_parser(parser)
+
+ args = parser.parse_args()
+ main(args)
diff --git a/wan_generate_video.py b/wan_generate_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b7afc26219355ed985afa65f2ad30d35af62034
--- /dev/null
+++ b/wan_generate_video.py
@@ -0,0 +1,1902 @@
+import argparse
+from datetime import datetime
+import gc
+import random
+import os
+import re
+import time
+import math
+import copy
+from types import ModuleType, SimpleNamespace
+from typing import Tuple, Optional, List, Union, Any, Dict
+
+import torch
+import accelerate
+from accelerate import Accelerator
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from PIL import Image
+import cv2
+import numpy as np
+import torchvision.transforms.functional as TF
+from tqdm import tqdm
+
+from networks import lora_wan
+from utils.safetensors_utils import mem_eff_save_file, load_safetensors
+from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES
+import wan
+from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype
+from wan.modules.vae import WanVAE
+from wan.modules.t5 import T5EncoderModel
+from wan.modules.clip import CLIPModel
+from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
+from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
+from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+
+try:
+ from lycoris.kohya import create_network_from_weights
+except:
+ pass
+
+from utils.model_utils import str_to_dtype
+from utils.device_utils import clean_memory_on_device
+from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
+from dataset.image_video_dataset import load_video
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+class GenerationSettings:
+ def __init__(
+ self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype
+ ):
+ self.device = device
+ self.cfg = cfg
+ self.dit_dtype = dit_dtype
+ self.dit_weight_dtype = dit_weight_dtype
+ self.vae_dtype = vae_dtype
+
+
+def parse_args() -> argparse.Namespace:
+ """parse command line arguments"""
+ parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
+
+ # WAN arguments
+ parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
+ parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
+ parser.add_argument(
+ "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
+ )
+
+ parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path")
+ parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path")
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16")
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
+ parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
+ parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path")
+ # LoRA
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
+ parser.add_argument(
+ "--save_merged_model",
+ type=str,
+ default=None,
+ help="Save merged model to path. If specified, no inference will be performed.",
+ )
+
+ # inference
+ parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
+ parser.add_argument(
+ "--negative_prompt",
+ type=str,
+ default=None,
+ help="negative prompt for generation, use default negative prompt if not specified",
+ )
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
+ parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task")
+ parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16")
+ parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps")
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
+ parser.add_argument(
+ "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=5.0,
+ help="Guidance scale for classifier free guidance. Default is 5.0.",
+ )
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
+ parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference")
+ parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
+ parser.add_argument(
+ "--control_path",
+ type=str,
+ default=None,
+ help="path to control video for inference with controlnet. video file or directory with images",
+ )
+ parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
+ parser.add_argument(
+ "--cfg_skip_mode",
+ type=str,
+ default="none",
+ choices=["early", "late", "middle", "early_late", "alternate", "none"],
+ help="CFG skip mode. each mode skips different parts of the CFG. "
+ " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)",
+ )
+ parser.add_argument(
+ "--cfg_apply_ratio",
+ type=float,
+ default=None,
+ help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).",
+ )
+ parser.add_argument(
+ "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated"
+ )
+ parser.add_argument(
+ "--slg_scale",
+ type=float,
+ default=3.0,
+ help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond",
+ )
+ parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.")
+ parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.")
+ parser.add_argument(
+ "--slg_mode",
+ type=str,
+ default=None,
+ choices=["original", "uncond"],
+ help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred",
+ )
+
+ # Flow Matching
+ parser.add_argument(
+ "--flow_shift",
+ type=float,
+ default=None,
+ help="Shift factor for flow matching schedulers. Default depends on task.",
+ )
+
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
+ parser.add_argument(
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
+ )
+ parser.add_argument(
+ "--attn_mode",
+ type=str,
+ default="torch",
+ choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"],
+ help="attention mode",
+ )
+ parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
+ parser.add_argument(
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
+ parser.add_argument(
+ "--compile_args",
+ nargs=4,
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
+ help="Torch.compile settings",
+ )
+
+ # New arguments for batch and interactive modes
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
+
+ args = parser.parse_args()
+
+ # Validate arguments
+ if args.from_file and args.interactive:
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
+
+ if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None:
+ raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified")
+
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
+ args.output_type == "images" or args.output_type == "video"
+ ), "latent_path is only supported for images or video output"
+
+ return args
+
+
+def parse_prompt_line(line: str) -> Dict[str, Any]:
+ """Parse a prompt line into a dictionary of argument overrides
+
+ Args:
+ line: Prompt line with options
+
+ Returns:
+ Dict[str, Any]: Dictionary of argument overrides
+ """
+ # TODO common function with hv_train_network.line_to_prompt_dict
+ parts = line.split(" --")
+ prompt = parts[0].strip()
+
+ # Create dictionary of overrides
+ overrides = {"prompt": prompt}
+
+ for part in parts[1:]:
+ if not part.strip():
+ continue
+ option_parts = part.split(" ", 1)
+ option = option_parts[0].strip()
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
+
+ # Map options to argument names
+ if option == "w":
+ overrides["video_size_width"] = int(value)
+ elif option == "h":
+ overrides["video_size_height"] = int(value)
+ elif option == "f":
+ overrides["video_length"] = int(value)
+ elif option == "d":
+ overrides["seed"] = int(value)
+ elif option == "s":
+ overrides["infer_steps"] = int(value)
+ elif option == "g" or option == "l":
+ overrides["guidance_scale"] = float(value)
+ elif option == "fs":
+ overrides["flow_shift"] = float(value)
+ elif option == "i":
+ overrides["image_path"] = value
+ elif option == "cn":
+ overrides["control_path"] = value
+ elif option == "n":
+ overrides["negative_prompt"] = value
+
+ return overrides
+
+
+def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
+ """Apply overrides to args
+
+ Args:
+ args: Original arguments
+ overrides: Dictionary of overrides
+
+ Returns:
+ argparse.Namespace: New arguments with overrides applied
+ """
+ args_copy = copy.deepcopy(args)
+
+ for key, value in overrides.items():
+ if key == "video_size_width":
+ args_copy.video_size[1] = value
+ elif key == "video_size_height":
+ args_copy.video_size[0] = value
+ else:
+ setattr(args_copy, key, value)
+
+ return args_copy
+
+
+def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]:
+ """Return default values for each task
+
+ Args:
+ task: task name (t2v, t2i, i2v etc.)
+ size: size of the video (width, height)
+
+ Returns:
+ Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip)
+ """
+ width, height = size if size else (0, 0)
+
+ if "t2i" in task:
+ return 50, 5.0, 1, False
+ elif "i2v" in task:
+ flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0
+ return 40, flow_shift, 81, True
+ else: # t2v or default
+ return 50, 5.0, 81, False
+
+
+def setup_args(args: argparse.Namespace) -> argparse.Namespace:
+ """Validate and set default values for optional arguments
+
+ Args:
+ args: command line arguments
+
+ Returns:
+ argparse.Namespace: updated arguments
+ """
+ # Get default values for the task
+ infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size))
+
+ # Apply default values to unset arguments
+ if args.infer_steps is None:
+ args.infer_steps = infer_steps
+ if args.flow_shift is None:
+ args.flow_shift = flow_shift
+ if args.video_length is None:
+ args.video_length = video_length
+
+ # Force video_length to 1 for t2i tasks
+ if "t2i" in args.task:
+ assert args.video_length == 1, f"video_length should be 1 for task {args.task}"
+
+ # parse slg_layers
+ if args.slg_layers is not None:
+ args.slg_layers = list(map(int, args.slg_layers.split(",")))
+
+ return args
+
+
+def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
+ """Validate video size and length
+
+ Args:
+ args: command line arguments
+
+ Returns:
+ Tuple[int, int, int]: (height, width, video_length)
+ """
+ height = args.video_size[0]
+ width = args.video_size[1]
+ size = f"{width}*{height}"
+
+ if size not in SUPPORTED_SIZES[args.task]:
+ logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.")
+
+ video_length = args.video_length
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ return height, width, video_length
+
+
+def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]:
+ """calculate dimensions for the generation
+
+ Args:
+ video_size: video frame size (height, width)
+ video_length: number of frames in the video
+ config: model configuration
+
+ Returns:
+ Tuple[Tuple[int, int, int, int], int]:
+ ((channels, frames, height, width), seq_len)
+ """
+ height, width = video_size
+ frames = video_length
+
+ # calculate latent space dimensions
+ lat_f = (frames - 1) // config.vae_stride[0] + 1
+ lat_h = height // config.vae_stride[1]
+ lat_w = width // config.vae_stride[2]
+
+ # calculate sequence length
+ seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f)
+
+ return ((16, lat_f, lat_h, lat_w), seq_len)
+
+
+def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE:
+ """load VAE model
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ device: device to use
+ dtype: data type for the model
+
+ Returns:
+ WanVAE: loaded VAE model
+ """
+ vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint)
+
+ logger.info(f"Loading VAE model from {vae_path}")
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
+ vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device)
+ return vae
+
+
+def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel:
+ """load text encoder (T5) model
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ device: device to use
+
+ Returns:
+ T5EncoderModel: loaded text encoder model
+ """
+ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint)
+ tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer)
+
+ text_encoder = T5EncoderModel(
+ text_len=config.text_len,
+ dtype=config.t5_dtype,
+ device=device,
+ checkpoint_path=checkpoint_path,
+ tokenizer_path=tokenizer_path,
+ weight_path=args.t5,
+ fp8=args.fp8_t5,
+ )
+
+ return text_encoder
+
+
+def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel:
+ """load CLIP model (for I2V only)
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ device: device to use
+
+ Returns:
+ CLIPModel: loaded CLIP model
+ """
+ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint)
+ tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer)
+
+ clip = CLIPModel(
+ dtype=config.clip_dtype,
+ device=device,
+ checkpoint_path=checkpoint_path,
+ tokenizer_path=tokenizer_path,
+ weight_path=args.clip,
+ )
+
+ return clip
+
+
+def load_dit_model(
+ args: argparse.Namespace,
+ config,
+ device: torch.device,
+ dit_dtype: torch.dtype,
+ dit_weight_dtype: Optional[torch.dtype] = None,
+ is_i2v: bool = False,
+) -> WanModel:
+ """load DiT model
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ device: device to use
+ dit_dtype: data type for the model
+ dit_weight_dtype: data type for the model weights. None for as-is
+ is_i2v: I2V mode
+
+ Returns:
+ WanModel: loaded DiT model
+ """
+ loading_device = "cpu"
+ if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled:
+ loading_device = device
+
+ loading_weight_dtype = dit_weight_dtype
+ if args.fp8_scaled or args.lora_weight is not None:
+ loading_weight_dtype = dit_dtype # load as-is
+
+ # do not fp8 optimize because we will merge LoRA weights
+ model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False)
+
+ return model
+
+
+def merge_lora_weights(
+ lora_module: ModuleType,
+ model: torch.nn.Module,
+ args: argparse.Namespace,
+ device: torch.device,
+ converter: Optional[callable] = None,
+) -> None:
+ """merge LoRA weights to the model
+
+ Args:
+ lora_module: LoRA module, e.g. lora_wan
+ model: DiT model
+ args: command line arguments
+ device: device to use
+ converter: Optional callable to convert weights
+ """
+ if args.lora_weight is None or len(args.lora_weight) == 0:
+ return
+
+ for i, lora_weight in enumerate(args.lora_weight):
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
+ lora_multiplier = args.lora_multiplier[i]
+ else:
+ lora_multiplier = 1.0
+
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
+ weights_sd = load_file(lora_weight)
+ if converter is not None:
+ weights_sd = converter(weights_sd)
+
+ # apply include/exclude patterns
+ original_key_count = len(weights_sd.keys())
+ if args.include_patterns is not None and len(args.include_patterns) > i:
+ include_pattern = args.include_patterns[i]
+ regex_include = re.compile(include_pattern)
+ weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
+ logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
+ if args.exclude_patterns is not None and len(args.exclude_patterns) > i:
+ original_key_count_ex = len(weights_sd.keys())
+ exclude_pattern = args.exclude_patterns[i]
+ regex_exclude = re.compile(exclude_pattern)
+ weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
+ logger.info(
+ f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}"
+ )
+ if len(weights_sd) != original_key_count:
+ remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
+ remaining_keys.sort()
+ logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
+ if len(weights_sd) == 0:
+ logger.warning(f"No keys left after filtering.")
+
+ if args.lycoris:
+ lycoris_net, _ = create_network_from_weights(
+ multiplier=lora_multiplier,
+ file=None,
+ weights_sd=weights_sd,
+ unet=model,
+ text_encoder=None,
+ vae=None,
+ for_inference=True,
+ )
+ lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device)
+ else:
+ network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True)
+ network.merge_to(None, model, weights_sd, device=device, non_blocking=True)
+
+ synchronize_device(device)
+ logger.info("LoRA weights loaded")
+
+ # save model here before casting to dit_weight_dtype
+ if args.save_merged_model:
+ logger.info(f"Saving merged model to {args.save_merged_model}")
+ mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory
+ logger.info("Merged model saved")
+
+
+def optimize_model(
+ model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype
+) -> None:
+ """optimize the model (FP8 conversion, device move etc.)
+
+ Args:
+ model: dit model
+ args: command line arguments
+ device: device to use
+ dit_dtype: dtype for the model
+ dit_weight_dtype: dtype for the model weights
+ """
+ if args.fp8_scaled:
+ # load state dict as-is and optimize to fp8
+ state_dict = model.state_dict()
+
+ # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
+ move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
+ state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
+
+ info = model.load_state_dict(state_dict, strict=True, assign=True)
+ logger.info(f"Loaded FP8 optimized weights: {info}")
+
+ if args.blocks_to_swap == 0:
+ model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
+ else:
+ # simple cast to dit_dtype
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
+ target_device = None
+
+ if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
+ logger.info(f"Convert model to {dit_weight_dtype}")
+ target_dtype = dit_weight_dtype
+
+ if args.blocks_to_swap == 0:
+ logger.info(f"Move model to device: {device}")
+ target_device = device
+
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
+
+ if args.compile:
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
+ logger.info(
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
+ )
+ torch._dynamo.config.cache_size_limit = 32
+ for i in range(len(model.blocks)):
+ model.blocks[i] = torch.compile(
+ model.blocks[i],
+ backend=compile_backend,
+ mode=compile_mode,
+ dynamic=compile_dynamic.lower() in "true",
+ fullgraph=compile_fullgraph.lower() in "true",
+ )
+
+ if args.blocks_to_swap > 0:
+ logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
+ model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
+ model.move_to_device_except_swap_blocks(device)
+ model.prepare_block_swap_before_forward()
+ else:
+ # make sure the model is on the right device
+ model.to(device)
+
+ model.eval().requires_grad_(False)
+ clean_memory_on_device(device)
+
+
+def prepare_t2v_inputs(
+ args: argparse.Namespace,
+ config,
+ accelerator: Accelerator,
+ device: torch.device,
+ vae: Optional[WanVAE] = None,
+ encoded_context: Optional[Dict] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
+ """Prepare inputs for T2V
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ accelerator: Accelerator instance
+ device: device to use
+ vae: VAE model for control video encoding
+ encoded_context: Pre-encoded text context
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
+ (noise, context, context_null, (arg_c, arg_null))
+ """
+ # Prepare inputs for T2V
+ # calculate dimensions and sequence length
+ height, width = args.video_size
+ frames = args.video_length
+ (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config)
+ target_shape = (16, lat_f, lat_h, lat_w)
+
+ # configure negative prompt
+ n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
+
+ # set seed
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+ if not args.cpu_noise:
+ seed_g = torch.Generator(device=device)
+ seed_g.manual_seed(seed)
+ else:
+ # ComfyUI compatible noise
+ seed_g = torch.manual_seed(seed)
+
+ if encoded_context is None:
+ # load text encoder
+ text_encoder = load_text_encoder(args, config, device)
+ text_encoder.model.to(device)
+
+ # encode prompt
+ with torch.no_grad():
+ if args.fp8_t5:
+ with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
+ context = text_encoder([args.prompt], device)
+ context_null = text_encoder([n_prompt], device)
+ else:
+ context = text_encoder([args.prompt], device)
+ context_null = text_encoder([n_prompt], device)
+
+ # free text encoder and clean memory
+ del text_encoder
+ clean_memory_on_device(device)
+ else:
+ # Use pre-encoded context
+ context = encoded_context["context"]
+ context_null = encoded_context["context_null"]
+
+ # Fun-Control: encode control video to latent space
+ if config.is_fun_control:
+ # TODO use same resizing as for image
+ logger.info(f"Encoding control video to latent space")
+ # C, F, H, W
+ control_video = load_control_video(args.control_path, frames, height, width).to(device)
+ vae.to_device(device)
+ with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
+ control_latent = vae.encode([control_video])[0]
+ y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent
+ vae.to_device("cpu")
+ else:
+ y = None
+
+ # generate noise
+ noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu")
+ noise = noise.to(device)
+
+ # prepare model input arguments
+ arg_c = {"context": context, "seq_len": seq_len}
+ arg_null = {"context": context_null, "seq_len": seq_len}
+ if y is not None:
+ arg_c["y"] = [y]
+ arg_null["y"] = [y]
+
+ return noise, context, context_null, (arg_c, arg_null)
+
+
+def prepare_i2v_inputs(
+ args: argparse.Namespace,
+ config,
+ accelerator: Accelerator,
+ device: torch.device,
+ vae: WanVAE,
+ encoded_context: Optional[Dict] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
+ """Prepare inputs for I2V
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ accelerator: Accelerator instance
+ device: device to use
+ vae: VAE model, used for image encoding
+ encoded_context: Pre-encoded text context
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
+ (noise, context, context_null, y, (arg_c, arg_null))
+ """
+ # get video dimensions
+ height, width = args.video_size
+ frames = args.video_length
+ max_area = width * height
+
+ # load image
+ img = Image.open(args.image_path).convert("RGB")
+
+ # convert to numpy
+ img_cv2 = np.array(img) # PIL to numpy
+
+ # convert to tensor (-1 to 1)
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
+
+ # end frame image
+ if args.end_image_path is not None:
+ end_img = Image.open(args.end_image_path).convert("RGB")
+ end_img_cv2 = np.array(end_img) # PIL to numpy
+ else:
+ end_img = None
+ end_img_cv2 = None
+ has_end_image = end_img is not None
+
+ # calculate latent dimensions: keep aspect ratio
+ height, width = img_tensor.shape[1:]
+ aspect_ratio = height / width
+ lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
+ lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
+ height = lat_h * config.vae_stride[1]
+ width = lat_w * config.vae_stride[2]
+ lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames
+ max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2])
+
+ # set seed
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+ if not args.cpu_noise:
+ seed_g = torch.Generator(device=device)
+ seed_g.manual_seed(seed)
+ else:
+ # ComfyUI compatible noise
+ seed_g = torch.manual_seed(seed)
+
+ # generate noise
+ noise = torch.randn(
+ 16,
+ lat_f + (1 if has_end_image else 0),
+ lat_h,
+ lat_w,
+ dtype=torch.float32,
+ generator=seed_g,
+ device=device if not args.cpu_noise else "cpu",
+ )
+ noise = noise.to(device)
+
+ # configure negative prompt
+ n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
+
+ if encoded_context is None:
+ # load text encoder
+ text_encoder = load_text_encoder(args, config, device)
+ text_encoder.model.to(device)
+
+ # encode prompt
+ with torch.no_grad():
+ if args.fp8_t5:
+ with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
+ context = text_encoder([args.prompt], device)
+ context_null = text_encoder([n_prompt], device)
+ else:
+ context = text_encoder([args.prompt], device)
+ context_null = text_encoder([n_prompt], device)
+
+ # free text encoder and clean memory
+ del text_encoder
+ clean_memory_on_device(device)
+
+ # load CLIP model
+ clip = load_clip_model(args, config, device)
+ clip.model.to(device)
+
+ # encode image to CLIP context
+ logger.info(f"Encoding image to CLIP context")
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
+ logger.info(f"Encoding complete")
+
+ # free CLIP model and clean memory
+ del clip
+ clean_memory_on_device(device)
+ else:
+ # Use pre-encoded context
+ context = encoded_context["context"]
+ context_null = encoded_context["context_null"]
+ clip_context = encoded_context["clip_context"]
+
+ # encode image to latent space with VAE
+ logger.info(f"Encoding image to latent space")
+ vae.to_device(device)
+
+ # resize image
+ interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC
+ img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation)
+ img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
+ img_resized = img_resized.unsqueeze(1) # CFHW
+
+ if has_end_image:
+ interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC
+ end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation)
+ end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
+ end_img_resized = end_img_resized.unsqueeze(1) # CFHW
+
+ # create mask for the first frame
+ msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device)
+ msk[:, 0] = 1
+ if has_end_image:
+ msk[:, -1] = 1
+
+ # encode image to latent space
+ with accelerator.autocast(), torch.no_grad():
+ # padding to match the required number of frames
+ padding_frames = frames - 1 # the first frame is image
+ img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1)
+ y = vae.encode([img_resized])[0]
+
+ if has_end_image:
+ y_end = vae.encode([end_img_resized])[0]
+ y = torch.concat([y, y_end], dim=1) # add end frame
+
+ y = torch.concat([msk, y])
+ logger.info(f"Encoding complete")
+
+ # Fun-Control: encode control video to latent space
+ if config.is_fun_control:
+ # TODO use same resizing as for image
+ logger.info(f"Encoding control video to latent space")
+ # C, F, H, W
+ control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device)
+ with accelerator.autocast(), torch.no_grad():
+ control_latent = vae.encode([control_video])[0]
+ y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it
+ if has_end_image:
+ y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work
+ else:
+ y[:, 1:] = 0 # remove image latent except first frame
+ y = torch.concat([control_latent, y], dim=0) # add control video latent
+
+ # prepare model input arguments
+ arg_c = {
+ "context": [context[0]],
+ "clip_fea": clip_context,
+ "seq_len": max_seq_len,
+ "y": [y],
+ }
+
+ arg_null = {
+ "context": context_null,
+ "clip_fea": clip_context,
+ "seq_len": max_seq_len,
+ "y": [y],
+ }
+
+ vae.to_device("cpu") # move VAE to CPU to save memory
+ clean_memory_on_device(device)
+
+ return noise, context, context_null, y, (arg_c, arg_null)
+
+
+def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor:
+ """load control video to latent space
+
+ Args:
+ control_path: path to control video
+ frames: number of frames in the video
+ height: height of the video
+ width: width of the video
+
+ Returns:
+ torch.Tensor: control video latent, CFHW
+ """
+ logger.info(f"Load control video from {control_path}")
+ video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames
+ if len(video) < frames:
+ raise ValueError(f"Video length is less than {frames}")
+ # video = np.stack(video, axis=0) # F, H, W, C
+ video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1
+ video = video.permute(1, 0, 2, 3) # C, F, H, W
+ return video
+
+
+def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
+ """setup scheduler for sampling
+
+ Args:
+ args: command line arguments
+ config: model configuration
+ device: device to use
+
+ Returns:
+ Tuple[Any, torch.Tensor]: (scheduler, timesteps)
+ """
+ if args.sample_solver == "unipc":
+ scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
+ scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
+ timesteps = scheduler.timesteps
+ elif args.sample_solver == "dpm++":
+ scheduler = FlowDPMSolverMultistepScheduler(
+ num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
+ )
+ sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
+ timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
+ elif args.sample_solver == "vanilla":
+ scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
+ scheduler.set_timesteps(args.infer_steps, device=device)
+ timesteps = scheduler.timesteps
+
+ # FlowMatchDiscreteScheduler does not support generator argument in step method
+ org_step = scheduler.step
+
+ def step_wrapper(
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ return_dict: bool = True,
+ generator=None,
+ ):
+ return org_step(model_output, timestep, sample, return_dict=return_dict)
+
+ scheduler.step = step_wrapper
+ else:
+ raise NotImplementedError("Unsupported solver.")
+
+ return scheduler, timesteps
+
+
+def run_sampling(
+ model: WanModel,
+ noise: torch.Tensor,
+ scheduler: Any,
+ timesteps: torch.Tensor,
+ args: argparse.Namespace,
+ inputs: Tuple[dict, dict],
+ device: torch.device,
+ seed_g: torch.Generator,
+ accelerator: Accelerator,
+ is_i2v: bool = False,
+ use_cpu_offload: bool = True,
+) -> torch.Tensor:
+ """run sampling
+ Args:
+ model: dit model
+ noise: initial noise
+ scheduler: scheduler for sampling
+ timesteps: time steps for sampling
+ args: command line arguments
+ inputs: model input (arg_c, arg_null)
+ device: device to use
+ seed_g: random generator
+ accelerator: Accelerator instance
+ is_i2v: I2V mode (False means T2V mode)
+ use_cpu_offload: Whether to offload tensors to CPU during processing
+ Returns:
+ torch.Tensor: generated latent
+ """
+ arg_c, arg_null = inputs
+
+ latent = noise
+ latent_storage_device = device if not use_cpu_offload else "cpu"
+ latent = latent.to(latent_storage_device)
+
+ # cfg skip
+ apply_cfg_array = []
+ num_timesteps = len(timesteps)
+
+ if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None:
+ # Calculate thresholds based on cfg_apply_ratio
+ apply_steps = int(num_timesteps * args.cfg_apply_ratio)
+
+ if args.cfg_skip_mode == "early":
+ # Skip CFG in early steps, apply in late steps
+ start_index = num_timesteps - apply_steps
+ end_index = num_timesteps
+ elif args.cfg_skip_mode == "late":
+ # Skip CFG in late steps, apply in early steps
+ start_index = 0
+ end_index = apply_steps
+ elif args.cfg_skip_mode == "early_late":
+ # Skip CFG in early and late steps, apply in middle steps
+ start_index = (num_timesteps - apply_steps) // 2
+ end_index = start_index + apply_steps
+ elif args.cfg_skip_mode == "middle":
+ # Skip CFG in middle steps, apply in early and late steps
+ skip_steps = num_timesteps - apply_steps
+ middle_start = (num_timesteps - skip_steps) // 2
+ middle_end = middle_start + skip_steps
+
+ w = 0.0
+ for step_idx in range(num_timesteps):
+ if args.cfg_skip_mode == "alternate":
+ # accumulate w and apply CFG when w >= 1.0
+ w += args.cfg_apply_ratio
+ apply = w >= 1.0
+ if apply:
+ w -= 1.0
+ elif args.cfg_skip_mode == "middle":
+ # Skip CFG in early and late steps, apply in middle steps
+ apply = step_idx < middle_start or step_idx >= middle_end
+ else:
+ # Apply CFG on some steps based on ratio
+ apply = step_idx >= start_index and step_idx < end_index
+
+ apply_cfg_array.append(apply)
+
+ pattern = ["A" if apply else "S" for apply in apply_cfg_array]
+ pattern = "".join(pattern)
+ logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}")
+ else:
+ # Apply CFG on all steps
+ apply_cfg_array = [True] * num_timesteps
+
+ # SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py
+ slg_start_step = int(args.slg_start * num_timesteps)
+ slg_end_step = int(args.slg_end * num_timesteps)
+
+ for i, t in enumerate(tqdm(timesteps)):
+ # latent is on CPU if use_cpu_offload is True
+ latent_model_input = [latent.to(device)]
+ timestep = torch.stack([t]).to(device)
+
+ with accelerator.autocast(), torch.no_grad():
+ noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device)
+
+ apply_cfg = apply_cfg_array[i] # apply CFG or not
+ if apply_cfg:
+ apply_slg = i >= slg_start_step and i < slg_end_step
+ # print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}")
+ if args.slg_mode == "original" and apply_slg:
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
+
+ # apply guidance
+ # SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # calculate skip layer out
+ skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
+ latent_storage_device
+ )
+
+ # apply skip layer guidance
+ # SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg
+ noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out)
+ elif args.slg_mode == "uncond" and apply_slg:
+ # noise_pred_uncond is skip layer out
+ noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
+ latent_storage_device
+ )
+
+ # apply guidance
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ else:
+ # normal guidance
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
+
+ # apply guidance
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
+
+ # step
+ latent_input = latent.unsqueeze(0)
+ temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0]
+
+ # update latent
+ latent = temp_x0.squeeze(0)
+
+ return latent
+
+
+def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor:
+ """main function for generation
+
+ Args:
+ args: command line arguments
+ shared_models: dictionary containing pre-loaded models and encoded data
+
+ Returns:
+ torch.Tensor: generated latent
+ """
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
+ gen_settings.device,
+ gen_settings.cfg,
+ gen_settings.dit_dtype,
+ gen_settings.dit_weight_dtype,
+ gen_settings.vae_dtype,
+ )
+
+ # prepare accelerator
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
+
+ # I2V or T2V
+ is_i2v = "i2v" in args.task
+
+ # prepare seed
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+ args.seed = seed # set seed to args for saving
+
+ # Check if we have shared models
+ if shared_models is not None:
+ # Use shared models and encoded data
+ vae = shared_models.get("vae")
+ model = shared_models.get("model")
+ encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt)
+
+ # prepare inputs
+ if is_i2v:
+ # I2V
+ noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
+ else:
+ # T2V
+ noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
+ else:
+ # prepare inputs without shared models
+ if is_i2v:
+ # I2V: need text encoder, VAE and CLIP
+ vae = load_vae(args, cfg, device, vae_dtype)
+ noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae)
+ # vae is on CPU after prepare_i2v_inputs
+ else:
+ # T2V: need text encoder
+ vae = None
+ if cfg.is_fun_control:
+ # Fun-Control: need VAE for encoding control video
+ vae = load_vae(args, cfg, device, vae_dtype)
+ noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae)
+
+ # load DiT model
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
+
+ # merge LoRA weights
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ merge_lora_weights(lora_wan, model, args, device)
+
+ # if we only want to save the model, we can skip the rest
+ if args.save_merged_model:
+ return None
+
+ # optimize model: fp8 conversion, block swap etc.
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
+
+ # setup scheduler
+ scheduler, timesteps = setup_scheduler(args, cfg, device)
+
+ # set random generator
+ seed_g = torch.Generator(device=device)
+ seed_g.manual_seed(seed)
+
+ # run sampling
+ latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v)
+
+ # Only clean up shared models if they were created within this function
+ if shared_models is None:
+ # free memory
+ del model
+ del scheduler
+ synchronize_device(device)
+
+ # wait for 5 seconds until block swap is done
+ logger.info("Waiting for 5 seconds to finish block swap")
+ time.sleep(5)
+
+ gc.collect()
+ clean_memory_on_device(device)
+
+ # save VAE model for decoding
+ if vae is None:
+ args._vae = None
+ else:
+ args._vae = vae
+
+ return latent
+
+
+def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor:
+ """decode latent
+
+ Args:
+ latent: latent tensor
+ args: command line arguments
+ cfg: model configuration
+
+ Returns:
+ torch.Tensor: decoded video or image
+ """
+ device = torch.device(args.device)
+
+ # load VAE model or use the one from the generation
+ vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16
+ if hasattr(args, "_vae") and args._vae is not None:
+ vae = args._vae
+ else:
+ vae = load_vae(args, cfg, device, vae_dtype)
+
+ vae.to_device(device)
+
+ logger.info(f"Decoding video from latents: {latent.shape}")
+ x0 = latent.to(device)
+
+ with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad():
+ videos = vae.decode(x0)
+
+ # some tail frames may be corrupted when end frame is used, we add an option to remove them
+ if args.trim_tail_frames:
+ videos[0] = videos[0][:, : -args.trim_tail_frames]
+
+ logger.info(f"Decoding complete")
+ video = videos[0]
+ del videos
+ video = video.to(torch.float32).cpu()
+
+ return video
+
+
+def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
+ """Save latent to file
+
+ Args:
+ latent: latent tensor
+ args: command line arguments
+ height: height of frame
+ width: width of frame
+
+ Returns:
+ str: Path to saved latent file
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ seed = args.seed
+ video_length = args.video_length
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
+
+ if args.no_metadata:
+ metadata = None
+ else:
+ metadata = {
+ "seeds": f"{seed}",
+ "prompt": f"{args.prompt}",
+ "height": f"{height}",
+ "width": f"{width}",
+ "video_length": f"{video_length}",
+ "infer_steps": f"{args.infer_steps}",
+ "guidance_scale": f"{args.guidance_scale}",
+ }
+ if args.negative_prompt is not None:
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
+
+ sd = {"latent": latent}
+ save_file(sd, latent_path, metadata=metadata)
+ logger.info(f"Latent saved to: {latent_path}")
+
+ return latent_path
+
+
+def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
+ """Save video to file
+
+ Args:
+ video: Video tensor
+ args: command line arguments
+ original_base_name: Original base name (if latents are loaded from files)
+
+ Returns:
+ str: Path to saved video file
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ seed = args.seed
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
+ video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4"
+
+ video = video.unsqueeze(0)
+ save_videos_grid(video, video_path, fps=args.fps, rescale=True)
+ logger.info(f"Video saved to: {video_path}")
+
+ return video_path
+
+
+def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
+ """Save images to directory
+
+ Args:
+ sample: Video tensor
+ args: command line arguments
+ original_base_name: Original base name (if latents are loaded from files)
+
+ Returns:
+ str: Path to saved images directory
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
+
+ seed = args.seed
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
+ image_name = f"{time_flag}_{seed}{original_name}"
+ sample = sample.unsqueeze(0)
+ save_images_grid(sample, save_path, image_name, rescale=True)
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
+
+ return f"{save_path}/{image_name}"
+
+
+def save_output(
+ latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None
+) -> None:
+ """save output
+
+ Args:
+ latent: latent tensor
+ args: command line arguments
+ cfg: model configuration
+ height: height of frame
+ width: width of frame
+ original_base_names: original base names (if latents are loaded from files)
+ """
+ if args.output_type == "latent" or args.output_type == "both":
+ # save latent
+ save_latent(latent, args, height, width)
+
+ if args.output_type == "video" or args.output_type == "both":
+ # save video
+ sample = decode_latent(latent.unsqueeze(0), args, cfg)
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
+ save_video(sample, args, original_name)
+
+ elif args.output_type == "images":
+ # save images
+ sample = decode_latent(latent.unsqueeze(0), args, cfg)
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
+ save_images(sample, args, original_name)
+
+
+def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
+ """Process multiple prompts for batch mode
+
+ Args:
+ prompt_lines: List of prompt lines
+ base_args: Base command line arguments
+
+ Returns:
+ List[Dict]: List of prompt data dictionaries
+ """
+ prompts_data = []
+
+ for line in prompt_lines:
+ line = line.strip()
+ if not line or line.startswith("#"): # Skip empty lines and comments
+ continue
+
+ # Parse prompt line and create override dictionary
+ prompt_data = parse_prompt_line(line)
+ logger.info(f"Parsed prompt data: {prompt_data}")
+ prompts_data.append(prompt_data)
+
+ return prompts_data
+
+
+def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
+ """Process multiple prompts with model reuse
+
+ Args:
+ prompts_data: List of prompt data dictionaries
+ args: Base command line arguments
+ """
+ if not prompts_data:
+ logger.warning("No valid prompts found")
+ return
+
+ # 1. Load configuration
+ gen_settings = get_generation_settings(args)
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
+ gen_settings.device,
+ gen_settings.cfg,
+ gen_settings.dit_dtype,
+ gen_settings.dit_weight_dtype,
+ gen_settings.vae_dtype,
+ )
+ is_i2v = "i2v" in args.task
+
+ # 2. Encode all prompts
+ logger.info("Loading text encoder to encode all prompts")
+ text_encoder = load_text_encoder(args, cfg, device)
+ text_encoder.model.to(device)
+
+ encoded_contexts = {}
+
+ with torch.no_grad():
+ for prompt_data in prompts_data:
+ prompt = prompt_data["prompt"]
+ prompt_args = apply_overrides(args, prompt_data)
+ n_prompt = prompt_data.get(
+ "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
+ )
+
+ if args.fp8_t5:
+ with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
+ context = text_encoder([prompt], device)
+ context_null = text_encoder([n_prompt], device)
+ else:
+ context = text_encoder([prompt], device)
+ context_null = text_encoder([n_prompt], device)
+
+ encoded_contexts[prompt] = {"context": context, "context_null": context_null}
+
+ # Free text encoder and clean memory
+ del text_encoder
+ clean_memory_on_device(device)
+
+ # 3. Process I2V additional encodings if needed
+ vae = None
+ if is_i2v:
+ logger.info("Loading VAE and CLIP for I2V preprocessing")
+ vae = load_vae(args, cfg, device, vae_dtype)
+ vae.to_device(device)
+
+ clip = load_clip_model(args, cfg, device)
+ clip.model.to(device)
+
+ # Process each image and encode with CLIP
+ for prompt_data in prompts_data:
+ if "image_path" not in prompt_data:
+ continue
+
+ prompt_args = apply_overrides(args, prompt_data)
+ if not os.path.exists(prompt_args.image_path):
+ logger.warning(f"Image path not found: {prompt_args.image_path}")
+ continue
+
+ # Load and encode image with CLIP
+ img = Image.open(prompt_args.image_path).convert("RGB")
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
+
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
+
+ encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context
+
+ # Free CLIP and clean memory
+ del clip
+ clean_memory_on_device(device)
+
+ # Keep VAE in CPU memory for later use
+ vae.to_device("cpu")
+ elif cfg.is_fun_control:
+ # For Fun-Control, we need VAE but keep it on CPU
+ vae = load_vae(args, cfg, device, vae_dtype)
+ vae.to_device("cpu")
+
+ # 4. Load DiT model
+ logger.info("Loading DiT model")
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
+
+ # 5. Merge LoRA weights if needed
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ merge_lora_weights(lora_wan, model, args, device)
+ if args.save_merged_model:
+ logger.info("Model merged and saved. Exiting.")
+ return
+
+ # 6. Optimize model
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
+
+ # Create shared models dict for generate function
+ shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts}
+
+ # 7. Generate for each prompt
+ all_latents = []
+ all_prompt_args = []
+
+ for i, prompt_data in enumerate(prompts_data):
+ logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...")
+
+ # Apply overrides for this prompt
+ prompt_args = apply_overrides(args, prompt_data)
+
+ # Generate latent
+ latent = generate(prompt_args, gen_settings, shared_models)
+
+ # Save latent if needed
+ height, width, _ = check_inputs(prompt_args)
+ if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
+ save_latent(latent, prompt_args, height, width)
+
+ all_latents.append(latent)
+ all_prompt_args.append(prompt_args)
+
+ # 8. Free DiT model
+ del model
+ clean_memory_on_device(device)
+ synchronize_device(device)
+
+ # wait for 5 seconds until block swap is done
+ logger.info("Waiting for 5 seconds to finish block swap")
+ time.sleep(5)
+
+ gc.collect()
+ clean_memory_on_device(device)
+
+ # 9. Decode latents if needed
+ if args.output_type != "latent":
+ logger.info("Decoding latents to videos/images")
+
+ if vae is None:
+ vae = load_vae(args, cfg, device, vae_dtype)
+
+ vae.to_device(device)
+
+ for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
+ logger.info(f"Decoding output {i+1}/{len(all_latents)}")
+
+ # Decode latent
+ video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
+
+ # Save as video or images
+ if prompt_args.output_type == "video" or prompt_args.output_type == "both":
+ save_video(video, prompt_args)
+ elif prompt_args.output_type == "images":
+ save_images(video, prompt_args)
+
+ # Free VAE
+ del vae
+
+ clean_memory_on_device(device)
+ gc.collect()
+
+
+def process_interactive(args: argparse.Namespace) -> None:
+ """Process prompts in interactive mode
+
+ Args:
+ args: Base command line arguments
+ """
+ gen_settings = get_generation_settings(args)
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
+ gen_settings.device,
+ gen_settings.cfg,
+ gen_settings.dit_dtype,
+ gen_settings.dit_weight_dtype,
+ gen_settings.vae_dtype,
+ )
+ is_i2v = "i2v" in args.task
+
+ # Initialize models to None
+ text_encoder = None
+ vae = None
+ model = None
+ clip = None
+
+ print("Interactive mode. Enter prompts (Ctrl+D to exit):")
+
+ try:
+ while True:
+ try:
+ line = input("> ")
+ if not line.strip():
+ continue
+
+ # Parse prompt
+ prompt_data = parse_prompt_line(line)
+ prompt_args = apply_overrides(args, prompt_data)
+
+ # Ensure we have all the models we need
+
+ # 1. Load text encoder if not already loaded
+ if text_encoder is None:
+ logger.info("Loading text encoder")
+ text_encoder = load_text_encoder(args, cfg, device)
+
+ text_encoder.model.to(device)
+
+ # Encode prompt
+ n_prompt = prompt_data.get(
+ "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
+ )
+
+ with torch.no_grad():
+ if args.fp8_t5:
+ with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
+ context = text_encoder([prompt_data["prompt"]], device)
+ context_null = text_encoder([n_prompt], device)
+ else:
+ context = text_encoder([prompt_data["prompt"]], device)
+ context_null = text_encoder([n_prompt], device)
+
+ encoded_context = {"context": context, "context_null": context_null}
+
+ # Move text encoder to CPU after use
+ text_encoder.model.to("cpu")
+
+ # 2. For I2V, we need CLIP and VAE
+ if is_i2v:
+ if clip is None:
+ logger.info("Loading CLIP model")
+ clip = load_clip_model(args, cfg, device)
+
+ clip.model.to(device)
+
+ # Encode image with CLIP if there's an image path
+ if prompt_args.image_path and os.path.exists(prompt_args.image_path):
+ img = Image.open(prompt_args.image_path).convert("RGB")
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
+
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
+
+ encoded_context["clip_context"] = clip_context
+
+ # Move CLIP to CPU after use
+ clip.model.to("cpu")
+
+ # Load VAE if needed
+ if vae is None:
+ logger.info("Loading VAE model")
+ vae = load_vae(args, cfg, device, vae_dtype)
+ elif cfg.is_fun_control and vae is None:
+ # For Fun-Control, we need VAE
+ logger.info("Loading VAE model for Fun-Control")
+ vae = load_vae(args, cfg, device, vae_dtype)
+
+ # 3. Load DiT model if not already loaded
+ if model is None:
+ logger.info("Loading DiT model")
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
+
+ # Merge LoRA weights if needed
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ merge_lora_weights(lora_wan, model, args, device)
+
+ # Optimize model
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
+ else:
+ # Move model to GPU if it was offloaded
+ model.to(device)
+
+ # Create shared models dict
+ shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}}
+
+ # Generate latent
+ latent = generate(prompt_args, gen_settings, shared_models)
+
+ # Move model to CPU after generation
+ model.to("cpu")
+
+ # Save latent if needed
+ height, width, _ = check_inputs(prompt_args)
+ if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
+ save_latent(latent, prompt_args, height, width)
+
+ # Decode and save output
+ if prompt_args.output_type != "latent":
+ if vae is None:
+ vae = load_vae(args, cfg, device, vae_dtype)
+
+ vae.to_device(device)
+ video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
+
+ if prompt_args.output_type == "video" or prompt_args.output_type == "both":
+ save_video(video, prompt_args)
+ elif prompt_args.output_type == "images":
+ save_images(video, prompt_args)
+
+ # Move VAE to CPU after use
+ vae.to_device("cpu")
+
+ clean_memory_on_device(device)
+
+ except KeyboardInterrupt:
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
+ continue
+
+ except EOFError:
+ print("\nExiting interactive mode")
+
+ # Clean up all models
+ if text_encoder is not None:
+ del text_encoder
+ if clip is not None:
+ del clip
+ if vae is not None:
+ del vae
+ if model is not None:
+ del model
+
+ clean_memory_on_device(device)
+ gc.collect()
+
+
+def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
+ device = torch.device(args.device)
+
+ cfg = WAN_CONFIGS[args.task]
+
+ # select dtype
+ dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16
+ if dit_dtype.itemsize == 1:
+ # if weight is in fp8, use bfloat16 for DiT (input/output)
+ dit_dtype = torch.bfloat16
+ if args.fp8_scaled:
+ raise ValueError(
+ "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
+ )
+
+ dit_weight_dtype = dit_dtype # default
+ if args.fp8_scaled:
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
+ elif args.fp8:
+ dit_weight_dtype = torch.float8_e4m3fn
+
+ vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype
+ logger.info(
+ f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}"
+ )
+
+ gen_settings = GenerationSettings(
+ device=device,
+ cfg=cfg,
+ dit_dtype=dit_dtype,
+ dit_weight_dtype=dit_weight_dtype,
+ vae_dtype=vae_dtype,
+ )
+ return gen_settings
+
+
+def main():
+ # Parse arguments
+ args = parse_args()
+
+ # Check if latents are provided
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
+
+ # Set device
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ logger.info(f"Using device: {device}")
+ args.device = device
+
+ if latents_mode:
+ # Original latent decode mode
+ cfg = WAN_CONFIGS[args.task] # any task is fine
+ original_base_names = []
+ latents_list = []
+ seeds = []
+
+ assert len(args.latent_path) == 1, "Only one latent path is supported for now"
+
+ for latent_path in args.latent_path:
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
+ seed = 0
+
+ if os.path.splitext(latent_path)[1] != ".safetensors":
+ latents = torch.load(latent_path, map_location="cpu")
+ else:
+ latents = load_file(latent_path)["latent"]
+ with safe_open(latent_path, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ logger.info(f"Loaded metadata: {metadata}")
+
+ if "seeds" in metadata:
+ seed = int(metadata["seeds"])
+ if "height" in metadata and "width" in metadata:
+ height = int(metadata["height"])
+ width = int(metadata["width"])
+ args.video_size = [height, width]
+ if "video_length" in metadata:
+ args.video_length = int(metadata["video_length"])
+
+ seeds.append(seed)
+ latents_list.append(latents)
+
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
+
+ latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
+
+ height = latents.shape[-2]
+ width = latents.shape[-1]
+ height *= cfg.patch_size[1] * cfg.vae_stride[1]
+ width *= cfg.patch_size[2] * cfg.vae_stride[2]
+ video_length = latents.shape[1]
+ video_length = (video_length - 1) * cfg.vae_stride[0] + 1
+ args.seed = seeds[0]
+
+ # Decode and save
+ save_output(latent[0], args, cfg, height, width, original_base_names)
+
+ elif args.from_file:
+ # Batch mode from file
+ args = setup_args(args)
+
+ # Read prompts from file
+ with open(args.from_file, "r", encoding="utf-8") as f:
+ prompt_lines = f.readlines()
+
+ # Process prompts
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
+ process_batch_prompts(prompts_data, args)
+
+ elif args.interactive:
+ # Interactive mode
+ args = setup_args(args)
+ process_interactive(args)
+
+ else:
+ # Single prompt mode (original behavior)
+ args = setup_args(args)
+ height, width, video_length = check_inputs(args)
+
+ logger.info(
+ f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, "
+ f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}"
+ )
+
+ # Generate latent
+ gen_settings = get_generation_settings(args)
+ latent = generate(args, gen_settings)
+
+ # Make sure the model is freed from GPU memory
+ gc.collect()
+ clean_memory_on_device(args.device)
+
+ # Save latent and video
+ if args.save_merged_model:
+ return
+
+ # Add batch dimension
+ latent = latent.unsqueeze(0)
+ save_output(latent[0], args, WAN_CONFIGS[args.task], height, width)
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/wan_train_network.py b/wan_train_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..4061659df12aafeadad5fc516e673d4943ab388d
--- /dev/null
+++ b/wan_train_network.py
@@ -0,0 +1,444 @@
+import argparse
+from typing import Optional
+from PIL import Image
+
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from tqdm import tqdm
+from accelerate import Accelerator, init_empty_weights
+
+from dataset.image_video_dataset import ARCHITECTURE_WAN, ARCHITECTURE_WAN_FULL, load_video
+from hv_generate_video import resize_image_to_bucket
+from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file
+
+import logging
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+from utils import model_utils
+from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen
+from wan.configs import WAN_CONFIGS
+from wan.modules.clip import CLIPModel
+from wan.modules.model import WanModel, detect_wan_sd_dtype, load_wan_model
+from wan.modules.t5 import T5EncoderModel
+from wan.modules.vae import WanVAE
+from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+
+
+class WanNetworkTrainer(NetworkTrainer):
+ def __init__(self):
+ super().__init__()
+
+ # region model specific
+
+ @property
+ def architecture(self) -> str:
+ return ARCHITECTURE_WAN
+
+ @property
+ def architecture_full_name(self) -> str:
+ return ARCHITECTURE_WAN_FULL
+
+ def handle_model_specific_args(self, args):
+ self.config = WAN_CONFIGS[args.task]
+ self._i2v_training = "i2v" in args.task # we cannot use config.i2v because Fun-Control T2V has i2v flag TODO refactor this
+ self._control_training = self.config.is_fun_control
+
+ self.dit_dtype = detect_wan_sd_dtype(args.dit)
+
+ if self.dit_dtype == torch.float16:
+ assert args.mixed_precision in ["fp16", "no"], "DiT weights are in fp16, mixed precision must be fp16 or no"
+ elif self.dit_dtype == torch.bfloat16:
+ assert args.mixed_precision in ["bf16", "no"], "DiT weights are in bf16, mixed precision must be bf16 or no"
+
+ if args.fp8_scaled and self.dit_dtype.itemsize == 1:
+ raise ValueError(
+ "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
+ )
+
+ # dit_dtype cannot be fp8, so we select the appropriate dtype
+ if self.dit_dtype.itemsize == 1:
+ self.dit_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+
+ args.dit_dtype = model_utils.dtype_to_str(self.dit_dtype)
+
+ self.default_guidance_scale = 1.0 # not used
+
+ def process_sample_prompts(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ sample_prompts: str,
+ ):
+ config = self.config
+ device = accelerator.device
+ t5_path, clip_path, fp8_t5 = args.t5, args.clip, args.fp8_t5
+
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
+ prompts = load_prompts(sample_prompts)
+
+ def encode_for_text_encoder(text_encoder):
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
+ # with accelerator.autocast(), torch.no_grad(): # this causes NaN if dit_dtype is fp16
+ t5_dtype = config.t5_dtype
+ with torch.amp.autocast(device_type=device.type, dtype=t5_dtype), torch.no_grad():
+ for prompt_dict in prompts:
+ if "negative_prompt" not in prompt_dict:
+ prompt_dict["negative_prompt"] = self.config["sample_neg_prompt"]
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", None)]:
+ if p is None:
+ continue
+ if p not in sample_prompts_te_outputs:
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
+
+ prompt_outputs = text_encoder([p], device)
+ sample_prompts_te_outputs[p] = prompt_outputs
+
+ return sample_prompts_te_outputs
+
+ # Load Text Encoder 1 and encode
+ logger.info(f"loading T5: {t5_path}")
+ t5 = T5EncoderModel(text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=t5_path, fp8=fp8_t5)
+
+ logger.info("encoding with Text Encoder 1")
+ te_outputs_1 = encode_for_text_encoder(t5)
+ del t5
+
+ # load CLIP and encode image (for I2V training)
+ # Note: VAE encoding is done in do_inference() for I2V training, because we have VAE in the pipeline. Control video is also done in do_inference()
+ sample_prompts_image_embs = {}
+ for prompt_dict in prompts:
+ if prompt_dict.get("image_path", None) is not None and self.i2v_training:
+ sample_prompts_image_embs[prompt_dict["image_path"]] = None # this will be replaced with CLIP context
+
+ if len(sample_prompts_image_embs) > 0:
+ logger.info(f"loading CLIP: {clip_path}")
+ assert clip_path is not None, "CLIP path is required for I2V training / I2V学習にはCLIPのパスが必要です"
+ clip = CLIPModel(dtype=config.clip_dtype, device=device, weight_path=clip_path)
+ clip.model.to(device)
+
+ logger.info(f"Encoding image to CLIP context")
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
+ for image_path in sample_prompts_image_embs:
+ logger.info(f"Encoding image: {image_path}")
+ img = Image.open(image_path).convert("RGB")
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) # -1 to 1
+ clip_context = clip.visual([img[:, None, :, :]])
+ sample_prompts_image_embs[image_path] = clip_context
+
+ del clip
+ clean_memory_on_device(device)
+
+ # prepare sample parameters
+ sample_parameters = []
+ for prompt_dict in prompts:
+ prompt_dict_copy = prompt_dict.copy()
+
+ p = prompt_dict.get("prompt", "")
+ prompt_dict_copy["t5_embeds"] = te_outputs_1[p][0]
+
+ p = prompt_dict.get("negative_prompt", None)
+ if p is not None:
+ prompt_dict_copy["negative_t5_embeds"] = te_outputs_1[p][0]
+
+ p = prompt_dict.get("image_path", None)
+ if p is not None and self.i2v_training:
+ prompt_dict_copy["clip_embeds"] = sample_prompts_image_embs[p]
+
+ sample_parameters.append(prompt_dict_copy)
+
+ clean_memory_on_device(accelerator.device)
+
+ return sample_parameters
+
+ def do_inference(
+ self,
+ accelerator,
+ args,
+ sample_parameter,
+ vae,
+ dit_dtype,
+ transformer,
+ discrete_flow_shift,
+ sample_steps,
+ width,
+ height,
+ frame_count,
+ generator,
+ do_classifier_free_guidance,
+ guidance_scale,
+ cfg_scale,
+ image_path=None,
+ control_video_path=None,
+ ):
+ """architecture dependent inference"""
+ model: WanModel = transformer
+ device = accelerator.device
+ if cfg_scale is None:
+ cfg_scale = 5.0
+ do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0
+
+ # Calculate latent video length based on VAE version
+ latent_video_length = (frame_count - 1) // self.config["vae_stride"][0] + 1
+
+ # Get embeddings
+ context = sample_parameter["t5_embeds"].to(device=device)
+ if do_classifier_free_guidance:
+ context_null = sample_parameter["negative_t5_embeds"].to(device=device)
+ else:
+ context_null = None
+
+ num_channels_latents = 16 # model.in_dim
+ vae_scale_factor = self.config["vae_stride"][1]
+
+ # Initialize latents
+ lat_h = height // vae_scale_factor
+ lat_w = width // vae_scale_factor
+ shape_or_frame = (1, num_channels_latents, 1, lat_h, lat_w)
+ latents = []
+ for _ in range(latent_video_length):
+ latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=torch.float32))
+ latents = torch.cat(latents, dim=2)
+
+ image_latents = None
+ if self.i2v_training or self.control_training:
+ # Move VAE to the appropriate device for sampling: consider to cache image latents in CPU in advance
+ vae.to(device)
+ vae.eval()
+
+ if self.i2v_training:
+ image = Image.open(image_path)
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).float() # C, 1, H, W
+ image = image / 127.5 - 1 # -1 to 1
+
+ # Create mask for the required number of frames
+ msk = torch.ones(1, frame_count, lat_h, lat_w, device=device)
+ msk[:, 1:] = 0
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
+ msk = msk.transpose(1, 2) # B, C, T, H, W
+
+ with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
+ # Zero padding for the required number of frames only
+ padding_frames = frame_count - 1 # The first frame is the input image
+ image = torch.concat([image, torch.zeros(3, padding_frames, height, width)], dim=1).to(device=device)
+ y = vae.encode([image])[0]
+
+ y = y[:, :latent_video_length] # may be not needed
+ y = y.unsqueeze(0) # add batch dim
+ image_latents = torch.concat([msk, y], dim=1)
+
+ if self.control_training:
+ # Control video
+ video = load_video(control_video_path, 0, frame_count, bucket_reso=(width, height)) # list of frames
+ video = np.stack(video, axis=0) # F, H, W, C
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # C, F, H, W
+ video = video / 127.5 - 1 # -1 to 1
+ video = video.to(device=device)
+
+ with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
+ control_latents = vae.encode([video])[0]
+ control_latents = control_latents[:, :latent_video_length]
+ control_latents = control_latents.unsqueeze(0) # add batch dim
+
+ # We supports Wan2.1-Fun-Control only
+ if image_latents is not None:
+ image_latents = image_latents[:, 4:] # remove mask for Wan2.1-Fun-Control
+ image_latents[:, :, 1:] = 0 # remove except the first frame
+ else:
+ image_latents = torch.zeros_like(control_latents) # B, C, F, H, W
+
+ image_latents = torch.concat([control_latents, image_latents], dim=1) # B, C, F, H, W
+
+ vae.to("cpu")
+ clean_memory_on_device(device)
+
+ # use the default value for num_train_timesteps (1000)
+ scheduler = FlowUniPCMultistepScheduler(shift=1, use_dynamic_shifting=False)
+ scheduler.set_timesteps(sample_steps, device=device, shift=discrete_flow_shift)
+ timesteps = scheduler.timesteps
+
+ # Generate noise for the required number of frames only
+ noise = torch.randn(16, latent_video_length, lat_h, lat_w, dtype=torch.float32, generator=generator, device=device).to(
+ "cpu"
+ )
+
+ # prepare the model input
+ max_seq_len = latent_video_length * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
+ arg_c = {"context": [context], "seq_len": max_seq_len}
+ arg_null = {"context": [context_null], "seq_len": max_seq_len}
+
+ if self.i2v_training:
+ arg_c["clip_fea"] = sample_parameter["clip_embeds"].to(device=device, dtype=dit_dtype)
+ arg_null["clip_fea"] = arg_c["clip_fea"]
+ if self.i2v_training or self.control_training:
+ arg_c["y"] = image_latents
+ arg_null["y"] = image_latents
+
+ # Wrap the inner loop with tqdm to track progress over timesteps
+ prompt_idx = sample_parameter.get("enum", 0)
+ latent = noise
+ with torch.no_grad():
+ for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")):
+ latent_model_input = [latent.to(device=device)]
+ timestep = t.unsqueeze(0)
+
+ with accelerator.autocast():
+ noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to("cpu")
+ if do_classifier_free_guidance:
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to("cpu")
+ else:
+ noise_pred_uncond = None
+
+ if do_classifier_free_guidance:
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
+
+ temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator)[0]
+ latent = temp_x0.squeeze(0)
+
+ # Move VAE to the appropriate device for sampling
+ vae.to(device)
+ vae.eval()
+
+ # Decode latents to video
+ logger.info(f"Decoding video from latents: {latent.shape}")
+ latent = latent.unsqueeze(0) # add batch dim
+ latent = latent.to(device=device)
+
+ with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
+ video = vae.decode(latent)[0] # vae returns list
+ video = video.unsqueeze(0) # add batch dim
+ del latent
+
+ logger.info(f"Decoding complete")
+ video = video.to(torch.float32).cpu()
+ video = (video / 2 + 0.5).clamp(0, 1) # -1 to 1 -> 0 to 1
+
+ vae.to("cpu")
+ clean_memory_on_device(device)
+
+ return video
+
+ def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str):
+ vae_path = args.vae
+
+ logger.info(f"Loading VAE model from {vae_path}")
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
+ vae = WanVAE(vae_path=vae_path, device="cpu", dtype=vae_dtype, cache_device=cache_device)
+ return vae
+
+ def load_transformer(
+ self,
+ accelerator: Accelerator,
+ args: argparse.Namespace,
+ dit_path: str,
+ attn_mode: str,
+ split_attn: bool,
+ loading_device: str,
+ dit_weight_dtype: Optional[torch.dtype],
+ ):
+ model = load_wan_model(
+ self.config, accelerator.device, dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.fp8_scaled
+ )
+ return model
+
+ def scale_shift_latents(self, latents):
+ return latents
+
+ def call_dit(
+ self,
+ args: argparse.Namespace,
+ accelerator: Accelerator,
+ transformer,
+ latents: torch.Tensor,
+ batch: dict[str, torch.Tensor],
+ noise: torch.Tensor,
+ noisy_model_input: torch.Tensor,
+ timesteps: torch.Tensor,
+ network_dtype: torch.dtype,
+ ):
+ model: WanModel = transformer
+
+ # I2V training and Control training
+ image_latents = None
+ clip_fea = None
+ if self.i2v_training:
+ image_latents = batch["latents_image"]
+ image_latents = image_latents.to(device=accelerator.device, dtype=network_dtype)
+ clip_fea = batch["clip"]
+ clip_fea = clip_fea.to(device=accelerator.device, dtype=network_dtype)
+ if self.control_training:
+ control_latents = batch["latents_control"]
+ control_latents = control_latents.to(device=accelerator.device, dtype=network_dtype)
+ if image_latents is not None:
+ image_latents = image_latents[:, 4:] # remove mask for Wan2.1-Fun-Control
+ image_latents[:, :, 1:] = 0 # remove except the first frame
+ else:
+ image_latents = torch.zeros_like(control_latents) # B, C, F, H, W
+ image_latents = torch.concat([control_latents, image_latents], dim=1) # B, C, F, H, W
+ control_latents = None
+
+ context = [t.to(device=accelerator.device, dtype=network_dtype) for t in batch["t5"]]
+
+ # ensure the hidden state will require grad
+ if args.gradient_checkpointing:
+ noisy_model_input.requires_grad_(True)
+ for t in context:
+ t.requires_grad_(True)
+ if image_latents is not None:
+ image_latents.requires_grad_(True)
+ if clip_fea is not None:
+ clip_fea.requires_grad_(True)
+
+ # call DiT
+ lat_f, lat_h, lat_w = latents.shape[2:5]
+ seq_len = lat_f * lat_h * lat_w // (self.config.patch_size[0] * self.config.patch_size[1] * self.config.patch_size[2])
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
+ with accelerator.autocast():
+ model_pred = model(noisy_model_input, t=timesteps, context=context, clip_fea=clip_fea, seq_len=seq_len, y=image_latents)
+ model_pred = torch.stack(model_pred, dim=0) # list to tensor
+
+ # flow matching loss
+ target = noise - latents
+
+ return model_pred, target
+
+ # endregion model specific
+
+
+def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Wan2.1 specific parser setup"""
+ parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
+ parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
+ parser.add_argument(
+ "--clip",
+ type=str,
+ default=None,
+ help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
+ )
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
+ return parser
+
+
+if __name__ == "__main__":
+ parser = setup_parser_common()
+ parser = wan_setup_parser(parser)
+
+ args = parser.parse_args()
+ args = read_config_from_file(args, parser)
+
+ args.dit_dtype = None # automatically detected
+ if args.vae_dtype is None:
+ args.vae_dtype = "bfloat16" # make bfloat16 as default for VAE
+
+ trainer = WanNetworkTrainer()
+ trainer.train(args)