anisgtboi commited on
Commit
b08e390
·
verified ·
1 Parent(s): d3a831e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -184
app.py CHANGED
@@ -1,206 +1,103 @@
1
  # app.py
2
- # English -> Bengali translation (facebook/nllb-200-distilled-600M) + FLUX.1 [schnell] image generation
3
-
4
  import os
5
  import re
6
- import traceback
7
  import random
8
-
9
  import torch
10
  import gradio as gr
11
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
 
13
- # Diffusers + FluxPipeline
14
- try:
15
- from diffusers import FluxPipeline
16
- _FLUX_AVAILABLE = True
17
- except Exception:
18
- FluxPipeline = None
19
- _FLUX_AVAILABLE = False
20
-
21
- # -------- Configuration --------
22
- TRANSLATION_MODEL = os.environ.get("TRANSLATION_MODEL", "facebook/nllb-200-distilled-600M")
23
- SRC_LANG = os.environ.get("SRC_LANG", "eng_Latn")
24
- TGT_LANG = os.environ.get("TGT_LANG", "ben_Beng")
25
- MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512"))
26
-
27
- FLUX_MODEL_ID = os.environ.get("FLUX_MODEL_ID", "black-forest-labs/FLUX.1-schnell")
28
- DEFAULT_IMAGE_STEPS = int(os.environ.get("DEFAULT_IMAGE_STEPS", "2"))
29
-
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
- # -------- Globals --------
33
  _translation_tokenizer = None
34
  _translation_model = None
35
- _flux_pipe = None
36
-
37
- # -------- Helpers: translation --------
38
-
39
- def split_into_sentences(text: str):
40
- if not text:
41
- return []
42
- sentences = re.split(r'(?<=[.!?])\s+', text.strip())
43
- return [s.strip() for s in sentences if s.strip()]
44
-
45
 
46
  def load_translation_model():
47
  global _translation_tokenizer, _translation_model
48
  if _translation_tokenizer is None or _translation_model is None:
49
- try:
50
- print(f"Loading translation model {TRANSLATION_MODEL} on {DEVICE}...")
51
- _translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL, use_fast=False)
52
- _translation_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL).to(DEVICE)
53
- print("Translation model loaded.")
54
- except Exception as e:
55
- _translation_tokenizer, _translation_model = None, None
56
- raise RuntimeError(f"Failed to load translation model: {e}")
57
  return _translation_tokenizer, _translation_model
58
 
59
-
60
  def _get_forced_bos_token_id(tokenizer):
 
 
61
  try:
62
- if hasattr(tokenizer, "lang_code_to_id") and isinstance(tokenizer.lang_code_to_id, dict):
63
- if TGT_LANG in tokenizer.lang_code_to_id:
64
- return tokenizer.lang_code_to_id[TGT_LANG]
65
- except Exception:
66
- pass
67
 
68
- try:
69
- tid = tokenizer.convert_tokens_to_ids(TGT_LANG)
70
- if tid is not None and tid != tokenizer.unk_token_id:
71
- return tid
72
- except Exception:
73
- pass
74
-
75
- try:
76
- candidate = f"<2{TGT_LANG}>"
77
- tid = tokenizer.convert_tokens_to_ids(candidate)
78
- if tid is not None and tid != tokenizer.unk_token_id:
79
- return tid
80
- except Exception:
81
- pass
82
-
83
- return None
84
-
85
-
86
- def translate_text(text: str, max_length: int = MAX_LENGTH):
87
- if not text or not text.strip():
88
- return ""
89
-
90
- try:
91
- tokenizer, model = load_translation_model()
92
- except Exception as e:
93
- tb = traceback.format_exc()
94
- return f"Model load error: {e}\n{tb}"
95
 
 
 
 
 
96
  sentences = split_into_sentences(text)
97
  translations = []
98
- forced_bos = _get_forced_bos_token_id(tokenizer)
99
 
100
  for s in sentences:
101
- if not s:
102
- continue
103
- try:
104
- src_prefixed = f"{SRC_LANG} {s}"
105
- inputs = tokenizer(src_prefixed, return_tensors="pt", truncation=True, max_length=max_length)
106
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
107
-
108
- gen_kwargs = dict(max_length=max_length + 64, num_beams=5, early_stopping=True)
109
- if forced_bos is not None:
110
- gen_kwargs["forced_bos_token_id"] = int(forced_bos)
111
- elif getattr(model.config, "forced_bos_token_id", None) is not None:
112
- gen_kwargs["forced_bos_token_id"] = int(model.config.forced_bos_token_id)
113
-
114
- generated_tokens = model.generate(**inputs, **gen_kwargs)
115
- decoded = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
116
- if decoded.startswith(TGT_LANG):
117
- decoded = decoded[len(TGT_LANG):].strip()
118
-
119
- translations.append(decoded)
120
-
121
- except Exception as e:
122
- translations.append(f"[Error translating sentence: {e}]")
123
-
124
  return " ".join(translations)
125
 
126
- # -------- FLUX.1 Schnell image generation --------
127
-
128
- def load_flux_model(model_id: str = FLUX_MODEL_ID):
129
- global _flux_pipe
130
- if not _FLUX_AVAILABLE:
131
- raise RuntimeError("FluxPipeline (diffusers) not available. Install a diffusers version that provides FluxPipeline.")
132
-
133
- if _flux_pipe is None:
134
- try:
135
- # prefer bfloat16 on supported hardware for memory efficiency
136
- dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else (torch.float16 if torch.cuda.is_available() else torch.float32)
137
- print(f"Loading FLUX model {model_id} (dtype={dtype}) on {DEVICE}...")
138
- _flux_pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype)
139
-
140
- # Try enabling model CPU offload if available (reduces VRAM peak)
141
- try:
142
- _flux_pipe.enable_model_cpu_offload()
143
- except Exception:
144
- pass
145
-
146
- # Move pipeline to device when appropriate (some pipelines handle devices internally)
147
- try:
148
- _flux_pipe.to(DEVICE)
149
- except Exception:
150
- pass
151
-
152
- print("Flux model loaded.")
153
- except Exception as e:
154
- _flux_pipe = None
155
- raise RuntimeError(f"Failed to load FLUX model: {e}")
156
-
157
- return _flux_pipe
158
-
159
-
160
- def generate_flux_image(prompt: str, num_inference_steps: int = DEFAULT_IMAGE_STEPS, guidance_scale: float = 0.0):
161
- if not prompt or not prompt.strip():
162
- return None, "Please enter an image prompt."
163
- if not _FLUX_AVAILABLE:
164
- return None, "FluxPipeline is not available: install appropriate diffusers package to enable FLUX."
165
-
166
- try:
167
- pipe = load_flux_model()
168
- # Use CPU generator by default for reproducibility; Flux pipeline will handle device placement
169
- seed = random.randint(0, 2**31 - 1)
170
- generator = torch.Generator(device="cpu").manual_seed(seed)
171
-
172
- out = pipe(
173
- prompt=prompt,
174
- num_inference_steps=int(num_inference_steps),
175
- guidance_scale=float(guidance_scale),
176
- generator=generator,
177
- )
178
-
179
- image = out.images[0]
180
- return image, f"FLUX.1 Schnell generated (seed={seed}) steps={num_inference_steps} guidance={guidance_scale}"
181
- except Exception as e:
182
- tb = traceback.format_exc()
183
- return None, f"Error generating image: {e}\n{tb}"
184
-
185
- # -------- Gradio UI (no microphone / no speech) --------
186
-
187
  css = """
188
  .gradio-container { max-width: 1100px !important; }
189
- .header { text-align: center; padding: 12px; border-radius: 8px; color: white; background: linear-gradient(90deg,#2563eb,#7c3aed); }
190
  """
191
 
192
- with gr.Blocks(title="NLLB → Bengali + FLUX.1 Schnell", css=css) as demo:
193
- gr.Markdown("""<div class='header'><h2>Translation (NLLB) + Fast Image Generation (FLUX.1 Schnell)</h2></div>""")
194
-
195
  with gr.Tabs():
196
  with gr.TabItem("Translate"):
197
  with gr.Row():
198
  with gr.Column(scale=6):
199
  input_text = gr.Textbox(lines=6, label="English Text", placeholder="Type English text here...")
200
  with gr.Row():
201
- quick_1 = gr.Button("Hello, how are you?")
202
- quick_2 = gr.Button("Thank you very much.")
203
- quick_3 = gr.Button("The weather is nice today.")
204
  translate_btn = gr.Button("Translate")
205
  with gr.Column(scale=6):
206
  output_text = gr.Textbox(lines=6, label="Bengali Translation", interactive=False)
@@ -211,30 +108,24 @@ with gr.Blocks(title="NLLB → Bengali + FLUX.1 Schnell", css=css) as demo:
211
  with gr.Column(scale=6):
212
  image_prompt = gr.Textbox(lines=4, label="Image Prompt", placeholder="Describe the image to generate...")
213
  with gr.Row():
214
- generate_btn = gr.Button("Generate Image (FLUX)")
215
  clear_btn = gr.Button("Clear")
216
- steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_IMAGE_STEPS, label="Inference Steps (1-4 recommended)")
217
- guidance_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=0.0, label="Guidance Scale (Schnell often uses low guidance)")
218
  with gr.Column(scale=6):
219
  output_image = gr.Image(label="Generated Image")
220
  status = gr.Textbox(label="Status", interactive=False)
221
 
222
- gr.Markdown("---")
223
- gr.Markdown("*Notes: FLUX.1 [schnell] is designed for very low-step generation. A GPU with sufficient VRAM is strongly recommended. If you cannot run locally, consider a hosted API.*")
224
-
225
- # Bind events
226
- def _use_translation(t):
227
- return t
228
-
229
- quick_1.click(fn=lambda: "Hello, how are you?", inputs=None, outputs=input_text)
230
- quick_2.click(fn=lambda: "Thank you very much.", inputs=None, outputs=input_text)
231
- quick_3.click(fn=lambda: "The weather is nice today.", inputs=None, outputs=input_text)
232
-
233
  translate_btn.click(fn=translate_text, inputs=input_text, outputs=output_text)
234
- use_for_image.click(fn=_use_translation, inputs=output_text, outputs=image_prompt)
235
-
236
- generate_btn.click(fn=generate_flux_image, inputs=[image_prompt, steps_slider, guidance_slider], outputs=[output_image, status])
237
  clear_btn.click(fn=lambda: ["", None, ""], inputs=None, outputs=[image_prompt, output_image, status])
238
 
239
- if __name__ == '__main__':
240
- demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))
 
1
  # app.py
 
 
2
  import os
3
  import re
 
4
  import random
 
5
  import torch
6
  import gradio as gr
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
+ # -------- Translation: Facebook NLLB --------
10
+ TRANSLATION_MODEL = "facebook/nllb-200-distilled-600M"
11
+ SRC_LANG = "eng_Latn"
12
+ TGT_LANG = "ben_Beng"
13
+ MAX_LENGTH = 512
 
 
 
 
 
 
 
 
 
 
 
 
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
 
16
  _translation_tokenizer = None
17
  _translation_model = None
 
 
 
 
 
 
 
 
 
 
18
 
19
  def load_translation_model():
20
  global _translation_tokenizer, _translation_model
21
  if _translation_tokenizer is None or _translation_model is None:
22
+ _translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL, use_fast=False)
23
+ _translation_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL).to(DEVICE)
 
 
 
 
 
 
24
  return _translation_tokenizer, _translation_model
25
 
 
26
  def _get_forced_bos_token_id(tokenizer):
27
+ if hasattr(tokenizer, "lang_code_to_id") and TGT_LANG in tokenizer.lang_code_to_id:
28
+ return tokenizer.lang_code_to_id[TGT_LANG]
29
  try:
30
+ return tokenizer.convert_tokens_to_ids(TGT_LANG)
31
+ except:
32
+ return None
 
 
33
 
34
+ def split_into_sentences(text: str):
35
+ if not text: return []
36
+ return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text.strip()) if s.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def translate_text(text: str):
39
+ if not text or not text.strip(): return ""
40
+ tokenizer, model = load_translation_model()
41
+ forced_bos = _get_forced_bos_token_id(tokenizer)
42
  sentences = split_into_sentences(text)
43
  translations = []
 
44
 
45
  for s in sentences:
46
+ inputs = tokenizer(f"{SRC_LANG} {s}", return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
47
+ inputs = {k:v.to(DEVICE) for k,v in inputs.items()}
48
+ gen_kwargs = dict(max_length=MAX_LENGTH+64, num_beams=5, early_stopping=True)
49
+ if forced_bos is not None: gen_kwargs["forced_bos_token_id"] = int(forced_bos)
50
+ generated_tokens = model.generate(**inputs, **gen_kwargs)
51
+ decoded = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
52
+ if decoded.startswith(TGT_LANG): decoded = decoded[len(TGT_LANG):].strip()
53
+ translations.append(decoded)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return " ".join(translations)
55
 
56
+ # -------- Image Generation: SANA-Sprint 0.6B --------
57
+ from diffusers import DiffusionPipeline
58
+
59
+ SANA_MODEL = "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers"
60
+ DEFAULT_STEPS = 1
61
+ DEFAULT_GUIDANCE = 1.0
62
+ _sana_pipe = None
63
+
64
+ def load_sana():
65
+ global _sana_pipe
66
+ if _sana_pipe is None:
67
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else \
68
+ (torch.float16 if torch.cuda.is_available() else torch.float32)
69
+ _sana_pipe = DiffusionPipeline.from_pretrained(SANA_MODEL, torch_dtype=dtype)
70
+ try: _sana_pipe.enable_model_cpu_offload()
71
+ except: pass
72
+ _sana_pipe = _sana_pipe.to(DEVICE)
73
+ return _sana_pipe
74
+
75
+ def generate_sana_image(prompt: str, steps: int = DEFAULT_STEPS, guidance: float = DEFAULT_GUIDANCE, seed: int = None):
76
+ if not prompt.strip(): return None, "Please enter an image prompt."
77
+ pipe = load_sana()
78
+ if seed is None: seed = random.randint(0, 2**31-1)
79
+ gen = torch.Generator(device=DEVICE).manual_seed(seed) if DEVICE.type=="cuda" else torch.Generator().manual_seed(seed)
80
+ out = pipe(prompt, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=gen)
81
+ return out.images[0], f"SANA-Sprint generated (seed={seed}) steps={steps} guidance={guidance}"
82
+
83
+ # -------- Gradio App --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  css = """
85
  .gradio-container { max-width: 1100px !important; }
86
+ .header { text-align:center; padding:12px; border-radius:8px; color:white; background:linear-gradient(90deg,#2563eb,#7c3aed); }
87
  """
88
 
89
+ with gr.Blocks(title="NLLB → Bengali + SANA-Sprint", css=css) as demo:
90
+ gr.Markdown("<div class='header'><h2>English → Bengali Translation + Fast Image Generation (SANA-Sprint)</h2></div>")
91
+
92
  with gr.Tabs():
93
  with gr.TabItem("Translate"):
94
  with gr.Row():
95
  with gr.Column(scale=6):
96
  input_text = gr.Textbox(lines=6, label="English Text", placeholder="Type English text here...")
97
  with gr.Row():
98
+ quick1 = gr.Button("Hello, how are you?")
99
+ quick2 = gr.Button("Thank you very much.")
100
+ quick3 = gr.Button("The weather is nice today.")
101
  translate_btn = gr.Button("Translate")
102
  with gr.Column(scale=6):
103
  output_text = gr.Textbox(lines=6, label="Bengali Translation", interactive=False)
 
108
  with gr.Column(scale=6):
109
  image_prompt = gr.Textbox(lines=4, label="Image Prompt", placeholder="Describe the image to generate...")
110
  with gr.Row():
111
+ generate_btn = gr.Button("Generate Image (SANA)")
112
  clear_btn = gr.Button("Clear")
113
+ steps_slider = gr.Slider(minimum=1, maximum=4, step=1, value=DEFAULT_STEPS, label="Inference Steps (1-4 fastest)")
114
+ guidance_slider = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=DEFAULT_GUIDANCE, label="Guidance Scale")
115
  with gr.Column(scale=6):
116
  output_image = gr.Image(label="Generated Image")
117
  status = gr.Textbox(label="Status", interactive=False)
118
 
119
+ # Quick phrase events
120
+ quick1.click(fn=lambda: "Hello, how are you?", inputs=None, outputs=input_text)
121
+ quick2.click(fn=lambda: "Thank you very much.", inputs=None, outputs=input_text)
122
+ quick3.click(fn=lambda: "The weather is nice today.", inputs=None, outputs=input_text)
123
+
 
 
 
 
 
 
124
  translate_btn.click(fn=translate_text, inputs=input_text, outputs=output_text)
125
+ use_for_image.click(fn=lambda x: x, inputs=output_text, outputs=image_prompt)
126
+
127
+ generate_btn.click(fn=generate_sana_image, inputs=[image_prompt, steps_slider, guidance_slider], outputs=[output_image, status])
128
  clear_btn.click(fn=lambda: ["", None, ""], inputs=None, outputs=[image_prompt, output_image, status])
129
 
130
+ if __name__ == "__main__":
131
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))