update interface
Browse files- __pycache__/live_preview_helpers.cpython-310.pyc +0 -0
- __pycache__/optim_utils.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +105 -17
- optim_utils.py +0 -3
    	
        __pycache__/live_preview_helpers.cpython-310.pyc
    CHANGED
    
    | Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ | 
|  | 
    	
        __pycache__/optim_utils.cpython-310.pyc
    CHANGED
    
    | Binary files a/__pycache__/optim_utils.cpython-310.pyc and b/__pycache__/optim_utils.cpython-310.pyc differ | 
|  | 
    	
        __pycache__/utils.cpython-310.pyc
    CHANGED
    
    | Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -6,6 +6,7 @@ import spaces | |
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import re
         | 
| 8 | 
             
            import transformers
         | 
|  | |
| 9 |  | 
| 10 | 
             
            # Optional: keep these utilities if your pipeline depends on them
         | 
| 11 | 
             
            from optim_utils import optimize_prompt
         | 
| @@ -33,17 +34,15 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| 33 | 
             
            clean_cache()
         | 
| 34 |  | 
| 35 | 
             
            selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
         | 
|  | |
| 36 | 
             
            llm_pipe = None
         | 
| 37 | 
            -
            torch.cuda.empty_cache()
         | 
| 38 | 
             
            inverted_prompt = ""
         | 
|  | |
| 39 |  | 
| 40 | 
             
            METHOD = "Experimental"  # keep ONLY experimental
         | 
| 41 | 
            -
             | 
| 42 | 
            -
            # Global states for a single-task, single-method flow
         | 
| 43 | 
             
            counter = 1
         | 
| 44 | 
             
            enable_submit = False
         | 
| 45 | 
             
            responses_memory = {METHOD: {}}
         | 
| 46 | 
            -
             | 
| 47 | 
             
            example_data = [
         | 
| 48 | 
             
                [
         | 
| 49 | 
             
                    PROMPTS["Tourist promotion"],
         | 
| @@ -58,7 +57,6 @@ example_data = [ | |
| 58 | 
             
                    IMAGES["Interior Design"]["ours"]
         | 
| 59 | 
             
                ],
         | 
| 60 | 
             
            ]
         | 
| 61 | 
            -
            print(example_data)
         | 
| 62 |  | 
| 63 | 
             
            # =========================
         | 
| 64 | 
             
            # Image Generation Helpers
         | 
| @@ -103,11 +101,31 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0 | |
| 103 | 
             
            def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
         | 
| 104 | 
             
                seed = random.randint(0, MAX_SEED)
         | 
| 105 | 
             
                client = init_gpt_api()
         | 
| 106 | 
            -
                print(like_image, dislike_image)
         | 
| 107 | 
             
                messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
         | 
| 108 | 
             
                outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
         | 
| 109 | 
             
                return outputs
         | 
| 110 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
            # =========================
         | 
| 112 | 
             
            # UI Helper Functions
         | 
| 113 | 
             
            # =========================
         | 
| @@ -241,29 +259,98 @@ css = """ | |
| 241 | 
             
                display: flex;
         | 
| 242 | 
             
                justify-content: center;
         | 
| 243 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 244 | 
             
            #compact-row {
         | 
| 245 | 
             
                width:100%;
         | 
| 246 | 
             
                max-width: 1000px;
         | 
| 247 | 
             
                margin: 0px auto;
         | 
| 248 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 249 | 
             
            """
         | 
| 250 |  | 
| 251 | 
             
            with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
         | 
| 252 | 
            -
                with gr.Column(elem_id="col-container"):
         | 
| 253 | 
             
                    gr.Markdown("# 📌 **POET**")
         | 
| 254 | 
            -
                     | 
| 255 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 256 |  | 
| 257 | 
             
                with gr.Tab(""):
         | 
| 258 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| 259 | 
            -
                         | 
| 260 | 
            -
                             | 
| 261 | 
            -
             | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
             | 
|  | |
|  | |
| 267 |  | 
| 268 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| 269 | 
             
                        with gr.Column(elem_id="col-container"):
         | 
| @@ -313,6 +400,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), " | |
| 313 | 
             
                            examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
         | 
| 314 | 
             
                            inputs=[prompt, ex1, ex2, ex3, ex4]
         | 
| 315 | 
             
                        )
         | 
|  | |
| 316 | 
             
            # =========================
         | 
| 317 | 
             
            # Wiring
         | 
| 318 | 
             
            # =========================
         | 
|  | |
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import re
         | 
| 8 | 
             
            import transformers
         | 
| 9 | 
            +
            import open_clip
         | 
| 10 |  | 
| 11 | 
             
            # Optional: keep these utilities if your pipeline depends on them
         | 
| 12 | 
             
            from optim_utils import optimize_prompt
         | 
|  | |
| 34 | 
             
            clean_cache()
         | 
| 35 |  | 
| 36 | 
             
            selected_pipe = setup_model(default_t2i_model, torch_dtype, device)
         | 
| 37 | 
            +
            clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device)
         | 
| 38 | 
             
            llm_pipe = None
         | 
|  | |
| 39 | 
             
            inverted_prompt = ""
         | 
| 40 | 
            +
            torch.cuda.empty_cache()
         | 
| 41 |  | 
| 42 | 
             
            METHOD = "Experimental"  # keep ONLY experimental
         | 
|  | |
|  | |
| 43 | 
             
            counter = 1
         | 
| 44 | 
             
            enable_submit = False
         | 
| 45 | 
             
            responses_memory = {METHOD: {}}
         | 
|  | |
| 46 | 
             
            example_data = [
         | 
| 47 | 
             
                [
         | 
| 48 | 
             
                    PROMPTS["Tourist promotion"],
         | 
|  | |
| 57 | 
             
                    IMAGES["Interior Design"]["ours"]
         | 
| 58 | 
             
                ],
         | 
| 59 | 
             
            ]
         | 
|  | |
| 60 |  | 
| 61 | 
             
            # =========================
         | 
| 62 | 
             
            # Image Generation Helpers
         | 
|  | |
| 101 | 
             
            def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
         | 
| 102 | 
             
                seed = random.randint(0, MAX_SEED)
         | 
| 103 | 
             
                client = init_gpt_api()
         | 
|  | |
| 104 | 
             
                messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
         | 
| 105 | 
             
                outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
         | 
| 106 | 
             
                return outputs
         | 
| 107 |  | 
| 108 | 
            +
            @spaces.GPU(duration=100)
         | 
| 109 | 
            +
            def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2):
         | 
| 110 | 
            +
                global inverted_prompt
         | 
| 111 | 
            +
                text_params = {
         | 
| 112 | 
            +
                    "iter": iter,
         | 
| 113 | 
            +
                    "lr": lr,
         | 
| 114 | 
            +
                    "batch_size": batch_size,
         | 
| 115 | 
            +
                    "prompt_len": prompt_len,
         | 
| 116 | 
            +
                    "weight_decay": 0.1,
         | 
| 117 | 
            +
                    "prompt_bs": 1,
         | 
| 118 | 
            +
                    "loss_weight": 1.0,
         | 
| 119 | 
            +
                    "print_step": 100,
         | 
| 120 | 
            +
                    "clip_model": CLIP_MODEL,
         | 
| 121 | 
            +
                    "clip_pretrain": PRETRAINED_CLIP,
         | 
| 122 | 
            +
                }
         | 
| 123 | 
            +
                inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
         | 
| 124 | 
            +
                print(inverted_prompt)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
         | 
| 127 | 
            +
                # return learned_prompt
         | 
| 128 | 
            +
             | 
| 129 | 
             
            # =========================
         | 
| 130 | 
             
            # UI Helper Functions
         | 
| 131 | 
             
            # =========================
         | 
|  | |
| 259 | 
             
                display: flex;
         | 
| 260 | 
             
                justify-content: center;
         | 
| 261 | 
             
            }
         | 
| 262 | 
            +
            #compact-compact-row {
         | 
| 263 | 
            +
                width:100%;
         | 
| 264 | 
            +
                max-width: 800px;
         | 
| 265 | 
            +
                margin: 0px auto;
         | 
| 266 | 
            +
            }
         | 
| 267 | 
             
            #compact-row {
         | 
| 268 | 
             
                width:100%;
         | 
| 269 | 
             
                max-width: 1000px;
         | 
| 270 | 
             
                margin: 0px auto;
         | 
| 271 | 
             
            }
         | 
| 272 | 
            +
            .header-section {
         | 
| 273 | 
            +
                text-align: center;
         | 
| 274 | 
            +
                margin-bottom: 2rem;
         | 
| 275 | 
            +
            }
         | 
| 276 | 
            +
            .abstract-text {
         | 
| 277 | 
            +
                text-align: justify;
         | 
| 278 | 
            +
                line-height: 1.6;
         | 
| 279 | 
            +
                margin: 0.5rem 0;
         | 
| 280 | 
            +
                padding: 0.5rem;
         | 
| 281 | 
            +
                background-color: rgba(0, 0, 0, 0.05);
         | 
| 282 | 
            +
                border-radius: 8px;
         | 
| 283 | 
            +
                border-left: 4px solid #3498db;
         | 
| 284 | 
            +
            }
         | 
| 285 | 
            +
            .paper-link {
         | 
| 286 | 
            +
                display: inline-block;
         | 
| 287 | 
            +
                margin: 0rem 0;
         | 
| 288 | 
            +
                padding: 0rem 0rem;
         | 
| 289 | 
            +
                background-color: #3498db;
         | 
| 290 | 
            +
                color: white;
         | 
| 291 | 
            +
                text-decoration: none;
         | 
| 292 | 
            +
                border-radius: 5px;
         | 
| 293 | 
            +
                font-weight: 500;
         | 
| 294 | 
            +
            }
         | 
| 295 | 
            +
            .paper-link:hover {
         | 
| 296 | 
            +
                background-color: #2980b9;
         | 
| 297 | 
            +
                text-decoration: none;
         | 
| 298 | 
            +
            }
         | 
| 299 | 
            +
            .authors-section {
         | 
| 300 | 
            +
                text-align: center;
         | 
| 301 | 
            +
                margin: 0 0;
         | 
| 302 | 
            +
                font-style: italic;
         | 
| 303 | 
            +
                color: #666;
         | 
| 304 | 
            +
            }
         | 
| 305 | 
            +
            .authors-title {
         | 
| 306 | 
            +
                font-weight: bold;
         | 
| 307 | 
            +
                margin-bottom: 0rem;
         | 
| 308 | 
            +
                color: #333;
         | 
| 309 | 
            +
            }
         | 
| 310 | 
             
            """
         | 
| 311 |  | 
| 312 | 
             
            with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
         | 
| 313 | 
            +
                with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
         | 
| 314 | 
             
                    gr.Markdown("# 📌 **POET**")
         | 
| 315 | 
            +
                    gr.Markdown("## Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
         | 
| 316 | 
            +
                    
         | 
| 317 | 
            +
                    # <strong>Abstract:</strong> State-of-the-art visual generative AI tools hold immense potential to assist users in the early ideation stages of creative tasks — offering the ability to generate (rather than search for) novel and unprecedented (instead of existing) images of considerable quality that also adhere to boundless combinations of user specifications. However, many large-scale text-to-image systems are designed for broad applicability, yielding conventional output that may limit creative exploration. They also employ interaction methods that may be difficult for beginners.        # 
         | 
| 318 | 
            +
                    gr.Markdown("""
         | 
| 319 | 
            +
                    <div class="abstract-text">
         | 
| 320 | 
            +
                    <strong>Abstract:</strong> Given that creative end-users often operate in diverse, context-specific ways that are often unpredictable, more variation and personalization are necessary. We introduce POET, a real-time interactive tool that (1) automatically discovers dimensions of homogeneity in text-to-image generative models, (2) expands these dimensions to diversify the output space of generated images, and (3) learns from user feedback to personalize expansions. Focusing on visual creativity, POET offers a first glimpse of how interaction techniques of future text-to-image generation tools may support and align with more pluralistic values and the needs of end-users during the ideation stages of their work.
         | 
| 321 | 
            +
                    </div>
         | 
| 322 | 
            +
                    """, elem_classes=["abstract-text"])
         | 
| 323 | 
            +
                    
         | 
| 324 | 
            +
                    # Paper Link
         | 
| 325 | 
            +
                    gr.HTML("""
         | 
| 326 | 
            +
                    <div style="text-align: center;">
         | 
| 327 | 
            +
                        <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
         | 
| 328 | 
            +
                            📄 Read the Full Paper .
         | 
| 329 | 
            +
                        </a>
         | 
| 330 | 
            +
                    </div>
         | 
| 331 | 
            +
                    """)
         | 
| 332 | 
            +
                    
         | 
| 333 | 
            +
                    # Authors
         | 
| 334 | 
            +
                    gr.Markdown("""
         | 
| 335 | 
            +
                    <div class="authors-section">
         | 
| 336 | 
            +
                        Evans Han, Alice Qian Zhang, Haiyi Zhu, Hong Shen, Paul Pu Liang, Jane Hsieh
         | 
| 337 | 
            +
                    </div>
         | 
| 338 | 
            +
                    """, elem_classes=["authors-section"])
         | 
| 339 | 
            +
                    
         | 
| 340 | 
            +
                    # gr.Markdown("---")
         | 
| 341 |  | 
| 342 | 
             
                with gr.Tab(""):
         | 
| 343 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| 344 | 
            +
                        with gr.Column(elem_id="col-container"):
         | 
| 345 | 
            +
                            with gr.Row():
         | 
| 346 | 
            +
                                prompt = gr.Textbox(
         | 
| 347 | 
            +
                                    label="🎨 Prompt",
         | 
| 348 | 
            +
                                    max_lines=5,
         | 
| 349 | 
            +
                                    placeholder="Enter your prompt",
         | 
| 350 | 
            +
                                    visible=True,
         | 
| 351 | 
            +
                                )
         | 
| 352 | 
            +
                        with gr.Column(elem_id="col-container3"):
         | 
| 353 | 
            +
                            next_btn = gr.Button("Generate", variant="primary", scale=1)
         | 
| 354 |  | 
| 355 | 
             
                    with gr.Row(elem_id="compact-row"):
         | 
| 356 | 
             
                        with gr.Column(elem_id="col-container"):
         | 
|  | |
| 400 | 
             
                            examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
         | 
| 401 | 
             
                            inputs=[prompt, ex1, ex2, ex3, ex4]
         | 
| 402 | 
             
                        )
         | 
| 403 | 
            +
             | 
| 404 | 
             
            # =========================
         | 
| 405 | 
             
            # Wiring
         | 
| 406 | 
             
            # =========================
         | 
    	
        optim_utils.py
    CHANGED
    
    | @@ -19,9 +19,6 @@ def nn_project(curr_embeds, embedding_layer, print_hits=False): | |
| 19 | 
             
                with torch.no_grad():
         | 
| 20 | 
             
                    bsz,seq_len,emb_dim = curr_embeds.shape
         | 
| 21 |  | 
| 22 | 
            -
                    # Using the sentence transformers semantic search which is 
         | 
| 23 | 
            -
                    # a dot product exact kNN search between a set of 
         | 
| 24 | 
            -
                    # query vectors and a corpus of vectors
         | 
| 25 | 
             
                    curr_embeds = curr_embeds.reshape((-1,emb_dim))
         | 
| 26 | 
             
                    curr_embeds = normalize_embeddings(curr_embeds) # queries
         | 
| 27 |  | 
|  | |
| 19 | 
             
                with torch.no_grad():
         | 
| 20 | 
             
                    bsz,seq_len,emb_dim = curr_embeds.shape
         | 
| 21 |  | 
|  | |
|  | |
|  | |
| 22 | 
             
                    curr_embeds = curr_embeds.reshape((-1,emb_dim))
         | 
| 23 | 
             
                    curr_embeds = normalize_embeddings(curr_embeds) # queries
         | 
| 24 |  | 
