|
|
|
|
|
"""Pytest-based minimal sanity tests for `perform_speaker_diarization_on_utterances`. |
|
|
|
|
|
These tests avoid heavy dependencies (sherpa_onnx/faiss/sklearn) by using a mock |
|
|
extractor and rely on the lightweight paths & heuristics implemented in |
|
|
`src.diarization`. |
|
|
|
|
|
Run: |
|
|
pytest -q tests/test_diarization_minimal.py |
|
|
|
|
|
Or standalone (still works): |
|
|
python3 tests/test_diarization_minimal.py |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Iterable, List, Tuple |
|
|
import numpy as np |
|
|
import pytest |
|
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent |
|
|
if str(ROOT) not in sys.path: |
|
|
sys.path.insert(0, str(ROOT)) |
|
|
|
|
|
from src.diarization import perform_speaker_diarization_on_utterances |
|
|
|
|
|
|
|
|
EMB_DIM = 192 |
|
|
|
|
|
|
|
|
def _emb(seed: int, delta: float | None = None) -> np.ndarray: |
|
|
rng = np.random.default_rng(seed) |
|
|
v = rng.normal(size=EMB_DIM).astype(np.float32) |
|
|
if delta is not None: |
|
|
v = (v + delta).astype(np.float32) |
|
|
return v |
|
|
|
|
|
|
|
|
class MockStream: |
|
|
def __init__(self, sample_rate: int, segment: np.ndarray | None): |
|
|
self.sample_rate = sample_rate |
|
|
self.segment = segment |
|
|
def accept_waveform(self, sr, seg): |
|
|
pass |
|
|
def input_finished(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class MockExtractor: |
|
|
"""Mimics the subset of sherpa_onnx SpeakerEmbeddingExtractor we use.""" |
|
|
def __init__(self, embeddings_sequence: List[np.ndarray]): |
|
|
self._embs = embeddings_sequence |
|
|
self._i = 0 |
|
|
def create_stream(self): |
|
|
return MockStream(16000, None) |
|
|
def compute(self, _stream): |
|
|
if self._i >= len(self._embs): |
|
|
return self._embs[-1] |
|
|
emb = self._embs[self._i] |
|
|
self._i += 1 |
|
|
return emb |
|
|
|
|
|
|
|
|
def _collect(gen) -> List[Tuple[float, float, int]]: |
|
|
result: List[Tuple[float, float, int]] | None = None |
|
|
for item in gen: |
|
|
if isinstance(item, list): |
|
|
result = item |
|
|
break |
|
|
if result is None: |
|
|
|
|
|
try: |
|
|
while True: |
|
|
next(gen) |
|
|
except StopIteration as e: |
|
|
result = e.value |
|
|
assert result is not None, "Generator produced no result list" |
|
|
return result |
|
|
|
|
|
|
|
|
def _run_case(embeddings: List[np.ndarray], utterances: List[Tuple[float, float, str]]): |
|
|
extractor = MockExtractor(embeddings) |
|
|
audio = np.zeros(int(16000 * 3), dtype=np.float32) |
|
|
gen = perform_speaker_diarization_on_utterances( |
|
|
audio=audio, |
|
|
sample_rate=16000, |
|
|
utterances=utterances, |
|
|
embedding_extractor=extractor, |
|
|
config_dict={"cluster_threshold": 0.5, "num_speakers": -1}, |
|
|
progress_callback=None, |
|
|
) |
|
|
segments = _collect(gen) |
|
|
|
|
|
for seg in segments: |
|
|
assert isinstance(seg, tuple) and len(seg) == 3 |
|
|
s, e, spk = seg |
|
|
assert 0 <= s < e, "Invalid time bounds" |
|
|
assert isinstance(spk, int) |
|
|
return segments |
|
|
|
|
|
|
|
|
def test_single_segment(): |
|
|
utts = [(0.0, 2.0, "Hello world")] |
|
|
segs = _run_case([_emb(1)], utts) |
|
|
assert len(segs) == 1 |
|
|
assert segs[0][2] == 0 |
|
|
|
|
|
|
|
|
def test_two_similar_segments_same_speaker(): |
|
|
base = _emb(2) |
|
|
almost_same = (base + 0.001).astype(np.float32) |
|
|
utts = [(0.0, 2.0, "Bonjour"), (2.1, 4.0, "Bonjour encore")] |
|
|
segs = _run_case([base, almost_same], utts) |
|
|
assert len(segs) == 2 |
|
|
assert len({spk for *_rest, spk in segs}) == 1, "Should have merged speaker IDs" |
|
|
|
|
|
|
|
|
def test_two_different_segments_distinct_speakers(): |
|
|
utts = [(0.0, 1.5, "Hola"), (1.6, 3.2, "Adios")] |
|
|
segs = _run_case([_emb(10), _emb(200)], utts) |
|
|
assert len(segs) == 2 |
|
|
|
|
|
assert len(segs) >= 1 |
|
|
|
|
|
|
|
|
def test_three_segments_enhanced_or_fallback(): |
|
|
utts = [(0.0, 1.0, "A"), (1.1, 2.2, "B"), (2.3, 3.4, "C")] |
|
|
segs = _run_case([_emb(11), _emb(12), _emb(13)], utts) |
|
|
assert len(segs) == 3, "Granularity should be preserved for small n" |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import pytest as _pytest |
|
|
raise SystemExit(_pytest.main([__file__])) |
|
|
|