Spaces:
Paused
Paused
| import gradio as gr | |
| import asyncio | |
| from threading import RLock | |
| from pathlib import Path | |
| import os | |
| from typing import Union | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| server_timeout = 600 | |
| inference_timeout = 600 | |
| lock = RLock() | |
| loaded_models = {} | |
| def rename_image(image_path: Union[str, None], model_name: str, save_path: Union[str, None] = None): | |
| import shutil | |
| from datetime import datetime, timezone, timedelta | |
| if image_path is None: return None | |
| dt_now = datetime.now(timezone(timedelta(hours=9))) | |
| filename = f"{model_name.split('/')[-1]}_{dt_now.strftime('%Y%m%d_%H%M%S')}.png" | |
| try: | |
| if Path(image_path).exists(): | |
| png_path = "image.png" | |
| if str(Path(image_path).resolve()) != str(Path(png_path).resolve()): shutil.copy(image_path, png_path) | |
| if save_path is not None: | |
| new_path = str(Path(png_path).resolve().rename(Path(save_path).resolve())) | |
| else: | |
| new_path = str(Path(png_path).resolve().rename(Path(filename).resolve())) | |
| return new_path | |
| else: | |
| return None | |
| except Exception as e: | |
| print(e) | |
| return None | |
| # https://github.com/gradio-app/gradio/blob/main/gradio/external.py | |
| # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
| from typing import Literal | |
| def load_from_model(model_name: str, hf_token: Union[str, Literal[False], None] = None): | |
| import httpx | |
| import huggingface_hub | |
| from gradio.exceptions import ModelNotFoundError, TooManyRequestsError | |
| model_url = f"https://huggingface.co/{model_name}" | |
| api_url = f"/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2F%3Cspan class="hljs-subst">{model_name}" | |
| print(f"Fetching model from: {model_url}") | |
| headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"}) | |
| response = httpx.request("GET", api_url, headers=headers) | |
| if response.status_code != 200: | |
| raise ModelNotFoundError( | |
| f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter." | |
| ) | |
| p = response.json().get("pipeline_tag") | |
| if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.") | |
| headers["X-Wait-For-Model"] = "true" | |
| kwargs = {} | |
| if hf_token is not None: kwargs["token"] = hf_token | |
| client = huggingface_hub.InferenceClient(model=model_name, headers=headers, timeout=server_timeout, **kwargs) | |
| inputs = gr.components.Textbox(label="Input") | |
| outputs = gr.components.Image(label="Output") | |
| fn = client.text_to_image | |
| def query_huggingface_inference_endpoints(*data, **kwargs): | |
| try: | |
| data = fn(*data, **kwargs) # type: ignore | |
| except huggingface_hub.utils.HfHubHTTPError as e: | |
| print(e) | |
| if "429" in str(e): raise TooManyRequestsError() from e | |
| except Exception as e: | |
| print(e) | |
| raise Exception() from e | |
| return data | |
| interface_info = { | |
| "fn": query_huggingface_inference_endpoints, | |
| "inputs": inputs, | |
| "outputs": outputs, | |
| "title": model_name, | |
| } | |
| return gr.Interface(**interface_info) | |
| def load_model(model_name: str): | |
| global loaded_models | |
| global model_info_dict | |
| if model_name in loaded_models.keys(): return loaded_models[model_name] | |
| try: | |
| loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN) | |
| print(f"Loaded: {model_name}") | |
| except Exception as e: | |
| if model_name in loaded_models.keys(): del loaded_models[model_name] | |
| print(f"Failed to load: {model_name}") | |
| print(e) | |
| return None | |
| return loaded_models[model_name] | |
| def load_models(models: list): | |
| for model in models: | |
| load_model(model) | |
| def warm_model(model_name: str): | |
| model = load_model(model_name) | |
| if model: | |
| try: | |
| print(f"Warming model: {model_name}") | |
| infer_body(model, model_name, " ") | |
| except Exception as e: | |
| print(e) | |
| def warm_models(models: list[str]): | |
| for model in models: | |
| asyncio.new_event_loop().run_in_executor(None, warm_model, model) | |
| # https://huggingface.co/docs/api-inference/detailed_parameters | |
| # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
| def infer_body(client: Union[gr.Interface, object], model_str: str, prompt: str, neg_prompt: str = "", | |
| height: int = 0, width: int = 0, steps: int = 0, cfg: int = 0, seed: int = -1): | |
| png_path = "image.png" | |
| kwargs = {} | |
| if height > 0: kwargs["height"] = height | |
| if width > 0: kwargs["width"] = width | |
| if steps > 0: kwargs["num_inference_steps"] = steps | |
| if cfg > 0: cfg = kwargs["guidance_scale"] = cfg | |
| if seed == -1: kwargs["seed"] = randomize_seed() | |
| else: kwargs["seed"] = seed | |
| try: | |
| if isinstance(client, gr.Interface): image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs) | |
| else: return None | |
| if isinstance(image, tuple): return None | |
| return save_image(image, png_path, model_str, prompt, neg_prompt, height, width, steps, cfg, seed) | |
| except Exception as e: | |
| print(e) | |
| raise Exception(e) from e | |
| async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int = 0, width: int = 0, | |
| steps: int = 0, cfg: int = 0, seed: int = -1, | |
| save_path: str | None = None, timeout: float = inference_timeout): | |
| model = load_model(model_name) | |
| if not model: return None | |
| task = asyncio.create_task(asyncio.to_thread(infer_body, model, model_name, prompt, neg_prompt, | |
| height, width, steps, cfg, seed)) | |
| await asyncio.sleep(0) | |
| try: | |
| result = await asyncio.wait_for(task, timeout=timeout) | |
| except asyncio.TimeoutError as e: | |
| print(e) | |
| print(f"Task timed out: {model_name}") | |
| if not task.done(): task.cancel() | |
| result = None | |
| raise Exception(f"Task timed out: {model_name}") from e | |
| except Exception as e: | |
| print(e) | |
| if not task.done(): task.cancel() | |
| result = None | |
| raise Exception(e) from e | |
| if task.done() and result is not None: | |
| with lock: | |
| image = rename_image(result, model_name, save_path) | |
| return image | |
| return None | |
| def save_image(image, savefile, modelname, prompt, nprompt, height=0, width=0, steps=0, cfg=0, seed=-1): | |
| from PIL import Image, PngImagePlugin | |
| import json | |
| try: | |
| metadata = {"prompt": prompt, "negative_prompt": nprompt, "Model": {"Model": modelname.split("/")[-1]}} | |
| if steps > 0: metadata["num_inference_steps"] = steps | |
| if cfg > 0: metadata["guidance_scale"] = cfg | |
| if seed != -1: metadata["seed"] = seed | |
| if width > 0 and height > 0: metadata["resolution"] = f"{width} x {height}" | |
| metadata_str = json.dumps(metadata) | |
| info = PngImagePlugin.PngInfo() | |
| info.add_text("metadata", metadata_str) | |
| image.save(savefile, "PNG", pnginfo=info) | |
| return str(Path(savefile).resolve()) | |
| except Exception as e: | |
| print(f"Failed to save image file: {e}") | |
| raise Exception(f"Failed to save image file:") from e | |
| def randomize_seed(): | |
| from random import seed, randint | |
| MAX_SEED = 2**32-1 | |
| seed() | |
| rseed = randint(0, MAX_SEED) | |
| return rseed | |
| def gen_image(model_name: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0, | |
| steps: int = 0, cfg: int = 0, seed: int = -1): | |
| if model_name in ["NA", ""]: return gr.update() | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except Exception: | |
| loop = asyncio.new_event_loop() | |
| try: | |
| result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width, | |
| steps, cfg, seed, None, inference_timeout)) | |
| except (Exception, asyncio.CancelledError) as e: | |
| print(e) | |
| print(f"Task aborted: {model_name}, Error: {e}") | |
| result = None | |
| raise gr.Error(f"Task aborted: {model_name}, Error: {e}") | |
| finally: | |
| loop.close() | |
| return result | |
| def generate_image_hf(model_name: str, prompt: str, negative_prompt: str, use_defaults: bool, resolution: str, | |
| guidance_scale: float, num_inference_steps: int, seed: int, randomize_seed: bool, progress=gr.Progress()): | |
| if randomize_seed: seed = -1 | |
| if use_defaults: | |
| prompt = f"{prompt}, best quality, amazing quality, very aesthetic" | |
| negative_prompt = f"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], {negative_prompt}" | |
| width, height = map(int, resolution.split('x')) | |
| image = gen_image(model_name, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale) | |
| metadata_text = f"{prompt}\nNegative prompt: {negative_prompt}\nSteps: {num_inference_steps}, Sampler: Euler a, Size: {width}x{height}, Seed: {seed}, CFG scale: {guidance_scale}" | |
| return image, seed, metadata_text | |