xh365 commited on
Commit
70f494f
·
1 Parent(s): 13ebcb9

update interface

Browse files
__pycache__/live_preview_helpers.cpython-310.pyc CHANGED
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
 
__pycache__/optim_utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/optim_utils.cpython-310.pyc and b/__pycache__/optim_utils.cpython-310.pyc differ
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -6,6 +6,7 @@ import spaces
6
  import torch
7
  import re
8
  import transformers
 
9
 
10
  # Optional: keep these utilities if your pipeline depends on them
11
  from optim_utils import optimize_prompt
@@ -33,17 +34,15 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
33
  clean_cache()
34
 
35
  selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
 
36
  llm_pipe = None
37
- torch.cuda.empty_cache()
38
  inverted_prompt = ""
 
39
 
40
  METHOD = "Experimental" # keep ONLY experimental
41
-
42
- # Global states for a single-task, single-method flow
43
  counter = 1
44
  enable_submit = False
45
  responses_memory = {METHOD: {}}
46
-
47
  example_data = [
48
  [
49
  PROMPTS["Tourist promotion"],
@@ -58,7 +57,6 @@ example_data = [
58
  IMAGES["Interior Design"]["ours"]
59
  ],
60
  ]
61
- print(example_data)
62
 
63
  # =========================
64
  # Image Generation Helpers
@@ -103,11 +101,31 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
103
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
104
  seed = random.randint(0, MAX_SEED)
105
  client = init_gpt_api()
106
- print(like_image, dislike_image)
107
  messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
108
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
109
  return outputs
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # =========================
112
  # UI Helper Functions
113
  # =========================
@@ -241,29 +259,98 @@ css = """
241
  display: flex;
242
  justify-content: center;
243
  }
 
 
 
 
 
244
  #compact-row {
245
  width:100%;
246
  max-width: 1000px;
247
  margin: 0px auto;
248
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
252
- with gr.Column(elem_id="col-container"):
253
  gr.Markdown("# 📌 **POET**")
254
- instruction = gr.Markdown("Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
255
- instruction = gr.Markdown("Images generated by POET is more diverse in multiple aspects; e.g., background, race, ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  with gr.Tab(""):
258
  with gr.Row(elem_id="compact-row"):
259
- prompt = gr.Textbox(
260
- label="🎨 Prompt",
261
- max_lines=5,
262
- placeholder="Enter your prompt",
263
- scale=3,
264
- visible=True,
265
- )
266
- next_btn = gr.Button("Generate", variant="primary", scale=1)
 
 
267
 
268
  with gr.Row(elem_id="compact-row"):
269
  with gr.Column(elem_id="col-container"):
@@ -313,6 +400,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
313
  examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
314
  inputs=[prompt, ex1, ex2, ex3, ex4]
315
  )
 
316
  # =========================
317
  # Wiring
318
  # =========================
 
6
  import torch
7
  import re
8
  import transformers
9
+ import open_clip
10
 
11
  # Optional: keep these utilities if your pipeline depends on them
12
  from optim_utils import optimize_prompt
 
34
  clean_cache()
35
 
36
  selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
37
+ clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
38
  llm_pipe = None
 
39
  inverted_prompt = ""
40
+ torch.cuda.empty_cache()
41
 
42
  METHOD = "Experimental" # keep ONLY experimental
 
 
43
  counter = 1
44
  enable_submit = False
45
  responses_memory = {METHOD: {}}
 
46
  example_data = [
47
  [
48
  PROMPTS["Tourist promotion"],
 
57
  IMAGES["Interior Design"]["ours"]
58
  ],
59
  ]
 
60
 
61
  # =========================
62
  # Image Generation Helpers
 
101
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
102
  seed = random.randint(0, MAX_SEED)
103
  client = init_gpt_api()
 
104
  messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
105
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
106
  return outputs
107
 
108
+ @spaces.GPU(duration=100)
109
+ def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2):
110
+ global inverted_prompt
111
+ text_params = {
112
+ "iter": iter,
113
+ "lr": lr,
114
+ "batch_size": batch_size,
115
+ "prompt_len": prompt_len,
116
+ "weight_decay": 0.1,
117
+ "prompt_bs": 1,
118
+ "loss_weight": 1.0,
119
+ "print_step": 100,
120
+ "clip_model": CLIP_MODEL,
121
+ "clip_pretrain": PRETRAINED_CLIP,
122
+ }
123
+ inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
124
+ print(inverted_prompt)
125
+
126
+ # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
127
+ # return learned_prompt
128
+
129
  # =========================
130
  # UI Helper Functions
131
  # =========================
 
259
  display: flex;
260
  justify-content: center;
261
  }
262
+ #compact-compact-row {
263
+ width:100%;
264
+ max-width: 800px;
265
+ margin: 0px auto;
266
+ }
267
  #compact-row {
268
  width:100%;
269
  max-width: 1000px;
270
  margin: 0px auto;
271
  }
272
+ .header-section {
273
+ text-align: center;
274
+ margin-bottom: 2rem;
275
+ }
276
+ .abstract-text {
277
+ text-align: justify;
278
+ line-height: 1.6;
279
+ margin: 0.5rem 0;
280
+ padding: 0.5rem;
281
+ background-color: rgba(0, 0, 0, 0.05);
282
+ border-radius: 8px;
283
+ border-left: 4px solid #3498db;
284
+ }
285
+ .paper-link {
286
+ display: inline-block;
287
+ margin: 0rem 0;
288
+ padding: 0rem 0rem;
289
+ background-color: #3498db;
290
+ color: white;
291
+ text-decoration: none;
292
+ border-radius: 5px;
293
+ font-weight: 500;
294
+ }
295
+ .paper-link:hover {
296
+ background-color: #2980b9;
297
+ text-decoration: none;
298
+ }
299
+ .authors-section {
300
+ text-align: center;
301
+ margin: 0 0;
302
+ font-style: italic;
303
+ color: #666;
304
+ }
305
+ .authors-title {
306
+ font-weight: bold;
307
+ margin-bottom: 0rem;
308
+ color: #333;
309
+ }
310
  """
311
 
312
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
313
+ with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
314
  gr.Markdown("# 📌 **POET**")
315
+ gr.Markdown("## Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
316
+
317
+ # <strong>Abstract:</strong> State-of-the-art visual generative AI tools hold immense potential to assist users in the early ideation stages of creative tasks — offering the ability to generate (rather than search for) novel and unprecedented (instead of existing) images of considerable quality that also adhere to boundless combinations of user specifications. However, many large-scale text-to-image systems are designed for broad applicability, yielding conventional output that may limit creative exploration. They also employ interaction methods that may be difficult for beginners. #
318
+ gr.Markdown("""
319
+ <div class="abstract-text">
320
+ <strong>Abstract:</strong> Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work.
321
+ </div>
322
+ """, elem_classes=["abstract-text"])
323
+
324
+ # Paper Link
325
+ gr.HTML("""
326
+ <div style="text-align: center;">
327
+ <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
328
+ 📄 Read the Full Paper .
329
+ </a>
330
+ </div>
331
+ """)
332
+
333
+ # Authors
334
+ gr.Markdown("""
335
+ <div class="authors-section">
336
+ Evans Han, Alice Qian Zhang, Haiyi Zhu, Hong Shen, Paul Pu Liang, Jane Hsieh
337
+ </div>
338
+ """, elem_classes=["authors-section"])
339
+
340
+ # gr.Markdown("---")
341
 
342
  with gr.Tab(""):
343
  with gr.Row(elem_id="compact-row"):
344
+ with gr.Column(elem_id="col-container"):
345
+ with gr.Row():
346
+ prompt = gr.Textbox(
347
+ label="🎨 Prompt",
348
+ max_lines=5,
349
+ placeholder="Enter your prompt",
350
+ visible=True,
351
+ )
352
+ with gr.Column(elem_id="col-container3"):
353
+ next_btn = gr.Button("Generate", variant="primary", scale=1)
354
 
355
  with gr.Row(elem_id="compact-row"):
356
  with gr.Column(elem_id="col-container"):
 
400
  examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
401
  inputs=[prompt, ex1, ex2, ex3, ex4]
402
  )
403
+
404
  # =========================
405
  # Wiring
406
  # =========================
optim_utils.py CHANGED
@@ -19,9 +19,6 @@ def nn_project(curr_embeds, embedding_layer, print_hits=False):
19
  with torch.no_grad():
20
  bsz,seq_len,emb_dim = curr_embeds.shape
21
 
22
- # Using the sentence transformers semantic search which is
23
- # a dot product exact kNN search between a set of
24
- # query vectors and a corpus of vectors
25
  curr_embeds = curr_embeds.reshape((-1,emb_dim))
26
  curr_embeds = normalize_embeddings(curr_embeds) # queries
27
 
 
19
  with torch.no_grad():
20
  bsz,seq_len,emb_dim = curr_embeds.shape
21
 
 
 
 
22
  curr_embeds = curr_embeds.reshape((-1,emb_dim))
23
  curr_embeds = normalize_embeddings(curr_embeds) # queries
24