djessica's picture
Rename app-todo.py to app.py
06bb5bc verified
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import FlowMatchEulerDiscreteScheduler
from optimization import optimize_pipeline_
from diffusers import QwenImageEditPlusPipeline
import math
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
import os
import gradio as gr
from gradio_client import Client, handle_file
import tempfile
from typing import Optional, Tuple, Any
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
# ================================================================================================
# added by DJ
# TEMP DIRECTORY MANAGEMENT
TEMP_DIR = tempfile.gettempdir()
# 用於追蹤生成的影片
generated_files = []
def find_gradio_temp_files():
"""尋找 Gradio 暫存目錄中的檔案"""
gradio_files = []
# Gradio 通常會在臨時目錄下創建 gradio 子目錄
possible_gradio_dirs = [
os.path.join(TEMP_DIR, "gradio"),
os.path.join(TEMP_DIR, "tmp"),
TEMP_DIR
]
for base_dir in possible_gradio_dirs:
if os.path.exists(base_dir):
try:
for root, dirs, files in os.walk(base_dir):
for file in files:
# 只列出圖片和影片檔案
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.mp4', '.avi', '.mov', '.mkv')):
filepath = os.path.join(root, file)
gradio_files.append(filepath)
except PermissionError:
continue
return gradio_files
def list_temp_files():
"""列出 Gradio 暫存的圖片和生成的影片"""
try:
result = f"📁 臨時目錄路徑: {TEMP_DIR}\n"
result += "=" * 60 + "\n\n"
# 尋找 Gradio 暫存的檔案
gradio_files = find_gradio_temp_files()
# 分類檔案
image_files = []
video_files = []
for filepath in gradio_files:
filename = os.path.basename(filepath)
ext = os.path.splitext(filename)[1].lower()
if ext in ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp']:
image_files.append(filepath)
elif ext in ['.mp4', '.avi', '.mov', '.mkv']:
video_files.append(filepath)
# 顯示上傳的圖片(Gradio 暫存)
if image_files:
result += f"📤 Gradio 暫存的圖片 ({len(image_files)} 個):\n"
result += "-" * 60 + "\n"
for i, filepath in enumerate(image_files, 1):
if os.path.exists(filepath):
filename = os.path.basename(filepath)
size = os.path.getsize(filepath)
size_mb = size / (1024 * 1024)
# 顯示相對路徑以便識別
rel_path = os.path.relpath(filepath, TEMP_DIR)
result += f"{i}. {filename}\n"
result += f" 📍 {rel_path} ({size_mb:.2f} MB)\n"
result += "\n"
else:
result += "📤 Gradio 暫存的圖片: 無\n\n"
# 顯示生成的影片
if video_files:
result += f"🎬 生成的影片 ({len(video_files)} 個):\n"
result += "-" * 60 + "\n"
for i, filepath in enumerate(video_files, 1):
if os.path.exists(filepath):
filename = os.path.basename(filepath)
size = os.path.getsize(filepath)
size_mb = size / (1024 * 1024)
rel_path = os.path.relpath(filepath, TEMP_DIR)
result += f"{i}. {filename}\n"
result += f" 📍 {rel_path} ({size_mb:.2f} MB)\n"
result += "\n"
else:
result += "🎬 生成的影片: 無\n\n"
# 統計
total_files = len(image_files) + len(video_files)
result += "=" * 60 + "\n"
result += f"📊 統計: 圖片 {len(image_files)} 個 | 影片 {len(video_files)} 個 | 總計 {total_files} 個"
return result
except Exception as e:
return f"❌ 錯誤: 無法讀取檔案列表 - {str(e)}"
def clear_temp_files():
"""清除 Gradio 暫存的圖片和生成的影片"""
try:
deleted_count = 0
error_count = 0
# 尋找所有要刪除的檔案
gradio_files = find_gradio_temp_files()
for filepath in gradio_files:
try:
if os.path.exists(filepath):
os.remove(filepath)
deleted_count += 1
except Exception as e:
error_count += 1
print(f"無法刪除 {filepath}: {str(e)}")
# 清空追蹤的生成檔案列表
generated_files.clear()
result = "🗑️ 清除完成!\n"
result += "=" * 60 + "\n"
result += f"✅ 成功刪除: {deleted_count} 個檔案\n"
if error_count > 0:
result += f"⚠️ 刪除失敗: {error_count} 個檔案\n"
result += f"\n📁 臨時目錄: {TEMP_DIR}\n"
result += "\n💡 提示: Gradio 暫存圖片和生成的影片已清除"
return result
except Exception as e:
return f"❌ 錯誤: 清除失敗 - {str(e)}"
# ================================================================================================
pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509",
scheduler=scheduler,
torch_dtype=dtype
).to(device)
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors",
adapter_name="fast"
)
pipe.load_lora_weights(
"dx8152/Qwen-Edit-2509-Light-Migration",
weight_name="参考色调.safetensors",
adapter_name="angles"
)
pipe.set_adapters(["angles"], adapter_weights=[1.])
pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.)
pipe.set_adapters(["fast"], adapter_weights=[1.])
pipe.fuse_lora(adapter_names=["fast"], lora_scale=1.)
pipe.unload_lora_weights()
#spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Qwen-Image", variant="fa3")
pipe.transformer.set_attention_backend("_flash_3_hub")
optimize_pipeline_(
pipe,
image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))],
prompt="prompt"
)
MAX_SEED = np.iinfo(np.int32).max
# Default prompt for light migration
DEFAULT_PROMPT = "参考色调,移除图1原有的光照并参考图2的光照和色调对图1重新照明"
@spaces.GPU
def infer_light_migration(
image: Optional[Image.Image] = None,
light_source: Optional[Image.Image] = None,
prompt: str = DEFAULT_PROMPT,
seed: int = 0,
randomize_seed: bool = True,
true_guidance_scale: float = 1.0,
num_inference_steps: int = 8,
height: Optional[int] = None,
width: Optional[int] = None,
progress: Optional[gr.Progress] = gr.Progress(track_tqdm=True)
) -> Tuple[Image.Image, int]:
"""
Transfer lighting and color tones from a reference image to a source image
using Qwen Image Edit 2509 with the Light Migration LoRA.
Args:
image (PIL.Image.Image | None, optional):
The source image to relight. Defaults to None.
light_source (PIL.Image.Image | None, optional):
The reference image providing the lighting and color tones. Defaults to None.
prompt (str, optional):
The prompt describing the lighting transfer operation.
Defaults to the Chinese prompt for light migration.
seed (int, optional):
Random seed for the generation. Ignored if `randomize_seed=True`.
Defaults to 0.
randomize_seed (bool, optional):
If True, a random seed (0..MAX_SEED) is chosen per call.
Defaults to True.
true_guidance_scale (float, optional):
CFG / guidance scale controlling prompt adherence.
Defaults to 1.0 for the distilled transformer.
num_inference_steps (int, optional):
Number of inference steps. Defaults to 4.
height (int, optional):
Output image height. Must typically be a multiple of 8.
If set to 0 or None, the model will infer a size. Defaults to None.
width (int, optional):
Output image width. Must typically be a multiple of 8.
If set to 0 or None, the model will infer a size. Defaults to None.
Returns:
Tuple[PIL.Image.Image, int]:
- The relit output image.
- The actual seed used for generation.
"""
if image is None:
raise gr.Error("Please upload a source image (Image 1).")
if light_source is None:
raise gr.Error("Please upload a light source reference image (Image 2).")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
# Prepare images - Image 1 is source, Image 2 is light reference
pil_images = []
if isinstance(image, Image.Image):
pil_images.append(image.convert("RGB"))
elif hasattr(image, "name"):
pil_images.append(Image.open(image.name).convert("RGB"))
if isinstance(light_source, Image.Image):
pil_images.append(light_source.convert("RGB"))
elif hasattr(light_source, "name"):
pil_images.append(Image.open(light_source.name).convert("RGB"))
result = pipe(
image=pil_images,
prompt=prompt,
height=height if height and height != 0 else None,
width=width if width and width != 0 else None,
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=true_guidance_scale,
num_images_per_prompt=1,
).images[0]
return result, seed
def update_dimensions_on_upload(
image: Optional[Image.Image]
) -> Tuple[int, int]:
"""
Compute recommended (width, height) for the output resolution when an
image is uploaded while preserving the aspect ratio.
Args:
image (PIL.Image.Image | None):
The uploaded image. If `None`, defaults to (1024, 1024).
Returns:
Tuple[int, int]:
The new (width, height).
"""
if image is None:
return 1024, 1024
original_width, original_height = image.size
if original_width > original_height:
new_width = 1024
aspect_ratio = original_height / original_width
new_height = int(new_width * aspect_ratio)
else:
new_height = 1024
aspect_ratio = original_width / original_height
new_width = int(new_height * aspect_ratio)
# Ensure dimensions are multiples of 8
new_width = (new_width // 8) * 8
new_height = (new_height // 8) * 8
return new_width, new_height
# --- UI ---
css = '''
#col-container { max-width: 1000px; margin: 0 auto; }
.dark .progress-text { color: white !important }
#examples { max-width: 1000px; margin: 0 auto; }
.image-container { min-height: 300px; }
'''
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## 💡 Qwen Image Edit — Light Migration")
gr.Markdown("""
Transfer lighting and color tones from a reference image to your source image ✨
Using [dx8152's Qwen-Edit-2509-Light-Migration LoRA](https://huggingface.co/dx8152/Qwen-Edit-2509-Light-Migration)
and [lightx2v/Qwen-Image-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Lightning/tree/main) for 8-step inference 💨
""")
with gr.Row():
with gr.Column():
with gr.Row():
image = gr.Image(
label="Image 1 (Source - to be relit)",
type="pil",
elem_classes="image-container"
)
light_source = gr.Image(
label="Image 2 (Light Reference)",
type="pil",
elem_classes="image-container"
)
run_btn = gr.Button("✨ Transfer Lighting", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
placeholder="Enter prompt for light migration...",
lines=2
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
true_guidance_scale = gr.Slider(
label="True Guidance Scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=1.0
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=40,
step=1,
value=8
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=8,
value=1024
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=8,
value=1024
)
# ================================================================================================
# added by DJ
# 檔案管理區塊
gr.Markdown("---")
gr.Markdown("### 📁 檔案管理")
with gr.Row():
print_button = gr.Button("📋 Print", variant="secondary")
clear_button = gr.Button("🗑️ Clear", variant="stop")
file_list_textbox = gr.Textbox(
label="列表",
lines=12,
placeholder="點擊 Print 按鈕查看檔案列表...\n\n📤 會顯示 Gradio 暫存的圖片\n🎬 會顯示生成的影片",
interactive=False
)
# ================================================================================================
with gr.Column():
result = gr.Image(label="Output Image", interactive=False)
# output_seed = gr.Number(label="Seed Used", interactive=False, visible=False)
gr.Examples(
examples=[
# Character 1 with 3 different lights
["character_1.png", "light_1.png"],
["character_1.png", "light_3.jpeg"],
["character_1.png", "light_5.png"],
# Character 2 with 3 different lights
["character_2.png", "light_2.png"],
["character_2.png", "light_4.png"],
["character_2.png", "light_6.png"],
# Place 1 with 3 different lights
["place_1.png", "light_1.png"],
["place_1.png", "light_4.png"],
["place_1.png", "light_6.png"],
],
inputs=[
image, light_source
],
outputs=[result, seed],
fn=infer_light_migration,
cache_examples=True,
cache_mode="lazy",
elem_id="examples"
)
inputs = [
image, light_source, prompt,
seed, randomize_seed, true_guidance_scale,
num_inference_steps, height, width
]
outputs = [result, seed]
# Run button click
run_btn.click(
fn=infer_light_migration,
inputs=inputs,
outputs=outputs
)
# ================================================================================================
# added by DJ
print_button.click(fn=list_temp_files, inputs=None, outputs=file_list_textbox)
clear_button.click(fn=clear_temp_files, inputs=None, outputs=file_list_textbox)
# ================================================================================================
# Image upload triggers dimension update
image.upload(
fn=update_dimensions_on_upload,
inputs=[image],
outputs=[width, height]
)
# API endpoint
# gr.api(infer_light_migration, api_name="infer_light_migration")
# ================================================================================================
# modified by DJ
# demo.launch(mcp_server=True, theme=gr.themes.Citrus(), css=css, footer_links=["api", "gradio", "settings"])
demo.launch(theme=gr.themes.Citrus(), css=css, footer_links=["api", "gradio", "settings"])
# ================================================================================================