frimelle HF Staff commited on
Commit
182a1f7
·
1 Parent(s): f329f75

English-only setup

Browse files
Files changed (1) hide show
  1. app.py +17 -21
app.py CHANGED
@@ -75,15 +75,14 @@ def get_asr(model_id: str, device_preference: str):
75
  def gen_sentence():
76
  return random.choice(SENTENCE_BANK)
77
 
78
- def check_pronunciation(audio_path, target_sentence, model_id, lang, device_pref, pass_threshold):
79
  if not target_sentence:
80
  return gr.update(value=""), gr.update(value=""), gr.update(value=""), gr.update(value="Please generate a sentence first.")
81
 
82
  asr = get_asr(model_id, device_pref)
83
- # Whisper models accept a 'generate' kwarg with language hints via tokenizer, but
84
- # transformers pipeline exposes it as 'generate_kwargs' for whisper models.
85
  try:
86
- result = asr(audio_path, generate_kwargs={"language": lang} if lang else None)
87
  hyp_raw = result["text"].strip()
88
  except Exception as e:
89
  return "", "", "", f"Transcription failed: {e}"
@@ -122,20 +121,17 @@ with gr.Blocks(title="Say the Sentence") as demo:
122
  with gr.Row():
123
  audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your voice")
124
  with gr.Accordion("Advanced settings", open=False):
125
- model_id = gr.Dropdown(
126
- choices=[
127
- "openai/whisper-tiny.en", # Fastest (English)
128
- "openai/whisper-base.en",
129
- "openai/whisper-small.en",
130
- "distil-whisper/distil-small.en", # Distil variant (English)
131
- "openai/whisper-tiny", # Multilingual tiny
132
- ],
133
- value="openai/whisper-tiny.en",
134
- label="ASR model",
135
- )
136
- lang = gr.Textbox(value="en", label="Language hint (e.g., 'en', 'de', 'fr')", info="Whisper language code; leave as 'en' for English-only models.")
137
- device_pref = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device preference")
138
- pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01, label="Match threshold")
139
 
140
  with gr.Row():
141
  btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
@@ -151,9 +147,9 @@ with gr.Blocks(title="Say the Sentence") as demo:
151
  btn_gen.click(fn=gen_sentence, outputs=target)
152
  btn_clear.click(fn=lambda: ("", "", "", "", ""), outputs=[target, hyp_out, score_out, diff_out, summary_out])
153
  btn_check.click(
154
- fn=check_pronunciation,
155
- inputs=[audio, target, model_id, lang, device_pref, pass_threshold],
156
- outputs=[hyp_out, score_out, diff_out, summary_out]
157
  )
158
 
159
  if __name__ == "__main__":
 
75
  def gen_sentence():
76
  return random.choice(SENTENCE_BANK)
77
 
78
+ def check_pronunciation(audio_path, target_sentence, model_id, device_pref, pass_threshold):
79
  if not target_sentence:
80
  return gr.update(value=""), gr.update(value=""), gr.update(value=""), gr.update(value="Please generate a sentence first.")
81
 
82
  asr = get_asr(model_id, device_pref)
83
+
 
84
  try:
85
+ result = asr(audio_path) # ✅ no language/task args for English-only models
86
  hyp_raw = result["text"].strip()
87
  except Exception as e:
88
  return "", "", "", f"Transcription failed: {e}"
 
121
  with gr.Row():
122
  audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your voice")
123
  with gr.Accordion("Advanced settings", open=False):
124
+ model_id = gr.Dropdown(
125
+ choices=[
126
+ "openai/whisper-tiny.en", # fastest
127
+ "openai/whisper-base.en", # slightly better accuracy
128
+ "distil-whisper/distil-small.en", # optional
129
+ ],
130
+ value="openai/whisper-tiny.en",
131
+ label="ASR model (English only)",
132
+ )
133
+ device_pref = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device preference")
134
+ pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01, label="Match threshold")
 
 
 
135
 
136
  with gr.Row():
137
  btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
 
147
  btn_gen.click(fn=gen_sentence, outputs=target)
148
  btn_clear.click(fn=lambda: ("", "", "", "", ""), outputs=[target, hyp_out, score_out, diff_out, summary_out])
149
  btn_check.click(
150
+ fn=check_pronunciation,
151
+ inputs=[audio, target, model_id, device_pref, pass_threshold],
152
+ outputs=[hyp_out, score_out, diff_out, summary_out]
153
  )
154
 
155
  if __name__ == "__main__":