update personalization
Browse files- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +242 -122
- utils.py +31 -1
    	
        __pycache__/utils.cpython-310.pyc
    CHANGED
    
    | Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
             | 
| 2 | 
             
            import gradio as gr
         | 
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import random
         | 
| @@ -11,7 +10,7 @@ import open_clip | |
| 11 | 
             
            from optim_utils import optimize_prompt
         | 
| 12 | 
             
            from utils import (
         | 
| 13 | 
             
                clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
         | 
| 14 | 
            -
                get_refine_msg, clean_cache, get_personalize_message,
         | 
| 15 | 
             
                clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
         | 
| 16 | 
             
                INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS  
         | 
| 17 | 
             
            )
         | 
| @@ -41,6 +40,7 @@ torch.cuda.empty_cache() | |
| 41 | 
             
            METHOD = "Experimental" 
         | 
| 42 | 
             
            counter = 1
         | 
| 43 | 
             
            enable_submit = False
         | 
|  | |
| 44 | 
             
            responses_memory = {METHOD: {}}
         | 
| 45 | 
             
            example_data = [
         | 
| 46 | 
             
                [
         | 
| @@ -100,7 +100,8 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0 | |
| 100 | 
             
            def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
         | 
| 101 | 
             
                seed = random.randint(0, MAX_SEED)
         | 
| 102 | 
             
                client = init_gpt_api()
         | 
| 103 | 
            -
                messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
         | 
|  | |
| 104 | 
             
                outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
         | 
| 105 | 
             
                return outputs
         | 
| 106 |  | 
| @@ -121,12 +122,12 @@ def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2) | |
| 121 | 
             
                }
         | 
| 122 | 
             
                inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
         | 
| 123 |  | 
| 124 | 
            -
                # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
         | 
| 125 | 
            -
                # return learned_prompt
         | 
| 126 | 
            -
             | 
| 127 | 
             
            # =========================
         | 
| 128 | 
             
            # UI Helper Functions
         | 
| 129 | 
             
            # =========================
         | 
|  | |
|  | |
|  | |
| 130 | 
             
            def reset_gallery():
         | 
| 131 | 
             
                return []
         | 
| 132 |  | 
| @@ -136,105 +137,106 @@ def display_error_message(msg, duration=5): | |
| 136 | 
             
            def display_info_message(msg, duration=5):
         | 
| 137 | 
             
                gr.Info(msg, duration=duration)
         | 
| 138 |  | 
| 139 | 
            -
            def  | 
| 140 | 
            -
                 | 
| 141 | 
            -
                fully_satisfied_option = ["Satisfied", "Very Satisfied"]
         | 
| 142 | 
            -
                if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
         | 
| 143 | 
            -
                return gr.update(interactive=if_submit)
         | 
| 144 | 
            -
             | 
| 145 | 
            -
            def select_image(like_radio, images_method):
         | 
| 146 | 
            -
                if like_radio == IMAGE_OPTIONS[0]:
         | 
| 147 | 
            -
                    return images_method[0][0]
         | 
| 148 | 
            -
                elif like_radio == IMAGE_OPTIONS[1]:
         | 
| 149 | 
            -
                    return images_method[1][0]
         | 
| 150 | 
            -
                elif like_radio == IMAGE_OPTIONS[2]:
         | 
| 151 | 
            -
                    return images_method[2][0]
         | 
| 152 | 
            -
                elif like_radio == IMAGE_OPTIONS[3]:
         | 
| 153 | 
            -
                    return images_method[3][0]
         | 
| 154 | 
            -
                else:
         | 
| 155 | 
            -
                    return None
         | 
| 156 | 
            -
             | 
| 157 | 
            -
            def check_evaluation(sim_radio):
         | 
| 158 | 
            -
                if not sim_radio:
         | 
| 159 | 
             
                    display_error_message("โ Please fill all evaluations before changing image or submitting.")
         | 
| 160 | 
             
                    return False
         | 
| 161 | 
             
                return True
         | 
| 162 |  | 
| 163 | 
             
            def generate_image(prompt, like_image, dislike_image):
         | 
| 164 | 
            -
                global responses_memory
         | 
| 165 | 
             
                history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
         | 
| 166 | 
             
                feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
         | 
| 167 | 
            -
                 | 
| 168 | 
            -
                 | 
| 169 | 
            -
             | 
| 170 | 
            -
                 | 
| 171 | 
            -
             | 
| 172 | 
             
                gallery_images = []
         | 
|  | |
| 173 | 
             
                refined_prompts = call_gpt_refine_prompt(personalized)
         | 
| 174 | 
             
                for i in range(NUM_IMAGES):
         | 
| 175 | 
             
                    img = infer(refined_prompts[i])
         | 
| 176 | 
             
                    gallery_images.append(img)
         | 
|  | |
| 177 | 
             
                    yield gallery_images
         | 
| 178 |  | 
| 179 | 
            -
            def  | 
| 180 | 
            -
                 | 
| 181 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 182 | 
             
                    responses_memory[METHOD][counter] = {
         | 
| 183 | 
             
                        "prompt": prompt,
         | 
| 184 | 
             
                        "sim_radio": sim_radio,
         | 
| 185 | 
             
                        "response": "",
         | 
| 186 | 
            -
                        "satisfied_img": f"round {counter},  | 
| 187 | 
            -
                        "unsatisfied_img": f"round {counter},  | 
| 188 | 
             
                    }
         | 
| 189 |  | 
| 190 | 
            -
                    enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
         | 
| 191 | 
            -
             | 
| 192 | 
             
                    history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
         | 
|  | |
|  | |
| 193 | 
             
                    if not history_images:
         | 
| 194 | 
            -
                        history_images = current_images
         | 
| 195 | 
             
                    elif current_images:
         | 
| 196 | 
             
                        history_images.extend(current_images)
         | 
|  | |
| 197 | 
             
                    current_images = []
         | 
| 198 |  | 
| 199 | 
             
                    examples_state = gr.update(samples=history_prompts, visible=True)
         | 
| 200 | 
             
                    prompt_state = gr.update(interactive=True)
         | 
| 201 | 
             
                    next_state = gr.update(visible=True, interactive=True)
         | 
| 202 | 
             
                    redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
         | 
| 203 | 
            -
                     | 
| 204 | 
            -
             | 
| 205 | 
             
                    counter += 1
         | 
|  | |
| 206 |  | 
| 207 | 
            -
                     | 
| 208 | 
            -
                else:
         | 
| 209 | 
            -
                    return {submit_btn: gr.skip()}
         | 
| 210 |  | 
| 211 | 
            -
             | 
| 212 | 
            -
                global counter, enable_submit, responses_memory
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                if check_evaluation(sim_radio):
         | 
| 215 | 
            -
                    # Save the final round entry
         | 
| 216 | 
            -
                    responses_memory[METHOD][counter] = {
         | 
| 217 | 
            -
                        "prompt": prompt,
         | 
| 218 | 
            -
                        "sim_radio": sim_radio,
         | 
| 219 | 
            -
                        "response": "",
         | 
| 220 | 
            -
                        "satisfied_img": f"round {counter}, {like_radio}",
         | 
| 221 | 
            -
                        "unsatisfied_img": f"round {counter}, {dislike_radio}",
         | 
| 222 | 
            -
                    }
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                    # Reset states
         | 
| 225 | 
            -
                    counter = 1
         | 
| 226 | 
            -
                    enable_submit = False
         | 
| 227 | 
            -
             | 
| 228 | 
            -
                    # Reset buttons
         | 
| 229 | 
            -
                    prompt_state = gr.update(interactive=False)
         | 
| 230 | 
            -
                    next_state = gr.update(visible=False, interactive=False)
         | 
| 231 | 
            -
                    submit_state = gr.update(interactive=False)
         | 
| 232 | 
            -
                    redesign_state = gr.update(interactive=False)
         | 
| 233 | 
            -
             | 
| 234 | 
            -
                    display_info_message("โ
 Your answer is saved!")
         | 
| 235 | 
            -
                    return None, None, None, prompt_state, next_state, redesign_state, submit_state
         | 
| 236 | 
             
                else:
         | 
| 237 | 
            -
                    return  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 238 |  | 
| 239 | 
             
            # =========================
         | 
| 240 | 
             
            # Interface (single tab, no participant/scenario/background)
         | 
| @@ -256,6 +258,7 @@ css = """ | |
| 256 | 
             
            #button-container {
         | 
| 257 | 
             
                display: flex;
         | 
| 258 | 
             
                justify-content: center;
         | 
|  | |
| 259 | 
             
            }
         | 
| 260 | 
             
            #compact-compact-row {
         | 
| 261 | 
             
                width:100%;
         | 
| @@ -315,9 +318,56 @@ css = """ | |
| 315 | 
             
                max-width: 150px;
         | 
| 316 | 
             
                display: inline-block;
         | 
| 317 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 318 | 
             
            """
         | 
| 319 |  | 
| 320 | 
             
            with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
         | 
|  | |
|  | |
|  | |
| 321 | 
             
                with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
         | 
| 322 | 
             
                    gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
         | 
| 323 | 
             
                    gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
         | 
| @@ -325,7 +375,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), " | |
| 325 | 
             
                    gr.HTML("""
         | 
| 326 | 
             
                    <div style="text-align: center;">
         | 
| 327 | 
             
                        <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
         | 
| 328 | 
            -
                            ๐ Read the Full Paper | 
| 329 | 
             
                        </a>
         | 
| 330 | 
             
                    </div>
         | 
| 331 | 
             
                    """)
         | 
| @@ -337,13 +387,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), " | |
| 337 |  | 
| 338 | 
             
                    gr.Markdown("""
         | 
| 339 | 
             
                    <div class="authors-section">
         | 
| 340 | 
            -
                        <a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>,  | 
| 341 | 
            -
                        <a href="https:// | 
| 342 | 
            -
                        <a href="https:// | 
|  | |
|  | |
|  | |
| 343 | 
             
                    </div>
         | 
| 344 | 
             
                    """, elem_classes=["authors-section"])
         | 
| 345 |  | 
| 346 | 
            -
                    # gr.Markdown("---")
         | 
| 347 |  | 
| 348 | 
             
                with gr.Tab(""):
         | 
| 349 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| @@ -360,47 +412,99 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), " | |
| 360 |  | 
| 361 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| 362 | 
             
                        with gr.Column(elem_id="col-container"):
         | 
| 363 | 
            -
                            images_method = gr.Gallery( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 364 |  | 
| 365 | 
             
                        with gr.Column(elem_id="col-container3"):
         | 
| 366 | 
            -
                             | 
| 367 | 
            -
                             | 
| 368 | 
            -
             | 
| 369 | 
            -
             | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
             | 
| 374 | 
            -
                             | 
| 375 | 
            -
                             | 
| 376 | 
            -
             | 
| 377 | 
            -
             | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
                             | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
             | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 397 |  | 
| 398 | 
             
                    with gr.Column(elem_id="col-container2"):
         | 
| 399 | 
             
                        gr.Markdown("### ๐ Examples")
         | 
| 400 | 
            -
                        ex1 = gr.Image(label="Image 1", width=200, height=200,  | 
| 401 | 
            -
                        ex2 = gr.Image(label="Image 2", width=200, height=200,  | 
| 402 | 
            -
                        ex3 = gr.Image(label="Image 3", width=200, height=200,  | 
| 403 | 
            -
                        ex4 = gr.Image(label="Image 4", width=200, height=200,  | 
| 404 |  | 
| 405 | 
             
                        gr.Examples(
         | 
| 406 | 
             
                            examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
         | 
| @@ -410,28 +514,44 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), " | |
| 410 | 
             
            # =========================
         | 
| 411 | 
             
            # Wiring
         | 
| 412 | 
             
            # =========================
         | 
| 413 | 
            -
                 | 
| 414 | 
            -
             | 
| 415 | 
            -
             | 
| 416 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 417 |  | 
| 418 | 
             
                next_btn.click(
         | 
| 419 | 
             
                    fn=generate_image,
         | 
| 420 | 
             
                    inputs=[prompt, like_image, dislike_image],
         | 
| 421 | 
             
                    outputs=[images_method]
         | 
| 422 | 
            -
                ).success(lambda: [gr.update(interactive=True), gr.update(interactive=True) | 
|  | |
| 423 |  | 
| 424 | 
             
                redesign_btn.click(
         | 
| 425 | 
             
                    fn=redesign,
         | 
| 426 | 
            -
                    inputs=[prompt, sim_radio,  | 
| 427 | 
            -
                    outputs=[sim_radio,  | 
| 428 | 
             
                )
         | 
| 429 |  | 
| 430 | 
             
                submit_btn.click(
         | 
| 431 | 
             
                    fn=save_response,
         | 
| 432 | 
            -
                    inputs=[prompt, sim_radio,  | 
| 433 | 
            -
                    outputs=[sim_radio,  | 
| 434 | 
             
                )
         | 
| 435 |  | 
| 436 | 
             
            if __name__ == "__main__":
         | 
| 437 | 
            -
                demo.launch()
         | 
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            import numpy as np
         | 
| 3 | 
             
            import random
         | 
|  | |
| 10 | 
             
            from optim_utils import optimize_prompt
         | 
| 11 | 
             
            from utils import (
         | 
| 12 | 
             
                clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
         | 
| 13 | 
            +
                get_refine_msg, clean_cache, get_personalize_message, get_personalized_simplified,
         | 
| 14 | 
             
                clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
         | 
| 15 | 
             
                INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS  
         | 
| 16 | 
             
            )
         | 
|  | |
| 40 | 
             
            METHOD = "Experimental" 
         | 
| 41 | 
             
            counter = 1
         | 
| 42 | 
             
            enable_submit = False
         | 
| 43 | 
            +
            redesign_flag = False
         | 
| 44 | 
             
            responses_memory = {METHOD: {}}
         | 
| 45 | 
             
            example_data = [
         | 
| 46 | 
             
                [
         | 
|  | |
| 100 | 
             
            def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
         | 
| 101 | 
             
                seed = random.randint(0, MAX_SEED)
         | 
| 102 | 
             
                client = init_gpt_api()
         | 
| 103 | 
            +
                # messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
         | 
| 104 | 
            +
                messages = get_personalized_simplified(prompt, like_image, dislike_image)
         | 
| 105 | 
             
                outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
         | 
| 106 | 
             
                return outputs
         | 
| 107 |  | 
|  | |
| 122 | 
             
                }
         | 
| 123 | 
             
                inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
         | 
| 124 |  | 
|  | |
|  | |
|  | |
| 125 | 
             
            # =========================
         | 
| 126 | 
             
            # UI Helper Functions
         | 
| 127 | 
             
            # =========================
         | 
| 128 | 
            +
            # Store generated images for selection
         | 
| 129 | 
            +
            current_generated_images = []
         | 
| 130 | 
            +
             | 
| 131 | 
             
            def reset_gallery():
         | 
| 132 | 
             
                return []
         | 
| 133 |  | 
|  | |
| 137 | 
             
            def display_info_message(msg, duration=5):
         | 
| 138 | 
             
                gr.Info(msg, duration=duration)
         | 
| 139 |  | 
| 140 | 
            +
            def check_evaluation(sim_radio, like_image, dislike_image):
         | 
| 141 | 
            +
                if not sim_radio or not like_image or not dislike_image:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 142 | 
             
                    display_error_message("โ Please fill all evaluations before changing image or submitting.")
         | 
| 143 | 
             
                    return False
         | 
| 144 | 
             
                return True
         | 
| 145 |  | 
| 146 | 
             
            def generate_image(prompt, like_image, dislike_image):
         | 
| 147 | 
            +
                global responses_memory, current_generated_images
         | 
| 148 | 
             
                history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
         | 
| 149 | 
             
                feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
         | 
| 150 | 
            +
                print(feedback, like_image, dislike_image)
         | 
| 151 | 
            +
                if like_image and dislike_image and feedback:
         | 
| 152 | 
            +
                    personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
         | 
| 153 | 
            +
                else:
         | 
| 154 | 
            +
                    personalized = prompt
         | 
| 155 | 
             
                gallery_images = []
         | 
| 156 | 
            +
                current_generated_images = []  # Reset the stored images
         | 
| 157 | 
             
                refined_prompts = call_gpt_refine_prompt(personalized)
         | 
| 158 | 
             
                for i in range(NUM_IMAGES):
         | 
| 159 | 
             
                    img = infer(refined_prompts[i])
         | 
| 160 | 
             
                    gallery_images.append(img)
         | 
| 161 | 
            +
                    current_generated_images.append(img)  # Store for selection
         | 
| 162 | 
             
                    yield gallery_images
         | 
| 163 |  | 
| 164 | 
            +
            def on_gallery_select(evt: gr.SelectData):
         | 
| 165 | 
            +
                """Handle gallery image selection and return the selected image"""
         | 
| 166 | 
            +
                global current_generated_images
         | 
| 167 | 
            +
                if current_generated_images and evt.index < len(current_generated_images):
         | 
| 168 | 
            +
                    return current_generated_images[evt.index]
         | 
| 169 | 
            +
                return None
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            def handle_like_drag(selected_image):
         | 
| 172 | 
            +
                """Handle setting an image as liked"""
         | 
| 173 | 
            +
                return selected_image
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            def handle_dislike_drag(selected_image):
         | 
| 176 | 
            +
                """Handle setting an image as disliked"""
         | 
| 177 | 
            +
                return selected_image
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            def redesign(prompt, sim_radio, current_images, history_images, like_image, dislike_image):
         | 
| 180 | 
            +
                global counter, responses_memory, redesign_flag
         | 
| 181 | 
            +
                
         | 
| 182 | 
            +
                if check_evaluation(sim_radio, like_image, dislike_image):
         | 
| 183 | 
             
                    responses_memory[METHOD][counter] = {
         | 
| 184 | 
             
                        "prompt": prompt,
         | 
| 185 | 
             
                        "sim_radio": sim_radio,
         | 
| 186 | 
             
                        "response": "",
         | 
| 187 | 
            +
                        "satisfied_img": f"round {counter}, liked image",
         | 
| 188 | 
            +
                        "unsatisfied_img": f"round {counter}, disliked image",
         | 
| 189 | 
             
                    }
         | 
| 190 |  | 
|  | |
|  | |
| 191 | 
             
                    history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
         | 
| 192 | 
            +
                    
         | 
| 193 | 
            +
                    # Update history images
         | 
| 194 | 
             
                    if not history_images:
         | 
| 195 | 
            +
                        history_images = current_images.copy() if current_images else []
         | 
| 196 | 
             
                    elif current_images:
         | 
| 197 | 
             
                        history_images.extend(current_images)
         | 
| 198 | 
            +
                    
         | 
| 199 | 
             
                    current_images = []
         | 
| 200 |  | 
| 201 | 
             
                    examples_state = gr.update(samples=history_prompts, visible=True)
         | 
| 202 | 
             
                    prompt_state = gr.update(interactive=True)
         | 
| 203 | 
             
                    next_state = gr.update(visible=True, interactive=True)
         | 
| 204 | 
             
                    redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
         | 
| 205 | 
            +
                    
         | 
|  | |
| 206 | 
             
                    counter += 1
         | 
| 207 | 
            +
                    redesign_flag = True
         | 
| 208 |  | 
| 209 | 
            +
                    display_info_message(f"โ
 Round {counter-1} feedback saved! You can continue redesigning or restart.")
         | 
|  | |
|  | |
| 210 |  | 
| 211 | 
            +
                    return None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 212 | 
             
                else:
         | 
| 213 | 
            +
                    return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            def save_response(prompt, sim_radio, like_image, dislike_image):
         | 
| 216 | 
            +
                global counter, responses_memory, redesign_flag, current_generated_images
         | 
| 217 | 
            +
                
         | 
| 218 | 
            +
                # Reset all global variables
         | 
| 219 | 
            +
                responses_memory[METHOD] = {}
         | 
| 220 | 
            +
                counter = 1
         | 
| 221 | 
            +
                redesign_flag = False
         | 
| 222 | 
            +
                current_generated_images = []
         | 
| 223 | 
            +
                
         | 
| 224 | 
            +
                # Reset UI states
         | 
| 225 | 
            +
                prompt_state = gr.update(value="", interactive=True)
         | 
| 226 | 
            +
                next_state = gr.update(visible=True, interactive=True)
         | 
| 227 | 
            +
                redesign_state = gr.update(interactive=False)
         | 
| 228 | 
            +
                submit_state = gr.update(interactive=False)
         | 
| 229 | 
            +
                sim_radio_state = gr.update(value=None)
         | 
| 230 | 
            +
                like_image_state = gr.update(value=None)
         | 
| 231 | 
            +
                dislike_image_state = gr.update(value=None)
         | 
| 232 | 
            +
                gallery_state = []
         | 
| 233 | 
            +
                history_gallery_state = []
         | 
| 234 | 
            +
                examples_state = gr.update(samples=[['']], visible=True)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                display_info_message("๐ Session restarted! You can begin with a new prompt.")
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                return (sim_radio_state, prompt_state, next_state, redesign_state, 
         | 
| 239 | 
            +
                        like_image_state, dislike_image_state, gallery_state, history_gallery_state, examples_state)
         | 
| 240 |  | 
| 241 | 
             
            # =========================
         | 
| 242 | 
             
            # Interface (single tab, no participant/scenario/background)
         | 
|  | |
| 258 | 
             
            #button-container {
         | 
| 259 | 
             
                display: flex;
         | 
| 260 | 
             
                justify-content: center;
         | 
| 261 | 
            +
                gap: 10px;
         | 
| 262 | 
             
            }
         | 
| 263 | 
             
            #compact-compact-row {
         | 
| 264 | 
             
                width:100%;
         | 
|  | |
| 318 | 
             
                max-width: 150px;
         | 
| 319 | 
             
                display: inline-block;
         | 
| 320 | 
             
            }
         | 
| 321 | 
            +
            .instruction-box {
         | 
| 322 | 
            +
                background: linear-gradient(135deg, #e8f4fd 0%, #f0f8ff 100%);
         | 
| 323 | 
            +
                border: 2px solid #3498db;
         | 
| 324 | 
            +
                border-radius: 12px;
         | 
| 325 | 
            +
                padding: 20px;
         | 
| 326 | 
            +
                margin: 15px 0;
         | 
| 327 | 
            +
                color: #2c3e50;
         | 
| 328 | 
            +
            }
         | 
| 329 | 
            +
            .instruction-title {
         | 
| 330 | 
            +
                font-size: 1.2em;
         | 
| 331 | 
            +
                font-weight: bold;
         | 
| 332 | 
            +
                margin-bottom: 15px;
         | 
| 333 | 
            +
                color: #2c3e50;
         | 
| 334 | 
            +
                display: flex;
         | 
| 335 | 
            +
                align-items: center;
         | 
| 336 | 
            +
                gap: 8px;
         | 
| 337 | 
            +
            }
         | 
| 338 | 
            +
            .step-list {
         | 
| 339 | 
            +
                list-style: none;
         | 
| 340 | 
            +
                padding: 0;
         | 
| 341 | 
            +
                margin: 0;
         | 
| 342 | 
            +
            }
         | 
| 343 | 
            +
            .step-item {
         | 
| 344 | 
            +
                background: rgba(52, 152, 219, 0.1);
         | 
| 345 | 
            +
                border-radius: 8px;
         | 
| 346 | 
            +
                padding: 12px 16px;
         | 
| 347 | 
            +
                margin: 8px 0;
         | 
| 348 | 
            +
                border-left: 4px solid #3498db;
         | 
| 349 | 
            +
            }
         | 
| 350 | 
            +
            .step-number {
         | 
| 351 | 
            +
                font-weight: bold;
         | 
| 352 | 
            +
                color: #3498db;
         | 
| 353 | 
            +
                margin-right: 8px;
         | 
| 354 | 
            +
            }
         | 
| 355 | 
            +
            .personalization-header {
         | 
| 356 | 
            +
                background: linear-gradient(135deg, #ff6b6b, #ee5a24);
         | 
| 357 | 
            +
                color: white;
         | 
| 358 | 
            +
                padding: 15px;
         | 
| 359 | 
            +
                border-radius: 10px 10px 0 0;
         | 
| 360 | 
            +
                margin: -10px -10px 15px -10px;
         | 
| 361 | 
            +
                text-align: center;
         | 
| 362 | 
            +
                font-weight: bold;
         | 
| 363 | 
            +
                font-size: 1.1em;
         | 
| 364 | 
            +
            }
         | 
| 365 | 
             
            """
         | 
| 366 |  | 
| 367 | 
             
            with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
         | 
| 368 | 
            +
                # State variable to hold selected image
         | 
| 369 | 
            +
                selected_image = gr.State(None)
         | 
| 370 | 
            +
                
         | 
| 371 | 
             
                with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
         | 
| 372 | 
             
                    gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
         | 
| 373 | 
             
                    gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
         | 
|  | |
| 375 | 
             
                    gr.HTML("""
         | 
| 376 | 
             
                    <div style="text-align: center;">
         | 
| 377 | 
             
                        <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
         | 
| 378 | 
            +
                            ๐ Read the Full Paper
         | 
| 379 | 
             
                        </a>
         | 
| 380 | 
             
                    </div>
         | 
| 381 | 
             
                    """)
         | 
|  | |
| 387 |  | 
| 388 | 
             
                    gr.Markdown("""
         | 
| 389 | 
             
                    <div class="authors-section">
         | 
| 390 | 
            +
                        <a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>, 
         | 
| 391 | 
            +
                        <a href="https://www.aliceqian.com/">Alice Qian Zhang</a>, 
         | 
| 392 | 
            +
                        <a href="https://haiyizhu.com/">Haiyi Zhu</a>, 
         | 
| 393 | 
            +
                        <a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>, 
         | 
| 394 | 
            +
                        <a href="https://pliang279.github.io/">Paul Pu Liang</a>, 
         | 
| 395 | 
            +
                        <a href="https://janeon.github.io/">Jane Hsieh</a>
         | 
| 396 | 
             
                    </div>
         | 
| 397 | 
             
                    """, elem_classes=["authors-section"])
         | 
| 398 |  | 
|  | |
| 399 |  | 
| 400 | 
             
                with gr.Tab(""):
         | 
| 401 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
|  | |
| 412 |  | 
| 413 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| 414 | 
             
                        with gr.Column(elem_id="col-container"):
         | 
| 415 | 
            +
                            images_method = gr.Gallery(
         | 
| 416 | 
            +
                                label="Generated Images (Click to select, then set to Like/Dislike image)", 
         | 
| 417 | 
            +
                                columns=[4], 
         | 
| 418 | 
            +
                                rows=[1], 
         | 
| 419 | 
            +
                                height=400, 
         | 
| 420 | 
            +
                                interactive=False,
         | 
| 421 | 
            +
                                elem_id="gallery", 
         | 
| 422 | 
            +
                                format="png"
         | 
| 423 | 
            +
                            )
         | 
| 424 |  | 
| 425 | 
             
                        with gr.Column(elem_id="col-container3"):
         | 
| 426 | 
            +
                            like_btn = gr.Button("๐ Set as Liked (Optional for personalization)", size="sm", variant="secondary")
         | 
| 427 | 
            +
                            like_image = gr.Image(
         | 
| 428 | 
            +
                                label="Satisfied Image", 
         | 
| 429 | 
            +
                                width=150, 
         | 
| 430 | 
            +
                                height=150, 
         | 
| 431 | 
            +
                                interactive=False,
         | 
| 432 | 
            +
                                format="png", 
         | 
| 433 | 
            +
                                type="filepath"
         | 
| 434 | 
            +
                            )
         | 
| 435 | 
            +
                            dislike_btn = gr.Button("๐ Set as Disliked (Optional for personalization)", size="sm", variant="secondary")
         | 
| 436 | 
            +
                            dislike_image = gr.Image(
         | 
| 437 | 
            +
                                label="Unsatisfied Image", 
         | 
| 438 | 
            +
                                width=150, 
         | 
| 439 | 
            +
                                height=150, 
         | 
| 440 | 
            +
                                interactive=False,
         | 
| 441 | 
            +
                                format="png", 
         | 
| 442 | 
            +
                                type="filepath"
         | 
| 443 | 
            +
                            )
         | 
| 444 | 
            +
                            
         | 
| 445 | 
            +
                    with gr.Accordion("๐ฏ Advanced: Personalized Image Redesign", open=False, elem_id="col-container2"):
         | 
| 446 | 
            +
                        gr.HTML("""
         | 
| 447 | 
            +
                        <div class="instruction-box">
         | 
| 448 | 
            +
                            <div class="instruction-title">
         | 
| 449 | 
            +
                                ๐ How to Use Personalized Redesign
         | 
| 450 | 
            +
                            </div>
         | 
| 451 | 
            +
                            <div class="step-list">
         | 
| 452 | 
            +
                                <div class="step-item">
         | 
| 453 | 
            +
                                    <span class="step-number">1.</span>
         | 
| 454 | 
            +
                                    <strong>Rate Your Satisfaction:</strong> Provide a satisfaction score for the current generated images
         | 
| 455 | 
            +
                                </div>
         | 
| 456 | 
            +
                                <div class="step-item">
         | 
| 457 | 
            +
                                    <span class="step-number">2.</span>
         | 
| 458 | 
            +
                                    <strong>Select Preferences:</strong> Choose your most liked and disliked images
         | 
| 459 | 
            +
                                </div>
         | 
| 460 | 
            +
                                <div class="step-item">
         | 
| 461 | 
            +
                                    <span class="step-number">3.</span>
         | 
| 462 | 
            +
                                    <strong>Save & Iterate:</strong> Click "Save Personalized Data" before redesgining your prompt and clicking "Generate" 
         | 
| 463 | 
            +
                                </div>
         | 
| 464 | 
            +
                                <div class="step-item">
         | 
| 465 | 
            +
                                    <span class="step-number">4.</span>
         | 
| 466 | 
            +
                                    <strong>Restart Anytime:</strong> Use the "Restart" button to begin a fresh session
         | 
| 467 | 
            +
                                </div>
         | 
| 468 | 
            +
                            </div>
         | 
| 469 | 
            +
                        </div>
         | 
| 470 | 
            +
                        """)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                        with gr.Column(elem_id="col-container2"):
         | 
| 473 | 
            +
                            gr.Markdown("### ๐ Rate Current Generation")
         | 
| 474 | 
            +
                            with gr.Row():
         | 
| 475 | 
            +
                                sim_radio = gr.Radio(
         | 
| 476 | 
            +
                                    OPTIONS,
         | 
| 477 | 
            +
                                    label="How satisfied are you with the current generated images?",
         | 
| 478 | 
            +
                                    type="value",
         | 
| 479 | 
            +
                                    show_label=True,
         | 
| 480 | 
            +
                                    container=True,
         | 
| 481 | 
            +
                                    scale=1
         | 
| 482 | 
            +
                                )
         | 
| 483 | 
            +
                            
         | 
| 484 | 
            +
                            with gr.Row(elem_id="button-container"):
         | 
| 485 | 
            +
                                with gr.Column(scale=1):
         | 
| 486 | 
            +
                                    redesign_btn = gr.Button("๐พ Save Personalization Data", variant="primary", size="lg")
         | 
| 487 | 
            +
                                with gr.Column(scale=1):
         | 
| 488 | 
            +
                                    submit_btn = gr.Button("๐ Restart Session", variant="secondary", size="lg")
         | 
| 489 | 
            +
             | 
| 490 | 
            +
             | 
| 491 | 
            +
                        with gr.Column(elem_id="col-container2"):
         | 
| 492 | 
            +
                            example = gr.Examples([['']], prompt, label="๐ Prompt History", visible=True)
         | 
| 493 | 
            +
                            history_images = gr.Gallery(
         | 
| 494 | 
            +
                                label="๐๏ธ Generation History", 
         | 
| 495 | 
            +
                                columns=[4], 
         | 
| 496 | 
            +
                                rows=[1], 
         | 
| 497 | 
            +
                                elem_id="gallery", 
         | 
| 498 | 
            +
                                format="png", 
         | 
| 499 | 
            +
                                interactive=False,
         | 
| 500 | 
            +
                            )
         | 
| 501 |  | 
| 502 | 
             
                    with gr.Column(elem_id="col-container2"):
         | 
| 503 | 
             
                        gr.Markdown("### ๐ Examples")
         | 
| 504 | 
            +
                        ex1 = gr.Image(label="Image 1", width=200, height=200, format="png", type="filepath", visible=False)
         | 
| 505 | 
            +
                        ex2 = gr.Image(label="Image 2", width=200, height=200, format="png", type="filepath", visible=False)
         | 
| 506 | 
            +
                        ex3 = gr.Image(label="Image 3", width=200, height=200, format="png", type="filepath", visible=False)
         | 
| 507 | 
            +
                        ex4 = gr.Image(label="Image 4", width=200, height=200, format="png", type="filepath", visible=False)
         | 
| 508 |  | 
| 509 | 
             
                        gr.Examples(
         | 
| 510 | 
             
                            examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
         | 
|  | |
| 514 | 
             
            # =========================
         | 
| 515 | 
             
            # Wiring
         | 
| 516 | 
             
            # =========================
         | 
| 517 | 
            +
                # Handle gallery selection
         | 
| 518 | 
            +
                images_method.select(
         | 
| 519 | 
            +
                    fn=on_gallery_select,
         | 
| 520 | 
            +
                    inputs=[],
         | 
| 521 | 
            +
                    outputs=[selected_image]
         | 
| 522 | 
            +
                )
         | 
| 523 | 
            +
                
         | 
| 524 | 
            +
                # Handle like/dislike button clicks
         | 
| 525 | 
            +
                like_btn.click(
         | 
| 526 | 
            +
                    fn=handle_like_drag,
         | 
| 527 | 
            +
                    inputs=[selected_image],
         | 
| 528 | 
            +
                    outputs=[like_image]
         | 
| 529 | 
            +
                )
         | 
| 530 | 
            +
                
         | 
| 531 | 
            +
                dislike_btn.click(
         | 
| 532 | 
            +
                    fn=handle_dislike_drag,
         | 
| 533 | 
            +
                    inputs=[selected_image],
         | 
| 534 | 
            +
                    outputs=[dislike_image]
         | 
| 535 | 
            +
                )
         | 
| 536 |  | 
| 537 | 
             
                next_btn.click(
         | 
| 538 | 
             
                    fn=generate_image,
         | 
| 539 | 
             
                    inputs=[prompt, like_image, dislike_image],
         | 
| 540 | 
             
                    outputs=[images_method]
         | 
| 541 | 
            +
                ).success(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)], 
         | 
| 542 | 
            +
                outputs=[next_btn, prompt, redesign_btn, submit_btn])
         | 
| 543 |  | 
| 544 | 
             
                redesign_btn.click(
         | 
| 545 | 
             
                    fn=redesign,
         | 
| 546 | 
            +
                    inputs=[prompt, sim_radio, images_method, history_images, like_image, dislike_image],
         | 
| 547 | 
            +
                    outputs=[sim_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn]
         | 
| 548 | 
             
                )
         | 
| 549 |  | 
| 550 | 
             
                submit_btn.click(
         | 
| 551 | 
             
                    fn=save_response,
         | 
| 552 | 
            +
                    inputs=[prompt, sim_radio, like_image, dislike_image],
         | 
| 553 | 
            +
                    outputs=[sim_radio, prompt, next_btn, redesign_btn, like_image, dislike_image, images_method, history_images, example.dataset]
         | 
| 554 | 
             
                )
         | 
| 555 |  | 
| 556 | 
             
            if __name__ == "__main__":
         | 
| 557 | 
            +
                demo.launch()
         | 
    	
        utils.py
    CHANGED
    
    | @@ -171,7 +171,37 @@ def get_personalize_message(prompt, history_prompts, history_feedback, like_imag | |
| 171 | 
             
                                    "url": f"data:image/png;base64,{dislike_image_base64}",
         | 
| 172 | 
             
                                },
         | 
| 173 | 
             
                            })
         | 
| 174 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 175 | 
             
                return messages
         | 
| 176 |  | 
| 177 | 
             
            @spaces.GPU 
         | 
|  | |
| 171 | 
             
                                    "url": f"data:image/png;base64,{dislike_image_base64}",
         | 
| 172 | 
             
                                },
         | 
| 173 | 
             
                            })
         | 
| 174 | 
            +
                return messages
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            def get_personalized_simplified(prompt, like_image, dislike_image):
         | 
| 177 | 
            +
                messages = [
         | 
| 178 | 
            +
                {"role": "system", "content": f"You are a prompt refinement assistant. Your task is to improve a userโs text prompt based on his liked and disliked images. Your goal is to refine the prompt while maintaining its original meaning, improving clarity, specificity, and alignment with user preferences."}
         | 
| 179 | 
            +
            ]
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                message = f"""The first given image is user's liked image, refine prompt with style, and content user likes. The second given image is user's disliked image, refine prompt to avoid those elements or style of this image."""
         | 
| 182 | 
            +
                
         | 
| 183 | 
            +
                messages.append({
         | 
| 184 | 
            +
                        "role": "user",
         | 
| 185 | 
            +
                        "content": [
         | 
| 186 | 
            +
                            {"type": "text", "text": f"{message}"},
         | 
| 187 | 
            +
                        ],
         | 
| 188 | 
            +
                    })
         | 
| 189 | 
            +
                if like_image:
         | 
| 190 | 
            +
                    like_image_base64 = encode_image(like_image)
         | 
| 191 | 
            +
                    messages[-1]["content"].append({
         | 
| 192 | 
            +
                                "type": "image_url",
         | 
| 193 | 
            +
                                "image_url": {
         | 
| 194 | 
            +
                                    "url": f"data:image/png;base64,{like_image_base64}",
         | 
| 195 | 
            +
                                },
         | 
| 196 | 
            +
                            })
         | 
| 197 | 
            +
                if dislike_image:
         | 
| 198 | 
            +
                    dislike_image_base64 = encode_image(dislike_image)
         | 
| 199 | 
            +
                    messages[-1]["content"].append({
         | 
| 200 | 
            +
                                "type": "image_url",
         | 
| 201 | 
            +
                                "image_url": {
         | 
| 202 | 
            +
                                    "url": f"data:image/png;base64,{dislike_image_base64}",
         | 
| 203 | 
            +
                                },
         | 
| 204 | 
            +
                            })
         | 
| 205 | 
             
                return messages
         | 
| 206 |  | 
| 207 | 
             
            @spaces.GPU 
         | 
