File size: 4,602 Bytes
7282eec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4090512
7282eec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e11c57
 
7282eec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# app.py — EN→BN MT API (cleaned)

import os
import torch
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware  # optional
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# -------------------------
# Device + model name (COPIED)
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mt_pretrained_model_name = "shhossain/opus-mt-en-to-bn"

# -------------------------
# Load tokenizer/model with clear error if it fails
# -------------------------
try:
    tokenizer = AutoTokenizer.from_pretrained(mt_pretrained_model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(mt_pretrained_model_name).to(device)
    model.eval()
except Exception as e:
    raise RuntimeError(f"Failed to load model/tokenizer '{mt_pretrained_model_name}': {e}")



# -------------------------
# FastAPI app + (optional) CORS
# -------------------------
app = FastAPI(title="EN→BN MT API", version="1.0.0")

# If you’ll call from a browser (localhost dev or a web app), enable CORS:
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],   # replace with your domain(s) in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# -------------------------
# Schemas (COPIED/NEW mix)
# -------------------------
class TranslateIn(BaseModel):
    text: str = Field(..., description="English sentence")
    max_new_tokens: int = Field(128, ge=1, le=512)
    num_beams: int = Field(4, ge=1, le=10)
    do_sample: bool = Field(False, description="Use sampling instead of pure beam search")
    temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0)
    top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0)

class TranslateOut(BaseModel):
    translation: str

class BatchTranslateIn(BaseModel):
    texts: List[str] = Field(..., description="List of English sentences")
    max_new_tokens: int = Field(128, ge=1, le=512)
    num_beams: int = Field(4, ge=1, le=10)
    do_sample: bool = Field(False)
    temperature: Optional[float] = Field(1.0, ge=0.1, le=5.0)
    top_p: Optional[float] = Field(1.0, ge=0.1, le=1.0)

class BatchTranslateOut(BaseModel):
    translations: List[str]


MAX_INPUT_CHARS = 2000

def generate_translation(
    inputs: List[str],
    max_new_tokens: int,
    num_beams: int,
    do_sample: bool,
    temperature: Optional[float],
    top_p: Optional[float],
) -> List[str]:
    # input length guard
    for s in inputs:
        if len(s) > MAX_INPUT_CHARS:
            raise ValueError(f"Input too long (> {MAX_INPUT_CHARS} chars).")

    batch = tokenizer(
        inputs,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(device)

    gen_kwargs = {
        "max_new_tokens": max_new_tokens,
        "num_beams": num_beams,
        "do_sample": do_sample,
    }
    if do_sample:
        if temperature is not None:
            gen_kwargs["temperature"] = float(temperature)
        if top_p is not None:
            gen_kwargs["top_p"] = float(top_p)

    with torch.no_grad():
        outputs = model.generate(**batch, **gen_kwargs)

    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


@app.get("/greet")
def greet():
    return {
        "message": "Welcome to EN→BN MT API",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "model": mt_pretrained_model_name,
    }

@app.post("/translate", response_model=TranslateOut)
def translate(payload: TranslateIn):
    try:
        out = generate_translation(
            [payload.text],
            payload.max_new_tokens,
            payload.num_beams,
            payload.do_sample,
            payload.temperature,
            payload.top_p,
        )[0]
        return {"translation": out}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/translate_batch", response_model=BatchTranslateOut)
def translate_batch(payload: BatchTranslateIn):
    try:
        if not payload.texts:
            raise ValueError("texts list is empty.")
        outs = generate_translation(
            payload.texts,
            payload.max_new_tokens,
            payload.num_beams,
            payload.do_sample,
            payload.temperature,
            payload.top_p,
        )
        return {"translations": outs}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)