text2image / app.py
LauraWWJ's picture
Update app.py
887f7bc verified
import gradio as gr
import numpy as np
import random
import torch
import os
from diffusers import DiffusionPipeline
from openai import OpenAI
from PIL import Image, ImageDraw
# 建立 OpenAI client(讀環境變數)
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/sdxl-turbo"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# 1️⃣ 中文 → 英文繪圖 prompt
def generate_prompt(chinese_text):
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": (
"你是一位專業的海報設計師和健康新知專家。你的任務是將使用者提供的健康主題,轉換為一個專為 SDXL 模型優化的提示詞(Prompt),以生成一張視覺上引人入勝且資訊清晰的健康新知海報。"
"請確保提示詞包含以下元素:"
"1. **海報風格**:必須明確指定為「專業的健康海報設計(professional health poster design)」、「資訊圖表(infographic)」或類似的風格。"
"2. **核心主題**:將使用者提供的健康主題轉換為具體的視覺元素。例如,如果主題是「每日喝水的重要性」,請將其轉化為「一個水杯中充滿水的插畫(an illustration of a water glass full of water)」。"
"3. **視覺細節**:加入能提升海報專業感的細節,例如「色彩鮮明(vibrant colors)」、「簡潔的設計(minimalist design)」、「高解析度(high resolution)」、「清晰的排版(clear typography)」。"
"4. **畫面構圖**:利用提示詞引導構圖,例如「中央特寫(center focus)」或「清晰的背景(clean background)」。"
"5. **負面提示詞**:自動加入常見的負面提示詞,以避免生成低品質、不專業的圖像。例如:`blurry, low quality, deformed, messy, amateur, text, watermark`。"
"**格式要求**:"
"* 只輸出一個完整的提示詞,不要有任何解釋或額外文字。"
"* 請將正向提示詞與負面提示詞分開,中間用逗號隔開。"
"* 所有提示詞都使用英文。"
"**範例**:"
"如果使用者輸入:`每日喝水的重要性`"
"請輸出:`a professional health poster design about daily hydration, an illustration of a water glass full of water, clean and vibrant colors, minimalist style, high resolution, clear typography --no blurry, low quality, deformed, messy, amateur, text, watermark`"
)
},
{"role": "user", "content": chinese_text}
],
max_tokens=150,
temperature=0.7,
)
return response.choices[0].message.content.strip()
# 2️⃣ 繪圖主函式
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
white_width,
white_height,
white_alpha, # 透明度
corner_radius, # 圓角
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
# 🔹 GPT 產生完整英文 prompt
final_prompt = generate_prompt(prompt)
# 強化的負面 prompt
strong_negative = (
"text, words, letters, typography, logo, watermark, signature, "
"messy details, clutter, objects in the center, elements overlapping the central blank area"
)
if negative_prompt:
strong_negative += ", " + negative_prompt
# 生成圖像
image = pipe(
prompt=final_prompt,
negative_prompt=strong_negative,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0].convert("RGBA")
# 🔹 建立透明圖層,畫半透明白底
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
center_x, center_y = width // 2, height // 2
x0 = center_x - white_width // 2
y0 = center_y - white_height // 2
x1 = center_x + white_width // 2
y1 = center_y + white_height // 2
draw.rounded_rectangle(
[x0, y0, x1, y1],
radius=corner_radius,
fill=(255, 255, 255, white_alpha) # 白色 + 透明度
)
# 疊合到原圖
image = Image.alpha_composite(image, overlay).convert("RGB")
return image, seed, final_prompt
# 範例
examples = [
"多吃水果促進健康",
"運動對小朋友的好處",
"保持醫院清潔的重要性",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
# 3️⃣ Gradio UI
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # 健康布告欄產生器 🩺")
with gr.Row():
prompt = gr.Text(
label="輸入中文主題",
max_lines=2,
placeholder="輸入布告欄主題,例如:多吃水果促進健康",
)
run_button = gr.Button("產生圖像", variant="primary")
result = gr.Image(label="Result", show_label=False)
prompt_box = gr.Textbox(label="英文繪圖 Prompt(AI 產生)", interactive=False)
with gr.Accordion("進階設定", open=False):
negative_prompt = gr.Text(label="Negative prompt", placeholder="額外要避免的內容")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="隨機種子", value=True)
with gr.Row():
width = gr.Slider(label="寬度", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=384)
height = gr.Slider(label="高度", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=384)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=15.0, step=0.5, value=8)
num_inference_steps = gr.Slider(label="步數", minimum=1, maximum=50, step=1, value=25)
with gr.Row():
white_width = gr.Slider(label="白底寬度", minimum=50, maximum=800, step=10, value=300)
white_height = gr.Slider(label="白底高度", minimum=50, maximum=800, step=10, value=300)
with gr.Row():
white_alpha = gr.Slider(label="白底透明度 (0透明-255不透明)", minimum=0, maximum=255, step=5, value=210)
corner_radius = gr.Slider(label="圓角大小", minimum=0, maximum=400, step=5, value=50)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
white_width,
white_height,
white_alpha,
corner_radius,
],
outputs=[result, seed, prompt_box],
)
if __name__ == "__main__":
demo.launch()