from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, ValidationError from typing import List, Optional from torch import cuda from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig ) from hashlib import sha256 from huggingface_hub import login from dotenv import load_dotenv from datetime import datetime import os import uvicorn import time import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Required for access to a gated model load_dotenv() hf_token = os.getenv("HF_TOKEN", None) if hf_token is not None: login(token=hf_token) # Configurable model identifier model_name = os.getenv("HF_MODEL", "swiss-ai/Apertus-8B-Instruct-2509") model_quantization = int(os.getenv("QUANTIZE", 0)) # 8, 4, 0=default # Configure max tokens MAX_NEW_TOKENS = 4096 # Load base prompt from a text file system_prompt = None if int(os.getenv("USE_SYSTEM_PROMPT", 1)): with open('system_prompt.md', 'r') as file: system_prompt = file.read() # Keep data in session model = None tokenizer = None class TextInput(BaseModel): text: str = "" min_length: int = 3 # Apertus by default supports a context length up to 65,536 tokens. max_length: int = 65536 class ModelResponse(BaseModel): text: str confidence: float processing_time: float class ChatMessage(BaseModel): role: str = "user" content: str = "" class Completion(BaseModel): model: str = "apertus" messages: List[ChatMessage] max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.1 top_p: Optional[float] = 0.9 @asynccontextmanager async def lifespan(app: FastAPI): """Load the transformer model on startup""" global model, tokenizer try: logger.info(f"Loading model: {model_name}") # Automatically select device based on availability device = "cuda" if cuda.is_available() else "cpu" # load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(model_name) # Use a quantization setting bnb_config = None if model_quantization == 8: bnb_config = BitsAndBytesConfig(load_in_8bit=True) elif model_quantization == 4: bnb_config = BitsAndBytesConfig(load_in_4bit=True) if bnb_config is not None: model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically splits model across CPU/GPU offload_folder="offload", # Temporary offload to disk low_cpu_mem_usage=True, # Avoids unnecessary CPU memory duplication quantization_config=bnb_config, # To reduce memory and overhead ) else: model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically splits model across CPU/GPU offload_folder="offload", # Temporary offload to disk ) logger.info(f"Model loaded successfully! ({device})") except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Release resources when the app is stopped yield del model del tokenizer cuda.empty_cache() # Setup our app app = FastAPI( title="Apertus API", description="REST API for serving Apertus models via Hugging Face transformers", version="0.1.0", docs_url="/", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def fit_to_length(text, min_length=3, max_length=100): """Truncate text if too long.""" text = text[:max_length] if len(text) == max_length: logger.warning("Warning: text truncated") if len(text) < min_length: logger.warning("Warning: empty text, aborting") return None return text def get_completion_text(messages_think: List[ChatMessage]): txt = "" for cm in messages_think: txt = " ".join((txt, cm.content)) return txt def get_message_id(txt: str): return sha256(str(txt).encode()).hexdigest() def get_model_reponse(messages_think: List[ChatMessage]): """Process the text content.""" # Apply the system template has_system = False for m in messages_think: if m.role == 'system': has_system = True if not has_system and system_prompt: cm = ChatMessage(role='system', content=system_prompt) messages_think.insert(0, cm) print(messages_think) # Prepare the model input text = tokenizer.apply_chat_template( messages_think, tokenize=False, add_generation_prompt=True, top_p=0.9, temperature=0.8, ) model_inputs = tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to(model.device) # Generate the output generated_ids = model.generate( **model_inputs, max_new_tokens=MAX_NEW_TOKENS ) # Get and decode the output output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] # Decode the text message return tokenizer.decode(output_ids, skip_special_tokens=True) @app.post("/v1/models/apertus") async def completion(data: Completion): """Generate an OpenAPI-style completion""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: mt = data.messages text = get_completion_text(mt) result = get_model_reponse(mt) # Standard formatted object return { "id": get_message_id(text), "object": "chat.completion", "created": time.time(), "model": data.model, "choices": [{ "message": ChatMessage(role="assistant", content=result) }], "usage": { "prompt_tokens": len(text), "completion_tokens": len(result), "total_tokens": len(text) + len(result) } } except Exception as e: logger.warning(e) raise HTTPException(status_code=400, detail="Could not process") from e @app.get("/predict", response_model=ModelResponse) async def predict(q: str): """Generate a model response for input text""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: start_time = time.time() input_data = TextInput(text=q) text = fit_to_length(input_data.text, input_data.min_length, input_data.max_length) messages_think = [ {"role": "user", "content": text} ] result = get_model_reponse(messages_think) # Checkpoint processing_time = time.time() - start_time return ModelResponse( text=result, #['label'], confidence=0, #result['score'], processing_time=processing_time ) except Exception as e: logger.warning(e) raise HTTPException(status_code=500, detail="Evaluation failed") @app.get("/health") async def health_check(): """Health check and basic configuration""" return { "status": "healthy", "model_loaded": model is not None, "gpu_available": cuda.is_available() } if __name__=='__main__': uvicorn.run('app:app', reload=True)