SagarVelamuri commited on
Commit
3b1a54d
·
verified ·
1 Parent(s): eacfc7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -47
app.py CHANGED
@@ -84,13 +84,18 @@ def load_model(model_name: str):
84
  if model_name in _model_cache:
85
  return _model_cache[model_name]
86
 
87
- tok = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True, use_fast=True)
 
 
 
 
 
88
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
89
  model_name, trust_remote_code=True,
90
- low_cpu_mem_usage=True, dtype=dtype
91
  ).to(device).eval()
92
 
93
- # Fix vocab
94
  try:
95
  mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
96
  except Exception:
@@ -109,56 +114,88 @@ def build_bad_words_ids_from_vocab(tok):
109
  ] + [f"<ID{i}>" for i in range(10)]
110
  out = []
111
  for c in candidates:
112
- if c in vocab: out.append([vocab[c]]); continue
 
 
113
  sp_c = "▁" + c
114
- if sp_c in vocab: out.append([vocab[sp_c]])
 
115
  return out
116
 
117
- # --------------------- Translation ---------------------
 
 
 
118
  @torch.inference_mode()
119
- def _translate(text: str, tgt_lang: str, model_choice: str,
120
- num_beams=4, max_new=128, batch_size=3) -> str:
121
-
 
 
 
 
 
 
 
122
  tok, mdl = load_model(MODELS[model_choice])
123
  BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
124
-
125
  sentences = split_into_sentences(text)
126
- full_trans = []
127
-
128
- for i in range(0, len(sentences), batch_size):
129
- batch = sentences[i:i+batch_size]
130
- proc = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=tgt_lang)
131
- enc = tok(proc, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
132
-
133
- out = mdl.generate(
134
- **enc, max_length=max_new, num_beams=num_beams,
135
- early_stopping=True, no_repeat_ngram_size=3, use_cache=False,
136
- bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
137
- )
138
 
139
- decoded = tok.batch_decode(out, skip_special_tokens=True)
140
- decoded = [strip_lang_tags(t) for t in decoded]
141
- post = ip.postprocess_batch(decoded, lang=tgt_lang)
142
-
143
- if tgt_lang == HI_CODE:
144
- post = [ensure_hindi_danda(x) for x in post]
145
-
146
- full_trans.extend(p.strip() for p in post)
147
-
148
- return " ".join(full_trans)
149
-
150
-
151
- def translate_dual(text, model_choice, num_beams, max_new):
152
- if not text.strip(): return "", ""
153
- try:
154
- hi = _translate(text, HI_CODE, model_choice, num_beams=num_beams, max_new=max_new)
155
- except Exception as e:
156
- hi = f"⚠️ Hindi failed: {e}"
157
- try:
158
- te = _translate(text, TE_CODE, model_choice, num_beams=num_beams, max_new=max_new)
159
- except Exception as e:
160
- te = f"⚠️ Telugu failed: {e}"
161
- return hi, te
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
 
164
  # --------------------- Dark Theme ---------------------
@@ -169,7 +206,7 @@ THEME = gr.themes.Soft(
169
  body_text_color="#f3f4f6",
170
  block_background_fill="#111827",
171
  block_border_color="#1f2937",
172
- block_title_text_color="#e5e7eb",
173
  button_primary_background_fill="#2563eb",
174
  button_primary_text_color="#ffffff",
175
  )
@@ -262,11 +299,13 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as
262
  num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
263
  max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
264
 
 
265
  translate_btn.click(
266
- translate_dual,
267
  inputs=[src, model_choice, num_beams, max_new],
268
  outputs=[hi_out, te_out]
269
  )
270
  clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
271
 
 
272
  demo.queue(max_size=48).launch()
 
84
  if model_name in _model_cache:
85
  return _model_cache[model_name]
86
 
87
+ token = os.getenv("hf_token")
88
+
89
+ tok = AutoTokenizer.from_pretrained(
90
+ "ai4bharat/indictrans2-en-indic-1B",
91
+ trust_remote_code=True, use_fast=True
92
+ )
93
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
94
  model_name, trust_remote_code=True,
95
+ low_cpu_mem_usage=True, dtype=dtype, token = token
96
  ).to(device).eval()
97
 
98
+ # Fix vocab (some HF models have mismatched config.vocab_size)
99
  try:
100
  mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
101
  except Exception:
 
114
  ] + [f"<ID{i}>" for i in range(10)]
115
  out = []
116
  for c in candidates:
117
+ if c in vocab:
118
+ out.append([vocab[c]])
119
+ continue
120
  sp_c = "▁" + c
121
+ if sp_c in vocab:
122
+ out.append([vocab[sp_c]])
123
  return out
124
 
125
+
126
+ # --------------------- Streaming Translation ---------------------
127
+ BATCH_SIZE = 6
128
+
129
  @torch.inference_mode()
130
+ def translate_dual_stream(text, model_choice, num_beams, max_new):
131
+ """
132
+ Generator that yields (hindi_accumulated_text, telugu_accumulated_text)
133
+ after each processed batch so the UI updates progressively.
134
+ """
135
+ if not text or not text.strip():
136
+ yield "", ""
137
+ return
138
+
139
+ # Prepare once
140
  tok, mdl = load_model(MODELS[model_choice])
141
  BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
 
142
  sentences = split_into_sentences(text)
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ hi_acc, te_acc = [], []
145
+
146
+ # Clear outputs immediately for a snappy feel
147
+ yield "", ""
148
+
149
+ for i in range(0, len(sentences), BATCH_SIZE):
150
+ batch = sentences[i:i + BATCH_SIZE]
151
+
152
+ # --- Hindi batch ---
153
+ try:
154
+ proc_hi = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=HI_CODE)
155
+ enc_hi = tok(
156
+ proc_hi, padding=True, truncation=True, max_length=256, return_tensors="pt"
157
+ ).to(device)
158
+ out_hi = mdl.generate(
159
+ **enc_hi,
160
+ max_length=max_new, # keep semantics same as your original
161
+ num_beams=int(num_beams),
162
+ early_stopping=True,
163
+ no_repeat_ngram_size=3,
164
+ use_cache=False,
165
+ bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
166
+ )
167
+ dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True)
168
+ dec_hi = [strip_lang_tags(t) for t in dec_hi]
169
+ post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
170
+ post_hi = [ensure_hindi_danda(x) for x in post_hi]
171
+ hi_acc.extend(p.strip() for p in post_hi)
172
+ except Exception as e:
173
+ hi_acc.append(f"⚠️ Hindi failed (batch {i//BATCH_SIZE+1}): {e}")
174
+
175
+ # --- Telugu batch ---
176
+ try:
177
+ proc_te = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=TE_CODE)
178
+ enc_te = tok(
179
+ proc_te, padding=True, truncation=True, max_length=256, return_tensors="pt"
180
+ ).to(device)
181
+ out_te = mdl.generate(
182
+ **enc_te,
183
+ max_length=max_new,
184
+ num_beams=int(num_beams),
185
+ early_stopping=True,
186
+ no_repeat_ngram_size=3,
187
+ use_cache=False,
188
+ bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
189
+ )
190
+ dec_te = tok.batch_decode(out_te, skip_special_tokens=True)
191
+ dec_te = [strip_lang_tags(t) for t in dec_te]
192
+ post_te = ip.postprocess_batch(dec_te, lang=TE_CODE)
193
+ te_acc.extend(p.strip() for p in post_te)
194
+ except Exception as e:
195
+ te_acc.append(f"⚠️ Telugu failed (batch {i//BATCH_SIZE+1}): {e}")
196
+
197
+ # Stream the accumulators so far
198
+ yield (" ".join(hi_acc), " ".join(te_acc))
199
 
200
 
201
  # --------------------- Dark Theme ---------------------
 
206
  body_text_color="#f3f4f6",
207
  block_background_fill="#111827",
208
  block_border_color="#1f2937",
209
+ block_title_text_color="#123456",
210
  button_primary_background_fill="#2563eb",
211
  button_primary_text_color="#ffffff",
212
  )
 
299
  num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
300
  max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
301
 
302
+ # Use streaming generator
303
  translate_btn.click(
304
+ translate_dual_stream,
305
  inputs=[src, model_choice, num_beams, max_new],
306
  outputs=[hi_out, te_out]
307
  )
308
  clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
309
 
310
+ # Enable queue for streaming
311
  demo.queue(max_size=48).launch()