Spaces:
Running
Running
| import json | |
| from functools import partial | |
| from .utils import * | |
| from .vote_utils import ( | |
| upvote_last_response_t2s as upvote_last_response, | |
| downvote_last_response_t2s as downvote_last_response, | |
| flag_last_response_t2s as flag_last_response, | |
| ) | |
| from .inference import( | |
| sample_prompt, | |
| generate_t2s | |
| ) | |
| from constants import TEXT_PROMPT_PATH | |
| with open(TEXT_PROMPT_PATH, 'r') as f: | |
| prompt_list = json.load(f) | |
| def build_single_model_ui(models): | |
| notice_markdown = """ | |
| # ποΈ Play with Image Generation Models | |
| {promotion} | |
| ## π€ Choose any model to generate | |
| """ | |
| model_list = models.get_t2s_models() | |
| gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel) | |
| gr.Markdown(notice_markdown, elem_id="notice_markdown") | |
| with gr.Row(elem_id="model_selector_row"): | |
| model_selector = gr.Dropdown( | |
| choices=model_list, | |
| value=model_list[0] if len(model_list) > 0 else "", | |
| interactive=True, | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| with gr.Accordion("π Expand to see all Arena players", open=False): | |
| model_description_md = get_model_description_md(model_list) | |
| gr.Markdown(model_description_md, elem_id="model_description_markdown") | |
| with gr.Row(): | |
| textbox = gr.Textbox( | |
| show_label=False, | |
| placeholder="π Enter your prompt or Sample a random prompt, and press ENTER", | |
| container=True, | |
| elem_id="input_box", | |
| ) | |
| sample_btn = gr.Button(value="π² Sample", variant="primary", scale=0) | |
| send_btn = gr.Button(value="π€ Send", variant="primary", scale=0) | |
| with gr.Row(): | |
| normal = gr.Image(width=512, label = "Normal", show_copy_button=True) | |
| rgb = gr.Image(width=512, label = "RGB", show_copy_button=True,) | |
| with gr.Row(): | |
| clear_btn = gr.Button(value="ποΈ Clear", interactive=False) | |
| regenerate_btn = gr.Button(value="π Regenerate", interactive=False) | |
| with gr.Row(elem_id="Geometry Quality"): | |
| geo_upvote_btn = gr.Button(value="π Upvote", interactive=False) | |
| geo_downvote_btn = gr.Button(value="π Downvote", interactive=False) | |
| geo_flag_btn = gr.Button(value="β οΈ Flag", interactive=False) | |
| with gr.Row(elem_id="Texture Quality"): | |
| text_upvote_btn = gr.Button(value="π Upvote", interactive=False) | |
| text_downvote_btn = gr.Button(value="π Downvote", interactive=False) | |
| text_flag_btn = gr.Button(value="β οΈ Flag", interactive=False) | |
| with gr.Row(elem_id="Alignment Quality"): | |
| align_upvote_btn = gr.Button(value="π Upvote", interactive=False) | |
| align_downvote_btn = gr.Button(value="π Downvote", interactive=False) | |
| align_flag_btn = gr.Button(value="β οΈ Flag", interactive=False) | |
| gr.Markdown(acknowledgment_md, elem_id="ack_markdown") | |
| state = gr.State() | |
| geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn] | |
| text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn] | |
| align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn] | |
| for btn_list in [geo_btn_list, text_btn_list, align_btn_list]: | |
| upvote_btn, downvote_btn, flag_btn = btn_list | |
| upvote_btn.click( | |
| upvote_last_response, | |
| [state, model_selector], | |
| [textbox] + btn_list | |
| ) | |
| downvote_btn.click( | |
| downvote_last_response, | |
| [state, model_selector], | |
| [textbox] + btn_list | |
| ) | |
| flag_btn.click( | |
| flag_last_response, | |
| [state, model_selector], | |
| [textbox] + btn_list | |
| ) | |
| sample_btn.click( | |
| sample_prompt, | |
| [state, model_selector, prompt_list], | |
| state + [textbox], | |
| api_name="sample_btn_single" | |
| ) | |
| textbox.submit( | |
| gen_func, | |
| [state, textbox, model_selector, prompt_list], | |
| [state, normal, rgb], | |
| api_name="submit_btn_single", | |
| show_progress = "full" | |
| ).then( | |
| enable_buttons, | |
| None, | |
| geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] | |
| ) | |
| send_btn.click( | |
| gen_func, | |
| [state, textbox, model_selector, prompt_list], | |
| [state, normal, rgb], | |
| api_name="send_btn_single", | |
| show_progress = "full" | |
| ).then( | |
| enable_buttons, | |
| None, | |
| geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] | |
| ) | |
| clear_btn.click( | |
| clear_history, | |
| None, | |
| [state, textbox, normal, rgb], | |
| api_name="clear_history_single", | |
| show_progress="full" | |
| ).then( | |
| disable_buttons, | |
| None, | |
| geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] | |
| ) | |
| regenerate_btn.click( | |
| gen_func, | |
| [state, textbox, model_selector, prompt_list], | |
| [state, normal, rgb], | |
| api_name="regenerate_btn_single", | |
| show_progress = "full" | |
| ).then( | |
| enable_buttons, | |
| None, | |
| geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn] | |
| ) | |