File size: 2,773 Bytes
bf31d48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"""
WhisperX with speaker diarization.
This example demonstrates how to use WhisperX with speaker diarization
to identify different speakers in multi-speaker audio.
Prerequisites:
1. Install pyannote.audio: pip install pyannote.audio
2. Accept pyannote model conditions on HuggingFace:
- https://huggingface.co/pyannote/speaker-diarization-3.1
- https://huggingface.co/pyannote/segmentation-3.0
3. Set HuggingFace token: export HF_TOKEN=your_token_here
or: huggingface-cli login
"""
import os
from vllm import LLM
from vllm.model_executor.models.whisperx_pipeline import create_whisperx_pipeline
# Check for HuggingFace token
if "HF_TOKEN" not in os.environ and "HUGGING_FACE_HUB_TOKEN" not in os.environ:
print("Warning: HF_TOKEN not found. Diarization may fail.")
print("Set it with: export HF_TOKEN=your_token_here")
# Initialize WhisperX model
llm = LLM(
model="openai/whisper-large-v3",
trust_remote_code=True,
)
# Get the model instance
model = llm.llm_engine.model_executor.driver_worker.model_runner.model
# Create WhisperX pipeline with both alignment and diarization
pipeline = create_whisperx_pipeline(
model=model,
enable_alignment=True, # Required for diarization to work well
enable_diarization=True,
language="en",
min_speakers=1, # Optional: minimum number of speakers
max_speakers=5, # Optional: maximum number of speakers
# num_speakers=2, # Optional: exact number of speakers if known
)
# Path to your audio file
audio_path = "path/to/your/multi_speaker_audio.wav"
# Transcribe with alignment and diarization
print("Transcribing with speaker diarization...")
print("This may take a few minutes on first run (downloading models)...\n")
result = pipeline.transcribe(
audio=audio_path,
language="en",
task="transcribe",
)
# Print results with speaker labels
print(f"Full transcription: {result['text']}\n")
print("Segments with speaker labels:")
print("-" * 80)
for segment in result["segments"]:
speaker = segment.get("speaker", "UNKNOWN")
print(f"\n[Speaker {speaker}] [{segment['start']:.2f}s - {segment['end']:.2f}s]")
print(f" {segment['text']}")
if "words" in segment:
print(" Words:")
for word in segment["words"]:
word_speaker = word.get("speaker", speaker)
print(
f" [{word['start']:.2f}s - {word['end']:.2f}s] "
f"{word['word']} (Speaker: {word_speaker})"
)
# Speaker embeddings (if needed for downstream tasks)
if "speaker_embeddings" in result:
print("\n\nSpeaker Embeddings:")
for speaker_id, embedding in result["speaker_embeddings"].items():
print(f" {speaker_id}: {len(embedding)}-dimensional embedding")
# Cleanup
pipeline.cleanup()
|