evansh666
		
	commited on
		
		
					Commit 
							
							·
						
						9cf98ec
	
1
								Parent(s):
							
							4c26858
								
first commit
Browse files- .DS_Store +0 -0
- app.py +514 -4
- images/.DS_Store +0 -0
- images/scenario1_base1.png +0 -0
- images/scenario1_base2.png +0 -0
- images/scenario1_base3.png +0 -0
- images/scenario1_base4.png +0 -0
- images/scenario1_our1.png +0 -0
- images/scenario1_our2.png +0 -0
- images/scenario1_our3.png +0 -0
- images/scenario1_our4.png +0 -0
- images/scenario2_base1.png +0 -0
- images/scenario2_base2.png +0 -0
- images/scenario2_base3.png +0 -0
- images/scenario2_base4.png +0 -0
- images/scenario2_our1.png +0 -0
- images/scenario2_our2.png +0 -0
- images/scenario2_our3.png +0 -0
- images/scenario2_our4.png +0 -0
- images/scenario3_base1.png +0 -0
- images/scenario3_base2.png +0 -0
- images/scenario3_base3.png +0 -0
- images/scenario3_base4.png +0 -0
- images/scenario3_our1.png +0 -0
- images/scenario3_our2.png +0 -0
- images/scenario3_our3.png +0 -0
- images/scenario3_our4.png +0 -0
- images/scenario4_base1.png +0 -0
- images/scenario4_base2.png +0 -0
- images/scenario4_base3.png +0 -0
- images/scenario4_base4.png +0 -0
- images/scenario4_our1.png +0 -0
- images/scenario4_our2.png +0 -0
- images/scenario4_our3.png +0 -0
- images/scenario4_our4.png +0 -0
- images/scenario5_base1.png +0 -0
- images/scenario5_base2.png +0 -0
- images/scenario5_base3.png +0 -0
- images/scenario5_base4.png +0 -0
- images/scenario5_our1.png +0 -0
- images/scenario5_our2.png +0 -0
- images/scenario5_our3.png +0 -0
- images/scenario5_our4.png +0 -0
- live_preview_helpers.py +172 -0
- optim_utils.py +239 -0
- requirements.txt +13 -0
- utils.py +104 -0
    	
        .DS_Store
    ADDED
    
    | Binary file (8.2 kB). View file | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -1,7 +1,517 @@ | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 |  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            from gradio.themes.base import Base
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import spaces
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
            import open_clip
         | 
| 9 | 
            +
            from optim_utils import optimize_prompt
         | 
| 10 | 
            +
            from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache
         | 
| 11 | 
            +
            from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION
         | 
| 12 | 
            +
            import spaces #[uncomment to use ZeroGPU]
         | 
| 13 | 
            +
            import transformers
         | 
| 14 | 
            +
            import gspread
         | 
| 15 | 
            +
            import asyncio
         | 
| 16 | 
            +
            from datetime import datetime
         | 
| 17 |  | 
| 18 | 
            +
            CLIP_MODEL = "ViT-H-14"
         | 
| 19 | 
            +
            PRETRAINED_CLIP = "laion2b_s32b_b79k"
         | 
| 20 | 
            +
            default_t2i_model = "black-forest-labs/FLUX.1-dev" # "black-forest-labs/FLUX.1-dev" 
         | 
| 21 | 
            +
            default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # "meta-llama/Meta-Llama-3-8B-Instruct"
         | 
| 22 | 
            +
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 23 | 
            +
            MAX_IMAGE_SIZE = 1024
         | 
| 24 | 
            +
            NUM_IMAGES=4
         | 
| 25 |  | 
| 26 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 27 | 
            +
            torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
         | 
| 28 | 
            +
            clean_cache() 
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
         | 
| 31 | 
            +
            # clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
         | 
| 32 | 
            +
            llm_pipe = None
         | 
| 33 | 
            +
            torch.cuda.empty_cache()
         | 
| 34 | 
            +
            inverted_prompt = ""
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            VERBAL_MSG = "Please verbally describe key differences found in the image pair."
         | 
| 37 | 
            +
            DEFAULT_SCENARIO = "Product advertisement"
         | 
| 38 | 
            +
            METHODS = ["Method 1", "Method 2"]
         | 
| 39 | 
            +
            MAX_ROUND = 5
         | 
| 40 | 
            +
            # intermittent memory
         | 
| 41 | 
            +
            counter1, counter2 = 1, 1
         | 
| 42 | 
            +
            responses_memory = {}
         | 
| 43 | 
            +
            assigned_scenarios = list(SCENARIOS.keys())[:2]
         | 
| 44 | 
            +
            current_task1, current_task2 = METHODS # current task 1 (tab 1)
         | 
| 45 | 
            +
            task1_success, task2_success = False, False
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            ########################################################################################################
         | 
| 48 | 
            +
            # Generating images with two methods
         | 
| 49 | 
            +
            ########################################################################################################
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            @spaces.GPU(duration=65)
         | 
| 53 | 
            +
            def infer(
         | 
| 54 | 
            +
                prompt,
         | 
| 55 | 
            +
                negative_prompt="",
         | 
| 56 | 
            +
                seed=42,
         | 
| 57 | 
            +
                randomize_seed=True,
         | 
| 58 | 
            +
                width=256,
         | 
| 59 | 
            +
                height=256,
         | 
| 60 | 
            +
                guidance_scale=5,
         | 
| 61 | 
            +
                num_inference_steps=18,
         | 
| 62 | 
            +
                progress=gr.Progress(track_tqdm=True),
         | 
| 63 | 
            +
            ):
         | 
| 64 | 
            +
                if randomize_seed:
         | 
| 65 | 
            +
                    seed = random.randint(0, MAX_SEED)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                generator = torch.Generator().manual_seed(seed)
         | 
| 68 | 
            +
                with torch.no_grad():
         | 
| 69 | 
            +
                    image = selected_pipe(
         | 
| 70 | 
            +
                        prompt=prompt,
         | 
| 71 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 72 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 73 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 74 | 
            +
                        width=width,
         | 
| 75 | 
            +
                        height=height,
         | 
| 76 | 
            +
                        generator=generator,
         | 
| 77 | 
            +
                    ).images[0]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                return image
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            async def infer_async(prompt):
         | 
| 82 | 
            +
                return infer(prompt)
         | 
| 83 | 
            +
            # generate a batch of images in parallel
         | 
| 84 | 
            +
            async def generate_batch(prompts):
         | 
| 85 | 
            +
                tasks = [infer_async(p) for p in prompts]
         | 
| 86 | 
            +
                images = await asyncio.gather(*tasks)  # Run all in parallel
         | 
| 87 | 
            +
                return images
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            @spaces.GPU 
         | 
| 90 | 
            +
            def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
         | 
| 91 | 
            +
                print(f"loading {default_llm_model}")
         | 
| 92 | 
            +
                global llm_pipe
         | 
| 93 | 
            +
                if not llm_pipe:
         | 
| 94 | 
            +
                    llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                messages = get_refine_msg(prmpt, num_prompts)
         | 
| 97 | 
            +
                terminators = [
         | 
| 98 | 
            +
                    llm_pipe.tokenizer.eos_token_id,
         | 
| 99 | 
            +
                    llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
         | 
| 100 | 
            +
                ]
         | 
| 101 | 
            +
                outputs = llm_pipe(
         | 
| 102 | 
            +
                    messages,
         | 
| 103 | 
            +
                    max_new_tokens=max_tokens,
         | 
| 104 | 
            +
                    eos_token_id=terminators,
         | 
| 105 | 
            +
                    do_sample=True,
         | 
| 106 | 
            +
                    temperature=temperature,
         | 
| 107 | 
            +
                    top_p=top_p,
         | 
| 108 | 
            +
                )
         | 
| 109 | 
            +
                prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
         | 
| 110 | 
            +
                return prompt_list
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
         | 
| 113 | 
            +
                seed = random.randint(0, MAX_SEED)
         | 
| 114 | 
            +
                client = init_gpt_api()
         | 
| 115 | 
            +
                messages = get_refine_msg(prompt, num_prompts)
         | 
| 116 | 
            +
                outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p)
         | 
| 117 | 
            +
                prompt_list = clean_response_gpt(outputs)
         | 
| 118 | 
            +
                return prompt_list
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            def refine_prompt(gallery_state, prompt):
         | 
| 121 | 
            +
                modified_prompts = call_gpt_refine_prompt(prompt)
         | 
| 122 | 
            +
                return modified_prompts
         | 
| 123 | 
            +
                
         | 
| 124 | 
            +
                # eval(prompt, inverted_prompt, gallery_state, clip_model, preprocess)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            @spaces.GPU(duration=100)
         | 
| 127 | 
            +
            def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
         | 
| 128 | 
            +
                text_params = {
         | 
| 129 | 
            +
                    "iter": iter,
         | 
| 130 | 
            +
                    "lr": lr,
         | 
| 131 | 
            +
                    "batch_size": batch_size,
         | 
| 132 | 
            +
                    "prompt_len": prompt_len,
         | 
| 133 | 
            +
                    "weight_decay": 0.1,
         | 
| 134 | 
            +
                    "prompt_bs": 1,
         | 
| 135 | 
            +
                    "loss_weight": 1.0,
         | 
| 136 | 
            +
                    "print_step": 100,
         | 
| 137 | 
            +
                    "clip_model": CLIP_MODEL,
         | 
| 138 | 
            +
                    "clip_pretrain": PRETRAINED_CLIP,
         | 
| 139 | 
            +
                }
         | 
| 140 | 
            +
                inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
         | 
| 143 | 
            +
                # return learned_prompt
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def eval(prompt, optimized_prompt, optimized_images, clip_model, preprocess):
         | 
| 147 | 
            +
                torch.cuda.empty_cache()
         | 
| 148 | 
            +
                tokenizer = open_clip.get_tokenizer(CLIP_MODEL)
         | 
| 149 | 
            +
                images = [preprocess(i).unsqueeze(0) for i in optimized_images]
         | 
| 150 | 
            +
                images = torch.concatenate(images).to(device)
         | 
| 151 | 
            +
                
         | 
| 152 | 
            +
                with torch.no_grad():
         | 
| 153 | 
            +
                    image_feat = clip_model.encode_image(images)
         | 
| 154 | 
            +
                    text_feat = clip_model.encode_text(tokenizer([prompt]).to(device))
         | 
| 155 | 
            +
                    optimized_text_feat = clip_model.encode_text(tokenizer([optimized_prompt]).to(device))
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                image_feat /= image_feat.norm(dim=-1, keepdim=True)
         | 
| 158 | 
            +
                text_feat /= text_feat.norm(dim=-1, keepdim=True)
         | 
| 159 | 
            +
                optimized_text_feat /= optimized_text_feat.norm(dim=-1, keepdim=True)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                similarity = text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
         | 
| 162 | 
            +
                similarity_optimized = optimized_text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            ########################################################################################################
         | 
| 166 | 
            +
            # Button-related functions
         | 
| 167 | 
            +
            ########################################################################################################
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            def reset_gallery():
         | 
| 170 | 
            +
                return []
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            def display_error_message(msg, duration=5):
         | 
| 173 | 
            +
                gr.Warning(msg, duration=duration)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            def display_info_message(msg, duration=5):
         | 
| 176 | 
            +
                gr.Info(msg, duration=duration)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            def switch_tab(active_tab):
         | 
| 179 | 
            +
                print("switching tab")
         | 
| 180 | 
            +
                if active_tab == "Task A":
         | 
| 181 | 
            +
                    return gr.Tabs(selected="Task B")
         | 
| 182 | 
            +
                else:
         | 
| 183 | 
            +
                    return gr.Tabs(selected="Task A")
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            def set_user(participant):
         | 
| 186 | 
            +
                global responses_memory
         | 
| 187 | 
            +
                responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                id = re.findall(r'\d+', participant)
         | 
| 190 | 
            +
                if len(id) == 0 or int(id[0]) % 2 == 0: # name invalid, assign first half scenarios
         | 
| 191 | 
            +
                    assigned_scenarios = list(SCENARIOS.keys())[:2]
         | 
| 192 | 
            +
                else:
         | 
| 193 | 
            +
                    assigned_scenarios = list(SCENARIOS.keys())[2:]
         | 
| 194 | 
            +
                return assigned_scenarios[0]
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            def display_scenario(participant, choice):
         | 
| 197 | 
            +
                # reset intermittent storage when scenario change
         | 
| 198 | 
            +
                global counter1, counter2, responses_memory, current_task1, current_task2, task1_success, task2_success
         | 
| 199 | 
            +
                
         | 
| 200 | 
            +
                task1_success, task2_success = False, False
         | 
| 201 | 
            +
                counter1, counter2 = 1, 1
         | 
| 202 | 
            +
                
         | 
| 203 | 
            +
                if check_participant(participant):
         | 
| 204 | 
            +
                    responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                [current_task1, current_task2] = random.sample(METHODS, 2)
         | 
| 207 | 
            +
                if current_task1 == METHODS[0]:
         | 
| 208 | 
            +
                    initial_images1 = IMAGES[choice]["baseline"]
         | 
| 209 | 
            +
                    initial_images2 = IMAGES[choice]["ours"]
         | 
| 210 | 
            +
                else:
         | 
| 211 | 
            +
                    initial_images1 = IMAGES[choice]["ours"]
         | 
| 212 | 
            +
                    initial_images2 = IMAGES[choice]["baseline"]
         | 
| 213 | 
            +
                
         | 
| 214 | 
            +
                res = { 
         | 
| 215 | 
            +
                    scenario_content: SCENARIOS.get(choice, ""), 
         | 
| 216 | 
            +
                    prompt: PROMPTS.get(choice, ""),
         | 
| 217 | 
            +
                    prompt1: "", 
         | 
| 218 | 
            +
                    prompt2: "",
         | 
| 219 | 
            +
                    images_method1: initial_images1, 
         | 
| 220 | 
            +
                    images_method2: initial_images2,
         | 
| 221 | 
            +
                    gallery_state1: initial_images1, 
         | 
| 222 | 
            +
                    gallery_state2: initial_images2, 
         | 
| 223 | 
            +
                    sim_radio1: None, 
         | 
| 224 | 
            +
                    sim_radio2: None, 
         | 
| 225 | 
            +
                    response1: VERBAL_MSG, 
         | 
| 226 | 
            +
                    response2: VERBAL_MSG, 
         | 
| 227 | 
            +
                    next_btn1: gr.update(interactive=False), 
         | 
| 228 | 
            +
                    next_btn2: gr.update(interactive=False), 
         | 
| 229 | 
            +
                    redesign_btn1: gr.update(interactive=True), 
         | 
| 230 | 
            +
                    redesign_btn2: gr.update(interactive=True),
         | 
| 231 | 
            +
                    submit_btn1: gr.update(interactive=False),
         | 
| 232 | 
            +
                    submit_btn2: gr.update(interactive=False),
         | 
| 233 | 
            +
                }
         | 
| 234 | 
            +
                return res
         | 
| 235 | 
            +
             | 
| 236 | 
            +
            def generate_image(participant, scenario, prompt, gallery_state, active_tab):
         | 
| 237 | 
            +
                if not check_participant(participant): return [], []
         | 
| 238 | 
            +
                global current_task1, current_task2
         | 
| 239 | 
            +
                
         | 
| 240 | 
            +
                method = current_task1 if active_tab == "Task A" else current_task2
         | 
| 241 | 
            +
                
         | 
| 242 | 
            +
                if method == METHODS[0]:
         | 
| 243 | 
            +
                    for i in range(NUM_IMAGES): 
         | 
| 244 | 
            +
                        img = infer(prompt)
         | 
| 245 | 
            +
                        gallery_state.append(img)
         | 
| 246 | 
            +
                        yield gallery_state
         | 
| 247 | 
            +
                else:
         | 
| 248 | 
            +
                    refined_prompts = refine_prompt(gallery_state, prompt)
         | 
| 249 | 
            +
                    for i in range(NUM_IMAGES): 
         | 
| 250 | 
            +
                        img = infer(refined_prompts[i])
         | 
| 251 | 
            +
                        gallery_state.append(img)
         | 
| 252 | 
            +
                        yield gallery_state
         | 
| 253 | 
            +
             | 
| 254 | 
            +
            def check_satisfaction(sim_radio, active_tab):
         | 
| 255 | 
            +
                global counter1, counter2, current_task1, current_task2
         | 
| 256 | 
            +
                method = current_task1 if active_tab == "Task A" else current_task2
         | 
| 257 | 
            +
                counter = counter1 if method == METHODS[0] else counter2
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                fully_satisfied_option = ["Satisfied", "Very Satisfied"]  # The value to trigger submit
         | 
| 260 | 
            +
                enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
         | 
| 261 | 
            +
               
         | 
| 262 | 
            +
                return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit)) 
         | 
| 263 | 
            +
             | 
| 264 | 
            +
            def check_participant(participant):
         | 
| 265 | 
            +
                if participant == "":
         | 
| 266 | 
            +
                    display_error_message("Please fill your participant id!")
         | 
| 267 | 
            +
                    return False
         | 
| 268 | 
            +
                return True
         | 
| 269 | 
            +
             | 
| 270 | 
            +
            def check_evaluation(sim_radio, response):
         | 
| 271 | 
            +
                if not sim_radio :
         | 
| 272 | 
            +
                    display_error_message("❌ Please fill all evaluations before change image or submit.")
         | 
| 273 | 
            +
                    return False
         | 
| 274 | 
            +
                
         | 
| 275 | 
            +
                return True
         | 
| 276 | 
            +
             | 
| 277 | 
            +
            def redesign(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
         | 
| 278 | 
            +
                global counter1, counter2, responses_memory, current_task1, current_task2
         | 
| 279 | 
            +
                method = current_task1 if active_tab == "Task A" else current_task2
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                if check_evaluation(sim_radio, response) and check_participant(participant):
         | 
| 282 | 
            +
                    if method == METHODS[0]:
         | 
| 283 | 
            +
                        counter1 += 1
         | 
| 284 | 
            +
                        counter = counter1
         | 
| 285 | 
            +
                    else:
         | 
| 286 | 
            +
                        counter2 += 1
         | 
| 287 | 
            +
                        counter = counter2
         | 
| 288 | 
            +
                    
         | 
| 289 | 
            +
                    responses_memory[participant][method][counter-1] = {}
         | 
| 290 | 
            +
                    responses_memory[participant][method][counter-1]["prompt"] = prompt
         | 
| 291 | 
            +
                    responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
         | 
| 292 | 
            +
                    responses_memory[participant][method][counter-1]["response"] = response
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    prompt_state = gr.update(visible=True)
         | 
| 295 | 
            +
                    next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
         | 
| 296 | 
            +
                    redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
         | 
| 297 | 
            +
                    submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    return [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state
         | 
| 300 | 
            +
                else:
         | 
| 301 | 
            +
                    return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
         | 
| 302 | 
            +
             | 
| 303 | 
            +
            def show_message(selected_option):
         | 
| 304 | 
            +
                if selected_option:
         | 
| 305 | 
            +
                    return "Click \"Redesign\" and revise your prompt to create images that may more closely match your expectations."
         | 
| 306 | 
            +
                return ""  
         | 
| 307 | 
            +
              
         | 
| 308 | 
            +
            def save_response(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
         | 
| 309 | 
            +
                global current_task1, current_task2, counter1, counter2, responses_memory, task1_success, task2_success, assigned_scenarios
         | 
| 310 | 
            +
                method = current_task1 if active_tab == "Task A" else current_task2
         | 
| 311 | 
            +
                
         | 
| 312 | 
            +
                if check_evaluation(sim_radio, response) and check_participant(participant):
         | 
| 313 | 
            +
                    counter = counter1 if method == METHODS[0] else counter2
         | 
| 314 | 
            +
                    # image_paths = [save_image(img, "method", i) for i, img in enumerate(images_method)]
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    responses_memory[participant][method][counter] = {}
         | 
| 317 | 
            +
                    responses_memory[participant][method][counter]["prompt"] = prompt
         | 
| 318 | 
            +
                    responses_memory[participant][method][counter]["sim_radio"] = sim_radio
         | 
| 319 | 
            +
                    responses_memory[participant][method][counter]["response"] = response
         | 
| 320 | 
            +
                    prompt_state = gr.update(visible=False)
         | 
| 321 | 
            +
                    next_state = gr.update(visible=False, interactive=False)
         | 
| 322 | 
            +
                    submit_state = gr.update(interactive=False) 
         | 
| 323 | 
            +
                    redesign_state = gr.update(interactive=False) 
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    try:
         | 
| 326 | 
            +
                        gc = gspread.service_account(filename='credentials.json')
         | 
| 327 | 
            +
                        sheet = gc.open("DiverseGen-phase2").sheet1 
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                        for i, entry in responses_memory[participant][method].items():
         | 
| 330 | 
            +
                            sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"],  entry["response"]])
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                        display_info_message("✅ Your answer is saved!")
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        # reset counter and update success indicator
         | 
| 335 | 
            +
                        if method == METHODS[0]:
         | 
| 336 | 
            +
                            counter1 = 1
         | 
| 337 | 
            +
                        else:
         | 
| 338 | 
            +
                            counter2 = 1
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                        if active_tab == "Task A":
         | 
| 341 | 
            +
                            task1_success = True
         | 
| 342 | 
            +
                        else:
         | 
| 343 | 
            +
                            task2_success = True
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                        tabs = switch_tab(active_tab)
         | 
| 346 | 
            +
                        next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
         | 
| 347 | 
            +
                        return [], [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state, tabs, next_scenario
         | 
| 348 | 
            +
                    except Exception as e:
         | 
| 349 | 
            +
                        display_error_message(f"❌ Error saving response: {str(e)}")
         | 
| 350 | 
            +
                        return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
         | 
| 351 | 
            +
                else:
         | 
| 352 | 
            +
                    return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
         | 
| 353 | 
            +
             | 
| 354 | 
            +
             | 
| 355 | 
            +
            ########################################################################################################
         | 
| 356 | 
            +
            # Interface 
         | 
| 357 | 
            +
            ########################################################################################################
         | 
| 358 | 
            +
             | 
| 359 | 
            +
            css="""
         | 
| 360 | 
            +
            #col-container {
         | 
| 361 | 
            +
                margin: 0 auto;
         | 
| 362 | 
            +
                max-width: 700px;
         | 
| 363 | 
            +
            }
         | 
| 364 | 
            +
             | 
| 365 | 
            +
            #col-container2 {
         | 
| 366 | 
            +
                margin: 0 auto;
         | 
| 367 | 
            +
                max-width: 1000px;
         | 
| 368 | 
            +
            }
         | 
| 369 | 
            +
             | 
| 370 | 
            +
            #button-container {
         | 
| 371 | 
            +
                display: flex;
         | 
| 372 | 
            +
                justify-content: center; /* Centers the buttons horizontally */
         | 
| 373 | 
            +
            }
         | 
| 374 | 
            +
            """
         | 
| 375 | 
            +
             | 
| 376 | 
            +
            with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
         | 
| 377 | 
            +
                with gr.Column(elem_id="col-container"):
         | 
| 378 | 
            +
                    gr.Markdown(" # 📌 **Diverse Text-to-Image Generation**")
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    with gr.Row():
         | 
| 381 | 
            +
                        participant = gr.Textbox(
         | 
| 382 | 
            +
                            label="🧑💼 Participant ID", placeholder="Please enter you participant id"
         | 
| 383 | 
            +
                        )
         | 
| 384 | 
            +
                        scenario = gr.Dropdown(
         | 
| 385 | 
            +
                            choices=list(SCENARIOS.keys()),
         | 
| 386 | 
            +
                            # value=DEFAULT_SCENARIO,
         | 
| 387 | 
            +
                            value=None,
         | 
| 388 | 
            +
                            label="���� Scenario",
         | 
| 389 | 
            +
                            interactive=False,
         | 
| 390 | 
            +
                        )
         | 
| 391 | 
            +
                    scenario_content = gr.Textbox(
         | 
| 392 | 
            +
                        label="📖 Background", 
         | 
| 393 | 
            +
                        interactive=False, 
         | 
| 394 | 
            +
                        # value=SCENARIOS[DEFAULT_SCENARIO]
         | 
| 395 | 
            +
                    )
         | 
| 396 | 
            +
                    prompt = gr.Textbox(
         | 
| 397 | 
            +
                            label="🎨 Prompt",
         | 
| 398 | 
            +
                            max_lines=1,
         | 
| 399 | 
            +
                            # value=PROMPTS[DEFAULT_SCENARIO],
         | 
| 400 | 
            +
                            interactive=False
         | 
| 401 | 
            +
                    )
         | 
| 402 | 
            +
                    active_tab = gr.State("Task A")
         | 
| 403 | 
            +
                    instruction = gr.Markdown(INSTRUCTION)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                with gr.Tabs() as tabs:
         | 
| 406 | 
            +
                    with gr.TabItem("Task A", id="Task A") as task1_tab:
         | 
| 407 | 
            +
                        task1_tab.select(lambda: "Task A", outputs=[active_tab])
         | 
| 408 | 
            +
                        with gr.Column(elem_id="col-container"):        
         | 
| 409 | 
            +
                            # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
         | 
| 410 | 
            +
                            with gr.Row():
         | 
| 411 | 
            +
                                prompt1 = gr.Textbox(
         | 
| 412 | 
            +
                                        label="🎨 Revise Prompt",
         | 
| 413 | 
            +
                                        max_lines=1,
         | 
| 414 | 
            +
                                        placeholder="Enter your prompt",
         | 
| 415 | 
            +
                                        # value=PROMPTS[DEFAULT_SCENARIO],
         | 
| 416 | 
            +
                                        scale=4, 
         | 
| 417 | 
            +
                                        visible=False
         | 
| 418 | 
            +
                                )
         | 
| 419 | 
            +
                                next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                        with gr.Column(elem_id="col-container"):
         | 
| 422 | 
            +
                            gallery_state1 = gr.State(IMAGES[DEFAULT_SCENARIO]["baseline"])
         | 
| 423 | 
            +
                            images_method1 = gr.Gallery(show_label=False, columns=[4], rows=[1], elem_id="gallery")
         | 
| 424 | 
            +
                        with gr.Column(elem_id="col-container2"):
         | 
| 425 | 
            +
                            gr.Markdown("### 📝 Evaluation")               
         | 
| 426 | 
            +
                            sim_radio1 = gr.Radio(
         | 
| 427 | 
            +
                                OPTIONS, 
         | 
| 428 | 
            +
                                label="How would you evaluate your satisfaction with the generated images, based on your expectations for the specified scenario?",
         | 
| 429 | 
            +
                                type="value",
         | 
| 430 | 
            +
                                elem_classes=["gradio-radio"]
         | 
| 431 | 
            +
                            )
         | 
| 432 | 
            +
                            response1 = gr.Textbox(
         | 
| 433 | 
            +
                                label="Verbally describe key differences found in the image pair.",
         | 
| 434 | 
            +
                                max_lines=1,
         | 
| 435 | 
            +
                                interactive=False,
         | 
| 436 | 
            +
                                container=False,
         | 
| 437 | 
            +
                                value=VERBAL_MSG
         | 
| 438 | 
            +
                            )
         | 
| 439 | 
            +
                            
         | 
| 440 | 
            +
                            with gr.Row(elem_id="button-container"):
         | 
| 441 | 
            +
                                redesign_btn1 = gr.Button("🎨 Redesign", variant="primary", scale=0)
         | 
| 442 | 
            +
                                submit_btn1 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
             | 
| 445 | 
            +
                    with gr.TabItem("Task B", id="Task B") as task2_tab:
         | 
| 446 | 
            +
                        task2_tab.select(lambda: "Task B", outputs=[active_tab])
         | 
| 447 | 
            +
                        with gr.Column(elem_id="col-container"):        
         | 
| 448 | 
            +
                            # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
         | 
| 449 | 
            +
                            with gr.Row():
         | 
| 450 | 
            +
                                prompt2 = gr.Textbox(
         | 
| 451 | 
            +
                                        label="🎨 Revise Prompt",
         | 
| 452 | 
            +
                                        max_lines=1,
         | 
| 453 | 
            +
                                        placeholder="Enter your prompt",
         | 
| 454 | 
            +
                                        # value=PROMPTS[DEFAULT_SCENARIO],
         | 
| 455 | 
            +
                                        scale=4,
         | 
| 456 | 
            +
                                        visible=False
         | 
| 457 | 
            +
                                )
         | 
| 458 | 
            +
                        
         | 
| 459 | 
            +
                                next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                        with gr.Column(elem_id="col-container"):
         | 
| 462 | 
            +
                            gallery_state2 = gr.State(IMAGES[DEFAULT_SCENARIO]["ours"])
         | 
| 463 | 
            +
                            images_method2 = gr.Gallery(show_label=False, columns=[4], rows=[1], elem_id="gallery")
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                        with gr.Column(elem_id="col-container2"):
         | 
| 466 | 
            +
                            gr.Markdown("### 📝 Evaluation")
         | 
| 467 | 
            +
                            sim_radio2 = gr.Radio(
         | 
| 468 | 
            +
                                OPTIONS, 
         | 
| 469 | 
            +
                                label="How would you evaluate your satisfaction with the generated images, based on your expectations for the specified scenario?",
         | 
| 470 | 
            +
                                type="value",
         | 
| 471 | 
            +
                                elem_classes=["gradio-radio"]
         | 
| 472 | 
            +
                            )
         | 
| 473 | 
            +
                            response2 = gr.Textbox(
         | 
| 474 | 
            +
                                label="Verbally describe key differences found in the image pair.",
         | 
| 475 | 
            +
                                max_lines=1,
         | 
| 476 | 
            +
                                interactive=False,
         | 
| 477 | 
            +
                                container=False,
         | 
| 478 | 
            +
                                value=VERBAL_MSG
         | 
| 479 | 
            +
                            )
         | 
| 480 | 
            +
                            with gr.Row(elem_id="button-container"):
         | 
| 481 | 
            +
                                redesign_btn2 = gr.Button("🎨 Redesign", variant="primary", scale=0)
         | 
| 482 | 
            +
                                submit_btn2 = gr.Button("✅ Submit", variant="primary", interactive=False, scale=0)
         | 
| 483 | 
            +
             | 
| 484 | 
            +
             | 
| 485 | 
            +
            ########################################################################################################
         | 
| 486 | 
            +
            # Button Function Setup
         | 
| 487 | 
            +
            ########################################################################################################
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
         | 
| 490 | 
            +
                scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt, prompt1, prompt2, images_method1, images_method2, gallery_state1, gallery_state2, sim_radio1, sim_radio2, response1, response2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
         | 
| 491 | 
            +
                prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
         | 
| 492 | 
            +
                prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
         | 
| 493 | 
            +
                next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, gallery_state1, active_tab], outputs=[images_method1])
         | 
| 494 | 
            +
                next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, gallery_state2, active_tab], outputs=[images_method2])
         | 
| 495 | 
            +
                sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
         | 
| 496 | 
            +
                sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
         | 
| 497 | 
            +
                redesign_btn1.click(
         | 
| 498 | 
            +
                    fn=redesign, 
         | 
| 499 | 
            +
                    inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab], 
         | 
| 500 | 
            +
                    outputs=[gallery_state1, sim_radio1, response1, prompt1, next_btn1, redesign_btn1, submit_btn1]
         | 
| 501 | 
            +
                )
         | 
| 502 | 
            +
                redesign_btn2.click(
         | 
| 503 | 
            +
                    fn=redesign, 
         | 
| 504 | 
            +
                    inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab], 
         | 
| 505 | 
            +
                    outputs=[gallery_state2, sim_radio2, response2, prompt2, next_btn2, redesign_btn2, submit_btn2]
         | 
| 506 | 
            +
                )
         | 
| 507 | 
            +
                submit_btn1.click(fn=save_response, 
         | 
| 508 | 
            +
                    inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab], 
         | 
| 509 | 
            +
                    outputs=[images_method1, gallery_state1, sim_radio1, prompt1, response1, next_btn1, redesign_btn1, submit_btn1, tabs, scenario])
         | 
| 510 | 
            +
                
         | 
| 511 | 
            +
                submit_btn2.click(fn=save_response, 
         | 
| 512 | 
            +
                    inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab], 
         | 
| 513 | 
            +
                    outputs=[images_method2, gallery_state2, sim_radio2, prompt2, response2, next_btn2, redesign_btn2, submit_btn2, tabs, scenario])
         | 
| 514 | 
            +
             | 
| 515 | 
            +
             | 
| 516 | 
            +
            if __name__ == "__main__":
         | 
| 517 | 
            +
                demo.launch()
         | 
    	
        images/.DS_Store
    ADDED
    
    | Binary file (6.15 kB). View file | 
|  | 
    	
        images/scenario1_base1.png
    ADDED
    
    |   | 
    	
        images/scenario1_base2.png
    ADDED
    
    |   | 
    	
        images/scenario1_base3.png
    ADDED
    
    |   | 
    	
        images/scenario1_base4.png
    ADDED
    
    |   | 
    	
        images/scenario1_our1.png
    ADDED
    
    |   | 
    	
        images/scenario1_our2.png
    ADDED
    
    |   | 
    	
        images/scenario1_our3.png
    ADDED
    
    |   | 
    	
        images/scenario1_our4.png
    ADDED
    
    |   | 
    	
        images/scenario2_base1.png
    ADDED
    
    |   | 
    	
        images/scenario2_base2.png
    ADDED
    
    |   | 
    	
        images/scenario2_base3.png
    ADDED
    
    |   | 
    	
        images/scenario2_base4.png
    ADDED
    
    |   | 
    	
        images/scenario2_our1.png
    ADDED
    
    |   | 
    	
        images/scenario2_our2.png
    ADDED
    
    |   | 
    	
        images/scenario2_our3.png
    ADDED
    
    |   | 
    	
        images/scenario2_our4.png
    ADDED
    
    |   | 
    	
        images/scenario3_base1.png
    ADDED
    
    |   | 
    	
        images/scenario3_base2.png
    ADDED
    
    |   | 
    	
        images/scenario3_base3.png
    ADDED
    
    |   | 
    	
        images/scenario3_base4.png
    ADDED
    
    |   | 
    	
        images/scenario3_our1.png
    ADDED
    
    |   | 
    	
        images/scenario3_our2.png
    ADDED
    
    |   | 
    	
        images/scenario3_our3.png
    ADDED
    
    |   | 
    	
        images/scenario3_our4.png
    ADDED
    
    |   | 
    	
        images/scenario4_base1.png
    ADDED
    
    |   | 
    	
        images/scenario4_base2.png
    ADDED
    
    |   | 
    	
        images/scenario4_base3.png
    ADDED
    
    |   | 
    	
        images/scenario4_base4.png
    ADDED
    
    |   | 
    	
        images/scenario4_our1.png
    ADDED
    
    |   | 
    	
        images/scenario4_our2.png
    ADDED
    
    |   | 
    	
        images/scenario4_our3.png
    ADDED
    
    |   | 
    	
        images/scenario4_our4.png
    ADDED
    
    |   | 
    	
        images/scenario5_base1.png
    ADDED
    
    |   | 
    	
        images/scenario5_base2.png
    ADDED
    
    |   | 
    	
        images/scenario5_base3.png
    ADDED
    
    |   | 
    	
        images/scenario5_base4.png
    ADDED
    
    |   | 
    	
        images/scenario5_our1.png
    ADDED
    
    |   | 
    	
        images/scenario5_our2.png
    ADDED
    
    |   | 
    	
        images/scenario5_our3.png
    ADDED
    
    |   | 
    	
        images/scenario5_our4.png
    ADDED
    
    |   | 
    	
        live_preview_helpers.py
    ADDED
    
    | @@ -0,0 +1,172 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from typing import Any, Dict, List, Optional, Union
         | 
| 4 | 
            +
            from diffusers import FluxPipeline
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Helper functions
         | 
| 7 | 
            +
            def calculate_shift(
         | 
| 8 | 
            +
                image_seq_len,
         | 
| 9 | 
            +
                base_seq_len: int = 256,
         | 
| 10 | 
            +
                max_seq_len: int = 4096,
         | 
| 11 | 
            +
                base_shift: float = 0.5,
         | 
| 12 | 
            +
                max_shift: float = 1.16,
         | 
| 13 | 
            +
            ):
         | 
| 14 | 
            +
                m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
         | 
| 15 | 
            +
                b = base_shift - m * base_seq_len
         | 
| 16 | 
            +
                mu = image_seq_len * m + b
         | 
| 17 | 
            +
                return mu
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def retrieve_timesteps(
         | 
| 20 | 
            +
                scheduler,
         | 
| 21 | 
            +
                num_inference_steps: Optional[int] = None,
         | 
| 22 | 
            +
                device: Optional[Union[str, torch.device]] = None,
         | 
| 23 | 
            +
                timesteps: Optional[List[int]] = None,
         | 
| 24 | 
            +
                sigmas: Optional[List[float]] = None,
         | 
| 25 | 
            +
                **kwargs,
         | 
| 26 | 
            +
            ):
         | 
| 27 | 
            +
                if timesteps is not None and sigmas is not None:
         | 
| 28 | 
            +
                    raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
         | 
| 29 | 
            +
                if timesteps is not None:
         | 
| 30 | 
            +
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 31 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 32 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 33 | 
            +
                elif sigmas is not None:
         | 
| 34 | 
            +
                    scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
         | 
| 35 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 36 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
         | 
| 39 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 40 | 
            +
                return timesteps, num_inference_steps
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                Extends the FluxPipeline to yield intermediate images during the denoising process 
         | 
| 46 | 
            +
                with progressively increasing resolution for faster generation.
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                # FLUX pipeline function
         | 
| 49 | 
            +
                @torch.inference_mode()
         | 
| 50 | 
            +
                def generate_images(
         | 
| 51 | 
            +
                    self,
         | 
| 52 | 
            +
                    prompt: Union[str, List[str]] = None,
         | 
| 53 | 
            +
                    prompt_2: Optional[Union[str, List[str]]] = None,
         | 
| 54 | 
            +
                    height: Optional[int] = None,
         | 
| 55 | 
            +
                    width: Optional[int] = None,
         | 
| 56 | 
            +
                    num_inference_steps: int = 28,
         | 
| 57 | 
            +
                    timesteps: List[int] = None,
         | 
| 58 | 
            +
                    guidance_scale: float = 3.5,
         | 
| 59 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 60 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 61 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 62 | 
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 63 | 
            +
                    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 64 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 65 | 
            +
                    return_dict: bool = True,
         | 
| 66 | 
            +
                    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 67 | 
            +
                    max_sequence_length: int = 512,
         | 
| 68 | 
            +
                ):
         | 
| 69 | 
            +
                    height = height or self.default_sample_size * self.vae_scale_factor
         | 
| 70 | 
            +
                    width = width or self.default_sample_size * self.vae_scale_factor
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # 1. Check inputs
         | 
| 73 | 
            +
                    self.check_inputs(
         | 
| 74 | 
            +
                        prompt,
         | 
| 75 | 
            +
                        prompt_2,
         | 
| 76 | 
            +
                        height,
         | 
| 77 | 
            +
                        width,
         | 
| 78 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 79 | 
            +
                        pooled_prompt_embeds=pooled_prompt_embeds,
         | 
| 80 | 
            +
                        max_sequence_length=max_sequence_length,
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self._guidance_scale = guidance_scale
         | 
| 84 | 
            +
                    self._joint_attention_kwargs = joint_attention_kwargs
         | 
| 85 | 
            +
                    self._interrupt = False
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    # 2. Define call parameters
         | 
| 88 | 
            +
                    batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 89 | 
            +
                    device = self._execution_device
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    # 3. Encode prompt
         | 
| 92 | 
            +
                    lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
         | 
| 93 | 
            +
                    prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
         | 
| 94 | 
            +
                        prompt=prompt,
         | 
| 95 | 
            +
                        prompt_2=prompt_2,
         | 
| 96 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 97 | 
            +
                        pooled_prompt_embeds=pooled_prompt_embeds,
         | 
| 98 | 
            +
                        device=device,
         | 
| 99 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 100 | 
            +
                        max_sequence_length=max_sequence_length,
         | 
| 101 | 
            +
                        lora_scale=lora_scale,
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    # 4. Prepare latent variables
         | 
| 104 | 
            +
                    num_channels_latents = self.transformer.config.in_channels // 4
         | 
| 105 | 
            +
                    latents, latent_image_ids = self.prepare_latents(
         | 
| 106 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 107 | 
            +
                        num_channels_latents,
         | 
| 108 | 
            +
                        height,
         | 
| 109 | 
            +
                        width,
         | 
| 110 | 
            +
                        prompt_embeds.dtype,
         | 
| 111 | 
            +
                        device,
         | 
| 112 | 
            +
                        generator,
         | 
| 113 | 
            +
                        latents,
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
                    # 5. Prepare timesteps
         | 
| 116 | 
            +
                    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
         | 
| 117 | 
            +
                    image_seq_len = latents.shape[1]
         | 
| 118 | 
            +
                    mu = calculate_shift(
         | 
| 119 | 
            +
                        image_seq_len,
         | 
| 120 | 
            +
                        self.scheduler.config.base_image_seq_len,
         | 
| 121 | 
            +
                        self.scheduler.config.max_image_seq_len,
         | 
| 122 | 
            +
                        self.scheduler.config.base_shift,
         | 
| 123 | 
            +
                        self.scheduler.config.max_shift,
         | 
| 124 | 
            +
                    )
         | 
| 125 | 
            +
                    timesteps, num_inference_steps = retrieve_timesteps(
         | 
| 126 | 
            +
                        self.scheduler,
         | 
| 127 | 
            +
                        num_inference_steps,
         | 
| 128 | 
            +
                        device,
         | 
| 129 | 
            +
                        timesteps,
         | 
| 130 | 
            +
                        sigmas,
         | 
| 131 | 
            +
                        mu=mu,
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
                    self._num_timesteps = len(timesteps)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # Handle guidance
         | 
| 136 | 
            +
                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # 6. Denoising loop
         | 
| 139 | 
            +
                    for i, t in enumerate(timesteps):
         | 
| 140 | 
            +
                        if self.interrupt:
         | 
| 141 | 
            +
                            continue
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                        timestep = t.expand(latents.shape[0]).to(latents.dtype)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                        noise_pred = self.transformer(
         | 
| 146 | 
            +
                            hidden_states=latents,
         | 
| 147 | 
            +
                            timestep=timestep / 1000,
         | 
| 148 | 
            +
                            guidance=guidance,
         | 
| 149 | 
            +
                            pooled_projections=pooled_prompt_embeds,
         | 
| 150 | 
            +
                            encoder_hidden_states=prompt_embeds,
         | 
| 151 | 
            +
                            txt_ids=text_ids,
         | 
| 152 | 
            +
                            img_ids=latent_image_ids,
         | 
| 153 | 
            +
                            joint_attention_kwargs=self.joint_attention_kwargs,
         | 
| 154 | 
            +
                            return_dict=False,
         | 
| 155 | 
            +
                        )[0]
         | 
| 156 | 
            +
                        
         | 
| 157 | 
            +
                        # Yield intermediate result
         | 
| 158 | 
            +
                        latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
         | 
| 159 | 
            +
                        yield self._decode_latents_to_image(latents, height, width, output_type)
         | 
| 160 | 
            +
                        torch.cuda.empty_cache()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    # Final image
         | 
| 163 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 164 | 
            +
                    torch.cuda.empty_cache()
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
         | 
| 167 | 
            +
                    """Decodes the given latents into an image."""
         | 
| 168 | 
            +
                    vae = vae or self.vae
         | 
| 169 | 
            +
                    latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
         | 
| 170 | 
            +
                    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
         | 
| 171 | 
            +
                    image = vae.decode(latents, return_dict=False)[0]
         | 
| 172 | 
            +
                    return self.image_processor.postprocess(image, output_type=output_type)[0]
         | 
    	
        optim_utils.py
    ADDED
    
    | @@ -0,0 +1,239 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import requests
         | 
| 4 | 
            +
            from io import BytesIO
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            from statistics import mean
         | 
| 7 | 
            +
            import copy
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            from typing import Any, Mapping
         | 
| 10 | 
            +
            import open_clip
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from sentence_transformers.util import (semantic_search, 
         | 
| 14 | 
            +
                                                    dot_score, 
         | 
| 15 | 
            +
                                                    normalize_embeddings)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def nn_project(curr_embeds, embedding_layer, print_hits=False):
         | 
| 19 | 
            +
                with torch.no_grad():
         | 
| 20 | 
            +
                    bsz,seq_len,emb_dim = curr_embeds.shape
         | 
| 21 | 
            +
                    
         | 
| 22 | 
            +
                    # Using the sentence transformers semantic search which is 
         | 
| 23 | 
            +
                    # a dot product exact kNN search between a set of 
         | 
| 24 | 
            +
                    # query vectors and a corpus of vectors
         | 
| 25 | 
            +
                    curr_embeds = curr_embeds.reshape((-1,emb_dim))
         | 
| 26 | 
            +
                    curr_embeds = normalize_embeddings(curr_embeds) # queries
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    embedding_matrix = embedding_layer.weight
         | 
| 29 | 
            +
                    embedding_matrix = normalize_embeddings(embedding_matrix)
         | 
| 30 | 
            +
                    
         | 
| 31 | 
            +
                    hits = semantic_search(curr_embeds, embedding_matrix, 
         | 
| 32 | 
            +
                                            query_chunk_size=curr_embeds.shape[0], 
         | 
| 33 | 
            +
                                            top_k=1,
         | 
| 34 | 
            +
                                            score_function=dot_score)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    if print_hits:
         | 
| 37 | 
            +
                        all_hits = []
         | 
| 38 | 
            +
                        for hit in hits:
         | 
| 39 | 
            +
                            all_hits.append(hit[0]["score"])
         | 
| 40 | 
            +
                        print(f"mean hits:{mean(all_hits)}")
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device=curr_embeds.device)
         | 
| 43 | 
            +
                    nn_indices = nn_indices.reshape((bsz,seq_len))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    projected_embeds = embedding_layer(nn_indices)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                return projected_embeds, nn_indices
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            def decode_ids(input_ids, tokenizer, by_token=False):
         | 
| 50 | 
            +
                input_ids = input_ids.detach().cpu().numpy()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                texts = []
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                if by_token:
         | 
| 55 | 
            +
                    for input_ids_i in input_ids:
         | 
| 56 | 
            +
                        curr_text = []
         | 
| 57 | 
            +
                        for tmp in input_ids_i:
         | 
| 58 | 
            +
                            curr_text.append(tokenizer.decode([tmp]))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        texts.append('|'.join(curr_text))
         | 
| 61 | 
            +
                else:
         | 
| 62 | 
            +
                    for input_ids_i in input_ids:
         | 
| 63 | 
            +
                        texts.append(tokenizer.decode(input_ids_i))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                return texts
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            def get_target_feature(model, preprocess, tokenizer_funct, device, target_images=None, target_prompts=None):
         | 
| 68 | 
            +
                if target_images is not None:
         | 
| 69 | 
            +
                    with torch.no_grad():
         | 
| 70 | 
            +
                        curr_images = [preprocess(i).unsqueeze(0) for i in target_images]
         | 
| 71 | 
            +
                        curr_images = torch.concatenate(curr_images).to(device)
         | 
| 72 | 
            +
                        all_target_features = model.encode_image(curr_images)
         | 
| 73 | 
            +
                else:
         | 
| 74 | 
            +
                    texts = tokenizer_funct(target_prompts).to(device)
         | 
| 75 | 
            +
                    all_target_features = model.encode_text(texts)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return all_target_features
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            def encode_text_embedding(model, text_embedding, ids, avg_text=False):
         | 
| 80 | 
            +
                    cast_dtype = model.transformer.get_cast_dtype()
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    x = text_embedding + model.positional_embedding.to(cast_dtype)
         | 
| 83 | 
            +
                    x = x.permute(1, 0, 2)  # NLD -> LND
         | 
| 84 | 
            +
                    x = model.transformer(x, attn_mask=model.attn_mask)
         | 
| 85 | 
            +
                    x = x.permute(1, 0, 2)  # LND -> NLD
         | 
| 86 | 
            +
                    x = model.ln_final(x)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # x.shape = [batch_size, n_ctx, transformer.width]
         | 
| 89 | 
            +
                    # take features from the eot embedding (eot_token is the highest number in each sequence)
         | 
| 90 | 
            +
                    if avg_text:
         | 
| 91 | 
            +
                        x = x[torch.arange(x.shape[0]), :ids.argmax(dim=-1)]
         | 
| 92 | 
            +
                        x[:, 1:-1]
         | 
| 93 | 
            +
                        x = x.mean(dim=1) @ model.text_projection
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        x = x[torch.arange(x.shape[0]), ids.argmax(dim=-1)] @ model.text_projection
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    return x
         | 
| 98 | 
            +
                    
         | 
| 99 | 
            +
            def forward_text_embedding(model, embeddings, ids, image_features, avg_text=False, return_feature=False):
         | 
| 100 | 
            +
                text_features = encode_text_embedding(model, embeddings, ids, avg_text=avg_text)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                if return_feature:
         | 
| 103 | 
            +
                    return text_features
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                image_features = image_features / image_features.norm(dim=1, keepdim=True)
         | 
| 106 | 
            +
                text_features = text_features / text_features.norm(dim=1, keepdim=True)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                logits_per_image = image_features @ text_features.t()
         | 
| 109 | 
            +
                logits_per_text = logits_per_image.t()
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                return logits_per_image, logits_per_text
         | 
| 112 | 
            +
                
         | 
| 113 | 
            +
            def initialize_prompt(tokenizer, token_embedding, args, device, original_prompt):
         | 
| 114 | 
            +
                prompt_len = args["prompt_len"]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                # randomly optimize prompt embeddings
         | 
| 117 | 
            +
                tokens = tokenizer.encode(original_prompt)
         | 
| 118 | 
            +
                if len(tokens) > prompt_len:
         | 
| 119 | 
            +
                    tokens = tokens[:prompt_len]
         | 
| 120 | 
            +
                if len(tokens) < prompt_len:
         | 
| 121 | 
            +
                    tokens += [0] * (prompt_len - len(tokens))
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                prompt_ids = torch.tensor([tokens] * args["prompt_bs"]).to(device)
         | 
| 124 | 
            +
                # prompt_ids = torch.randint(len(tokenizer.encoder), (args.prompt_bs, prompt_len)).to(device)
         | 
| 125 | 
            +
                prompt_embeds = token_embedding(prompt_ids).detach()
         | 
| 126 | 
            +
                prompt_embeds.requires_grad = True
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                # initialize the template
         | 
| 129 | 
            +
                template_text = "{}"
         | 
| 130 | 
            +
                padded_template_text = template_text.format(" ".join(["<start_of_text>"] * prompt_len))
         | 
| 131 | 
            +
                dummy_ids = tokenizer.encode(padded_template_text)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # -1 for optimized tokens
         | 
| 134 | 
            +
                dummy_ids = [i if i != 49406 else -1 for i in dummy_ids]
         | 
| 135 | 
            +
                dummy_ids = [49406] + dummy_ids + [49407]
         | 
| 136 | 
            +
                dummy_ids += [0] * (77 - len(dummy_ids))
         | 
| 137 | 
            +
                dummy_ids = torch.tensor([dummy_ids] * args["prompt_bs"]).to(device)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                # for getting dummy embeds; -1 won't work for token_embedding
         | 
| 140 | 
            +
                tmp_dummy_ids = copy.deepcopy(dummy_ids)
         | 
| 141 | 
            +
                tmp_dummy_ids[tmp_dummy_ids == -1] = 0
         | 
| 142 | 
            +
                dummy_embeds = token_embedding(tmp_dummy_ids).detach()
         | 
| 143 | 
            +
                dummy_embeds.requires_grad = False
         | 
| 144 | 
            +
                
         | 
| 145 | 
            +
                return prompt_embeds, dummy_embeds, dummy_ids
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            def optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, original_prompt):
         | 
| 148 | 
            +
                opt_iters = args["iter"]
         | 
| 149 | 
            +
                lr = args["lr"]
         | 
| 150 | 
            +
                weight_decay = args["weight_decay"]
         | 
| 151 | 
            +
                print_step = args["print_step"]
         | 
| 152 | 
            +
                batch_size = args["batch_size"]
         | 
| 153 | 
            +
                print_new_best = True
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                # initialize prompt
         | 
| 156 | 
            +
                prompt_embeds, dummy_embeds, dummy_ids = initialize_prompt(tokenizer, token_embedding, args, device, original_prompt)
         | 
| 157 | 
            +
                p_bs, p_len, p_dim = prompt_embeds.shape
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # get optimizer
         | 
| 160 | 
            +
                input_optimizer = torch.optim.AdamW([prompt_embeds], lr=lr, weight_decay=weight_decay)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                best_sim = -1000 * args["loss_weight"]
         | 
| 163 | 
            +
                best_text = ""
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                for step in range(opt_iters):
         | 
| 166 | 
            +
                    # randomly sample sample images and get features
         | 
| 167 | 
            +
                    if batch_size is None:
         | 
| 168 | 
            +
                        target_features = all_target_features
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        curr_indx = torch.randperm(len(all_target_features))
         | 
| 171 | 
            +
                        target_features = all_target_features[curr_indx][0:batch_size]
         | 
| 172 | 
            +
                        
         | 
| 173 | 
            +
                    universal_target_features = all_target_features
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # forward projection
         | 
| 176 | 
            +
                    projected_embeds, nn_indices = nn_project(prompt_embeds, token_embedding, print_hits=False)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # get cosine similarity score with all target features
         | 
| 179 | 
            +
                    with torch.no_grad():
         | 
| 180 | 
            +
                        # padded_embeds = copy.deepcopy(dummy_embeds)
         | 
| 181 | 
            +
                        padded_embeds = dummy_embeds.detach().clone()
         | 
| 182 | 
            +
                        padded_embeds[dummy_ids == -1] = projected_embeds.reshape(-1, p_dim)
         | 
| 183 | 
            +
                        logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, universal_target_features)
         | 
| 184 | 
            +
                        scores_per_prompt = logits_per_image.mean(dim=0)
         | 
| 185 | 
            +
                        universal_cosim_score = scores_per_prompt.max().item()
         | 
| 186 | 
            +
                        best_indx = scores_per_prompt.argmax().item()
         | 
| 187 | 
            +
                    
         | 
| 188 | 
            +
                    # tmp_embeds = copy.deepcopy(prompt_embeds)
         | 
| 189 | 
            +
                    tmp_embeds = prompt_embeds.detach().clone()
         | 
| 190 | 
            +
                    tmp_embeds.data = projected_embeds.data
         | 
| 191 | 
            +
                    tmp_embeds.requires_grad = True
         | 
| 192 | 
            +
                    
         | 
| 193 | 
            +
                    # padding
         | 
| 194 | 
            +
                    # padded_embeds = copy.deepcopy(dummy_embeds)
         | 
| 195 | 
            +
                    padded_embeds = dummy_embeds.detach().clone()
         | 
| 196 | 
            +
                    padded_embeds[dummy_ids == -1] = tmp_embeds.reshape(-1, p_dim)
         | 
| 197 | 
            +
                    
         | 
| 198 | 
            +
                    logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, target_features)
         | 
| 199 | 
            +
                    cosim_scores = logits_per_image
         | 
| 200 | 
            +
                    loss = 1 - cosim_scores.mean()
         | 
| 201 | 
            +
                    loss = loss * args["loss_weight"]
         | 
| 202 | 
            +
                    
         | 
| 203 | 
            +
                    prompt_embeds.grad, = torch.autograd.grad(loss, [tmp_embeds])
         | 
| 204 | 
            +
                    
         | 
| 205 | 
            +
                    input_optimizer.step()
         | 
| 206 | 
            +
                    input_optimizer.zero_grad()
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    curr_lr = input_optimizer.param_groups[0]["lr"]
         | 
| 209 | 
            +
                    cosim_scores = cosim_scores.mean().item()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    decoded_text = decode_ids(nn_indices, tokenizer)[best_indx]
         | 
| 212 | 
            +
                    if print_step is not None and (step % print_step == 0 or step == opt_iters-1):
         | 
| 213 | 
            +
                        per_step_message = f"step: {step}, lr: {curr_lr}"
         | 
| 214 | 
            +
                        # if not print_new_best:
         | 
| 215 | 
            +
                            # per_step_message = f"\n{per_step_message}, cosim: {universal_cosim_score:.3f}, text: {decoded_text}"
         | 
| 216 | 
            +
                        # print(per_step_message)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    if best_sim * args["loss_weight"] < universal_cosim_score * args["loss_weight"]:
         | 
| 219 | 
            +
                        best_sim = universal_cosim_score
         | 
| 220 | 
            +
                        best_text = decoded_text
         | 
| 221 | 
            +
                        if print_new_best:
         | 
| 222 | 
            +
                            print(f"step: {step}, new best cosine sim: {best_sim}, new best prompt: {best_text}")
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                if print_step is not None:
         | 
| 225 | 
            +
                    print(f"best cosine sim: {best_sim}, best prompt: {best_text}")
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                return best_text
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            def optimize_prompt(model, preprocess, args, device, target_images=None, target_prompts=None):
         | 
| 231 | 
            +
                token_embedding = model.token_embedding
         | 
| 232 | 
            +
                tokenizer = open_clip.tokenizer._tokenizer
         | 
| 233 | 
            +
                tokenizer_funct = open_clip.get_tokenizer(args["clip_model"])
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                all_target_features = get_target_feature(model, preprocess, tokenizer_funct, device, target_images=target_images)
         | 
| 236 | 
            +
                learned_prompt = optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, target_prompts)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                return learned_prompt
         | 
| 239 | 
            +
                
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            accelerate
         | 
| 2 | 
            +
            diffusers
         | 
| 3 | 
            +
            torch
         | 
| 4 | 
            +
            transformers
         | 
| 5 | 
            +
            git+https://github.com/huggingface/diffusers.git
         | 
| 6 | 
            +
            sentencepiece
         | 
| 7 | 
            +
            openai
         | 
| 8 | 
            +
            huggingface_hub
         | 
| 9 | 
            +
            sentence-transformers
         | 
| 10 | 
            +
            ftfy
         | 
| 11 | 
            +
            mediapy
         | 
| 12 | 
            +
            open-clip-torch==2.24.0
         | 
| 13 | 
            +
            gspread
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,104 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            from diffusers import DiffusionPipeline, FluxPipeline
         | 
| 3 | 
            +
            from live_preview_helpers import FLUXPipelineWithIntermediateOutputs 
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from openai import OpenAI
         | 
| 7 | 
            +
            import subprocess
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            T2I_MODELS = {
         | 
| 10 | 
            +
                "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1", 
         | 
| 11 | 
            +
                "SDXL-Turbo": "stabilityai/sdxl-turbo",
         | 
| 12 | 
            +
                "Stable Diffusion v3.5-medium": "stabilityai/stable-diffusion-3.5-medium", # Default
         | 
| 13 | 
            +
                "Flux.1-dev": "black-forest-labs/FLUX.1-dev",
         | 
| 14 | 
            +
            }
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            SCENARIOS = {
         | 
| 17 | 
            +
              "Product advertisement": "You are designing an advertising campaign for a new line of coffee machines. To ensure the campaign resonates with a wider audience, you use generative models to create marketing images that showcase a variety of users interacting with the product.",
         | 
| 18 | 
            +
              "Tourist promotion": "You are creating a travel campaign to attract a diverse range of visitors to a specific destination. To make the promotional materials more engaging and inclusive, you use generative models to design posters that highlight a broader array of experiences.",
         | 
| 19 | 
            +
              "Fictional character generation": "You are creating a narrative superhero game where the player often interacts with multiple other non-player characters in the story. To test how different characters would affect the experience of gameplay, you decide to use generative models to help construct characters for (play)testing.",
         | 
| 20 | 
            +
              "Interior Design": "You have a one-bedroom apartment and want to arrange your bed, desk, and dresser in the best way possible. You love the color white and want to ensure your space feels bright and open. To make a decision, you need a way to visualize different furniture placements before setting everything up.",
         | 
| 21 | 
            +
            #   "Education & accessibility": "You are a grade school teacher and the lesson of the day is about teamwork. Some of your students may have a difficult time visualizing what teamwork looks like because they are either (1) too young, (2) English is not their first language, or (3) they may have cognitive impairments that make it difficult for them to visualize concepts (e.g. aphantasia).."
         | 
| 22 | 
            +
            }
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            PROMPTS = {
         | 
| 25 | 
            +
                "Product advertisement": "Design a marketing advertisement image for a coffee machine.",
         | 
| 26 | 
            +
                "Tourist promotion": "Design a travel promotional poster to showcase the beauty and attractions of a tourist destination.",
         | 
| 27 | 
            +
                "Fictional character generation": "Generate a character of a superhero.",
         | 
| 28 | 
            +
                "Interior Design": "Generate an one-bedroom apartment interior design.",
         | 
| 29 | 
            +
                # "Education & accessibility": "Generate an image of grade school students buildind a sandcastle together on the beach."
         | 
| 30 | 
            +
            }
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            IMAGES = {
         | 
| 33 | 
            +
                "Product advertisement": {"baseline": ["images/scenario1_base1.png","images/scenario1_base2.png","images/scenario1_base3.png","images/scenario1_base4.png"], 
         | 
| 34 | 
            +
                                                        "ours": ["images/scenario1_our1.png","images/scenario1_our2.png","images/scenario1_our3.png","images/scenario1_our4.png"]},
         | 
| 35 | 
            +
                "Tourist promotion": {"baseline": ["images/scenario5_base1.png","images/scenario5_base2.png","images/scenario5_base3.png","images/scenario5_base4.png"], 
         | 
| 36 | 
            +
                                                        "ours": ["images/scenario5_our1.png","images/scenario5_our2.png","images/scenario5_our3.png","images/scenario5_our4.png"]},
         | 
| 37 | 
            +
                "Fictional character generation": {"baseline": ["images/scenario2_base1.png","images/scenario2_base2.png","images/scenario2_base3.png","images/scenario2_base4.png"], 
         | 
| 38 | 
            +
                                                        "ours": ["images/scenario2_our1.png","images/scenario2_our2.png","images/scenario2_our3.png","images/scenario2_our4.png"]},
         | 
| 39 | 
            +
                "Interior Design": {"baseline": ["images/scenario3_base1.png","images/scenario3_base2.png","images/scenario3_base3.png","images/scenario3_base4.png"], 
         | 
| 40 | 
            +
                                                        "ours": ["images/scenario3_our1.png","images/scenario3_our2.png","images/scenario3_our3.png","images/scenario3_our4.png"]},
         | 
| 41 | 
            +
                # "Education & accessibility": {"baseline": ["images/scenario4_base1.png","images/scenario4_base2.png","images/scenario4_base3.png","images/scenario4_base4.png"], 
         | 
| 42 | 
            +
                                                        # "ours": ["images/scenario4_our1.png","images/scenario4_our2.png","images/scenario4_our3.png","images/scenario4_our4.png"]},
         | 
| 43 | 
            +
            }
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            OPTIONS = ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            INSTRUCTION = "📌 **Instruction**: Now, we want to understand your satisfaction with the images generated. <br /> 📌 Step 1: You will start from evaluating the following images based on the given prompt. <br /> 📌 Step 2: Then please modify the prompt according to your expectations for the given scenario background, and answer the evaluation question **until you are satisfied** with at least one of the images generated below. If you are not satisfied with the generated images, you can repeatedly modify the prompts for at most **5 times**."
         | 
| 48 | 
            +
            def clean_cache():
         | 
| 49 | 
            +
                subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
         | 
| 50 | 
            +
                if torch.cuda.is_available():
         | 
| 51 | 
            +
                    torch.cuda.empty_cache()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            def setup_model(t2i_model_repo, torch_dtype, device):
         | 
| 54 | 
            +
                if t2i_model_repo == "stabilityai/sdxl-turbo" or t2i_model_repo == "stabilityai/stable-diffusion-3.5-medium" or t2i_model_repo == "stabilityai/stable-diffusion-2-1":
         | 
| 55 | 
            +
                    pipe = DiffusionPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
         | 
| 56 | 
            +
                elif t2i_model_repo == "black-forest-labs/FLUX.1-dev":
         | 
| 57 | 
            +
                    # pipe = FluxPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
         | 
| 58 | 
            +
                    pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
         | 
| 59 | 
            +
                torch.cuda.empty_cache()
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                return pipe
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def init_gpt_api():
         | 
| 64 | 
            +
                return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            def call_gpt_api(messages, client, model, seed, max_tokens, temperature, top_p):
         | 
| 67 | 
            +
                completion = client.chat.completions.create(
         | 
| 68 | 
            +
                    model=model,
         | 
| 69 | 
            +
                    messages=messages,
         | 
| 70 | 
            +
                    seed=seed,
         | 
| 71 | 
            +
                    max_tokens=max_tokens,
         | 
| 72 | 
            +
                    temperature=temperature,
         | 
| 73 | 
            +
                    top_p=top_p,
         | 
| 74 | 
            +
                )
         | 
| 75 | 
            +
                return completion.choices[0].message.content
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def clean_response_gpt(res: str):
         | 
| 78 | 
            +
                prompts = re.findall(r'\d+\.\s"?(.*?)"?(?=\n|$)', res)
         | 
| 79 | 
            +
                return prompts
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def get_refine_msg(prompt, num_prompts):
         | 
| 83 | 
            +
                messages = [{"role": "system", "content": f"You are a helpful, respectful and precise assistant. You will be asked to generate {num_prompts} refined prompts. Only respond with those refined prompts"}]
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                message = f"""Given a prompt, modify the prompt for me to explore variations in subject attributes, actions, and contextual details, while retaining the semantic consistency of the original description. 
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            Follow the following refinement instruction: 
         | 
| 88 | 
            +
            1. Subject: refine broad terms into specific subsets, focusing on but not restricted on ethinity, gender, age of human. 
         | 
| 89 | 
            +
            2. Object: modify the brand, color of object(s) only if it's not specified in the prompt. 
         | 
| 90 | 
            +
            3. Setting: add details to the background environment, such as change of temporal or spatial details (e.g., day to night, indoor to outdoor). 
         | 
| 91 | 
            +
            4. Action: add more details to the action or specify the object or goal of the action. 
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            For example, given this prompt: a person is drinking a coffee in a coffee shop, the refined prompts could be: 
         | 
| 94 | 
            +
            'an elderly woman is drinking a coffee in a coffee shop' (subject adjective) 
         | 
| 95 | 
            +
            'an asian young woman is drinking a coffee in a coffee shop' (subject adjective) 
         | 
| 96 | 
            +
            'a young woman is drinking a hot coffee with her left hand in a coffee shop' (action details) 
         | 
| 97 | 
            +
            'a woman is drinking a coffee in an outdoor coffee shop in the garden' (setting details) 
         | 
| 98 | 
            +
            If there is no human in the sentence, you do not need to add person intentionally. 
         | 
| 99 | 
            +
            If you use adjectives, they should be visual. So don't use something like 'interesting'. Please also vary the number of modifications but do not change the number of subjects/objects that have been specified in the prompt. Remember not to change the predefined concepts that have been specified in the prompt. e.g. don't change a boy to several boys.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            Can you give me {num_prompts} modified prompts for the prompt '{prompt}' please."""
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                messages.append({"role": "user", "content": f"{message}"})
         | 
| 104 | 
            +
                return messages
         |