whisperx-vllm / vllm /examples /offline_inference /whisperx_diarization.py
WhisperX Team
WhisperX-vLLM: Production-ready integration (HF release)
bf31d48
"""
WhisperX with speaker diarization.
This example demonstrates how to use WhisperX with speaker diarization
to identify different speakers in multi-speaker audio.
Prerequisites:
1. Install pyannote.audio: pip install pyannote.audio
2. Accept pyannote model conditions on HuggingFace:
- https://huggingface.co/pyannote/speaker-diarization-3.1
- https://huggingface.co/pyannote/segmentation-3.0
3. Set HuggingFace token: export HF_TOKEN=your_token_here
or: huggingface-cli login
"""
import os
from vllm import LLM
from vllm.model_executor.models.whisperx_pipeline import create_whisperx_pipeline
# Check for HuggingFace token
if "HF_TOKEN" not in os.environ and "HUGGING_FACE_HUB_TOKEN" not in os.environ:
print("Warning: HF_TOKEN not found. Diarization may fail.")
print("Set it with: export HF_TOKEN=your_token_here")
# Initialize WhisperX model
llm = LLM(
model="openai/whisper-large-v3",
trust_remote_code=True,
)
# Get the model instance
model = llm.llm_engine.model_executor.driver_worker.model_runner.model
# Create WhisperX pipeline with both alignment and diarization
pipeline = create_whisperx_pipeline(
model=model,
enable_alignment=True, # Required for diarization to work well
enable_diarization=True,
language="en",
min_speakers=1, # Optional: minimum number of speakers
max_speakers=5, # Optional: maximum number of speakers
# num_speakers=2, # Optional: exact number of speakers if known
)
# Path to your audio file
audio_path = "path/to/your/multi_speaker_audio.wav"
# Transcribe with alignment and diarization
print("Transcribing with speaker diarization...")
print("This may take a few minutes on first run (downloading models)...\n")
result = pipeline.transcribe(
audio=audio_path,
language="en",
task="transcribe",
)
# Print results with speaker labels
print(f"Full transcription: {result['text']}\n")
print("Segments with speaker labels:")
print("-" * 80)
for segment in result["segments"]:
speaker = segment.get("speaker", "UNKNOWN")
print(f"\n[Speaker {speaker}] [{segment['start']:.2f}s - {segment['end']:.2f}s]")
print(f" {segment['text']}")
if "words" in segment:
print(" Words:")
for word in segment["words"]:
word_speaker = word.get("speaker", speaker)
print(
f" [{word['start']:.2f}s - {word['end']:.2f}s] "
f"{word['word']} (Speaker: {word_speaker})"
)
# Speaker embeddings (if needed for downstream tasks)
if "speaker_embeddings" in result:
print("\n\nSpeaker Embeddings:")
for speaker_id, embedding in result["speaker_embeddings"].items():
print(f" {speaker_id}: {len(embedding)}-dimensional embedding")
# Cleanup
pipeline.cleanup()