Spaces:
Runtime error
Runtime error
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from torch import cuda | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| model = None | |
| tokenizer = None | |
| class TextInput(BaseModel): | |
| text: str | |
| min_length: int = 3 | |
| # Apertus by default supports a context length up to 65,536 tokens. | |
| max_length: int = 65536 | |
| class ModelResponse(BaseModel): | |
| text: str | |
| confidence: float | |
| processing_time: float | |
| async def lifespan(app: FastAPI): | |
| """Load the transformer model on startup""" | |
| global model, tokenizer | |
| try: | |
| logger.info("Loading sentiment analysis model...") | |
| # TODO: make this configurable | |
| model_name = "swiss-ai/Apertus-8B-Instruct-2509" | |
| # Automatically select device based on availability | |
| device = "cuda" if cuda.is_available() else "cpu" | |
| # load the tokenizer and the model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| ).to(device) | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise e | |
| # Release resources when the app is stopped | |
| yield | |
| model.clear() | |
| tokenizer.clear() | |
| # Setup our app | |
| app = FastAPI( | |
| title="Apertus API", | |
| description="REST API for serving Apertus models via Hugging Face transformers", | |
| version="0.1.0", | |
| lifespan=lifespan | |
| ) | |
| async def predict(q: str): | |
| """Generate a model response for input text""" | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| input_data = TextInput(text=q) | |
| # Truncate text if too long | |
| text = input_data.text[:input_data.max_length] | |
| if len(text) == input_data.max_length: | |
| logger.warning("Warning: text truncated") | |
| if len(text) < input_data.min_length: | |
| logger.warning("Warning: empty text, aborting") | |
| return None | |
| # Prepare the model input | |
| messages_think = [ | |
| {"role": "user", "content": text} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages_think, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # Generate the output | |
| generated_ids = model.generate(**model_inputs, max_new_tokens=32768) | |
| # Get and decode the output | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] | |
| result = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| # Checkpoint | |
| processing_time = time.time() - start_time | |
| return ModelResponse( | |
| text=result['label'], | |
| confidence=result['score'], | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Evaluation error: {e}") | |
| raise HTTPException(status_code=500, detail="Evaluation failed") | |
| async def health_check(): | |
| """Health check and basic configuration""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "gpu_available": cuda.is_available() | |
| } | |