File size: 8,490 Bytes
f329f75
 
aad4cd6
 
f329f75
 
aad4cd6
63d0469
aad4cd6
63d0469
 
f329f75
aad4cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
182a1f7
f329f75
aad4cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f329f75
 
aad4cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f329f75
 
63d0469
 
f329f75
 
63d0469
f329f75
63d0469
 
f329f75
 
 
 
aad4cd6
 
63d0469
f329f75
 
 
 
 
aad4cd6
 
63d0469
f329f75
63d0469
 
aad4cd6
 
 
 
63d0469
 
 
 
 
 
 
 
 
aad4cd6
 
f329f75
 
 
 
aad4cd6
f329f75
aad4cd6
 
 
 
 
 
 
 
 
 
 
 
f329f75
aad4cd6
63d0469
aad4cd6
f329f75
 
 
aad4cd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr

import src.generate as generate
import src.process as process


# ------------------- UI printing functions -------------------
def clear_all():
    # target, user_transcript, score_html, diff_html, result_html
    return "", "", "", "", ""


def make_result_html(pass_threshold, passed, ratio):
    """Returns HTML summarizing results.
    Parameters:
        pass_threshold: Minimum percentage of match between target and recognized user utterance that counts as passing.
        passed: Whether the recognized user utterance is >= `pass_threshold`.
        ratio: Sequence match ratio.
    """
    summary = (
        f"✅ Correct (≥ {int(pass_threshold * 100)}%)"
        if passed else
        f"❌ Not a match (need ≥ {int(pass_threshold * 100)}%)"
    )
    score = f"Similarity: {ratio * 100:.1f}%"
    return summary, score


def make_alignment_html(ref_tokens, hyp_tokens, alignments):
    """Returns HTML showing alignment between the target and recognized user audio.
    Parameters:
        ref_tokens: Target sentence for the user to say, tokenized.
        hyp_tokens: Recognized utterance from the user, tokenized.
        alignments: Tuples of alignment pattern (equal, delete, insert) and corresponding indices in `hyp_tokens`.
    """
    out = []
    no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">'
    match_html = ' <span style="background:#e0ffe0;">'
    for span in alignments:
        op, i1, i2, j1, j2 = span
        ref_string = " ".join(ref_tokens[i1:i2])
        hyp_string = " ".join(hyp_tokens[j1:j2])
        if op == "equal":
            out.append(" " + ref_string)
        elif op == "delete":
            out.append(no_match_html + ref_string + "</span>")
        elif op == "insert":
            out.append(match_html + hyp_string + "</span>")
        elif op == "replace":
            out.append(no_match_html + ref_string + "</span>")
            out.append(match_html + hyp_string + "</span>")
    html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(
        out).strip() + "</div>"
    return html


def make_html(sentence_match):
    """Creates the HTML written out to the UI based on the results.
    Parameters:
        sentence_match: Class that stores the features of the target - user utterance alignment
    Returns:
        diff_html: An HTML string showing how the target sentence and recognized user utterance matches.
        result_html: An HTML string summarizing the results of the match between target and user utterance.
    """
    diff_html = make_alignment_html(sentence_match.target_tokens,
                                    sentence_match.user_tokens,
                                    sentence_match.alignments)
    result_html, score_html = make_result_html(sentence_match.pass_threshold,
                                     sentence_match.passed,
                                     sentence_match.ratio)

    return score_html, result_html, diff_html


# ------------------- Core Check (English-only) -------------------
def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str):
    """Uses the selected ASR model `model_id` to recognize words in the input `audio_path`.
    Parameters:
        audio_path: Processed audio file returned from gradio Audio component.
        target_sentence: Sentence the user needs to say.
        model_id: Desired ASR model.
        device_pref: Preferred ASR processing device. Can be "auto", "cpu", "cuda".
    Returns:
        error_msg: If there's an error, a string describing what happened.
        user_transcript: The recognized user utterance.
    """
    error_msg = ""
    # Handles user interaction errors.
    if not target_sentence:
        return "Please generate a sentence first.", ""
    # TODO: Automatically stop the recording if someone presses the Transcribe & Check button.
    if audio_path is None:
        return "Please start, record, then stop the audio recording before trying to transcribe.", ""

    # Runs automatic speech recognition
    user_transcript = process.run_asr(audio_path, model_id, device_pref)

    # Handles processing errors.
    if type(user_transcript) is Exception:
        return f"Transcription failed: {user_transcript}", ""

    return error_msg, user_transcript


def transcribe_check(audio_path, target_sentence, model_id, device_pref,
                     pass_threshold):
    """Transcribe the input user audio, calculate the match to the target sentence,
    create the output HTML string displaying the results.
    Parameters:
        audio_path: Local path to recorded audio.
        target_sentence: Sentence the user needs to say.
        model_id: Desired ASR model.
        device_pref: Preferred ASR processing device. Can be "auto", "cpu", "cuda".
    Returns:
        user_transcript: The recognized user utterance
        score_html: HTML string to display the score
        diff_html: HTML string for displaying the differences between target and user utterance
        result_html: HTML string describing the results, or an error message
    """
    # Transcribe user input
    error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id,
                                          device_pref)
    if error_msg != "":
        score_html = ""
        diff_html = ""
        result_html = error_msg
    else:
        # Calculate match details between the target and recognized user input
        sentence_match = process.SentenceMatcher(target_sentence, user_transcript,
                                                 pass_threshold)
        # Create the output to print out
        score_html, result_html, diff_html = make_html(sentence_match)
    return user_transcript, score_html, result_html, diff_html


# ------------------- UI -------------------
with gr.Blocks(title="Say the Sentence (English)") as demo:
    gr.Markdown(
        """
        # 🎤 Say the Sentence (English)
        1) Generate a sentence.  
        2) Record yourself reading it.  
        3) Transcribe & check your accuracy.  
        """
    )

    with gr.Row():
        target = gr.Textbox(label="Target sentence", interactive=False,
                            placeholder="Click 'Generate sentence'")

    with gr.Row():
        btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
        btn_clear = gr.Button("🧹 Clear")

    with gr.Row():
        audio = gr.Audio(sources=["microphone"], type="filepath",
                         label="Record your voice")

    with gr.Accordion("Advanced settings", open=False):
        model_id = gr.Dropdown(
            choices=[
                "openai/whisper-tiny.en",  # fastest (CPU-friendly)
                "openai/whisper-base.en",  # better accuracy, a bit slower
                "distil-whisper/distil-small.en"
                # optional distil English model
            ],
            value="openai/whisper-tiny.en",
            label="ASR model (English only)",
        )
        device_pref = gr.Radio(
            choices=["auto", "cpu", "cuda"],
            value="auto",
            label="Device preference"
        )
        pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01,
                                   label="Match threshold")

    with gr.Row():
        btn_check = gr.Button("✅ Transcribe & Check", variant="primary")
    with gr.Row():
        user_transcript = gr.Textbox(label="Transcription", interactive=False)
    with gr.Row():
        score_html = gr.Label(label="Score")
        result_html = gr.Label(label="Result")
    diff_html = gr.HTML(
        label="Word-level diff (red = expected but missing / green = extra or replacement)")

    # -------- Events --------
    # Uncomment below if you prefer to use the pre-specified set of target sentences.
    btn_gen.click(fn=generate.gen_sentence_set, outputs=target)
    # Comment this out below if you prefer to use the pre-specified set of target sentences (above).
    # btn_gen.click(fn=generate.gen_sentence_llm, outputs=target)
    btn_clear.click(fn=clear_all,
                    outputs=[target, user_transcript, score_html, result_html, diff_html])
    btn_check.click(
        fn=transcribe_check,
        inputs=[audio, target, model_id, device_pref, pass_threshold],
        outputs=[user_transcript, score_html, result_html, diff_html]
    )

if __name__ == "__main__":
    demo.launch()