Spaces:
Runtime error
Runtime error
lock with async wait
Browse files
app.py
CHANGED
|
@@ -17,6 +17,8 @@ import uuid
|
|
| 17 |
import logging
|
| 18 |
from fastapi import FastAPI, Request, HTTPException
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
| 20 |
|
| 21 |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
| 22 |
|
|
@@ -24,7 +26,7 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
| 24 |
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
|
| 25 |
SPACE_ID = os.environ.get("SPACE_ID", "")
|
| 26 |
DEV = os.environ.get("DEV", "0") == "1"
|
| 27 |
-
os.environ[
|
| 28 |
|
| 29 |
DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
|
| 30 |
IMGS_PATH = DB_PATH / "imgs"
|
|
@@ -32,6 +34,7 @@ DB_PATH.mkdir(exist_ok=True, parents=True)
|
|
| 32 |
IMGS_PATH.mkdir(exist_ok=True, parents=True)
|
| 33 |
|
| 34 |
database = Database(DB_PATH)
|
|
|
|
| 35 |
|
| 36 |
dtype = torch.bfloat16
|
| 37 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -120,21 +123,25 @@ app.add_middleware(
|
|
| 120 |
|
| 121 |
|
| 122 |
@app.get("/image")
|
| 123 |
-
async def generate_image(
|
|
|
|
|
|
|
| 124 |
cached_img = database.check(prompt, negative_prompt, seed)
|
| 125 |
if cached_img:
|
| 126 |
logging.info(f"Image found in cache: {cached_img[0]}")
|
| 127 |
return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
|
| 128 |
|
| 129 |
logging.info(f"Image not found in cache, generating new image")
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
|
| 139 |
return StreamingResponse(img_io, media_type="image/jpeg")
|
| 140 |
|
|
|
|
| 17 |
import logging
|
| 18 |
from fastapi import FastAPI, Request, HTTPException
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from asyncio import Lock
|
| 21 |
+
|
| 22 |
|
| 23 |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
| 24 |
|
|
|
|
| 26 |
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
|
| 27 |
SPACE_ID = os.environ.get("SPACE_ID", "")
|
| 28 |
DEV = os.environ.get("DEV", "0") == "1"
|
| 29 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 30 |
|
| 31 |
DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
|
| 32 |
IMGS_PATH = DB_PATH / "imgs"
|
|
|
|
| 34 |
IMGS_PATH.mkdir(exist_ok=True, parents=True)
|
| 35 |
|
| 36 |
database = Database(DB_PATH)
|
| 37 |
+
generate_lock = Lock()
|
| 38 |
|
| 39 |
dtype = torch.bfloat16
|
| 40 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
@app.get("/image")
|
| 126 |
+
async def generate_image(
|
| 127 |
+
prompt: str, negative_prompt: str = "", seed: int = 2134213213
|
| 128 |
+
):
|
| 129 |
cached_img = database.check(prompt, negative_prompt, seed)
|
| 130 |
if cached_img:
|
| 131 |
logging.info(f"Image found in cache: {cached_img[0]}")
|
| 132 |
return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
|
| 133 |
|
| 134 |
logging.info(f"Image not found in cache, generating new image")
|
| 135 |
+
async with generate_lock:
|
| 136 |
+
|
| 137 |
+
pil_image = generate(prompt, negative_prompt, seed)
|
| 138 |
+
img_id = str(uuid.uuid4())
|
| 139 |
+
img_path = IMGS_PATH / f"{img_id}.jpg"
|
| 140 |
+
pil_image.save(img_path)
|
| 141 |
+
img_io = io.BytesIO()
|
| 142 |
+
pil_image.save(img_io, "JPEG")
|
| 143 |
+
img_io.seek(0)
|
| 144 |
+
database.insert(prompt, negative_prompt, str(img_path), seed)
|
| 145 |
|
| 146 |
return StreamingResponse(img_io, media_type="image/jpeg")
|
| 147 |
|