File size: 4,296 Bytes
913c94a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
#!/usr/bin/env python3
"""Pytest-based minimal sanity tests for `perform_speaker_diarization_on_utterances`.
These tests avoid heavy dependencies (sherpa_onnx/faiss/sklearn) by using a mock
extractor and rely on the lightweight paths & heuristics implemented in
`src.diarization`.
Run:
pytest -q tests/test_diarization_minimal.py
Or standalone (still works):
python3 tests/test_diarization_minimal.py
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Iterable, List, Tuple
import numpy as np
import pytest
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.diarization import perform_speaker_diarization_on_utterances # type: ignore
EMB_DIM = 192
def _emb(seed: int, delta: float | None = None) -> np.ndarray:
rng = np.random.default_rng(seed)
v = rng.normal(size=EMB_DIM).astype(np.float32)
if delta is not None:
v = (v + delta).astype(np.float32)
return v
class MockStream:
def __init__(self, sample_rate: int, segment: np.ndarray | None):
self.sample_rate = sample_rate
self.segment = segment
def accept_waveform(self, sr, seg): # pragma: no cover - no-op
pass
def input_finished(self): # pragma: no cover - no-op
pass
class MockExtractor:
"""Mimics the subset of sherpa_onnx SpeakerEmbeddingExtractor we use."""
def __init__(self, embeddings_sequence: List[np.ndarray]):
self._embs = embeddings_sequence
self._i = 0
def create_stream(self):
return MockStream(16000, None)
def compute(self, _stream):
if self._i >= len(self._embs):
return self._embs[-1]
emb = self._embs[self._i]
self._i += 1
return emb
def _collect(gen) -> List[Tuple[float, float, int]]:
result: List[Tuple[float, float, int]] | None = None
for item in gen:
if isinstance(item, list):
result = item # final segments emitted
break
if result is None:
# Drain StopIteration
try:
while True:
next(gen)
except StopIteration as e:
result = e.value # type: ignore
assert result is not None, "Generator produced no result list"
return result
def _run_case(embeddings: List[np.ndarray], utterances: List[Tuple[float, float, str]]):
extractor = MockExtractor(embeddings)
audio = np.zeros(int(16000 * 3), dtype=np.float32) # 3s silence is enough
gen = perform_speaker_diarization_on_utterances(
audio=audio,
sample_rate=16000,
utterances=utterances,
embedding_extractor=extractor,
config_dict={"cluster_threshold": 0.5, "num_speakers": -1},
progress_callback=None,
)
segments = _collect(gen)
# Basic validation
for seg in segments:
assert isinstance(seg, tuple) and len(seg) == 3
s, e, spk = seg
assert 0 <= s < e, "Invalid time bounds"
assert isinstance(spk, int)
return segments
def test_single_segment():
utts = [(0.0, 2.0, "Hello world")]
segs = _run_case([_emb(1)], utts)
assert len(segs) == 1
assert segs[0][2] == 0
def test_two_similar_segments_same_speaker():
base = _emb(2)
almost_same = (base + 0.001).astype(np.float32)
utts = [(0.0, 2.0, "Bonjour"), (2.1, 4.0, "Bonjour encore")]
segs = _run_case([base, almost_same], utts)
assert len(segs) == 2
assert len({spk for *_rest, spk in segs}) == 1, "Should have merged speaker IDs"
def test_two_different_segments_distinct_speakers():
utts = [(0.0, 1.5, "Hola"), (1.6, 3.2, "Adios")]
segs = _run_case([_emb(10), _emb(200)], utts)
assert len(segs) == 2
# Can be 1 or 2 depending on heuristic similarity, but expecting at least one speaker id present
assert len(segs) >= 1
def test_three_segments_enhanced_or_fallback():
utts = [(0.0, 1.0, "A"), (1.1, 2.2, "B"), (2.3, 3.4, "C")]
segs = _run_case([_emb(11), _emb(12), _emb(13)], utts)
assert len(segs) == 3, "Granularity should be preserved for small n"
# Allow running directly without pytest invocation
if __name__ == "__main__": # pragma: no cover
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))
|