Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import random | |
| import sys | |
| import tempfile | |
| from typing import Sequence, Mapping, Any, Union | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from comfy import model_management | |
| def hf_hub_download_local(repo_id, filename, local_dir, **kwargs): | |
| downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) | |
| os.makedirs(local_dir, exist_ok=True) | |
| base_filename = os.path.basename(filename) | |
| target_path = os.path.join(local_dir, base_filename) | |
| if os.path.exists(target_path) or os.path.islink(target_path): | |
| os.remove(target_path) | |
| os.symlink(downloaded_path, target_path) | |
| return target_path | |
| # --- Model Downloads --- | |
| print("Downloading models from Hugging Face Hub...") | |
| text_encoder_repo = hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders") | |
| print(text_encoder_repo) | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae") | |
| hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision") | |
| hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras") | |
| hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras") | |
| print("Downloads complete.") | |
| # --- Image Processing Functions --- | |
| def calculate_video_dimensions(width, height, max_size=832, min_size=480): | |
| """ | |
| Calculate video dimensions based on input image size. | |
| Larger dimension becomes max_size, smaller becomes proportional. | |
| If square, use min_size x min_size. | |
| Results are rounded to nearest multiple of 16. | |
| """ | |
| # Handle square images | |
| if width == height: | |
| video_width = min_size | |
| video_height = min_size | |
| else: | |
| # Calculate aspect ratio | |
| aspect_ratio = width / height | |
| if width > height: | |
| # Landscape orientation | |
| video_width = max_size | |
| video_height = int(max_size / aspect_ratio) | |
| else: | |
| # Portrait orientation | |
| video_height = max_size | |
| video_width = int(max_size * aspect_ratio) | |
| # Round to nearest multiple of 16 | |
| video_width = round(video_width / 16) * 16 | |
| video_height = round(video_height / 16) * 16 | |
| # Ensure minimum size | |
| video_width = max(video_width, 16) | |
| video_height = max(video_height, 16) | |
| return video_width, video_height | |
| def resize_and_crop_to_match(target_image, reference_image): | |
| """ | |
| Resize and center crop target_image to match reference_image dimensions. | |
| """ | |
| ref_width, ref_height = reference_image.size | |
| target_width, target_height = target_image.size | |
| # Calculate scaling factor to ensure target covers reference dimensions | |
| scale = max(ref_width / target_width, ref_height / target_height) | |
| # Resize target image | |
| new_width = int(target_width * scale) | |
| new_height = int(target_height * scale) | |
| resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Center crop to match reference dimensions | |
| left = (new_width - ref_width) // 2 | |
| top = (new_height - ref_height) // 2 | |
| right = left + ref_width | |
| bottom = top + ref_height | |
| cropped = resized.crop((left, top, right, bottom)) | |
| return cropped | |
| # --- Boilerplate code from the original script --- | |
| def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
| """Returns the value at the given index of a sequence or mapping. | |
| If the object is a sequence (like list or string), returns the value at the given index. | |
| If the object is a mapping (like a dictionary), returns the value at the index-th key. | |
| Some return a dictionary, in these cases, we look for the "results" key | |
| Args: | |
| obj (Union[Sequence, Mapping]): The object to retrieve the value from. | |
| index (int): The index of the value to retrieve. | |
| Returns: | |
| Any: The value at the given index. | |
| Raises: | |
| IndexError: If the index is out of bounds for the object and the object is not a mapping. | |
| """ | |
| try: | |
| return obj[index] | |
| except KeyError: | |
| # This is a fallback for custom node outputs that might be dictionaries | |
| if isinstance(obj, Mapping) and "result" in obj: | |
| return obj["result"][index] | |
| raise | |
| def find_path(name: str, path: str = None) -> str: | |
| """ | |
| Recursively looks at parent folders starting from the given path until it finds the given name. | |
| Returns the path as a Path object if found, or None otherwise. | |
| """ | |
| if path is None: | |
| path = os.getcwd() | |
| if name in os.listdir(path): | |
| path_name = os.path.join(path, name) | |
| print(f"'{name}' found: {path_name}") | |
| return path_name | |
| parent_directory = os.path.dirname(path) | |
| if parent_directory == path: | |
| return None | |
| return find_path(name, parent_directory) | |
| def add_comfyui_directory_to_sys_path() -> None: | |
| """ | |
| Add 'ComfyUI' to the sys.path | |
| """ | |
| comfyui_path = find_path("ComfyUI") | |
| if comfyui_path is not None and os.path.isdir(comfyui_path): | |
| sys.path.append(comfyui_path) | |
| print(f"'{comfyui_path}' added to sys.path") | |
| else: | |
| print("Could not find ComfyUI directory. Please run from a parent folder of ComfyUI.") | |
| def add_extra_model_paths() -> None: | |
| """ | |
| Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path. | |
| """ | |
| try: | |
| from main import load_extra_path_config | |
| except ImportError: | |
| print( | |
| "Could not import load_extra_path_config from main.py. This might be okay if you don't use it." | |
| ) | |
| return | |
| extra_model_paths = find_path("extra_model_paths.yaml") | |
| if extra_model_paths is not None: | |
| load_extra_path_config(extra_model_paths) | |
| else: | |
| print("Could not find an optional 'extra_model_paths.yaml' config file.") | |
| def import_custom_nodes() -> None: | |
| """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS | |
| This function sets up a new asyncio event loop, initializes the PromptServer, | |
| creates a PromptQueue, and initializes the custom nodes. | |
| """ | |
| import asyncio | |
| import execution | |
| from nodes import init_extra_nodes | |
| import server | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| server_instance = server.PromptServer(loop) | |
| execution.PromptQueue(server_instance) | |
| loop.run_until_complete(init_extra_nodes(init_custom_nodes=True)) | |
| # --- Model Loading and Caching --- | |
| MODELS_AND_NODES = {} | |
| print("Setting up ComfyUI paths...") | |
| add_comfyui_directory_to_sys_path() | |
| add_extra_model_paths() | |
| print("Importing custom nodes...") | |
| import_custom_nodes() | |
| # Now that paths are set up, we can import from nodes | |
| from nodes import NODE_CLASS_MAPPINGS | |
| global folder_paths # Make folder_paths globally accessible | |
| import folder_paths | |
| print("Loading models into memory. This may take a few minutes...") | |
| # Load Text-to-Image models (CLIP, UNETs, VAE) | |
| cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]() | |
| MODELS_AND_NODES["clip"] = cliploader.load_clip( | |
| clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu" | |
| ) | |
| unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
| unet_low_noise = unetloader.load_unet( | |
| unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", | |
| weight_dtype="default", | |
| ) | |
| unet_high_noise = unetloader.load_unet( | |
| unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", | |
| weight_dtype="default", | |
| ) | |
| vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
| MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors") | |
| # Load LoRAs | |
| loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() | |
| MODELS_AND_NODES["model_low_noise"] = loraloadermodelonly.load_lora_model_only( | |
| lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", | |
| strength_model=0.8, | |
| model=get_value_at_index(unet_low_noise, 0), | |
| ) | |
| MODELS_AND_NODES["model_high_noise"] = loraloadermodelonly.load_lora_model_only( | |
| lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", | |
| strength_model=0.8, | |
| model=get_value_at_index(unet_high_noise, 0), | |
| ) | |
| # Load Vision model | |
| clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() | |
| MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip( | |
| clip_name="clip_vision_h.safetensors" | |
| ) | |
| # Instantiate all required node classes | |
| MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
| MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]() | |
| MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() | |
| MODELS_AND_NODES["ModelSamplingSD3"] = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]() | |
| MODELS_AND_NODES["PathchSageAttentionKJ"] = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]() | |
| MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]() | |
| MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]() | |
| MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]() | |
| MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]() | |
| MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]() | |
| print("Pre-loading main models onto GPU...") | |
| model_loaders = [ | |
| MODELS_AND_NODES["clip"], | |
| MODELS_AND_NODES["vae"], | |
| MODELS_AND_NODES["model_low_noise"], # This is the UNET + LoRA | |
| MODELS_AND_NODES["model_high_noise"], # This is the other UNET + LoRA | |
| MODELS_AND_NODES["clip_vision"], | |
| ] | |
| model_management.load_models_gpu([ | |
| loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders | |
| ]) | |
| print("All models loaded successfully!") | |
| # --- Main Video Generation Logic --- | |
| def generate_video( | |
| start_image_pil, | |
| end_image_pil, | |
| prompt, | |
| negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", | |
| duration=33, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """ | |
| The main function to generate a video based on user inputs. | |
| This function is called every time the user clicks the 'Generate' button. | |
| """ | |
| FPS = 16 | |
| # Process images: resize and crop second image to match first | |
| # The first image determines the dimensions | |
| processed_start_image = start_image_pil.copy() | |
| processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil) | |
| # Calculate video dimensions based on the first image | |
| video_width, video_height = calculate_video_dimensions( | |
| processed_start_image.width, | |
| processed_start_image.height | |
| ) | |
| print(f"Input image size: {processed_start_image.width}x{processed_start_image.height}") | |
| print(f"Video dimensions: {video_width}x{video_height}") | |
| clip = MODELS_AND_NODES["clip"] | |
| vae = MODELS_AND_NODES["vae"] | |
| model_low_noise = MODELS_AND_NODES["model_low_noise"] | |
| model_high_noise = MODELS_AND_NODES["model_high_noise"] | |
| clip_vision = MODELS_AND_NODES["clip_vision"] | |
| cliptextencode = MODELS_AND_NODES["CLIPTextEncode"] | |
| loadimage = MODELS_AND_NODES["LoadImage"] | |
| clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"] | |
| modelsamplingsd3 = MODELS_AND_NODES["ModelSamplingSD3"] | |
| pathchsageattentionkj = MODELS_AND_NODES["PathchSageAttentionKJ"] | |
| wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"] | |
| ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"] | |
| vaedecode = MODELS_AND_NODES["VAEDecode"] | |
| createvideo = MODELS_AND_NODES["CreateVideo"] | |
| savevideo = MODELS_AND_NODES["SaveVideo"] | |
| # Save processed images to temporary files | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as start_file, \ | |
| tempfile.NamedTemporaryFile(suffix=".png", delete=False) as end_file: | |
| processed_start_image.save(start_file.name) | |
| processed_end_image.save(end_file.name) | |
| start_image_path = start_file.name | |
| end_image_path = end_file.name | |
| with torch.inference_mode(): | |
| progress(0.1, desc="Encoding text and images...") | |
| # --- Workflow execution --- | |
| positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0)) | |
| negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0)) | |
| start_image_loaded = loadimage.load_image(image=start_image_path) | |
| end_image_loaded = loadimage.load_image(image=end_image_path) | |
| clip_vision_encoded_start = clipvisionencode.encode( | |
| crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0) | |
| ) | |
| clip_vision_encoded_end = clipvisionencode.encode( | |
| crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0) | |
| ) | |
| progress(0.2, desc="Preparing initial latents...") | |
| initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED( | |
| width=video_width, height=video_height, length=duration, batch_size=1, | |
| positive=get_value_at_index(positive_conditioning, 0), | |
| negative=get_value_at_index(negative_conditioning, 0), | |
| vae=get_value_at_index(vae, 0), | |
| clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0), | |
| clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0), | |
| start_image=get_value_at_index(start_image_loaded, 0), | |
| end_image=get_value_at_index(end_image_loaded, 0), | |
| ) | |
| progress(0.3, desc="Patching models...") | |
| model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0)) | |
| model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0)) | |
| model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0)) | |
| model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0)) | |
| progress(0.5, desc="Running KSampler (Step 1/2)...") | |
| latent_step1 = ksampleradvanced.sample( | |
| add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, | |
| sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, | |
| return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0), | |
| positive=get_value_at_index(initial_latents, 0), | |
| negative=get_value_at_index(initial_latents, 1), | |
| latent_image=get_value_at_index(initial_latents, 2), | |
| ) | |
| progress(0.7, desc="Running KSampler (Step 2/2)...") | |
| latent_step2 = ksampleradvanced.sample( | |
| add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, | |
| sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, | |
| return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0), | |
| positive=get_value_at_index(initial_latents, 0), | |
| negative=get_value_at_index(initial_latents, 1), | |
| latent_image=get_value_at_index(latent_step1, 0), | |
| ) | |
| progress(0.8, desc="Decoding VAE...") | |
| decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0)) | |
| progress(0.9, desc="Creating and saving video...") | |
| video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0)) | |
| # Save the video to ComfyUI's output directory | |
| save_result = savevideo.save_video( | |
| filename_prefix="GradioVideo", format="mp4", codec="h264", | |
| video=get_value_at_index(video_data, 0), | |
| ) | |
| progress(1.0, desc="Done!") | |
| return f"output/{save_result['ui']['images'][0]['filename']}" | |
| css = ''' | |
| .fillable{max-width: 1100px !important} | |
| .dark .progress-text {color: white} | |
| ''' | |
| with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: | |
| gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") | |
| gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) fast LoRA on ZeroGPU") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| with gr.Row(): | |
| start_image = gr.Image(type="pil", label="Start Frame") | |
| end_image = gr.Image(type="pil", label="End Frame") | |
| prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") | |
| with gr.Accordion("Advanced Settings", open=False, visible=False): | |
| duration = gr.Radio( | |
| [("Short (2s)", 33), ("Mid (4s)", 66)], | |
| value=33, | |
| label="Video Duration", | |
| visible=False | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", | |
| visible=False | |
| ) | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video", autoplay=True) | |
| generate_button.click( | |
| fn=generate_video, | |
| inputs=[start_image, end_image, prompt, negative_prompt, duration], | |
| outputs=output_video | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["poli_tower.png", "tower_takes_off.png", "the man turns around"], | |
| ["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"], | |
| ], | |
| inputs=[start_image, end_image, prompt], | |
| outputs=output_video, | |
| fn=generate_video, | |
| cache_examples="lazy", | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(share=True) |