MegaTronX's picture
Rename app.py to app.py.bak
5269ddc verified
import io
import datetime
import tempfile
import gradio as gr
import spaces
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Models
TRANSCRIBE_MODEL = "kotoba-tech/kotoba-whisper-v1.0"
TRANSLATE_MODEL = "tencent/Hunyuan-MT-7B"
# Load pipelines
transcriber = pipeline(
"automatic-speech-recognition",
model=TRANSCRIBE_MODEL,
generate_kwargs={"language": "japanese", "task": "transcribe"}
)
tokenizer_t = AutoTokenizer.from_pretrained(TRANSLATE_MODEL)
model_t = AutoModelForCausalLM.from_pretrained(TRANSLATE_MODEL, device_map="auto")
def format_srt(chunks, texts):
"""Build SRT using chunks list (with start/end) and corresponding texts list."""
srt_lines = []
for idx, (chunk, text) in enumerate(zip(chunks, texts), start=1):
start = chunk.get("start") or (chunk.get("timestamp")[0] if "timestamp" in chunk else None)
end = chunk.get("end") or (chunk.get("timestamp")[1] if "timestamp" in chunk else None)
if start is None or end is None:
continue
start_ts = datetime.timedelta(seconds=float(start))
end_ts = datetime.timedelta(seconds=float(end))
def fmt(ts):
total_seconds = int(ts.total_seconds())
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
seconds = total_seconds % 60
milliseconds = int((ts.total_seconds() - total_seconds) * 1000)
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
srt_lines.append(f"{idx}")
srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}")
srt_lines.append(text.strip())
srt_lines.append("")
return "\n".join(srt_lines)
def translate_text(text, target_lang="English"):
"""Translate text to English using Hunyuan-MT-7B."""
prompt = f"Translate the following segment into English, without explanation:\n\n{text}"
# Tokenize
inputs = tokenizer_t(prompt, return_tensors="pt")
inputs.pop("token_type_ids", None)
# ✅ Move input tensors to the same device as the model
inputs = {k: v.to(model_t.device) for k, v in inputs.items()}
# Generate translation
outputs = model_t.generate(
**inputs,
max_new_tokens=512
)
result = tokenizer_t.decode(outputs[0], skip_special_tokens=True)
return result.strip()
@spaces.GPU
def process_audio(audio_file, translate: bool):
"""Transcribe, optionally translate, and return subtitle paths."""
import traceback, sys
try:
# --- Transcription ---
res = transcriber(audio_file, return_timestamps=True)
full_text = res.get("text", "")
chunks = res.get("chunks", []) or res.get("segments", [])
orig_texts = [c.get("text", "") for c in chunks]
# --- Original subtitles ---
orig_srt_content = format_srt(chunks, orig_texts)
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_orig:
tmp_orig.write(orig_srt_content.encode("utf-8"))
orig_srt_path = tmp_orig.name
# --- Optional translation ---
if translate:
translated_texts = [translate_text(txt) for txt in orig_texts]
trans_srt_content = format_srt(chunks, translated_texts)
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_tr:
tmp_tr.write(trans_srt_content.encode("utf-8"))
trans_srt_path = tmp_tr.name
else:
trans_srt_path = None
return full_text, orig_srt_path, trans_srt_path
except Exception as e:
print("🚨 GPU worker exception:", e, file=sys.stderr)
traceback.print_exc()
# Return placeholders so Gradio doesn’t crash
return f"Error: {e}", None, None
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ Audio → Text + (Optional) English SRT Translation")
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)")
translate_checkbox = gr.Checkbox(label="Translate to English subtitles", value=True)
process_button = gr.Button("Process Audio", variant="primary")
transcript_box = gr.Textbox(label="Transcript (Original)")
download_orig = gr.File(label="Download original .srt file")
download_trans = gr.File(label="Download English .srt file")
def wrapper(audio_path, translate_flag):
full, orig_path, trans_path = process_audio(audio_path, translate_flag)
if translate_flag and trans_path:
return full, orig_path, trans_path
else:
return full, orig_path, None
process_button.click(
wrapper,
inputs=[audio_input, translate_checkbox],
outputs=[transcript_box, download_orig, download_trans]
)
demo.launch()