Spaces:
Sleeping
Sleeping
File size: 4,831 Bytes
081dd9e e7b8414 081dd9e a3c9c6a 935a252 081dd9e 935a252 081dd9e 935a252 081dd9e 935a252 a3c9c6a 081dd9e 935a252 b3b2606 935a252 081dd9e 935a252 b3b2606 935a252 081dd9e 935a252 081dd9e 935a252 b3b2606 935a252 081dd9e 935a252 b8dc19a b453f3e b8dc19a b453f3e b8dc19a cf3cb5d b8dc19a 935a252 b8dc19a 081dd9e b453f3e e7b8414 935a252 d307aeb e7b8414 935a252 e7b8414 935a252 081dd9e 935a252 b3b2606 081dd9e b3b2606 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import io
import datetime
import tempfile
import gradio as gr
import spaces
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Models
TRANSCRIBE_MODEL = "kotoba-tech/kotoba-whisper-v1.0"
TRANSLATE_MODEL = "tencent/Hunyuan-MT-7B"
# Load pipelines
transcriber = pipeline(
"automatic-speech-recognition",
model=TRANSCRIBE_MODEL,
generate_kwargs={"language": "japanese", "task": "transcribe"}
)
tokenizer_t = AutoTokenizer.from_pretrained(TRANSLATE_MODEL)
model_t = AutoModelForCausalLM.from_pretrained(TRANSLATE_MODEL, device_map="auto")
def format_srt(chunks, texts):
"""Build SRT using chunks list (with start/end) and corresponding texts list."""
srt_lines = []
for idx, (chunk, text) in enumerate(zip(chunks, texts), start=1):
start = chunk.get("start") or (chunk.get("timestamp")[0] if "timestamp" in chunk else None)
end = chunk.get("end") or (chunk.get("timestamp")[1] if "timestamp" in chunk else None)
if start is None or end is None:
continue
start_ts = datetime.timedelta(seconds=float(start))
end_ts = datetime.timedelta(seconds=float(end))
def fmt(ts):
total_seconds = int(ts.total_seconds())
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
seconds = total_seconds % 60
milliseconds = int((ts.total_seconds() - total_seconds) * 1000)
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
srt_lines.append(f"{idx}")
srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}")
srt_lines.append(text.strip())
srt_lines.append("")
return "\n".join(srt_lines)
def translate_text(text, target_lang="English"):
"""Translate text to English using Hunyuan-MT-7B."""
prompt = f"Translate the following segment into English, without explanation:\n\n{text}"
# Tokenize
inputs = tokenizer_t(prompt, return_tensors="pt")
inputs.pop("token_type_ids", None)
# ✅ Move input tensors to the same device as the model
inputs = {k: v.to(model_t.device) for k, v in inputs.items()}
# Generate translation
outputs = model_t.generate(
**inputs,
max_new_tokens=512
)
result = tokenizer_t.decode(outputs[0], skip_special_tokens=True)
return result.strip()
@spaces.GPU
def process_audio(audio_file, translate: bool):
"""Transcribe, optionally translate, and return subtitle paths."""
import traceback, sys
try:
# --- Transcription ---
res = transcriber(audio_file, return_timestamps=True)
full_text = res.get("text", "")
chunks = res.get("chunks", []) or res.get("segments", [])
orig_texts = [c.get("text", "") for c in chunks]
# --- Original subtitles ---
orig_srt_content = format_srt(chunks, orig_texts)
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_orig:
tmp_orig.write(orig_srt_content.encode("utf-8"))
orig_srt_path = tmp_orig.name
# --- Optional translation ---
if translate:
translated_texts = [translate_text(txt) for txt in orig_texts]
trans_srt_content = format_srt(chunks, translated_texts)
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_tr:
tmp_tr.write(trans_srt_content.encode("utf-8"))
trans_srt_path = tmp_tr.name
else:
trans_srt_path = None
return full_text, orig_srt_path, trans_srt_path
except Exception as e:
print("🚨 GPU worker exception:", e, file=sys.stderr)
traceback.print_exc()
# Return placeholders so Gradio doesn’t crash
return f"Error: {e}", None, None
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ Audio → Text + (Optional) English SRT Translation")
audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)")
translate_checkbox = gr.Checkbox(label="Translate to English subtitles", value=True)
process_button = gr.Button("Process Audio", variant="primary")
transcript_box = gr.Textbox(label="Transcript (Original)")
download_orig = gr.File(label="Download original .srt file")
download_trans = gr.File(label="Download English .srt file")
def wrapper(audio_path, translate_flag):
full, orig_path, trans_path = process_audio(audio_path, translate_flag)
if translate_flag and trans_path:
return full, orig_path, trans_path
else:
return full, orig_path, None
process_button.click(
wrapper,
inputs=[audio_input, translate_checkbox],
outputs=[transcript_box, download_orig, download_trans]
)
demo.launch() |