FireRedTTS2 / app.py
Shen Feiyu
init at 250916
71cd91e
import re
import spaces
import gradio as gr
from tqdm import tqdm
from huggingface_hub import snapshot_download
from argparse import ArgumentParser
from typing import Literal, List, Tuple
from fireredtts2.fireredtts2 import FireRedTTS2
# ================================================
# FireRedTTS2 Model
# ================================================
# Global model instance
model: FireRedTTS2 = None
def initiate_model(pretrained_dir: str, device="cuda"):
global model
if model is None:
model = FireRedTTS2(
pretrained_dir=pretrained_dir,
gen_type="dialogue",
device=device,
)
# ================================================
# Gradio
# ================================================
# i18n
_i18n_key2lang_dict = dict(
# Title markdown
title_md_desc=dict(
en="FireRedTTS-2 🔥 Dialogue Generation",
zh="FireRedTTS-2 🔥 对话生成",
),
# Voice mode radio
voice_mode_label=dict(
en="Voice Mode",
zh="音色模式",
),
voice_model_choice1=dict(
en="Voice Clone",
zh="音色克隆",
),
voice_model_choice2=dict(
en="Random Voice",
zh="随机音色",
),
# Speaker1 Prompt
spk1_prompt_audio_label=dict(
en="Speaker 1 Prompt Audio",
zh="说话人 1 参考语音",
),
spk1_prompt_text_label=dict(
en="Speaker 1 Prompt Text",
zh="说话人 1 参考文本",
),
spk1_prompt_text_placeholder=dict(
en="[S1] text of speaker 1 prompt audio.",
zh="[S1] 说话人 1 参考文本",
),
# Speaker2 Prompt
spk2_prompt_audio_label=dict(
en="Speaker 2 Prompt Audio",
zh="说话人 2 参考语音",
),
spk2_prompt_text_label=dict(
en="Speaker 2 Prompt Text",
zh="说话人 2 参考文本",
),
spk2_prompt_text_placeholder=dict(
en="[S2] text of speaker 2 prompt audio.",
zh="[S2] 说话人 2 参考文本",
),
# Dialogue input textbox
dialogue_text_input_label=dict(
en="Dialogue Text Input",
zh="对话文本输入",
),
dialogue_text_input_placeholder=dict(
en="[S1]text[S2]text[S1]text...",
zh="[S1]文本[S2]文本[S1]文本...",
),
# Generate button
generate_btn_label=dict(
en="Generate Audio",
zh="合成",
),
# Generated audio
generated_audio_label=dict(
en="Generated Dialogue Audio",
zh="合成的对话音频",
),
# Warining1: invalid text for prompt
warn_invalid_spk1_prompt_text=dict(
en='Invalid speaker 1 prompt text, should strictly follow: "[S1]xxx"',
zh='说话人 1 参考文本不合规,格式:"[S1]xxx"',
),
warn_invalid_spk2_prompt_text=dict(
en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"',
zh='说话人 2 参考文本不合规,格式:"[S2]xxx"',
),
# Warining2: invalid text for dialogue input
warn_invalid_dialogue_text=dict(
en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."',
zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."',
),
# Warining3: incomplete prompt info
warn_incomplete_prompt=dict(
en="Please provide prompt audio and text for both speaker 1 and speaker 2",
zh="请提供说话人 1 与说话人 2 的参考语音与参考文本",
),
)
global_lang: Literal["zh", "en"] = "zh"
def i18n(key):
global global_lang
return _i18n_key2lang_dict[key][global_lang]
def check_monologue_text(text: str, prefix: str = None) -> bool:
text = text.strip()
# Check speaker tags
if prefix is not None and (not text.startswith(prefix)):
return False
# Remove prefix
if prefix is not None:
text = text.removeprefix(prefix)
text = text.strip()
# If empty?
if len(text) == 0:
return False
return True
def check_dialogue_text(text_list: List[str]) -> bool:
if len(text_list) == 0:
return False
for text in text_list:
if not (
check_monologue_text(text, "[S1]")
or check_monologue_text(text, "[S2]")
or check_monologue_text(text, "[S3]")
or check_monologue_text(text, "[S4]")
):
return False
return True
@spaces.GPU(duration=200)
def dialogue_synthesis_function(
target_text: str,
voice_mode: Literal[0, 1] = 0, # 0 means voice clone
spk1_prompt_text: str | None = "",
spk1_prompt_audio: str | None = None,
spk2_prompt_text: str | None = "",
spk2_prompt_audio: str | None = None,
):
# Voice clone mode, check prompt info
if voice_mode == 0:
prompt_has_value = [
spk1_prompt_text != "",
spk1_prompt_audio is not None,
spk2_prompt_text != "",
spk2_prompt_audio is not None,
]
if not all(prompt_has_value):
gr.Warning(message=i18n("warn_incomplete_prompt"))
return None
if not check_monologue_text(spk1_prompt_text, "[S1]"):
gr.Warning(message=i18n("warn_invalid_spk1_prompt_text"))
return None
if not check_monologue_text(spk2_prompt_text, "[S2]"):
gr.Warning(message=i18n("warn_invalid_spk2_prompt_text"))
return None
# Check dialogue text
target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text)
target_text_list = [text.strip() for text in target_text_list]
if not check_dialogue_text(target_text_list):
gr.Warning(message=i18n("warn_invalid_dialogue_text"))
return None
# Go synthesis
progress_bar = gr.Progress(track_tqdm=True)
prompt_wav_list = (
None if voice_mode != 0 else [spk1_prompt_audio, spk2_prompt_audio]
)
prompt_text_list = None if voice_mode != 0 else [spk1_prompt_text, spk2_prompt_text]
target_audio = model.generate_dialogue(
text_list=target_text_list,
prompt_wav_list=prompt_wav_list,
prompt_text_list=prompt_text_list,
temperature=0.9,
topk=30,
)
return (24000, target_audio.squeeze(0).numpy())
# UI rendering
def render_interface() -> gr.Blocks:
with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page:
# ======================== UI ========================
# A large title
title_desc = gr.Markdown(value="# {}".format(i18n("title_md_desc")))
with gr.Row():
lang_choice = gr.Radio(
choices=["中文", "English"],
value="中文",
label="Display Language/显示语言",
type="index",
interactive=True,
)
voice_mode_choice = gr.Radio(
choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")],
value=i18n("voice_model_choice1"),
label=i18n("voice_mode_label"),
type="index",
interactive=True,
)
with gr.Row():
# ==== Speaker1 Prompt ====
with gr.Column(scale=1):
with gr.Group(visible=True) as spk1_prompt_group:
spk1_prompt_audio = gr.Audio(
label=i18n("spk1_prompt_audio_label"),
type="filepath",
editable=False,
interactive=True,
) # Audio component returns tmp audio path
spk1_prompt_text = gr.Textbox(
label=i18n("spk1_prompt_text_label"),
placeholder=i18n("spk1_prompt_text_placeholder"),
lines=3,
)
# ==== Speaker2 Prompt ====
with gr.Column(scale=1):
with gr.Group(visible=True) as spk2_prompt_group:
spk2_prompt_audio = gr.Audio(
label=i18n("spk2_prompt_audio_label"),
type="filepath",
editable=False,
interactive=True,
)
spk2_prompt_text = gr.Textbox(
label=i18n("spk2_prompt_text_label"),
placeholder=i18n("spk2_prompt_text_placeholder"),
lines=3,
)
# ==== Text input ====
with gr.Column(scale=2):
dialogue_text_input = gr.Textbox(
label=i18n("dialogue_text_input_label"),
placeholder=i18n("dialogue_text_input_placeholder"),
lines=18,
)
# Generate button
generate_btn = gr.Button(
value=i18n("generate_btn_label"), variant="primary", size="lg"
)
# Long output audio
generate_audio = gr.Audio(
label=i18n("generated_audio_label"),
interactive=False,
)
# ======================== Action ========================
# Language action
def _change_component_language(lang):
global global_lang
global_lang = ["zh", "en"][lang]
return [
# title_desc
gr.update(value="# {}".format(i18n("title_md_desc"))),
# voice_mode_choice
gr.update(
choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")],
value=i18n("voice_model_choice1"),
label=i18n("voice_mode_label"),
),
# spk1_prompt_{audio,text}
gr.update(label=i18n("spk1_prompt_audio_label")),
gr.update(
label=i18n("spk1_prompt_text_label"),
placeholder=i18n("spk1_prompt_text_placeholder"),
),
# spk2_prompt_{audio,text}
gr.update(label=i18n("spk2_prompt_audio_label")),
gr.update(
label=i18n("spk2_prompt_text_label"),
placeholder=i18n("spk2_prompt_text_placeholder"),
),
# dialogue_text_input
gr.update(
label=i18n("dialogue_text_input_label"),
placeholder=i18n("dialogue_text_input_placeholder"),
),
# generate_btn
gr.update(value=i18n("generate_btn_label")),
# generate_audio
gr.update(label=i18n("generated_audio_label")),
]
lang_choice.change(
fn=_change_component_language,
inputs=[lang_choice],
outputs=[
title_desc,
voice_mode_choice,
spk1_prompt_audio,
spk1_prompt_text,
spk2_prompt_audio,
spk2_prompt_text,
dialogue_text_input,
generate_btn,
generate_audio,
],
)
# Voice clone mode action
def _change_prompt_input_visibility(voice_mode):
enable = voice_mode == 0
return [gr.update(visible=enable), gr.update(visible=enable)]
voice_mode_choice.change(
fn=_change_prompt_input_visibility,
inputs=[voice_mode_choice],
outputs=[spk1_prompt_group, spk2_prompt_group],
)
generate_btn.click(
fn=dialogue_synthesis_function,
inputs=[
dialogue_text_input,
voice_mode_choice,
spk1_prompt_text,
spk1_prompt_audio,
spk2_prompt_text,
spk2_prompt_audio,
],
outputs=[generate_audio],
)
return page
if __name__ == "__main__":
# Download model
snapshot_download(repo_id='FireRedTeam/FireRedTTS2', local_dir='pretrained_models/FireRedTTS2')
# Initiate model
initiate_model('pretrained_models/FireRedTTS2')
# UI
page = render_interface()
page.queue()
page.launch()