# app.py — EN→BN MT API (cleaned) import os import torch from typing import List, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware # optional from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ------------------------- # Device + model name (COPIED) # ------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mt_pretrained_model_name = "shhossain/opus-mt-en-to-bn" # ------------------------- # Load tokenizer/model with clear error if it fails # ------------------------- try: tokenizer = AutoTokenizer.from_pretrained(mt_pretrained_model_name) model = AutoModelForSeq2SeqLM.from_pretrained(mt_pretrained_model_name).to(device) model.eval() except Exception as e: raise RuntimeError(f"Failed to load model/tokenizer '{mt_pretrained_model_name}': {e}") # ------------------------- # FastAPI app + (optional) CORS # ------------------------- app = FastAPI(title="EN→BN MT API", version="1.0.0") # If you’ll call from a browser (localhost dev or a web app), enable CORS: app.add_middleware( CORSMiddleware, allow_origins=["*"], # replace with your domain(s) in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------------- # Schemas (COPIED/NEW mix) # ------------------------- class TranslateIn(BaseModel): text: str = Field(..., description="English sentence") max_new_tokens: int = Field(128, ge=1, le=512) num_beams: int = Field(4, ge=1, le=10) do_sample: bool = Field(False, description="Use sampling instead of pure beam search") temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0) top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0) class TranslateOut(BaseModel): translation: str class BatchTranslateIn(BaseModel): texts: List[str] = Field(..., description="List of English sentences") max_new_tokens: int = Field(128, ge=1, le=512) num_beams: int = Field(4, ge=1, le=10) do_sample: bool = Field(False) temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0) top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0) class BatchTranslateOut(BaseModel): translations: List[str] MAX_INPUT_CHARS = 2000 def generate_translation( inputs: List[str], max_new_tokens: int, num_beams: int, do_sample: bool, temperature: Optional[float], top_p: Optional[float], ) -> List[str]: # input length guard for s in inputs: if len(s) > MAX_INPUT_CHARS: raise ValueError(f"Input too long (> {MAX_INPUT_CHARS} chars).") batch = tokenizer( inputs, return_tensors="pt", padding=True, truncation=True ).to(device) gen_kwargs = { "max_new_tokens": max_new_tokens, "num_beams": num_beams, "do_sample": do_sample, } if do_sample: if temperature is not None: gen_kwargs["temperature"] = float(temperature) if top_p is not None: gen_kwargs["top_p"] = float(top_p) with torch.no_grad(): outputs = model.generate(**batch, **gen_kwargs) return tokenizer.batch_decode(outputs, skip_special_tokens=True) @app.get("/greet") def greet(): return { "message": "Welcome to EN→BN MT API", "device": "cuda" if torch.cuda.is_available() else "cpu", "model": mt_pretrained_model_name, } @app.post("/translate", response_model=TranslateOut) def translate(payload: TranslateIn): try: out = generate_translation( [payload.text], payload.max_new_tokens, payload.num_beams, payload.do_sample, payload.temperature, payload.top_p, )[0] return {"translation": out} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/translate_batch", response_model=BatchTranslateOut) def translate_batch(payload: BatchTranslateIn): try: if not payload.texts: raise ValueError("texts list is empty.") outs = generate_translation( payload.texts, payload.max_new_tokens, payload.num_beams, payload.do_sample, payload.temperature, payload.top_p, ) return {"translations": outs} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)