from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import numpy as np import os import time import joblib from pathlib import Path from datetime import datetime, timezone from typing import Optional from contextlib import asynccontextmanager from dotenv import load_dotenv import shutil from huggingface_hub import hf_hub_download # Transformers imports specifically for ProtBERT from transformers import BertTokenizer, BertModel # Import your custom model structure from utils.model_classes import MHSA_GRU load_dotenv() # ========================= CONFIGURATION ========================== # Repository details (Where your trained classifier/scaler live) MODEL_REPO = { "repo_id": "camlas/toxicity", "files": { "classifier": "mhsa_gru_classifier.pth", "scaler": "scaler.pkl" } } # Feature Extraction Config - UPDATED FOR PROTBERT TRANSFORMER_CONFIG = { "model_name": "Rostlab/prot_bert", "model_type": "ProtBERT", "tokenizer_class": BertTokenizer, "model_class": BertModel } CLASSES = ["Non-Toxic", "Toxic"] API_VERSION = "2.0.0-protbert" MODEL_VERSION = "ProtBERT-MHSA-GRU-v1" # Global variables to hold loaded models models = { "transformer": None, "tokenizer": None, "classifier": None, "scaler": None } # Device selection device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ========================= HELPER FUNCTIONS ========================= def ensure_models_directory(): models_dir = "models" Path(models_dir).mkdir(exist_ok=True) return models_dir def download_model_from_hub(model_key: str) -> Optional[str]: """Download custom trained models (Classifier/Scaler) from Private HF Repo""" try: filename = MODEL_REPO["files"][model_key] repo_id = MODEL_REPO["repo_id"] models_dir = ensure_models_directory() local_path = os.path.join(models_dir, filename) # If file exists locally, use it if os.path.exists(local_path): print(f"✅ Found {model_key} locally: {local_path}") return local_path print(f"đŸ“Ĩ Downloading {model_key} from {repo_id}...") token = os.getenv("HF_TOKEN") if not token: print("âš ī¸ Warning: HF_TOKEN not found in .env. Private repos will fail.") temp_path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="model", token=token ) shutil.copy2(temp_path, local_path) return local_path except Exception as e: print(f"❌ Error downloading {model_key}: {e}") return None def load_feature_extractor(): """Load the ProtBERT Model from HuggingFace""" print(f"🔄 Loading Transformer: {TRANSFORMER_CONFIG['model_name']}...") try: # Load specifically with do_lower_case=False for ProtBERT tokenizer = TRANSFORMER_CONFIG['tokenizer_class'].from_pretrained( TRANSFORMER_CONFIG['model_name'], do_lower_case=False ) model = TRANSFORMER_CONFIG['model_class'].from_pretrained( TRANSFORMER_CONFIG['model_name'] ) model.to(device) model.eval() models["tokenizer"] = tokenizer models["transformer"] = model print("✅ ProtBERT Transformer loaded successfully") return True except Exception as e: print(f"❌ Error loading Transformer: {e}") return False def load_classifier_and_scaler(): """Load the custom MHSA-GRU classifier and Scaler""" try: # 1. Load Scaler scaler_path = download_model_from_hub("scaler") if scaler_path: models["scaler"] = joblib.load(scaler_path) print("✅ Scaler loaded") # 2. Load Classifier clf_path = download_model_from_hub("classifier") if clf_path: # ProtBERT output dimension is 1024 input_dim = 1024 print(f"â„šī¸ Initializing MHSA_GRU with input_dim={input_dim} (ProtBERT)") classifier = MHSA_GRU( input_dim=input_dim, hidden_dim=256, # Matching your training code num_heads=8, num_gru_layers=2, dropout=0.3 ) state_dict = torch.load(clf_path, map_location=device) classifier.load_state_dict(state_dict) classifier.to(device) classifier.eval() models["classifier"] = classifier print("✅ Classifier loaded") return models["scaler"] is not None and models["classifier"] is not None except Exception as e: print(f"❌ Error loading custom models: {e}") return False def preprocess_sequence(sequence: str): """ Preprocess sequence for ProtBERT. ProtBERT expects spaces between amino acids: 'M K T A Y...' """ # Clean and uppercase sequence = sequence.upper().strip().replace('\n', '').replace('\r', '') # Add spaces between residues spaced_sequence = " ".join(list(sequence)) return spaced_sequence def extract_features(sequence: str): """Run sequence through ProtBERT to get [CLS] embeddings""" tokenizer = models["tokenizer"] model = models["transformer"] processed_seq = preprocess_sequence(sequence) inputs = tokenizer( [processed_seq], return_tensors="pt", padding=True, truncation=True, max_length=512 # ProtBERT max length ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) # Extract [CLS] token embedding (Index 0) # shape: (batch_size, hidden_dim) -> (1, 1024) features = outputs.last_hidden_state[:, 0, :] return features.cpu().numpy() # ========================= FASTAPI LIFESPAN ========================= @asynccontextmanager async def lifespan(app: FastAPI): print("🚀 Starting Toxicity Detection API (ProtBERT Edition)...") # Check if utils/model_classes.py exists if not os.path.exists("utils/model_classes.py"): print("❌ Error: utils/model_classes.py not found. Please create it.") success_tf = load_feature_extractor() success_custom = load_classifier_and_scaler() if not (success_tf and success_custom): print("âš ī¸ Warning: Not all models loaded successfully") yield print("🔄 Shutting down API...") app = FastAPI( title="Peptide Toxicity Detection API", description="API using ProtBERT features + MHSA-GRU classifier", version=API_VERSION, lifespan=lifespan ) # ========================= PYDANTIC MODELS ========================= class SequenceRequest(BaseModel): sequence: str class PredictionResponse(BaseModel): sequence_preview: str is_toxic: bool label: str score: float confidence_level: str model_used: str processing_time_ms: float timestamp: str # ========================= ENDPOINTS ========================= @app.get("/") async def root(): return {"message": "Toxicity Detection API is running. Use /predict to analyze sequences."} @app.get("/health") async def health_check(): loaded = all(v is not None for v in models.values()) return { "status": "healthy" if loaded else "degraded", "models_loaded": {k: v is not None for k, v in models.items()}, "device": str(device), "model_version": MODEL_VERSION, "feature_extractor": TRANSFORMER_CONFIG["model_name"] } @app.post("/predict", response_model=PredictionResponse) async def predict(request: SequenceRequest): start_time = time.time() if not all(models.values()): raise HTTPException(status_code=503, detail="Models are not fully initialized.") if not request.sequence: raise HTTPException(status_code=400, detail="Empty sequence provided.") try: # 1. Extract Features (ProtBERT [CLS] Token) # This handles the 'M K T' spacing internally raw_features = extract_features(request.sequence) # 2. Scale Features # Use the scaler loaded from your repo scaled_features = models["scaler"].transform(raw_features) # 3. Predict (MHSA-GRU) features_tensor = torch.FloatTensor(scaled_features).to(device) with torch.no_grad(): # Get probability (sigmoid output) probability = models["classifier"](features_tensor).item() # 4. Interpret Results # Threshold 0.5 prediction_class = 1 if probability > 0.5 else 0 predicted_label = CLASSES[prediction_class] # Confidence calculation confidence_score = abs(probability - 0.5) * 2 confidence_level = "High" if confidence_score > 0.8 else "Medium" if confidence_score > 0.5 else "Low" processing_time = round((time.time() - start_time) * 1000, 2) return PredictionResponse( sequence_preview=request.sequence[:20] + "..." if len(request.sequence) > 20 else request.sequence, is_toxic=(prediction_class == 1), label=predicted_label, score=probability, confidence_level=confidence_level, model_used="ProtBERT + MHSA-GRU", processing_time_ms=processing_time, timestamp=datetime.now(timezone.utc).isoformat() ) except Exception as e: print(f"Error during prediction: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)