Spaces:
Running
on
Zero
Running
on
Zero
| # Environment setup | |
| from pathlib import Path | |
| import os | |
| import sys | |
| sys.path.append(str(Path(__file__).parent)) | |
| # FIXME add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315 | |
| if os.path.exists('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py'): | |
| file_lines = [] | |
| with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'r') as f: | |
| for line in f: | |
| file_lines.append(line.strip('\n')) | |
| file_lines[314] = file_lines[314].replace( | |
| "state = torch.load(f, map_location=torch.device(\"cpu\"))", | |
| "state = torch.load(f, map_location=torch.device(\"cpu\"), weights_only=False)" | |
| ) | |
| with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'w') as f: | |
| for line in file_lines: | |
| f.write(line+'\n') | |
| print('[DEBUG] added weights_only=False') | |
| # Run | |
| import spaces | |
| import gradio as gr | |
| from zipfile import ZipFile | |
| from typing import Literal | |
| from huggingface_hub import snapshot_download | |
| from fireredtts.models.fireredtts import FireRedTTS | |
| # NOTE disable verbose INFO logs | |
| import logging | |
| httpx_logger = logging.getLogger("httpx") | |
| httpx_logger.setLevel(logging.WARNING) | |
| # NOTE Some launching setups | |
| # - install fairseq manually ("python -m pip install pip==24.0") | |
| # - manually add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315 | |
| # ================================================ | |
| # FireRedTTS1s Model | |
| # ================================================ | |
| # Global model instance | |
| tts_flow: FireRedTTS = None | |
| tts_acollm: FireRedTTS = None | |
| def initiate_model(pretrained_dir: str): | |
| global tts_flow, tts_acollm | |
| if tts_flow is None: | |
| tts_flow = FireRedTTS( | |
| config_path='configs/config_24k_flow.json', | |
| pretrained_path=pretrained_dir, | |
| ) | |
| if tts_acollm is None: | |
| tts_acollm = FireRedTTS( | |
| config_path='configs/config_24k.json', | |
| pretrained_path=pretrained_dir, | |
| ) | |
| # ================================================ | |
| # Gradio | |
| # ================================================ | |
| # i18n | |
| _i18n_key2lang_dict = dict( | |
| # Title markdown | |
| title_md_desc=dict( | |
| en="FireRedTTS-1s 🔥 Streamable TTS", | |
| zh="FireRedTTS-1s 🔥 可流式TTS", | |
| ), | |
| # Decoder choice radio | |
| decoder_choice_label=dict( | |
| en="Decoder Choice", | |
| zh="解码器选择", | |
| ), | |
| decoder_choice_1=dict( | |
| en="Flow Matching", | |
| zh="Flow Matching", | |
| ), | |
| decoder_choice_2=dict( | |
| en="Acoustic LLM", | |
| zh="Acoustic LLM", | |
| ), | |
| # Speaker Prompt | |
| spk_prompt_audio_label=dict( | |
| en="Speaker Prompt Audio", | |
| zh="参考语音", | |
| ), | |
| spk_prompt_text_label=dict( | |
| en="Speaker Prompt Text", | |
| zh="参考语音的文本", | |
| ), | |
| spk_prompt_text_placeholder=dict( | |
| en="Speaker Prompt Text", | |
| zh="参考语音的文本", | |
| ), | |
| # Input textbox | |
| target_text_input_label=dict( | |
| en="Text To Synthesis", | |
| zh="待合成文本", | |
| ), | |
| target_text_input_placeholder=dict( | |
| en="Text To Synthesis", | |
| zh="待合成文本", | |
| ), | |
| # Generate button | |
| generate_btn_label=dict( | |
| en="Generate Audio", | |
| zh="合成", | |
| ), | |
| # Generated audio | |
| generated_audio_label=dict( | |
| en="Generated Audio", | |
| zh="合成的音频", | |
| ), | |
| # Warining1: incomplete prompt info | |
| warn_incomplete_prompt=dict( | |
| en="Please provide prompt audio and text", | |
| zh="请提供说话人参考语音与参考文本", | |
| ), | |
| # Warining2: invalid text for target text input | |
| warn_invalid_target_text=dict( | |
| en="Empty input text", | |
| zh="待合成文本为空", | |
| ), | |
| ) | |
| global_lang: Literal['zh', 'en'] = 'zh' | |
| def i18n(key): | |
| global global_lang | |
| return _i18n_key2lang_dict[key][global_lang] | |
| def check_monologue_text(text:str, prefix:str=None)->bool: | |
| text = text.strip() | |
| # Check speaker tags | |
| if prefix is not None and (not text.startswith(prefix)): | |
| return False | |
| # Remove prefix | |
| if prefix is not None: | |
| text = text.removeprefix(prefix) | |
| text = text.strip() | |
| # If empty? | |
| if len(text) == 0: | |
| return False | |
| return True | |
| def synthesis_function( | |
| spk_prompt_audio: str, | |
| spk_prompt_text: str, | |
| target_text: str, | |
| decoder_choice: Literal[0, 1] = 0, # 0 means flow matching decoder | |
| ): | |
| global tts_flow, tts_acollm | |
| # Check prompt info | |
| spk_prompt_text = spk_prompt_text.strip() | |
| if spk_prompt_audio is None or spk_prompt_text == "": | |
| gr.Warning(message=i18n('warn_incomplete_prompt')) | |
| return None | |
| # Check target text | |
| target_text = target_text.strip() | |
| if target_text == "": | |
| gr.Warning(message=i18n('warn_invalid_target_text')) | |
| return None | |
| # Go synthesis | |
| if decoder_choice == 0: | |
| audio = tts_flow.synthesize( | |
| prompt_wav=spk_prompt_audio, | |
| prompt_text=spk_prompt_text, | |
| text=target_text, | |
| lang="zh", | |
| use_tn=True | |
| ) | |
| else: | |
| audio = tts_acollm.synthesize( | |
| prompt_wav=spk_prompt_audio, | |
| prompt_text=spk_prompt_text, | |
| text=target_text, | |
| lang="zh", | |
| use_tn=True | |
| ) | |
| return (24000, audio.detach().cpu().squeeze(0).numpy()) | |
| # UI rendering | |
| def render_interface()->gr.Blocks: | |
| with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page: | |
| # ======================== UI ======================== | |
| # A large title | |
| title_desc = gr.Markdown(value="# {}".format(i18n('title_md_desc'))) | |
| with gr.Row(): | |
| lang_choice = gr.Radio( | |
| choices=['中文', 'English'], | |
| value='中文', | |
| label='Display Language/显示语言', | |
| type="index", | |
| interactive=True, | |
| ) | |
| decoder_choice = gr.Radio( | |
| choices=[i18n('decoder_choice_1'), i18n('decoder_choice_2')], | |
| value=i18n('decoder_choice_1'), | |
| label=i18n('decoder_choice_label'), | |
| type="index", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| # ==== Speaker Prompt ==== | |
| spk_prompt_text = gr.Textbox( | |
| label=i18n('spk_prompt_text_label'), | |
| placeholder=i18n('spk_prompt_text_placeholder'), | |
| lines=5, | |
| ) | |
| spk_prompt_audio = gr.Audio( | |
| label=i18n('spk_prompt_audio_label'), | |
| type="filepath", | |
| editable=False, | |
| interactive=True, | |
| ) # Audio component returns tmp audio path | |
| # ==== Target Text ==== | |
| target_text_input = gr.Textbox( | |
| label=i18n('target_text_input_label'), | |
| placeholder=i18n('target_text_input_placeholder'), | |
| lines=5, | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button(value=i18n('generate_btn_label'), variant="primary", size="lg") | |
| # Long output audio | |
| generate_audio = gr.Audio( | |
| label=i18n('generated_audio_label'), | |
| interactive=False, | |
| ) | |
| # ======================== Action ======================== | |
| # Language action | |
| def _change_component_language(lang): | |
| global global_lang | |
| global_lang = ['zh', 'en'][lang] | |
| return [ | |
| # title_desc | |
| gr.update(value="# {}".format(i18n('title_md_desc'))), | |
| # decoder_choice | |
| gr.update(label=i18n('decoder_choice_label')), | |
| # spk_prompt_{audio,text} | |
| gr.update(label=i18n('spk_prompt_text_label'), placeholder=i18n('spk_prompt_text_placeholder')), | |
| gr.update(label=i18n('spk_prompt_audio_label')), | |
| # target_text_input | |
| gr.update(label=i18n('target_text_input_label'), placeholder=i18n('target_text_input_placeholder')), | |
| # generate_btn | |
| gr.update(value=i18n('generate_btn_label')), | |
| # generate_audio | |
| gr.update(label=i18n('generated_audio_label')), | |
| ] | |
| lang_choice.change( | |
| fn=_change_component_language, | |
| inputs=[lang_choice], | |
| outputs=[ | |
| title_desc, decoder_choice, | |
| spk_prompt_text, spk_prompt_audio, | |
| target_text_input, | |
| generate_btn, generate_audio, | |
| ] | |
| ) | |
| generate_btn.click( | |
| fn=synthesis_function, | |
| inputs=[spk_prompt_audio, spk_prompt_text, target_text_input, decoder_choice], | |
| outputs=[generate_audio] | |
| ) | |
| return page | |
| if __name__ == '__main__': | |
| # Download model | |
| snapshot_download(repo_id='FireRedTeam/FireRedTTS-1S', local_dir='pretrained_models/FireRedTTS-1S') | |
| # Unzip model, weights under "pretrained_models/FireRedTTS-1S/pretrained_models" | |
| with ZipFile('pretrained_models/FireRedTTS-1S/pretrained_models.zip', 'r') as zipf: | |
| zipf.extractall('pretrained_models/FireRedTTS-1S') | |
| # Init model | |
| initiate_model('pretrained_models/FireRedTTS-1S/pretrained_models') | |
| print('[INFO] model loaded') | |
| # UI | |
| page = render_interface() | |
| page.launch() | |