frimelle HF Staff commited on
Commit
0d9ff36
·
1 Parent(s): 8fa1508

try to recreate the code

Browse files
Files changed (5) hide show
  1. app.py +143 -122
  2. requirements.txt +0 -5
  3. src/generate.py +45 -0
  4. src/process.py +53 -40
  5. src/tts.py +25 -15
app.py CHANGED
@@ -1,118 +1,128 @@
1
  import gradio as gr
2
- import random
3
- import re
4
- import difflib
5
- import torch
6
- from functools import lru_cache
7
- from transformers import pipeline
8
-
9
- # ------------------- Sentence Bank (customize freely) -------------------
10
- SENTENCE_BANK = [
11
- "The quick brown fox jumps over the lazy dog.",
12
- "I promise to speak clearly and at a steady pace.",
13
- "Open source makes AI more transparent and inclusive.",
14
- "Hugging Face Spaces make demos easy to share.",
15
- "Today the weather in Berlin is pleasantly cool.",
16
- "Privacy and transparency should go hand in hand.",
17
- "Please generate a new sentence for me to read.",
18
- "Machine learning can amplify or reduce inequality.",
19
- "Responsible AI requires participation from everyone.",
20
- "This microphone test checks my pronunciation accuracy.",
21
- ]
22
-
23
- # ------------------- Utilities -------------------
24
- def normalize_text(t: str) -> str:
25
- # English-only normalization: lowercase, keep letters/digits/' and -
26
- t = t.lower()
27
- t = re.sub(r"[^a-z0-9'\-]+", " ", t)
28
- t = re.sub(r"\s+", " ", t).strip()
29
- return t
30
-
31
- def similarity_and_diff(ref: str, hyp: str):
32
- """Return similarity ratio (0..1) and HTML diff highlighting changes."""
33
- ref_tokens = ref.split()
34
- hyp_tokens = hyp.split()
35
- sm = difflib.SequenceMatcher(a=ref_tokens, b=hyp_tokens)
36
- ratio = sm.ratio()
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  out = []
39
- for op, i1, i2, j1, j2 in sm.get_opcodes():
 
 
 
 
 
40
  if op == "equal":
41
- out.append(" " + " ".join(ref_tokens[i1:i2]))
42
  elif op == "delete":
43
- out.append(
44
- ' <span style="background:#ffe0e0;text-decoration:line-through;">'
45
- + " ".join(ref_tokens[i1:i2]) + "</span>"
46
- )
47
  elif op == "insert":
48
- out.append(
49
- ' <span style="background:#e0ffe0;">'
50
- + " ".join(hyp_tokens[j1:j2]) + "</span>"
51
- )
52
  elif op == "replace":
53
- out.append(
54
- ' <span style="background:#ffe0e0;text-decoration:line-through;">'
55
- + " ".join(ref_tokens[i1:i2]) + "</span>"
56
- )
57
- out.append(
58
- ' <span style="background:#e0ffe0;">'
59
- + " ".join(hyp_tokens[j1:j2]) + "</span>"
60
- )
61
  html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>"
62
- return ratio, html
63
-
64
- @lru_cache(maxsize=2)
65
- def get_asr(model_id: str, device_preference: str):
66
- """Cache an ASR pipeline. device_preference: 'auto'|'cpu'|'cuda'."""
67
- if device_preference == "cuda" and torch.cuda.is_available():
68
- device = 0
69
- elif device_preference == "auto":
70
- device = 0 if torch.cuda.is_available() else -1
71
- else:
72
- device = -1
73
- return pipeline(
74
- "automatic-speech-recognition",
75
- model=model_id, # use English-only Whisper models (.en)
76
- device=device,
77
- chunk_length_s=30,
78
- return_timestamps=False,
79
- )
80
 
81
- def gen_sentence():
82
- return random.choice(SENTENCE_BANK)
83
 
84
- def clear_all():
85
- # target, hyp_out, score_out, diff_out, summary_out
86
- return "", "", "", "", ""
 
 
 
 
 
 
 
87
 
88
  # ------------------- Core Check (English-only) -------------------
89
- def check_pronunciation(audio_path, target_sentence, model_id, device_pref, pass_threshold):
 
90
  if not target_sentence:
91
- return "", "", "", "Please generate a sentence first."
92
-
93
- asr = get_asr(model_id, device_pref)
94
-
95
- try:
96
- # IMPORTANT: For English-only Whisper (.en), do NOT pass language/task args.
97
- result = asr(audio_path)
98
- hyp_raw = result["text"].strip()
99
- except Exception as e:
100
- return "", "", "", f"Transcription failed: {e}"
101
-
102
- ref_norm = normalize_text(target_sentence)
103
- hyp_norm = normalize_text(hyp_raw)
104
-
105
- ratio, diff_html = similarity_and_diff(ref_norm, hyp_norm)
106
- passed = ratio >= pass_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- summary = (
109
- f"✅ Correct ( {int(pass_threshold*100)}%)"
110
- if passed else
111
- f" Not a match (need ≥ {int(pass_threshold*100)}%)"
112
- )
113
- score = f"Similarity: {ratio*100:.1f}%"
 
114
 
115
- return hyp_raw, score, diff_html, summary
116
 
117
  # ------------------- UI -------------------
118
  with gr.Blocks(title="Say the Sentence (English)") as demo:
@@ -122,25 +132,28 @@ with gr.Blocks(title="Say the Sentence (English)") as demo:
122
  1) Generate a sentence.
123
  2) Record yourself reading it.
124
  3) Transcribe & check your accuracy.
 
125
  """
126
  )
127
 
128
  with gr.Row():
129
- target = gr.Textbox(label="Target sentence", interactive=False, placeholder="Click 'Generate sentence'")
 
130
 
131
  with gr.Row():
132
  btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
133
  btn_clear = gr.Button("🧹 Clear")
134
 
135
  with gr.Row():
136
- audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your voice")
 
137
 
138
  with gr.Accordion("Advanced settings", open=False):
139
  model_id = gr.Dropdown(
140
  choices=[
141
- "openai/whisper-tiny.en", # fastest (CPU-friendly)
142
- "openai/whisper-base.en", # better accuracy, a bit slower
143
- "distil-whisper/distil-small.en" # optional distil English model
144
  ],
145
  value="openai/whisper-tiny.en",
146
  label="ASR model (English only)",
@@ -150,11 +163,11 @@ with gr.Blocks(title="Say the Sentence (English)") as demo:
150
  value="auto",
151
  label="Device preference"
152
  )
153
- pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01, label="Match threshold")
 
154
 
155
  with gr.Row():
156
  btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
157
- <<<<<<< HEAD
158
  with gr.Row():
159
  user_transcript = gr.Textbox(label="Transcription", interactive=False)
160
  with gr.Row():
@@ -172,10 +185,10 @@ with gr.Blocks(title="Say the Sentence (English)") as demo:
172
  with gr.Row():
173
  tts_model_id = gr.Dropdown(
174
  choices=[
175
- "tts_models/multilingual/multi-dataset/xtts_v2",
176
- # add others if you like, e.g. "myshell-ai/MeloTTS"
177
  ],
178
- value="tts_models/multilingual/multi-dataset/xtts_v2",
179
  label="TTS (voice cloning) model",
180
  )
181
  tts_language = gr.Dropdown(
@@ -183,24 +196,32 @@ with gr.Blocks(title="Say the Sentence (English)") as demo:
183
  value="en",
184
  label="Language",
185
  )
186
- =======
187
- >>>>>>> parent of c5d4931 (add audio cloning functionality (test))
188
 
189
  with gr.Row():
190
- hyp_out = gr.Textbox(label="Transcription", interactive=False)
191
  with gr.Row():
192
- score_out = gr.Label(label="Score")
193
- summary_out = gr.Label(label="Result")
194
- diff_out = gr.HTML(label="Word-level diff (red = expected but missing / green = extra or replacement)")
 
 
 
 
 
 
 
195
 
196
- # Events
197
- btn_gen.click(fn=gen_sentence, outputs=target)
198
- btn_clear.click(fn=clear_all, outputs=[target, hyp_out, score_out, diff_out, summary_out])
199
  btn_check.click(
200
- fn=check_pronunciation,
201
  inputs=[audio, target, model_id, device_pref, pass_threshold],
202
- outputs=[hyp_out, score_out, diff_out, summary_out]
 
 
 
 
 
 
203
  )
204
 
205
  if __name__ == "__main__":
206
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import src.generate as generate
4
+ import src.process as process
5
+ import src.tts as tts
6
+
7
+
8
+ # ------------------- UI printing functions -------------------
9
+ def clear_all():
10
+ # target, user_transcript, score_html, diff_html, result_html,
11
+ # tts_text, clone_status, tts_audio
12
+ return "", "", "", "", "", "", "", None
13
+
14
+
15
+ def make_result_html(pass_threshold, passed, ratio):
16
+ """Returns summary and score label."""
17
+ summary = (
18
+ f"✅ Correct (≥ {int(pass_threshold * 100)}%)"
19
+ if passed else
20
+ f"❌ Not a match (need ≥ {int(pass_threshold * 100)}%)"
21
+ )
22
+ score = f"Similarity: {ratio * 100:.1f}%"
23
+ return summary, score
24
+
25
+
26
+ def make_alignment_html(ref_tokens, hyp_tokens, alignments):
27
+ """Returns HTML showing alignment between target and recognized user audio."""
28
  out = []
29
+ no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">'
30
+ match_html = ' <span style="background:#e0ffe0;">'
31
+ for span in alignments:
32
+ op, i1, i2, j1, j2 = span
33
+ ref_string = " ".join(ref_tokens[i1:i2])
34
+ hyp_string = " ".join(hyp_tokens[j1:j2])
35
  if op == "equal":
36
+ out.append(" " + ref_string)
37
  elif op == "delete":
38
+ out.append(no_match_html + ref_string + "</span>")
 
 
 
39
  elif op == "insert":
40
+ out.append(match_html + hyp_string + "</span>")
 
 
 
41
  elif op == "replace":
42
+ out.append(no_match_html + ref_string + "</span>")
43
+ out.append(match_html + hyp_string + "</span>")
 
 
 
 
 
 
44
  html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>"
45
+ return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
47
 
48
+ def make_html(sentence_match):
49
+ """Build diff + results HTML."""
50
+ diff_html = make_alignment_html(sentence_match.target_tokens,
51
+ sentence_match.user_tokens,
52
+ sentence_match.alignments)
53
+ result_html, score_html = make_result_html(sentence_match.pass_threshold,
54
+ sentence_match.passed,
55
+ sentence_match.ratio)
56
+ return score_html, result_html, diff_html
57
+
58
 
59
  # ------------------- Core Check (English-only) -------------------
60
+ def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str):
61
+ """ASR for the input audio and basic validation."""
62
  if not target_sentence:
63
+ return "Please generate a sentence first.", ""
64
+ if audio_path is None:
65
+ return "Please start, record, then stop the audio recording before trying to transcribe.", ""
66
+
67
+ user_transcript = process.run_asr(audio_path, model_id, device_pref)
68
+ if isinstance(user_transcript, Exception):
69
+ return f"Transcription failed: {user_transcript}", ""
70
+ return "", user_transcript
71
+
72
+
73
+ def transcribe_check(audio_path, target_sentence, model_id, device_pref, pass_threshold):
74
+ """Transcribe user audio, compute match, and render results."""
75
+ error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id, device_pref)
76
+ if error_msg:
77
+ score_html = ""
78
+ diff_html = ""
79
+ result_html = error_msg
80
+ else:
81
+ sentence_match = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
82
+ score_html, result_html, diff_html = make_html(sentence_match)
83
+ return user_transcript, score_html, result_html, diff_html
84
+
85
+
86
+ # ------------------- Voice cloning gate -------------------
87
+ def clone_if_pass(
88
+ audio_path, # ref voice (the same recorded clip)
89
+ target_sentence, # sentence user was supposed to say
90
+ user_transcript, # what ASR heard
91
+ tts_text, # what we want to synthesize (in cloned voice)
92
+ pass_threshold, # must meet or exceed this
93
+ tts_model_id, # e.g., "coqui/XTTS-v2"
94
+ tts_language, # e.g., "en"
95
+ ):
96
+ """
97
+ If user correctly read the target (>= threshold), clone their voice from the
98
+ recorded audio and speak 'tts_text'. Otherwise, refuse.
99
+ """
100
+ # Basic validations
101
+ if audio_path is None:
102
+ return None, "Record audio first (reference voice is required)."
103
+ if not target_sentence:
104
+ return None, "Generate a target sentence first."
105
+ if not user_transcript:
106
+ return None, "Transcribe first to verify the sentence."
107
+ if not tts_text:
108
+ return None, "Enter the sentence to synthesize."
109
+
110
+ # Recompute pass/fail to avoid relying on UI state
111
+ sm = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
112
+ if not sm.passed:
113
+ return None, (
114
+ f"❌ Cloning blocked: your reading did not reach the threshold "
115
+ f"({sm.ratio*100:.1f}% < {int(pass_threshold*100)}%)."
116
+ )
117
 
118
+ # Run zero-shot cloning
119
+ out = tts.run_tts_clone(audio_path, tts_text, model_id=tts_model_id, language=tts_language)
120
+ if isinstance(out, Exception):
121
+ return None, f"Voice cloning failed: {out}"
122
+ sr, wav = out
123
+ # Gradio Audio can take a tuple (sr, np.array)
124
+ return (sr, wav), f"✅ Cloned and synthesized with {tts_model_id} ({tts_language})."
125
 
 
126
 
127
  # ------------------- UI -------------------
128
  with gr.Blocks(title="Say the Sentence (English)") as demo:
 
132
  1) Generate a sentence.
133
  2) Record yourself reading it.
134
  3) Transcribe & check your accuracy.
135
+ 4) If matched, clone your voice to speak any sentence you enter.
136
  """
137
  )
138
 
139
  with gr.Row():
140
+ target = gr.Textbox(label="Target sentence", interactive=False,
141
+ placeholder="Click 'Generate sentence'")
142
 
143
  with gr.Row():
144
  btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
145
  btn_clear = gr.Button("🧹 Clear")
146
 
147
  with gr.Row():
148
+ audio = gr.Audio(sources=["microphone"], type="filepath",
149
+ label="Record your voice")
150
 
151
  with gr.Accordion("Advanced settings", open=False):
152
  model_id = gr.Dropdown(
153
  choices=[
154
+ "openai/whisper-tiny.en",
155
+ "openai/whisper-base.en",
156
+ "distil-whisper/distil-small.en",
157
  ],
158
  value="openai/whisper-tiny.en",
159
  label="ASR model (English only)",
 
163
  value="auto",
164
  label="Device preference"
165
  )
166
+ pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01,
167
+ label="Match threshold")
168
 
169
  with gr.Row():
170
  btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
 
171
  with gr.Row():
172
  user_transcript = gr.Textbox(label="Transcription", interactive=False)
173
  with gr.Row():
 
185
  with gr.Row():
186
  tts_model_id = gr.Dropdown(
187
  choices=[
188
+ "coqui/XTTS-v2",
189
+ # add others if you like, e.g., "myshell-ai/MeloTTS"
190
  ],
191
+ value="coqui/XTTS-v2",
192
  label="TTS (voice cloning) model",
193
  )
194
  tts_language = gr.Dropdown(
 
196
  value="en",
197
  label="Language",
198
  )
 
 
199
 
200
  with gr.Row():
201
+ btn_clone = gr.Button("🔁 Clone voice (if passed)", variant="secondary")
202
  with gr.Row():
203
+ tts_audio = gr.Audio(label="Cloned speech output", interactive=False)
204
+ clone_status = gr.Label(label="Cloning status")
205
+
206
+ # -------- Events --------
207
+ btn_gen.click(fn=generate.gen_sentence_set, outputs=target)
208
+
209
+ btn_clear.click(
210
+ fn=clear_all,
211
+ outputs=[target, user_transcript, score_html, result_html, diff_html, tts_text, clone_status, tts_audio]
212
+ )
213
 
 
 
 
214
  btn_check.click(
215
+ fn=transcribe_check,
216
  inputs=[audio, target, model_id, device_pref, pass_threshold],
217
+ outputs=[user_transcript, score_html, result_html, diff_html]
218
+ )
219
+
220
+ btn_clone.click(
221
+ fn=clone_if_pass,
222
+ inputs=[audio, target, user_transcript, tts_text, pass_threshold, tts_model_id, tts_language],
223
+ outputs=[tts_audio, clone_status],
224
  )
225
 
226
  if __name__ == "__main__":
227
+ demo.launch()
requirements.txt CHANGED
@@ -3,12 +3,7 @@ transformers>=4.44.0
3
  torch>=2.2.0
4
  accelerate>=0.33.0
5
  sentencepiece>=0.2.0
6
- <<<<<<< HEAD
7
  numpy
8
  TTS>=0.22.0
9
  ffmpeg-python
10
  torchaudio
11
-
12
-
13
- =======
14
- >>>>>>> parent of c5d4931 (add audio cloning functionality (test))
 
3
  torch>=2.2.0
4
  accelerate>=0.33.0
5
  sentencepiece>=0.2.0
 
6
  numpy
7
  TTS>=0.22.0
8
  ffmpeg-python
9
  torchaudio
 
 
 
 
src/generate.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from transformers import pipeline, AutoTokenizer
4
+
5
+ import src.process as process
6
+
7
+ # You can choose to use either:
8
+ # (1) a list of pre-specified sentences, in SENTENCE_BANK
9
+ # (2) an LLM-generated sentence.
10
+ # SENTENCE_BANK is used in the gen_sentence_set function.
11
+ # LLM generation is used in the gen_sentence_llm function.
12
+
13
+ # ------------------- Sentence Bank (customize freely) -------------------
14
+ SENTENCE_BANK = [
15
+ "The quick brown fox jumps over the lazy dog.",
16
+ "I promise to speak clearly and at a steady pace.",
17
+ "Open source makes AI more transparent and inclusive.",
18
+ "Hugging Face Spaces make demos easy to share.",
19
+ "Today the weather in Berlin is pleasantly cool.",
20
+ "Privacy and transparency should go hand in hand.",
21
+ "Please generate a new sentence for me to read.",
22
+ "Machine learning can amplify or reduce inequality.",
23
+ "Responsible AI requires participation from everyone.",
24
+ "This microphone test checks my pronunciation accuracy.",
25
+ ]
26
+
27
+
28
+ def gen_sentence_llm():
29
+ """Generates a sentence using an LLM.
30
+ Returns:
31
+ Normalized text string to display in the UI.
32
+ """
33
+ prompt = ""
34
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
35
+ generator = pipeline('text-generation', model='gpt2')
36
+ result = generator(prompt, stop_strings=[".", ], num_return_sequences=1,
37
+ tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id)
38
+ display_text = process.normalize_text(result[0]["generated_text"],
39
+ lower=False)
40
+ return display_text
41
+
42
+
43
+ def gen_sentence_set():
44
+ """Returns a sentence for the user to say using a prespecified set of options."""
45
+ return random.choice(SENTENCE_BANK)
src/process.py CHANGED
@@ -1,74 +1,87 @@
1
  import difflib
2
- import os
3
  import re
4
  from functools import lru_cache
5
 
 
6
  import torch
7
  from transformers import pipeline
8
 
 
9
  # ------------------- Utilities -------------------
10
  def normalize_text(t: str, lower: bool = True) -> str:
 
 
 
11
  if lower:
12
  t = t.lower()
 
13
  t = re.sub(r"[^a-zA-Z0-9'\-.,]+", " ", t)
14
  t = re.sub(r"\s+", " ", t).strip()
15
  return t
16
 
17
- def _pick_device(pref: str) -> int:
18
- if pref == "cuda" and torch.cuda.is_available():
19
- return 0
20
- if pref == "auto":
21
- return 0 if torch.cuda.is_available() else -1
22
- return -1
23
 
24
  @lru_cache(maxsize=2)
25
- def get_asr_pipeline(model_id: str, device_preference: str):
26
- device = _pick_device(device_preference)
27
- # IMPORTANT: For .en models do NOT set language/task
 
 
 
 
 
 
 
 
 
 
 
28
  return pipeline(
29
  "automatic-speech-recognition",
30
- model=model_id,
31
  device=device,
32
  chunk_length_s=30,
33
  return_timestamps=False,
34
  )
35
 
36
- def _validate_audio_path(p: str) -> None:
37
- if not isinstance(p, str):
38
- raise ValueError("Audio input is not a file path (expected type='filepath').")
39
- if not os.path.exists(p):
40
- raise FileNotFoundError(f"Recorded audio file not found: {p}")
41
- if os.path.getsize(p) < 1024:
42
- raise ValueError("Recorded audio seems empty or too short (<1KB). Try again.")
43
-
44
- def run_asr(audio_path, model_id: str, device_pref: str):
45
- """
46
- Returns the recognized text or an Exception (do NOT raise).
47
  """
 
48
  try:
49
- _validate_audio_path(audio_path)
50
- asr = get_asr_pipeline(model_id, device_pref)
51
  result = asr(audio_path)
52
- # transformers ASR returns {"text": "...", ...}
53
- hyp_raw = result.get("text", "").strip()
54
- if not hyp_raw:
55
- raise RuntimeError("ASR returned empty text.")
56
- return hyp_raw
57
  except Exception as e:
58
- # Return the real, descriptive error back to the UI
59
  return e
 
60
 
61
- # -------------- diff + matching (unchanged) --------------
62
- def similarity_and_diff(ref_tokens: list, hyp_tokens: list):
 
 
 
 
63
  sm = difflib.SequenceMatcher(a=ref_tokens, b=hyp_tokens)
64
- return sm.ratio(), sm.get_opcodes()
 
 
65
 
66
  class SentenceMatcher:
 
67
  def __init__(self, target_sentence, user_transcript, pass_threshold):
68
- self.target_sentence = target_sentence
69
- self.user_transcript = user_transcript
70
- self.pass_threshold = pass_threshold
71
- self.target_tokens = normalize_text(target_sentence).split()
72
- self.user_tokens = normalize_text(user_transcript).split()
73
- self.ratio, self.alignments = similarity_and_diff(self.target_tokens, self.user_tokens)
74
- self.passed = self.ratio >= self.pass_threshold
 
 
 
 
1
  import difflib
 
2
  import re
3
  from functools import lru_cache
4
 
5
+ import gradio.components.audio as gr_audio
6
  import torch
7
  from transformers import pipeline
8
 
9
+
10
  # ------------------- Utilities -------------------
11
  def normalize_text(t: str, lower: bool = True) -> str:
12
+ """For normalizing LLM-generated and human-generated strings.
13
+ For LLMs, this removes extraneous quote marks and spaces."""
14
+ # English-only normalization: lowercase, keep letters/digits/' and -
15
  if lower:
16
  t = t.lower()
17
+ # TODO: Previously was re.sub(r"[^a-z0-9'\-]+", " ", t); discuss normalizing for LLMs too.
18
  t = re.sub(r"[^a-zA-Z0-9'\-.,]+", " ", t)
19
  t = re.sub(r"\s+", " ", t).strip()
20
  return t
21
 
 
 
 
 
 
 
22
 
23
  @lru_cache(maxsize=2)
24
+ def get_asr_pipeline(model_id: str, device_preference: str) -> pipeline:
25
+ """Cache an ASR pipeline.
26
+ Parameters:
27
+ model_id: String of desired ASR model.
28
+ device_preference: String of desired device for ASR processing, "cuda", "cpu", or "auto".
29
+ Returns:
30
+ transformers.pipeline ASR component.
31
+ """
32
+ if device_preference == "cuda" and torch.cuda.is_available():
33
+ device = 0
34
+ elif device_preference == "auto":
35
+ device = 0 if torch.cuda.is_available() else -1
36
+ else:
37
+ device = -1
38
  return pipeline(
39
  "automatic-speech-recognition",
40
+ model=model_id, # use English-only Whisper models (.en)
41
  device=device,
42
  chunk_length_s=30,
43
  return_timestamps=False,
44
  )
45
 
46
+ def run_asr(audio_path: gr_audio, model_id: str, device_pref: str) -> str | Exception:
47
+ """Returns the recognized user utterance from the input audio stream.
48
+ Parameters:
49
+ audio_path: gradio.Audio component.
50
+ model_id: String of desired ASR model.
51
+ device_preference: String of desired device for ASR processing, "cuda", "cpu", or "auto".
52
+ Returns:
53
+ hyp_raw: Recognized user utterance.
 
 
 
54
  """
55
+ asr = get_asr_pipeline(model_id, device_pref)
56
  try:
57
+ # IMPORTANT: For English-only Whisper (.en), do NOT pass language/task args.
 
58
  result = asr(audio_path)
59
+ hyp_raw = result["text"].strip()
 
 
 
 
60
  except Exception as e:
 
61
  return e
62
+ return hyp_raw
63
 
64
+ def similarity_and_diff(ref_tokens: list, hyp_tokens: list) -> (float, list[str, int, int, int]):
65
+ """
66
+ Returns:
67
+ ratio: Similarity ratio (0..1).
68
+ opcodes: List of differences between target and recognized user utterance.
69
+ """
70
  sm = difflib.SequenceMatcher(a=ref_tokens, b=hyp_tokens)
71
+ ratio = sm.ratio()
72
+ opcodes = sm.get_opcodes()
73
+ return ratio, opcodes
74
 
75
  class SentenceMatcher:
76
+ """Class for keeping track of (target sentence, user utterance) match features."""
77
  def __init__(self, target_sentence, user_transcript, pass_threshold):
78
+ self.target_sentence: str = target_sentence
79
+ self.user_transcript: str = user_transcript
80
+ self.pass_threshold: float = pass_threshold
81
+ self.target_tokens: list = normalize_text(target_sentence).split()
82
+ self.user_tokens: list = normalize_text(user_transcript).split()
83
+ self.ratio: float
84
+ self.alignments: list
85
+ self.ratio, self.alignments = similarity_and_diff(self.target_tokens,
86
+ self.user_tokens)
87
+ self.passed: bool = self.ratio >= self.pass_threshold
src/tts.py CHANGED
@@ -1,32 +1,42 @@
1
- # src/tts.py
2
  from __future__ import annotations
3
  from typing import Tuple, Union
 
4
  import numpy as np
5
- from TTS.api import TTS # ← from the Coqui TTS package, not transformers
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def run_tts_clone(
8
  ref_audio_path: str,
9
  text_to_speak: str,
10
- model_id: str = "tts_models/multilingual/multi-dataset/xtts_v2",
11
  language: str = "en",
12
  ) -> Union[Tuple[int, np.ndarray], Exception]:
13
  """
14
- Synthesize `text_to_speak` in the cloned voice from `ref_audio_path`.
15
 
16
  Returns:
17
  (sampling_rate, waveform) on success, or Exception on failure.
18
  """
19
  try:
20
- try:
21
- tts = TTS(model_name=model_id, progress_bar=False, gpu=False)
22
- except KeyError as ke:
23
- # Typical message shows just 'xtts_v2' → old Coqui package
24
- return RuntimeError(
25
- f"Coqui TTS cannot find '{model_id}'. "
26
- "Please upgrade the TTS package (e.g., `pip install -U TTS>=0.22.0`)."
27
- )
28
-
29
- wav = tts.tts(text=text_to_speak, speaker_wav=ref_audio_path, language=language)
30
- return 24000, np.asarray(wav, dtype=np.float32)
31
  except Exception as e:
32
  return e
 
 
1
  from __future__ import annotations
2
  from typing import Tuple, Union
3
+
4
  import numpy as np
5
+ from transformers import pipeline
6
+
7
+ # We use the text-to-speech pipeline with XTTS v2 (zero-shot cloning)
8
+ # Example forward params: {"speaker_wav": "/path/to/ref.wav", "language": "en"}
9
+
10
+ def get_tts_pipeline(model_id: str):
11
+ """
12
+ Create a TTS pipeline for the given model.
13
+ XTTS v2 works well for zero-shot cloning and is available on the Hub.
14
+ """
15
+ # NOTE: Add device selection similar to ASR if needed
16
+ return pipeline("text-to-speech", model=model_id)
17
 
18
  def run_tts_clone(
19
  ref_audio_path: str,
20
  text_to_speak: str,
21
+ model_id: str = "coqui/XTTS-v2",
22
  language: str = "en",
23
  ) -> Union[Tuple[int, np.ndarray], Exception]:
24
  """
25
+ Synthesize 'text_to_speak' in the cloned voice from 'ref_audio_path'.
26
 
27
  Returns:
28
  (sampling_rate, waveform) on success, or Exception on failure.
29
  """
30
  try:
31
+ tts = get_tts_pipeline(model_id)
32
+ result = tts(
33
+ text_to_speak,
34
+ forward_params={"speaker_wav": ref_audio_path, "language": language},
35
+ )
36
+ # transformers TTS returns dict like: {"audio": {"array": np.ndarray, "sampling_rate": 24000}}
37
+ audio = result["audio"]
38
+ sr = int(audio["sampling_rate"])
39
+ wav = audio["array"].astype(np.float32)
40
+ return sr, wav
 
41
  except Exception as e:
42
  return e