RepeatAfterMe / app.py
frimelle's picture
frimelle HF Staff
try to recreate the code
0d9ff36
raw
history blame
8.76 kB
import gradio as gr
import src.generate as generate
import src.process as process
import src.tts as tts
# ------------------- UI printing functions -------------------
def clear_all():
# target, user_transcript, score_html, diff_html, result_html,
# tts_text, clone_status, tts_audio
return "", "", "", "", "", "", "", None
def make_result_html(pass_threshold, passed, ratio):
"""Returns summary and score label."""
summary = (
f"βœ… Correct (β‰₯ {int(pass_threshold * 100)}%)"
if passed else
f"❌ Not a match (need β‰₯ {int(pass_threshold * 100)}%)"
)
score = f"Similarity: {ratio * 100:.1f}%"
return summary, score
def make_alignment_html(ref_tokens, hyp_tokens, alignments):
"""Returns HTML showing alignment between target and recognized user audio."""
out = []
no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">'
match_html = ' <span style="background:#e0ffe0;">'
for span in alignments:
op, i1, i2, j1, j2 = span
ref_string = " ".join(ref_tokens[i1:i2])
hyp_string = " ".join(hyp_tokens[j1:j2])
if op == "equal":
out.append(" " + ref_string)
elif op == "delete":
out.append(no_match_html + ref_string + "</span>")
elif op == "insert":
out.append(match_html + hyp_string + "</span>")
elif op == "replace":
out.append(no_match_html + ref_string + "</span>")
out.append(match_html + hyp_string + "</span>")
html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>"
return html
def make_html(sentence_match):
"""Build diff + results HTML."""
diff_html = make_alignment_html(sentence_match.target_tokens,
sentence_match.user_tokens,
sentence_match.alignments)
result_html, score_html = make_result_html(sentence_match.pass_threshold,
sentence_match.passed,
sentence_match.ratio)
return score_html, result_html, diff_html
# ------------------- Core Check (English-only) -------------------
def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str):
"""ASR for the input audio and basic validation."""
if not target_sentence:
return "Please generate a sentence first.", ""
if audio_path is None:
return "Please start, record, then stop the audio recording before trying to transcribe.", ""
user_transcript = process.run_asr(audio_path, model_id, device_pref)
if isinstance(user_transcript, Exception):
return f"Transcription failed: {user_transcript}", ""
return "", user_transcript
def transcribe_check(audio_path, target_sentence, model_id, device_pref, pass_threshold):
"""Transcribe user audio, compute match, and render results."""
error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id, device_pref)
if error_msg:
score_html = ""
diff_html = ""
result_html = error_msg
else:
sentence_match = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
score_html, result_html, diff_html = make_html(sentence_match)
return user_transcript, score_html, result_html, diff_html
# ------------------- Voice cloning gate -------------------
def clone_if_pass(
audio_path, # ref voice (the same recorded clip)
target_sentence, # sentence user was supposed to say
user_transcript, # what ASR heard
tts_text, # what we want to synthesize (in cloned voice)
pass_threshold, # must meet or exceed this
tts_model_id, # e.g., "coqui/XTTS-v2"
tts_language, # e.g., "en"
):
"""
If user correctly read the target (>= threshold), clone their voice from the
recorded audio and speak 'tts_text'. Otherwise, refuse.
"""
# Basic validations
if audio_path is None:
return None, "Record audio first (reference voice is required)."
if not target_sentence:
return None, "Generate a target sentence first."
if not user_transcript:
return None, "Transcribe first to verify the sentence."
if not tts_text:
return None, "Enter the sentence to synthesize."
# Recompute pass/fail to avoid relying on UI state
sm = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
if not sm.passed:
return None, (
f"❌ Cloning blocked: your reading did not reach the threshold "
f"({sm.ratio*100:.1f}% < {int(pass_threshold*100)}%)."
)
# Run zero-shot cloning
out = tts.run_tts_clone(audio_path, tts_text, model_id=tts_model_id, language=tts_language)
if isinstance(out, Exception):
return None, f"Voice cloning failed: {out}"
sr, wav = out
# Gradio Audio can take a tuple (sr, np.array)
return (sr, wav), f"βœ… Cloned and synthesized with {tts_model_id} ({tts_language})."
# ------------------- UI -------------------
with gr.Blocks(title="Say the Sentence (English)") as demo:
gr.Markdown(
"""
# 🎀 Say the Sentence (English)
1) Generate a sentence.
2) Record yourself reading it.
3) Transcribe & check your accuracy.
4) If matched, clone your voice to speak any sentence you enter.
"""
)
with gr.Row():
target = gr.Textbox(label="Target sentence", interactive=False,
placeholder="Click 'Generate sentence'")
with gr.Row():
btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
btn_clear = gr.Button("🧹 Clear")
with gr.Row():
audio = gr.Audio(sources=["microphone"], type="filepath",
label="Record your voice")
with gr.Accordion("Advanced settings", open=False):
model_id = gr.Dropdown(
choices=[
"openai/whisper-tiny.en",
"openai/whisper-base.en",
"distil-whisper/distil-small.en",
],
value="openai/whisper-tiny.en",
label="ASR model (English only)",
)
device_pref = gr.Radio(
choices=["auto", "cpu", "cuda"],
value="auto",
label="Device preference"
)
pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01,
label="Match threshold")
with gr.Row():
btn_check = gr.Button("βœ… Transcribe & Check", variant="primary")
with gr.Row():
user_transcript = gr.Textbox(label="Transcription", interactive=False)
with gr.Row():
score_html = gr.Label(label="Score")
result_html = gr.Label(label="Result")
diff_html = gr.HTML(
label="Word-level diff (red = expected but missing / green = extra or replacement)")
gr.Markdown("## πŸ” Voice cloning (gated)")
with gr.Row():
tts_text = gr.Textbox(
label="Text to synthesize (voice clone)",
placeholder="Type the sentence you want the cloned voice to say",
)
with gr.Row():
tts_model_id = gr.Dropdown(
choices=[
"coqui/XTTS-v2",
# add others if you like, e.g., "myshell-ai/MeloTTS"
],
value="coqui/XTTS-v2",
label="TTS (voice cloning) model",
)
tts_language = gr.Dropdown(
choices=["en", "de", "fr", "es", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh"],
value="en",
label="Language",
)
with gr.Row():
btn_clone = gr.Button("πŸ” Clone voice (if passed)", variant="secondary")
with gr.Row():
tts_audio = gr.Audio(label="Cloned speech output", interactive=False)
clone_status = gr.Label(label="Cloning status")
# -------- Events --------
btn_gen.click(fn=generate.gen_sentence_set, outputs=target)
btn_clear.click(
fn=clear_all,
outputs=[target, user_transcript, score_html, result_html, diff_html, tts_text, clone_status, tts_audio]
)
btn_check.click(
fn=transcribe_check,
inputs=[audio, target, model_id, device_pref, pass_threshold],
outputs=[user_transcript, score_html, result_html, diff_html]
)
btn_clone.click(
fn=clone_if_pass,
inputs=[audio, target, user_transcript, tts_text, pass_threshold, tts_model_id, tts_language],
outputs=[tts_audio, clone_status],
)
if __name__ == "__main__":
demo.launch()