Samin7479's picture
Initial commit: EN-BN Translation Project
4090512
# app.py — EN→BN MT API (cleaned)
import os
import torch
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware # optional
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# -------------------------
# Device + model name (COPIED)
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mt_pretrained_model_name = "shhossain/opus-mt-en-to-bn"
# -------------------------
# Load tokenizer/model with clear error if it fails
# -------------------------
try:
tokenizer = AutoTokenizer.from_pretrained(mt_pretrained_model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(mt_pretrained_model_name).to(device)
model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load model/tokenizer '{mt_pretrained_model_name}': {e}")
# -------------------------
# FastAPI app + (optional) CORS
# -------------------------
app = FastAPI(title="EN→BN MT API", version="1.0.0")
# If you’ll call from a browser (localhost dev or a web app), enable CORS:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # replace with your domain(s) in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------
# Schemas (COPIED/NEW mix)
# -------------------------
class TranslateIn(BaseModel):
text: str = Field(..., description="English sentence")
max_new_tokens: int = Field(128, ge=1, le=512)
num_beams: int = Field(4, ge=1, le=10)
do_sample: bool = Field(False, description="Use sampling instead of pure beam search")
temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0)
top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0)
class TranslateOut(BaseModel):
translation: str
class BatchTranslateIn(BaseModel):
texts: List[str] = Field(..., description="List of English sentences")
max_new_tokens: int = Field(128, ge=1, le=512)
num_beams: int = Field(4, ge=1, le=10)
do_sample: bool = Field(False)
temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0)
top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0)
class BatchTranslateOut(BaseModel):
translations: List[str]
MAX_INPUT_CHARS = 2000
def generate_translation(
inputs: List[str],
max_new_tokens: int,
num_beams: int,
do_sample: bool,
temperature: Optional[float],
top_p: Optional[float],
) -> List[str]:
# input length guard
for s in inputs:
if len(s) > MAX_INPUT_CHARS:
raise ValueError(f"Input too long (> {MAX_INPUT_CHARS} chars).")
batch = tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True
).to(device)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"do_sample": do_sample,
}
if do_sample:
if temperature is not None:
gen_kwargs["temperature"] = float(temperature)
if top_p is not None:
gen_kwargs["top_p"] = float(top_p)
with torch.no_grad():
outputs = model.generate(**batch, **gen_kwargs)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
@app.get("/greet")
def greet():
return {
"message": "Welcome to EN→BN MT API",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"model": mt_pretrained_model_name,
}
@app.post("/translate", response_model=TranslateOut)
def translate(payload: TranslateIn):
try:
out = generate_translation(
[payload.text],
payload.max_new_tokens,
payload.num_beams,
payload.do_sample,
payload.temperature,
payload.top_p,
)[0]
return {"translation": out}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/translate_batch", response_model=BatchTranslateOut)
def translate_batch(payload: BatchTranslateIn):
try:
if not payload.texts:
raise ValueError("texts list is empty.")
outs = generate_translation(
payload.texts,
payload.max_new_tokens,
payload.num_beams,
payload.do_sample,
payload.temperature,
payload.top_p,
)
return {"translations": outs}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)