import os from typing import List, Literal, Optional from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import BaseModel from huggingface_hub import InferenceClient # -------------------- Config -------------------- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_API_TOKEN") MODEL_ID = "google/gemma-2-2b-it" # conversational model if HF_TOKEN is None: raise RuntimeError( "HF_TOKEN or HF_API_TOKEN is not set. " "Go to Space → Settings → Variables & secrets and add one." ) hf_client = InferenceClient(token=HF_TOKEN) DEFAULT_SYSTEM_PROMPT = ( "You are a helpful, friendly AI assistant. " "Answer clearly and concisely." ) # -------------------- FastAPI setup -------------------- app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str class ChatRequest(BaseModel): messages: List[ChatMessage] temperature: float = 0.7 max_new_tokens: int = 256 system_prompt: Optional[str] = None # -------------------- Helpers -------------------- def convert_messages(messages: List[ChatMessage], system_prompt: Optional[str]): """ Convert our internal message format into OpenAI-style messages. """ sys = system_prompt or DEFAULT_SYSTEM_PROMPT out = [{"role": "system", "content": sys}] for m in messages: out.append({"role": m.role, "content": m.content}) return out def call_llm(req: ChatRequest) -> str: """ Call Gemma using the chat.completions API. """ msgs = convert_messages(req.messages, req.system_prompt) try: resp = hf_client.chat.completions.create( model=MODEL_ID, messages=msgs, max_tokens=req.max_new_tokens, temperature=req.temperature, ) # Extract chat reply return resp.choices[0].message["content"].strip() except Exception as e: raise RuntimeError(f"Inference API error: {e}") # -------------------- Routes -------------------- @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/chat") async def chat_endpoint(payload: ChatRequest): if not payload.messages: return JSONResponse({"error": "No messages provided"}, status_code=400) try: reply = call_llm(payload) return {"reply": reply} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500)