from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import torch.nn as nn import numpy as np from typing import Optional, List import time from datetime import datetime, timezone import os import warnings from huggingface_hub import hf_hub_download from contextlib import asynccontextmanager import uvicorn from dotenv import load_dotenv import shutil import joblib from pathlib import Path from transformers import BertTokenizer, BertModel from utils.model_classes import MHSA_GRU, MultiHeadSelfAttention load_dotenv() warnings.filterwarnings('ignore') # ========================= CONFIGURATION ========================= device = torch.device("cuda" if torch.cuda.is_available() else "cpu") API_VERSION = "1.0.0" MODEL_VERSION = "MHSA-GRU-Transformer-v1.0" # Model repository configuration MODEL_REPO = { "repo_id": "camlas/toxicity", "files": { "classifier": "mhsa_gru_classifier.pth", "scaler": "scaler.pkl", "config": "config.json", "model_weights": "model.safetensors", "vocab": "vocab.txt", "tokenizer_config": "tokenizer_config.json", "special_tokens_map": "special_tokens_map.json" } } # Global model variables classifier = None scaler = None transformer_model = None transformer_tokenizer = None EMBEDDING_TYPE = "Bert" MODEL_NAME = "ProtBERT" # ========================= PYDANTIC MODELS ========================= class SequenceRequest(BaseModel): sequence: str class BatchSequenceRequest(BaseModel): sequences: List[str] class PredictionResponse(BaseModel): status_code: int status: str success: bool data: Optional[dict] = None error: Optional[str] = None error_code: Optional[str] = None timestamp: str api_version: str processing_time_ms: float class HealthResponse(BaseModel): status_code: int status: str service: str api_version: str model_version: str models_loaded: bool models_loaded_count: int total_models_required: int model_sources: dict repository_info: dict device: str timestamp: str # ========================= HELPER FUNCTIONS ========================= def create_kmers(sequence, k=6): """Convert DNA sequence to k-mer tokens (for DNABERT)""" kmers = [] for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] kmers.append(kmer) return ' '.join(kmers) def ensure_models_directory(): models_dir = "models" if not os.path.exists(models_dir): os.makedirs(models_dir) print(f"āœ… Created {models_dir} directory") return models_dir def download_model_from_hub(model_name: str) -> Optional[str]: """Download individual model files from HuggingFace Hub""" try: if model_name not in MODEL_REPO["files"]: raise ValueError(f"Unknown model: {model_name}") filename = MODEL_REPO["files"][model_name] repo_id = MODEL_REPO["repo_id"] models_dir = ensure_models_directory() local_path = os.path.join(models_dir, filename) if os.path.exists(local_path): print(f"āœ… Found {model_name} in local models directory: {local_path}") return local_path print(f"šŸ“„ Downloading {model_name} ({filename}) from {repo_id}...") token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") if not token: print("āš ļø Warning: No HF token found. This may fail for private repositories.") temp_model_path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="model", token=token ) shutil.copy2(temp_model_path, local_path) print(f"āœ… {model_name} downloaded and stored!") return local_path except Exception as e: print(f"āŒ Error downloading {model_name}: {e}") return None def extract_features_from_sequence(sequence: str): """Extract features from sequence using ProtBERT""" global transformer_model, transformer_tokenizer if transformer_model is None or transformer_tokenizer is None: raise ValueError("ProtBERT model not loaded") # ProtBERT expects sequences with spaces between amino acids # Convert "MKTAYIAKQR" to "M K T A Y I A K Q R" processed_seq = ' '.join(list(sequence.upper())) # Tokenize inputs = transformer_tokenizer( processed_seq, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(device) for k, v in inputs.items()} # Extract features with torch.no_grad(): outputs = transformer_model(**inputs) # Use [CLS] token embedding cls_embeddings = outputs.last_hidden_state[:, 0, :] return cls_embeddings.cpu().numpy() def load_all_models(): """Load all models from HuggingFace Hub""" global classifier, scaler, transformer_model, transformer_tokenizer models_dir = ensure_models_directory() models_loaded = { "classifier": False, "scaler": False, "transformer_model": False, "transformer_tokenizer": False } print(f"šŸš€ Loading models from {MODEL_REPO['repo_id']}...") print("=" * 60) try: # Download all necessary files print("šŸ“„ Downloading ProtBERT model files...") files_to_download = ["config", "model_weights", "vocab", "tokenizer_config", "special_tokens_map"] for file_key in files_to_download: download_model_from_hub(file_key) # Load ProtBERT Tokenizer print("šŸ”„ Loading ProtBERT tokenizer...") try: transformer_tokenizer = BertTokenizer.from_pretrained( models_dir, do_lower_case=False, local_files_only=True ) models_loaded["transformer_tokenizer"] = True print("āœ… ProtBERT tokenizer loaded!") except Exception as e: print(f"āŒ Error loading tokenizer: {e}") # Try loading from HuggingFace directly print("šŸ”„ Trying to load tokenizer directly from HuggingFace...") token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") transformer_tokenizer = BertTokenizer.from_pretrained( MODEL_REPO["repo_id"], do_lower_case=False, token=token ) models_loaded["transformer_tokenizer"] = True print("āœ… ProtBERT tokenizer loaded from HuggingFace!") # Load ProtBERT Model print("šŸ”„ Loading ProtBERT model...") try: transformer_model = BertModel.from_pretrained( models_dir, local_files_only=True ) models_loaded["transformer_model"] = True print("āœ… ProtBERT model loaded!") except Exception as e: print(f"āŒ Error loading model: {e}") # Try loading from HuggingFace directly print("šŸ”„ Trying to load model directly from HuggingFace...") token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") transformer_model = BertModel.from_pretrained( MODEL_REPO["repo_id"], token=token ) models_loaded["transformer_model"] = True print("āœ… ProtBERT model loaded from HuggingFace!") transformer_model.to(device) transformer_model.eval() # Load Classifier print("šŸ”„ Loading classifier (MHSA-GRU)...") clf_path = os.path.join(models_dir, MODEL_REPO["files"]["classifier"]) if not os.path.exists(clf_path): print("šŸ“„ Classifier not found locally, downloading...") clf_path = download_model_from_hub("classifier") if clf_path and os.path.exists(clf_path): checkpoint = torch.load(clf_path, map_location=device, weights_only=False) # Handle different checkpoint formats if 'input_dim' in checkpoint: input_dim = checkpoint['input_dim'] else: # ProtBERT embedding size is 1024 input_dim = 1024 classifier = MHSA_GRU(input_dim, hidden_dim=256) # Load state dict if 'model_state_dict' in checkpoint: classifier.load_state_dict(checkpoint['model_state_dict']) else: classifier.load_state_dict(checkpoint) classifier.to(device) classifier.eval() models_loaded["classifier"] = True print(f"āœ… Classifier loaded! (input_dim: {input_dim})") # Load Scaler print("šŸ”„ Loading feature scaler...") scaler_path = os.path.join(models_dir, MODEL_REPO["files"]["scaler"]) if not os.path.exists(scaler_path): print("šŸ“„ Scaler not found locally, downloading...") scaler_path = download_model_from_hub("scaler") if scaler_path and os.path.exists(scaler_path): scaler = joblib.load(scaler_path) models_loaded["scaler"] = True print("āœ… Scaler loaded!") loaded_count = sum(models_loaded.values()) total_count = len(models_loaded) print(f"\nšŸ“Š Model Loading Summary:") print(f" • Successfully loaded: {loaded_count}/{total_count}") print(f" • Repository: {MODEL_REPO['repo_id']}") print(f" • Embedding Model: {MODEL_NAME}") print(f" • Device: {device}") critical_models = ["classifier", "scaler", "transformer_model", "transformer_tokenizer"] critical_loaded = all(models_loaded[m] for m in critical_models) if critical_loaded: print("šŸŽ‰ All critical models loaded successfully!") return True else: print("āš ļø Some critical models failed to load") print(f" Models status: {models_loaded}") return False except Exception as e: print(f"āŒ Error loading models: {e}") import traceback traceback.print_exc() return False # ========================= FASTAPI APPLICATION ========================= @asynccontextmanager async def lifespan(app: FastAPI): # Startup print("šŸš€ Starting Toxicity Prediction API...") success = load_all_models() if not success: print("āš ļø Warning: Not all models loaded successfully") yield # Shutdown print("šŸ”„ Shutting down API...") app = FastAPI( title="Toxicity Prediction API", description="API for toxicity prediction using MHSA-GRU with Transformer embeddings", version="1.0.0", lifespan=lifespan ) @app.get("/") async def root(): return { "message": "Toxicity Prediction API", "version": API_VERSION, "endpoints": { "/predict": "POST - Predict toxicity for a single sequence", "/predict/batch": "POST - Predict toxicity for multiple sequences", "/example": "GET - Try the API with a hardcoded example sequence", "/health": "GET - Check API health and model status" }, "example_usage": { "single": { "method": "POST", "url": "/predict", "body": {"sequence": "MKTAYIAKQRQISFVKSHFSRQLE"} }, "batch": { "method": "POST", "url": "/predict/batch", "body": { "sequences": [ "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES", "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK" ] } }, "example": { "method": "GET", "url": "/example", "description": "No input needed - just call this endpoint" } } } @app.post("/predict", response_model=PredictionResponse) async def predict(request: SequenceRequest): start_time = time.time() timestamp = datetime.now(timezone.utc).isoformat() try: if not request.sequence or len(request.sequence) == 0: raise HTTPException( status_code=400, detail={ "status_code": 400, "status": "error", "success": False, "error": "No sequence provided", "error_code": "MISSING_SEQUENCE", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": round((time.time() - start_time) * 1000, 2) } ) # Check if models are loaded if classifier is None or scaler is None or transformer_model is None: raise HTTPException( status_code=503, detail={ "status_code": 503, "status": "error", "success": False, "error": "Models not loaded properly", "error_code": "MODEL_NOT_LOADED", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": round((time.time() - start_time) * 1000, 2) } ) # Validate sequence sequence = request.sequence.upper().strip() if len(sequence) < 10: raise HTTPException( status_code=400, detail={ "status_code": 400, "status": "error", "success": False, "error": "Sequence too short (minimum 10 characters)", "error_code": "SEQUENCE_TOO_SHORT", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": round((time.time() - start_time) * 1000, 2) } ) # Step 1: Extract features using ProtBERT features = extract_features_from_sequence(sequence) # Step 2: Scale features scaled_features = scaler.transform(features) # Step 3: Predict using MHSA-GRU features_tensor = torch.FloatTensor(scaled_features).to(device) with torch.no_grad(): probability = classifier(features_tensor).cpu().numpy()[0, 0] # Determine prediction prediction_class = 1 if probability > 0.5 else 0 predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" confidence = float(abs(probability - 0.5) * 2) # Determine confidence level if confidence > 0.8: confidence_level = "high" elif confidence > 0.6: confidence_level = "medium" else: confidence_level = "low" processing_time = round((time.time() - start_time) * 1000, 2) return PredictionResponse( status_code=200, status="success", success=True, data={ "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, "sequence_length": len(sequence), "prediction": { "predicted_class": predicted_label, "confidence": confidence, "confidence_level": confidence_level, "toxicity_score": float(probability), "non_toxicity_score": float(1 - probability) }, "metadata": { "embedding_model": MODEL_NAME, "embedding_type": EMBEDDING_TYPE, "model_version": MODEL_VERSION, "device": str(device) } }, timestamp=timestamp, api_version=API_VERSION, processing_time_ms=processing_time ) except HTTPException: raise except Exception as e: processing_time = round((time.time() - start_time) * 1000, 2) raise HTTPException( status_code=500, detail={ "status_code": 500, "status": "error", "success": False, "error": f"Internal server error: {str(e)}", "error_code": "INTERNAL_ERROR", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": processing_time } ) @app.post("/predict/batch", response_model=PredictionResponse) async def predict_batch(request: BatchSequenceRequest): """ Predict toxicity for multiple sequences at once. Example request body: { "sequences": [ "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES", "MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK" ] } """ start_time = time.time() timestamp = datetime.now(timezone.utc).isoformat() try: if not request.sequences or len(request.sequences) == 0: raise HTTPException( status_code=400, detail={ "status_code": 400, "status": "error", "success": False, "error": "No sequences provided", "error_code": "MISSING_SEQUENCES", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": round((time.time() - start_time) * 1000, 2) } ) # Check if models are loaded if classifier is None or scaler is None or transformer_model is None: raise HTTPException( status_code=503, detail={ "status_code": 503, "status": "error", "success": False, "error": "Models not loaded properly", "error_code": "MODEL_NOT_LOADED", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": round((time.time() - start_time) * 1000, 2) } ) results = [] for idx, seq in enumerate(request.sequences, 1): try: sequence = seq.upper().strip() # Validate sequence length if len(sequence) < 10: results.append({ "sequence_index": idx, "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, "sequence_length": len(sequence), "error": "Sequence too short (minimum 10 characters)", "predicted_class": None, "toxicity_score": None, "confidence": None }) continue # Extract features using ProtBERT features = extract_features_from_sequence(sequence) scaled_features = scaler.transform(features) features_tensor = torch.FloatTensor(scaled_features).to(device) with torch.no_grad(): probability = classifier(features_tensor).cpu().numpy()[0, 0] prediction_class = 1 if probability > 0.5 else 0 predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" confidence = float(abs(probability - 0.5) * 2) # Determine confidence level if confidence > 0.8: confidence_level = "high" elif confidence > 0.6: confidence_level = "medium" else: confidence_level = "low" results.append({ "sequence_index": idx, "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, "sequence_length": len(sequence), "predicted_class": predicted_label, "toxicity_score": float(probability), "non_toxicity_score": float(1 - probability), "confidence": confidence, "confidence_level": confidence_level, "error": None }) except Exception as e: # Handle individual sequence errors without stopping the batch results.append({ "sequence_index": idx, "sequence": seq[:100] + "..." if len(seq) > 100 else seq, "sequence_length": len(seq), "error": f"Error processing sequence: {str(e)}", "predicted_class": None, "toxicity_score": None, "confidence": None }) processing_time = round((time.time() - start_time) * 1000, 2) # Count successful predictions successful_predictions = sum(1 for r in results if r.get("predicted_class") is not None) return PredictionResponse( status_code=200, status="success", success=True, data={ "total_sequences": len(request.sequences), "successful_predictions": successful_predictions, "failed_predictions": len(request.sequences) - successful_predictions, "results": results, "metadata": { "embedding_model": MODEL_NAME, "embedding_type": EMBEDDING_TYPE, "model_version": MODEL_VERSION, "device": str(device) } }, timestamp=timestamp, api_version=API_VERSION, processing_time_ms=processing_time ) except HTTPException: raise except Exception as e: processing_time = round((time.time() - start_time) * 1000, 2) raise HTTPException( status_code=500, detail={ "status_code": 500, "status": "error", "success": False, "error": f"Internal server error: {str(e)}", "error_code": "INTERNAL_ERROR", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": processing_time } ) @app.get("/example", response_model=PredictionResponse) async def predict_example(): """ Predict using a hardcoded example protein sequence. No input required - just call this endpoint to see how the API works. Example sequence: MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES """ start_time = time.time() timestamp = datetime.now(timezone.utc).isoformat() # Hardcoded example sequence EXAMPLE_SEQUENCE = "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES" try: # Check if models are loaded if classifier is None or scaler is None or transformer_model is None: raise HTTPException( status_code=503, detail={ "status_code": 503, "status": "error", "success": False, "error": "Models not loaded properly", "error_code": "MODEL_NOT_LOADED", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": round((time.time() - start_time) * 1000, 2) } ) sequence = EXAMPLE_SEQUENCE.upper().strip() # Step 1: Extract features using ProtBERT features = extract_features_from_sequence(sequence) # Step 2: Scale features scaled_features = scaler.transform(features) # Step 3: Predict using MHSA-GRU features_tensor = torch.FloatTensor(scaled_features).to(device) with torch.no_grad(): probability = classifier(features_tensor).cpu().numpy()[0, 0] # Determine prediction prediction_class = 1 if probability > 0.5 else 0 predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" confidence = float(abs(probability - 0.5) * 2) # Determine confidence level if confidence > 0.8: confidence_level = "high" elif confidence > 0.6: confidence_level = "medium" else: confidence_level = "low" processing_time = round((time.time() - start_time) * 1000, 2) return PredictionResponse( status_code=200, status="success", success=True, data={ "note": "This is an example prediction using a hardcoded sequence", "sequence": sequence, "sequence_length": len(sequence), "prediction": { "predicted_class": predicted_label, "confidence": confidence, "confidence_level": confidence_level, "toxicity_score": float(probability), "non_toxicity_score": float(1 - probability) }, "metadata": { "embedding_model": MODEL_NAME, "embedding_type": EMBEDDING_TYPE, "model_version": MODEL_VERSION, "device": str(device), "source": "hardcoded_example" } }, timestamp=timestamp, api_version=API_VERSION, processing_time_ms=processing_time ) except HTTPException: raise except Exception as e: processing_time = round((time.time() - start_time) * 1000, 2) raise HTTPException( status_code=500, detail={ "status_code": 500, "status": "error", "success": False, "error": f"Internal server error: {str(e)}", "error_code": "INTERNAL_ERROR", "timestamp": timestamp, "api_version": API_VERSION, "processing_time_ms": processing_time } ) @app.get("/health", response_model=HealthResponse) async def health_check(): models_loaded = all([ classifier is not None, scaler is not None, transformer_model is not None, transformer_tokenizer is not None ]) model_sources = { "classifier": { "loaded": classifier is not None, "source": "huggingface_hub", "repository": MODEL_REPO["repo_id"] }, "scaler": { "loaded": scaler is not None, "source": "huggingface_hub", "repository": MODEL_REPO["repo_id"] }, "transformer_model": { "loaded": transformer_model is not None, "model_name": MODEL_NAME, "source": "huggingface_hub", "repository": MODEL_REPO["repo_id"] } } repository_info = { "repository_id": MODEL_REPO["repo_id"], "embedding_type": EMBEDDING_TYPE, "model_name": MODEL_NAME, "total_models": len(MODEL_REPO["files"]) } return HealthResponse( status_code=200 if models_loaded else 503, status="healthy" if models_loaded else "unhealthy", service="Toxicity Prediction API", api_version=API_VERSION, model_version=MODEL_VERSION, models_loaded=models_loaded, models_loaded_count=sum(1 for source in model_sources.values() if source["loaded"]), total_models_required=3, model_sources=model_sources, repository_info=repository_info, device=str(device), timestamp=datetime.now(timezone.utc).isoformat() ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)