Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,177 +1,228 @@
|
|
| 1 |
-
import os,
|
| 2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
except Exception:
|
| 8 |
-
from IndicTransToolkit.IndicTransToolkit import IndicProcessor
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
# -------- Config --------
|
| 12 |
-
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
|
| 13 |
-
MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
|
| 14 |
-
TOKENIZER_REV, MODEL_REV = os.getenv("TOKENIZER_REV"), os.getenv("MODEL_REV")
|
| 15 |
|
|
|
|
| 16 |
SRC_CODE = "eng_Latn"
|
| 17 |
HI_CODE = "hin_Deva"
|
| 18 |
TE_CODE = "tel_Telu"
|
| 19 |
|
| 20 |
-
|
| 21 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
-
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device).eval()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
def _ensure_vocab_consistency(md, tok):
|
| 40 |
-
try:
|
| 41 |
-
actual_vocab = md.get_output_embeddings().weight.shape[0]
|
| 42 |
-
except Exception: actual_vocab = None
|
| 43 |
-
if actual_vocab:
|
| 44 |
-
md.config.vocab_size = actual_vocab
|
| 45 |
-
md.generation_config.vocab_size = actual_vocab
|
| 46 |
-
else:
|
| 47 |
-
vs = getattr(tok, "vocab_size", len(tok) if hasattr(tok, "__len__") else 64000)
|
| 48 |
-
md.config.vocab_size = vs
|
| 49 |
-
md.generation_config.vocab_size = vs
|
| 50 |
-
if not hasattr(md.config, "get_text_config"):
|
| 51 |
-
md.config.get_text_config = types.MethodType(lambda self: self, md.config)
|
| 52 |
-
|
| 53 |
-
_ensure_vocab_consistency(model, tokenizer)
|
| 54 |
-
|
| 55 |
-
for obj in (model.config, model.generation_config):
|
| 56 |
-
try: setattr(obj, "use_cache", False)
|
| 57 |
-
except: pass
|
| 58 |
-
|
| 59 |
-
# Processor
|
| 60 |
-
ip = IndicProcessor(inference=True)
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
@torch.inference_mode()
|
| 65 |
-
def
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
try:
|
| 88 |
-
hi =
|
| 89 |
except Exception as e:
|
| 90 |
-
hi = f"⚠️ Hindi
|
| 91 |
try:
|
| 92 |
-
te =
|
| 93 |
except Exception as e:
|
| 94 |
-
te = f"⚠️ Telugu
|
| 95 |
return hi, te
|
| 96 |
|
| 97 |
|
| 98 |
-
#
|
| 99 |
-
THEME = gr.themes.
|
| 100 |
primary_hue="blue", neutral_hue="slate"
|
| 101 |
).set(
|
| 102 |
-
body_background_fill="#
|
| 103 |
-
body_text_color="#
|
| 104 |
-
block_background_fill="#
|
| 105 |
-
block_border_color="#
|
| 106 |
-
block_title_text_color="#
|
| 107 |
button_primary_background_fill="#2563eb",
|
| 108 |
-
button_primary_text_color="#ffffff"
|
| 109 |
)
|
| 110 |
|
| 111 |
CUSTOM_CSS = """
|
| 112 |
-
#hdr {
|
| 113 |
-
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
border:1px solid #e5e7eb; border-radius:12px;
|
| 120 |
-
background:white; box-shadow:0 1px 3px rgba(0,0,0,0.08);
|
| 121 |
-
padding:12px; display:flex; flex-direction:column;
|
| 122 |
-
}
|
| 123 |
-
.panel h2 {
|
| 124 |
-
font-size:16px; font-weight:600; margin-bottom:8px; color:#374151;
|
| 125 |
-
}
|
| 126 |
-
textarea {
|
| 127 |
-
font-size:15px !important; line-height:1.55 !important;
|
| 128 |
-
padding:10px 12px !important;
|
| 129 |
-
border:1px solid #d1d5db !important; border-radius:8px !important;
|
| 130 |
-
}
|
| 131 |
-
button { font-weight:600 !important; border-radius:8px !important; }
|
| 132 |
-
button:hover { opacity:0.9; transition:opacity 0.2s; }
|
| 133 |
"""
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN →
|
| 137 |
with gr.Group(elem_id="hdr"):
|
| 138 |
gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
|
| 139 |
-
gr.Markdown("<p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
with gr.Row():
|
| 142 |
-
# Input Column
|
| 143 |
with gr.Column(scale=2):
|
| 144 |
with gr.Group(elem_classes="panel"):
|
| 145 |
gr.Markdown("<h2>English Input</h2>")
|
| 146 |
-
src = gr.Textbox(placeholder="
|
| 147 |
-
|
| 148 |
with gr.Row():
|
| 149 |
translate_btn = gr.Button("👉 Translate", variant="primary")
|
| 150 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 151 |
|
| 152 |
-
# Output Column
|
| 153 |
with gr.Column(scale=2):
|
| 154 |
with gr.Group(elem_classes="panel"):
|
| 155 |
gr.Markdown("<h2>Hindi Translation</h2>")
|
| 156 |
hi_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
|
| 157 |
-
|
| 158 |
with gr.Group(elem_classes="panel"):
|
| 159 |
gr.Markdown("<h2>Telugu Translation</h2>")
|
| 160 |
te_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
|
| 161 |
|
| 162 |
-
# Settings Column
|
| 163 |
with gr.Column(scale=1):
|
| 164 |
with gr.Group(elem_classes="panel"):
|
| 165 |
-
gr.Markdown("<h2>
|
| 166 |
-
num_beams = gr.Slider(1, 8, value=4, step=1, label="
|
| 167 |
-
max_new = gr.Slider(
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
outputs=[hi_out, te_out])
|
| 175 |
clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
|
| 176 |
|
| 177 |
demo.queue(max_size=48).launch()
|
|
|
|
| 1 |
+
import os, re, types, traceback, torch, gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
+
from IndicTransToolkit import IndicProcessor
|
| 4 |
|
| 5 |
+
# --------------------- Device ---------------------
|
| 6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
# --------------------- Languages ------------------
|
| 10 |
SRC_CODE = "eng_Latn"
|
| 11 |
HI_CODE = "hin_Deva"
|
| 12 |
TE_CODE = "tel_Telu"
|
| 13 |
|
| 14 |
+
ip = IndicProcessor(inference=True)
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# --------------------- Regex / Helpers ---------------------
|
| 17 |
+
TAG_REGEX = re.compile(
|
| 18 |
+
r"(?:_src\S+)|(?:tgt\S+)|"
|
| 19 |
+
r"(?:>>\s*\S+\s*<<)|"
|
| 20 |
+
r"\b(?:eng_Latn|hin_Deva|hin_deva|tel_Telu|tel_telu)\b|"
|
| 21 |
+
r"<ID\d*>"
|
| 22 |
+
)
|
| 23 |
|
| 24 |
+
def strip_lang_tags(text: str) -> str:
|
| 25 |
+
s = TAG_REGEX.sub(" ", text)
|
| 26 |
+
return re.sub(r"\s{2,}", " ", s).strip()
|
|
|
|
| 27 |
|
| 28 |
+
def ensure_hindi_danda(s: str) -> str:
|
| 29 |
+
s = re.sub(r"\.\s*$", "।", s)
|
| 30 |
+
if not re.search(r"[।?!…]\s*$", s) and re.search(r"[\u0900-\u097F]\s*$", s):
|
| 31 |
+
s += "।"
|
| 32 |
+
return s
|
| 33 |
+
|
| 34 |
+
# Sentence splitting (pysbd or fallback)
|
| 35 |
+
try:
|
| 36 |
+
import pysbd
|
| 37 |
+
_SEGMENTER = pysbd.Segmenter(language="en", clean=True)
|
| 38 |
+
except Exception:
|
| 39 |
+
_SEGMENTER = None
|
| 40 |
+
|
| 41 |
+
_LEGAL_JOIN_RE = re.compile(r'\b([A-Za-z]{1,6})\.\s*$')
|
| 42 |
+
_NEXT_CONT_RE = re.compile(r'^\s*(?:[\(\[\{]|\d|[a-z])')
|
| 43 |
+
|
| 44 |
+
def _merge_legal_abbrev_breaks(sents):
|
| 45 |
+
merged, i = [], 0
|
| 46 |
+
while i < len(sents):
|
| 47 |
+
cur = sents[i].strip()
|
| 48 |
+
while i + 1 < len(sents):
|
| 49 |
+
nxt = sents[i + 1].lstrip()
|
| 50 |
+
if _LEGAL_JOIN_RE.search(cur) and _NEXT_CONT_RE.match(nxt):
|
| 51 |
+
cur = f"{cur} {nxt}"
|
| 52 |
+
i += 1
|
| 53 |
+
else:
|
| 54 |
+
break
|
| 55 |
+
merged.append(cur)
|
| 56 |
+
i += 1
|
| 57 |
+
return [s for s in merged if s]
|
| 58 |
+
|
| 59 |
+
def split_into_sentences(text: str):
|
| 60 |
+
if _SEGMENTER is not None:
|
| 61 |
+
return _merge_legal_abbrev_breaks(_SEGMENTER.segment(text))
|
| 62 |
+
PLACEHOLDER = "\uE000"
|
| 63 |
+
protected = re.sub(
|
| 64 |
+
r'\b([A-Za-z]{1,6})\.(?=\s*(?:[\(\[\{]|\d|[a-z]))',
|
| 65 |
+
r'\1' + PLACEHOLDER, text.strip()
|
| 66 |
+
)
|
| 67 |
+
protected = re.sub(
|
| 68 |
+
r'\b([A-Za-z]{1,5})\.(?=\s+[A-Z])',
|
| 69 |
+
r'\1' + PLACEHOLDER, protected
|
| 70 |
+
)
|
| 71 |
+
parts = re.split(r'(?<=[.?!])\s+', protected)
|
| 72 |
+
return _merge_legal_abbrev_breaks([p.replace(PLACEHOLDER, '.') for p in parts if p.strip()])
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
# --------------------- Model Loader ---------------------
|
| 76 |
+
MODELS = {
|
| 77 |
+
"Default (Public)": "law-ai/InLegalTrans-En2Indic-1B",
|
| 78 |
+
"Fine-tuned (Private)": "SagarVelamuri/InLegalTrans-En2Indic-FineTuned-Tel-Hin"
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
_model_cache = {}
|
| 82 |
+
|
| 83 |
+
def load_model(model_name: str):
|
| 84 |
+
if model_name in _model_cache:
|
| 85 |
+
return _model_cache[model_name]
|
| 86 |
|
| 87 |
+
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
|
| 88 |
+
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 89 |
+
model_name, trust_remote_code=True,
|
| 90 |
+
low_cpu_mem_usage=True, dtype=dtype
|
| 91 |
+
).to(device).eval()
|
| 92 |
+
|
| 93 |
+
# Fix vocab
|
| 94 |
+
try:
|
| 95 |
+
mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
|
| 96 |
+
except Exception:
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
_model_cache[model_name] = (tok, mdl)
|
| 100 |
+
return tok, mdl
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_bad_words_ids_from_vocab(tok):
|
| 104 |
+
vocab = tok.get_vocab()
|
| 105 |
+
candidates = [
|
| 106 |
+
"eng_Latn","hin_Deva","hin_deva","tel_Telu","tel_telu",
|
| 107 |
+
"_srceng_Latn","tgthin_Deva","tgt_tel_Telu",
|
| 108 |
+
">>hin_Deva<<",">>tel_Telu<<",
|
| 109 |
+
] + [f"<ID{i}>" for i in range(10)]
|
| 110 |
+
out = []
|
| 111 |
+
for c in candidates:
|
| 112 |
+
if c in vocab: out.append([vocab[c]]); continue
|
| 113 |
+
sp_c = "▁" + c
|
| 114 |
+
if sp_c in vocab: out.append([vocab[sp_c]])
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
# --------------------- Translation ---------------------
|
| 118 |
@torch.inference_mode()
|
| 119 |
+
def _translate(text: str, tgt_lang: str, model_choice: str,
|
| 120 |
+
num_beams=4, max_new=128, batch_size=3) -> str:
|
| 121 |
+
|
| 122 |
+
tok, mdl = load_model(MODELS[model_choice])
|
| 123 |
+
BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
|
| 124 |
+
|
| 125 |
+
sentences = split_into_sentences(text)
|
| 126 |
+
full_trans = []
|
| 127 |
+
|
| 128 |
+
for i in range(0, len(sentences), batch_size):
|
| 129 |
+
batch = sentences[i:i+batch_size]
|
| 130 |
+
proc = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=tgt_lang)
|
| 131 |
+
enc = tok(proc, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
|
| 132 |
+
|
| 133 |
+
out = mdl.generate(
|
| 134 |
+
**enc, max_length=max_new, num_beams=num_beams,
|
| 135 |
+
early_stopping=True, no_repeat_ngram_size=3, use_cache=False,
|
| 136 |
+
bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
decoded = tok.batch_decode(out, skip_special_tokens=True)
|
| 140 |
+
decoded = [strip_lang_tags(t) for t in decoded]
|
| 141 |
+
post = ip.postprocess_batch(decoded, lang=tgt_lang)
|
| 142 |
|
| 143 |
+
if tgt_lang == HI_CODE:
|
| 144 |
+
post = [ensure_hindi_danda(x) for x in post]
|
| 145 |
+
|
| 146 |
+
full_trans.extend(p.strip() for p in post)
|
| 147 |
+
|
| 148 |
+
return " ".join(full_trans)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def translate_dual(text, model_choice, num_beams, max_new):
|
| 152 |
+
if not text.strip(): return "", ""
|
| 153 |
try:
|
| 154 |
+
hi = _translate(text, HI_CODE, model_choice, num_beams=num_beams, max_new=max_new)
|
| 155 |
except Exception as e:
|
| 156 |
+
hi = f"⚠️ Hindi failed: {e}"
|
| 157 |
try:
|
| 158 |
+
te = _translate(text, TE_CODE, model_choice, num_beams=num_beams, max_new=max_new)
|
| 159 |
except Exception as e:
|
| 160 |
+
te = f"⚠️ Telugu failed: {e}"
|
| 161 |
return hi, te
|
| 162 |
|
| 163 |
|
| 164 |
+
# --------------------- Dark Theme ---------------------
|
| 165 |
+
THEME = gr.themes.Soft(
|
| 166 |
primary_hue="blue", neutral_hue="slate"
|
| 167 |
).set(
|
| 168 |
+
body_background_fill="#0b0f19",
|
| 169 |
+
body_text_color="#f3f4f6",
|
| 170 |
+
block_background_fill="#111827",
|
| 171 |
+
block_border_color="#1f2937",
|
| 172 |
+
block_title_text_color="#e5e7eb",
|
| 173 |
button_primary_background_fill="#2563eb",
|
| 174 |
+
button_primary_text_color="#ffffff",
|
| 175 |
)
|
| 176 |
|
| 177 |
CUSTOM_CSS = """
|
| 178 |
+
#hdr { text-align:center; padding:16px; }
|
| 179 |
+
#hdr h1 { font-size:24px; font-weight:700; color:#f9fafb; margin:0; }
|
| 180 |
+
#hdr p { font-size:14px; color:#9ca3af; margin-top:4px; }
|
| 181 |
+
.panel { border:1px solid #1f2937; border-radius:10px; padding:12px; background:#111827; box-shadow:0 1px 2px rgba(0,0,0,0.4);}
|
| 182 |
+
.panel h2 { font-size:16px; font-weight:600; margin-bottom:6px; color:#f3f4f6; }
|
| 183 |
+
textarea { background:#0b0f19 !important; color:#f9fafb !important; border-radius:8px !important; border:1px solid #374151 !important; font-size:15px !important; line-height:1.55; }
|
| 184 |
+
button { border-radius:8px !important; font-weight:600 !important; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
"""
|
| 186 |
|
| 187 |
+
# --------------------- UI ---------------------
|
| 188 |
+
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
|
| 189 |
with gr.Group(elem_id="hdr"):
|
| 190 |
gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
|
| 191 |
+
gr.Markdown("<p>IndicTrans2 with batch sentence decomposition</p>")
|
| 192 |
+
|
| 193 |
+
model_choice = gr.Dropdown(
|
| 194 |
+
label="Choose Model", choices=list(MODELS.keys()),
|
| 195 |
+
value="Default (Public)"
|
| 196 |
+
)
|
| 197 |
|
| 198 |
with gr.Row():
|
|
|
|
| 199 |
with gr.Column(scale=2):
|
| 200 |
with gr.Group(elem_classes="panel"):
|
| 201 |
gr.Markdown("<h2>English Input</h2>")
|
| 202 |
+
src = gr.Textbox(lines=12, placeholder="Enter English...", show_label=False)
|
|
|
|
| 203 |
with gr.Row():
|
| 204 |
translate_btn = gr.Button("👉 Translate", variant="primary")
|
| 205 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 206 |
|
|
|
|
| 207 |
with gr.Column(scale=2):
|
| 208 |
with gr.Group(elem_classes="panel"):
|
| 209 |
gr.Markdown("<h2>Hindi Translation</h2>")
|
| 210 |
hi_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
|
|
|
|
| 211 |
with gr.Group(elem_classes="panel"):
|
| 212 |
gr.Markdown("<h2>Telugu Translation</h2>")
|
| 213 |
te_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
|
| 214 |
|
|
|
|
| 215 |
with gr.Column(scale=1):
|
| 216 |
with gr.Group(elem_classes="panel"):
|
| 217 |
+
gr.Markdown("<h2>Settings</h2>")
|
| 218 |
+
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search")
|
| 219 |
+
max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens")
|
| 220 |
+
|
| 221 |
+
translate_btn.click(
|
| 222 |
+
translate_dual,
|
| 223 |
+
inputs=[src, model_choice, num_beams, max_new],
|
| 224 |
+
outputs=[hi_out, te_out]
|
| 225 |
+
)
|
|
|
|
| 226 |
clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
|
| 227 |
|
| 228 |
demo.queue(max_size=48).launch()
|