Spaces:
Sleeping
Sleeping
| import io | |
| import datetime | |
| import tempfile | |
| import gradio as gr | |
| import spaces | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| # Models | |
| TRANSCRIBE_MODEL = "kotoba-tech/kotoba-whisper-v1.0" | |
| TRANSLATE_MODEL = "tencent/Hunyuan-MT-7B" | |
| # Load pipelines | |
| transcriber = pipeline( | |
| "automatic-speech-recognition", | |
| model=TRANSCRIBE_MODEL, | |
| generate_kwargs={"language": "japanese", "task": "transcribe"} | |
| ) | |
| tokenizer_t = AutoTokenizer.from_pretrained(TRANSLATE_MODEL) | |
| model_t = AutoModelForCausalLM.from_pretrained(TRANSLATE_MODEL, device_map="auto") | |
| def format_srt(chunks, texts): | |
| """Build SRT using chunks list (with start/end) and corresponding texts list.""" | |
| srt_lines = [] | |
| for idx, (chunk, text) in enumerate(zip(chunks, texts), start=1): | |
| start = chunk.get("start") or (chunk.get("timestamp")[0] if "timestamp" in chunk else None) | |
| end = chunk.get("end") or (chunk.get("timestamp")[1] if "timestamp" in chunk else None) | |
| if start is None or end is None: | |
| continue | |
| start_ts = datetime.timedelta(seconds=float(start)) | |
| end_ts = datetime.timedelta(seconds=float(end)) | |
| def fmt(ts): | |
| total_seconds = int(ts.total_seconds()) | |
| hours = total_seconds // 3600 | |
| minutes = (total_seconds % 3600) // 60 | |
| seconds = total_seconds % 60 | |
| milliseconds = int((ts.total_seconds() - total_seconds) * 1000) | |
| return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" | |
| srt_lines.append(f"{idx}") | |
| srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}") | |
| srt_lines.append(text.strip()) | |
| srt_lines.append("") | |
| return "\n".join(srt_lines) | |
| def translate_text(text, target_lang="English"): | |
| """Translate text to English using Hunyuan-MT-7B.""" | |
| prompt = f"Translate the following segment into English, without explanation:\n\n{text}" | |
| # Tokenize | |
| inputs = tokenizer_t(prompt, return_tensors="pt") | |
| inputs.pop("token_type_ids", None) | |
| # ✅ Move input tensors to the same device as the model | |
| inputs = {k: v.to(model_t.device) for k, v in inputs.items()} | |
| # Generate translation | |
| outputs = model_t.generate( | |
| **inputs, | |
| max_new_tokens=512 | |
| ) | |
| result = tokenizer_t.decode(outputs[0], skip_special_tokens=True) | |
| return result.strip() | |
| def process_audio(audio_file, translate: bool): | |
| """Transcribe, optionally translate, and return subtitle paths.""" | |
| import traceback, sys | |
| try: | |
| # --- Transcription --- | |
| res = transcriber(audio_file, return_timestamps=True) | |
| full_text = res.get("text", "") | |
| chunks = res.get("chunks", []) or res.get("segments", []) | |
| orig_texts = [c.get("text", "") for c in chunks] | |
| # --- Original subtitles --- | |
| orig_srt_content = format_srt(chunks, orig_texts) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_orig: | |
| tmp_orig.write(orig_srt_content.encode("utf-8")) | |
| orig_srt_path = tmp_orig.name | |
| # --- Optional translation --- | |
| if translate: | |
| translated_texts = [translate_text(txt) for txt in orig_texts] | |
| trans_srt_content = format_srt(chunks, translated_texts) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_tr: | |
| tmp_tr.write(trans_srt_content.encode("utf-8")) | |
| trans_srt_path = tmp_tr.name | |
| else: | |
| trans_srt_path = None | |
| return full_text, orig_srt_path, trans_srt_path | |
| except Exception as e: | |
| print("🚨 GPU worker exception:", e, file=sys.stderr) | |
| traceback.print_exc() | |
| # Return placeholders so Gradio doesn’t crash | |
| return f"Error: {e}", None, None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎙️ Audio → Text + (Optional) English SRT Translation") | |
| audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)") | |
| translate_checkbox = gr.Checkbox(label="Translate to English subtitles", value=True) | |
| process_button = gr.Button("Process Audio", variant="primary") | |
| transcript_box = gr.Textbox(label="Transcript (Original)") | |
| download_orig = gr.File(label="Download original .srt file") | |
| download_trans = gr.File(label="Download English .srt file") | |
| def wrapper(audio_path, translate_flag): | |
| full, orig_path, trans_path = process_audio(audio_path, translate_flag) | |
| if translate_flag and trans_path: | |
| return full, orig_path, trans_path | |
| else: | |
| return full, orig_path, None | |
| process_button.click( | |
| wrapper, | |
| inputs=[audio_input, translate_checkbox], | |
| outputs=[transcript_box, download_orig, download_trans] | |
| ) | |
| demo.launch() |