Spaces:
Running
on
Zero
Running
on
Zero
import os | |
# PyTorch 2.8 (temporary hack) | |
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') | |
# --- 1. Model Download and Setup (Diffusers Backend) --- | |
import spaces | |
import torch | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
from diffusers.utils.export_utils import export_to_video | |
import gradio as gr | |
import tempfile | |
import numpy as np | |
from PIL import Image | |
import random | |
import gc | |
from gradio_client import Client, handle_file # Import for API call | |
# Import the optimization function from the separate file | |
from optimization import optimize_pipeline_ | |
# --- Constants and Model Loading --- | |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
# --- NEW: Flexible Dimension Constants --- | |
MAX_DIMENSION = 832 | |
MIN_DIMENSION = 480 | |
DIMENSION_MULTIPLE = 16 | |
SQUARE_SIZE = 480 | |
MAX_SEED = np.iinfo(np.int32).max | |
FIXED_FPS = 16 | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 81 | |
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1) | |
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1) | |
default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝," | |
print("Loading models into memory. This may take a few minutes...") | |
pipe = WanImageToVideoPipeline.from_pretrained( | |
MODEL_ID, | |
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
subfolder='transformer', | |
torch_dtype=torch.bfloat16, | |
device_map='cuda', | |
), | |
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
subfolder='transformer_2', | |
torch_dtype=torch.bfloat16, | |
device_map='cuda', | |
), | |
torch_dtype=torch.bfloat16, | |
) | |
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0) | |
pipe.to('cuda') | |
print("Optimizing pipeline...") | |
for i in range(3): | |
gc.collect() | |
torch.cuda.synchronize() | |
torch.cuda.empty_cache() | |
optimize_pipeline_(pipe, | |
image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), | |
prompt='prompt', | |
height=MIN_DIMENSION, | |
width=MAX_DIMENSION, | |
num_frames=MAX_FRAMES_MODEL, | |
) | |
print("All models loaded and optimized. Gradio app is ready.") | |
# --- 2. Image Processing and Application Logic --- | |
def generate_end_frame(start_img, gen_prompt, progress=gr.Progress(track_tqdm=True)): | |
"""Calls an external Gradio API to generate an image.""" | |
if start_img is None: | |
raise gr.Error("Please provide a Start Frame first.") | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
raise gr.Error("HF_TOKEN not found in environment variables. Please set it in your Space secrets.") | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
start_img.save(tmpfile.name) | |
tmp_path = tmpfile.name | |
progress(0.1, desc="Connecting to image generation API...") | |
client = Client("multimodalart/nano-banana-private") | |
progress(0.5, desc=f"Generating with prompt: '{gen_prompt}'...") | |
try: | |
result = client.predict( | |
prompt=gen_prompt, | |
images=[ | |
{"image": handle_file(tmp_path)} | |
], | |
manual_token=hf_token, | |
api_name="/unified_image_generator" | |
) | |
finally: | |
os.remove(tmp_path) | |
progress(1.0, desc="Done!") | |
print(result) | |
return result | |
def switch_to_upload_tab(): | |
"""Returns a gr.Tabs update to switch to the first tab.""" | |
return gr.Tabs(selected="upload_tab") | |
def process_image_for_video(image: Image.Image) -> Image.Image: | |
""" | |
Resizes an image based on the following rules for video generation: | |
1. The longest side will be scaled down to MAX_DIMENSION if it's larger. | |
2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller. | |
3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE. | |
4. Square images are resized to a fixed SQUARE_SIZE. | |
The aspect ratio is preserved as closely as possible. | |
""" | |
width, height = image.size | |
# Rule 4: Handle square images | |
if width == height: | |
return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS) | |
# Determine target dimensions while preserving aspect ratio | |
aspect_ratio = width / height | |
new_width, new_height = width, height | |
# Rule 1: Scale down if too large | |
if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION: | |
if aspect_ratio > 1: # Landscape | |
scale = MAX_DIMENSION / new_width | |
else: # Portrait | |
scale = MAX_DIMENSION / new_height | |
new_width *= scale | |
new_height *= scale | |
# Rule 2: Scale up if too small | |
if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION: | |
if aspect_ratio > 1: # Landscape | |
scale = MIN_DIMENSION / new_height | |
else: # Portrait | |
scale = MIN_DIMENSION / new_width | |
new_width *= scale | |
new_height *= scale | |
# Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE | |
final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) | |
final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) | |
# Ensure final dimensions are at least the minimum | |
final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE) | |
final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE) | |
return image.resize((final_width, final_height), Image.Resampling.LANCZOS) | |
def resize_and_crop_to_match(target_image, reference_image): | |
"""Resizes and center-crops the target image to match the reference image's dimensions.""" | |
ref_width, ref_height = reference_image.size | |
target_width, target_height = target_image.size | |
scale = max(ref_width / target_width, ref_height / target_height) | |
new_width, new_height = int(target_width * scale), int(target_height * scale) | |
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 | |
return resized.crop((left, top, left + ref_width, top + ref_height)) | |
def generate_video( | |
start_image_pil, | |
end_image_pil, | |
prompt, | |
negative_prompt=default_negative_prompt, | |
duration_seconds=2.1, | |
steps=8, | |
guidance_scale=1, | |
guidance_scale_2=1, | |
seed=42, | |
randomize_seed=False, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
""" | |
Generates a video by interpolating between a start and end image, guided by a text prompt, | |
using the diffusers Wan2.2 pipeline. | |
""" | |
if start_image_pil is None or end_image_pil is None: | |
raise gr.Error("Please upload both a start and an end image.") | |
progress(0.1, desc="Preprocessing images...") | |
# Step 1: Process the start image to get our target dimensions based on the new rules. | |
processed_start_image = process_image_for_video(start_image_pil) | |
# Step 2: Make the end image match the *exact* dimensions of the processed start image. | |
processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image) | |
target_height, target_width = processed_start_image.height, processed_start_image.width | |
# Handle seed and frame count | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...") | |
output_frames_list = pipe( | |
image=processed_start_image, | |
last_image=processed_end_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=target_height, | |
width=target_width, | |
num_frames=num_frames, | |
guidance_scale=float(guidance_scale), | |
guidance_scale_2=float(guidance_scale_2), | |
num_inference_steps=int(steps), | |
generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
).frames[0] | |
progress(0.9, desc="Encoding and saving video...") | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
video_path = tmpfile.name | |
export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
progress(1.0, desc="Done!") | |
return video_path, current_seed | |
# --- 3. Gradio User Interface --- | |
css = ''' | |
.fillable{max-width: 1100px !important} | |
.dark .progress-text {color: white} | |
#general_items{margin-top: 2em} | |
#group_all{overflow:visible} | |
#group_all .styler{overflow:visible} | |
#group_tabs .tabitem{padding: 0} | |
.tab-wrapper{margin-top: -33px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;} | |
#component-9-button{width: 50%;justify-content: center} | |
#component-11-button{width: 50%;justify-content: center} | |
#or_item{text-align: center; padding-top: 1em; padding-bottom: 1em; font-size: 1.1em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)} | |
#fivesec{margin-top: 5em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)} | |
''' | |
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: | |
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") | |
gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA") | |
with gr.Row(elem_id="general_items"): | |
with gr.Column(): | |
with gr.Group(elem_id="group_all"): | |
with gr.Row(): | |
start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"]) | |
# Capture the Tabs component in a variable and assign IDs to tabs | |
with gr.Tabs(elem_id="group_tabs") as tabs: | |
with gr.TabItem("Upload", id="upload_tab"): | |
end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"]) | |
with gr.TabItem("Generate", id="generate_tab"): | |
generate_5seconds = gr.Button("Generate scene 5 seconds in the future", elem_id="fivesec") | |
gr.Markdown("Generate a custom end-frame with an edit model like [Nano Banana](https://huggingface.co/spaces/multimodalart/nano-banana) or [Qwen Image Edit](https://huggingface.co/spaces/multimodalart/Qwen-Image-Edit-Fast)", elem_id="or_item") | |
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") | |
with gr.Accordion("Advanced Settings", open=False): | |
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") | |
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) | |
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps") | |
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise") | |
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise") | |
with gr.Row(): | |
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) | |
generate_button = gr.Button("Generate Video", variant="primary") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video", autoplay=True) | |
# Main video generation button | |
ui_inputs = [ | |
start_image, | |
end_image, | |
prompt, | |
negative_prompt_input, | |
duration_seconds_input, | |
steps_slider, | |
guidance_scale_input, | |
guidance_scale_2_input, | |
seed_input, | |
randomize_seed_checkbox | |
] | |
ui_outputs = [output_video, seed_input] | |
generate_button.click( | |
fn=generate_video, | |
inputs=ui_inputs, | |
outputs=ui_outputs | |
) | |
generate_5seconds.click( | |
fn=switch_to_upload_tab, | |
inputs=None, | |
outputs=[tabs] | |
).then( | |
fn=lambda img: generate_end_frame(img, "this image is a still frame from a movie. generate a new frame with what happens on this scene 5 seconds in the future"), | |
inputs=[start_image], | |
outputs=[end_image] | |
).success( | |
fn=generate_video, | |
inputs=ui_inputs, | |
outputs=ui_outputs | |
) | |
gr.Examples( | |
examples=[ | |
["poli_tower.png", "tower_takes_off.png", "the man turns around"], | |
["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], | |
["capyabara_zoomed.png", "capyabara.webp", "a dramatic dolly zoom"], | |
], | |
inputs=[start_image, end_image, prompt], | |
outputs=ui_outputs, | |
fn=generate_video, | |
cache_examples="lazy", | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |