File size: 4,831 Bytes
081dd9e
 
e7b8414
081dd9e
a3c9c6a
935a252
081dd9e
935a252
 
 
081dd9e
935a252
 
081dd9e
935a252
a3c9c6a
081dd9e
 
935a252
 
b3b2606
935a252
 
081dd9e
935a252
 
 
 
b3b2606
 
935a252
081dd9e
 
935a252
081dd9e
 
 
 
 
 
935a252
b3b2606
935a252
081dd9e
935a252
b8dc19a
 
 
b453f3e
 
b8dc19a
 
b453f3e
 
 
 
b8dc19a
 
cf3cb5d
b8dc19a
 
935a252
b8dc19a
081dd9e
b453f3e
e7b8414
935a252
d307aeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7b8414
935a252
 
e7b8414
935a252
 
 
081dd9e
935a252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b2606
081dd9e
b3b2606
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import io
import datetime
import tempfile
import gradio as gr
import spaces
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# Models
TRANSCRIBE_MODEL = "kotoba-tech/kotoba-whisper-v1.0"
TRANSLATE_MODEL  = "tencent/Hunyuan-MT-7B"

# Load pipelines
transcriber = pipeline(
    "automatic-speech-recognition",
    model=TRANSCRIBE_MODEL,
    generate_kwargs={"language": "japanese", "task": "transcribe"}
)

tokenizer_t = AutoTokenizer.from_pretrained(TRANSLATE_MODEL)
model_t     = AutoModelForCausalLM.from_pretrained(TRANSLATE_MODEL, device_map="auto")

def format_srt(chunks, texts):
    """Build SRT using chunks list (with start/end) and corresponding texts list."""
    srt_lines = []
    for idx, (chunk, text) in enumerate(zip(chunks, texts), start=1):
        start = chunk.get("start") or (chunk.get("timestamp")[0] if "timestamp" in chunk else None)
        end   = chunk.get("end")   or (chunk.get("timestamp")[1] if "timestamp" in chunk else None)
        if start is None or end is None: 
            continue
        start_ts = datetime.timedelta(seconds=float(start))
        end_ts   = datetime.timedelta(seconds=float(end))
        def fmt(ts):
            total_seconds = int(ts.total_seconds())
            hours   = total_seconds // 3600
            minutes = (total_seconds % 3600) // 60
            seconds = total_seconds % 60
            milliseconds = int((ts.total_seconds() - total_seconds) * 1000)
            return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
        srt_lines.append(f"{idx}")
        srt_lines.append(f"{fmt(start_ts)} --> {fmt(end_ts)}")
        srt_lines.append(text.strip())
        srt_lines.append("")
    return "\n".join(srt_lines)

def translate_text(text, target_lang="English"):
    """Translate text to English using Hunyuan-MT-7B."""
    prompt = f"Translate the following segment into English, without explanation:\n\n{text}"

    # Tokenize
    inputs = tokenizer_t(prompt, return_tensors="pt")
    inputs.pop("token_type_ids", None)

    # ✅ Move input tensors to the same device as the model
    inputs = {k: v.to(model_t.device) for k, v in inputs.items()}

    # Generate translation
    outputs = model_t.generate(
        **inputs,
        max_new_tokens=512
    )

    result = tokenizer_t.decode(outputs[0], skip_special_tokens=True)
    return result.strip()


@spaces.GPU
def process_audio(audio_file, translate: bool):
    """Transcribe, optionally translate, and return subtitle paths."""
    import traceback, sys
    try:
        # --- Transcription ---
        res = transcriber(audio_file, return_timestamps=True)
        full_text = res.get("text", "")
        chunks = res.get("chunks", []) or res.get("segments", [])
        orig_texts = [c.get("text", "") for c in chunks]

        # --- Original subtitles ---
        orig_srt_content = format_srt(chunks, orig_texts)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_orig:
            tmp_orig.write(orig_srt_content.encode("utf-8"))
            orig_srt_path = tmp_orig.name

        # --- Optional translation ---
        if translate:
            translated_texts = [translate_text(txt) for txt in orig_texts]
            trans_srt_content = format_srt(chunks, translated_texts)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".srt") as tmp_tr:
                tmp_tr.write(trans_srt_content.encode("utf-8"))
                trans_srt_path = tmp_tr.name
        else:
            trans_srt_path = None

        return full_text, orig_srt_path, trans_srt_path

    except Exception as e:
        print("🚨 GPU worker exception:", e, file=sys.stderr)
        traceback.print_exc()
        # Return placeholders so Gradio doesn’t crash
        return f"Error: {e}", None, None


with gr.Blocks() as demo:
    gr.Markdown("## 🎙️ Audio → Text + (Optional) English SRT Translation")

    audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload audio (.wav or .mp3)")
    translate_checkbox = gr.Checkbox(label="Translate to English subtitles", value=True)
    process_button     = gr.Button("Process Audio", variant="primary")

    transcript_box     = gr.Textbox(label="Transcript (Original)")
    download_orig      = gr.File(label="Download original .srt file")
    download_trans     = gr.File(label="Download English .srt file")

    def wrapper(audio_path, translate_flag):
        full, orig_path, trans_path = process_audio(audio_path, translate_flag)
        if translate_flag and trans_path:
            return full, orig_path, trans_path
        else:
            return full, orig_path, None

    process_button.click(
        wrapper,
        inputs=[audio_input, translate_checkbox],
        outputs=[transcript_box, download_orig, download_trans]
    )

demo.launch()