xh365 commited on
Commit
a65a087
·
1 Parent(s): 185470f
__pycache__/live_preview_helpers.cpython-310.pyc CHANGED
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
 
__pycache__/optim_utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/optim_utils.cpython-310.pyc and b/__pycache__/optim_utils.cpython-310.pyc differ
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -21,7 +21,7 @@ from utils import (
21
  # =========================
22
  CLIP_MODEL = "ViT-H-14"
23
  PRETRAINED_CLIP = "laion2b_s32b_b79k"
24
- default_t2i_model = "black-forest-labs/FLUX.1-dev"
25
  default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
@@ -37,7 +37,6 @@ llm_pipe = None
37
  torch.cuda.empty_cache()
38
  inverted_prompt = ""
39
 
40
- VERBAL_MSG = "Please explain your rating of satisfaction in few words or sentences."
41
  METHOD = "Experimental" # keep ONLY experimental
42
 
43
  # Global states for a single-task, single-method flow
@@ -45,6 +44,22 @@ counter = 1
45
  enable_submit = False
46
  responses_memory = {METHOD: {}}
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # =========================
49
  # Image Generation Helpers
50
  # =========================
@@ -88,6 +103,7 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
88
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
89
  seed = random.randint(0, MAX_SEED)
90
  client = init_gpt_api()
 
91
  messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
92
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
93
  return outputs
@@ -128,21 +144,16 @@ def check_evaluation(sim_radio):
128
  return False
129
  return True
130
 
131
- # =========================
132
- # Core Actions (single method)
133
- # =========================
134
  def generate_image(prompt, like_image, dislike_image):
135
  global responses_memory
136
  history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
137
  feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
138
-
139
- personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
140
- personalized = clean_refined_prompt_response_gpt(personalized)
141
- if "I'm sorry, I can't assist with" in personalized:
142
- personalized = prompt
143
-
144
  gallery_images = []
145
- # Experimental method refines prompts first
146
  refined_prompts = call_gpt_refine_prompt(personalized)
147
  for i in range(NUM_IMAGES):
148
  img = infer(refined_prompts[i])
@@ -239,19 +250,19 @@ css = """
239
 
240
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
241
  with gr.Column(elem_id="col-container"):
242
- gr.Markdown("# 📌 **PAI-GEN — Experimental Only**")
243
- instruction = gr.Markdown(INSTRUCTION)
244
 
245
- with gr.Tab("Task"):
246
  with gr.Row(elem_id="compact-row"):
247
  prompt = gr.Textbox(
248
  label="🎨 Revise Prompt",
249
  max_lines=5,
250
  placeholder="Enter your prompt",
251
- scale=4,
252
  visible=True,
253
  )
254
- next_btn = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
255
 
256
  with gr.Row(elem_id="compact-row"):
257
  with gr.Column(elem_id="col-container"):
@@ -282,14 +293,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
282
  elem_classes=["gradio-radio"]
283
  )
284
 
285
- response = gr.Textbox(
286
- label="Briefly explain your rating.",
287
- max_lines=1,
288
- interactive=False,
289
- container=False,
290
- value=VERBAL_MSG
291
- )
292
-
293
  with gr.Column(elem_id="col-container2"):
294
  example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
295
  history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
@@ -298,6 +301,17 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
298
  redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
299
  submit_btn = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
300
 
 
 
 
 
 
 
 
 
 
 
 
301
  # =========================
302
  # Wiring
303
  # =========================
 
21
  # =========================
22
  CLIP_MODEL = "ViT-H-14"
23
  PRETRAINED_CLIP = "laion2b_s32b_b79k"
24
+ default_t2i_model = "black-forest-labs/FLUX.1-schnell" # "black-forest-labs/FLUX.1-dev"
25
  default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
 
37
  torch.cuda.empty_cache()
38
  inverted_prompt = ""
39
 
 
40
  METHOD = "Experimental" # keep ONLY experimental
41
 
42
  # Global states for a single-task, single-method flow
 
44
  enable_submit = False
45
  responses_memory = {METHOD: {}}
46
 
47
+ example_data = [
48
+ [
49
+ "A futuristic city skyline at sunset",
50
+ IMAGES["Tourist promotion"]["ours"]
51
+ ],
52
+ [
53
+ "A fantasy castle in the clouds",
54
+ IMAGES["Fictional character generation"]["ours"]
55
+ ],
56
+ [
57
+ "A robot painting a portrait in a studio",
58
+ IMAGES["Interior Design"]["ours"]
59
+ ],
60
+ ]
61
+ print(example_data)
62
+
63
  # =========================
64
  # Image Generation Helpers
65
  # =========================
 
103
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
104
  seed = random.randint(0, MAX_SEED)
105
  client = init_gpt_api()
106
+ print(like_image, dislike_image)
107
  messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
108
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
109
  return outputs
 
144
  return False
145
  return True
146
 
 
 
 
147
  def generate_image(prompt, like_image, dislike_image):
148
  global responses_memory
149
  history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
150
  feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
151
+ personalized = prompt
152
+ # personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
153
+ # personalized = clean_refined_prompt_response_gpt(personalized)
154
+ # if "I'm sorry, I can't assist with" in personalized:
155
+ # personalized = prompt
 
156
  gallery_images = []
 
157
  refined_prompts = call_gpt_refine_prompt(personalized)
158
  for i in range(NUM_IMAGES):
159
  img = infer(refined_prompts[i])
 
250
 
251
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
252
  with gr.Column(elem_id="col-container"):
253
+ gr.Markdown("# 📌 **POET**")
254
+ instruction = gr.Markdown(" Supporting Prompting Creativity and Personalization with Automated Expansion of Text-to-Image Generation")
255
 
256
+ with gr.Tab(""):
257
  with gr.Row(elem_id="compact-row"):
258
  prompt = gr.Textbox(
259
  label="🎨 Revise Prompt",
260
  max_lines=5,
261
  placeholder="Enter your prompt",
262
+ scale=3,
263
  visible=True,
264
  )
265
+ next_btn = gr.Button("Generate", variant="primary", scale=1)
266
 
267
  with gr.Row(elem_id="compact-row"):
268
  with gr.Column(elem_id="col-container"):
 
293
  elem_classes=["gradio-radio"]
294
  )
295
 
 
 
 
 
 
 
 
 
296
  with gr.Column(elem_id="col-container2"):
297
  example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
298
  history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
 
301
  redesign_btn = gr.Button("🎨 Redesign", variant="primary", scale=0)
302
  submit_btn = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
303
 
304
+ with gr.Column(elem_id="col-container2"):
305
+ gr.Markdown("### 🌟 Examples")
306
+ ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
307
+ ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
308
+ ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
309
+ ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
310
+
311
+ gr.Examples(
312
+ examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
313
+ inputs=[prompt, ex1, ex2, ex3, ex4]
314
+ )
315
  # =========================
316
  # Wiring
317
  # =========================
utils.py CHANGED
@@ -52,7 +52,7 @@ def clean_cache():
52
  def setup_model(t2i_model_repo, torch_dtype, device):
53
  if t2i_model_repo == "stabilityai/sdxl-turbo" or t2i_model_repo == "stabilityai/stable-diffusion-3.5-medium" or t2i_model_repo == "stabilityai/stable-diffusion-2-1":
54
  pipe = DiffusionPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
55
- elif t2i_model_repo == "black-forest-labs/FLUX.1-dev":
56
  # pipe = FluxPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
57
  pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
58
  torch.cuda.empty_cache()
@@ -171,7 +171,7 @@ def get_personalize_message(prompt, history_prompts, history_feedback, like_imag
171
  "url": f"data:image/png;base64,{dislike_image_base64}",
172
  },
173
  })
174
-
175
  return messages
176
 
177
  @spaces.GPU
 
52
  def setup_model(t2i_model_repo, torch_dtype, device):
53
  if t2i_model_repo == "stabilityai/sdxl-turbo" or t2i_model_repo == "stabilityai/stable-diffusion-3.5-medium" or t2i_model_repo == "stabilityai/stable-diffusion-2-1":
54
  pipe = DiffusionPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
55
+ elif t2i_model_repo == "black-forest-labs/FLUX.1-dev" or "black-forest-labs/FLUX.1-schnell":
56
  # pipe = FluxPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
57
  pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
58
  torch.cuda.empty_cache()
 
171
  "url": f"data:image/png;base64,{dislike_image_base64}",
172
  },
173
  })
174
+ print(messages)
175
  return messages
176
 
177
  @spaces.GPU