evansh666 commited on
Commit
9cf98ec
·
1 Parent(s): 4c26858

first commit

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
app.py CHANGED
@@ -1,7 +1,517 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio.themes.base import Base
3
+ import numpy as np
4
+ import random
5
+ import spaces
6
+ import torch
7
+ import re
8
+ import open_clip
9
+ from optim_utils import optimize_prompt
10
+ from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache
11
+ from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION
12
+ import spaces #[uncomment to use ZeroGPU]
13
+ import transformers
14
+ import gspread
15
+ import asyncio
16
+ from datetime import datetime
17
 
18
+ CLIP_MODEL = "ViT-H-14"
19
+ PRETRAINED_CLIP = "laion2b_s32b_b79k"
20
+ default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev"
21
+ default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ MAX_IMAGE_SIZE = 1024
24
+ NUM_IMAGES=4
25
 
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
28
+ clean_cache()
29
+
30
+ selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
31
+ # clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
32
+ llm_pipe = None
33
+ torch.cuda.empty_cache()
34
+ inverted_prompt = ""
35
+
36
+ VERBAL_MSG = "Please verbally describe key differences found in the image pair."
37
+ DEFAULT_SCENARIO = "Product advertisement"
38
+ METHODS = ["Method 1", "Method 2"]
39
+ MAX_ROUND = 5
40
+ # intermittent memory
41
+ counter1, counter2 = 1, 1
42
+ responses_memory = {}
43
+ assigned_scenarios = list(SCENARIOS.keys())[:2]
44
+ current_task1, current_task2 = METHODS # current task 1 (tab 1)
45
+ task1_success, task2_success = False, False
46
+
47
+ ########################################################################################################
48
+ # Generating images with two methods
49
+ ########################################################################################################
50
+
51
+
52
+ @spaces.GPU(duration=65)
53
+ def infer(
54
+ prompt,
55
+ negative_prompt="",
56
+ seed=42,
57
+ randomize_seed=True,
58
+ width=256,
59
+ height=256,
60
+ guidance_scale=5,
61
+ num_inference_steps=18,
62
+ progress=gr.Progress(track_tqdm=True),
63
+ ):
64
+ if randomize_seed:
65
+ seed = random.randint(0, MAX_SEED)
66
+
67
+ generator = torch.Generator().manual_seed(seed)
68
+ with torch.no_grad():
69
+ image = selected_pipe(
70
+ prompt=prompt,
71
+ negative_prompt=negative_prompt,
72
+ guidance_scale=guidance_scale,
73
+ num_inference_steps=num_inference_steps,
74
+ width=width,
75
+ height=height,
76
+ generator=generator,
77
+ ).images[0]
78
+
79
+ return image
80
+
81
+ async def infer_async(prompt):
82
+ return infer(prompt)
83
+ # generate a batch of images in parallel
84
+ async def generate_batch(prompts):
85
+ tasks = [infer_async(p) for p in prompts]
86
+ images = await asyncio.gather(*tasks) # Run all in parallel
87
+ return images
88
+
89
+ @spaces.GPU
90
+ def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
91
+ print(f"loading {default_llm_model}")
92
+ global llm_pipe
93
+ if not llm_pipe:
94
+ llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
95
+
96
+ messages = get_refine_msg(prmpt, num_prompts)
97
+ terminators = [
98
+ llm_pipe.tokenizer.eos_token_id,
99
+ llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
100
+ ]
101
+ outputs = llm_pipe(
102
+ messages,
103
+ max_new_tokens=max_tokens,
104
+ eos_token_id=terminators,
105
+ do_sample=True,
106
+ temperature=temperature,
107
+ top_p=top_p,
108
+ )
109
+ prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
110
+ return prompt_list
111
+
112
+ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
113
+ seed = random.randint(0, MAX_SEED)
114
+ client = init_gpt_api()
115
+ messages = get_refine_msg(prompt, num_prompts)
116
+ outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
117
+ prompt_list = clean_response_gpt(outputs)
118
+ return prompt_list
119
+
120
+ def refine_prompt(gallery_state, prompt):
121
+ modified_prompts = call_gpt_refine_prompt(prompt)
122
+ return modified_prompts
123
+
124
+ # eval(prompt, inverted_prompt, gallery_state, clip_model, preprocess)
125
+
126
+ @spaces.GPU(duration=100)
127
+ def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
128
+ text_params = {
129
+ "iter": iter,
130
+ "lr": lr,
131
+ "batch_size": batch_size,
132
+ "prompt_len": prompt_len,
133
+ "weight_decay": 0.1,
134
+ "prompt_bs": 1,
135
+ "loss_weight": 1.0,
136
+ "print_step": 100,
137
+ "clip_model": CLIP_MODEL,
138
+ "clip_pretrain": PRETRAINED_CLIP,
139
+ }
140
+ inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
141
+
142
+ # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
143
+ # return learned_prompt
144
+
145
+
146
+ def eval(prompt, optimized_prompt, optimized_images, clip_model, preprocess):
147
+ torch.cuda.empty_cache()
148
+ tokenizer = open_clip.get_tokenizer(CLIP_MODEL)
149
+ images = [preprocess(i).unsqueeze(0) for i in optimized_images]
150
+ images = torch.concatenate(images).to(device)
151
+
152
+ with torch.no_grad():
153
+ image_feat = clip_model.encode_image(images)
154
+ text_feat = clip_model.encode_text(tokenizer([prompt]).to(device))
155
+ optimized_text_feat = clip_model.encode_text(tokenizer([optimized_prompt]).to(device))
156
+
157
+ image_feat /= image_feat.norm(dim=-1, keepdim=True)
158
+ text_feat /= text_feat.norm(dim=-1, keepdim=True)
159
+ optimized_text_feat /= optimized_text_feat.norm(dim=-1, keepdim=True)
160
+
161
+ similarity = text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
162
+ similarity_optimized = optimized_text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
163
+
164
+
165
+ ########################################################################################################
166
+ # Button-related functions
167
+ ########################################################################################################
168
+
169
+ def reset_gallery():
170
+ return []
171
+
172
+ def display_error_message(msg, duration=5):
173
+ gr.Warning(msg, duration=duration)
174
+
175
+ def display_info_message(msg, duration=5):
176
+ gr.Info(msg, duration=duration)
177
+
178
+ def switch_tab(active_tab):
179
+ print("switching tab")
180
+ if active_tab == "Task A":
181
+ return gr.Tabs(selected="Task B")
182
+ else:
183
+ return gr.Tabs(selected="Task A")
184
+
185
+ def set_user(participant):
186
+ global responses_memory
187
+ responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
188
+
189
+ id = re.findall(r'\d+', participant)
190
+ if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
191
+ assigned_scenarios = list(SCENARIOS.keys())[:2]
192
+ else:
193
+ assigned_scenarios = list(SCENARIOS.keys())[2:]
194
+ return assigned_scenarios[0]
195
+
196
+ def display_scenario(participant, choice):
197
+ # reset intermittent storage when scenario change
198
+ global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success
199
+
200
+ task1_success, task2_success = False, False
201
+ counter1, counter2 = 1, 1
202
+
203
+ if check_participant(participant):
204
+ responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
205
+
206
+ [current_task1, current_task2] = random.sample(METHODS, 2)
207
+ if current_task1 == METHODS[0]:
208
+ initial_images1 = IMAGES[choice]["baseline"]
209
+ initial_images2 = IMAGES[choice]["ours"]
210
+ else:
211
+ initial_images1 = IMAGES[choice]["ours"]
212
+ initial_images2 = IMAGES[choice]["baseline"]
213
+
214
+ res = {
215
+ scenario_content: SCENARIOS.get(choice, ""),
216
+ prompt: PROMPTS.get(choice, ""),
217
+ prompt1: "",
218
+ prompt2: "",
219
+ images_method1: initial_images1,
220
+ images_method2: initial_images2,
221
+ gallery_state1: initial_images1,
222
+ gallery_state2: initial_images2,
223
+ sim_radio1: None,
224
+ sim_radio2: None,
225
+ response1: VERBAL_MSG,
226
+ response2: VERBAL_MSG,
227
+ next_btn1: gr.update(interactive=False),
228
+ next_btn2: gr.update(interactive=False),
229
+ redesign_btn1: gr.update(interactive=True),
230
+ redesign_btn2: gr.update(interactive=True),
231
+ submit_btn1: gr.update(interactive=False),
232
+ submit_btn2: gr.update(interactive=False),
233
+ }
234
+ return res
235
+
236
+ def generate_image(participant, scenario, prompt, gallery_state, active_tab):
237
+ if not check_participant(participant): return [], []
238
+ global current_task1, current_task2
239
+
240
+ method = current_task1 if active_tab == "Task A" else current_task2
241
+
242
+ if method == METHODS[0]:
243
+ for i in range(NUM_IMAGES):
244
+ img = infer(prompt)
245
+ gallery_state.append(img)
246
+ yield gallery_state
247
+ else:
248
+ refined_prompts = refine_prompt(gallery_state, prompt)
249
+ for i in range(NUM_IMAGES):
250
+ img = infer(refined_prompts[i])
251
+ gallery_state.append(img)
252
+ yield gallery_state
253
+
254
+ def check_satisfaction(sim_radio, active_tab):
255
+ global counter1, counter2, current_task1, current_task2
256
+ method = current_task1 if active_tab == "Task A" else current_task2
257
+ counter = counter1 if method == METHODS[0] else counter2
258
+
259
+ fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
260
+ enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
261
+
262
+ return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
263
+
264
+ def check_participant(participant):
265
+ if participant == "":
266
+ display_error_message("Please fill your participant id!")
267
+ return False
268
+ return True
269
+
270
+ def check_evaluation(sim_radio, response):
271
+ if not sim_radio :
272
+ display_error_message("❌ Please fill all evaluations before change image or submit.")
273
+ return False
274
+
275
+ return True
276
+
277
+ def redesign(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
278
+ global counter1, counter2, responses_memory, current_task1, current_task2
279
+ method = current_task1 if active_tab == "Task A" else current_task2
280
+
281
+ if check_evaluation(sim_radio, response) and check_participant(participant):
282
+ if method == METHODS[0]:
283
+ counter1 += 1
284
+ counter = counter1
285
+ else:
286
+ counter2 += 1
287
+ counter = counter2
288
+
289
+ responses_memory[participant][method][counter-1] = {}
290
+ responses_memory[participant][method][counter-1]["prompt"] = prompt
291
+ responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
292
+ responses_memory[participant][method][counter-1]["response"] = response
293
+
294
+ prompt_state = gr.update(visible=True)
295
+ next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
296
+ redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
297
+ submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
298
+
299
+ return [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state
300
+ else:
301
+ return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
302
+
303
+ def show_message(selected_option):
304
+ if selected_option:
305
+ return "Click \"Redesign\" and revise your prompt to create images that may more closely match your expectations."
306
+ return ""
307
+
308
+ def save_response(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
309
+ global current_task1, current_task2, counter1, counter2, responses_memory, task1_success, task2_success, assigned_scenarios
310
+ method = current_task1 if active_tab == "Task A" else current_task2
311
+
312
+ if check_evaluation(sim_radio, response) and check_participant(participant):
313
+ counter = counter1 if method == METHODS[0] else counter2
314
+ # image_paths = [save_image(img, "method", i) for i, img in enumerate(images_method)]
315
+
316
+ responses_memory[participant][method][counter] = {}
317
+ responses_memory[participant][method][counter]["prompt"] = prompt
318
+ responses_memory[participant][method][counter]["sim_radio"] = sim_radio
319
+ responses_memory[participant][method][counter]["response"] = response
320
+ prompt_state = gr.update(visible=False)
321
+ next_state = gr.update(visible=False, interactive=False)
322
+ submit_state = gr.update(interactive=False)
323
+ redesign_state = gr.update(interactive=False)
324
+
325
+ try:
326
+ gc = gspread.service_account(filename='credentials.json')
327
+ sheet = gc.open("DiverseGen-phase2").sheet1
328
+
329
+ for i, entry in responses_memory[participant][method].items():
330
+ sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"], entry["response"]])
331
+
332
+ display_info_message("✅ Your answer is saved!")
333
+
334
+ # reset counter and update success indicator
335
+ if method == METHODS[0]:
336
+ counter1 = 1
337
+ else:
338
+ counter2 = 1
339
+
340
+ if active_tab == "Task A":
341
+ task1_success = True
342
+ else:
343
+ task2_success = True
344
+
345
+ tabs = switch_tab(active_tab)
346
+ next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
347
+ return [], [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state, tabs, next_scenario
348
+ except Exception as e:
349
+ display_error_message(f"❌ Error saving response: {str(e)}")
350
+ return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
351
+ else:
352
+ return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
353
+
354
+
355
+ ########################################################################################################
356
+ # Interface
357
+ ########################################################################################################
358
+
359
+ css="""
360
+ #col-container {
361
+ margin: 0 auto;
362
+ max-width: 700px;
363
+ }
364
+
365
+ #col-container2 {
366
+ margin: 0 auto;
367
+ max-width: 1000px;
368
+ }
369
+
370
+ #button-container {
371
+ display: flex;
372
+ justify-content: center; /* Centers the buttons horizontally */
373
+ }
374
+ """
375
+
376
+ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
377
+ with gr.Column(elem_id="col-container"):
378
+ gr.Markdown(" # 📌 **Diverse Text-to-Image Generation**")
379
+
380
+ with gr.Row():
381
+ participant = gr.Textbox(
382
+ label="🧑‍💼 Participant ID", placeholder="Please enter you participant id"
383
+ )
384
+ scenario = gr.Dropdown(
385
+ choices=list(SCENARIOS.keys()),
386
+ # value=DEFAULT_SCENARIO,
387
+ value=None,
388
+ label="���� Scenario",
389
+ interactive=False,
390
+ )
391
+ scenario_content = gr.Textbox(
392
+ label="📖 Background",
393
+ interactive=False,
394
+ # value=SCENARIOS[DEFAULT_SCENARIO]
395
+ )
396
+ prompt = gr.Textbox(
397
+ label="🎨 Prompt",
398
+ max_lines=1,
399
+ # value=PROMPTS[DEFAULT_SCENARIO],
400
+ interactive=False
401
+ )
402
+ active_tab = gr.State("Task A")
403
+ instruction = gr.Markdown(INSTRUCTION)
404
+
405
+ with gr.Tabs() as tabs:
406
+ with gr.TabItem("Task A", id="Task A") as task1_tab:
407
+ task1_tab.select(lambda: "Task A", outputs=[active_tab])
408
+ with gr.Column(elem_id="col-container"):
409
+ # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
410
+ with gr.Row():
411
+ prompt1 = gr.Textbox(
412
+ label="🎨 Revise Prompt",
413
+ max_lines=1,
414
+ placeholder="Enter your prompt",
415
+ # value=PROMPTS[DEFAULT_SCENARIO],
416
+ scale=4,
417
+ visible=False
418
+ )
419
+ next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
420
+
421
+ with gr.Column(elem_id="col-container"):
422
+ gallery_state1 = gr.State(IMAGES[DEFAULT_SCENARIO]["baseline"])
423
+ images_method1 = gr.Gallery(show_label=False, columns=[4], rows=[1], elem_id="gallery")
424
+ with gr.Column(elem_id="col-container2"):
425
+ gr.Markdown("### 📝 Evaluation")
426
+ sim_radio1 = gr.Radio(
427
+ OPTIONS,
428
+ label="How would you evaluate your satisfaction with the generated images, based on your expectations for the specified scenario?",
429
+ type="value",
430
+ elem_classes=["gradio-radio"]
431
+ )
432
+ response1 = gr.Textbox(
433
+ label="Verbally describe key differences found in the image pair.",
434
+ max_lines=1,
435
+ interactive=False,
436
+ container=False,
437
+ value=VERBAL_MSG
438
+ )
439
+
440
+ with gr.Row(elem_id="button-container"):
441
+ redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
442
+ submit_btn1 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
443
+
444
+
445
+ with gr.TabItem("Task B", id="Task B") as task2_tab:
446
+ task2_tab.select(lambda: "Task B", outputs=[active_tab])
447
+ with gr.Column(elem_id="col-container"):
448
+ # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
449
+ with gr.Row():
450
+ prompt2 = gr.Textbox(
451
+ label="🎨 Revise Prompt",
452
+ max_lines=1,
453
+ placeholder="Enter your prompt",
454
+ # value=PROMPTS[DEFAULT_SCENARIO],
455
+ scale=4,
456
+ visible=False
457
+ )
458
+
459
+ next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
460
+
461
+ with gr.Column(elem_id="col-container"):
462
+ gallery_state2 = gr.State(IMAGES[DEFAULT_SCENARIO]["ours"])
463
+ images_method2 = gr.Gallery(show_label=False, columns=[4], rows=[1], elem_id="gallery")
464
+
465
+ with gr.Column(elem_id="col-container2"):
466
+ gr.Markdown("### 📝 Evaluation")
467
+ sim_radio2 = gr.Radio(
468
+ OPTIONS,
469
+ label="How would you evaluate your satisfaction with the generated images, based on your expectations for the specified scenario?",
470
+ type="value",
471
+ elem_classes=["gradio-radio"]
472
+ )
473
+ response2 = gr.Textbox(
474
+ label="Verbally describe key differences found in the image pair.",
475
+ max_lines=1,
476
+ interactive=False,
477
+ container=False,
478
+ value=VERBAL_MSG
479
+ )
480
+ with gr.Row(elem_id="button-container"):
481
+ redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
482
+ submit_btn2 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
483
+
484
+
485
+ ########################################################################################################
486
+ # Button Function Setup
487
+ ########################################################################################################
488
+
489
+ participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
490
+ scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt, prompt1, prompt2, images_method1, images_method2, gallery_state1, gallery_state2, sim_radio1, sim_radio2, response1, response2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
491
+ prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
492
+ prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
493
+ next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, gallery_state1, active_tab], outputs=[images_method1])
494
+ next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, gallery_state2, active_tab], outputs=[images_method2])
495
+ sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
496
+ sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
497
+ redesign_btn1.click(
498
+ fn=redesign,
499
+ inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
500
+ outputs=[gallery_state1, sim_radio1, response1, prompt1, next_btn1, redesign_btn1, submit_btn1]
501
+ )
502
+ redesign_btn2.click(
503
+ fn=redesign,
504
+ inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
505
+ outputs=[gallery_state2, sim_radio2, response2, prompt2, next_btn2, redesign_btn2, submit_btn2]
506
+ )
507
+ submit_btn1.click(fn=save_response,
508
+ inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
509
+ outputs=[images_method1, gallery_state1, sim_radio1, prompt1, response1, next_btn1, redesign_btn1, submit_btn1, tabs, scenario])
510
+
511
+ submit_btn2.click(fn=save_response,
512
+ inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
513
+ outputs=[images_method2, gallery_state2, sim_radio2, prompt2, response2, next_btn2, redesign_btn2, submit_btn2, tabs, scenario])
514
+
515
+
516
+ if __name__ == "__main__":
517
+ demo.launch()
images/.DS_Store ADDED
Binary file (6.15 kB). View file
 
images/scenario1_base1.png ADDED
images/scenario1_base2.png ADDED
images/scenario1_base3.png ADDED
images/scenario1_base4.png ADDED
images/scenario1_our1.png ADDED
images/scenario1_our2.png ADDED
images/scenario1_our3.png ADDED
images/scenario1_our4.png ADDED
images/scenario2_base1.png ADDED
images/scenario2_base2.png ADDED
images/scenario2_base3.png ADDED
images/scenario2_base4.png ADDED
images/scenario2_our1.png ADDED
images/scenario2_our2.png ADDED
images/scenario2_our3.png ADDED
images/scenario2_our4.png ADDED
images/scenario3_base1.png ADDED
images/scenario3_base2.png ADDED
images/scenario3_base3.png ADDED
images/scenario3_base4.png ADDED
images/scenario3_our1.png ADDED
images/scenario3_our2.png ADDED
images/scenario3_our3.png ADDED
images/scenario3_our4.png ADDED
images/scenario4_base1.png ADDED
images/scenario4_base2.png ADDED
images/scenario4_base3.png ADDED
images/scenario4_base4.png ADDED
images/scenario4_our1.png ADDED
images/scenario4_our2.png ADDED
images/scenario4_our3.png ADDED
images/scenario4_our4.png ADDED
images/scenario5_base1.png ADDED
images/scenario5_base2.png ADDED
images/scenario5_base3.png ADDED
images/scenario5_base4.png ADDED
images/scenario5_our1.png ADDED
images/scenario5_our2.png ADDED
images/scenario5_our3.png ADDED
images/scenario5_our4.png ADDED
live_preview_helpers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Any, Dict, List, Optional, Union
4
+ from diffusers import FluxPipeline
5
+
6
+ # Helper functions
7
+ def calculate_shift(
8
+ image_seq_len,
9
+ base_seq_len: int = 256,
10
+ max_seq_len: int = 4096,
11
+ base_shift: float = 0.5,
12
+ max_shift: float = 1.16,
13
+ ):
14
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
+ b = base_shift - m * base_seq_len
16
+ mu = image_seq_len * m + b
17
+ return mu
18
+
19
+ def retrieve_timesteps(
20
+ scheduler,
21
+ num_inference_steps: Optional[int] = None,
22
+ device: Optional[Union[str, torch.device]] = None,
23
+ timesteps: Optional[List[int]] = None,
24
+ sigmas: Optional[List[float]] = None,
25
+ **kwargs,
26
+ ):
27
+ if timesteps is not None and sigmas is not None:
28
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
+ if timesteps is not None:
30
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
+ timesteps = scheduler.timesteps
32
+ num_inference_steps = len(timesteps)
33
+ elif sigmas is not None:
34
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
+ timesteps = scheduler.timesteps
36
+ num_inference_steps = len(timesteps)
37
+ else:
38
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
+ timesteps = scheduler.timesteps
40
+ return timesteps, num_inference_steps
41
+
42
+
43
+ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
44
+ """
45
+ Extends the FluxPipeline to yield intermediate images during the denoising process
46
+ with progressively increasing resolution for faster generation.
47
+ """
48
+ # FLUX pipeline function
49
+ @torch.inference_mode()
50
+ def generate_images(
51
+ self,
52
+ prompt: Union[str, List[str]] = None,
53
+ prompt_2: Optional[Union[str, List[str]]] = None,
54
+ height: Optional[int] = None,
55
+ width: Optional[int] = None,
56
+ num_inference_steps: int = 28,
57
+ timesteps: List[int] = None,
58
+ guidance_scale: float = 3.5,
59
+ num_images_per_prompt: Optional[int] = 1,
60
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61
+ latents: Optional[torch.FloatTensor] = None,
62
+ prompt_embeds: Optional[torch.FloatTensor] = None,
63
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
64
+ output_type: Optional[str] = "pil",
65
+ return_dict: bool = True,
66
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
67
+ max_sequence_length: int = 512,
68
+ ):
69
+ height = height or self.default_sample_size * self.vae_scale_factor
70
+ width = width or self.default_sample_size * self.vae_scale_factor
71
+
72
+ # 1. Check inputs
73
+ self.check_inputs(
74
+ prompt,
75
+ prompt_2,
76
+ height,
77
+ width,
78
+ prompt_embeds=prompt_embeds,
79
+ pooled_prompt_embeds=pooled_prompt_embeds,
80
+ max_sequence_length=max_sequence_length,
81
+ )
82
+
83
+ self._guidance_scale = guidance_scale
84
+ self._joint_attention_kwargs = joint_attention_kwargs
85
+ self._interrupt = False
86
+
87
+ # 2. Define call parameters
88
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
89
+ device = self._execution_device
90
+
91
+ # 3. Encode prompt
92
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
93
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
94
+ prompt=prompt,
95
+ prompt_2=prompt_2,
96
+ prompt_embeds=prompt_embeds,
97
+ pooled_prompt_embeds=pooled_prompt_embeds,
98
+ device=device,
99
+ num_images_per_prompt=num_images_per_prompt,
100
+ max_sequence_length=max_sequence_length,
101
+ lora_scale=lora_scale,
102
+ )
103
+ # 4. Prepare latent variables
104
+ num_channels_latents = self.transformer.config.in_channels // 4
105
+ latents, latent_image_ids = self.prepare_latents(
106
+ batch_size * num_images_per_prompt,
107
+ num_channels_latents,
108
+ height,
109
+ width,
110
+ prompt_embeds.dtype,
111
+ device,
112
+ generator,
113
+ latents,
114
+ )
115
+ # 5. Prepare timesteps
116
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
117
+ image_seq_len = latents.shape[1]
118
+ mu = calculate_shift(
119
+ image_seq_len,
120
+ self.scheduler.config.base_image_seq_len,
121
+ self.scheduler.config.max_image_seq_len,
122
+ self.scheduler.config.base_shift,
123
+ self.scheduler.config.max_shift,
124
+ )
125
+ timesteps, num_inference_steps = retrieve_timesteps(
126
+ self.scheduler,
127
+ num_inference_steps,
128
+ device,
129
+ timesteps,
130
+ sigmas,
131
+ mu=mu,
132
+ )
133
+ self._num_timesteps = len(timesteps)
134
+
135
+ # Handle guidance
136
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
137
+
138
+ # 6. Denoising loop
139
+ for i, t in enumerate(timesteps):
140
+ if self.interrupt:
141
+ continue
142
+
143
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
144
+
145
+ noise_pred = self.transformer(
146
+ hidden_states=latents,
147
+ timestep=timestep / 1000,
148
+ guidance=guidance,
149
+ pooled_projections=pooled_prompt_embeds,
150
+ encoder_hidden_states=prompt_embeds,
151
+ txt_ids=text_ids,
152
+ img_ids=latent_image_ids,
153
+ joint_attention_kwargs=self.joint_attention_kwargs,
154
+ return_dict=False,
155
+ )[0]
156
+
157
+ # Yield intermediate result
158
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
159
+ yield self._decode_latents_to_image(latents, height, width, output_type)
160
+ torch.cuda.empty_cache()
161
+
162
+ # Final image
163
+ self.maybe_free_model_hooks()
164
+ torch.cuda.empty_cache()
165
+
166
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
167
+ """Decodes the given latents into an image."""
168
+ vae = vae or self.vae
169
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
170
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
171
+ image = vae.decode(latents, return_dict=False)[0]
172
+ return self.image_processor.postprocess(image, output_type=output_type)[0]
optim_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ from statistics import mean
7
+ import copy
8
+ import json
9
+ from typing import Any, Mapping
10
+ import open_clip
11
+ import torch
12
+
13
+ from sentence_transformers.util import (semantic_search,
14
+ dot_score,
15
+ normalize_embeddings)
16
+
17
+
18
+ def nn_project(curr_embeds, embedding_layer, print_hits=False):
19
+ with torch.no_grad():
20
+ bsz,seq_len,emb_dim = curr_embeds.shape
21
+
22
+ # Using the sentence transformers semantic search which is
23
+ # a dot product exact kNN search between a set of
24
+ # query vectors and a corpus of vectors
25
+ curr_embeds = curr_embeds.reshape((-1,emb_dim))
26
+ curr_embeds = normalize_embeddings(curr_embeds) # queries
27
+
28
+ embedding_matrix = embedding_layer.weight
29
+ embedding_matrix = normalize_embeddings(embedding_matrix)
30
+
31
+ hits = semantic_search(curr_embeds, embedding_matrix,
32
+ query_chunk_size=curr_embeds.shape[0],
33
+ top_k=1,
34
+ score_function=dot_score)
35
+
36
+ if print_hits:
37
+ all_hits = []
38
+ for hit in hits:
39
+ all_hits.append(hit[0]["score"])
40
+ print(f"mean hits:{mean(all_hits)}")
41
+
42
+ nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device=curr_embeds.device)
43
+ nn_indices = nn_indices.reshape((bsz,seq_len))
44
+
45
+ projected_embeds = embedding_layer(nn_indices)
46
+
47
+ return projected_embeds, nn_indices
48
+
49
+ def decode_ids(input_ids, tokenizer, by_token=False):
50
+ input_ids = input_ids.detach().cpu().numpy()
51
+
52
+ texts = []
53
+
54
+ if by_token:
55
+ for input_ids_i in input_ids:
56
+ curr_text = []
57
+ for tmp in input_ids_i:
58
+ curr_text.append(tokenizer.decode([tmp]))
59
+
60
+ texts.append('|'.join(curr_text))
61
+ else:
62
+ for input_ids_i in input_ids:
63
+ texts.append(tokenizer.decode(input_ids_i))
64
+
65
+ return texts
66
+
67
+ def get_target_feature(model, preprocess, tokenizer_funct, device, target_images=None, target_prompts=None):
68
+ if target_images is not None:
69
+ with torch.no_grad():
70
+ curr_images = [preprocess(i).unsqueeze(0) for i in target_images]
71
+ curr_images = torch.concatenate(curr_images).to(device)
72
+ all_target_features = model.encode_image(curr_images)
73
+ else:
74
+ texts = tokenizer_funct(target_prompts).to(device)
75
+ all_target_features = model.encode_text(texts)
76
+
77
+ return all_target_features
78
+
79
+ def encode_text_embedding(model, text_embedding, ids, avg_text=False):
80
+ cast_dtype = model.transformer.get_cast_dtype()
81
+
82
+ x = text_embedding + model.positional_embedding.to(cast_dtype)
83
+ x = x.permute(1, 0, 2) # NLD -> LND
84
+ x = model.transformer(x, attn_mask=model.attn_mask)
85
+ x = x.permute(1, 0, 2) # LND -> NLD
86
+ x = model.ln_final(x)
87
+
88
+ # x.shape = [batch_size, n_ctx, transformer.width]
89
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
90
+ if avg_text:
91
+ x = x[torch.arange(x.shape[0]), :ids.argmax(dim=-1)]
92
+ x[:, 1:-1]
93
+ x = x.mean(dim=1) @ model.text_projection
94
+ else:
95
+ x = x[torch.arange(x.shape[0]), ids.argmax(dim=-1)] @ model.text_projection
96
+
97
+ return x
98
+
99
+ def forward_text_embedding(model, embeddings, ids, image_features, avg_text=False, return_feature=False):
100
+ text_features = encode_text_embedding(model, embeddings, ids, avg_text=avg_text)
101
+
102
+ if return_feature:
103
+ return text_features
104
+
105
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
106
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
107
+
108
+ logits_per_image = image_features @ text_features.t()
109
+ logits_per_text = logits_per_image.t()
110
+
111
+ return logits_per_image, logits_per_text
112
+
113
+ def initialize_prompt(tokenizer, token_embedding, args, device, original_prompt):
114
+ prompt_len = args["prompt_len"]
115
+
116
+ # randomly optimize prompt embeddings
117
+ tokens = tokenizer.encode(original_prompt)
118
+ if len(tokens) > prompt_len:
119
+ tokens = tokens[:prompt_len]
120
+ if len(tokens) < prompt_len:
121
+ tokens += [0] * (prompt_len - len(tokens))
122
+
123
+ prompt_ids = torch.tensor([tokens] * args["prompt_bs"]).to(device)
124
+ # prompt_ids = torch.randint(len(tokenizer.encoder), (args.prompt_bs, prompt_len)).to(device)
125
+ prompt_embeds = token_embedding(prompt_ids).detach()
126
+ prompt_embeds.requires_grad = True
127
+
128
+ # initialize the template
129
+ template_text = "{}"
130
+ padded_template_text = template_text.format(" ".join(["<start_of_text>"] * prompt_len))
131
+ dummy_ids = tokenizer.encode(padded_template_text)
132
+
133
+ # -1 for optimized tokens
134
+ dummy_ids = [i if i != 49406 else -1 for i in dummy_ids]
135
+ dummy_ids = [49406] + dummy_ids + [49407]
136
+ dummy_ids += [0] * (77 - len(dummy_ids))
137
+ dummy_ids = torch.tensor([dummy_ids] * args["prompt_bs"]).to(device)
138
+
139
+ # for getting dummy embeds; -1 won't work for token_embedding
140
+ tmp_dummy_ids = copy.deepcopy(dummy_ids)
141
+ tmp_dummy_ids[tmp_dummy_ids == -1] = 0
142
+ dummy_embeds = token_embedding(tmp_dummy_ids).detach()
143
+ dummy_embeds.requires_grad = False
144
+
145
+ return prompt_embeds, dummy_embeds, dummy_ids
146
+
147
+ def optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, original_prompt):
148
+ opt_iters = args["iter"]
149
+ lr = args["lr"]
150
+ weight_decay = args["weight_decay"]
151
+ print_step = args["print_step"]
152
+ batch_size = args["batch_size"]
153
+ print_new_best = True
154
+
155
+ # initialize prompt
156
+ prompt_embeds, dummy_embeds, dummy_ids = initialize_prompt(tokenizer, token_embedding, args, device, original_prompt)
157
+ p_bs, p_len, p_dim = prompt_embeds.shape
158
+
159
+ # get optimizer
160
+ input_optimizer = torch.optim.AdamW([prompt_embeds], lr=lr, weight_decay=weight_decay)
161
+
162
+ best_sim = -1000 * args["loss_weight"]
163
+ best_text = ""
164
+
165
+ for step in range(opt_iters):
166
+ # randomly sample sample images and get features
167
+ if batch_size is None:
168
+ target_features = all_target_features
169
+ else:
170
+ curr_indx = torch.randperm(len(all_target_features))
171
+ target_features = all_target_features[curr_indx][0:batch_size]
172
+
173
+ universal_target_features = all_target_features
174
+
175
+ # forward projection
176
+ projected_embeds, nn_indices = nn_project(prompt_embeds, token_embedding, print_hits=False)
177
+
178
+ # get cosine similarity score with all target features
179
+ with torch.no_grad():
180
+ # padded_embeds = copy.deepcopy(dummy_embeds)
181
+ padded_embeds = dummy_embeds.detach().clone()
182
+ padded_embeds[dummy_ids == -1] = projected_embeds.reshape(-1, p_dim)
183
+ logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, universal_target_features)
184
+ scores_per_prompt = logits_per_image.mean(dim=0)
185
+ universal_cosim_score = scores_per_prompt.max().item()
186
+ best_indx = scores_per_prompt.argmax().item()
187
+
188
+ # tmp_embeds = copy.deepcopy(prompt_embeds)
189
+ tmp_embeds = prompt_embeds.detach().clone()
190
+ tmp_embeds.data = projected_embeds.data
191
+ tmp_embeds.requires_grad = True
192
+
193
+ # padding
194
+ # padded_embeds = copy.deepcopy(dummy_embeds)
195
+ padded_embeds = dummy_embeds.detach().clone()
196
+ padded_embeds[dummy_ids == -1] = tmp_embeds.reshape(-1, p_dim)
197
+
198
+ logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, target_features)
199
+ cosim_scores = logits_per_image
200
+ loss = 1 - cosim_scores.mean()
201
+ loss = loss * args["loss_weight"]
202
+
203
+ prompt_embeds.grad, = torch.autograd.grad(loss, [tmp_embeds])
204
+
205
+ input_optimizer.step()
206
+ input_optimizer.zero_grad()
207
+
208
+ curr_lr = input_optimizer.param_groups[0]["lr"]
209
+ cosim_scores = cosim_scores.mean().item()
210
+
211
+ decoded_text = decode_ids(nn_indices, tokenizer)[best_indx]
212
+ if print_step is not None and (step % print_step == 0 or step == opt_iters-1):
213
+ per_step_message = f"step: {step}, lr: {curr_lr}"
214
+ # if not print_new_best:
215
+ # per_step_message = f"\n{per_step_message}, cosim: {universal_cosim_score:.3f}, text: {decoded_text}"
216
+ # print(per_step_message)
217
+
218
+ if best_sim * args["loss_weight"] < universal_cosim_score * args["loss_weight"]:
219
+ best_sim = universal_cosim_score
220
+ best_text = decoded_text
221
+ if print_new_best:
222
+ print(f"step: {step}, new best cosine sim: {best_sim}, new best prompt: {best_text}")
223
+
224
+ if print_step is not None:
225
+ print(f"best cosine sim: {best_sim}, best prompt: {best_text}")
226
+
227
+ return best_text
228
+
229
+
230
+ def optimize_prompt(model, preprocess, args, device, target_images=None, target_prompts=None):
231
+ token_embedding = model.token_embedding
232
+ tokenizer = open_clip.tokenizer._tokenizer
233
+ tokenizer_funct = open_clip.get_tokenizer(args["clip_model"])
234
+
235
+ all_target_features = get_target_feature(model, preprocess, tokenizer_funct, device, target_images=target_images)
236
+ learned_prompt = optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, target_prompts)
237
+
238
+ return learned_prompt
239
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ torch
4
+ transformers
5
+ git+https://github.com/huggingface/diffusers.git
6
+ sentencepiece
7
+ openai
8
+ huggingface_hub
9
+ sentence-transformers
10
+ ftfy
11
+ mediapy
12
+ open-clip-torch==2.24.0
13
+ gspread
utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from diffusers import DiffusionPipeline, FluxPipeline
3
+ from live_preview_helpers import FLUXPipelineWithIntermediateOutputs
4
+ import torch
5
+ import os
6
+ from openai import OpenAI
7
+ import subprocess
8
+
9
+ T2I_MODELS = {
10
+ "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
11
+ "SDXL-Turbo": "stabilityai/sdxl-turbo",
12
+ "Stable Diffusion v3.5-medium": "stabilityai/stable-diffusion-3.5-medium", # Default
13
+ "Flux.1-dev": "black-forest-labs/FLUX.1-dev",
14
+ }
15
+
16
+ SCENARIOS = {
17
+ "Product advertisement": "You are designing an advertising campaign for a new line of coffee machines. To ensure the campaign resonates with a wider audience, you use generative models to create marketing images that showcase a variety of users interacting with the product.",
18
+ "Tourist promotion": "You are creating a travel campaign to attract a diverse range of visitors to a specific destination. To make the promotional materials more engaging and inclusive, you use generative models to design posters that highlight a broader array of experiences.",
19
+ "Fictional character generation": "You are creating a narrative superhero game where the player often interacts with multiple other non-player characters in the story. To test how different characters would affect the experience of gameplay, you decide to use generative models to help construct characters for (play)testing.",
20
+ "Interior Design": "You have a one-bedroom apartment and want to arrange your bed, desk, and dresser in the best way possible. You love the color white and want to ensure your space feels bright and open. To make a decision, you need a way to visualize different furniture placements before setting everything up.",
21
+ # "Education & accessibility": "You are a grade school teacher and the lesson of the day is about teamwork. Some of your students may have a difficult time visualizing what teamwork looks like because they are either (1) too young, (2) English is not their first language, or (3) they may have cognitive impairments that make it difficult for them to visualize concepts (e.g. aphantasia).."
22
+ }
23
+
24
+ PROMPTS = {
25
+ "Product advertisement": "Design a marketing advertisement image for a coffee machine.",
26
+ "Tourist promotion": "Design a travel promotional poster to showcase the beauty and attractions of a tourist destination.",
27
+ "Fictional character generation": "Generate a character of a superhero.",
28
+ "Interior Design": "Generate an one-bedroom apartment interior design.",
29
+ # "Education & accessibility": "Generate an image of grade school students buildind a sandcastle together on the beach."
30
+ }
31
+
32
+ IMAGES = {
33
+ "Product advertisement": {"baseline": ["images/scenario1_base1.png","images/scenario1_base2.png","images/scenario1_base3.png","images/scenario1_base4.png"],
34
+ "ours": ["images/scenario1_our1.png","images/scenario1_our2.png","images/scenario1_our3.png","images/scenario1_our4.png"]},
35
+ "Tourist promotion": {"baseline": ["images/scenario5_base1.png","images/scenario5_base2.png","images/scenario5_base3.png","images/scenario5_base4.png"],
36
+ "ours": ["images/scenario5_our1.png","images/scenario5_our2.png","images/scenario5_our3.png","images/scenario5_our4.png"]},
37
+ "Fictional character generation": {"baseline": ["images/scenario2_base1.png","images/scenario2_base2.png","images/scenario2_base3.png","images/scenario2_base4.png"],
38
+ "ours": ["images/scenario2_our1.png","images/scenario2_our2.png","images/scenario2_our3.png","images/scenario2_our4.png"]},
39
+ "Interior Design": {"baseline": ["images/scenario3_base1.png","images/scenario3_base2.png","images/scenario3_base3.png","images/scenario3_base4.png"],
40
+ "ours": ["images/scenario3_our1.png","images/scenario3_our2.png","images/scenario3_our3.png","images/scenario3_our4.png"]},
41
+ # "Education & accessibility": {"baseline": ["images/scenario4_base1.png","images/scenario4_base2.png","images/scenario4_base3.png","images/scenario4_base4.png"],
42
+ # "ours": ["images/scenario4_our1.png","images/scenario4_our2.png","images/scenario4_our3.png","images/scenario4_our4.png"]},
43
+ }
44
+
45
+ OPTIONS = ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"]
46
+
47
+ INSTRUCTION = "📌 **Instruction**: Now, we want to understand your satisfaction with the images generated. <br /> 📌 Step 1: You will start from evaluating the following images based on the given prompt. <br /> 📌 Step 2: Then please modify the prompt according to your expectations for the given scenario background, and answer the evaluation question **until you are satisfied** with at least one of the images generated below. If you are not satisfied with the generated images, you can repeatedly modify the prompts for at most **5 times**."
48
+ def clean_cache():
49
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
50
+ if torch.cuda.is_available():
51
+ torch.cuda.empty_cache()
52
+
53
+ def setup_model(t2i_model_repo, torch_dtype, device):
54
+ if t2i_model_repo == "stabilityai/sdxl-turbo" or t2i_model_repo == "stabilityai/stable-diffusion-3.5-medium" or t2i_model_repo == "stabilityai/stable-diffusion-2-1":
55
+ pipe = DiffusionPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
56
+ elif t2i_model_repo == "black-forest-labs/FLUX.1-dev":
57
+ # pipe = FluxPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
58
+ pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
59
+ torch.cuda.empty_cache()
60
+
61
+ return pipe
62
+
63
+ def init_gpt_api():
64
+ return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
65
+
66
+ def call_gpt_api(messages, client, model, seed, max_tokens, temperature, top_p):
67
+ completion = client.chat.completions.create(
68
+ model=model,
69
+ messages=messages,
70
+ seed=seed,
71
+ max_tokens=max_tokens,
72
+ temperature=temperature,
73
+ top_p=top_p,
74
+ )
75
+ return completion.choices[0].message.content
76
+
77
+ def clean_response_gpt(res: str):
78
+ prompts = re.findall(r'\d+\.\s"?(.*?)"?(?=\n|$)', res)
79
+ return prompts
80
+
81
+
82
+ def get_refine_msg(prompt, num_prompts):
83
+ messages = [{"role": "system", "content": f"You are a helpful, respectful and precise assistant. You will be asked to generate {num_prompts} refined prompts. Only respond with those refined prompts"}]
84
+
85
+ message = f"""Given a prompt, modify the prompt for me to explore variations in subject attributes, actions, and contextual details, while retaining the semantic consistency of the original description.
86
+
87
+ Follow the following refinement instruction:
88
+ 1. Subject: refine broad terms into specific subsets, focusing on but not restricted on ethinity, gender, age of human.
89
+ 2. Object: modify the brand, color of object(s) only if it's not specified in the prompt.
90
+ 3. Setting: add details to the background environment, such as change of temporal or spatial details (e.g., day to night, indoor to outdoor).
91
+ 4. Action: add more details to the action or specify the object or goal of the action.
92
+
93
+ For example, given this prompt: a person is drinking a coffee in a coffee shop, the refined prompts could be:
94
+ 'an elderly woman is drinking a coffee in a coffee shop' (subject adjective)
95
+ 'an asian young woman is drinking a coffee in a coffee shop' (subject adjective)
96
+ 'a young woman is drinking a hot coffee with her left hand in a coffee shop' (action details)
97
+ 'a woman is drinking a coffee in an outdoor coffee shop in the garden' (setting details)
98
+ If there is no human in the sentence, you do not need to add person intentionally.
99
+ If you use adjectives, they should be visual. So don't use something like 'interesting'. Please also vary the number of modifications but do not change the number of subjects/objects that have been specified in the prompt. Remember not to change the predefined concepts that have been specified in the prompt. e.g. don't change a boy to several boys.
100
+
101
+ Can you give me {num_prompts} modified prompts for the prompt '{prompt}' please."""
102
+
103
+ messages.append({"role": "user", "content": f"{message}"})
104
+ return messages