r3gm commited on
Commit
f69f471
·
verified ·
1 Parent(s): 1ccbed2

Upload speech_segmentation.py

Browse files
Files changed (1) hide show
  1. soni_translate/speech_segmentation.py +496 -503
soni_translate/speech_segmentation.py CHANGED
@@ -1,503 +1,496 @@
1
- import spaces
2
- from whisperx.alignment import (
3
- DEFAULT_ALIGN_MODELS_TORCH as DAMT,
4
- DEFAULT_ALIGN_MODELS_HF as DAMHF,
5
- )
6
- from whisperx.utils import TO_LANGUAGE_CODE
7
- import whisperx
8
- import torch
9
- import gc
10
- import os
11
- import soundfile as sf
12
- from IPython.utils import capture # noqa
13
- from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES
14
- from .logging_setup import logger
15
- from .postprocessor import sanitize_file_name
16
- from .utils import remove_directory_contents, run_command
17
-
18
- # ZERO GPU CONFIG
19
- import spaces
20
- import copy
21
- import random
22
- import time
23
-
24
- def random_sleep():
25
- if os.environ.get("ZERO_GPU") == "TRUE":
26
- print("Random sleep")
27
- sleep_time = round(random.uniform(7.2, 9.9), 1)
28
- time.sleep(sleep_time)
29
-
30
-
31
- @spaces.GPU
32
- def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit):
33
- # Load model
34
- model = whisperx.load_model(
35
- asr_model,
36
- os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
37
- compute_type=compute_type,
38
- language=language,
39
- asr_options=asr_options,
40
- )
41
-
42
- # Transcribe audio
43
- result = model.transcribe(
44
- audio,
45
- batch_size=batch_size,
46
- chunk_size=segment_duration_limit,
47
- print_progress=True,
48
- )
49
-
50
- del model
51
- gc.collect()
52
- torch.cuda.empty_cache() # noqa
53
-
54
- return result
55
-
56
- def load_align_and_align_segments(result, audio, DAMHF):
57
-
58
- # Load alignment model
59
- model_a, metadata = whisperx.load_align_model(
60
- language_code=result["language"],
61
- device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cpu",
62
- model_name=None
63
- if result["language"] in DAMHF.keys()
64
- else EXTRA_ALIGN[result["language"]],
65
- )
66
-
67
- # Align segments
68
- alignment_result = whisperx.align(
69
- result["segments"],
70
- model_a,
71
- metadata,
72
- audio,
73
- os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cpu",
74
- return_char_alignments=True,
75
- print_progress=False,
76
- )
77
-
78
- # Clean up
79
- del model_a
80
- gc.collect()
81
- torch.cuda.empty_cache() # noqa
82
-
83
- return alignment_result
84
-
85
- @spaces.GPU
86
- def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers):
87
-
88
- if os.environ.get("ZERO_GPU") == "TRUE":
89
- diarize_model.model.to(torch.device("cuda"))
90
- diarize_segments = diarize_model(
91
- audio_wav,
92
- min_speakers=min_speakers,
93
- max_speakers=max_speakers
94
- )
95
- return diarize_segments
96
-
97
- # ZERO GPU CONFIG
98
-
99
- ASR_MODEL_OPTIONS = [
100
- "tiny",
101
- "base",
102
- "small",
103
- "medium",
104
- "large",
105
- "large-v1",
106
- "large-v2",
107
- "large-v3",
108
- "distil-large-v2",
109
- "Systran/faster-distil-whisper-large-v3",
110
- "tiny.en",
111
- "base.en",
112
- "small.en",
113
- "medium.en",
114
- "distil-small.en",
115
- "distil-medium.en",
116
- "OpenAI_API_Whisper",
117
- ]
118
-
119
- COMPUTE_TYPE_GPU = [
120
- "default",
121
- "auto",
122
- "int8",
123
- "int8_float32",
124
- "int8_float16",
125
- "int8_bfloat16",
126
- "float16",
127
- "bfloat16",
128
- "float32"
129
- ]
130
-
131
- COMPUTE_TYPE_CPU = [
132
- "default",
133
- "auto",
134
- "int8",
135
- "int8_float32",
136
- "int16",
137
- "float32",
138
- ]
139
-
140
- WHISPER_MODELS_PATH = './WHISPER_MODELS'
141
-
142
-
143
- def openai_api_whisper(
144
- input_audio_file,
145
- source_lang=None,
146
- chunk_duration=1800
147
- ):
148
-
149
- info = sf.info(input_audio_file)
150
- duration = info.duration
151
-
152
- output_directory = "./whisper_api_audio_parts"
153
- os.makedirs(output_directory, exist_ok=True)
154
- remove_directory_contents(output_directory)
155
-
156
- if duration > chunk_duration:
157
- # Split the audio file into smaller chunks with 30-minute duration
158
- cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"'
159
- run_command(cm)
160
- # Get list of generated chunk files
161
- chunk_files = sorted(
162
- [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')]
163
- )
164
- else:
165
- one_file = f"{output_directory}/output000.ogg"
166
- cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}'
167
- run_command(cm)
168
- chunk_files = [one_file]
169
-
170
- # Transcript
171
- segments = []
172
- language = source_lang if source_lang else None
173
- for i, chunk in enumerate(chunk_files):
174
- from openai import OpenAI
175
- client = OpenAI()
176
-
177
- audio_file = open(chunk, "rb")
178
- transcription = client.audio.transcriptions.create(
179
- model="whisper-1",
180
- file=audio_file,
181
- language=language,
182
- response_format="verbose_json",
183
- timestamp_granularities=["segment"],
184
- )
185
-
186
- try:
187
- transcript_dict = transcription.model_dump()
188
- except: # noqa
189
- transcript_dict = transcription.to_dict()
190
-
191
- if language is None:
192
- logger.info(f'Language detected: {transcript_dict["language"]}')
193
- language = TO_LANGUAGE_CODE[transcript_dict["language"]]
194
-
195
- chunk_time = chunk_duration * (i)
196
-
197
- for seg in transcript_dict["segments"]:
198
-
199
- if "start" in seg.keys():
200
- segments.append(
201
- {
202
- "text": seg["text"],
203
- "start": seg["start"] + chunk_time,
204
- "end": seg["end"] + chunk_time,
205
- }
206
- )
207
-
208
- audio = whisperx.load_audio(input_audio_file)
209
- result = {"segments": segments, "language": language}
210
-
211
- return audio, result
212
-
213
-
214
- def find_whisper_models():
215
- path = WHISPER_MODELS_PATH
216
- folders = []
217
-
218
- if os.path.exists(path):
219
- for folder in os.listdir(path):
220
- folder_path = os.path.join(path, folder)
221
- if (
222
- os.path.isdir(folder_path)
223
- and 'model.bin' in os.listdir(folder_path)
224
- ):
225
- folders.append(folder)
226
- return folders
227
-
228
- def transcribe_speech(
229
- audio_wav,
230
- asr_model,
231
- compute_type,
232
- batch_size,
233
- SOURCE_LANGUAGE,
234
- literalize_numbers=True,
235
- segment_duration_limit=15,
236
- ):
237
- """
238
- Transcribe speech using a whisper model.
239
-
240
- Parameters:
241
- - audio_wav (str): Path to the audio file in WAV format.
242
- - asr_model (str): The whisper model to be loaded.
243
- - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16').
244
- - batch_size (int): Batch size for transcription.
245
- - SOURCE_LANGUAGE (str): Source language for transcription.
246
-
247
- Returns:
248
- - Tuple containing:
249
- - audio: Loaded audio file.
250
- - result: Transcription result as a dictionary.
251
- """
252
-
253
- if asr_model == "OpenAI_API_Whisper":
254
- if literalize_numbers:
255
- logger.info(
256
- "OpenAI's API Whisper does not support "
257
- "the literalization of numbers."
258
- )
259
- return openai_api_whisper(audio_wav, SOURCE_LANGUAGE)
260
-
261
- # https://github.com/openai/whisper/discussions/277
262
- prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None
263
- SOURCE_LANGUAGE = (
264
- SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh"
265
- )
266
- asr_options = {
267
- "initial_prompt": prompt,
268
- "suppress_numerals": literalize_numbers
269
- }
270
-
271
- if asr_model not in ASR_MODEL_OPTIONS:
272
-
273
- base_dir = WHISPER_MODELS_PATH
274
- if not os.path.exists(base_dir):
275
- os.makedirs(base_dir)
276
- model_dir = os.path.join(base_dir, sanitize_file_name(asr_model))
277
-
278
- if not os.path.exists(model_dir):
279
- from ctranslate2.converters import TransformersConverter
280
-
281
- quantization = "float32"
282
- # Download new model
283
- try:
284
- converter = TransformersConverter(
285
- asr_model,
286
- low_cpu_mem_usage=True,
287
- copy_files=[
288
- "tokenizer_config.json", "preprocessor_config.json"
289
- ]
290
- )
291
- converter.convert(
292
- model_dir,
293
- quantization=quantization,
294
- force=False
295
- )
296
- except Exception as error:
297
- if "File tokenizer_config.json does not exist" in str(error):
298
- converter._copy_files = [
299
- "tokenizer.json", "preprocessor_config.json"
300
- ]
301
- converter.convert(
302
- model_dir,
303
- quantization=quantization,
304
- force=True
305
- )
306
- else:
307
- raise error
308
-
309
- asr_model = model_dir
310
- logger.info(f"ASR Model: {str(model_dir)}")
311
-
312
- audio = whisperx.load_audio(audio_wav)
313
-
314
- result = load_and_transcribe_audio(
315
- asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit
316
- )
317
-
318
- if result["language"] == "zh" and not prompt:
319
- result["language"] = "zh-TW"
320
- logger.info("Chinese - Traditional (zh-TW)")
321
-
322
-
323
- return audio, result
324
-
325
- # if os.environ.get("ZERO_GPU") == "TRUE":
326
- transcribe_speech.zerogpu=True
327
-
328
-
329
- def align_speech(audio, result):
330
- """
331
- Aligns speech segments based on the provided audio and result metadata.
332
-
333
- Parameters:
334
- - audio (array): The audio data in a suitable format for alignment.
335
- - result (dict): Metadata containing information about the segments
336
- and language.
337
-
338
- Returns:
339
- - result (dict): Updated metadata after aligning the segments with
340
- the audio. This includes character-level alignments if
341
- 'return_char_alignments' is set to True.
342
-
343
- Notes:
344
- - This function uses language-specific models to align speech segments.
345
- - It performs language compatibility checks and selects the
346
- appropriate alignment model.
347
- - Cleans up memory by releasing resources after alignment.
348
- """
349
- DAMHF.update(DAMT) # lang align
350
- if (
351
- not result["language"] in DAMHF.keys()
352
- and not result["language"] in EXTRA_ALIGN.keys()
353
- ):
354
- logger.warning(
355
- "Automatic detection: Source language not compatible with align"
356
- )
357
- raise ValueError(
358
- f"Detected language {result['language']} incompatible, "
359
- "you can select the source language to avoid this error."
360
- )
361
- if (
362
- result["language"] in EXTRA_ALIGN.keys()
363
- and EXTRA_ALIGN[result["language"]] == ""
364
- ):
365
- lang_name = (
366
- INVERTED_LANGUAGES[result["language"]]
367
- if result["language"] in INVERTED_LANGUAGES.keys()
368
- else result["language"]
369
- )
370
- logger.warning(
371
- "No compatible wav2vec2 model found "
372
- f"for the language '{lang_name}', skipping alignment."
373
- )
374
- return result
375
-
376
- # random_sleep()
377
- result = load_align_and_align_segments(result, audio, DAMHF)
378
-
379
- return result
380
-
381
-
382
- diarization_models = {
383
- "pyannote_3.1": "pyannote/speaker-diarization-3.1",
384
- "pyannote_2.1": "pyannote/[email protected]",
385
- "disable": "",
386
- }
387
-
388
-
389
- def reencode_speakers(result):
390
-
391
- if result["segments"][0]["speaker"] == "SPEAKER_00":
392
- return result
393
-
394
- speaker_mapping = {}
395
- counter = 0
396
-
397
- logger.debug("Reencode speakers")
398
-
399
- for segment in result["segments"]:
400
- old_speaker = segment["speaker"]
401
- if old_speaker not in speaker_mapping:
402
- speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}"
403
- counter += 1
404
- segment["speaker"] = speaker_mapping[old_speaker]
405
-
406
- return result
407
-
408
-
409
- def diarize_speech(
410
- audio_wav,
411
- result,
412
- min_speakers,
413
- max_speakers,
414
- YOUR_HF_TOKEN,
415
- model_name="pyannote/speaker[email protected]",
416
- ):
417
- """
418
- Performs speaker diarization on speech segments.
419
-
420
- Parameters:
421
- - audio_wav (array): Audio data in WAV format to perform speaker
422
- diarization.
423
- - result (dict): Metadata containing information about speech segments
424
- and alignments.
425
- - min_speakers (int): Minimum number of speakers expected in the audio.
426
- - max_speakers (int): Maximum number of speakers expected in the audio.
427
- - YOUR_HF_TOKEN (str): Your Hugging Face API token for model
428
- authentication.
429
- - model_name (str): Name of the speaker diarization model to be used
430
- (default: "pyannote/[email protected]").
431
-
432
- Returns:
433
- - result_diarize (dict): Updated metadata after assigning speaker
434
- labels to segments.
435
-
436
- Notes:
437
- - This function utilizes a speaker diarization model to label speaker
438
- segments in the audio.
439
- - It assigns speakers to word-level segments based on diarization results.
440
- - Cleans up memory by releasing resources after diarization.
441
- - If only one speaker is specified, each segment is automatically assigned
442
- as the first speaker, eliminating the need for diarization inference.
443
- """
444
-
445
- if max(min_speakers, max_speakers) > 1 and model_name:
446
- try:
447
-
448
- diarize_model = whisperx.DiarizationPipeline(
449
- model_name=model_name,
450
- use_auth_token=YOUR_HF_TOKEN,
451
- device=os.environ.get("SONITR_DEVICE"),
452
- )
453
-
454
- except Exception as error:
455
- error_str = str(error)
456
- gc.collect()
457
- torch.cuda.empty_cache() # noqa
458
- if "'NoneType' object has no attribute 'to'" in error_str:
459
- if model_name == diarization_models["pyannote_2.1"]:
460
- raise ValueError(
461
- "Accept the license agreement for using Pyannote 2.1."
462
- " You need to have an account on Hugging Face and "
463
- "accept the license to use the models: "
464
- "https://huggingface.co/pyannote/speaker-diarization "
465
- "and https://huggingface.co/pyannote/segmentation "
466
- "Get your KEY TOKEN here: "
467
- "https://hf.co/settings/tokens "
468
- )
469
- elif model_name == diarization_models["pyannote_3.1"]:
470
- raise ValueError(
471
- "New Licence Pyannote 3.1: You need to have an account"
472
- " on Hugging Face and accept the license to use the "
473
- "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa
474
- "and https://huggingface.co/pyannote/segmentation-3.0 "
475
- )
476
- else:
477
- raise error
478
-
479
- random_sleep()
480
- diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers)
481
-
482
- result_diarize = whisperx.assign_word_speakers(
483
- diarize_segments, result
484
- )
485
-
486
- for segment in result_diarize["segments"]:
487
- if "speaker" not in segment:
488
- segment["speaker"] = "SPEAKER_00"
489
- logger.warning(
490
- f"No speaker detected in {segment['start']}. First TTS "
491
- f"will be used for the segment text: {segment['text']} "
492
- )
493
-
494
- del diarize_model
495
- gc.collect()
496
- torch.cuda.empty_cache() # noqa
497
- else:
498
- result_diarize = result
499
- result_diarize["segments"] = [
500
- {**item, "speaker": "SPEAKER_00"}
501
- for item in result_diarize["segments"]
502
- ]
503
- return reencode_speakers(result_diarize)
 
1
+ import spaces
2
+ from whisperx.alignment import (
3
+ DEFAULT_ALIGN_MODELS_TORCH as DAMT,
4
+ DEFAULT_ALIGN_MODELS_HF as DAMHF,
5
+ )
6
+ from whisperx.utils import TO_LANGUAGE_CODE
7
+ import whisperx
8
+ import torch
9
+ import gc
10
+ import os
11
+ import soundfile as sf
12
+ from IPython.utils import capture # noqa
13
+ from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES
14
+ from .logging_setup import logger
15
+ from .postprocessor import sanitize_file_name
16
+ from .utils import remove_directory_contents, run_command
17
+
18
+ # ZERO GPU CONFIG
19
+ import spaces
20
+ import copy
21
+ import random
22
+ import time
23
+
24
+
25
+ @spaces.GPU(duration=45)
26
+ def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit):
27
+ # Load model
28
+ model = whisperx.load_model(
29
+ asr_model,
30
+ os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
31
+ compute_type=compute_type,
32
+ language=language,
33
+ asr_options=asr_options,
34
+ )
35
+
36
+ # Transcribe audio
37
+ result = model.transcribe(
38
+ audio,
39
+ batch_size=batch_size,
40
+ chunk_size=segment_duration_limit,
41
+ print_progress=True,
42
+ )
43
+
44
+ del model
45
+ gc.collect()
46
+ torch.cuda.empty_cache() # noqa
47
+
48
+ return result
49
+
50
+
51
+ @spaces.GPU(duration=30)
52
+ def load_align_and_align_segments(result, audio, DAMHF):
53
+ # Load alignment model
54
+ model_a, metadata = whisperx.load_align_model(
55
+ language_code=result["language"],
56
+ # device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cpu", # cpu mode
57
+ device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
58
+ model_name=None
59
+ if result["language"] in DAMHF.keys()
60
+ else EXTRA_ALIGN[result["language"]],
61
+ )
62
+
63
+ # Align segments
64
+ alignment_result = whisperx.align(
65
+ result["segments"],
66
+ model_a,
67
+ metadata,
68
+ audio,
69
+ # os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cpu", # cpu mode
70
+ device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
71
+ return_char_alignments=True,
72
+ print_progress=False,
73
+ )
74
+
75
+ # Clean up
76
+ del model_a
77
+ gc.collect()
78
+ torch.cuda.empty_cache() # noqa
79
+
80
+ return alignment_result
81
+
82
+
83
+ @spaces.GPU(duration=35)
84
+ def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers):
85
+
86
+ if os.environ.get("ZERO_GPU") == "TRUE":
87
+ diarize_model.model.to(torch.device("cuda"))
88
+ diarize_segments = diarize_model(
89
+ audio_wav,
90
+ min_speakers=min_speakers,
91
+ max_speakers=max_speakers
92
+ )
93
+ return diarize_segments
94
+
95
+ # ZERO GPU CONFIG
96
+
97
+ ASR_MODEL_OPTIONS = [
98
+ "tiny",
99
+ "base",
100
+ "small",
101
+ "medium",
102
+ "large",
103
+ "large-v1",
104
+ "large-v2",
105
+ "large-v3",
106
+ "distil-large-v2",
107
+ "Systran/faster-distil-whisper-large-v3",
108
+ "tiny.en",
109
+ "base.en",
110
+ "small.en",
111
+ "medium.en",
112
+ "distil-small.en",
113
+ "distil-medium.en",
114
+ "OpenAI_API_Whisper",
115
+ ]
116
+
117
+ COMPUTE_TYPE_GPU = [
118
+ "default",
119
+ "auto",
120
+ "int8",
121
+ "int8_float32",
122
+ "int8_float16",
123
+ "int8_bfloat16",
124
+ "float16",
125
+ "bfloat16",
126
+ "float32"
127
+ ]
128
+
129
+ COMPUTE_TYPE_CPU = [
130
+ "default",
131
+ "auto",
132
+ "int8",
133
+ "int8_float32",
134
+ "int16",
135
+ "float32",
136
+ ]
137
+
138
+ WHISPER_MODELS_PATH = './WHISPER_MODELS'
139
+
140
+
141
+ def openai_api_whisper(
142
+ input_audio_file,
143
+ source_lang=None,
144
+ chunk_duration=1800
145
+ ):
146
+
147
+ info = sf.info(input_audio_file)
148
+ duration = info.duration
149
+
150
+ output_directory = "./whisper_api_audio_parts"
151
+ os.makedirs(output_directory, exist_ok=True)
152
+ remove_directory_contents(output_directory)
153
+
154
+ if duration > chunk_duration:
155
+ # Split the audio file into smaller chunks with 30-minute duration
156
+ cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"'
157
+ run_command(cm)
158
+ # Get list of generated chunk files
159
+ chunk_files = sorted(
160
+ [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')]
161
+ )
162
+ else:
163
+ one_file = f"{output_directory}/output000.ogg"
164
+ cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}'
165
+ run_command(cm)
166
+ chunk_files = [one_file]
167
+
168
+ # Transcript
169
+ segments = []
170
+ language = source_lang if source_lang else None
171
+ for i, chunk in enumerate(chunk_files):
172
+ from openai import OpenAI
173
+ client = OpenAI()
174
+
175
+ audio_file = open(chunk, "rb")
176
+ transcription = client.audio.transcriptions.create(
177
+ model="whisper-1",
178
+ file=audio_file,
179
+ language=language,
180
+ response_format="verbose_json",
181
+ timestamp_granularities=["segment"],
182
+ )
183
+
184
+ try:
185
+ transcript_dict = transcription.model_dump()
186
+ except: # noqa
187
+ transcript_dict = transcription.to_dict()
188
+
189
+ if language is None:
190
+ logger.info(f'Language detected: {transcript_dict["language"]}')
191
+ language = TO_LANGUAGE_CODE[transcript_dict["language"]]
192
+
193
+ chunk_time = chunk_duration * (i)
194
+
195
+ for seg in transcript_dict["segments"]:
196
+
197
+ if "start" in seg.keys():
198
+ segments.append(
199
+ {
200
+ "text": seg["text"],
201
+ "start": seg["start"] + chunk_time,
202
+ "end": seg["end"] + chunk_time,
203
+ }
204
+ )
205
+
206
+ audio = whisperx.load_audio(input_audio_file)
207
+ result = {"segments": segments, "language": language}
208
+
209
+ return audio, result
210
+
211
+
212
+ def find_whisper_models():
213
+ path = WHISPER_MODELS_PATH
214
+ folders = []
215
+
216
+ if os.path.exists(path):
217
+ for folder in os.listdir(path):
218
+ folder_path = os.path.join(path, folder)
219
+ if (
220
+ os.path.isdir(folder_path)
221
+ and 'model.bin' in os.listdir(folder_path)
222
+ ):
223
+ folders.append(folder)
224
+ return folders
225
+
226
+ def transcribe_speech(
227
+ audio_wav,
228
+ asr_model,
229
+ compute_type,
230
+ batch_size,
231
+ SOURCE_LANGUAGE,
232
+ literalize_numbers=True,
233
+ segment_duration_limit=15,
234
+ ):
235
+ """
236
+ Transcribe speech using a whisper model.
237
+
238
+ Parameters:
239
+ - audio_wav (str): Path to the audio file in WAV format.
240
+ - asr_model (str): The whisper model to be loaded.
241
+ - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16').
242
+ - batch_size (int): Batch size for transcription.
243
+ - SOURCE_LANGUAGE (str): Source language for transcription.
244
+
245
+ Returns:
246
+ - Tuple containing:
247
+ - audio: Loaded audio file.
248
+ - result: Transcription result as a dictionary.
249
+ """
250
+
251
+ if asr_model == "OpenAI_API_Whisper":
252
+ if literalize_numbers:
253
+ logger.info(
254
+ "OpenAI's API Whisper does not support "
255
+ "the literalization of numbers."
256
+ )
257
+ return openai_api_whisper(audio_wav, SOURCE_LANGUAGE)
258
+
259
+ # https://github.com/openai/whisper/discussions/277
260
+ prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None
261
+ SOURCE_LANGUAGE = (
262
+ SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh"
263
+ )
264
+ asr_options = {
265
+ "initial_prompt": prompt,
266
+ "suppress_numerals": literalize_numbers
267
+ }
268
+
269
+ if asr_model not in ASR_MODEL_OPTIONS:
270
+
271
+ base_dir = WHISPER_MODELS_PATH
272
+ if not os.path.exists(base_dir):
273
+ os.makedirs(base_dir)
274
+ model_dir = os.path.join(base_dir, sanitize_file_name(asr_model))
275
+
276
+ if not os.path.exists(model_dir):
277
+ from ctranslate2.converters import TransformersConverter
278
+
279
+ quantization = "float32"
280
+ # Download new model
281
+ try:
282
+ converter = TransformersConverter(
283
+ asr_model,
284
+ low_cpu_mem_usage=True,
285
+ copy_files=[
286
+ "tokenizer_config.json", "preprocessor_config.json"
287
+ ]
288
+ )
289
+ converter.convert(
290
+ model_dir,
291
+ quantization=quantization,
292
+ force=False
293
+ )
294
+ except Exception as error:
295
+ if "File tokenizer_config.json does not exist" in str(error):
296
+ converter._copy_files = [
297
+ "tokenizer.json", "preprocessor_config.json"
298
+ ]
299
+ converter.convert(
300
+ model_dir,
301
+ quantization=quantization,
302
+ force=True
303
+ )
304
+ else:
305
+ raise error
306
+
307
+ asr_model = model_dir
308
+ logger.info(f"ASR Model: {str(model_dir)}")
309
+
310
+ audio = whisperx.load_audio(audio_wav)
311
+
312
+ result = load_and_transcribe_audio(
313
+ asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit
314
+ )
315
+
316
+ if result["language"] == "zh" and not prompt:
317
+ result["language"] = "zh-TW"
318
+ logger.info("Chinese - Traditional (zh-TW)")
319
+
320
+
321
+ return audio, result
322
+
323
+
324
+ def align_speech(audio, result):
325
+ """
326
+ Aligns speech segments based on the provided audio and result metadata.
327
+
328
+ Parameters:
329
+ - audio (array): The audio data in a suitable format for alignment.
330
+ - result (dict): Metadata containing information about the segments
331
+ and language.
332
+
333
+ Returns:
334
+ - result (dict): Updated metadata after aligning the segments with
335
+ the audio. This includes character-level alignments if
336
+ 'return_char_alignments' is set to True.
337
+
338
+ Notes:
339
+ - This function uses language-specific models to align speech segments.
340
+ - It performs language compatibility checks and selects the
341
+ appropriate alignment model.
342
+ - Cleans up memory by releasing resources after alignment.
343
+ """
344
+ DAMHF.update(DAMT) # lang align
345
+ if (
346
+ not result["language"] in DAMHF.keys()
347
+ and not result["language"] in EXTRA_ALIGN.keys()
348
+ ):
349
+ logger.warning(
350
+ "Automatic detection: Source language not compatible with align"
351
+ )
352
+ raise ValueError(
353
+ f"Detected language {result['language']} incompatible, "
354
+ "you can select the source language to avoid this error."
355
+ )
356
+ if (
357
+ result["language"] in EXTRA_ALIGN.keys()
358
+ and EXTRA_ALIGN[result["language"]] == ""
359
+ ):
360
+ lang_name = (
361
+ INVERTED_LANGUAGES[result["language"]]
362
+ if result["language"] in INVERTED_LANGUAGES.keys()
363
+ else result["language"]
364
+ )
365
+ logger.warning(
366
+ "No compatible wav2vec2 model found "
367
+ f"for the language '{lang_name}', skipping alignment."
368
+ )
369
+ return result
370
+
371
+ result = load_align_and_align_segments(result, audio, DAMHF)
372
+
373
+ return result
374
+
375
+
376
+ diarization_models = {
377
+ "pyannote_3.1": "pyannote/speaker-diarization-3.1",
378
+ "pyannote_2.1": "pyannote/[email protected]",
379
+ "disable": "",
380
+ }
381
+
382
+
383
+ def reencode_speakers(result):
384
+
385
+ if result["segments"][0]["speaker"] == "SPEAKER_00":
386
+ return result
387
+
388
+ speaker_mapping = {}
389
+ counter = 0
390
+
391
+ logger.debug("Reencode speakers")
392
+
393
+ for segment in result["segments"]:
394
+ old_speaker = segment["speaker"]
395
+ if old_speaker not in speaker_mapping:
396
+ speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}"
397
+ counter += 1
398
+ segment["speaker"] = speaker_mapping[old_speaker]
399
+
400
+ return result
401
+
402
+
403
+ def diarize_speech(
404
+ audio_wav,
405
+ result,
406
+ min_speakers,
407
+ max_speakers,
408
+ YOUR_HF_TOKEN,
409
+ model_name="pyannote/[email protected]",
410
+ ):
411
+ """
412
+ Performs speaker diarization on speech segments.
413
+
414
+ Parameters:
415
+ - audio_wav (array): Audio data in WAV format to perform speaker
416
+ diarization.
417
+ - result (dict): Metadata containing information about speech segments
418
+ and alignments.
419
+ - min_speakers (int): Minimum number of speakers expected in the audio.
420
+ - max_speakers (int): Maximum number of speakers expected in the audio.
421
+ - YOUR_HF_TOKEN (str): Your Hugging Face API token for model
422
+ authentication.
423
+ - model_name (str): Name of the speaker diarization model to be used
424
+ (default: "pyannote/speaker-diarization@2.1").
425
+
426
+ Returns:
427
+ - result_diarize (dict): Updated metadata after assigning speaker
428
+ labels to segments.
429
+
430
+ Notes:
431
+ - This function utilizes a speaker diarization model to label speaker
432
+ segments in the audio.
433
+ - It assigns speakers to word-level segments based on diarization results.
434
+ - Cleans up memory by releasing resources after diarization.
435
+ - If only one speaker is specified, each segment is automatically assigned
436
+ as the first speaker, eliminating the need for diarization inference.
437
+ """
438
+
439
+ if max(min_speakers, max_speakers) > 1 and model_name:
440
+ try:
441
+
442
+ diarize_model = whisperx.DiarizationPipeline(
443
+ model_name=model_name,
444
+ use_auth_token=YOUR_HF_TOKEN,
445
+ device=os.environ.get("SONITR_DEVICE"),
446
+ )
447
+
448
+ except Exception as error:
449
+ error_str = str(error)
450
+ gc.collect()
451
+ torch.cuda.empty_cache() # noqa
452
+ if "'NoneType' object has no attribute 'to'" in error_str:
453
+ if model_name == diarization_models["pyannote_2.1"]:
454
+ raise ValueError(
455
+ "Accept the license agreement for using Pyannote 2.1."
456
+ " You need to have an account on Hugging Face and "
457
+ "accept the license to use the models: "
458
+ "https://huggingface.co/pyannote/speaker-diarization "
459
+ "and https://huggingface.co/pyannote/segmentation "
460
+ "Get your KEY TOKEN here: "
461
+ "https://hf.co/settings/tokens "
462
+ )
463
+ elif model_name == diarization_models["pyannote_3.1"]:
464
+ raise ValueError(
465
+ "New Licence Pyannote 3.1: You need to have an account"
466
+ " on Hugging Face and accept the license to use the "
467
+ "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa
468
+ "and https://huggingface.co/pyannote/segmentation-3.0 "
469
+ )
470
+ else:
471
+ raise error
472
+
473
+ diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers)
474
+
475
+ result_diarize = whisperx.assign_word_speakers(
476
+ diarize_segments, result
477
+ )
478
+
479
+ for segment in result_diarize["segments"]:
480
+ if "speaker" not in segment:
481
+ segment["speaker"] = "SPEAKER_00"
482
+ logger.warning(
483
+ f"No speaker detected in {segment['start']}. First TTS "
484
+ f"will be used for the segment text: {segment['text']} "
485
+ )
486
+
487
+ del diarize_model
488
+ gc.collect()
489
+ torch.cuda.empty_cache() # noqa
490
+ else:
491
+ result_diarize = result
492
+ result_diarize["segments"] = [
493
+ {**item, "speaker": "SPEAKER_00"}
494
+ for item in result_diarize["segments"]
495
+ ]
496
+ return reencode_speakers(result_diarize)