Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -84,13 +84,18 @@ def load_model(model_name: str):
|
|
| 84 |
if model_name in _model_cache:
|
| 85 |
return _model_cache[model_name]
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 89 |
model_name, trust_remote_code=True,
|
| 90 |
-
low_cpu_mem_usage=True, dtype=dtype
|
| 91 |
).to(device).eval()
|
| 92 |
|
| 93 |
-
# Fix vocab
|
| 94 |
try:
|
| 95 |
mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
|
| 96 |
except Exception:
|
|
@@ -109,56 +114,88 @@ def build_bad_words_ids_from_vocab(tok):
|
|
| 109 |
] + [f"<ID{i}>" for i in range(10)]
|
| 110 |
out = []
|
| 111 |
for c in candidates:
|
| 112 |
-
if c in vocab:
|
|
|
|
|
|
|
| 113 |
sp_c = "▁" + c
|
| 114 |
-
if sp_c in vocab:
|
|
|
|
| 115 |
return out
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
@torch.inference_mode()
|
| 119 |
-
def
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
tok, mdl = load_model(MODELS[model_choice])
|
| 123 |
BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
|
| 124 |
-
|
| 125 |
sentences = split_into_sentences(text)
|
| 126 |
-
full_trans = []
|
| 127 |
-
|
| 128 |
-
for i in range(0, len(sentences), batch_size):
|
| 129 |
-
batch = sentences[i:i+batch_size]
|
| 130 |
-
proc = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=tgt_lang)
|
| 131 |
-
enc = tok(proc, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
|
| 132 |
-
|
| 133 |
-
out = mdl.generate(
|
| 134 |
-
**enc, max_length=max_new, num_beams=num_beams,
|
| 135 |
-
early_stopping=True, no_repeat_ngram_size=3, use_cache=False,
|
| 136 |
-
bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
|
| 137 |
-
)
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
# --------------------- Dark Theme ---------------------
|
|
@@ -169,7 +206,7 @@ THEME = gr.themes.Soft(
|
|
| 169 |
body_text_color="#f3f4f6",
|
| 170 |
block_background_fill="#111827",
|
| 171 |
block_border_color="#1f2937",
|
| 172 |
-
block_title_text_color="#
|
| 173 |
button_primary_background_fill="#2563eb",
|
| 174 |
button_primary_text_color="#ffffff",
|
| 175 |
)
|
|
@@ -262,11 +299,13 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as
|
|
| 262 |
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
|
| 263 |
max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
|
| 264 |
|
|
|
|
| 265 |
translate_btn.click(
|
| 266 |
-
|
| 267 |
inputs=[src, model_choice, num_beams, max_new],
|
| 268 |
outputs=[hi_out, te_out]
|
| 269 |
)
|
| 270 |
clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
|
| 271 |
|
|
|
|
| 272 |
demo.queue(max_size=48).launch()
|
|
|
|
| 84 |
if model_name in _model_cache:
|
| 85 |
return _model_cache[model_name]
|
| 86 |
|
| 87 |
+
token = os.getenv("hf_token")
|
| 88 |
+
|
| 89 |
+
tok = AutoTokenizer.from_pretrained(
|
| 90 |
+
"ai4bharat/indictrans2-en-indic-1B",
|
| 91 |
+
trust_remote_code=True, use_fast=True
|
| 92 |
+
)
|
| 93 |
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 94 |
model_name, trust_remote_code=True,
|
| 95 |
+
low_cpu_mem_usage=True, dtype=dtype, token = token
|
| 96 |
).to(device).eval()
|
| 97 |
|
| 98 |
+
# Fix vocab (some HF models have mismatched config.vocab_size)
|
| 99 |
try:
|
| 100 |
mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
|
| 101 |
except Exception:
|
|
|
|
| 114 |
] + [f"<ID{i}>" for i in range(10)]
|
| 115 |
out = []
|
| 116 |
for c in candidates:
|
| 117 |
+
if c in vocab:
|
| 118 |
+
out.append([vocab[c]])
|
| 119 |
+
continue
|
| 120 |
sp_c = "▁" + c
|
| 121 |
+
if sp_c in vocab:
|
| 122 |
+
out.append([vocab[sp_c]])
|
| 123 |
return out
|
| 124 |
|
| 125 |
+
|
| 126 |
+
# --------------------- Streaming Translation ---------------------
|
| 127 |
+
BATCH_SIZE = 6
|
| 128 |
+
|
| 129 |
@torch.inference_mode()
|
| 130 |
+
def translate_dual_stream(text, model_choice, num_beams, max_new):
|
| 131 |
+
"""
|
| 132 |
+
Generator that yields (hindi_accumulated_text, telugu_accumulated_text)
|
| 133 |
+
after each processed batch so the UI updates progressively.
|
| 134 |
+
"""
|
| 135 |
+
if not text or not text.strip():
|
| 136 |
+
yield "", ""
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Prepare once
|
| 140 |
tok, mdl = load_model(MODELS[model_choice])
|
| 141 |
BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
|
|
|
|
| 142 |
sentences = split_into_sentences(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
hi_acc, te_acc = [], []
|
| 145 |
+
|
| 146 |
+
# Clear outputs immediately for a snappy feel
|
| 147 |
+
yield "", ""
|
| 148 |
+
|
| 149 |
+
for i in range(0, len(sentences), BATCH_SIZE):
|
| 150 |
+
batch = sentences[i:i + BATCH_SIZE]
|
| 151 |
+
|
| 152 |
+
# --- Hindi batch ---
|
| 153 |
+
try:
|
| 154 |
+
proc_hi = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=HI_CODE)
|
| 155 |
+
enc_hi = tok(
|
| 156 |
+
proc_hi, padding=True, truncation=True, max_length=256, return_tensors="pt"
|
| 157 |
+
).to(device)
|
| 158 |
+
out_hi = mdl.generate(
|
| 159 |
+
**enc_hi,
|
| 160 |
+
max_length=max_new, # keep semantics same as your original
|
| 161 |
+
num_beams=int(num_beams),
|
| 162 |
+
early_stopping=True,
|
| 163 |
+
no_repeat_ngram_size=3,
|
| 164 |
+
use_cache=False,
|
| 165 |
+
bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
|
| 166 |
+
)
|
| 167 |
+
dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True)
|
| 168 |
+
dec_hi = [strip_lang_tags(t) for t in dec_hi]
|
| 169 |
+
post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
|
| 170 |
+
post_hi = [ensure_hindi_danda(x) for x in post_hi]
|
| 171 |
+
hi_acc.extend(p.strip() for p in post_hi)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
hi_acc.append(f"⚠️ Hindi failed (batch {i//BATCH_SIZE+1}): {e}")
|
| 174 |
+
|
| 175 |
+
# --- Telugu batch ---
|
| 176 |
+
try:
|
| 177 |
+
proc_te = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=TE_CODE)
|
| 178 |
+
enc_te = tok(
|
| 179 |
+
proc_te, padding=True, truncation=True, max_length=256, return_tensors="pt"
|
| 180 |
+
).to(device)
|
| 181 |
+
out_te = mdl.generate(
|
| 182 |
+
**enc_te,
|
| 183 |
+
max_length=max_new,
|
| 184 |
+
num_beams=int(num_beams),
|
| 185 |
+
early_stopping=True,
|
| 186 |
+
no_repeat_ngram_size=3,
|
| 187 |
+
use_cache=False,
|
| 188 |
+
bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
|
| 189 |
+
)
|
| 190 |
+
dec_te = tok.batch_decode(out_te, skip_special_tokens=True)
|
| 191 |
+
dec_te = [strip_lang_tags(t) for t in dec_te]
|
| 192 |
+
post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
|
| 193 |
+
te_acc.extend(p.strip() for p in post_te)
|
| 194 |
+
except Exception as e:
|
| 195 |
+
te_acc.append(f"⚠️ Telugu failed (batch {i//BATCH_SIZE+1}): {e}")
|
| 196 |
+
|
| 197 |
+
# Stream the accumulators so far
|
| 198 |
+
yield (" ".join(hi_acc), " ".join(te_acc))
|
| 199 |
|
| 200 |
|
| 201 |
# --------------------- Dark Theme ---------------------
|
|
|
|
| 206 |
body_text_color="#f3f4f6",
|
| 207 |
block_background_fill="#111827",
|
| 208 |
block_border_color="#1f2937",
|
| 209 |
+
block_title_text_color="#123456",
|
| 210 |
button_primary_background_fill="#2563eb",
|
| 211 |
button_primary_text_color="#ffffff",
|
| 212 |
)
|
|
|
|
| 299 |
num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
|
| 300 |
max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
|
| 301 |
|
| 302 |
+
# Use streaming generator
|
| 303 |
translate_btn.click(
|
| 304 |
+
translate_dual_stream,
|
| 305 |
inputs=[src, model_choice, num_beams, max_new],
|
| 306 |
outputs=[hi_out, te_out]
|
| 307 |
)
|
| 308 |
clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
|
| 309 |
|
| 310 |
+
# Enable queue for streaming
|
| 311 |
demo.queue(max_size=48).launch()
|