SagarVelamuri commited on
Commit
cd9cc66
·
verified ·
1 Parent(s): 448b0e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -26
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, types, traceback, torch, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from IndicTransToolkit import IndicProcessor
4
  import spacy
@@ -14,8 +14,7 @@ TE_CODE = "tel_Telu"
14
 
15
  ip = IndicProcessor(inference=True)
16
 
17
- # --------------------- Sentence Splitting (spaCy) ---------------------
18
- import spacy
19
  try:
20
  nlp = spacy.load("en_core_web_sm")
21
  except OSError:
@@ -23,13 +22,44 @@ except OSError:
23
  download("en_core_web_sm")
24
  nlp = spacy.load("en_core_web_sm")
25
 
26
-
27
  def split_into_sentences(text):
28
  """Split English text into sentences using spaCy."""
29
  doc = nlp(text.strip())
30
  return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
31
 
32
- # --------------------- Cleanup Helper ---------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def clean_translation(text):
34
  """Remove unresolved placeholder tags such as <ID1>, <ID2>."""
35
  return re.sub(r"<ID\d+>", "", text).strip()
@@ -57,7 +87,6 @@ def load_model(model_name: str):
57
  low_cpu_mem_usage=True, dtype=dtype, token=token
58
  ).to(device).eval()
59
 
60
- # Fix vocab mismatch if any
61
  try:
62
  mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
63
  except Exception:
@@ -66,23 +95,25 @@ def load_model(model_name: str):
66
  _model_cache[model_name] = (tok, mdl)
67
  return tok, mdl
68
 
69
- # --------------------- Streaming Translation ---------------------
70
  @torch.inference_mode()
71
  def translate_dual_stream(text, model_choice, num_beams, max_new):
72
- """Generator that yields progressive Hindi & Telugu translations one sentence at a time."""
73
  if not text or not text.strip():
74
  yield "", ""
75
  return
76
 
77
  tok, mdl = load_model(MODELS[model_choice])
 
 
 
78
  sentences = split_into_sentences(text)
79
  hi_acc, te_acc = [], []
80
 
81
- # Yield empty for immediate UI update
82
- yield "", ""
83
 
84
  for i, sentence in enumerate(sentences, 1):
85
- # --- Hindi Translation ---
86
  try:
87
  batch_hi = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=HI_CODE)
88
  enc_hi = tok(batch_hi, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
@@ -97,11 +128,17 @@ def translate_dual_stream(text, model_choice, num_beams, max_new):
97
  )
98
  dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True, clean_up_tokenization_spaces=True)
99
  post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
100
- hi_acc.append(clean_translation(post_hi[0]))
 
 
 
 
 
 
101
  except Exception as e:
102
  hi_acc.append(f"⚠️ Hindi failed (sentence {i}): {e}")
103
 
104
- # --- Telugu Translation ---
105
  try:
106
  batch_te = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=TE_CODE)
107
  enc_te = tok(batch_te, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
@@ -120,7 +157,6 @@ def translate_dual_stream(text, model_choice, num_beams, max_new):
120
  except Exception as e:
121
  te_acc.append(f"⚠️ Telugu failed (sentence {i}): {e}")
122
 
123
- # Stream progressive output
124
  yield (" ".join(hi_acc), " ".join(te_acc))
125
 
126
  # --------------------- Dark Theme ---------------------
@@ -148,7 +184,7 @@ CUSTOM_CSS = """
148
  textarea { background:#0b0f19 !important; color:#f9fafb !important; border-radius:8px !important; border:1px solid #374151 !important; font-size:15px !important; line-height:1.55; }
149
  button { border-radius:8px !important; font-weight:600 !important; }
150
 
151
- /* Make all component labels readable on dark bg */
152
  .gradio-container label,
153
  .gradio-container .label,
154
  .gradio-container .block-title,
@@ -157,7 +193,7 @@ button { border-radius:8px !important; font-weight:600 !important; }
157
  color:#093999 !important;
158
  }
159
 
160
- /* --- Dropdown: dark text on white field/menu --- */
161
  #model_dd .wrap,
162
  #model_dd .container {
163
  background:#111827 !important;
@@ -169,19 +205,19 @@ button { border-radius:8px !important; font-weight:600 !important; }
169
  #model_dd ::placeholder,
170
  #model_dd select,
171
  #model_dd option {
172
- color: #ffffff!important; /* dark text */
173
  background:#111827 !important;
174
  }
175
  #model_dd .options,
176
  #model_dd .options .item {
177
  background:#111827 !important;
178
- color: #ffffff !important;
179
  }
180
- #model_dd label { /* the component's own label */
181
  color:#efe4b0 !important;
182
  }
183
 
184
- /* Sliders: keep labels visible */
185
  .gradio-container .range-block label,
186
  .gradio-container .gr-slider label {
187
  color:#efe4b0 !important;
@@ -192,7 +228,7 @@ button { border-radius:8px !important; font-weight:600 !important; }
192
  with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
193
  with gr.Group(elem_id="hdr"):
194
  gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
195
- gr.Markdown("<p>IndicTrans2 with simplified preprocessing and sentence-wise translation</p>")
196
 
197
  model_choice = gr.Dropdown(
198
  label="Choose Model",
@@ -205,10 +241,10 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as
205
  with gr.Column(scale=2):
206
  with gr.Group(elem_classes="panel"):
207
  gr.Markdown("<h2>English Input</h2>")
208
- src = gr.Textbox(lines=12, placeholder="Enter English...", show_label=False)
209
  with gr.Row():
210
  translate_btn = gr.Button("Translate", variant="primary")
211
- clear_btn = gr.Button("Clear", variant="secondary")
212
 
213
  with gr.Column(scale=2):
214
  with gr.Group(elem_classes="panel"):
@@ -222,15 +258,14 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as
222
  with gr.Group(elem_classes="panel"):
223
  gr.Markdown("<h2>Settings</h2>")
224
  num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
225
- max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
226
 
227
- # Stream generator connection
228
  translate_btn.click(
229
  translate_dual_stream,
230
  inputs=[src, model_choice, num_beams, max_new],
231
  outputs=[hi_out, te_out]
232
  )
 
233
  clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
234
 
235
- # Enable queue for streaming
236
  demo.queue(max_size=48).launch()
 
1
+ import os, re, torch, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from IndicTransToolkit import IndicProcessor
4
  import spacy
 
14
 
15
  ip = IndicProcessor(inference=True)
16
 
17
+ # --------------------- spaCy Sentence Splitter ---------------------
 
18
  try:
19
  nlp = spacy.load("en_core_web_sm")
20
  except OSError:
 
22
  download("en_core_web_sm")
23
  nlp = spacy.load("en_core_web_sm")
24
 
 
25
  def split_into_sentences(text):
26
  """Split English text into sentences using spaCy."""
27
  doc = nlp(text.strip())
28
  return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
29
 
30
+ # --------------------- Abbreviation Expansion ---------------------
31
+ ABBREVIATION_MAP = {
32
+ "subs.": "subsection",
33
+ "cl.": "clause",
34
+ "art.": "article",
35
+ "sec.": "section",
36
+ "s.": "section",
37
+ "no.": "number",
38
+ "sch.": "schedule",
39
+ "para.": "paragraph",
40
+ "r.": "rule",
41
+ "reg.": "regulation",
42
+ "dept.": "department",
43
+ }
44
+
45
+ _ABBR_PATTERN = re.compile(
46
+ r'(?<![A-Za-z])(' + '|'.join(re.escape(k) for k in ABBREVIATION_MAP.keys()) + r')(?=\s*(?:\(|\d|[A-Z]|[a-z]))',
47
+ flags=re.IGNORECASE
48
+ )
49
+
50
+ def expand_abbreviations(text: str) -> str:
51
+ """Replace known abbreviations with full forms safely (without affecting natural words)."""
52
+ def replacer(match):
53
+ key = match.group(0)
54
+ repl = ABBREVIATION_MAP.get(key.lower(), key)
55
+ if key.isupper():
56
+ return repl.upper()
57
+ elif key[0].isupper():
58
+ return repl.capitalize()
59
+ return repl
60
+ return _ABBR_PATTERN.sub(replacer, text)
61
+
62
+ # --------------------- Clean Up Placeholder Tags ---------------------
63
  def clean_translation(text):
64
  """Remove unresolved placeholder tags such as <ID1>, <ID2>."""
65
  return re.sub(r"<ID\d+>", "", text).strip()
 
87
  low_cpu_mem_usage=True, dtype=dtype, token=token
88
  ).to(device).eval()
89
 
 
90
  try:
91
  mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
92
  except Exception:
 
95
  _model_cache[model_name] = (tok, mdl)
96
  return tok, mdl
97
 
98
+ # --------------------- Translation ---------------------
99
  @torch.inference_mode()
100
  def translate_dual_stream(text, model_choice, num_beams, max_new):
101
+ """Stream Hindi and Telugu translations, one sentence at a time."""
102
  if not text or not text.strip():
103
  yield "", ""
104
  return
105
 
106
  tok, mdl = load_model(MODELS[model_choice])
107
+
108
+ # Expand known abbreviations
109
+ text = expand_abbreviations(text)
110
  sentences = split_into_sentences(text)
111
  hi_acc, te_acc = [], []
112
 
113
+ yield "", "" # Clear UI early
 
114
 
115
  for i, sentence in enumerate(sentences, 1):
116
+ # --- Hindi ---
117
  try:
118
  batch_hi = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=HI_CODE)
119
  enc_hi = tok(batch_hi, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
 
128
  )
129
  dec_hi = tok.batch_decode(out_hi, skip_special_tokens=True, clean_up_tokenization_spaces=True)
130
  post_hi = ip.postprocess_batch(dec_hi, lang=HI_CODE)
131
+ hi_text = clean_translation(post_hi[0])
132
+
133
+ # Optionally ensure danda for Hindi if missing
134
+ if not re.search(r"[।?!…]$", hi_text):
135
+ hi_text += "।"
136
+
137
+ hi_acc.append(hi_text)
138
  except Exception as e:
139
  hi_acc.append(f"⚠️ Hindi failed (sentence {i}): {e}")
140
 
141
+ # --- Telugu ---
142
  try:
143
  batch_te = ip.preprocess_batch([sentence], src_lang=SRC_CODE, tgt_lang=TE_CODE)
144
  enc_te = tok(batch_te, max_length=256, truncation=True, padding=True, return_tensors="pt").to(device)
 
157
  except Exception as e:
158
  te_acc.append(f"⚠️ Telugu failed (sentence {i}): {e}")
159
 
 
160
  yield (" ".join(hi_acc), " ".join(te_acc))
161
 
162
  # --------------------- Dark Theme ---------------------
 
184
  textarea { background:#0b0f19 !important; color:#f9fafb !important; border-radius:8px !important; border:1px solid #374151 !important; font-size:15px !important; line-height:1.55; }
185
  button { border-radius:8px !important; font-weight:600 !important; }
186
 
187
+ /* Labels */
188
  .gradio-container label,
189
  .gradio-container .label,
190
  .gradio-container .block-title,
 
193
  color:#093999 !important;
194
  }
195
 
196
+ /* Dropdown Styling */
197
  #model_dd .wrap,
198
  #model_dd .container {
199
  background:#111827 !important;
 
205
  #model_dd ::placeholder,
206
  #model_dd select,
207
  #model_dd option {
208
+ color:#ffffff!important;
209
  background:#111827 !important;
210
  }
211
  #model_dd .options,
212
  #model_dd .options .item {
213
  background:#111827 !important;
214
+ color:#ffffff !important;
215
  }
216
+ #model_dd label {
217
  color:#efe4b0 !important;
218
  }
219
 
220
+ /* Slider labels */
221
  .gradio-container .range-block label,
222
  .gradio-container .gr-slider label {
223
  color:#efe4b0 !important;
 
228
  with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
229
  with gr.Group(elem_id="hdr"):
230
  gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
231
+ gr.Markdown("<p>IndicTrans2 with abbreviation expansion and sentence-wise translation</p>")
232
 
233
  model_choice = gr.Dropdown(
234
  label="Choose Model",
 
241
  with gr.Column(scale=2):
242
  with gr.Group(elem_classes="panel"):
243
  gr.Markdown("<h2>English Input</h2>")
244
+ src = gr.Textbox(lines=12, placeholder="Enter English text...", show_label=False)
245
  with gr.Row():
246
  translate_btn = gr.Button("Translate", variant="primary")
247
+ clear_btn = gr.Button("Clear", variant="secondary")
248
 
249
  with gr.Column(scale=2):
250
  with gr.Group(elem_classes="panel"):
 
258
  with gr.Group(elem_classes="panel"):
259
  gr.Markdown("<h2>Settings</h2>")
260
  num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search", elem_id="model_dd")
261
+ max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens", elem_id="model_dd")
262
 
 
263
  translate_btn.click(
264
  translate_dual_stream,
265
  inputs=[src, model_choice, num_beams, max_new],
266
  outputs=[hi_out, te_out]
267
  )
268
+
269
  clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
270
 
 
271
  demo.queue(max_size=48).launch()