Spaces:
Runtime error
Runtime error
| import io | |
| import re | |
| import librosa | |
| import torch | |
| import torchaudio | |
| from cachetools import LRUCache, cached | |
| CACHE_MAXSIZE = 10000 | |
| MICRO_BATCH_SIZE = 8 | |
| ASR_SAMPLE_RATE = 16000 | |
| HUGE_GAP_THRESHOLD = 4000 | |
| def batch_encode(model, audios_list: list[bytes]): | |
| audios: list[torch.Tensor] = [ | |
| ( | |
| torch.from_numpy( | |
| librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] | |
| )[None] | |
| if isinstance(audio, bytes) | |
| else audio | |
| ) | |
| for audio in audios_list | |
| ] | |
| lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) | |
| max_length = lengths.max().item() | |
| print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") | |
| padded = torch.stack( | |
| [ | |
| torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1]))) | |
| for audio in audios | |
| ] | |
| ).to(model.device) | |
| features, feature_lengths = model.encode(padded, audio_lengths=lengths) | |
| features, feature_lengths = features.cpu(), feature_lengths.cpu() | |
| return [feature[..., :length] for feature, length in zip(features, feature_lengths)] | |
| def cached_vqgan_batch_encode(model, audios: list[bytes]): | |
| return batch_encode(model, audios) | |
| def batch_vqgan_decode(model, features): | |
| lengths = torch.tensor( | |
| [feature.shape[-1] for feature in features], device=model.device | |
| ) | |
| max_length = lengths.max().item() | |
| padded = torch.stack( | |
| [ | |
| torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) | |
| for feature in features | |
| ] | |
| ).to(model.device) | |
| # If bs too large, we do micro batch decode | |
| audios, audio_lengths = [], [] | |
| for i in range(0, padded.shape[0], MICRO_BATCH_SIZE): | |
| audio, audio_length = model.decode( | |
| padded[i : i + MICRO_BATCH_SIZE], | |
| feature_lengths=lengths[i : i + MICRO_BATCH_SIZE], | |
| ) | |
| audios.append(audio) | |
| audio_lengths.append(audio_length) | |
| audios = torch.cat(audios, dim=0) | |
| audio_lengths = torch.cat(audio_lengths, dim=0) | |
| audios, audio_lengths = audios.cpu(), audio_lengths.cpu() | |
| return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] | |
| def batch_asr(model, lock, audios, sr, language="auto"): | |
| resampled_audios = [] | |
| for audio in audios: | |
| audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE) | |
| assert audio.ndim == 1 | |
| resampled_audios.append(audio) | |
| with lock: | |
| res = model.generate( | |
| input=resampled_audios, | |
| batch_size=len(resampled_audios), | |
| language=language, | |
| use_itn=True, | |
| ) | |
| results = [] | |
| for r, audio in zip(res, audios): | |
| text = r["text"] | |
| text = re.sub(r"<\|.*?\|>", "", text) | |
| duration = len(audio) / sr * 1000 | |
| huge_gap = False | |
| if "timestamp" in r and len(r["timestamp"]) > 2: | |
| for timestamp_a, timestamp_b in zip( | |
| r["timestamp"][:-1], r["timestamp"][1:] | |
| ): | |
| # If there is a gap of more than 4 seconds, we consider it as a huge gap | |
| if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD: | |
| huge_gap = True | |
| break | |
| # Doesn't make sense to have a huge gap at the end | |
| if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD: | |
| huge_gap = True | |
| results.append( | |
| { | |
| "text": text, | |
| "duration": duration, | |
| "huge_gap": huge_gap, | |
| } | |
| ) | |
| return results | |