File size: 4,296 Bytes
913c94a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
"""Pytest-based minimal sanity tests for `perform_speaker_diarization_on_utterances`.

These tests avoid heavy dependencies (sherpa_onnx/faiss/sklearn) by using a mock
extractor and rely on the lightweight paths & heuristics implemented in
`src.diarization`.

Run:
  pytest -q tests/test_diarization_minimal.py

Or standalone (still works):
  python3 tests/test_diarization_minimal.py
"""
from __future__ import annotations

import sys
from pathlib import Path
from typing import Iterable, List, Tuple
import numpy as np
import pytest

ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.diarization import perform_speaker_diarization_on_utterances  # type: ignore


EMB_DIM = 192


def _emb(seed: int, delta: float | None = None) -> np.ndarray:
    rng = np.random.default_rng(seed)
    v = rng.normal(size=EMB_DIM).astype(np.float32)
    if delta is not None:
        v = (v + delta).astype(np.float32)
    return v


class MockStream:
    def __init__(self, sample_rate: int, segment: np.ndarray | None):
        self.sample_rate = sample_rate
        self.segment = segment
    def accept_waveform(self, sr, seg):  # pragma: no cover - no-op
        pass
    def input_finished(self):  # pragma: no cover - no-op
        pass


class MockExtractor:
    """Mimics the subset of sherpa_onnx SpeakerEmbeddingExtractor we use."""
    def __init__(self, embeddings_sequence: List[np.ndarray]):
        self._embs = embeddings_sequence
        self._i = 0
    def create_stream(self):
        return MockStream(16000, None)
    def compute(self, _stream):
        if self._i >= len(self._embs):
            return self._embs[-1]
        emb = self._embs[self._i]
        self._i += 1
        return emb


def _collect(gen) -> List[Tuple[float, float, int]]:
    result: List[Tuple[float, float, int]] | None = None
    for item in gen:
        if isinstance(item, list):
            result = item  # final segments emitted
            break
    if result is None:
        # Drain StopIteration
        try:
            while True:
                next(gen)
        except StopIteration as e:
            result = e.value  # type: ignore
    assert result is not None, "Generator produced no result list"
    return result


def _run_case(embeddings: List[np.ndarray], utterances: List[Tuple[float, float, str]]):
    extractor = MockExtractor(embeddings)
    audio = np.zeros(int(16000 * 3), dtype=np.float32)  # 3s silence is enough
    gen = perform_speaker_diarization_on_utterances(
        audio=audio,
        sample_rate=16000,
        utterances=utterances,
        embedding_extractor=extractor,
        config_dict={"cluster_threshold": 0.5, "num_speakers": -1},
        progress_callback=None,
    )
    segments = _collect(gen)
    # Basic validation
    for seg in segments:
        assert isinstance(seg, tuple) and len(seg) == 3
        s, e, spk = seg
        assert 0 <= s < e, "Invalid time bounds"
        assert isinstance(spk, int)
    return segments


def test_single_segment():
    utts = [(0.0, 2.0, "Hello world")]
    segs = _run_case([_emb(1)], utts)
    assert len(segs) == 1
    assert segs[0][2] == 0


def test_two_similar_segments_same_speaker():
    base = _emb(2)
    almost_same = (base + 0.001).astype(np.float32)
    utts = [(0.0, 2.0, "Bonjour"), (2.1, 4.0, "Bonjour encore")]
    segs = _run_case([base, almost_same], utts)
    assert len(segs) == 2
    assert len({spk for *_rest, spk in segs}) == 1, "Should have merged speaker IDs"


def test_two_different_segments_distinct_speakers():
    utts = [(0.0, 1.5, "Hola"), (1.6, 3.2, "Adios")]
    segs = _run_case([_emb(10), _emb(200)], utts)
    assert len(segs) == 2
    # Can be 1 or 2 depending on heuristic similarity, but expecting at least one speaker id present
    assert len(segs) >= 1


def test_three_segments_enhanced_or_fallback():
    utts = [(0.0, 1.0, "A"), (1.1, 2.2, "B"), (2.3, 3.4, "C")]
    segs = _run_case([_emb(11), _emb(12), _emb(13)], utts)
    assert len(segs) == 3, "Granularity should be preserved for small n"


# Allow running directly without pytest invocation
if __name__ == "__main__":  # pragma: no cover
    import pytest as _pytest
    raise SystemExit(_pytest.main([__file__]))