wuwa-bert-vits2 / app.py
JotunnBurton's picture
Update app.py
92259fe verified
raw
history blame
6.95 kB
import sys
import logging
import os
import json
import torch
import argparse
import commons
import utils
import gradio as gr
import numpy as np
import librosa
import re_matching
from tools.sentence import split_by_language
from huggingface_hub import hf_hub_download
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from models import SynthesizerTrn
from text.symbols import symbols
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
def get_net_g(model_path: str, version: str, device: str, hps):
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
_ = net_g.eval()
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
return net_g
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
style_text = None if style_text == "" else style_text
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
del word2ph
assert bert.shape[-1] == len(phone)
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, phone, tone, language
def infer(*args, **kwargs):
from infer import infer as real_infer
return real_infer(*args, **kwargs)
def load_audio(path):
audio, sr = librosa.load(path, 48000)
return sr, audio
def gr_util(item):
if item == "Text prompt":
return {"visible": True, "__type__": "update"}, {"visible": False, "__type__": "update"}
else:
return {"visible": False, "__type__": "update"}, {"visible": True, "__type__": "update"}
def create_tts_fn(hps, net_g, device):
def tts_fn(
text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,
reference_audio, emotion, prompt_mode, style_text, style_weight
):
if style_text == "":
style_text = None
if prompt_mode == "Audio prompt":
if reference_audio is None:
return ("Invalid audio prompt", None)
else:
reference_audio = load_audio(reference_audio)[1]
else:
reference_audio = None
audio = infer(
text=text,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
style_text=style_text,
style_weight=style_weight,
)
return "Success", (hps.data.sampling_rate, audio)
return tts_fn
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", default=False, help="make link public", action="store_true")
parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
args = parser.parse_args()
if args.debug:
logger.setLevel(logging.DEBUG)
with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
models_info = json.load(f)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
models = []
for _, info in models_info.items():
if not info['enable']:
continue
name, title, link, example = info['name'], info['title'], info['link'], info['example']
config_path = hf_hub_download(repo_id=link, filename="config.json")
model_path = hf_hub_download(repo_id=link, filename=f"{name}.pth")
hps = utils.get_hparams_from_file(config_path)
version = hps.version if hasattr(hps, "version") else "v2"
net_g = get_net_g(model_path, version, device, hps)
fn = create_tts_fn(hps, net_g, device)
models.append((title, example, list(hps.data.spk2id.keys()), fn))
with gr.Blocks(theme='NoCrypt/miku') as app:
gr.Markdown("## ✅ All models loaded successfully. Ready to use.")
with gr.Tabs():
for (title, example, speakers, tts_fn) in models:
with gr.TabItem(title):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input text", lines=5, value=example)
speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
prompt_mode = gr.Radio(["Text prompt", "Audio prompt"], label="Prompt Mode", value="Text prompt")
text_prompt = gr.Textbox(label="Text prompt", value="Happy", visible=True)
audio_prompt = gr.Audio(label="Audio prompt", type="filepath", visible=False)
sdp_ratio = gr.Slider(0, 1, 0.2, 0.1, label="SDP Ratio")
noise_scale = gr.Slider(0.1, 2.0, 0.6, 0.1, label="Noise")
noise_scale_w = gr.Slider(0.1, 2.0, 0.8, 0.1, label="Noise_W")
length_scale = gr.Slider(0.1, 2.0, 1.0, 0.1, label="Length")
language = gr.Dropdown(choices=["JP", "ZH", "EN", "mix", "auto"], value="JP", label="Language")
style_text = gr.Textbox(label="Style Text", placeholder="辅助文本 (留空为无)")
style_weight = gr.Slider(0, 1, 0.7, 0.1, label="Style Weight")
btn = gr.Button("Generate Audio", variant="primary")
with gr.Column():
output_msg = gr.Textbox(label="Output Message")
output_audio = gr.Audio(label="Output Audio")
prompt_mode.change(lambda x: gr_util(x), inputs=[prompt_mode], outputs=[text_prompt, audio_prompt])
audio_prompt.upload(lambda x: load_audio(x), inputs=[audio_prompt], outputs=[audio_prompt])
btn.click(tts_fn, inputs=[input_text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language, audio_prompt, text_prompt, prompt_mode, style_text, style_weight], outputs=[output_msg, output_audio])
app.queue().launch(share=args.share)