from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from pydantic import BaseModel from torch import cuda from transformers import AutoModelForCausalLM, AutoTokenizer import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) model = None tokenizer = None class TextInput(BaseModel): text: str min_length: int = 3 # Apertus by default supports a context length up to 65,536 tokens. max_length: int = 65536 class ModelResponse(BaseModel): text: str confidence: float processing_time: float @asynccontextmanager async def lifespan(app: FastAPI): """Load the transformer model on startup""" global model, tokenizer try: logger.info("Loading sentiment analysis model...") # TODO: make this configurable model_name = "swiss-ai/Apertus-8B-Instruct-2509" # Automatically select device based on availability device = "cuda" if cuda.is_available() else "cpu" # load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, ).to(device) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Release resources when the app is stopped yield model.clear() tokenizer.clear() # Setup our app app = FastAPI( title="Apertus API", description="REST API for serving Apertus models via Hugging Face transformers", version="0.1.0", lifespan=lifespan ) @app.get("/predict", response_model=ModelResponse) async def predict(q: str): """Generate a model response for input text""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: import time start_time = time.time() input_data = TextInput(text=q) # Truncate text if too long text = input_data.text[:input_data.max_length] if len(text) == input_data.max_length: logger.warning("Warning: text truncated") if len(text) < input_data.min_length: logger.warning("Warning: empty text, aborting") return None # Prepare the model input messages_think = [ {"role": "user", "content": text} ] text = tokenizer.apply_chat_template( messages_think, tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Generate the output generated_ids = model.generate(**model_inputs, max_new_tokens=32768) # Get and decode the output output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] result = tokenizer.decode(output_ids, skip_special_tokens=True) # Checkpoint processing_time = time.time() - start_time return ModelResponse( text=result['label'], confidence=result['score'], processing_time=processing_time ) except Exception as e: logger.error(f"Evaluation error: {e}") raise HTTPException(status_code=500, detail="Evaluation failed") @app.get("/health") async def health_check(): """Health check and basic configuration""" return { "status": "healthy", "model_loaded": model is not None, "gpu_available": cuda.is_available() }