VoxSum / src /asr.py
Luigi's picture
status in result tab got more informative
030e33b
raw
history blame
6.35 kB
# asr.py
import os
import re
import tempfile
from typing import Iterable, List, Optional, Tuple
import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
# Lazy / optional imports: guard heavy or optional ASR backends
try:
from silero_vad import load_silero_vad, VADIterator
except Exception:
load_silero_vad = None
VADIterator = None
try:
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
except Exception:
MoonshineOnnxModel = None
load_tokenizer = None
from .utils import load_sensevoice_model, s2tw_converter
SAMPLING_RATE = 16000
CHUNK_SIZE = 512
# tokenizer will be initialized lazily when moonshine backend is used
tokenizer = None
def clean_transcript(text):
text = re.sub(r'[�\uFFFD��]', '', text)
text = re.sub(r'([\u4e00-\u9fa5])\1{2,}', r'\1', text)
text = re.sub(r'([\u4e00-\u9fa5]) ([ \u4e00-\u9fa5])', r'\1\2', text)
return text
def transcribe_file(
audio_path: str,
vad_threshold: float,
model_name: str,
backend: str = "moonshine",
language: str = "auto",
textnorm: str = "withitn",
) -> Iterable[Tuple[Optional[Tuple[float, float, str]], List[Tuple[float, float, str]], float]]:
"""
Transcribe audio file using specified backend.
Args:
audio_path: Path to audio file
vad_threshold: VAD threshold (0-1)
model_name: Model name (backend-specific)
backend: Either "moonshine" or "sensevoice"
language: Language for sensevoice (auto or specific language code)
textnorm: Text normalization for sensevoice ("withitn" or "noitn")
Yields:
Tuple of (current_utterance, all_utterances)
"""
if load_silero_vad is None or VADIterator is None:
raise RuntimeError("silero_vad is not available. Please install 'silero-vad' package.")
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(model=vad_model, sampling_rate=SAMPLING_RATE, threshold=vad_threshold)
# Initialize backend model lazily and check availability
if backend == "moonshine":
if MoonshineOnnxModel is None or load_tokenizer is None:
raise RuntimeError("moonshine_onnx is not available. Install the dependency or choose 'sensevoice' backend.")
model = MoonshineOnnxModel(model_name=f"moonshine/{model_name}")
global tokenizer
if tokenizer is None:
tokenizer = load_tokenizer()
elif backend == "sensevoice":
model = load_sensevoice_model(model_name)
else:
raise ValueError(f"Unknown backend: {backend}")
wav, orig_sr = sf.read(audio_path, dtype='float32')
if orig_sr != SAMPLING_RATE:
gcd = np.gcd(int(orig_sr), SAMPLING_RATE)
up = SAMPLING_RATE // gcd
down = orig_sr // gcd
wav = resample_poly(wav, up, down)
if wav.ndim > 1:
wav = wav.mean(axis=1)
utterances = [] # Store all utterances (start, end, text)
speech_chunks = [] # List to accumulate speech chunks
segment_start = 0.0 # Track start time of current segment
i = 0
while i < len(wav):
chunk = wav[i:i + CHUNK_SIZE]
if len(chunk) < CHUNK_SIZE:
chunk = np.pad(chunk, (0, CHUNK_SIZE - len(chunk)), mode='constant')
i += CHUNK_SIZE
speech_dict = vad_iterator(chunk)
speech_chunks.append(chunk)
if speech_dict:
if "end" in speech_dict:
# Calculate timestamps
segment_end = i / SAMPLING_RATE
# Concatenate speech chunks into buffer
speech_buffer = np.concatenate(speech_chunks)
if backend == "moonshine":
text = model.generate(speech_buffer[np.newaxis, :].astype(np.float32))
text = tokenizer.decode_batch(text)[0].strip()
if text:
cleaned_text = clean_transcript(s2tw_converter.convert(text))
elif backend == "sensevoice":
# For sherpa-onnx, process directly without temp file
stream = model.create_stream()
stream.accept_waveform(SAMPLING_RATE, speech_buffer)
model.decode_stream(stream)
result = stream.result
text = result.text
# The language info is in result.lang, but we can't modify it
cleaned_text = clean_transcript(s2tw_converter.convert(text))
if text:
utterances.append((segment_start, segment_end, cleaned_text))
progress = min(100, (i / len(wav)) * 100)
yield utterances[-1], utterances.copy(), progress
# Reset for next segment
speech_chunks = []
segment_start = i / SAMPLING_RATE # Start of next segment
vad_iterator.reset_states()
# Process final segment
if speech_chunks:
speech_buffer = np.concatenate(speech_chunks)
if len(speech_buffer) > SAMPLING_RATE * 0.5:
segment_end = len(wav) / SAMPLING_RATE
if backend == "moonshine":
text = model.generate(speech_buffer[np.newaxis, :].astype(np.float32))
text = tokenizer.decode_batch(text)[0].strip()
if text:
cleaned_text = clean_transcript(s2tw_converter.convert(text))
elif backend == "sensevoice":
# For sherpa-onnx, process directly without temp file
stream = model.create_stream()
stream.accept_waveform(SAMPLING_RATE, speech_buffer)
model.decode_stream(stream)
result = stream.result
text = result.text
# The language info is in result.lang, but we can't modify it
cleaned_text = clean_transcript(s2tw_converter.convert(text))
if text:
utterances.append((segment_start, segment_end, cleaned_text))
yield utterances[-1], utterances.copy(), 100.0
# Final yield with all utterances
if utterances:
yield None, utterances, 100.0
else:
yield None, [(-1, -1, "No speech detected")], 100.0