Spaces:
Sleeping
Sleeping
| # app.py — EN→BN MT API (cleaned) | |
| import os | |
| import torch | |
| from typing import List, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware # optional | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ------------------------- | |
| # Device + model name (COPIED) | |
| # ------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mt_pretrained_model_name = "shhossain/opus-mt-en-to-bn" | |
| # ------------------------- | |
| # Load tokenizer/model with clear error if it fails | |
| # ------------------------- | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(mt_pretrained_model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(mt_pretrained_model_name).to(device) | |
| model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model/tokenizer '{mt_pretrained_model_name}': {e}") | |
| # ------------------------- | |
| # FastAPI app + (optional) CORS | |
| # ------------------------- | |
| app = FastAPI(title="EN→BN MT API", version="1.0.0") | |
| # If you’ll call from a browser (localhost dev or a web app), enable CORS: | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # replace with your domain(s) in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------- | |
| # Schemas (COPIED/NEW mix) | |
| # ------------------------- | |
| class TranslateIn(BaseModel): | |
| text: str = Field(..., description="English sentence") | |
| max_new_tokens: int = Field(128, ge=1, le=512) | |
| num_beams: int = Field(4, ge=1, le=10) | |
| do_sample: bool = Field(False, description="Use sampling instead of pure beam search") | |
| temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0) | |
| top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0) | |
| class TranslateOut(BaseModel): | |
| translation: str | |
| class BatchTranslateIn(BaseModel): | |
| texts: List[str] = Field(..., description="List of English sentences") | |
| max_new_tokens: int = Field(128, ge=1, le=512) | |
| num_beams: int = Field(4, ge=1, le=10) | |
| do_sample: bool = Field(False) | |
| temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0) | |
| top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0) | |
| class BatchTranslateOut(BaseModel): | |
| translations: List[str] | |
| MAX_INPUT_CHARS = 2000 | |
| def generate_translation( | |
| inputs: List[str], | |
| max_new_tokens: int, | |
| num_beams: int, | |
| do_sample: bool, | |
| temperature: Optional[float], | |
| top_p: Optional[float], | |
| ) -> List[str]: | |
| # input length guard | |
| for s in inputs: | |
| if len(s) > MAX_INPUT_CHARS: | |
| raise ValueError(f"Input too long (> {MAX_INPUT_CHARS} chars).") | |
| batch = tokenizer( | |
| inputs, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ).to(device) | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "num_beams": num_beams, | |
| "do_sample": do_sample, | |
| } | |
| if do_sample: | |
| if temperature is not None: | |
| gen_kwargs["temperature"] = float(temperature) | |
| if top_p is not None: | |
| gen_kwargs["top_p"] = float(top_p) | |
| with torch.no_grad(): | |
| outputs = model.generate(**batch, **gen_kwargs) | |
| return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| def greet(): | |
| return { | |
| "message": "Welcome to EN→BN MT API", | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "model": mt_pretrained_model_name, | |
| } | |
| def translate(payload: TranslateIn): | |
| try: | |
| out = generate_translation( | |
| [payload.text], | |
| payload.max_new_tokens, | |
| payload.num_beams, | |
| payload.do_sample, | |
| payload.temperature, | |
| payload.top_p, | |
| )[0] | |
| return {"translation": out} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def translate_batch(payload: BatchTranslateIn): | |
| try: | |
| if not payload.texts: | |
| raise ValueError("texts list is empty.") | |
| outs = generate_translation( | |
| payload.texts, | |
| payload.max_new_tokens, | |
| payload.num_beams, | |
| payload.do_sample, | |
| payload.temperature, | |
| payload.top_p, | |
| ) | |
| return {"translations": outs} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |