SagarVelamuri commited on
Commit
4151903
·
verified ·
1 Parent(s): 683b0a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -123
app.py CHANGED
@@ -1,177 +1,228 @@
1
- import os, traceback, types, torch, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
- # Robust import for IndicProcessor
5
- try:
6
- from IndicTransToolkit import IndicProcessor
7
- except Exception:
8
- from IndicTransToolkit.IndicTransToolkit import IndicProcessor
9
-
10
-
11
- # -------- Config --------
12
- TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
13
- MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
14
- TOKENIZER_REV, MODEL_REV = os.getenv("TOKENIZER_REV"), os.getenv("MODEL_REV")
15
 
 
16
  SRC_CODE = "eng_Latn"
17
  HI_CODE = "hin_Deva"
18
  TE_CODE = "tel_Telu"
19
 
20
- # -------- Model Load --------
21
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
 
24
- tok_kwargs = dict(trust_remote_code=True, use_fast=True)
25
- if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
26
- tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
 
 
 
 
27
 
28
- mdl_kwargs = dict(trust_remote_code=True, attn_implementation="eager",
29
- low_cpu_mem_usage=True, dtype=dtype)
30
- if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
31
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device).eval()
32
 
33
- # Ensure generation config is correct
34
- if getattr(model.generation_config, "pad_token_id", None) is None:
35
- model.generation_config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
36
- if getattr(model.generation_config, "eos_token_id", None) is None and tokenizer.eos_token_id is not None:
37
- model.generation_config.eos_token_id = tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- def _ensure_vocab_consistency(md, tok):
40
- try:
41
- actual_vocab = md.get_output_embeddings().weight.shape[0]
42
- except Exception: actual_vocab = None
43
- if actual_vocab:
44
- md.config.vocab_size = actual_vocab
45
- md.generation_config.vocab_size = actual_vocab
46
- else:
47
- vs = getattr(tok, "vocab_size", len(tok) if hasattr(tok, "__len__") else 64000)
48
- md.config.vocab_size = vs
49
- md.generation_config.vocab_size = vs
50
- if not hasattr(md.config, "get_text_config"):
51
- md.config.get_text_config = types.MethodType(lambda self: self, md.config)
52
-
53
- _ensure_vocab_consistency(model, tokenizer)
54
-
55
- for obj in (model.config, model.generation_config):
56
- try: setattr(obj, "use_cache", False)
57
- except: pass
58
-
59
- # Processor
60
- ip = IndicProcessor(inference=True)
61
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # -------- Inference --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @torch.inference_mode()
65
- def _translate_to_lang(text, tgt_code, num_beams, max_new_tokens, temperature, top_p, top_k):
66
- batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
67
- enc = tokenizer(batch, max_length=256, truncation=True, padding="longest",
68
- return_tensors="pt").to(device)
69
- do_sample = (temperature and float(temperature) > 0)
70
- out = model.generate(
71
- **enc,
72
- max_new_tokens=int(max_new_tokens),
73
- num_beams=int(num_beams),
74
- do_sample=do_sample,
75
- temperature=float(temperature) if do_sample else None,
76
- top_p=float(top_p) if do_sample else None,
77
- top_k=int(top_k) if do_sample else None,
78
- use_cache=False,
79
- )
80
- decoded = tokenizer.batch_decode(out, skip_special_tokens=True)
81
- final = ip.postprocess_batch(decoded, lang=tgt_code)
82
- return final[0].strip()
 
 
 
 
 
83
 
84
- def translate_dual(text, num_beams, max_new_tokens, temperature, top_p, top_k):
85
- text = text.strip()
86
- if not text: return "", ""
 
 
 
 
 
 
 
87
  try:
88
- hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
89
  except Exception as e:
90
- hi = f"⚠️ Hindi error: {type(e).__name__}: {str(e).splitlines()[-1]}"
91
  try:
92
- te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
93
  except Exception as e:
94
- te = f"⚠️ Telugu error: {type(e).__name__}: {str(e).splitlines()[-1]}"
95
  return hi, te
96
 
97
 
98
- # -------- Theme & Styling --------
99
- THEME = gr.themes.Base(
100
  primary_hue="blue", neutral_hue="slate"
101
  ).set(
102
- body_background_fill="#f9fafb",
103
- body_text_color="#111827",
104
- block_background_fill="#ffffff",
105
- block_border_color="#e5e7eb",
106
- block_title_text_color="#111827",
107
  button_primary_background_fill="#2563eb",
108
- button_primary_text_color="#ffffff"
109
  )
110
 
111
  CUSTOM_CSS = """
112
- #hdr {
113
- text-align:center; padding:16px; margin-bottom:16px;
114
- }
115
- #hdr h1 { font-size:24px; font-weight:700; margin:0; color:#111827; }
116
- #hdr p { font-size:14px; color:#6b7280; margin:4px 0 0; }
117
-
118
- .panel {
119
- border:1px solid #e5e7eb; border-radius:12px;
120
- background:white; box-shadow:0 1px 3px rgba(0,0,0,0.08);
121
- padding:12px; display:flex; flex-direction:column;
122
- }
123
- .panel h2 {
124
- font-size:16px; font-weight:600; margin-bottom:8px; color:#374151;
125
- }
126
- textarea {
127
- font-size:15px !important; line-height:1.55 !important;
128
- padding:10px 12px !important;
129
- border:1px solid #d1d5db !important; border-radius:8px !important;
130
- }
131
- button { font-weight:600 !important; border-radius:8px !important; }
132
- button:hover { opacity:0.9; transition:opacity 0.2s; }
133
  """
134
 
135
- # -------- UI --------
136
- with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → Hindi / Telugu Translator") as demo:
137
  with gr.Group(elem_id="hdr"):
138
  gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
139
- gr.Markdown("<p>Powered by IndicTrans2 · law-ai/InLegalTrans-En2Indic-1B</p>")
 
 
 
 
 
140
 
141
  with gr.Row():
142
- # Input Column
143
  with gr.Column(scale=2):
144
  with gr.Group(elem_classes="panel"):
145
  gr.Markdown("<h2>English Input</h2>")
146
- src = gr.Textbox(placeholder="Type English text...", lines=12, show_label=False)
147
-
148
  with gr.Row():
149
  translate_btn = gr.Button("👉 Translate", variant="primary")
150
  clear_btn = gr.Button("Clear", variant="secondary")
151
 
152
- # Output Column
153
  with gr.Column(scale=2):
154
  with gr.Group(elem_classes="panel"):
155
  gr.Markdown("<h2>Hindi Translation</h2>")
156
  hi_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
157
-
158
  with gr.Group(elem_classes="panel"):
159
  gr.Markdown("<h2>Telugu Translation</h2>")
160
  te_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
161
 
162
- # Settings Column
163
  with gr.Column(scale=1):
164
  with gr.Group(elem_classes="panel"):
165
- gr.Markdown("<h2>Advanced Settings</h2>")
166
- num_beams = gr.Slider(1, 8, value=4, step=1, label="Num Beams")
167
- max_new = gr.Slider(16, 512, value=128, step=8, label="Max Tokens")
168
- temperature = gr.Slider(0.0, 1.5, value=0.0, step=0.05, label="Temperature")
169
- top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Top-p")
170
- top_k = gr.Slider(0, 100, value=50, step=1, label="Top-k")
171
-
172
- # Wiring
173
- translate_btn.click(translate_dual, inputs=[src, num_beams, max_new, temperature, top_p, top_k],
174
- outputs=[hi_out, te_out])
175
  clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
176
 
177
  demo.queue(max_size=48).launch()
 
1
+ import os, re, types, traceback, torch, gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from IndicTransToolkit import IndicProcessor
4
 
5
+ # --------------------- Device ---------------------
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
8
 
9
+ # --------------------- Languages ------------------
10
  SRC_CODE = "eng_Latn"
11
  HI_CODE = "hin_Deva"
12
  TE_CODE = "tel_Telu"
13
 
14
+ ip = IndicProcessor(inference=True)
 
 
15
 
16
+ # --------------------- Regex / Helpers ---------------------
17
+ TAG_REGEX = re.compile(
18
+ r"(?:_src\S+)|(?:tgt\S+)|"
19
+ r"(?:>>\s*\S+\s*<<)|"
20
+ r"\b(?:eng_Latn|hin_Deva|hin_deva|tel_Telu|tel_telu)\b|"
21
+ r"<ID\d*>"
22
+ )
23
 
24
+ def strip_lang_tags(text: str) -> str:
25
+ s = TAG_REGEX.sub(" ", text)
26
+ return re.sub(r"\s{2,}", " ", s).strip()
 
27
 
28
+ def ensure_hindi_danda(s: str) -> str:
29
+ s = re.sub(r"\.\s*$", "", s)
30
+ if not re.search(r"[।?!…]\s*$", s) and re.search(r"[\u0900-\u097F]\s*$", s):
31
+ s += ""
32
+ return s
33
+
34
+ # Sentence splitting (pysbd or fallback)
35
+ try:
36
+ import pysbd
37
+ _SEGMENTER = pysbd.Segmenter(language="en", clean=True)
38
+ except Exception:
39
+ _SEGMENTER = None
40
+
41
+ _LEGAL_JOIN_RE = re.compile(r'\b([A-Za-z]{1,6})\.\s*$')
42
+ _NEXT_CONT_RE = re.compile(r'^\s*(?:[\(\[\{]|\d|[a-z])')
43
+
44
+ def _merge_legal_abbrev_breaks(sents):
45
+ merged, i = [], 0
46
+ while i < len(sents):
47
+ cur = sents[i].strip()
48
+ while i + 1 < len(sents):
49
+ nxt = sents[i + 1].lstrip()
50
+ if _LEGAL_JOIN_RE.search(cur) and _NEXT_CONT_RE.match(nxt):
51
+ cur = f"{cur} {nxt}"
52
+ i += 1
53
+ else:
54
+ break
55
+ merged.append(cur)
56
+ i += 1
57
+ return [s for s in merged if s]
58
+
59
+ def split_into_sentences(text: str):
60
+ if _SEGMENTER is not None:
61
+ return _merge_legal_abbrev_breaks(_SEGMENTER.segment(text))
62
+ PLACEHOLDER = "\uE000"
63
+ protected = re.sub(
64
+ r'\b([A-Za-z]{1,6})\.(?=\s*(?:[\(\[\{]|\d|[a-z]))',
65
+ r'\1' + PLACEHOLDER, text.strip()
66
+ )
67
+ protected = re.sub(
68
+ r'\b([A-Za-z]{1,5})\.(?=\s+[A-Z])',
69
+ r'\1' + PLACEHOLDER, protected
70
+ )
71
+ parts = re.split(r'(?<=[.?!])\s+', protected)
72
+ return _merge_legal_abbrev_breaks([p.replace(PLACEHOLDER, '.') for p in parts if p.strip()])
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # --------------------- Model Loader ---------------------
76
+ MODELS = {
77
+ "Default (Public)": "law-ai/InLegalTrans-En2Indic-1B",
78
+ "Fine-tuned (Private)": "SagarVelamuri/InLegalTrans-En2Indic-FineTuned-Tel-Hin"
79
+ }
80
+
81
+ _model_cache = {}
82
+
83
+ def load_model(model_name: str):
84
+ if model_name in _model_cache:
85
+ return _model_cache[model_name]
86
 
87
+ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
88
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(
89
+ model_name, trust_remote_code=True,
90
+ low_cpu_mem_usage=True, dtype=dtype
91
+ ).to(device).eval()
92
+
93
+ # Fix vocab
94
+ try:
95
+ mdl.config.vocab_size = mdl.get_output_embeddings().weight.shape[0]
96
+ except Exception:
97
+ pass
98
+
99
+ _model_cache[model_name] = (tok, mdl)
100
+ return tok, mdl
101
+
102
+
103
+ def build_bad_words_ids_from_vocab(tok):
104
+ vocab = tok.get_vocab()
105
+ candidates = [
106
+ "eng_Latn","hin_Deva","hin_deva","tel_Telu","tel_telu",
107
+ "_srceng_Latn","tgthin_Deva","tgt_tel_Telu",
108
+ ">>hin_Deva<<",">>tel_Telu<<",
109
+ ] + [f"<ID{i}>" for i in range(10)]
110
+ out = []
111
+ for c in candidates:
112
+ if c in vocab: out.append([vocab[c]]); continue
113
+ sp_c = "▁" + c
114
+ if sp_c in vocab: out.append([vocab[sp_c]])
115
+ return out
116
+
117
+ # --------------------- Translation ---------------------
118
  @torch.inference_mode()
119
+ def _translate(text: str, tgt_lang: str, model_choice: str,
120
+ num_beams=4, max_new=128, batch_size=3) -> str:
121
+
122
+ tok, mdl = load_model(MODELS[model_choice])
123
+ BAD_WORDS_IDS = build_bad_words_ids_from_vocab(tok)
124
+
125
+ sentences = split_into_sentences(text)
126
+ full_trans = []
127
+
128
+ for i in range(0, len(sentences), batch_size):
129
+ batch = sentences[i:i+batch_size]
130
+ proc = ip.preprocess_batch(batch, src_lang=SRC_CODE, tgt_lang=tgt_lang)
131
+ enc = tok(proc, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
132
+
133
+ out = mdl.generate(
134
+ **enc, max_length=max_new, num_beams=num_beams,
135
+ early_stopping=True, no_repeat_ngram_size=3, use_cache=False,
136
+ bad_words_ids=BAD_WORDS_IDS if BAD_WORDS_IDS else None
137
+ )
138
+
139
+ decoded = tok.batch_decode(out, skip_special_tokens=True)
140
+ decoded = [strip_lang_tags(t) for t in decoded]
141
+ post = ip.postprocess_batch(decoded, lang=tgt_lang)
142
 
143
+ if tgt_lang == HI_CODE:
144
+ post = [ensure_hindi_danda(x) for x in post]
145
+
146
+ full_trans.extend(p.strip() for p in post)
147
+
148
+ return " ".join(full_trans)
149
+
150
+
151
+ def translate_dual(text, model_choice, num_beams, max_new):
152
+ if not text.strip(): return "", ""
153
  try:
154
+ hi = _translate(text, HI_CODE, model_choice, num_beams=num_beams, max_new=max_new)
155
  except Exception as e:
156
+ hi = f"⚠️ Hindi failed: {e}"
157
  try:
158
+ te = _translate(text, TE_CODE, model_choice, num_beams=num_beams, max_new=max_new)
159
  except Exception as e:
160
+ te = f"⚠️ Telugu failed: {e}"
161
  return hi, te
162
 
163
 
164
+ # --------------------- Dark Theme ---------------------
165
+ THEME = gr.themes.Soft(
166
  primary_hue="blue", neutral_hue="slate"
167
  ).set(
168
+ body_background_fill="#0b0f19",
169
+ body_text_color="#f3f4f6",
170
+ block_background_fill="#111827",
171
+ block_border_color="#1f2937",
172
+ block_title_text_color="#e5e7eb",
173
  button_primary_background_fill="#2563eb",
174
+ button_primary_text_color="#ffffff",
175
  )
176
 
177
  CUSTOM_CSS = """
178
+ #hdr { text-align:center; padding:16px; }
179
+ #hdr h1 { font-size:24px; font-weight:700; color:#f9fafb; margin:0; }
180
+ #hdr p { font-size:14px; color:#9ca3af; margin-top:4px; }
181
+ .panel { border:1px solid #1f2937; border-radius:10px; padding:12px; background:#111827; box-shadow:0 1px 2px rgba(0,0,0,0.4);}
182
+ .panel h2 { font-size:16px; font-weight:600; margin-bottom:6px; color:#f3f4f6; }
183
+ textarea { background:#0b0f19 !important; color:#f9fafb !important; border-radius:8px !important; border:1px solid #374151 !important; font-size:15px !important; line-height:1.55; }
184
+ button { border-radius:8px !important; font-weight:600 !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  """
186
 
187
+ # --------------------- UI ---------------------
188
+ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN → HI/TE Translator") as demo:
189
  with gr.Group(elem_id="hdr"):
190
  gr.Markdown("<h1>English → Hindi & Telugu Translator</h1>")
191
+ gr.Markdown("<p>IndicTrans2 with batch sentence decomposition</p>")
192
+
193
+ model_choice = gr.Dropdown(
194
+ label="Choose Model", choices=list(MODELS.keys()),
195
+ value="Default (Public)"
196
+ )
197
 
198
  with gr.Row():
 
199
  with gr.Column(scale=2):
200
  with gr.Group(elem_classes="panel"):
201
  gr.Markdown("<h2>English Input</h2>")
202
+ src = gr.Textbox(lines=12, placeholder="Enter English...", show_label=False)
 
203
  with gr.Row():
204
  translate_btn = gr.Button("👉 Translate", variant="primary")
205
  clear_btn = gr.Button("Clear", variant="secondary")
206
 
 
207
  with gr.Column(scale=2):
208
  with gr.Group(elem_classes="panel"):
209
  gr.Markdown("<h2>Hindi Translation</h2>")
210
  hi_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
 
211
  with gr.Group(elem_classes="panel"):
212
  gr.Markdown("<h2>Telugu Translation</h2>")
213
  te_out = gr.Textbox(lines=6, show_copy_button=True, show_label=False)
214
 
 
215
  with gr.Column(scale=1):
216
  with gr.Group(elem_classes="panel"):
217
+ gr.Markdown("<h2>Settings</h2>")
218
+ num_beams = gr.Slider(1, 8, value=4, step=1, label="Beam Search")
219
+ max_new = gr.Slider(32, 512, value=128, step=16, label="Max New Tokens")
220
+
221
+ translate_btn.click(
222
+ translate_dual,
223
+ inputs=[src, model_choice, num_beams, max_new],
224
+ outputs=[hi_out, te_out]
225
+ )
 
226
  clear_btn.click(lambda: ("", "", ""), outputs=[src, hi_out, te_out])
227
 
228
  demo.queue(max_size=48).launch()