Spaces:
Runtime error
Runtime error
| #from flask import Flask, render_template, request | |
| from functools import lru_cache | |
| import math | |
| import os | |
| import logging | |
| import traceback | |
| import json | |
| import argparse | |
| from fastapi import FastAPI | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware # Cross-origin Resource Sharing: when FE running in a browser has JS code that communicates with BE | |
| from pydantic import BaseModel | |
| #from search_online import OnlineSearcher | |
| from search_online_demo_TEMPORARY import OnlineSearcher | |
| description = """ | |
| Retrieval inference. | |
| """ | |
| TASK_DESCRIPTION="Retrieval" | |
| TASK_VERSION="0.1.0" | |
| args = argparse.Namespace() | |
| searcher = OnlineSearcher(args) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title=TASK_DESCRIPTION, | |
| description=description, | |
| version=TASK_VERSION | |
| ) | |
| ## Use CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| counter = {"api" : 0} | |
| ## Response | |
| class RetrievalResponse(BaseModel): | |
| __root__: Any | |
| async def healthcheck() -> JSONResponse: | |
| """HealthCheck""" | |
| return JSONResponse(status_code=200, content="health check success") | |
| def api_search_query(query, k): | |
| print(f"Query={query}") | |
| if k == None: k = 10 | |
| k = min(int(k), 100) | |
| pids, ranks, scores = searcher.search(query, k=100) | |
| pids, ranks, scores = pids[:k], ranks[:k], scores[:k] | |
| passages = [searcher.collection[pid] for pid in pids] | |
| probs = [math.exp(score) for score in scores] | |
| probs = [prob / sum(probs) for prob in probs] | |
| topk = [] | |
| for pid, rank, score, prob in zip(pids, ranks, scores, probs): | |
| text = searcher.collection[pid] | |
| d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} | |
| topk.append(d) | |
| topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) | |
| return {"query" : query, "topk": topk} | |
| async def api_search(query: str, k: int = 10) -> JSONResponse: | |
| """ | |
| Retrieval inference | |
| - query : user question (type str) | |
| - k : topK to retrieve (type int) | |
| """ | |
| counter["api"] += 1 | |
| print("API request count:", counter["api"]) | |
| try: | |
| response = api_search_query(query=query, k=k) | |
| return JSONResponse( | |
| status_code=200, content=response | |
| ) | |
| except Exception as e: | |
| logger.error(f"inference exception: {str(e)}") | |
| log_traceback = traceback.format_exc() | |
| return JSONResponse( | |
| status_code=500, content={"error": {"code": "500", "message": f"{str(e)}\n{str(log_traceback)}"}} | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn # before gunicorn, try with uvicorn for python-standalone debugging | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT"))) | |