Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline | |
| import uvicorn | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import RedirectResponse, StreamingResponse | |
| import io | |
| import os | |
| from pathlib import Path | |
| from db import Database | |
| import uuid | |
| import logging | |
| from fastapi import FastAPI, Request, HTTPException | |
| from asyncio import Lock | |
| logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1" | |
| SPACE_ID = os.environ.get("SPACE_ID", "") | |
| DEV = os.environ.get("DEV", "0") == "1" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache") | |
| IMGS_PATH = DB_PATH / "imgs" | |
| DB_PATH.mkdir(exist_ok=True, parents=True) | |
| IMGS_PATH.mkdir(exist_ok=True, parents=True) | |
| database = Database(DB_PATH) | |
| generate_lock = Lock() | |
| dtype = torch.bfloat16 | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| if torch.cuda.is_available(): | |
| prior_pipeline = StableCascadePriorPipeline.from_pretrained( | |
| "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained( | |
| "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| if USE_TORCH_COMPILE: | |
| prior_pipeline.prior = torch.compile( | |
| prior_pipeline.prior, mode="reduce-overhead", fullgraph=True | |
| ) | |
| decoder_pipeline.decoder = torch.compile( | |
| decoder_pipeline.decoder, mode="max-autotune", fullgraph=True | |
| ) | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| seed: int = 0, | |
| width: int = 1024, | |
| height: int = 1024, | |
| prior_num_inference_steps: int = 20, | |
| prior_guidance_scale: float = 4.0, | |
| decoder_num_inference_steps: int = 10, | |
| decoder_guidance_scale: float = 0.0, | |
| num_images_per_prompt: int = 1, | |
| ) -> PIL.Image.Image: | |
| generator = torch.Generator().manual_seed(seed) | |
| prior_output = prior_pipeline( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=prior_num_inference_steps, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=prior_guidance_scale, | |
| num_images_per_prompt=num_images_per_prompt, | |
| generator=generator, | |
| ) | |
| decoder_output = decoder_pipeline( | |
| image_embeddings=prior_output.image_embeddings, | |
| prompt=prompt, | |
| num_inference_steps=decoder_num_inference_steps, | |
| # timesteps=decoder_timesteps, | |
| guidance_scale=decoder_guidance_scale, | |
| negative_prompt=negative_prompt, | |
| generator=generator, | |
| output_type="pil", | |
| ).images | |
| return decoder_output[0] | |
| app = FastAPI() | |
| origins = [ | |
| "https://huggingface.co", | |
| "http://huggingface.co", | |
| "https://huggingface.co/", | |
| "http://huggingface.co/", | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def validate_origin(request: Request, call_next): | |
| if DEV: | |
| return await call_next(request) | |
| if request.headers.get("referer") not in origins: | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| return await call_next(request) | |
| async def generate_image( | |
| prompt: str, negative_prompt: str = "", seed: int = 2134213213 | |
| ): | |
| cached_img = database.check(prompt, negative_prompt, seed) | |
| if cached_img: | |
| logging.info(f"Image found in cache: {cached_img[0]}") | |
| return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg") | |
| logging.info(f"Image not found in cache, generating new image") | |
| async with generate_lock: | |
| pil_image = generate(prompt, negative_prompt, seed) | |
| img_id = str(uuid.uuid4()) | |
| img_path = IMGS_PATH / f"{img_id}.jpg" | |
| pil_image.save(img_path) | |
| img_io = io.BytesIO() | |
| pil_image.save(img_io, "JPEG") | |
| img_io.seek(0) | |
| database.insert(prompt, negative_prompt, str(img_path), seed) | |
| return StreamingResponse(img_io, media_type="image/jpeg") | |
| async def main(): | |
| # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade | |
| return RedirectResponse( | |
| "https://multimodalart-stable-cascade.hf.space/?__theme=system" | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |