Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import re | |
| import time | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from diffusers import StableDiffusionPipeline, LCMScheduler | |
| import random | |
| # --- Configuration --- | |
| TRANSLATION_MODEL = "facebook/nllb-200-distilled-600M" | |
| SRC_LANG = "eng_Latn" | |
| TGT_LANG = "ben_Beng" | |
| MAX_LENGTH = 512 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Globals for caching --- | |
| translation_tokenizer = None | |
| translation_model = None | |
| image_pipe = None | |
| # --- Translation Functions --- | |
| def load_translation_model(): | |
| global translation_tokenizer, translation_model | |
| if translation_tokenizer is None or translation_model is None: | |
| try: | |
| print(f"Loading translation model {TRANSLATION_MODEL} on {DEVICE} ...") | |
| translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL) | |
| translation_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL).to(DEVICE) | |
| print("Translation model loaded successfully.") | |
| except Exception as e: | |
| translation_tokenizer, translation_model = None, None | |
| raise RuntimeError(f"Failed to load translation model: {e}") | |
| return translation_tokenizer, translation_model | |
| def split_into_sentences(text: str): | |
| if not text: | |
| return [] | |
| sentences = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def translate_text(text: str, max_length: int = MAX_LENGTH): | |
| if not text or not text.strip(): | |
| return "" | |
| try: | |
| tokenizer, model = load_translation_model() | |
| except Exception as e: | |
| return f"Model load error: {e}" | |
| sentences = split_into_sentences(text) | |
| translations = [] | |
| for s in sentences: | |
| if not s: | |
| continue | |
| try: | |
| formatted_text = f"{SRC_LANG} {s}" | |
| inputs = tokenizer( | |
| formatted_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_length, | |
| padding=False, | |
| ).to(DEVICE) | |
| generated_tokens = model.generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.convert_tokens_to_ids(TGT_LANG), | |
| max_length=max_length + 64, | |
| num_beams=5, | |
| early_stopping=True, | |
| ) | |
| decoded = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
| if decoded.startswith(TGT_LANG): | |
| decoded = decoded[len(TGT_LANG):].strip() | |
| translations.append(decoded) | |
| except RuntimeError as re_err: | |
| return f"Runtime error during generation: {re_err}" | |
| except Exception as e: | |
| translations.append(f"[Error translating sentence: {e}]") | |
| return " ".join(translations) | |
| # --- Faster Image Generation Functions --- | |
| def load_image_model(): | |
| global image_pipe | |
| if image_pipe is None: | |
| try: | |
| print("Loading faster image generation model...") | |
| # Using a much faster model with LCM-LoRA | |
| model_id = "lykon/dreamshaper-7" | |
| lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" | |
| image_pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ) | |
| # Load LCM-LoRA for faster inference | |
| image_pipe.load_lora_weights(lcm_lora_id) | |
| image_pipe.scheduler = LCMScheduler.from_config(image_pipe.scheduler.config) | |
| image_pipe = image_pipe.to(DEVICE) | |
| print("Fast image generation model loaded successfully.") | |
| except Exception as e: | |
| image_pipe = None | |
| raise RuntimeError(f"Failed to load image model: {e}") | |
| return image_pipe | |
| def generate_image(prompt: str, num_inference_steps: int = 4): # Only 4 steps needed with LCM! | |
| if not prompt or not prompt.strip(): | |
| return None, "Please enter a prompt to generate an image." | |
| try: | |
| pipe = load_image_model() | |
| seed = random.randint(0, 1000000) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| # Generate image with very few steps | |
| image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=1.0, # Low guidance scale for LCM | |
| generator=generator, | |
| ).images[0] | |
| return image, f"Successfully generated image in {num_inference_steps} steps!" | |
| except Exception as e: | |
| return None, f"Error generating image: {str(e)}" | |
| # --- Gradio UI with Real-time Features --- | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .header { | |
| text-align: center; | |
| background: linear-gradient(45deg, #667eea, #764ba2); | |
| padding: 20px; | |
| border-radius: 10px; | |
| color: white; | |
| margin-bottom: 20px; | |
| } | |
| .quick-btn { | |
| margin: 5px; | |
| padding: 8px 15px; | |
| } | |
| """ | |
| with gr.Blocks(title="Fast Bengali Translator & Image Generator", css=css) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| <div class="header"> | |
| <h1>⚡ Fast Bengali Translator & Image Generator</h1> | |
| <p>Real-time speech input with fast translation and image generation</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("🌐 Translation"): | |
| gr.Markdown("## English to Bengali Translation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Speech input section | |
| gr.Markdown("### 🎤 Speak or Type") | |
| audio_input = gr.Audio(type="filepath", label="Record your voice") | |
| transcribe_btn = gr.Button("Transcribe Speech", variant="primary") | |
| # Text input | |
| input_text = gr.Textbox( | |
| label="English Text", | |
| placeholder="Type or paste English text here...", | |
| lines=5 | |
| ) | |
| # Quick phrases buttons | |
| gr.Markdown("### 💬 Quick Phrases") | |
| with gr.Row(): | |
| quick_hello = gr.Button("Hello, how are you?", elem_classes="quick-btn") | |
| quick_weather = gr.Button("Nice weather today", elem_classes="quick-btn") | |
| quick_thanks = gr.Button("Thank you very much", elem_classes="quick-btn") | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["Hello, how are you? I hope you are doing well today."], | |
| ["The weather is beautiful today. Let's go for a walk in the park."], | |
| ], | |
| inputs=input_text, | |
| fn=None, | |
| cache_examples=False, | |
| label="Example Texts" | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Bengali Translation", | |
| lines=5, | |
| interactive=False | |
| ) | |
| copy_btn = gr.Button("Copy Translation", variant="secondary") | |
| gr.Markdown("### 🎨 Generate Image from Translation") | |
| use_for_image_btn = gr.Button("Use for Image Generation", variant="primary") | |
| with gr.TabItem("🎨 Fast Image Generation"): | |
| gr.Markdown("## AI Image Generation (Optimized for Speed)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_prompt = gr.Textbox( | |
| label="Image Prompt", | |
| placeholder="Describe the image you want to generate...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Image (Fast!)", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| steps_slider = gr.Slider( | |
| minimum=2, | |
| maximum=8, | |
| value=4, | |
| step=1, | |
| label="Inference Steps (4 is usually enough with LCM)" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image", interactive=False) | |
| status_message = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Tips for Fast Generation") | |
| gr.Markdown("- Use 4 steps for the best speed/quality balance") | |
| gr.Markdown("- Simple prompts work best with fast models") | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| <div style="text-align: center"> | |
| <p>This optimized app uses faster models for better performance on Hugging Face Spaces.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| def transcribe_audio(audio_path): | |
| if audio_path is None: | |
| return "Please record audio first." | |
| try: | |
| # Simple transcription simulation (in a real app, you'd use a speech recognition library) | |
| # For now, we'll just return a placeholder | |
| return "I heard your voice. Please type your text for translation." | |
| except: | |
| return "Error transcribing audio. Please try again or type your text." | |
| def copy_to_clipboard(text): | |
| return text # Gradio will automatically copy text when the button is clicked | |
| # Connect events | |
| transcribe_btn.click( | |
| fn=transcribe_audio, | |
| inputs=audio_input, | |
| outputs=input_text | |
| ) | |
| translate_btn.click( | |
| fn=translate_text, | |
| inputs=input_text, | |
| outputs=output_text | |
| ) | |
| copy_btn.click( | |
| fn=copy_to_clipboard, | |
| inputs=output_text, | |
| outputs=output_text # This will trigger the browser's copy functionality | |
| ) | |
| use_for_image_btn.click( | |
| fn=lambda x: x, | |
| inputs=output_text, | |
| outputs=image_prompt | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[image_prompt, steps_slider], | |
| outputs=[output_image, status_message] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None, None, None], | |
| inputs=None, | |
| outputs=[image_prompt, output_image, status_message] | |
| ) | |
| # Quick phrase buttons | |
| quick_hello.click( | |
| fn=lambda: "Hello, how are you?", | |
| inputs=None, | |
| outputs=input_text | |
| ) | |
| quick_weather.click( | |
| fn=lambda: "The weather is nice today.", | |
| inputs=None, | |
| outputs=input_text | |
| ) | |
| quick_thanks.click( | |
| fn=lambda: "Thank you very much for your help.", | |
| inputs=None, | |
| outputs=input_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |