import re import spaces import gradio as gr from tqdm import tqdm from huggingface_hub import snapshot_download from argparse import ArgumentParser from typing import Literal, List, Tuple from fireredtts2.fireredtts2 import FireRedTTS2 # ================================================ # FireRedTTS2 Model # ================================================ # Global model instance model: FireRedTTS2 = None def initiate_model(pretrained_dir: str, device="cuda"): global model if model is None: model = FireRedTTS2( pretrained_dir=pretrained_dir, gen_type="dialogue", device=device, ) # ================================================ # Gradio # ================================================ # i18n _i18n_key2lang_dict = dict( # Title markdown title_md_desc=dict( en="FireRedTTS-2 🔥 Dialogue Generation", zh="FireRedTTS-2 🔥 对话生成", ), # Voice mode radio voice_mode_label=dict( en="Voice Mode", zh="音色模式", ), voice_model_choice1=dict( en="Voice Clone", zh="音色克隆", ), voice_model_choice2=dict( en="Random Voice", zh="随机音色", ), # Speaker1 Prompt spk1_prompt_audio_label=dict( en="Speaker 1 Prompt Audio", zh="说话人 1 参考语音", ), spk1_prompt_text_label=dict( en="Speaker 1 Prompt Text", zh="说话人 1 参考文本", ), spk1_prompt_text_placeholder=dict( en="[S1] text of speaker 1 prompt audio.", zh="[S1] 说话人 1 参考文本", ), # Speaker2 Prompt spk2_prompt_audio_label=dict( en="Speaker 2 Prompt Audio", zh="说话人 2 参考语音", ), spk2_prompt_text_label=dict( en="Speaker 2 Prompt Text", zh="说话人 2 参考文本", ), spk2_prompt_text_placeholder=dict( en="[S2] text of speaker 2 prompt audio.", zh="[S2] 说话人 2 参考文本", ), # Dialogue input textbox dialogue_text_input_label=dict( en="Dialogue Text Input", zh="对话文本输入", ), dialogue_text_input_placeholder=dict( en="[S1]text[S2]text[S1]text...", zh="[S1]文本[S2]文本[S1]文本...", ), # Generate button generate_btn_label=dict( en="Generate Audio", zh="合成", ), # Generated audio generated_audio_label=dict( en="Generated Dialogue Audio", zh="合成的对话音频", ), # Warining1: invalid text for prompt warn_invalid_spk1_prompt_text=dict( en='Invalid speaker 1 prompt text, should strictly follow: "[S1]xxx"', zh='说话人 1 参考文本不合规,格式:"[S1]xxx"', ), warn_invalid_spk2_prompt_text=dict( en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"', zh='说话人 2 参考文本不合规,格式:"[S2]xxx"', ), # Warining2: invalid text for dialogue input warn_invalid_dialogue_text=dict( en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."', zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."', ), # Warining3: incomplete prompt info warn_incomplete_prompt=dict( en="Please provide prompt audio and text for both speaker 1 and speaker 2", zh="请提供说话人 1 与说话人 2 的参考语音与参考文本", ), ) global_lang: Literal["zh", "en"] = "zh" def i18n(key): global global_lang return _i18n_key2lang_dict[key][global_lang] def check_monologue_text(text: str, prefix: str = None) -> bool: text = text.strip() # Check speaker tags if prefix is not None and (not text.startswith(prefix)): return False # Remove prefix if prefix is not None: text = text.removeprefix(prefix) text = text.strip() # If empty? if len(text) == 0: return False return True def check_dialogue_text(text_list: List[str]) -> bool: if len(text_list) == 0: return False for text in text_list: if not ( check_monologue_text(text, "[S1]") or check_monologue_text(text, "[S2]") or check_monologue_text(text, "[S3]") or check_monologue_text(text, "[S4]") ): return False return True @spaces.GPU(duration=200) def dialogue_synthesis_function( target_text: str, voice_mode: Literal[0, 1] = 0, # 0 means voice clone spk1_prompt_text: str | None = "", spk1_prompt_audio: str | None = None, spk2_prompt_text: str | None = "", spk2_prompt_audio: str | None = None, ): # Voice clone mode, check prompt info if voice_mode == 0: prompt_has_value = [ spk1_prompt_text != "", spk1_prompt_audio is not None, spk2_prompt_text != "", spk2_prompt_audio is not None, ] if not all(prompt_has_value): gr.Warning(message=i18n("warn_incomplete_prompt")) return None if not check_monologue_text(spk1_prompt_text, "[S1]"): gr.Warning(message=i18n("warn_invalid_spk1_prompt_text")) return None if not check_monologue_text(spk2_prompt_text, "[S2]"): gr.Warning(message=i18n("warn_invalid_spk2_prompt_text")) return None # Check dialogue text target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text) target_text_list = [text.strip() for text in target_text_list] if not check_dialogue_text(target_text_list): gr.Warning(message=i18n("warn_invalid_dialogue_text")) return None # Go synthesis progress_bar = gr.Progress(track_tqdm=True) prompt_wav_list = ( None if voice_mode != 0 else [spk1_prompt_audio, spk2_prompt_audio] ) prompt_text_list = None if voice_mode != 0 else [spk1_prompt_text, spk2_prompt_text] target_audio = model.generate_dialogue( text_list=target_text_list, prompt_wav_list=prompt_wav_list, prompt_text_list=prompt_text_list, temperature=0.9, topk=30, ) return (24000, target_audio.squeeze(0).numpy()) # UI rendering def render_interface() -> gr.Blocks: with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page: # ======================== UI ======================== # A large title title_desc = gr.Markdown(value="# {}".format(i18n("title_md_desc"))) with gr.Row(): lang_choice = gr.Radio( choices=["中文", "English"], value="中文", label="Display Language/显示语言", type="index", interactive=True, ) voice_mode_choice = gr.Radio( choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")], value=i18n("voice_model_choice1"), label=i18n("voice_mode_label"), type="index", interactive=True, ) with gr.Row(): # ==== Speaker1 Prompt ==== with gr.Column(scale=1): with gr.Group(visible=True) as spk1_prompt_group: spk1_prompt_audio = gr.Audio( label=i18n("spk1_prompt_audio_label"), type="filepath", editable=False, interactive=True, ) # Audio component returns tmp audio path spk1_prompt_text = gr.Textbox( label=i18n("spk1_prompt_text_label"), placeholder=i18n("spk1_prompt_text_placeholder"), lines=3, ) # ==== Speaker2 Prompt ==== with gr.Column(scale=1): with gr.Group(visible=True) as spk2_prompt_group: spk2_prompt_audio = gr.Audio( label=i18n("spk2_prompt_audio_label"), type="filepath", editable=False, interactive=True, ) spk2_prompt_text = gr.Textbox( label=i18n("spk2_prompt_text_label"), placeholder=i18n("spk2_prompt_text_placeholder"), lines=3, ) # ==== Text input ==== with gr.Column(scale=2): dialogue_text_input = gr.Textbox( label=i18n("dialogue_text_input_label"), placeholder=i18n("dialogue_text_input_placeholder"), lines=18, ) # Generate button generate_btn = gr.Button( value=i18n("generate_btn_label"), variant="primary", size="lg" ) # Long output audio generate_audio = gr.Audio( label=i18n("generated_audio_label"), interactive=False, ) # ======================== Action ======================== # Language action def _change_component_language(lang): global global_lang global_lang = ["zh", "en"][lang] return [ # title_desc gr.update(value="# {}".format(i18n("title_md_desc"))), # voice_mode_choice gr.update( choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")], value=i18n("voice_model_choice1"), label=i18n("voice_mode_label"), ), # spk1_prompt_{audio,text} gr.update(label=i18n("spk1_prompt_audio_label")), gr.update( label=i18n("spk1_prompt_text_label"), placeholder=i18n("spk1_prompt_text_placeholder"), ), # spk2_prompt_{audio,text} gr.update(label=i18n("spk2_prompt_audio_label")), gr.update( label=i18n("spk2_prompt_text_label"), placeholder=i18n("spk2_prompt_text_placeholder"), ), # dialogue_text_input gr.update( label=i18n("dialogue_text_input_label"), placeholder=i18n("dialogue_text_input_placeholder"), ), # generate_btn gr.update(value=i18n("generate_btn_label")), # generate_audio gr.update(label=i18n("generated_audio_label")), ] lang_choice.change( fn=_change_component_language, inputs=[lang_choice], outputs=[ title_desc, voice_mode_choice, spk1_prompt_audio, spk1_prompt_text, spk2_prompt_audio, spk2_prompt_text, dialogue_text_input, generate_btn, generate_audio, ], ) # Voice clone mode action def _change_prompt_input_visibility(voice_mode): enable = voice_mode == 0 return [gr.update(visible=enable), gr.update(visible=enable)] voice_mode_choice.change( fn=_change_prompt_input_visibility, inputs=[voice_mode_choice], outputs=[spk1_prompt_group, spk2_prompt_group], ) generate_btn.click( fn=dialogue_synthesis_function, inputs=[ dialogue_text_input, voice_mode_choice, spk1_prompt_text, spk1_prompt_audio, spk2_prompt_text, spk2_prompt_audio, ], outputs=[generate_audio], ) return page if __name__ == "__main__": # Download model snapshot_download(repo_id='FireRedTeam/FireRedTTS2', local_dir='pretrained_models/FireRedTTS2') # Initiate model initiate_model('pretrained_models/FireRedTTS2') # UI page = render_interface() page.queue() page.launch()