Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import datetime | |
| import json | |
| import os | |
| import pathlib | |
| import tempfile | |
| from typing import Any | |
| import gradio as gr | |
| from gradio_client import Client | |
| from scheduler import ParquetScheduler | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| UPLOAD_REPO_ID = os.environ["UPLOAD_REPO_ID"] | |
| UPLOAD_FREQUENCY = int(os.getenv("UPLOAD_FREQUENCY", "15")) | |
| USE_PUBLIC_REPO = os.getenv("USE_PUBLIC_REPO") == "1" | |
| ABOUT_THIS_SPACE = """ | |
| This Space is a sample Space that collects user preferences for the results generated by a diffusion model. | |
| This demo calls the [stable diffusion Space](https://huggingface.co/spaces/stabilityai/stable-diffusion) with the [`gradio_client`](https://pypi.org/project/gradio-client/) library. | |
| The user preference data is periodically archived in parquet format and uploaded to [this dataset repo](https://huggingface.co/datasets/hysts-samples/sample-user-preferences). | |
| The periodic upload is done using [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler). | |
| See [this Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) for more general usage. | |
| """ | |
| scheduler = ParquetScheduler( | |
| repo_id=UPLOAD_REPO_ID, | |
| every=UPLOAD_FREQUENCY, | |
| private=not USE_PUBLIC_REPO, | |
| token=HF_TOKEN, | |
| ) | |
| # client = Client("stabilityai/stable-diffusion") # Space is paused | |
| client = Client("runwayml/stable-diffusion-v1-5") | |
| def generate(prompt: str) -> tuple[str, list[str]]: | |
| negative_prompt = "" | |
| guidance_scale = 9.0 | |
| # out_dir = client.predict(prompt, negative_prompt, guidance_scale, fn_index=1) # Space 'stabilityai/stable-diffusion' is paused | |
| out_dir = client.predict(prompt, fn_index=1) | |
| config = { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "guidance_scale": guidance_scale, | |
| } | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file: | |
| json.dump(config, config_file) | |
| with (pathlib.Path(out_dir) / "captions.json").open() as f: | |
| paths = list(json.load(f).keys()) | |
| return config_file.name, paths | |
| def get_selected_index(evt: gr.SelectData) -> int: | |
| return evt.index | |
| def save_preference(config_path: str, gallery: list[dict[str, Any]], selected_index: int) -> None: | |
| # Load config | |
| with open(config_path) as f: | |
| data = json.load(f) | |
| # Add selected item + timestamp | |
| data["selected_index"] = selected_index | |
| data["timestamp"] = datetime.datetime.utcnow().isoformat() | |
| # Add images | |
| for index, path in enumerate(x["name"] for x in gallery): | |
| data[f"image_{index:03d}"] = path | |
| # Send to scheduler | |
| scheduler.append(data) | |
| def update_save_button(selected_index: int) -> dict: | |
| return gr.update(interactive=selected_index != -1) | |
| def clear() -> tuple[dict, dict, dict]: | |
| return ( | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=-1), | |
| ) | |
| with gr.Blocks(css="style.css") as demo: | |
| with gr.Group(): | |
| prompt = gr.Text(show_label=False, placeholder="Prompt") | |
| gallery = gr.Gallery( | |
| show_label=False, | |
| columns=2, | |
| rows=2, | |
| height="600px", | |
| object_fit="scale-down", | |
| allow_preview=False, | |
| ) | |
| save_preference_button = gr.Button("Save preference", interactive=False) | |
| config_path = gr.Text(visible=False) | |
| selected_index = gr.Number(visible=False, precision=0, value=-1) | |
| with gr.Accordion(label="About this Space", open=False): | |
| gr.Markdown(ABOUT_THIS_SPACE) | |
| prompt.submit( | |
| fn=generate, | |
| inputs=prompt, | |
| outputs=[config_path, gallery], | |
| api_name=False, | |
| ) | |
| selected_index.change( | |
| fn=update_save_button, | |
| inputs=selected_index, | |
| outputs=save_preference_button, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| gallery.select( | |
| fn=get_selected_index, | |
| outputs=selected_index, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| save_preference_button.click( | |
| fn=save_preference, | |
| inputs=[config_path, gallery, selected_index], | |
| queue=False, | |
| api_name=False, | |
| ).then( | |
| fn=clear, | |
| outputs=[config_path, gallery, selected_index], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False, concurrency_count=5).launch() | |