anisgtboi's picture
Update app.py
98734d9 verified
raw
history blame
11.3 kB
import gradio as gr
import re
import time
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from diffusers import StableDiffusionPipeline, LCMScheduler
import random
# --- Configuration ---
TRANSLATION_MODEL = "facebook/nllb-200-distilled-600M"
SRC_LANG = "eng_Latn"
TGT_LANG = "ben_Beng"
MAX_LENGTH = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Globals for caching ---
translation_tokenizer = None
translation_model = None
image_pipe = None
# --- Translation Functions ---
def load_translation_model():
global translation_tokenizer, translation_model
if translation_tokenizer is None or translation_model is None:
try:
print(f"Loading translation model {TRANSLATION_MODEL} on {DEVICE} ...")
translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL).to(DEVICE)
print("Translation model loaded successfully.")
except Exception as e:
translation_tokenizer, translation_model = None, None
raise RuntimeError(f"Failed to load translation model: {e}")
return translation_tokenizer, translation_model
def split_into_sentences(text: str):
if not text:
return []
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
return [s.strip() for s in sentences if s.strip()]
def translate_text(text: str, max_length: int = MAX_LENGTH):
if not text or not text.strip():
return ""
try:
tokenizer, model = load_translation_model()
except Exception as e:
return f"Model load error: {e}"
sentences = split_into_sentences(text)
translations = []
for s in sentences:
if not s:
continue
try:
formatted_text = f"{SRC_LANG} {s}"
inputs = tokenizer(
formatted_text,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=False,
).to(DEVICE)
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(TGT_LANG),
max_length=max_length + 64,
num_beams=5,
early_stopping=True,
)
decoded = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
if decoded.startswith(TGT_LANG):
decoded = decoded[len(TGT_LANG):].strip()
translations.append(decoded)
except RuntimeError as re_err:
return f"Runtime error during generation: {re_err}"
except Exception as e:
translations.append(f"[Error translating sentence: {e}]")
return " ".join(translations)
# --- Faster Image Generation Functions ---
def load_image_model():
global image_pipe
if image_pipe is None:
try:
print("Loading faster image generation model...")
# Using a much faster model with LCM-LoRA
model_id = "lykon/dreamshaper-7"
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
image_pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
variant="fp16",
)
# Load LCM-LoRA for faster inference
image_pipe.load_lora_weights(lcm_lora_id)
image_pipe.scheduler = LCMScheduler.from_config(image_pipe.scheduler.config)
image_pipe = image_pipe.to(DEVICE)
print("Fast image generation model loaded successfully.")
except Exception as e:
image_pipe = None
raise RuntimeError(f"Failed to load image model: {e}")
return image_pipe
def generate_image(prompt: str, num_inference_steps: int = 4): # Only 4 steps needed with LCM!
if not prompt or not prompt.strip():
return None, "Please enter a prompt to generate an image."
try:
pipe = load_image_model()
seed = random.randint(0, 1000000)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
# Generate image with very few steps
image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=1.0, # Low guidance scale for LCM
generator=generator,
).images[0]
return image, f"Successfully generated image in {num_inference_steps} steps!"
except Exception as e:
return None, f"Error generating image: {str(e)}"
# --- Gradio UI with Real-time Features ---
css = """
.gradio-container {
max-width: 1200px !important;
}
.header {
text-align: center;
background: linear-gradient(45deg, #667eea, #764ba2);
padding: 20px;
border-radius: 10px;
color: white;
margin-bottom: 20px;
}
.quick-btn {
margin: 5px;
padding: 8px 15px;
}
"""
with gr.Blocks(title="Fast Bengali Translator & Image Generator", css=css) as demo:
# Header
gr.Markdown("""
<div class="header">
<h1>⚡ Fast Bengali Translator & Image Generator</h1>
<p>Real-time speech input with fast translation and image generation</p>
</div>
""")
with gr.Tabs():
with gr.TabItem("🌐 Translation"):
gr.Markdown("## English to Bengali Translation")
with gr.Row():
with gr.Column():
# Speech input section
gr.Markdown("### 🎤 Speak or Type")
audio_input = gr.Audio(type="filepath", label="Record your voice")
transcribe_btn = gr.Button("Transcribe Speech", variant="primary")
# Text input
input_text = gr.Textbox(
label="English Text",
placeholder="Type or paste English text here...",
lines=5
)
# Quick phrases buttons
gr.Markdown("### 💬 Quick Phrases")
with gr.Row():
quick_hello = gr.Button("Hello, how are you?", elem_classes="quick-btn")
quick_weather = gr.Button("Nice weather today", elem_classes="quick-btn")
quick_thanks = gr.Button("Thank you very much", elem_classes="quick-btn")
translate_btn = gr.Button("Translate", variant="primary")
examples = gr.Examples(
examples=[
["Hello, how are you? I hope you are doing well today."],
["The weather is beautiful today. Let's go for a walk in the park."],
],
inputs=input_text,
fn=None,
cache_examples=False,
label="Example Texts"
)
with gr.Column():
output_text = gr.Textbox(
label="Bengali Translation",
lines=5,
interactive=False
)
copy_btn = gr.Button("Copy Translation", variant="secondary")
gr.Markdown("### 🎨 Generate Image from Translation")
use_for_image_btn = gr.Button("Use for Image Generation", variant="primary")
with gr.TabItem("🎨 Fast Image Generation"):
gr.Markdown("## AI Image Generation (Optimized for Speed)")
with gr.Row():
with gr.Column():
image_prompt = gr.Textbox(
label="Image Prompt",
placeholder="Describe the image you want to generate...",
lines=3
)
with gr.Row():
generate_btn = gr.Button("Generate Image (Fast!)", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
steps_slider = gr.Slider(
minimum=2,
maximum=8,
value=4,
step=1,
label="Inference Steps (4 is usually enough with LCM)"
)
with gr.Column():
output_image = gr.Image(label="Generated Image", interactive=False)
status_message = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Tips for Fast Generation")
gr.Markdown("- Use 4 steps for the best speed/quality balance")
gr.Markdown("- Simple prompts work best with fast models")
# Footer
gr.Markdown("---")
gr.Markdown("""
<div style="text-align: center">
<p>This optimized app uses faster models for better performance on Hugging Face Spaces.</p>
</div>
""")
# Event handlers
def transcribe_audio(audio_path):
if audio_path is None:
return "Please record audio first."
try:
# Simple transcription simulation (in a real app, you'd use a speech recognition library)
# For now, we'll just return a placeholder
return "I heard your voice. Please type your text for translation."
except:
return "Error transcribing audio. Please try again or type your text."
def copy_to_clipboard(text):
return text # Gradio will automatically copy text when the button is clicked
# Connect events
transcribe_btn.click(
fn=transcribe_audio,
inputs=audio_input,
outputs=input_text
)
translate_btn.click(
fn=translate_text,
inputs=input_text,
outputs=output_text
)
copy_btn.click(
fn=copy_to_clipboard,
inputs=output_text,
outputs=output_text # This will trigger the browser's copy functionality
)
use_for_image_btn.click(
fn=lambda x: x,
inputs=output_text,
outputs=image_prompt
)
generate_btn.click(
fn=generate_image,
inputs=[image_prompt, steps_slider],
outputs=[output_image, status_message]
)
clear_btn.click(
fn=lambda: [None, None, None],
inputs=None,
outputs=[image_prompt, output_image, status_message]
)
# Quick phrase buttons
quick_hello.click(
fn=lambda: "Hello, how are you?",
inputs=None,
outputs=input_text
)
quick_weather.click(
fn=lambda: "The weather is nice today.",
inputs=None,
outputs=input_text
)
quick_thanks.click(
fn=lambda: "Thank you very much for your help.",
inputs=None,
outputs=input_text
)
if __name__ == "__main__":
demo.launch(debug=True)