FireRedTTS-1S / app.py
Shen Feiyu
add 1s
faadabf
# Environment setup
from pathlib import Path
import os
import sys
sys.path.append(str(Path(__file__).parent))
# FIXME add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315
if os.path.exists('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py'):
file_lines = []
with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'r') as f:
for line in f:
file_lines.append(line.strip('\n'))
file_lines[314] = file_lines[314].replace(
"state = torch.load(f, map_location=torch.device(\"cpu\"))",
"state = torch.load(f, map_location=torch.device(\"cpu\"), weights_only=False)"
)
with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'w') as f:
for line in file_lines:
f.write(line+'\n')
print('[DEBUG] added weights_only=False')
# Run
import spaces
import gradio as gr
from zipfile import ZipFile
from typing import Literal
from huggingface_hub import snapshot_download
from fireredtts.models.fireredtts import FireRedTTS
# NOTE disable verbose INFO logs
import logging
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
# NOTE Some launching setups
# - install fairseq manually ("python -m pip install pip==24.0")
# - manually add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315
# ================================================
# FireRedTTS1s Model
# ================================================
# Global model instance
tts_flow: FireRedTTS = None
tts_acollm: FireRedTTS = None
def initiate_model(pretrained_dir: str):
global tts_flow, tts_acollm
if tts_flow is None:
tts_flow = FireRedTTS(
config_path='configs/config_24k_flow.json',
pretrained_path=pretrained_dir,
)
if tts_acollm is None:
tts_acollm = FireRedTTS(
config_path='configs/config_24k.json',
pretrained_path=pretrained_dir,
)
# ================================================
# Gradio
# ================================================
# i18n
_i18n_key2lang_dict = dict(
# Title markdown
title_md_desc=dict(
en="FireRedTTS-1s 🔥 Streamable TTS",
zh="FireRedTTS-1s 🔥 可流式TTS",
),
# Decoder choice radio
decoder_choice_label=dict(
en="Decoder Choice",
zh="解码器选择",
),
decoder_choice_1=dict(
en="Flow Matching",
zh="Flow Matching",
),
decoder_choice_2=dict(
en="Acoustic LLM",
zh="Acoustic LLM",
),
# Speaker Prompt
spk_prompt_audio_label=dict(
en="Speaker Prompt Audio",
zh="参考语音",
),
spk_prompt_text_label=dict(
en="Speaker Prompt Text",
zh="参考语音的文本",
),
spk_prompt_text_placeholder=dict(
en="Speaker Prompt Text",
zh="参考语音的文本",
),
# Input textbox
target_text_input_label=dict(
en="Text To Synthesis",
zh="待合成文本",
),
target_text_input_placeholder=dict(
en="Text To Synthesis",
zh="待合成文本",
),
# Generate button
generate_btn_label=dict(
en="Generate Audio",
zh="合成",
),
# Generated audio
generated_audio_label=dict(
en="Generated Audio",
zh="合成的音频",
),
# Warining1: incomplete prompt info
warn_incomplete_prompt=dict(
en="Please provide prompt audio and text",
zh="请提供说话人参考语音与参考文本",
),
# Warining2: invalid text for target text input
warn_invalid_target_text=dict(
en="Empty input text",
zh="待合成文本为空",
),
)
global_lang: Literal['zh', 'en'] = 'zh'
def i18n(key):
global global_lang
return _i18n_key2lang_dict[key][global_lang]
def check_monologue_text(text:str, prefix:str=None)->bool:
text = text.strip()
# Check speaker tags
if prefix is not None and (not text.startswith(prefix)):
return False
# Remove prefix
if prefix is not None:
text = text.removeprefix(prefix)
text = text.strip()
# If empty?
if len(text) == 0:
return False
return True
@spaces.GPU(duration=60)
def synthesis_function(
spk_prompt_audio: str,
spk_prompt_text: str,
target_text: str,
decoder_choice: Literal[0, 1] = 0, # 0 means flow matching decoder
):
global tts_flow, tts_acollm
# Check prompt info
spk_prompt_text = spk_prompt_text.strip()
if spk_prompt_audio is None or spk_prompt_text == "":
gr.Warning(message=i18n('warn_incomplete_prompt'))
return None
# Check target text
target_text = target_text.strip()
if target_text == "":
gr.Warning(message=i18n('warn_invalid_target_text'))
return None
# Go synthesis
if decoder_choice == 0:
audio = tts_flow.synthesize(
prompt_wav=spk_prompt_audio,
prompt_text=spk_prompt_text,
text=target_text,
lang="zh",
use_tn=True
)
else:
audio = tts_acollm.synthesize(
prompt_wav=spk_prompt_audio,
prompt_text=spk_prompt_text,
text=target_text,
lang="zh",
use_tn=True
)
return (24000, audio.detach().cpu().squeeze(0).numpy())
# UI rendering
def render_interface()->gr.Blocks:
with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page:
# ======================== UI ========================
# A large title
title_desc = gr.Markdown(value="# {}".format(i18n('title_md_desc')))
with gr.Row():
lang_choice = gr.Radio(
choices=['中文', 'English'],
value='中文',
label='Display Language/显示语言',
type="index",
interactive=True,
)
decoder_choice = gr.Radio(
choices=[i18n('decoder_choice_1'), i18n('decoder_choice_2')],
value=i18n('decoder_choice_1'),
label=i18n('decoder_choice_label'),
type="index",
interactive=True,
)
with gr.Row():
# ==== Speaker Prompt ====
spk_prompt_text = gr.Textbox(
label=i18n('spk_prompt_text_label'),
placeholder=i18n('spk_prompt_text_placeholder'),
lines=5,
)
spk_prompt_audio = gr.Audio(
label=i18n('spk_prompt_audio_label'),
type="filepath",
editable=False,
interactive=True,
) # Audio component returns tmp audio path
# ==== Target Text ====
target_text_input = gr.Textbox(
label=i18n('target_text_input_label'),
placeholder=i18n('target_text_input_placeholder'),
lines=5,
)
# Generate button
generate_btn = gr.Button(value=i18n('generate_btn_label'), variant="primary", size="lg")
# Long output audio
generate_audio = gr.Audio(
label=i18n('generated_audio_label'),
interactive=False,
)
# ======================== Action ========================
# Language action
def _change_component_language(lang):
global global_lang
global_lang = ['zh', 'en'][lang]
return [
# title_desc
gr.update(value="# {}".format(i18n('title_md_desc'))),
# decoder_choice
gr.update(label=i18n('decoder_choice_label')),
# spk_prompt_{audio,text}
gr.update(label=i18n('spk_prompt_text_label'), placeholder=i18n('spk_prompt_text_placeholder')),
gr.update(label=i18n('spk_prompt_audio_label')),
# target_text_input
gr.update(label=i18n('target_text_input_label'), placeholder=i18n('target_text_input_placeholder')),
# generate_btn
gr.update(value=i18n('generate_btn_label')),
# generate_audio
gr.update(label=i18n('generated_audio_label')),
]
lang_choice.change(
fn=_change_component_language,
inputs=[lang_choice],
outputs=[
title_desc, decoder_choice,
spk_prompt_text, spk_prompt_audio,
target_text_input,
generate_btn, generate_audio,
]
)
generate_btn.click(
fn=synthesis_function,
inputs=[spk_prompt_audio, spk_prompt_text, target_text_input, decoder_choice],
outputs=[generate_audio]
)
return page
if __name__ == '__main__':
# Download model
snapshot_download(repo_id='FireRedTeam/FireRedTTS-1S', local_dir='pretrained_models/FireRedTTS-1S')
# Unzip model, weights under "pretrained_models/FireRedTTS-1S/pretrained_models"
with ZipFile('pretrained_models/FireRedTTS-1S/pretrained_models.zip', 'r') as zipf:
zipf.extractall('pretrained_models/FireRedTTS-1S')
# Init model
initiate_model('pretrained_models/FireRedTTS-1S/pretrained_models')
print('[INFO] model loaded')
# UI
page = render_interface()
page.launch()