VieNeu-TTS / vieneu_tts.py
pnnbao-ump's picture
add more examples
f01210e
from pathlib import Path
from typing import Generator
import librosa
import numpy as np
import torch
from neucodec import NeuCodec, DistillNeuCodec
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.phonemize_text import phonemize_text, phonemize_with_dict
import re
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
# original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
assert len(frames)
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = 0
for i, frame in enumerate(frames):
frame_end = stride * i + frame.shape[-1]
total_size = max(total_size, frame_end)
sum_weight = np.zeros(total_size, dtype=dtype)
out = np.zeros(*shape, total_size, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = np.abs(0.5 - (t - 0.5))
out[..., offset : offset + frame_length] += weight * frame
sum_weight[offset : offset + frame_length] += weight
offset += stride
assert sum_weight.min() > 0
return out / sum_weight
class VieNeuTTS:
def __init__(
self,
backbone_repo="pnnbao-ump/VieNeu-TTS",
backbone_device="cpu",
codec_repo="neuphonic/neucodec",
codec_device="cpu",
):
# Constants
self.sample_rate = 24_000
self.max_context = 2048
self.hop_length = 480
self.streaming_overlap_frames = 1
self.streaming_frames_per_chunk = 25
self.streaming_lookforward = 5
self.streaming_lookback = 50
self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
# ggml & onnx flags
self._is_quantized_model = False
self._is_onnx_codec = False
# HF tokenizer
self.tokenizer = None
# Load models
self._load_backbone(backbone_repo, backbone_device)
self._load_codec(codec_repo, codec_device)
def _load_backbone(self, backbone_repo, backbone_device):
print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...")
if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower():
try:
from llama_cpp import Llama
except ImportError as e:
raise ImportError(
"Failed to import `llama_cpp`. "
"Please install it with:\n"
" pip install llama-cpp-python"
) from e
self.backbone = Llama.from_pretrained(
repo_id=backbone_repo,
filename="*.gguf",
verbose=False,
n_gpu_layers=-1 if backbone_device == "gpu" else 0,
n_ctx=self.max_context,
mlock=True,
flash_attn=True if backbone_device == "gpu" else False,
)
self._is_quantized_model = True
else:
self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo)
self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to(
torch.device(backbone_device)
)
def _load_codec(self, codec_repo, codec_device):
print(f"Loading codec from: {codec_repo} on {codec_device} ...")
match codec_repo:
case "neuphonic/neucodec":
self.codec = NeuCodec.from_pretrained(codec_repo)
self.codec.eval().to(codec_device)
case "neuphonic/distill-neucodec":
self.codec = DistillNeuCodec.from_pretrained(codec_repo)
self.codec.eval().to(codec_device)
case "neuphonic/neucodec-onnx-decoder":
if codec_device != "cpu":
raise ValueError("Onnx decoder only currently runs on CPU.")
try:
from neucodec import NeuCodecOnnxDecoder
except ImportError as e:
raise ImportError(
"Failed to import the onnx decoder."
" Ensure you have onnxruntime installed as well as neucodec >= 0.0.4."
) from e
self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo)
self._is_onnx_codec = True
case _:
raise ValueError(f"Unsupported codec repository: {codec_repo}")
def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray:
"""
Perform inference to generate speech from text using the TTS model and reference audio.
Args:
text (str): Input text to be converted to speech.
ref_codes (np.ndarray | torch.tensor): Encoded reference.
ref_text (str): Reference text for reference audio. Defaults to None.
Returns:
np.ndarray: Generated speech waveform.
"""
# Generate tokens
if self._is_quantized_model:
output_str = self._infer_ggml(ref_codes, ref_text, text)
else:
prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
output_str = self._infer_torch(prompt_ids)
# Decode
wav = self._decode(output_str)
return wav
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
"""
Perform streaming inference to generate speech from text using the TTS model and reference audio.
Args:
text (str): Input text to be converted to speech.
ref_codes (np.ndarray | torch.tensor): Encoded reference.
ref_text (str): Reference text for reference audio. Defaults to None.
Yields:
np.ndarray: Generated speech waveform.
"""
if self._is_quantized_model:
return self._infer_stream_ggml(ref_codes, ref_text, text)
else:
raise NotImplementedError("Streaming is not implemented for the torch backend!")
def encode_reference(self, ref_audio_path: str | Path):
wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
with torch.no_grad():
ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
return ref_codes
def _decode(self, codes: str):
"""Decode speech tokens to audio waveform."""
# Extract speech token IDs using regex
speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)]
if len(speech_ids) == 0:
raise ValueError(
"No valid speech tokens found in the output. "
"The model may not have generated proper speech tokens."
)
# Onnx decode
if self._is_onnx_codec:
codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :]
recon = self.codec.decode_code(codes)
# Torch decode
else:
with torch.no_grad():
codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(
self.codec.device
)
recon = self.codec.decode_code(codes).cpu().numpy()
return recon[0, 0, :]
def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]:
input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text)
speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>")
speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>")
text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>")
text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>")
text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>")
input_ids = self.tokenizer.encode(input_text, add_special_tokens=False)
chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>"""
ids = self.tokenizer.encode(chat)
text_replace_idx = ids.index(text_replace)
ids = (
ids[:text_replace_idx]
+ [text_prompt_start]
+ input_ids
+ [text_prompt_end]
+ ids[text_replace_idx + 1 :] # noqa
)
speech_replace_idx = ids.index(speech_replace)
codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
codes = self.tokenizer.encode(codes_str, add_special_tokens=False)
ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes)
return ids
def _infer_torch(self, prompt_ids: list[int]) -> str:
prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
with torch.no_grad():
output_tokens = self.backbone.generate(
prompt_tensor,
max_length=self.max_context,
eos_token_id=speech_end_id,
do_sample=True,
temperature=1.0,
top_k=50,
use_cache=True,
min_new_tokens=50,
)
input_length = prompt_tensor.shape[-1]
output_str = self.tokenizer.decode(
output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
)
return output_str
def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
ref_text = phonemize_with_dict(ref_text)
input_text = phonemize_with_dict(input_text)
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
prompt = (
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
)
output = self.backbone(
prompt,
max_tokens=self.max_context,
temperature=1.0,
top_k=50,
stop=["<|SPEECH_GENERATION_END|>"],
)
output_str = output["choices"][0]["text"]
return output_str
def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
ref_text = phonemize_with_dict(ref_text)
input_text = phonemize_with_dict(input_text)
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
prompt = (
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
)
audio_cache: list[np.ndarray] = []
token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
n_decoded_samples: int = 0
n_decoded_tokens: int = len(ref_codes)
for item in self.backbone(
prompt,
max_tokens=self.max_context,
temperature=0.2,
top_k=50,
stop=["<|SPEECH_GENERATION_END|>"],
stream=True
):
output_str = item["choices"][0]["text"]
token_cache.append(output_str)
if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
# decode chunk
tokens_start = max(
n_decoded_tokens
- self.streaming_lookback
- self.streaming_overlap_frames,
0
)
tokens_end = (
n_decoded_tokens
+ self.streaming_frames_per_chunk
+ self.streaming_lookforward
+ self.streaming_overlap_frames
)
sample_start = (
n_decoded_tokens - tokens_start
) * self.hop_length
sample_end = (
sample_start
+ (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
)
curr_codes = token_cache[tokens_start:tokens_end]
recon = self._decode("".join(curr_codes))
recon = recon[sample_start:sample_end]
audio_cache.append(recon)
# postprocess
processed_recon = _linear_overlap_add(
audio_cache, stride=self.streaming_stride_samples
)
new_samples_end = len(audio_cache) * self.streaming_stride_samples
processed_recon = processed_recon[
n_decoded_samples:new_samples_end
]
n_decoded_samples = new_samples_end
n_decoded_tokens += self.streaming_frames_per_chunk
yield processed_recon
# final decoding handled separately as non-constant chunk size
remaining_tokens = len(token_cache) - n_decoded_tokens
if len(token_cache) > n_decoded_tokens:
tokens_start = max(
len(token_cache)
- (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
0
)
sample_start = (
len(token_cache)
- tokens_start
- remaining_tokens
- self.streaming_overlap_frames
) * self.hop_length
curr_codes = token_cache[tokens_start:]
recon = self._decode("".join(curr_codes))
recon = recon[sample_start:]
audio_cache.append(recon)
processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
processed_recon = processed_recon[n_decoded_samples:]
yield processed_recon