POET / app.py
xh365's picture
update
a65a087
raw
history blame
12.7 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
import re
import transformers
# Optional: keep these utilities if your pipeline depends on them
from optim_utils import optimize_prompt
from utils import (
clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
get_refine_msg, clean_cache, get_personalize_message,
clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS # some may be unused after simplification
)
# =========================
# Constants / Defaults
# =========================
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
default_t2i_model = "black-forest-labs/FLUX.1-schnell" # "black-forest-labs/FLUX.1-dev"
default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
NUM_IMAGES = 4
MAX_ROUND = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
clean_cache()
selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""
METHOD = "Experimental" # keep ONLY experimental
# Global states for a single-task, single-method flow
counter = 1
enable_submit = False
responses_memory = {METHOD: {}}
example_data = [
[
"A futuristic city skyline at sunset",
IMAGES["Tourist promotion"]["ours"]
],
[
"A fantasy castle in the clouds",
IMAGES["Fictional character generation"]["ours"]
],
[
"A robot painting a portrait in a studio",
IMAGES["Interior Design"]["ours"]
],
]
print(example_data)
# =========================
# Image Generation Helpers
# =========================
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=True,
width=256,
height=256,
guidance_scale=5,
num_inference_steps=18,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
image = selected_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
messages = get_refine_msg(prompt, num_prompts)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
prompt_list = clean_response_gpt(outputs)
return prompt_list
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
print(like_image, dislike_image)
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
return outputs
# =========================
# UI Helper Functions
# =========================
def reset_gallery():
return []
def display_error_message(msg, duration=5):
gr.Warning(msg, duration=duration)
def display_info_message(msg, duration=5):
gr.Info(msg, duration=duration)
def check_satisfaction(sim_radio):
global enable_submit, counter
fully_satisfied_option = ["Satisfied", "Very Satisfied"]
if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
return gr.update(interactive=if_submit)
def select_image(like_radio, images_method):
if like_radio == IMAGE_OPTIONS[0]:
return images_method[0][0]
elif like_radio == IMAGE_OPTIONS[1]:
return images_method[1][0]
elif like_radio == IMAGE_OPTIONS[2]:
return images_method[2][0]
elif like_radio == IMAGE_OPTIONS[3]:
return images_method[3][0]
else:
return None
def check_evaluation(sim_radio):
if not sim_radio:
display_error_message("❌ Please fill all evaluations before changing image or submitting.")
return False
return True
def generate_image(prompt, like_image, dislike_image):
global responses_memory
history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
personalized = prompt
# personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
# personalized = clean_refined_prompt_response_gpt(personalized)
# if "I'm sorry, I can't assist with" in personalized:
# personalized = prompt
gallery_images = []
refined_prompts = call_gpt_refine_prompt(personalized)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
gallery_images.append(img)
yield gallery_images
def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image):
global counter, enable_submit, responses_memory
if check_evaluation(sim_radio):
responses_memory[METHOD][counter] = {
"prompt": prompt,
"sim_radio": sim_radio,
"response": "",
"satisfied_img": f"round {counter}, {like_radio}",
"unsatisfied_img": f"round {counter}, {dislike_radio}",
}
enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
if not history_images:
history_images = current_images
elif current_images:
history_images.extend(current_images)
current_images = []
examples_state = gr.update(samples=history_prompts, visible=True)
prompt_state = gr.update(interactive=True)
next_state = gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
counter += 1
return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn: gr.skip()}
def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image):
global counter, enable_submit, responses_memory
if check_evaluation(sim_radio):
# Save the final round entry
responses_memory[METHOD][counter] = {
"prompt": prompt,
"sim_radio": sim_radio,
"response": "",
"satisfied_img": f"round {counter}, {like_radio}",
"unsatisfied_img": f"round {counter}, {dislike_radio}",
}
# Reset states
counter = 1
enable_submit = False
# Reset buttons
prompt_state = gr.update(interactive=False)
next_state = gr.update(visible=False, interactive=False)
submit_state = gr.update(interactive=False)
redesign_state = gr.update(interactive=False)
display_info_message("βœ… Your answer is saved!")
return None, None, None, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn: gr.skip()}
# =========================
# Interface (single tab, no participant/scenario/background)
# =========================
css = """
#col-container {
margin: 0 auto;
max-width: 700px;
}
#col-container2 {
margin: 0 auto;
max-width: 1000px;
}
#col-container3 {
margin: 0 0 auto auto;
max-width: 300px;
}
#button-container {
display: flex;
justify-content: center;
}
#compact-row {
width:100%;
max-width: 1000px;
margin: 0px auto;
}
"""
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# πŸ“Œ **POET**")
instruction = gr.Markdown(" Supporting Prompting Creativity and Personalization with Automated Expansion of Text-to-Image Generation")
with gr.Tab(""):
with gr.Row(elem_id="compact-row"):
prompt = gr.Textbox(
label="🎨 Revise Prompt",
max_lines=5,
placeholder="Enter your prompt",
scale=3,
visible=True,
)
next_btn = gr.Button("Generate", variant="primary", scale=1)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")
with gr.Column(elem_id="col-container3"):
like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### πŸ“ Evaluation")
sim_radio = gr.Radio(
OPTIONS,
label="How would you rate your satisfaction with the generated images?",
type="value",
elem_classes=["gradio-radio"]
)
like_radio = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time favorite image (optional).",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio = gr.Radio(
IMAGE_OPTIONS,
label="Select your all-time least satisfactory image (optional).",
type="value",
elem_classes=["gradio-radio"]
)
with gr.Column(elem_id="col-container2"):
example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
with gr.Row(elem_id="button-container"):
redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
submit_btn = gr.Button("βœ… Submit", variant="primary", interactive=False, scale=0)
with gr.Column(elem_id="col-container2"):
gr.Markdown("### 🌟 Examples")
ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
gr.Examples(
examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
inputs=[prompt, ex1, ex2, ex3, ex4]
)
# =========================
# Wiring
# =========================
sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn])
dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image])
like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image])
next_btn.click(
fn=generate_image,
inputs=[prompt, like_image, dislike_image],
outputs=[images_method]
).success(lambda: [gr.update(interactive=False), gr.update(interactive=False)], outputs=[next_btn, prompt])
redesign_btn.click(
fn=redesign,
inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image],
outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn]
)
submit_btn.click(
fn=save_response,
inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image],
outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn]
)
if __name__ == "__main__":
demo.launch()