radtts-uk-vocos / app.py
Yehor's picture
Add huggingface_hub
8b0e460
raw
history blame
15.9 kB
import os
import sys
import json
import time
from os.path import getsize
from pathlib import Path
from importlib.metadata import version, PackageNotFoundError
try:
import spaces # it's for ZeroGPU
except ImportError:
...
import torch
import torchaudio
import gradio as gr
from huggingface_hub import hf_hub_download
# RAD-TTS code
from radtts import RADTTS
from data import TextProcessor
from common import update_params
from torch_env import device
# Vocoder
from vocos import Vocos
use_zerogpu = False
try:
spaces_version = version("spaces")
use_zerogpu = True
print("ZeroGPU is available, changing inference call.")
except PackageNotFoundError:
spaces_version = "N/A"
print("ZeroGPU is not available, skipping...")
def download_file_from_repo(
repo_id: str,
filename: str,
local_dir: str = ".",
repo_type: str = "model",
) -> str:
try:
os.makedirs(local_dir, exist_ok=True)
file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
cache_dir=None,
force_download=False,
repo_type=repo_type,
)
return file_path
except Exception as e:
raise Exception(f"An error occurred during download: {e}") from e
download_file_from_repo(
"Yehor/radtts-uk",
"radtts-pp-dap-model/model_dap_84000_state.pt",
"./models/",
)
# Init the model
params = []
# Load the config
config = json.loads(Path("config.json").read_text())
update_params(config, params)
data_config = config["data_config"]
model_config = config["model_config"]
# Load vocoder
vocos_config = hf_hub_download(
"patriotyk/vocos-mel-hifigan-compat-44100khz", "config.yaml"
)
vocos_model = hf_hub_download(
"patriotyk/vocos-mel-hifigan-compat-44100khz", "pytorch_model.bin"
)
vocos_model_path = Path(vocos_model)
state_dict = torch.load(vocos_model_path, map_location="cpu")
vocos = Vocos.from_hparams(vocos_config).to(device)
vocos.load_state_dict(state_dict, strict=True)
vocos.eval()
# Load RAD-TTS
radtts = RADTTS(**model_config).to(device)
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
radtts_model_path = Path("models/radtts-pp-dap-model/model_dap_84000_state.pt")
checkpoint_dict = torch.load(radtts_model_path, map_location="cpu")
state_dict = checkpoint_dict["state_dict"]
radtts.load_state_dict(state_dict, strict=False)
radtts.eval()
radtts_params = f"{sum(param.numel() for param in radtts.parameters()):,}"
vocos_params = f"{sum(param.numel() for param in vocos.parameters()):,}"
print(f"Loaded checkpoint (RAD-TTS++), number of parameters: {radtts_params}")
print(f"Loaded checkpoint (Vocos), number of parameters: {vocos_params}")
text_processor = TextProcessor(
data_config["training_files"],
**dict(
(k, v)
for k, v in data_config.items()
if k not in ["training_files", "validation_files"]
),
)
# Config
concurrency_limit = 5
title = "RAD-TTS++ Ukrainian"
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them on social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
description_head = f"""
# {title}
## Overview
Type your text in Ukrainian and select a voice to synthesize speech using [the RAD-TTS++ model](https://huggingface.co/Yehor/radtts-uk) and [Vocos](https://huggingface.co/patriotyk/vocos-mel-hifigan-compat-44100khz) with 44100 Hz.
""".strip()
description_foot = f"""
{authors_table}
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
- Torch device: {device}
#### Models
##### Acoustic model
- Name: RAD-TTS++ (DAP)
- Parameters: {radtts_params}
- File size: {getsize(radtts_model_path) / 1e6:.2f} MB
##### Vocoder
- Name: Vocos
- Parameters: {vocos_params}
- File size: {getsize(vocos_model_path) / 1e6:.2f} MB
""".strip()
tech_libraries = f"""
#### Libraries
- gradio: {version("gradio")}
- huggingface_hub: {version("huggingface_hub")}
- spaces: {spaces_version}
- torch: {version("torch")}
- torchaudio: {version("torchaudio")}
- scipy: {version("scipy")}
- numba: {version("numba")}
- librosa: {version("librosa")}
""".strip()
voices = {
"lada": 0,
"mykyta": 1,
"tetiana": 2,
}
examples = [
[
"Прокинувся ґазда вранці. Пішов, вичистив з-під коня, вичистив з-під бика, вичистив з-під овечок, вибрав молодняк, відніс його набік.",
"Mykyta",
],
[
"Пішов взяв сіна, дав корові. Пішов взяв сіна, дав бикові. Ячміню коняці насипав. Зайшов почистив корову, зайшов почистив бика, зайшов почистив коня, за яйця його мацнув.",
"Lada",
],
[
"Кінь ногою здригнув, на хазяїна ласкавим оком подивився. Тоді дядько пішов відкрив курей, гусей, качок, повиносив їм зерна, огірків нарізаних, нагодував. Коли чує – з хати дружина кличе. Зайшов. Дітки повмивані, сидять за столом, всі чекають тата. Взяв він ложку, перехрестив дітей, перехрестив лоба, почали снідати. Поснідали, він дістав пряників, роздав дітям. Діти зібралися, пішли в школу. Дядько вийшов, сів на призьбі, взяв сапку, почав мантачити. Мантачив-мантачив, коли – жінка виходить. Він їй ту сапку дає, ласкаво за сраку вщипнув, жінка до нього лагідно всміхнулася, пішла на город – сапати. Коли – йде пастух і товар кличе в череду. Повідмикав дядько овечок, коровку, бика, коня, все відпустив. Сів попри хати, дістав табАку, відірвав шмат газети, насипав, наслинив собі гарну таку цигарку. Благодать божа – і сонечко вже здійнялося над деревами. Дядько встромив цигарку в рота, дістав сірники, тільки чиркати – коли раптом з хати: Доброе утро! Московское время – шесть часов утра! Витяг дядько цигарку с рота, сплюнув набік, і сам собі каже: Ана маєш. Прокинулись, бляді!",
"Tetiana",
],
]
def inference(
text,
voice,
n_takes,
use_latest_take,
token_dur_scaling,
f0_mean,
f0_std,
energy_mean,
energy_std,
sigma_decoder,
sigma_token_duration,
sigma_f0,
sigma_energy,
):
if not text:
raise gr.Error("Please paste your text.")
request = {
"text": text,
"voice": voice,
"n_takes": n_takes,
"use_latest_take": use_latest_take,
"token_dur_scaling": token_dur_scaling,
"f0_mean": f0_mean,
"f0_std": f0_std,
"energy_mean": energy_mean,
"energy_std": energy_std,
"sigma_decoder": sigma_decoder,
"sigma_token_duration": sigma_token_duration,
"sigma_f0": sigma_f0,
"sigma_energy": sigma_energy,
}
print(json.dumps(request, indent=2))
speaker = speaker_text = speaker_attributes = voice.lower()
tensor_text = torch.LongTensor(text_processor.tp.encode_text(text)).to(device)
speaker_tensor = torch.LongTensor([voices[speaker]]).to(device)
speaker_id = speaker_id_text = speaker_id_attributes = speaker_tensor
if speaker_text is not None:
speaker_id_text = torch.LongTensor([voices[speaker_text]]).to(device)
if speaker_attributes is not None:
speaker_id_attributes = torch.LongTensor([voices[speaker_attributes]]).to(
device
)
inference_start = time.time()
mels = []
for n_take in range(n_takes):
gr.Info(f"Inferencing take {n_take + 1}", duration=1)
with torch.autocast(device, enabled=False):
with torch.inference_mode():
outputs = radtts.infer(
speaker_id,
tensor_text[None],
sigma_decoder,
sigma_token_duration,
sigma_f0,
sigma_energy,
token_dur_scaling,
token_duration_max=100,
speaker_id_text=speaker_id_text,
speaker_id_attributes=speaker_id_attributes,
f0_mean=f0_mean,
f0_std=f0_std,
energy_mean=energy_mean,
energy_std=energy_std,
)
mels.append(outputs["mel"])
gr.Info("Synthesized MEL spectrograms, converting to WAVE.", duration=0.5)
wav_gen_all = []
for mel in mels:
wav_gen_all.append(vocos.decode(mel))
if use_latest_take:
wav_gen = wav_gen_all[-1] # Get the latest generated wav
else:
wav_gen = torch.cat(wav_gen_all, dim=1) # Concatenate all the generated wavs
duration = len(wav_gen[0]) / 44_100
torchaudio.save("audio.wav", wav_gen.cpu(), 44_100, encoding="PCM_S")
elapsed_time = time.time() - inference_start
rtf = elapsed_time / duration
speed_ratio = duration / elapsed_time
speech_rate = len(text.split(" ")) / duration
rtf_value = f"Real-Time Factor: {round(rtf, 4)}, time: {round(elapsed_time, 4)} seconds, audio duration: {round(duration, 4)} seconds. Speed ratio: {round(speed_ratio, 2)}x. Speech rate: {round(speech_rate, 4)} words-per-second."
gr.Success("Finished!", duration=0.5)
return [gr.Audio("audio.wav"), rtf_value]
try:
@spaces.GPU
def inference_zerogpu(*args):
return inference(*args)
except NameError:
def inference_cpu(*args):
return inference(*args)
demo = gr.Blocks(
title=title,
analytics_enabled=False,
theme=gr.themes.Base(),
)
with demo:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Synthesized audio")
rtf = gr.Markdown(
label="Real-Time Factor",
value="Here you will see how fast the model and the speaker is.",
)
with gr.Row():
with gr.Column():
text = gr.Text(
label="Text",
value="Сл+ава Укра+їні! — українське вітання, національне гасло.",
)
voice = gr.Radio(
label="Voice",
choices=[voice.title() for voice in voices.keys()],
value="Tetiana",
)
with gr.Accordion("Advanced options", open=False):
gr.Markdown("You can change the voice, speed, and other parameters.")
with gr.Column():
n_takes = gr.Number(
label="Number of takes",
value=1,
minimum=1,
maximum=10,
step=1,
)
use_latest_take = gr.Checkbox(
label="Use the latest take",
value=False,
)
token_dur_scaling = gr.Number(
label="Token duration scaling",
value=1.0,
minimum=0.0,
maximum=10,
step=0.1,
)
with gr.Row():
f0_mean = gr.Number(
label="F0 mean",
value=0,
minimum=0.0,
maximum=1.0,
step=0.1,
)
f0_std = gr.Number(
label="F0 std",
value=0,
minimum=0.0,
maximum=1.0,
step=0.1,
)
energy_mean = gr.Number(
label="Energy mean",
value=0,
minimum=0.0,
maximum=1.0,
step=0.1,
)
energy_std = gr.Number(
label="Energy std",
value=0,
minimum=0.0,
maximum=1.0,
step=0.1,
)
with gr.Row():
sigma_decoder = gr.Number(
label="Sampling sigma for decoder",
value=0.8,
minimum=0.0,
maximum=1.0,
step=0.1,
)
sigma_token_duration = gr.Number(
label="Sampling sigma for duration",
value=0.666,
minimum=0.0,
maximum=1.0,
step=0.1,
)
sigma_f0 = gr.Number(
label="Sampling sigma for F0",
value=1.0,
minimum=0.0,
maximum=1.0,
step=0.1,
)
sigma_energy = gr.Number(
label="Sampling sigma for energy avg",
value=1.0,
minimum=0.0,
maximum=1.0,
step=0.1,
)
gr.Button("Run").click(
inference_zerogpu if use_zerogpu else inference_cpu,
concurrency_limit=concurrency_limit,
inputs=[
text,
voice,
n_takes,
use_latest_take,
token_dur_scaling,
f0_mean,
f0_std,
energy_mean,
energy_std,
sigma_decoder,
sigma_token_duration,
sigma_f0,
sigma_energy,
],
outputs=[audio, rtf],
)
with gr.Row():
gr.Examples(
label="Choose an example",
inputs=[
text,
voice,
],
examples=examples,
)
gr.Markdown(description_foot)
gr.Markdown("### Gradio app uses:")
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
if __name__ == "__main__":
demo.queue()
demo.launch()