VoxSum / tests /test_diarization_minimal.py
Luigi's picture
Consolidate tests under tests/, add LLM default tests with opt-out flag, model selection, README update
913c94a
raw
history blame
4.3 kB
#!/usr/bin/env python3
"""Pytest-based minimal sanity tests for `perform_speaker_diarization_on_utterances`.
These tests avoid heavy dependencies (sherpa_onnx/faiss/sklearn) by using a mock
extractor and rely on the lightweight paths & heuristics implemented in
`src.diarization`.
Run:
pytest -q tests/test_diarization_minimal.py
Or standalone (still works):
python3 tests/test_diarization_minimal.py
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Iterable, List, Tuple
import numpy as np
import pytest
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.diarization import perform_speaker_diarization_on_utterances # type: ignore
EMB_DIM = 192
def _emb(seed: int, delta: float | None = None) -> np.ndarray:
rng = np.random.default_rng(seed)
v = rng.normal(size=EMB_DIM).astype(np.float32)
if delta is not None:
v = (v + delta).astype(np.float32)
return v
class MockStream:
def __init__(self, sample_rate: int, segment: np.ndarray | None):
self.sample_rate = sample_rate
self.segment = segment
def accept_waveform(self, sr, seg): # pragma: no cover - no-op
pass
def input_finished(self): # pragma: no cover - no-op
pass
class MockExtractor:
"""Mimics the subset of sherpa_onnx SpeakerEmbeddingExtractor we use."""
def __init__(self, embeddings_sequence: List[np.ndarray]):
self._embs = embeddings_sequence
self._i = 0
def create_stream(self):
return MockStream(16000, None)
def compute(self, _stream):
if self._i >= len(self._embs):
return self._embs[-1]
emb = self._embs[self._i]
self._i += 1
return emb
def _collect(gen) -> List[Tuple[float, float, int]]:
result: List[Tuple[float, float, int]] | None = None
for item in gen:
if isinstance(item, list):
result = item # final segments emitted
break
if result is None:
# Drain StopIteration
try:
while True:
next(gen)
except StopIteration as e:
result = e.value # type: ignore
assert result is not None, "Generator produced no result list"
return result
def _run_case(embeddings: List[np.ndarray], utterances: List[Tuple[float, float, str]]):
extractor = MockExtractor(embeddings)
audio = np.zeros(int(16000 * 3), dtype=np.float32) # 3s silence is enough
gen = perform_speaker_diarization_on_utterances(
audio=audio,
sample_rate=16000,
utterances=utterances,
embedding_extractor=extractor,
config_dict={"cluster_threshold": 0.5, "num_speakers": -1},
progress_callback=None,
)
segments = _collect(gen)
# Basic validation
for seg in segments:
assert isinstance(seg, tuple) and len(seg) == 3
s, e, spk = seg
assert 0 <= s < e, "Invalid time bounds"
assert isinstance(spk, int)
return segments
def test_single_segment():
utts = [(0.0, 2.0, "Hello world")]
segs = _run_case([_emb(1)], utts)
assert len(segs) == 1
assert segs[0][2] == 0
def test_two_similar_segments_same_speaker():
base = _emb(2)
almost_same = (base + 0.001).astype(np.float32)
utts = [(0.0, 2.0, "Bonjour"), (2.1, 4.0, "Bonjour encore")]
segs = _run_case([base, almost_same], utts)
assert len(segs) == 2
assert len({spk for *_rest, spk in segs}) == 1, "Should have merged speaker IDs"
def test_two_different_segments_distinct_speakers():
utts = [(0.0, 1.5, "Hola"), (1.6, 3.2, "Adios")]
segs = _run_case([_emb(10), _emb(200)], utts)
assert len(segs) == 2
# Can be 1 or 2 depending on heuristic similarity, but expecting at least one speaker id present
assert len(segs) >= 1
def test_three_segments_enhanced_or_fallback():
utts = [(0.0, 1.0, "A"), (1.1, 2.2, "B"), (2.3, 3.4, "C")]
segs = _run_case([_emb(11), _emb(12), _emb(13)], utts)
assert len(segs) == 3, "Granularity should be preserved for small n"
# Allow running directly without pytest invocation
if __name__ == "__main__": # pragma: no cover
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))