Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| from typing import Generator | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from neucodec import NeuCodec, DistillNeuCodec | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from utils.phonemize_text import phonemize_text, phonemize_with_dict | |
| import re | |
| def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray: | |
| # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py | |
| assert len(frames) | |
| dtype = frames[0].dtype | |
| shape = frames[0].shape[:-1] | |
| total_size = 0 | |
| for i, frame in enumerate(frames): | |
| frame_end = stride * i + frame.shape[-1] | |
| total_size = max(total_size, frame_end) | |
| sum_weight = np.zeros(total_size, dtype=dtype) | |
| out = np.zeros(*shape, total_size, dtype=dtype) | |
| offset: int = 0 | |
| for frame in frames: | |
| frame_length = frame.shape[-1] | |
| t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] | |
| weight = np.abs(0.5 - (t - 0.5)) | |
| out[..., offset : offset + frame_length] += weight * frame | |
| sum_weight[offset : offset + frame_length] += weight | |
| offset += stride | |
| assert sum_weight.min() > 0 | |
| return out / sum_weight | |
| class VieNeuTTS: | |
| def __init__( | |
| self, | |
| backbone_repo="pnnbao-ump/VieNeu-TTS", | |
| backbone_device="cpu", | |
| codec_repo="neuphonic/neucodec", | |
| codec_device="cpu", | |
| ): | |
| # Constants | |
| self.sample_rate = 24_000 | |
| self.max_context = 2048 | |
| self.hop_length = 480 | |
| self.streaming_overlap_frames = 1 | |
| self.streaming_frames_per_chunk = 25 | |
| self.streaming_lookforward = 5 | |
| self.streaming_lookback = 50 | |
| self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length | |
| # ggml & onnx flags | |
| self._is_quantized_model = False | |
| self._is_onnx_codec = False | |
| # HF tokenizer | |
| self.tokenizer = None | |
| # Load models | |
| self._load_backbone(backbone_repo, backbone_device) | |
| self._load_codec(codec_repo, codec_device) | |
| def _load_backbone(self, backbone_repo, backbone_device): | |
| print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...") | |
| if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower(): | |
| try: | |
| from llama_cpp import Llama | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import `llama_cpp`. " | |
| "Please install it with:\n" | |
| " pip install llama-cpp-python" | |
| ) from e | |
| self.backbone = Llama.from_pretrained( | |
| repo_id=backbone_repo, | |
| filename="*.gguf", | |
| verbose=False, | |
| n_gpu_layers=-1 if backbone_device == "gpu" else 0, | |
| n_ctx=self.max_context, | |
| mlock=True, | |
| flash_attn=True if backbone_device == "gpu" else False, | |
| ) | |
| self._is_quantized_model = True | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo) | |
| self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to( | |
| torch.device(backbone_device) | |
| ) | |
| def _load_codec(self, codec_repo, codec_device): | |
| print(f"Loading codec from: {codec_repo} on {codec_device} ...") | |
| match codec_repo: | |
| case "neuphonic/neucodec": | |
| self.codec = NeuCodec.from_pretrained(codec_repo) | |
| self.codec.eval().to(codec_device) | |
| case "neuphonic/distill-neucodec": | |
| self.codec = DistillNeuCodec.from_pretrained(codec_repo) | |
| self.codec.eval().to(codec_device) | |
| case "neuphonic/neucodec-onnx-decoder": | |
| if codec_device != "cpu": | |
| raise ValueError("Onnx decoder only currently runs on CPU.") | |
| try: | |
| from neucodec import NeuCodecOnnxDecoder | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import the onnx decoder." | |
| " Ensure you have onnxruntime installed as well as neucodec >= 0.0.4." | |
| ) from e | |
| self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo) | |
| self._is_onnx_codec = True | |
| case _: | |
| raise ValueError(f"Unsupported codec repository: {codec_repo}") | |
| def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray: | |
| """ | |
| Perform inference to generate speech from text using the TTS model and reference audio. | |
| Args: | |
| text (str): Input text to be converted to speech. | |
| ref_codes (np.ndarray | torch.tensor): Encoded reference. | |
| ref_text (str): Reference text for reference audio. Defaults to None. | |
| Returns: | |
| np.ndarray: Generated speech waveform. | |
| """ | |
| # Generate tokens | |
| if self._is_quantized_model: | |
| output_str = self._infer_ggml(ref_codes, ref_text, text) | |
| else: | |
| prompt_ids = self._apply_chat_template(ref_codes, ref_text, text) | |
| output_str = self._infer_torch(prompt_ids) | |
| # Decode | |
| wav = self._decode(output_str) | |
| return wav | |
| def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]: | |
| """ | |
| Perform streaming inference to generate speech from text using the TTS model and reference audio. | |
| Args: | |
| text (str): Input text to be converted to speech. | |
| ref_codes (np.ndarray | torch.tensor): Encoded reference. | |
| ref_text (str): Reference text for reference audio. Defaults to None. | |
| Yields: | |
| np.ndarray: Generated speech waveform. | |
| """ | |
| if self._is_quantized_model: | |
| return self._infer_stream_ggml(ref_codes, ref_text, text) | |
| else: | |
| raise NotImplementedError("Streaming is not implemented for the torch backend!") | |
| def encode_reference(self, ref_audio_path: str | Path): | |
| wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T] | |
| with torch.no_grad(): | |
| ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) | |
| return ref_codes | |
| def _decode(self, codes: str): | |
| """Decode speech tokens to audio waveform.""" | |
| # Extract speech token IDs using regex | |
| speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)] | |
| if len(speech_ids) == 0: | |
| raise ValueError( | |
| "No valid speech tokens found in the output. " | |
| "The model may not have generated proper speech tokens." | |
| ) | |
| # Onnx decode | |
| if self._is_onnx_codec: | |
| codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :] | |
| recon = self.codec.decode_code(codes) | |
| # Torch decode | |
| else: | |
| with torch.no_grad(): | |
| codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to( | |
| self.codec.device | |
| ) | |
| recon = self.codec.decode_code(codes).cpu().numpy() | |
| return recon[0, 0, :] | |
| def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]: | |
| input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text) | |
| speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>") | |
| speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>") | |
| text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>") | |
| text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>") | |
| text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>") | |
| input_ids = self.tokenizer.encode(input_text, add_special_tokens=False) | |
| chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>""" | |
| ids = self.tokenizer.encode(chat) | |
| text_replace_idx = ids.index(text_replace) | |
| ids = ( | |
| ids[:text_replace_idx] | |
| + [text_prompt_start] | |
| + input_ids | |
| + [text_prompt_end] | |
| + ids[text_replace_idx + 1 :] # noqa | |
| ) | |
| speech_replace_idx = ids.index(speech_replace) | |
| codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes]) | |
| codes = self.tokenizer.encode(codes_str, add_special_tokens=False) | |
| ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes) | |
| return ids | |
| def _infer_torch(self, prompt_ids: list[int]) -> str: | |
| prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device) | |
| speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") | |
| with torch.no_grad(): | |
| output_tokens = self.backbone.generate( | |
| prompt_tensor, | |
| max_length=self.max_context, | |
| eos_token_id=speech_end_id, | |
| do_sample=True, | |
| temperature=1.0, | |
| top_k=50, | |
| use_cache=True, | |
| min_new_tokens=50, | |
| ) | |
| input_length = prompt_tensor.shape[-1] | |
| output_str = self.tokenizer.decode( | |
| output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False | |
| ) | |
| return output_str | |
| def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str: | |
| ref_text = phonemize_with_dict(ref_text) | |
| input_text = phonemize_with_dict(input_text) | |
| codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) | |
| prompt = ( | |
| f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" | |
| f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" | |
| ) | |
| output = self.backbone( | |
| prompt, | |
| max_tokens=self.max_context, | |
| temperature=1.0, | |
| top_k=50, | |
| stop=["<|SPEECH_GENERATION_END|>"], | |
| ) | |
| output_str = output["choices"][0]["text"] | |
| return output_str | |
| def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]: | |
| ref_text = phonemize_with_dict(ref_text) | |
| input_text = phonemize_with_dict(input_text) | |
| codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) | |
| prompt = ( | |
| f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" | |
| f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" | |
| ) | |
| audio_cache: list[np.ndarray] = [] | |
| token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes] | |
| n_decoded_samples: int = 0 | |
| n_decoded_tokens: int = len(ref_codes) | |
| for item in self.backbone( | |
| prompt, | |
| max_tokens=self.max_context, | |
| temperature=0.2, | |
| top_k=50, | |
| stop=["<|SPEECH_GENERATION_END|>"], | |
| stream=True | |
| ): | |
| output_str = item["choices"][0]["text"] | |
| token_cache.append(output_str) | |
| if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward: | |
| # decode chunk | |
| tokens_start = max( | |
| n_decoded_tokens | |
| - self.streaming_lookback | |
| - self.streaming_overlap_frames, | |
| 0 | |
| ) | |
| tokens_end = ( | |
| n_decoded_tokens | |
| + self.streaming_frames_per_chunk | |
| + self.streaming_lookforward | |
| + self.streaming_overlap_frames | |
| ) | |
| sample_start = ( | |
| n_decoded_tokens - tokens_start | |
| ) * self.hop_length | |
| sample_end = ( | |
| sample_start | |
| + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length | |
| ) | |
| curr_codes = token_cache[tokens_start:tokens_end] | |
| recon = self._decode("".join(curr_codes)) | |
| recon = recon[sample_start:sample_end] | |
| audio_cache.append(recon) | |
| # postprocess | |
| processed_recon = _linear_overlap_add( | |
| audio_cache, stride=self.streaming_stride_samples | |
| ) | |
| new_samples_end = len(audio_cache) * self.streaming_stride_samples | |
| processed_recon = processed_recon[ | |
| n_decoded_samples:new_samples_end | |
| ] | |
| n_decoded_samples = new_samples_end | |
| n_decoded_tokens += self.streaming_frames_per_chunk | |
| yield processed_recon | |
| # final decoding handled separately as non-constant chunk size | |
| remaining_tokens = len(token_cache) - n_decoded_tokens | |
| if len(token_cache) > n_decoded_tokens: | |
| tokens_start = max( | |
| len(token_cache) | |
| - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens), | |
| 0 | |
| ) | |
| sample_start = ( | |
| len(token_cache) | |
| - tokens_start | |
| - remaining_tokens | |
| - self.streaming_overlap_frames | |
| ) * self.hop_length | |
| curr_codes = token_cache[tokens_start:] | |
| recon = self._decode("".join(curr_codes)) | |
| recon = recon[sample_start:] | |
| audio_cache.append(recon) | |
| processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples) | |
| processed_recon = processed_recon[n_decoded_samples:] | |
| yield processed_recon | |