meg HF Staff commited on
Commit
aad4cd6
·
verified ·
1 Parent(s): 63d0469

Modularizing, documenting, and adding LLM-generation support.

Browse files
Files changed (1) hide show
  1. app.py +147 -120
app.py CHANGED
@@ -1,118 +1,137 @@
1
  import gradio as gr
2
- import random
3
- import re
4
- import difflib
5
- import torch
6
- from functools import lru_cache
7
- from transformers import pipeline
8
-
9
- # ------------------- Sentence Bank (customize freely) -------------------
10
- SENTENCE_BANK = [
11
- "The quick brown fox jumps over the lazy dog.",
12
- "I promise to speak clearly and at a steady pace.",
13
- "Open source makes AI more transparent and inclusive.",
14
- "Hugging Face Spaces make demos easy to share.",
15
- "Today the weather in Berlin is pleasantly cool.",
16
- "Privacy and transparency should go hand in hand.",
17
- "Please generate a new sentence for me to read.",
18
- "Machine learning can amplify or reduce inequality.",
19
- "Responsible AI requires participation from everyone.",
20
- "This microphone test checks my pronunciation accuracy.",
21
- ]
22
-
23
- # ------------------- Utilities -------------------
24
- def normalize_text(t: str) -> str:
25
- # English-only normalization: lowercase, keep letters/digits/' and -
26
- t = t.lower()
27
- t = re.sub(r"[^a-z0-9'\-]+", " ", t)
28
- t = re.sub(r"\s+", " ", t).strip()
29
- return t
30
-
31
- def similarity_and_diff(ref: str, hyp: str):
32
- """Return similarity ratio (0..1) and HTML diff highlighting changes."""
33
- ref_tokens = ref.split()
34
- hyp_tokens = hyp.split()
35
- sm = difflib.SequenceMatcher(a=ref_tokens, b=hyp_tokens)
36
- ratio = sm.ratio()
37
 
38
- out = []
39
- for op, i1, i2, j1, j2 in sm.get_opcodes():
40
- if op == "equal":
41
- out.append(" " + " ".join(ref_tokens[i1:i2]))
42
- elif op == "delete":
43
- out.append(
44
- ' <span style="background:#ffe0e0;text-decoration:line-through;">'
45
- + " ".join(ref_tokens[i1:i2]) + "</span>"
46
- )
47
- elif op == "insert":
48
- out.append(
49
- ' <span style="background:#e0ffe0;">'
50
- + " ".join(hyp_tokens[j1:j2]) + "</span>"
51
- )
52
- elif op == "replace":
53
- out.append(
54
- ' <span style="background:#ffe0e0;text-decoration:line-through;">'
55
- + " ".join(ref_tokens[i1:i2]) + "</span>"
56
- )
57
- out.append(
58
- ' <span style="background:#e0ffe0;">'
59
- + " ".join(hyp_tokens[j1:j2]) + "</span>"
60
- )
61
- html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>"
62
- return ratio, html
63
-
64
- @lru_cache(maxsize=2)
65
- def get_asr(model_id: str, device_preference: str):
66
- """Cache an ASR pipeline. device_preference: 'auto'|'cpu'|'cuda'."""
67
- if device_preference == "cuda" and torch.cuda.is_available():
68
- device = 0
69
- elif device_preference == "auto":
70
- device = 0 if torch.cuda.is_available() else -1
71
- else:
72
- device = -1
73
- return pipeline(
74
- "automatic-speech-recognition",
75
- model=model_id, # use English-only Whisper models (.en)
76
- device=device,
77
- chunk_length_s=30,
78
- return_timestamps=False,
79
- )
80
 
81
- def gen_sentence():
82
- return random.choice(SENTENCE_BANK)
83
 
 
84
  def clear_all():
85
- # target, hyp_out, score_out, diff_out, summary_out
86
  return "", "", "", "", ""
87
 
88
- # ------------------- Core Check (English-only) -------------------
89
- def check_pronunciation(audio_path, target_sentence, model_id, device_pref, pass_threshold):
90
- if not target_sentence:
91
- return "", "", "", "Please generate a sentence first."
92
 
93
- asr = get_asr(model_id, device_pref)
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- try:
96
- # IMPORTANT: For English-only Whisper (.en), do NOT pass language/task args.
97
- result = asr(audio_path)
98
- hyp_raw = result["text"].strip()
99
- except Exception as e:
100
- return "", "", "", f"Transcription failed: {e}"
101
 
102
- ref_norm = normalize_text(target_sentence)
103
- hyp_norm = normalize_text(hyp_raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- ratio, diff_html = similarity_and_diff(ref_norm, hyp_norm)
106
- passed = ratio >= pass_threshold
107
 
108
- summary = (
109
- f"✅ Correct ( {int(pass_threshold*100)}%)"
110
- if passed else
111
- f"❌ Not a match (need ≥ {int(pass_threshold*100)}%)"
112
- )
113
- score = f"Similarity: {ratio*100:.1f}%"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- return hyp_raw, score, diff_html, summary
116
 
117
  # ------------------- UI -------------------
118
  with gr.Blocks(title="Say the Sentence (English)") as demo:
@@ -126,21 +145,24 @@ with gr.Blocks(title="Say the Sentence (English)") as demo:
126
  )
127
 
128
  with gr.Row():
129
- target = gr.Textbox(label="Target sentence", interactive=False, placeholder="Click 'Generate sentence'")
 
130
 
131
  with gr.Row():
132
  btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
133
  btn_clear = gr.Button("🧹 Clear")
134
 
135
  with gr.Row():
136
- audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your voice")
 
137
 
138
  with gr.Accordion("Advanced settings", open=False):
139
  model_id = gr.Dropdown(
140
  choices=[
141
- "openai/whisper-tiny.en", # fastest (CPU-friendly)
142
- "openai/whisper-base.en", # better accuracy, a bit slower
143
- "distil-whisper/distil-small.en" # optional distil English model
 
144
  ],
145
  value="openai/whisper-tiny.en",
146
  label="ASR model (English only)",
@@ -150,26 +172,31 @@ with gr.Blocks(title="Say the Sentence (English)") as demo:
150
  value="auto",
151
  label="Device preference"
152
  )
153
- pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01, label="Match threshold")
 
154
 
155
  with gr.Row():
156
  btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
157
-
158
  with gr.Row():
159
- hyp_out = gr.Textbox(label="Transcription", interactive=False)
160
  with gr.Row():
161
- score_out = gr.Label(label="Score")
162
- summary_out = gr.Label(label="Result")
163
- diff_out = gr.HTML(label="Word-level diff (red = expected but missing / green = extra or replacement)")
164
-
165
- # Events
166
- btn_gen.click(fn=gen_sentence, outputs=target)
167
- btn_clear.click(fn=clear_all, outputs=[target, hyp_out, score_out, diff_out, summary_out])
 
 
 
 
 
168
  btn_check.click(
169
- fn=check_pronunciation,
170
  inputs=[audio, target, model_id, device_pref, pass_threshold],
171
- outputs=[hyp_out, score_out, diff_out, summary_out]
172
  )
173
 
174
  if __name__ == "__main__":
175
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import src.generate as generate
4
+ import src.process as process
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
6
 
7
+ # ------------------- UI printing functions -------------------
8
  def clear_all():
9
+ # target, user_transcript, score_html, diff_html, result_html
10
  return "", "", "", "", ""
11
 
 
 
 
 
12
 
13
+ def make_result_html(pass_threshold, passed, ratio):
14
+ """Returns HTML summarizing results.
15
+ Parameters:
16
+ pass_threshold: Minimum percentage of match between target and recognized user utterance that counts as passing.
17
+ passed: Whether the recognized user utterance is >= `pass_threshold`.
18
+ ratio: Sequence match ratio.
19
+ """
20
+ summary = (
21
+ f"✅ Correct (≥ {int(pass_threshold * 100)}%)"
22
+ if passed else
23
+ f"❌ Not a match (need ≥ {int(pass_threshold * 100)}%)"
24
+ )
25
+ score = f"Similarity: {ratio * 100:.1f}%"
26
+ return summary, score
27
 
 
 
 
 
 
 
28
 
29
+ def make_alignment_html(ref_tokens, hyp_tokens, alignments):
30
+ """Returns HTML showing alignment between the target and recognized user audio.
31
+ Parameters:
32
+ ref_tokens: Target sentence for the user to say, tokenized.
33
+ hyp_tokens: Recognized utterance from the user, tokenized.
34
+ alignments: Tuples of alignment pattern (equal, delete, insert) and corresponding indices in `hyp_tokens`.
35
+ """
36
+ out = []
37
+ no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">'
38
+ match_html = ' <span style="background:#e0ffe0;">'
39
+ for span in alignments:
40
+ op, i1, i2, j1, j2 = span
41
+ ref_string = " ".join(ref_tokens[i1:i2])
42
+ hyp_string = " ".join(hyp_tokens[j1:j2])
43
+ if op == "equal":
44
+ out.append(" " + ref_string)
45
+ elif op == "delete":
46
+ out.append(no_match_html + ref_string + "</span>")
47
+ elif op == "insert":
48
+ out.append(match_html + hyp_string + "</span>")
49
+ elif op == "replace":
50
+ out.append(no_match_html + ref_string + "</span>")
51
+ out.append(match_html + hyp_string + "</span>")
52
+ html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(
53
+ out).strip() + "</div>"
54
+ return html
55
+
56
+
57
+ def make_html(sentence_match):
58
+ """Creates the HTML written out to the UI based on the results.
59
+ Parameters:
60
+ sentence_match: Class that stores the features of the target - user utterance alignment
61
+ Returns:
62
+ diff_html: An HTML string showing how the target sentence and recognized user utterance matches.
63
+ result_html: An HTML string summarizing the results of the match between target and user utterance.
64
+ """
65
+ diff_html = make_alignment_html(sentence_match.target_tokens,
66
+ sentence_match.user_tokens,
67
+ sentence_match.alignments)
68
+ result_html, score_html = make_result_html(sentence_match.pass_threshold,
69
+ sentence_match.passed,
70
+ sentence_match.ratio)
71
+
72
+ return score_html, result_html, diff_html
73
 
 
 
74
 
75
+ # ------------------- Core Check (English-only) -------------------
76
+ def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str):
77
+ """Uses the selected ASR model `model_id` to recognize words in the input `audio_path`.
78
+ Parameters:
79
+ audio_path: Processed audio file returned from gradio Audio component.
80
+ target_sentence: Sentence the user needs to say.
81
+ model_id: Desired ASR model.
82
+ device_pref: Preferred ASR processing device. Can be "auto", "cpu", "cuda".
83
+ Returns:
84
+ error_msg: If there's an error, a string describing what happened.
85
+ user_transcript: The recognized user utterance.
86
+ """
87
+ error_msg = ""
88
+ # Handles user interaction errors.
89
+ if not target_sentence:
90
+ return "Please generate a sentence first.", ""
91
+ # TODO: Automatically stop the recording if someone presses the Transcribe & Check button.
92
+ if audio_path is None:
93
+ return "Please start, record, then stop the audio recording before trying to transcribe.", ""
94
+
95
+ # Runs automatic speech recognition
96
+ user_transcript = process.run_asr(audio_path, model_id, device_pref)
97
+
98
+ # Handles processing errors.
99
+ if type(user_transcript) is Exception:
100
+ return f"Transcription failed: {user_transcript}", ""
101
+
102
+ return error_msg, user_transcript
103
+
104
+
105
+ def transcribe_check(audio_path, target_sentence, model_id, device_pref,
106
+ pass_threshold):
107
+ """Transcribe the input user audio, calculate the match to the target sentence,
108
+ create the output HTML string displaying the results.
109
+ Parameters:
110
+ audio_path: Local path to recorded audio.
111
+ target_sentence: Sentence the user needs to say.
112
+ model_id: Desired ASR model.
113
+ device_pref: Preferred ASR processing device. Can be "auto", "cpu", "cuda".
114
+ Returns:
115
+ user_transcript: The recognized user utterance
116
+ score_html: HTML string to display the score
117
+ diff_html: HTML string for displaying the differences between target and user utterance
118
+ result_html: HTML string describing the results, or an error message
119
+ """
120
+ # Transcribe user input
121
+ error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id,
122
+ device_pref)
123
+ if error_msg != "":
124
+ score_html = ""
125
+ diff_html = ""
126
+ result_html = error_msg
127
+ else:
128
+ # Calculate match details between the target and recognized user input
129
+ sentence_match = process.SentenceMatcher(target_sentence, user_transcript,
130
+ pass_threshold)
131
+ # Create the output to print out
132
+ score_html, result_html, diff_html = make_html(sentence_match)
133
+ return user_transcript, score_html, result_html, diff_html
134
 
 
135
 
136
  # ------------------- UI -------------------
137
  with gr.Blocks(title="Say the Sentence (English)") as demo:
 
145
  )
146
 
147
  with gr.Row():
148
+ target = gr.Textbox(label="Target sentence", interactive=False,
149
+ placeholder="Click 'Generate sentence'")
150
 
151
  with gr.Row():
152
  btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
153
  btn_clear = gr.Button("🧹 Clear")
154
 
155
  with gr.Row():
156
+ audio = gr.Audio(sources=["microphone"], type="filepath",
157
+ label="Record your voice")
158
 
159
  with gr.Accordion("Advanced settings", open=False):
160
  model_id = gr.Dropdown(
161
  choices=[
162
+ "openai/whisper-tiny.en", # fastest (CPU-friendly)
163
+ "openai/whisper-base.en", # better accuracy, a bit slower
164
+ "distil-whisper/distil-small.en"
165
+ # optional distil English model
166
  ],
167
  value="openai/whisper-tiny.en",
168
  label="ASR model (English only)",
 
172
  value="auto",
173
  label="Device preference"
174
  )
175
+ pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01,
176
+ label="Match threshold")
177
 
178
  with gr.Row():
179
  btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
 
180
  with gr.Row():
181
+ user_transcript = gr.Textbox(label="Transcription", interactive=False)
182
  with gr.Row():
183
+ score_html = gr.Label(label="Score")
184
+ result_html = gr.Label(label="Result")
185
+ diff_html = gr.HTML(
186
+ label="Word-level diff (red = expected but missing / green = extra or replacement)")
187
+
188
+ # -------- Events --------
189
+ # Uncomment below if you prefer to use the pre-specified set of target sentences.
190
+ btn_gen.click(fn=generate.gen_sentence_set, outputs=target)
191
+ # Comment this out below if you prefer to use the pre-specified set of target sentences (above).
192
+ # btn_gen.click(fn=generate.gen_sentence_llm, outputs=target)
193
+ btn_clear.click(fn=clear_all,
194
+ outputs=[target, user_transcript, score_html, result_html, diff_html])
195
  btn_check.click(
196
+ fn=transcribe_check,
197
  inputs=[audio, target, model_id, device_pref, pass_threshold],
198
+ outputs=[user_transcript, score_html, result_html, diff_html]
199
  )
200
 
201
  if __name__ == "__main__":
202
+ demo.launch()