Spaces:
Running
on
Zero
Running
on
Zero
Shen Feiyu
commited on
Commit
·
71cd91e
1
Parent(s):
a4ec42e
init at 250916
Browse files- README.md +4 -4
- app.py +357 -0
- fireredtts2/__init__.py +0 -0
- fireredtts2/codec/__init__.py +1 -0
- fireredtts2/codec/audio.py +148 -0
- fireredtts2/codec/decoder.py +700 -0
- fireredtts2/codec/model.py +376 -0
- fireredtts2/codec/rvq.py +164 -0
- fireredtts2/codec/utils.py +38 -0
- fireredtts2/codec/whisper.py +420 -0
- fireredtts2/fireredtts2.py +459 -0
- fireredtts2/llm/__init__.py +1 -0
- fireredtts2/llm/llm.py +371 -0
- fireredtts2/llm/modules.py +90 -0
- fireredtts2/llm/utils.py +303 -0
- fireredtts2/utils/spliter.py +289 -0
- pretrained_models/README.md +1 -0
- requirements.txt +6 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.45.0
|
| 8 |
app_file: app.py
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Tts2 Test
|
| 3 |
+
emoji: 🌖
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.45.0
|
| 8 |
app_file: app.py
|
app.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import spaces
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
from typing import Literal, List, Tuple
|
| 8 |
+
from fireredtts2.fireredtts2 import FireRedTTS2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ================================================
|
| 12 |
+
# FireRedTTS2 Model
|
| 13 |
+
# ================================================
|
| 14 |
+
# Global model instance
|
| 15 |
+
model: FireRedTTS2 = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def initiate_model(pretrained_dir: str, device="cuda"):
|
| 19 |
+
global model
|
| 20 |
+
if model is None:
|
| 21 |
+
model = FireRedTTS2(
|
| 22 |
+
pretrained_dir=pretrained_dir,
|
| 23 |
+
gen_type="dialogue",
|
| 24 |
+
device=device,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ================================================
|
| 29 |
+
# Gradio
|
| 30 |
+
# ================================================
|
| 31 |
+
|
| 32 |
+
# i18n
|
| 33 |
+
_i18n_key2lang_dict = dict(
|
| 34 |
+
# Title markdown
|
| 35 |
+
title_md_desc=dict(
|
| 36 |
+
en="FireRedTTS-2 🔥 Dialogue Generation",
|
| 37 |
+
zh="FireRedTTS-2 🔥 对话生成",
|
| 38 |
+
),
|
| 39 |
+
# Voice mode radio
|
| 40 |
+
voice_mode_label=dict(
|
| 41 |
+
en="Voice Mode",
|
| 42 |
+
zh="音色模式",
|
| 43 |
+
),
|
| 44 |
+
voice_model_choice1=dict(
|
| 45 |
+
en="Voice Clone",
|
| 46 |
+
zh="音色克隆",
|
| 47 |
+
),
|
| 48 |
+
voice_model_choice2=dict(
|
| 49 |
+
en="Random Voice",
|
| 50 |
+
zh="随机音色",
|
| 51 |
+
),
|
| 52 |
+
# Speaker1 Prompt
|
| 53 |
+
spk1_prompt_audio_label=dict(
|
| 54 |
+
en="Speaker 1 Prompt Audio",
|
| 55 |
+
zh="说话人 1 参考语音",
|
| 56 |
+
),
|
| 57 |
+
spk1_prompt_text_label=dict(
|
| 58 |
+
en="Speaker 1 Prompt Text",
|
| 59 |
+
zh="说话人 1 参考文本",
|
| 60 |
+
),
|
| 61 |
+
spk1_prompt_text_placeholder=dict(
|
| 62 |
+
en="[S1] text of speaker 1 prompt audio.",
|
| 63 |
+
zh="[S1] 说话人 1 参考文本",
|
| 64 |
+
),
|
| 65 |
+
# Speaker2 Prompt
|
| 66 |
+
spk2_prompt_audio_label=dict(
|
| 67 |
+
en="Speaker 2 Prompt Audio",
|
| 68 |
+
zh="说话人 2 参考语音",
|
| 69 |
+
),
|
| 70 |
+
spk2_prompt_text_label=dict(
|
| 71 |
+
en="Speaker 2 Prompt Text",
|
| 72 |
+
zh="说话人 2 参考文本",
|
| 73 |
+
),
|
| 74 |
+
spk2_prompt_text_placeholder=dict(
|
| 75 |
+
en="[S2] text of speaker 2 prompt audio.",
|
| 76 |
+
zh="[S2] 说话人 2 参考文本",
|
| 77 |
+
),
|
| 78 |
+
# Dialogue input textbox
|
| 79 |
+
dialogue_text_input_label=dict(
|
| 80 |
+
en="Dialogue Text Input",
|
| 81 |
+
zh="对话文本输入",
|
| 82 |
+
),
|
| 83 |
+
dialogue_text_input_placeholder=dict(
|
| 84 |
+
en="[S1]text[S2]text[S1]text...",
|
| 85 |
+
zh="[S1]文本[S2]文本[S1]文本...",
|
| 86 |
+
),
|
| 87 |
+
# Generate button
|
| 88 |
+
generate_btn_label=dict(
|
| 89 |
+
en="Generate Audio",
|
| 90 |
+
zh="合成",
|
| 91 |
+
),
|
| 92 |
+
# Generated audio
|
| 93 |
+
generated_audio_label=dict(
|
| 94 |
+
en="Generated Dialogue Audio",
|
| 95 |
+
zh="合成的对话音频",
|
| 96 |
+
),
|
| 97 |
+
# Warining1: invalid text for prompt
|
| 98 |
+
warn_invalid_spk1_prompt_text=dict(
|
| 99 |
+
en='Invalid speaker 1 prompt text, should strictly follow: "[S1]xxx"',
|
| 100 |
+
zh='说话人 1 参考文本不合规,格式:"[S1]xxx"',
|
| 101 |
+
),
|
| 102 |
+
warn_invalid_spk2_prompt_text=dict(
|
| 103 |
+
en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"',
|
| 104 |
+
zh='说话人 2 参考文本不合规,格式:"[S2]xxx"',
|
| 105 |
+
),
|
| 106 |
+
# Warining2: invalid text for dialogue input
|
| 107 |
+
warn_invalid_dialogue_text=dict(
|
| 108 |
+
en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."',
|
| 109 |
+
zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."',
|
| 110 |
+
),
|
| 111 |
+
# Warining3: incomplete prompt info
|
| 112 |
+
warn_incomplete_prompt=dict(
|
| 113 |
+
en="Please provide prompt audio and text for both speaker 1 and speaker 2",
|
| 114 |
+
zh="请提供说话人 1 与说话人 2 的参考语音与参考文本",
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
global_lang: Literal["zh", "en"] = "zh"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def i18n(key):
|
| 122 |
+
global global_lang
|
| 123 |
+
return _i18n_key2lang_dict[key][global_lang]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def check_monologue_text(text: str, prefix: str = None) -> bool:
|
| 127 |
+
text = text.strip()
|
| 128 |
+
# Check speaker tags
|
| 129 |
+
if prefix is not None and (not text.startswith(prefix)):
|
| 130 |
+
return False
|
| 131 |
+
# Remove prefix
|
| 132 |
+
if prefix is not None:
|
| 133 |
+
text = text.removeprefix(prefix)
|
| 134 |
+
text = text.strip()
|
| 135 |
+
# If empty?
|
| 136 |
+
if len(text) == 0:
|
| 137 |
+
return False
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def check_dialogue_text(text_list: List[str]) -> bool:
|
| 142 |
+
if len(text_list) == 0:
|
| 143 |
+
return False
|
| 144 |
+
for text in text_list:
|
| 145 |
+
if not (
|
| 146 |
+
check_monologue_text(text, "[S1]")
|
| 147 |
+
or check_monologue_text(text, "[S2]")
|
| 148 |
+
or check_monologue_text(text, "[S3]")
|
| 149 |
+
or check_monologue_text(text, "[S4]")
|
| 150 |
+
):
|
| 151 |
+
return False
|
| 152 |
+
return True
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@spaces.GPU(duration=200)
|
| 156 |
+
def dialogue_synthesis_function(
|
| 157 |
+
target_text: str,
|
| 158 |
+
voice_mode: Literal[0, 1] = 0, # 0 means voice clone
|
| 159 |
+
spk1_prompt_text: str | None = "",
|
| 160 |
+
spk1_prompt_audio: str | None = None,
|
| 161 |
+
spk2_prompt_text: str | None = "",
|
| 162 |
+
spk2_prompt_audio: str | None = None,
|
| 163 |
+
):
|
| 164 |
+
# Voice clone mode, check prompt info
|
| 165 |
+
if voice_mode == 0:
|
| 166 |
+
prompt_has_value = [
|
| 167 |
+
spk1_prompt_text != "",
|
| 168 |
+
spk1_prompt_audio is not None,
|
| 169 |
+
spk2_prompt_text != "",
|
| 170 |
+
spk2_prompt_audio is not None,
|
| 171 |
+
]
|
| 172 |
+
if not all(prompt_has_value):
|
| 173 |
+
gr.Warning(message=i18n("warn_incomplete_prompt"))
|
| 174 |
+
return None
|
| 175 |
+
if not check_monologue_text(spk1_prompt_text, "[S1]"):
|
| 176 |
+
gr.Warning(message=i18n("warn_invalid_spk1_prompt_text"))
|
| 177 |
+
return None
|
| 178 |
+
if not check_monologue_text(spk2_prompt_text, "[S2]"):
|
| 179 |
+
gr.Warning(message=i18n("warn_invalid_spk2_prompt_text"))
|
| 180 |
+
return None
|
| 181 |
+
# Check dialogue text
|
| 182 |
+
target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text)
|
| 183 |
+
target_text_list = [text.strip() for text in target_text_list]
|
| 184 |
+
if not check_dialogue_text(target_text_list):
|
| 185 |
+
gr.Warning(message=i18n("warn_invalid_dialogue_text"))
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
# Go synthesis
|
| 189 |
+
progress_bar = gr.Progress(track_tqdm=True)
|
| 190 |
+
prompt_wav_list = (
|
| 191 |
+
None if voice_mode != 0 else [spk1_prompt_audio, spk2_prompt_audio]
|
| 192 |
+
)
|
| 193 |
+
prompt_text_list = None if voice_mode != 0 else [spk1_prompt_text, spk2_prompt_text]
|
| 194 |
+
target_audio = model.generate_dialogue(
|
| 195 |
+
text_list=target_text_list,
|
| 196 |
+
prompt_wav_list=prompt_wav_list,
|
| 197 |
+
prompt_text_list=prompt_text_list,
|
| 198 |
+
temperature=0.9,
|
| 199 |
+
topk=30,
|
| 200 |
+
)
|
| 201 |
+
return (24000, target_audio.squeeze(0).numpy())
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# UI rendering
|
| 205 |
+
def render_interface() -> gr.Blocks:
|
| 206 |
+
with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page:
|
| 207 |
+
# ======================== UI ========================
|
| 208 |
+
# A large title
|
| 209 |
+
title_desc = gr.Markdown(value="# {}".format(i18n("title_md_desc")))
|
| 210 |
+
with gr.Row():
|
| 211 |
+
lang_choice = gr.Radio(
|
| 212 |
+
choices=["中文", "English"],
|
| 213 |
+
value="中文",
|
| 214 |
+
label="Display Language/显示语言",
|
| 215 |
+
type="index",
|
| 216 |
+
interactive=True,
|
| 217 |
+
)
|
| 218 |
+
voice_mode_choice = gr.Radio(
|
| 219 |
+
choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")],
|
| 220 |
+
value=i18n("voice_model_choice1"),
|
| 221 |
+
label=i18n("voice_mode_label"),
|
| 222 |
+
type="index",
|
| 223 |
+
interactive=True,
|
| 224 |
+
)
|
| 225 |
+
with gr.Row():
|
| 226 |
+
# ==== Speaker1 Prompt ====
|
| 227 |
+
with gr.Column(scale=1):
|
| 228 |
+
with gr.Group(visible=True) as spk1_prompt_group:
|
| 229 |
+
spk1_prompt_audio = gr.Audio(
|
| 230 |
+
label=i18n("spk1_prompt_audio_label"),
|
| 231 |
+
type="filepath",
|
| 232 |
+
editable=False,
|
| 233 |
+
interactive=True,
|
| 234 |
+
) # Audio component returns tmp audio path
|
| 235 |
+
spk1_prompt_text = gr.Textbox(
|
| 236 |
+
label=i18n("spk1_prompt_text_label"),
|
| 237 |
+
placeholder=i18n("spk1_prompt_text_placeholder"),
|
| 238 |
+
lines=3,
|
| 239 |
+
)
|
| 240 |
+
# ==== Speaker2 Prompt ====
|
| 241 |
+
with gr.Column(scale=1):
|
| 242 |
+
with gr.Group(visible=True) as spk2_prompt_group:
|
| 243 |
+
spk2_prompt_audio = gr.Audio(
|
| 244 |
+
label=i18n("spk2_prompt_audio_label"),
|
| 245 |
+
type="filepath",
|
| 246 |
+
editable=False,
|
| 247 |
+
interactive=True,
|
| 248 |
+
)
|
| 249 |
+
spk2_prompt_text = gr.Textbox(
|
| 250 |
+
label=i18n("spk2_prompt_text_label"),
|
| 251 |
+
placeholder=i18n("spk2_prompt_text_placeholder"),
|
| 252 |
+
lines=3,
|
| 253 |
+
)
|
| 254 |
+
# ==== Text input ====
|
| 255 |
+
with gr.Column(scale=2):
|
| 256 |
+
dialogue_text_input = gr.Textbox(
|
| 257 |
+
label=i18n("dialogue_text_input_label"),
|
| 258 |
+
placeholder=i18n("dialogue_text_input_placeholder"),
|
| 259 |
+
lines=18,
|
| 260 |
+
)
|
| 261 |
+
# Generate button
|
| 262 |
+
generate_btn = gr.Button(
|
| 263 |
+
value=i18n("generate_btn_label"), variant="primary", size="lg"
|
| 264 |
+
)
|
| 265 |
+
# Long output audio
|
| 266 |
+
generate_audio = gr.Audio(
|
| 267 |
+
label=i18n("generated_audio_label"),
|
| 268 |
+
interactive=False,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# ======================== Action ========================
|
| 272 |
+
# Language action
|
| 273 |
+
def _change_component_language(lang):
|
| 274 |
+
global global_lang
|
| 275 |
+
global_lang = ["zh", "en"][lang]
|
| 276 |
+
return [
|
| 277 |
+
# title_desc
|
| 278 |
+
gr.update(value="# {}".format(i18n("title_md_desc"))),
|
| 279 |
+
# voice_mode_choice
|
| 280 |
+
gr.update(
|
| 281 |
+
choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")],
|
| 282 |
+
value=i18n("voice_model_choice1"),
|
| 283 |
+
label=i18n("voice_mode_label"),
|
| 284 |
+
),
|
| 285 |
+
# spk1_prompt_{audio,text}
|
| 286 |
+
gr.update(label=i18n("spk1_prompt_audio_label")),
|
| 287 |
+
gr.update(
|
| 288 |
+
label=i18n("spk1_prompt_text_label"),
|
| 289 |
+
placeholder=i18n("spk1_prompt_text_placeholder"),
|
| 290 |
+
),
|
| 291 |
+
# spk2_prompt_{audio,text}
|
| 292 |
+
gr.update(label=i18n("spk2_prompt_audio_label")),
|
| 293 |
+
gr.update(
|
| 294 |
+
label=i18n("spk2_prompt_text_label"),
|
| 295 |
+
placeholder=i18n("spk2_prompt_text_placeholder"),
|
| 296 |
+
),
|
| 297 |
+
# dialogue_text_input
|
| 298 |
+
gr.update(
|
| 299 |
+
label=i18n("dialogue_text_input_label"),
|
| 300 |
+
placeholder=i18n("dialogue_text_input_placeholder"),
|
| 301 |
+
),
|
| 302 |
+
# generate_btn
|
| 303 |
+
gr.update(value=i18n("generate_btn_label")),
|
| 304 |
+
# generate_audio
|
| 305 |
+
gr.update(label=i18n("generated_audio_label")),
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
lang_choice.change(
|
| 309 |
+
fn=_change_component_language,
|
| 310 |
+
inputs=[lang_choice],
|
| 311 |
+
outputs=[
|
| 312 |
+
title_desc,
|
| 313 |
+
voice_mode_choice,
|
| 314 |
+
spk1_prompt_audio,
|
| 315 |
+
spk1_prompt_text,
|
| 316 |
+
spk2_prompt_audio,
|
| 317 |
+
spk2_prompt_text,
|
| 318 |
+
dialogue_text_input,
|
| 319 |
+
generate_btn,
|
| 320 |
+
generate_audio,
|
| 321 |
+
],
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Voice clone mode action
|
| 325 |
+
def _change_prompt_input_visibility(voice_mode):
|
| 326 |
+
enable = voice_mode == 0
|
| 327 |
+
return [gr.update(visible=enable), gr.update(visible=enable)]
|
| 328 |
+
|
| 329 |
+
voice_mode_choice.change(
|
| 330 |
+
fn=_change_prompt_input_visibility,
|
| 331 |
+
inputs=[voice_mode_choice],
|
| 332 |
+
outputs=[spk1_prompt_group, spk2_prompt_group],
|
| 333 |
+
)
|
| 334 |
+
generate_btn.click(
|
| 335 |
+
fn=dialogue_synthesis_function,
|
| 336 |
+
inputs=[
|
| 337 |
+
dialogue_text_input,
|
| 338 |
+
voice_mode_choice,
|
| 339 |
+
spk1_prompt_text,
|
| 340 |
+
spk1_prompt_audio,
|
| 341 |
+
spk2_prompt_text,
|
| 342 |
+
spk2_prompt_audio,
|
| 343 |
+
],
|
| 344 |
+
outputs=[generate_audio],
|
| 345 |
+
)
|
| 346 |
+
return page
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
# Download model
|
| 351 |
+
snapshot_download(repo_id='FireRedTeam/FireRedTTS2', local_dir='pretrained_models/FireRedTTS2')
|
| 352 |
+
# Initiate model
|
| 353 |
+
initiate_model('pretrained_models/FireRedTTS2')
|
| 354 |
+
# UI
|
| 355 |
+
page = render_interface()
|
| 356 |
+
page.queue()
|
| 357 |
+
page.launch()
|
fireredtts2/__init__.py
ADDED
|
File without changes
|
fireredtts2/codec/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from fireredtts2.codec.model import RedCodecInfer
|
fireredtts2/codec/audio.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
|
| 17 |
+
and remove unnecessary dependencies.
|
| 18 |
+
"""
|
| 19 |
+
import warnings
|
| 20 |
+
import numpy as np
|
| 21 |
+
from typing import Union, Optional
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def hertz_to_mel(
|
| 25 |
+
freq: Union[float, np.ndarray], mel_scale: str = "htk"
|
| 26 |
+
) -> Union[float, np.ndarray]:
|
| 27 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
| 28 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
| 29 |
+
|
| 30 |
+
if mel_scale == "htk":
|
| 31 |
+
return 2595.0 * np.log10(1.0 + (freq / 700.0))
|
| 32 |
+
elif mel_scale == "kaldi":
|
| 33 |
+
return 1127.0 * np.log(1.0 + (freq / 700.0))
|
| 34 |
+
|
| 35 |
+
min_log_hertz = 1000.0
|
| 36 |
+
min_log_mel = 15.0
|
| 37 |
+
logstep = 27.0 / np.log(6.4)
|
| 38 |
+
mels = 3.0 * freq / 200.0
|
| 39 |
+
|
| 40 |
+
if isinstance(freq, np.ndarray):
|
| 41 |
+
log_region = freq >= min_log_hertz
|
| 42 |
+
mels[log_region] = (
|
| 43 |
+
min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
|
| 44 |
+
)
|
| 45 |
+
elif freq >= min_log_hertz:
|
| 46 |
+
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
|
| 47 |
+
|
| 48 |
+
return mels
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def mel_to_hertz(
|
| 52 |
+
mels: Union[float, np.ndarray], mel_scale: str = "htk"
|
| 53 |
+
) -> Union[float, np.ndarray]:
|
| 54 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
| 55 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
| 56 |
+
|
| 57 |
+
if mel_scale == "htk":
|
| 58 |
+
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
|
| 59 |
+
elif mel_scale == "kaldi":
|
| 60 |
+
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
|
| 61 |
+
|
| 62 |
+
min_log_hertz = 1000.0
|
| 63 |
+
min_log_mel = 15.0
|
| 64 |
+
logstep = np.log(6.4) / 27.0
|
| 65 |
+
freq = 200.0 * mels / 3.0
|
| 66 |
+
|
| 67 |
+
if isinstance(mels, np.ndarray):
|
| 68 |
+
log_region = mels >= min_log_mel
|
| 69 |
+
freq[log_region] = min_log_hertz * np.exp(
|
| 70 |
+
logstep * (mels[log_region] - min_log_mel)
|
| 71 |
+
)
|
| 72 |
+
elif mels >= min_log_mel:
|
| 73 |
+
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
|
| 74 |
+
|
| 75 |
+
return freq
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _create_triangular_filter_bank(
|
| 79 |
+
fft_freqs: np.ndarray, filter_freqs: np.ndarray
|
| 80 |
+
) -> np.ndarray:
|
| 81 |
+
"""
|
| 82 |
+
Creates a triangular filter bank.
|
| 83 |
+
|
| 84 |
+
Adapted from *torchaudio* and *librosa*.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
|
| 88 |
+
Discrete frequencies of the FFT bins in Hz.
|
| 89 |
+
filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
|
| 90 |
+
Center frequencies of the triangular filters to create, in Hz.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
`np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
|
| 94 |
+
"""
|
| 95 |
+
filter_diff = np.diff(filter_freqs)
|
| 96 |
+
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
|
| 97 |
+
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
|
| 98 |
+
up_slopes = slopes[:, 2:] / filter_diff[1:]
|
| 99 |
+
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def mel_filter_bank(
|
| 103 |
+
num_frequency_bins: int,
|
| 104 |
+
num_mel_filters: int,
|
| 105 |
+
min_frequency: float,
|
| 106 |
+
max_frequency: float,
|
| 107 |
+
sampling_rate: int,
|
| 108 |
+
norm: Optional[str] = None,
|
| 109 |
+
mel_scale: str = "htk",
|
| 110 |
+
triangularize_in_mel_space: bool = False,
|
| 111 |
+
) -> np.ndarray:
|
| 112 |
+
if norm is not None and norm != "slaney":
|
| 113 |
+
raise ValueError('norm must be one of None or "slaney"')
|
| 114 |
+
|
| 115 |
+
# center points of the triangular mel filters
|
| 116 |
+
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
| 117 |
+
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
| 118 |
+
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
|
| 119 |
+
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
|
| 120 |
+
|
| 121 |
+
if triangularize_in_mel_space:
|
| 122 |
+
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
| 123 |
+
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
|
| 124 |
+
fft_freqs = hertz_to_mel(
|
| 125 |
+
fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale
|
| 126 |
+
)
|
| 127 |
+
filter_freqs = mel_freqs
|
| 128 |
+
else:
|
| 129 |
+
# frequencies of FFT bins in Hz
|
| 130 |
+
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
| 131 |
+
|
| 132 |
+
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
|
| 133 |
+
|
| 134 |
+
if norm is not None and norm == "slaney":
|
| 135 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
| 136 |
+
enorm = 2.0 / (
|
| 137 |
+
filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]
|
| 138 |
+
)
|
| 139 |
+
mel_filters *= np.expand_dims(enorm, 0)
|
| 140 |
+
|
| 141 |
+
if (mel_filters.max(axis=0) == 0.0).any():
|
| 142 |
+
warnings.warn(
|
| 143 |
+
"At least one mel filter has all zero values. "
|
| 144 |
+
f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
|
| 145 |
+
f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return mel_filters
|
fireredtts2/codec/decoder.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from fireredtts2.codec.whisper import WhisperEncoderLayer
|
| 5 |
+
from fireredtts2.codec.utils import make_nonpad_mask, make_block_causal_mask
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResnetBlock(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_channels: int,
|
| 12 |
+
out_channels: int = None,
|
| 13 |
+
conv_shortcut: bool = False,
|
| 14 |
+
dropout: float = 0.0,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.in_channels = in_channels
|
| 18 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 19 |
+
self.out_channels = out_channels
|
| 20 |
+
self.use_conv_shortcut = conv_shortcut
|
| 21 |
+
|
| 22 |
+
self.block1 = nn.Sequential(
|
| 23 |
+
nn.GroupNorm(
|
| 24 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 25 |
+
),
|
| 26 |
+
nn.SiLU(),
|
| 27 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.block2 = nn.Sequential(
|
| 31 |
+
nn.GroupNorm(
|
| 32 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
| 33 |
+
),
|
| 34 |
+
nn.SiLU(),
|
| 35 |
+
nn.Dropout(dropout),
|
| 36 |
+
nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if self.in_channels != self.out_channels:
|
| 40 |
+
if self.use_conv_shortcut:
|
| 41 |
+
self.conv_shortcut = torch.nn.Conv1d(
|
| 42 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
self.nin_shortcut = torch.nn.Conv1d(
|
| 46 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor):
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
x: shape (b, c, t)
|
| 53 |
+
"""
|
| 54 |
+
h = x
|
| 55 |
+
h = self.block1(h)
|
| 56 |
+
h = self.block2(h)
|
| 57 |
+
|
| 58 |
+
if self.in_channels != self.out_channels:
|
| 59 |
+
if self.use_conv_shortcut:
|
| 60 |
+
x = self.conv_shortcut(x)
|
| 61 |
+
else:
|
| 62 |
+
x = self.nin_shortcut(x)
|
| 63 |
+
return x + h
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Transpose(torch.nn.Module):
|
| 67 |
+
def __init__(self, dim0: int, dim1: int):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.dim0 = dim0
|
| 70 |
+
self.dim1 = dim1
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor):
|
| 73 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# A causal variant of Conv1d
|
| 78 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
in_channels: int,
|
| 82 |
+
out_channels: int,
|
| 83 |
+
kernel_size: int,
|
| 84 |
+
) -> None:
|
| 85 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
|
| 86 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor):
|
| 89 |
+
x = F.pad(x, self.causal_padding)
|
| 90 |
+
x = super(CausalConv1d, self).forward(x)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor = None):
|
| 94 |
+
if cnn_cache is None:
|
| 95 |
+
cnn_cache = x.new_zeros(
|
| 96 |
+
(x.shape[0], self.in_channels, self.causal_padding[0])
|
| 97 |
+
)
|
| 98 |
+
x = torch.cat([cnn_cache, x], dim=2)
|
| 99 |
+
new_cnn_cache = x[..., -self.causal_padding[0] :]
|
| 100 |
+
x = super(CausalConv1d, self).forward(x)
|
| 101 |
+
return x, new_cnn_cache
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# A causal variant of ResnetBlock
|
| 105 |
+
class CausalResnetBlock(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
in_channels: int,
|
| 109 |
+
out_channels: int = None,
|
| 110 |
+
dropout: float = 0.0,
|
| 111 |
+
):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.in_channels = in_channels
|
| 114 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 115 |
+
self.out_channels = out_channels
|
| 116 |
+
|
| 117 |
+
self.block1 = nn.Sequential(
|
| 118 |
+
Transpose(1, 2),
|
| 119 |
+
nn.LayerNorm(in_channels),
|
| 120 |
+
Transpose(1, 2),
|
| 121 |
+
nn.SiLU(),
|
| 122 |
+
CausalConv1d(in_channels, out_channels, kernel_size=3),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.block2 = nn.Sequential(
|
| 126 |
+
Transpose(1, 2),
|
| 127 |
+
nn.LayerNorm(out_channels),
|
| 128 |
+
Transpose(1, 2),
|
| 129 |
+
nn.SiLU(),
|
| 130 |
+
nn.Dropout(dropout),
|
| 131 |
+
CausalConv1d(out_channels, out_channels, kernel_size=3),
|
| 132 |
+
)
|
| 133 |
+
if self.in_channels != self.out_channels:
|
| 134 |
+
self.nin_shortcut = torch.nn.Conv1d(
|
| 135 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Args:
|
| 141 |
+
x: shape (b, c, t)
|
| 142 |
+
"""
|
| 143 |
+
h = x
|
| 144 |
+
h = self.block1(h)
|
| 145 |
+
h = self.block2(h)
|
| 146 |
+
if self.in_channels != self.out_channels:
|
| 147 |
+
x = self.nin_shortcut(x)
|
| 148 |
+
return x + h
|
| 149 |
+
|
| 150 |
+
def forward_chunk(self, x: torch.Tensor, cache: torch.Tensor = None):
|
| 151 |
+
"""
|
| 152 |
+
Args:
|
| 153 |
+
x: shape (b, c, t)
|
| 154 |
+
cache: shape (b, c_in+c_out, t=2)
|
| 155 |
+
"""
|
| 156 |
+
cache1, cache2 = (
|
| 157 |
+
(None, None)
|
| 158 |
+
if cache is None
|
| 159 |
+
else cache.split((self.in_channels, self.out_channels), dim=1)
|
| 160 |
+
)
|
| 161 |
+
h = x
|
| 162 |
+
# block1
|
| 163 |
+
h = self.block1[:4](h)
|
| 164 |
+
h, new_cache1 = self.block1[4].forward_chunk(h, cache1)
|
| 165 |
+
# block2
|
| 166 |
+
h = self.block2[:5](h)
|
| 167 |
+
h, new_cache2 = self.block2[5].forward_chunk(h, cache2)
|
| 168 |
+
if self.in_channels != self.out_channels:
|
| 169 |
+
x = self.nin_shortcut(x)
|
| 170 |
+
new_cache = torch.cat([new_cache1, new_cache2], dim=1)
|
| 171 |
+
return x + h, new_cache
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# Nonstreaming Vocos backbone based on Transformer layers
|
| 175 |
+
class VocosBackbone(nn.Module):
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
embed_dim: int = 1024,
|
| 179 |
+
num_layers: int = 12,
|
| 180 |
+
num_heads: int = 16,
|
| 181 |
+
dropout: float = 0.1,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.in_proj = nn.Conv1d(embed_dim, embed_dim, kernel_size=7, padding=3)
|
| 185 |
+
self.prior_net = nn.Sequential(
|
| 186 |
+
ResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 187 |
+
ResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 188 |
+
)
|
| 189 |
+
self.transformers = nn.ModuleList(
|
| 190 |
+
[WhisperEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)]
|
| 191 |
+
)
|
| 192 |
+
self.post_net = nn.Sequential(
|
| 193 |
+
ResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 194 |
+
ResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 195 |
+
)
|
| 196 |
+
self.final_norm = nn.LayerNorm(embed_dim, eps=1e-6)
|
| 197 |
+
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
x: torch.Tensor,
|
| 201 |
+
x_lens: torch.Tensor,
|
| 202 |
+
):
|
| 203 |
+
"""
|
| 204 |
+
Args:
|
| 205 |
+
x: shape (b, t, c)
|
| 206 |
+
x_lens: shape (b,)
|
| 207 |
+
"""
|
| 208 |
+
x = x.transpose(1, 2)
|
| 209 |
+
x = self.in_proj(x)
|
| 210 |
+
x = self.prior_net(x)
|
| 211 |
+
x = x.transpose(1, 2)
|
| 212 |
+
|
| 213 |
+
attention_mask = make_nonpad_mask(x_lens).unsqueeze(1) # (b, 1, t)
|
| 214 |
+
# NOTE(sfy): I think positional embedding is unnecessary
|
| 215 |
+
for layer in self.transformers:
|
| 216 |
+
x = layer(x, attention_mask)
|
| 217 |
+
x = x.transpose(1, 2)
|
| 218 |
+
x = self.post_net(x)
|
| 219 |
+
x = x.transpose(1, 2)
|
| 220 |
+
x = self.final_norm(x)
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Streaming Vocos backbone based on Transformer layers
|
| 225 |
+
class CausalVocosBackbone(nn.Module):
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
embed_dim: int = 1024,
|
| 229 |
+
num_layers: int = 12,
|
| 230 |
+
num_heads: int = 16,
|
| 231 |
+
dropout: float = 0.1,
|
| 232 |
+
):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.in_proj = CausalConv1d(embed_dim, embed_dim, kernel_size=7)
|
| 235 |
+
self.prior_net = nn.Sequential(
|
| 236 |
+
CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 237 |
+
CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 238 |
+
)
|
| 239 |
+
self.transformers = nn.ModuleList(
|
| 240 |
+
[WhisperEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)]
|
| 241 |
+
)
|
| 242 |
+
self.post_net = nn.Sequential(
|
| 243 |
+
CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 244 |
+
CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
|
| 245 |
+
)
|
| 246 |
+
self.final_norm = nn.LayerNorm(embed_dim, eps=1e-6)
|
| 247 |
+
|
| 248 |
+
def forward(
|
| 249 |
+
self,
|
| 250 |
+
x: torch.Tensor,
|
| 251 |
+
x_lens: torch.Tensor,
|
| 252 |
+
):
|
| 253 |
+
"""
|
| 254 |
+
Args:
|
| 255 |
+
x: shape (b, t, c)
|
| 256 |
+
x_lens: shape (b,)
|
| 257 |
+
"""
|
| 258 |
+
x = x.transpose(1, 2)
|
| 259 |
+
x = self.in_proj(x)
|
| 260 |
+
x = self.prior_net(x)
|
| 261 |
+
x = x.transpose(1, 2)
|
| 262 |
+
|
| 263 |
+
# NOTE(sfy): We have no padding in training, so safe for sdpa attention, no Nan.
|
| 264 |
+
# Also, 1 token(12.5Hz) -> 4 latents(50Hz) -> 8 latents(100Hz),
|
| 265 |
+
# so we design a 8 block causal attention mask instead of fully causal to improve performance
|
| 266 |
+
attention_mask = make_block_causal_mask(x_lens, chunk_size=8)
|
| 267 |
+
for layer in self.transformers:
|
| 268 |
+
x = layer(x, attention_mask)
|
| 269 |
+
|
| 270 |
+
x = x.transpose(1, 2)
|
| 271 |
+
x = self.post_net(x)
|
| 272 |
+
x = x.transpose(1, 2)
|
| 273 |
+
x = self.final_norm(x)
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
def forward_chunk(
|
| 277 |
+
self,
|
| 278 |
+
x: torch.Tensor,
|
| 279 |
+
conv_cache1: torch.Tensor = None,
|
| 280 |
+
conv_cache2: torch.Tensor = None,
|
| 281 |
+
kv_cache: torch.Tensor = None,
|
| 282 |
+
):
|
| 283 |
+
# Unpack cache
|
| 284 |
+
cache1 = conv_cache1
|
| 285 |
+
cache2, cache3, cache4, cache5 = (
|
| 286 |
+
(None, None, None, None)
|
| 287 |
+
if conv_cache2 is None
|
| 288 |
+
else conv_cache2.chunk(4, dim=1)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# cache1: shape (b, c=embed_dim, t=6)
|
| 292 |
+
x = x.transpose(1, 2)
|
| 293 |
+
x, new_cache1 = self.in_proj.forward_chunk(x, cache1)
|
| 294 |
+
# cache2: shape (b, c=embed_dim*2, t=2)
|
| 295 |
+
x, new_cache2 = self.prior_net[0].forward_chunk(x, cache2)
|
| 296 |
+
# cache3: shape (b, c=embed_dim*2, t=2)
|
| 297 |
+
x, new_cache3 = self.prior_net[1].forward_chunk(x, cache3)
|
| 298 |
+
x = x.transpose(1, 2)
|
| 299 |
+
|
| 300 |
+
# k,v-cache: shape (b, nlayer, nh, t, c*2)
|
| 301 |
+
new_kv_cache = []
|
| 302 |
+
for idx, layer in enumerate(self.transformers):
|
| 303 |
+
kv_cache_i = None if kv_cache is None else kv_cache[:, idx]
|
| 304 |
+
x, new_kv_cache_i = layer.forward_chunk(x, kv_cache=kv_cache_i)
|
| 305 |
+
new_kv_cache.append(new_kv_cache_i)
|
| 306 |
+
new_kv_cache = torch.stack(new_kv_cache, dim=1)
|
| 307 |
+
|
| 308 |
+
x = x.transpose(1, 2)
|
| 309 |
+
# cache4: shape (b, c=embed_dim*2, t=2)
|
| 310 |
+
x, new_cache4 = self.post_net[0].forward_chunk(x, cache4)
|
| 311 |
+
# cache5: shape (b, c=embed_dim*2, t=2)
|
| 312 |
+
x, new_cache5 = self.post_net[1].forward_chunk(x, cache5)
|
| 313 |
+
x = x.transpose(1, 2)
|
| 314 |
+
x = self.final_norm(x)
|
| 315 |
+
|
| 316 |
+
new_conv_cache1 = new_cache1
|
| 317 |
+
new_conv_cache2 = torch.cat(
|
| 318 |
+
[new_cache2, new_cache3, new_cache4, new_cache5], dim=1
|
| 319 |
+
)
|
| 320 |
+
return x, new_conv_cache1, new_conv_cache2, new_kv_cache
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class ISTFT(nn.Module):
|
| 324 |
+
"""
|
| 325 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
| 326 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
| 327 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
| 328 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
| 329 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
n_fft (int): Size of Fourier transform.
|
| 333 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
| 334 |
+
win_length (int): The size of window frame and STFT filter.
|
| 335 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
| 340 |
+
):
|
| 341 |
+
super().__init__()
|
| 342 |
+
assert padding in ["center", "same"], "Padding must be 'center' or 'same'."
|
| 343 |
+
self.padding = padding
|
| 344 |
+
self.n_fft = n_fft
|
| 345 |
+
self.hop_length = hop_length
|
| 346 |
+
self.win_length = win_length
|
| 347 |
+
window = torch.hann_window(win_length)
|
| 348 |
+
self.register_buffer("window", window)
|
| 349 |
+
|
| 350 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 351 |
+
"""
|
| 352 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
| 356 |
+
N is the number of frequency bins, and T is the number of time frames.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
| 360 |
+
"""
|
| 361 |
+
if self.padding == "center":
|
| 362 |
+
# Fallback to pytorch native implementation
|
| 363 |
+
return torch.istft(
|
| 364 |
+
spec,
|
| 365 |
+
self.n_fft,
|
| 366 |
+
self.hop_length,
|
| 367 |
+
self.win_length,
|
| 368 |
+
self.window,
|
| 369 |
+
center=True,
|
| 370 |
+
)
|
| 371 |
+
elif self.padding == "same":
|
| 372 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 375 |
+
|
| 376 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
| 377 |
+
B, N, T = spec.shape
|
| 378 |
+
|
| 379 |
+
# Inverse FFT
|
| 380 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 381 |
+
ifft = ifft * self.window[None, :, None]
|
| 382 |
+
|
| 383 |
+
# Overlap and Add
|
| 384 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
| 385 |
+
y = torch.nn.functional.fold(
|
| 386 |
+
ifft,
|
| 387 |
+
output_size=(1, output_size),
|
| 388 |
+
kernel_size=(1, self.win_length),
|
| 389 |
+
stride=(1, self.hop_length),
|
| 390 |
+
)[:, 0, 0, pad:-pad]
|
| 391 |
+
|
| 392 |
+
# Window envelope
|
| 393 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
| 394 |
+
window_envelope = torch.nn.functional.fold(
|
| 395 |
+
window_sq,
|
| 396 |
+
output_size=(1, output_size),
|
| 397 |
+
kernel_size=(1, self.win_length),
|
| 398 |
+
stride=(1, self.hop_length),
|
| 399 |
+
).squeeze()[pad:-pad]
|
| 400 |
+
|
| 401 |
+
# Normalize
|
| 402 |
+
assert (window_envelope > 1e-11).all()
|
| 403 |
+
y = y / window_envelope
|
| 404 |
+
|
| 405 |
+
return y
|
| 406 |
+
|
| 407 |
+
def forward_chunk(
|
| 408 |
+
self, spec: torch.Tensor, cache: torch.Tensor = None, last_chunk: bool = False
|
| 409 |
+
):
|
| 410 |
+
"""Forward only one frame.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
spec: shape (B, N, T=chunk_size)
|
| 414 |
+
cache: previous chunk's last ifft frame, shape (B, N, T=3)
|
| 415 |
+
last_chunk: if last_chunk, will not trim the last (win-hop) segment
|
| 416 |
+
Returns:
|
| 417 |
+
y: shape (B, T=effective_length)
|
| 418 |
+
"""
|
| 419 |
+
assert self.padding == "same", "Padding must be same."
|
| 420 |
+
assert (
|
| 421 |
+
self.win_length % self.hop_length == 0
|
| 422 |
+
), f"{self.win_length} {self.hop_length}"
|
| 423 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 424 |
+
|
| 425 |
+
# Inverse FFT
|
| 426 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 427 |
+
ifft = ifft * self.window[None, :, None] # (B, N, T=chunk_size)
|
| 428 |
+
|
| 429 |
+
# Append previous cache
|
| 430 |
+
if cache is not None:
|
| 431 |
+
ifft = torch.cat([cache, ifft], dim=-1)
|
| 432 |
+
new_cache_t = self.win_length // self.hop_length - 1
|
| 433 |
+
new_cache = ifft[..., -new_cache_t:]
|
| 434 |
+
|
| 435 |
+
# Overlap and Add
|
| 436 |
+
output_size = (ifft.shape[-1] - 1) * self.hop_length + self.win_length
|
| 437 |
+
y = torch.nn.functional.fold(
|
| 438 |
+
ifft,
|
| 439 |
+
output_size=(1, output_size),
|
| 440 |
+
kernel_size=(1, self.win_length),
|
| 441 |
+
stride=(1, self.hop_length),
|
| 442 |
+
)[:, 0, 0, :]
|
| 443 |
+
|
| 444 |
+
# Window envelope
|
| 445 |
+
window_sq = (
|
| 446 |
+
self.window.square().expand(1, ifft.shape[-1], -1).transpose(1, 2)
|
| 447 |
+
) # (B=1, N, T)
|
| 448 |
+
window_envelope = torch.nn.functional.fold(
|
| 449 |
+
window_sq,
|
| 450 |
+
output_size=(1, output_size),
|
| 451 |
+
kernel_size=(1, self.win_length),
|
| 452 |
+
stride=(1, self.hop_length),
|
| 453 |
+
).squeeze()
|
| 454 |
+
|
| 455 |
+
# Normalize
|
| 456 |
+
# assert (window_envelope > 1e-11).all()
|
| 457 |
+
y = y / window_envelope
|
| 458 |
+
|
| 459 |
+
# Only take effective part
|
| 460 |
+
if cache is None:
|
| 461 |
+
y = y[:, pad:]
|
| 462 |
+
else:
|
| 463 |
+
y = y[:, (self.win_length - self.hop_length) :]
|
| 464 |
+
if last_chunk:
|
| 465 |
+
y = y[:, :-pad]
|
| 466 |
+
else:
|
| 467 |
+
y = y[:, : -(self.win_length - self.hop_length)]
|
| 468 |
+
return y, new_cache
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class ISTFTHead(nn.Module):
|
| 472 |
+
"""
|
| 473 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
dim (int): Hidden dimension of the model.
|
| 477 |
+
n_fft (int): Size of Fourier transform.
|
| 478 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
| 479 |
+
the resolution of the input features.
|
| 480 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.hop_length = hop_length
|
| 486 |
+
out_dim = n_fft + 2
|
| 487 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
| 488 |
+
self.istft = ISTFT(
|
| 489 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> torch.Tensor:
|
| 493 |
+
"""
|
| 494 |
+
Forward pass of the ISTFTHead module.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 498 |
+
L is the sequence length, and H denotes the model dimension.
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 502 |
+
"""
|
| 503 |
+
x_pred = self.out(x)
|
| 504 |
+
x_pred = x_pred.transpose(1, 2)
|
| 505 |
+
mag, p = x_pred.chunk(2, dim=1)
|
| 506 |
+
mag = torch.exp(mag)
|
| 507 |
+
mag = torch.clip(
|
| 508 |
+
mag, max=1e2
|
| 509 |
+
) # safeguard to prevent excessively large magnitudes
|
| 510 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 511 |
+
x = torch.cos(p)
|
| 512 |
+
y = torch.sin(p)
|
| 513 |
+
# recalculating phase here does not produce anything new
|
| 514 |
+
# only costs time
|
| 515 |
+
# phase = torch.atan2(y, x)
|
| 516 |
+
# S = mag * torch.exp(phase * 1j)
|
| 517 |
+
# better directly produce the complex value
|
| 518 |
+
S = mag * (x + 1j * y)
|
| 519 |
+
audio = self.istft(S)
|
| 520 |
+
audio_length = x_len * self.hop_length
|
| 521 |
+
return audio, audio_length
|
| 522 |
+
|
| 523 |
+
def forward_chunk(
|
| 524 |
+
self, x: torch.Tensor, cache: torch.Tensor = None, last_chunk: bool = False
|
| 525 |
+
):
|
| 526 |
+
"""ISTFTHead can be adapted in streaming inference without retraining.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
x: shape (B, T, C)
|
| 530 |
+
cache: shape (B, N, T=3), istft cache
|
| 531 |
+
Returns:
|
| 532 |
+
audio: shape (B, t)
|
| 533 |
+
"""
|
| 534 |
+
x_pred = self.out(x)
|
| 535 |
+
x_pred = x_pred.transpose(1, 2)
|
| 536 |
+
mag, p = x_pred.chunk(2, dim=1)
|
| 537 |
+
mag = torch.exp(mag) # (B, C, T)
|
| 538 |
+
mag = torch.clip(
|
| 539 |
+
mag, max=1e2
|
| 540 |
+
) # safeguard to prevent excessively large magnitudes
|
| 541 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 542 |
+
x = torch.cos(p)
|
| 543 |
+
y = torch.sin(p)
|
| 544 |
+
S = mag * (x + 1j * y) # (B, C, T)
|
| 545 |
+
audio, new_cache = self.istft.forward_chunk(S, cache, last_chunk)
|
| 546 |
+
return audio, new_cache
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# UpsampleConv(50->100Hz) + VocosBackbone + ISTFTHead
|
| 550 |
+
class AcousticDecoder(nn.Module):
|
| 551 |
+
def __init__(
|
| 552 |
+
self,
|
| 553 |
+
# Transformer
|
| 554 |
+
embed_dim: int,
|
| 555 |
+
num_layers: int,
|
| 556 |
+
num_heads: int,
|
| 557 |
+
dropout: float = 0.0,
|
| 558 |
+
# iSTFT
|
| 559 |
+
hop_length: int = 240,
|
| 560 |
+
# Causal
|
| 561 |
+
causal: bool = False,
|
| 562 |
+
):
|
| 563 |
+
super().__init__()
|
| 564 |
+
self.embed_dim = embed_dim
|
| 565 |
+
self.num_layers = num_layers
|
| 566 |
+
self.num_heads = num_heads
|
| 567 |
+
self.hop_length = hop_length
|
| 568 |
+
self.causal = causal
|
| 569 |
+
|
| 570 |
+
# Output upsample
|
| 571 |
+
self.upsample_conv = nn.Sequential(
|
| 572 |
+
nn.ConvTranspose1d(
|
| 573 |
+
embed_dim,
|
| 574 |
+
embed_dim,
|
| 575 |
+
kernel_size=3,
|
| 576 |
+
stride=2,
|
| 577 |
+
padding=0, # Do not fill input side
|
| 578 |
+
output_padding=0, # Can be adjusted to precisely control length
|
| 579 |
+
),
|
| 580 |
+
nn.GELU(),
|
| 581 |
+
nn.ConvTranspose1d(
|
| 582 |
+
embed_dim,
|
| 583 |
+
embed_dim,
|
| 584 |
+
kernel_size=3,
|
| 585 |
+
stride=1,
|
| 586 |
+
padding=0, # Do not fill input side
|
| 587 |
+
),
|
| 588 |
+
nn.GELU(),
|
| 589 |
+
)
|
| 590 |
+
self.backbone = (
|
| 591 |
+
CausalVocosBackbone(embed_dim, num_layers, num_heads, dropout)
|
| 592 |
+
if causal
|
| 593 |
+
else VocosBackbone(embed_dim, num_layers, num_heads, dropout)
|
| 594 |
+
)
|
| 595 |
+
self.isift = ISTFTHead(embed_dim, hop_length * 4, hop_length, padding="same")
|
| 596 |
+
# Init weights
|
| 597 |
+
self.apply(self._init_weights)
|
| 598 |
+
|
| 599 |
+
def _init_weights(self, m):
|
| 600 |
+
if isinstance(m, nn.Conv1d):
|
| 601 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 602 |
+
nn.init.constant_(m.bias, 0)
|
| 603 |
+
|
| 604 |
+
def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
|
| 605 |
+
"""
|
| 606 |
+
Args:
|
| 607 |
+
x: shape (b, t, c)
|
| 608 |
+
x_lens: shape (b,)
|
| 609 |
+
"""
|
| 610 |
+
# Upsample
|
| 611 |
+
target_length = x.shape[1] * 2
|
| 612 |
+
x = x.transpose(1, 2)
|
| 613 |
+
x = self.upsample_conv(x)
|
| 614 |
+
x = x.transpose(1, 2)
|
| 615 |
+
# NOTE strict upsampling, trim the last 3 elements
|
| 616 |
+
x = x[:, :target_length]
|
| 617 |
+
x_lens = x_lens * 2
|
| 618 |
+
# Backbone
|
| 619 |
+
x = self.backbone(x, x_lens)
|
| 620 |
+
# iSTFT
|
| 621 |
+
y, y_lens = self.isift(x, x_lens)
|
| 622 |
+
return y, y_lens
|
| 623 |
+
|
| 624 |
+
def forward_upsample_conv_chunk(self, x: torch.Tensor, cache: torch.Tensor = None):
|
| 625 |
+
"""Stream forward upsample_conv module with previous block cache.
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
x: shape (B, C, T)
|
| 629 |
+
cache: shape (B, C, 3), where 3 denotes 1 history state for 1st conv and 2 for the rest conv.
|
| 630 |
+
"""
|
| 631 |
+
# Unpack cache
|
| 632 |
+
cache1, cache2 = (
|
| 633 |
+
(None, None) if cache is None else torch.split(cache, [1, 2], dim=2)
|
| 634 |
+
)
|
| 635 |
+
# 1st conv cache
|
| 636 |
+
if cache1 is not None:
|
| 637 |
+
x = torch.cat([cache1, x], dim=2)
|
| 638 |
+
new_cache1 = x[..., -1:]
|
| 639 |
+
# 1st conv
|
| 640 |
+
x = self.upsample_conv[0](x)[..., :-1] # remove extra 1 frame
|
| 641 |
+
if cache1 is not None:
|
| 642 |
+
x = x[..., 2:] # remove cache1 part
|
| 643 |
+
x = self.upsample_conv[1](x)
|
| 644 |
+
# 2nd conv cache
|
| 645 |
+
if cache2 is not None:
|
| 646 |
+
x = torch.cat([cache2, x], dim=2)
|
| 647 |
+
new_cache2 = x[..., -2:]
|
| 648 |
+
# 2nd conv
|
| 649 |
+
x = self.upsample_conv[2](x)[..., :-2] # remove extra 2 frame
|
| 650 |
+
if cache2 is not None:
|
| 651 |
+
x = x[..., 2:] # remove cache2 part
|
| 652 |
+
x = self.upsample_conv[3](x)
|
| 653 |
+
|
| 654 |
+
new_cache = torch.cat([new_cache1, new_cache2], dim=2)
|
| 655 |
+
return x, new_cache
|
| 656 |
+
|
| 657 |
+
def forward_chunk(
|
| 658 |
+
self,
|
| 659 |
+
x: torch.Tensor,
|
| 660 |
+
# Upsample conv cache
|
| 661 |
+
up_conv_cache: torch.Tensor = None,
|
| 662 |
+
# Backbone conv cache
|
| 663 |
+
bb_conv_cache1: torch.Tensor = None,
|
| 664 |
+
bb_conv_cache2: torch.Tensor = None,
|
| 665 |
+
# Backbone attention cache
|
| 666 |
+
bb_kv_cache: torch.Tensor = None,
|
| 667 |
+
# iSTFT cache
|
| 668 |
+
is_cache: torch.Tensor = None,
|
| 669 |
+
last_chunk: bool = False,
|
| 670 |
+
):
|
| 671 |
+
"""
|
| 672 |
+
Args:
|
| 673 |
+
x: input sequence at 50Hz, length should be multiples of 4
|
| 674 |
+
"""
|
| 675 |
+
assert (
|
| 676 |
+
self.causal
|
| 677 |
+
), "Only AcousticDecoder with causal=True supports forward_chunk method."
|
| 678 |
+
|
| 679 |
+
x = x.transpose(1, 2)
|
| 680 |
+
x, new_up_conv_cache = self.forward_upsample_conv_chunk(x, up_conv_cache)
|
| 681 |
+
x = x.transpose(1, 2)
|
| 682 |
+
# Backbone
|
| 683 |
+
x, new_bb_conv_cache1, new_bb_conv_cache2, new_bb_kv_cache = (
|
| 684 |
+
self.backbone.forward_chunk(
|
| 685 |
+
x,
|
| 686 |
+
bb_conv_cache1,
|
| 687 |
+
bb_conv_cache2,
|
| 688 |
+
bb_kv_cache,
|
| 689 |
+
)
|
| 690 |
+
)
|
| 691 |
+
# iSTFT
|
| 692 |
+
y, new_is_cache = self.isift.forward_chunk(x, is_cache, last_chunk)
|
| 693 |
+
return (
|
| 694 |
+
y,
|
| 695 |
+
new_up_conv_cache,
|
| 696 |
+
new_bb_conv_cache1,
|
| 697 |
+
new_bb_conv_cache2,
|
| 698 |
+
new_bb_kv_cache,
|
| 699 |
+
new_is_cache,
|
| 700 |
+
)
|
fireredtts2/codec/model.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
+
|
| 9 |
+
from fireredtts2.codec.rvq import ResidualVQ
|
| 10 |
+
from fireredtts2.codec.decoder import AcousticDecoder
|
| 11 |
+
from fireredtts2.codec.utils import make_nonpad_mask
|
| 12 |
+
from fireredtts2.codec.whisper import (
|
| 13 |
+
WhisperEncoderLayer,
|
| 14 |
+
PretrainedWhisperEncoder,
|
| 15 |
+
WhisperAcousticEncoder,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SslAdaptor(nn.Module):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
in_dim: int,
|
| 23 |
+
embed_dim: int,
|
| 24 |
+
out_dim: int,
|
| 25 |
+
num_layers: int,
|
| 26 |
+
num_heads: int,
|
| 27 |
+
ffn_dim: int = None,
|
| 28 |
+
attn_dropout: float = 0.0,
|
| 29 |
+
dropout: float = 0.0,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.in_dim = in_dim
|
| 33 |
+
self.embed_dim = embed_dim
|
| 34 |
+
self.dropout = dropout
|
| 35 |
+
# Input Projection
|
| 36 |
+
self.in_proj = nn.Linear(in_dim, embed_dim)
|
| 37 |
+
# Transformer
|
| 38 |
+
self.layers = nn.ModuleList(
|
| 39 |
+
[
|
| 40 |
+
WhisperEncoderLayer(
|
| 41 |
+
embed_dim, num_heads, ffn_dim, attn_dropout, dropout
|
| 42 |
+
)
|
| 43 |
+
for _ in range(num_layers)
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
# Output norm
|
| 47 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
| 48 |
+
# Output projection
|
| 49 |
+
self.out_proj = nn.Linear(embed_dim, out_dim)
|
| 50 |
+
# Init weight
|
| 51 |
+
self.apply(self._init_weights)
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self,
|
| 55 |
+
hidden_states: torch.Tensor,
|
| 56 |
+
hidden_length: torch.Tensor,
|
| 57 |
+
):
|
| 58 |
+
# Downsampling
|
| 59 |
+
hidden_states = self.in_proj(hidden_states)
|
| 60 |
+
# Transformer
|
| 61 |
+
attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t)
|
| 62 |
+
for layer in self.layers:
|
| 63 |
+
hidden_states = layer(hidden_states, attention_mask)
|
| 64 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 65 |
+
hidden_states = self.out_proj(hidden_states)
|
| 66 |
+
return hidden_states, hidden_length
|
| 67 |
+
|
| 68 |
+
def _init_weights(self, module):
|
| 69 |
+
std = 0.02
|
| 70 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
| 71 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 72 |
+
if module.bias is not None:
|
| 73 |
+
module.bias.data.zero_()
|
| 74 |
+
elif isinstance(module, nn.Embedding):
|
| 75 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 76 |
+
if module.padding_idx is not None:
|
| 77 |
+
module.weight.data[module.padding_idx].zero_()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ResidualDownConv(nn.Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
embed_dim: int = 768,
|
| 84 |
+
avg_pooler=4,
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.embed_dim = embed_dim
|
| 88 |
+
self.avg_pooler = avg_pooler
|
| 89 |
+
self.intermediate_dim = embed_dim * avg_pooler
|
| 90 |
+
# Convolution layer for downsampling
|
| 91 |
+
self.gate_proj = nn.Conv1d(
|
| 92 |
+
embed_dim, self.intermediate_dim, avg_pooler, avg_pooler, bias=False
|
| 93 |
+
)
|
| 94 |
+
self.up_proj = nn.Conv1d(
|
| 95 |
+
embed_dim, self.intermediate_dim, avg_pooler, avg_pooler, bias=False
|
| 96 |
+
)
|
| 97 |
+
# Downsampled linear projection
|
| 98 |
+
self.down_proj = nn.Linear(
|
| 99 |
+
self.intermediate_dim, self.intermediate_dim, bias=False
|
| 100 |
+
)
|
| 101 |
+
# Activation function and layer normalization
|
| 102 |
+
self.act_fn = nn.SiLU()
|
| 103 |
+
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
|
| 104 |
+
# Final output projection
|
| 105 |
+
self.out_proj = nn.Linear(self.intermediate_dim, embed_dim)
|
| 106 |
+
|
| 107 |
+
def forward(self, x: torch.Tensor, input_length: torch.Tensor):
|
| 108 |
+
output_length = input_length // self.avg_pooler
|
| 109 |
+
batch_size, seq_len, _ = x.shape # (B, T, D)
|
| 110 |
+
|
| 111 |
+
xt = x.permute(0, 2, 1) # (B, D, T)
|
| 112 |
+
g = self.gate_proj(xt).permute(0, 2, 1) # (B, T//4, D*4)
|
| 113 |
+
u = self.up_proj(xt).permute(0, 2, 1) # (B, T//4, D*4)
|
| 114 |
+
x = x.reshape(batch_size, -1, self.intermediate_dim) # (B, T//4, D*4)
|
| 115 |
+
|
| 116 |
+
c = self.down_proj(self.act_fn(g) * u) # (B, T//4, D*4)
|
| 117 |
+
res = self.layer_norm(c + x) # (B, T//4, D*4)
|
| 118 |
+
|
| 119 |
+
res = self.out_proj(res)
|
| 120 |
+
return res, output_length
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class UpConv(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
embed_dim: int = 768,
|
| 127 |
+
stride: int = 4,
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.embed_dim = embed_dim
|
| 131 |
+
self.stride = stride
|
| 132 |
+
self.in_proj = nn.Linear(embed_dim, self.stride * embed_dim)
|
| 133 |
+
# Simple transpose convolution layer to keep channel number consistent
|
| 134 |
+
self.up_conv = nn.ConvTranspose1d(
|
| 135 |
+
self.stride * embed_dim,
|
| 136 |
+
embed_dim,
|
| 137 |
+
kernel_size=stride,
|
| 138 |
+
stride=stride,
|
| 139 |
+
bias=False,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x: torch.Tensor, input_length: torch.Tensor):
|
| 143 |
+
x = self.in_proj(x)
|
| 144 |
+
x = x.transpose(1, 2)
|
| 145 |
+
res = self.up_conv(x)
|
| 146 |
+
res = res.transpose(1, 2)
|
| 147 |
+
output_length = input_length * self.stride
|
| 148 |
+
return res, output_length
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class RedCodec(nn.Module):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
ssl: PretrainedWhisperEncoder,
|
| 155 |
+
ssl_adaptor: SslAdaptor,
|
| 156 |
+
acoustic_encoder: WhisperAcousticEncoder,
|
| 157 |
+
downsample: ResidualDownConv,
|
| 158 |
+
rvq: ResidualVQ,
|
| 159 |
+
upsample: UpConv,
|
| 160 |
+
semantic_decoder: SslAdaptor,
|
| 161 |
+
acoustic_decoder: AcousticDecoder,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.ssl = ssl
|
| 165 |
+
self.ssl_adaptor = ssl_adaptor
|
| 166 |
+
self.acoustic_encoder = acoustic_encoder
|
| 167 |
+
self.downsample = downsample
|
| 168 |
+
self.rvq = rvq
|
| 169 |
+
self.upsample = upsample
|
| 170 |
+
self.semantic_decoder = semantic_decoder
|
| 171 |
+
self.acoustic_decoder = acoustic_decoder
|
| 172 |
+
|
| 173 |
+
@classmethod
|
| 174 |
+
def from_config(cls, config_json: str) -> "RedCodec":
|
| 175 |
+
with open(config_json, "rb") as f:
|
| 176 |
+
config = json.load(f)["codec"]
|
| 177 |
+
ssl = PretrainedWhisperEncoder.from_pretrained()
|
| 178 |
+
ssl_adaptor = SslAdaptor(**config["ssl_adaptor"])
|
| 179 |
+
acoustic_encoder = WhisperAcousticEncoder(**config["acoustic_encoder"])
|
| 180 |
+
downsample = ResidualDownConv(**config["downsample"])
|
| 181 |
+
rvq = ResidualVQ(**config["rvq"])
|
| 182 |
+
upsample = UpConv(**config["upsample"])
|
| 183 |
+
semantic_decoder = SslAdaptor(**config["semantic_decoder"])
|
| 184 |
+
acoustic_decoder = AcousticDecoder(**config["acoustic_decoder"])
|
| 185 |
+
return cls(
|
| 186 |
+
ssl,
|
| 187 |
+
ssl_adaptor,
|
| 188 |
+
acoustic_encoder,
|
| 189 |
+
downsample,
|
| 190 |
+
rvq,
|
| 191 |
+
upsample,
|
| 192 |
+
semantic_decoder,
|
| 193 |
+
acoustic_decoder,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class RedCodecInfer(RedCodec):
|
| 198 |
+
def __init__(self, codec: RedCodec):
|
| 199 |
+
super().__init__(
|
| 200 |
+
codec.ssl,
|
| 201 |
+
codec.ssl_adaptor,
|
| 202 |
+
codec.acoustic_encoder,
|
| 203 |
+
codec.downsample,
|
| 204 |
+
codec.rvq,
|
| 205 |
+
codec.upsample,
|
| 206 |
+
codec.semantic_decoder,
|
| 207 |
+
codec.acoustic_decoder,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
@classmethod
|
| 211 |
+
def from_pretrained(cls, conf_path: str, ckpt_path: str) -> "RedCodecInfer":
|
| 212 |
+
with open(conf_path, "r") as f:
|
| 213 |
+
codec = RedCodec.from_config(conf_path)
|
| 214 |
+
ckpt = torch.load(ckpt_path)["generator"]
|
| 215 |
+
codec.load_state_dict(ckpt)
|
| 216 |
+
return cls(codec)
|
| 217 |
+
|
| 218 |
+
def _encode_one_batch(self, audio16k: torch.Tensor):
|
| 219 |
+
B, T = audio16k.shape
|
| 220 |
+
audio16k_length = torch.tensor(
|
| 221 |
+
[T] * B, dtype=torch.long, device=audio16k.device
|
| 222 |
+
)
|
| 223 |
+
# Semantic
|
| 224 |
+
ssl, ssl_length = self.ssl.forward(audio16k, audio16k_length)
|
| 225 |
+
ssl = ssl.clone() # For onnx export
|
| 226 |
+
sem_feats, sem_length = self.ssl_adaptor(ssl, ssl_length)
|
| 227 |
+
# Acoustic
|
| 228 |
+
aco_feats, aco_length = self.acoustic_encoder(audio16k, audio16k_length)
|
| 229 |
+
# VQ
|
| 230 |
+
vq_in_feats = torch.cat([sem_feats, aco_feats], dim=2)
|
| 231 |
+
vq_in_feats, vq_in_length = self.downsample(vq_in_feats, aco_length)
|
| 232 |
+
# RVQ,
|
| 233 |
+
indices = self.rvq.encode_codes(vq_in_feats.transpose(1, 2)) # (nq, B, L)
|
| 234 |
+
indices = indices.permute(1, 0, 2)
|
| 235 |
+
return indices # (B, nq, L)
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
def _pad_and_chunk(audio: torch.Tensor, chunk_size: int) -> List[torch.Tensor]:
|
| 239 |
+
pad_len = math.ceil(audio.shape[1] / chunk_size) * chunk_size - audio.shape[1]
|
| 240 |
+
audio = F.pad(audio, (0, pad_len), mode="constant", value=0)
|
| 241 |
+
audio_chunks = audio.split(chunk_size, dim=1)
|
| 242 |
+
return audio_chunks
|
| 243 |
+
|
| 244 |
+
@torch.inference_mode()
|
| 245 |
+
def encode(
|
| 246 |
+
self,
|
| 247 |
+
audio16k: torch.Tensor,
|
| 248 |
+
audio16k_length: torch.Tensor = None,
|
| 249 |
+
batch_size: int = 96,
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Args:
|
| 253 |
+
audio16k: shape (b, t)
|
| 254 |
+
audio16k_length: (b,)
|
| 255 |
+
Returns:
|
| 256 |
+
token: shape (b, nq, l)
|
| 257 |
+
token_length: (b,)
|
| 258 |
+
"""
|
| 259 |
+
if audio16k_length is None:
|
| 260 |
+
assert audio16k.shape[0] == 1
|
| 261 |
+
audio16k_length = torch.tensor(
|
| 262 |
+
[audio16k.shape[1]], dtype=torch.long, device=audio16k.device
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
CHUNK_SIZE = 6 * 16000
|
| 266 |
+
B, T = audio16k.shape
|
| 267 |
+
# Pad, chunk, and batch
|
| 268 |
+
audio16k_batch = []
|
| 269 |
+
batch_size_list = []
|
| 270 |
+
for i in range(B):
|
| 271 |
+
# Remove extra paddings
|
| 272 |
+
one_audio_chunks = self._pad_and_chunk(
|
| 273 |
+
audio16k[i : (i + 1), : audio16k_length[i]], CHUNK_SIZE
|
| 274 |
+
)
|
| 275 |
+
audio16k_batch += one_audio_chunks
|
| 276 |
+
batch_size_list.append(len(one_audio_chunks))
|
| 277 |
+
audio16k_batch = torch.cat(audio16k_batch, dim=0)
|
| 278 |
+
# Batch encode
|
| 279 |
+
token_batch = []
|
| 280 |
+
for i in range(0, audio16k_batch.shape[0], batch_size):
|
| 281 |
+
one_audio_batch = audio16k_batch[i : (i + batch_size)]
|
| 282 |
+
one_token_batch = self._encode_one_batch(one_audio_batch)
|
| 283 |
+
token_batch.append(one_token_batch)
|
| 284 |
+
token_batch = torch.cat(token_batch, dim=0)
|
| 285 |
+
# Recover & concat
|
| 286 |
+
token_list = torch.split(
|
| 287 |
+
token_batch, batch_size_list, dim=0
|
| 288 |
+
) # [(B=1, nq, l), (B=3, nq, l), ...]
|
| 289 |
+
token_list = [
|
| 290 |
+
torch.cat(token_ts.split(1, dim=0), dim=-1) # (B=1, nq, l)
|
| 291 |
+
for token_ts in token_list
|
| 292 |
+
]
|
| 293 |
+
# Pad tokens
|
| 294 |
+
token = pad_sequence(
|
| 295 |
+
[ts.squeeze(0).transpose(1, 0) for ts in token_list],
|
| 296 |
+
batch_first=True,
|
| 297 |
+
padding_value=0,
|
| 298 |
+
).transpose(
|
| 299 |
+
1, 2
|
| 300 |
+
) # (B, nq, L)
|
| 301 |
+
token_length = (audio16k_length / 1280).ceil().long()
|
| 302 |
+
token = token[
|
| 303 |
+
..., : token_length.max()
|
| 304 |
+
] # Remove extra paddings (we pad to multiples of 6s)
|
| 305 |
+
return token, token_length
|
| 306 |
+
|
| 307 |
+
@torch.inference_mode()
|
| 308 |
+
def decode(self, tokens: torch.Tensor):
|
| 309 |
+
"""
|
| 310 |
+
Args:
|
| 311 |
+
tokens: (B=1, nq, L)
|
| 312 |
+
Returns:
|
| 313 |
+
audio: (B=1, t)
|
| 314 |
+
"""
|
| 315 |
+
tokens = tokens.permute(1, 0, 2) # (B, nq, L) -> (nq, B, L)
|
| 316 |
+
vq_out_feats = self.rvq.decode_codes(tokens)
|
| 317 |
+
vq_out_feats = vq_out_feats.transpose(1, 2)
|
| 318 |
+
vq_out_length = torch.tensor(
|
| 319 |
+
[vq_out_feats.shape[1]], dtype=torch.long, device=vq_out_feats.device
|
| 320 |
+
)
|
| 321 |
+
vq_out_feats, vq_out_length = self.upsample(vq_out_feats, vq_out_length)
|
| 322 |
+
# audio: (b, t)
|
| 323 |
+
audio, audio_length = self.acoustic_decoder(vq_out_feats, vq_out_length)
|
| 324 |
+
return audio
|
| 325 |
+
|
| 326 |
+
@torch.inference_mode()
|
| 327 |
+
def decode_one_token(
|
| 328 |
+
self, token: torch.Tensor, cache_dict: Dict[str, torch.Tensor], last_token: bool
|
| 329 |
+
):
|
| 330 |
+
"""Decode one single token to audio.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
token: (B=1, nq, L=1)
|
| 334 |
+
Returns:
|
| 335 |
+
audio: (B=1, t)
|
| 336 |
+
"""
|
| 337 |
+
# token->latent->upsample, (naturally causal)
|
| 338 |
+
token = token.permute(1, 0, 2) # (B, nq, L) -> (nq, B, L)
|
| 339 |
+
vq_out_feats = self.rvq.decode_codes(token)
|
| 340 |
+
vq_out_feats = vq_out_feats.transpose(1, 2)
|
| 341 |
+
vq_out_length = torch.tensor(
|
| 342 |
+
[vq_out_feats.shape[1]], dtype=torch.long, device=vq_out_feats.device
|
| 343 |
+
)
|
| 344 |
+
vq_out_feats, vq_out_length = self.upsample(vq_out_feats, vq_out_length)
|
| 345 |
+
# acoustic decoder
|
| 346 |
+
up_conv_cache = cache_dict.get("up_conv_cache", None)
|
| 347 |
+
bb_conv_cache1 = cache_dict.get("bb_conv_cache1", None)
|
| 348 |
+
bb_conv_cache2 = cache_dict.get("bb_conv_cache2", None)
|
| 349 |
+
bb_kv_cache = cache_dict.get("bb_kv_cache", None)
|
| 350 |
+
is_cache = cache_dict.get("is_cache", None)
|
| 351 |
+
|
| 352 |
+
(
|
| 353 |
+
audio,
|
| 354 |
+
new_up_conv_cache,
|
| 355 |
+
new_bb_conv_cache1,
|
| 356 |
+
new_bb_conv_cache2,
|
| 357 |
+
new_bb_kv_cache,
|
| 358 |
+
new_is_cache,
|
| 359 |
+
) = self.acoustic_decoder.forward_chunk(
|
| 360 |
+
vq_out_feats,
|
| 361 |
+
up_conv_cache,
|
| 362 |
+
bb_conv_cache1,
|
| 363 |
+
bb_conv_cache2,
|
| 364 |
+
bb_kv_cache,
|
| 365 |
+
is_cache,
|
| 366 |
+
last_token,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
new_cache_dict = {
|
| 370 |
+
"up_conv_cache": new_up_conv_cache,
|
| 371 |
+
"bb_conv_cache1": new_bb_conv_cache1,
|
| 372 |
+
"bb_conv_cache2": new_bb_conv_cache2,
|
| 373 |
+
"bb_kv_cache": new_bb_kv_cache,
|
| 374 |
+
"is_cache": new_is_cache,
|
| 375 |
+
}
|
| 376 |
+
return audio, new_cache_dict
|
fireredtts2/codec/rvq.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def WNConv1d(*args, **kwargs):
|
| 9 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 13 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VectorQuantize(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
input_dim: int,
|
| 20 |
+
codebook_size: int,
|
| 21 |
+
codebook_dim: int,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.input_dim = input_dim
|
| 25 |
+
self.codebook_size = codebook_size
|
| 26 |
+
self.codebook_dim = codebook_dim
|
| 27 |
+
|
| 28 |
+
self.in_project = (
|
| 29 |
+
WNConv1d(
|
| 30 |
+
self.input_dim, self.codebook_dim, kernel_size=1
|
| 31 |
+
) # (B, D, T) -> (B, D', T)
|
| 32 |
+
if self.input_dim != self.codebook_dim
|
| 33 |
+
else nn.Identity()
|
| 34 |
+
)
|
| 35 |
+
self.out_project = (
|
| 36 |
+
WNConv1d(
|
| 37 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
| 38 |
+
) # (B, D', T) -> (B, D, T)
|
| 39 |
+
if self.input_dim != self.codebook_dim
|
| 40 |
+
else nn.Identity()
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Initialize codebook and EMA buffers
|
| 44 |
+
self.register_buffer(
|
| 45 |
+
"codebook", torch.zeros(codebook_size, codebook_dim).float()
|
| 46 |
+
) # (codebook_size, D'), ensure fp32
|
| 47 |
+
# Place holder, not used in inference
|
| 48 |
+
self.register_buffer("inited", torch.tensor([True], dtype=torch.bool)) # (1)
|
| 49 |
+
self.register_buffer(
|
| 50 |
+
"cluster_size", torch.zeros(codebook_size).float()
|
| 51 |
+
) # (codebook_size), ensure fp32
|
| 52 |
+
self.register_buffer(
|
| 53 |
+
"embed_avg", self.codebook.clone().float()
|
| 54 |
+
) # (codebook_size, D'), ensure fp32
|
| 55 |
+
|
| 56 |
+
def decode_code(self, embed_id): # embed_id: (B, T)
|
| 57 |
+
embed = (
|
| 58 |
+
F.embedding(embed_id, self.codebook).transpose(1, 2).float()
|
| 59 |
+
) # (B, D', T), ensure fp32
|
| 60 |
+
return embed
|
| 61 |
+
|
| 62 |
+
def encode_code(self, z: torch.Tensor): # z: (B, D, T)
|
| 63 |
+
# logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
|
| 64 |
+
z = z.float() # Ensure fp32
|
| 65 |
+
z_e = self.in_project(z).float() # (B, D', T), ensure fp32
|
| 66 |
+
|
| 67 |
+
# Rearrange for quantization
|
| 68 |
+
encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32
|
| 69 |
+
|
| 70 |
+
# Quantization
|
| 71 |
+
dist = (
|
| 72 |
+
encodings.pow(2).sum(1, keepdim=True) # (B*T, 1)
|
| 73 |
+
- 2 * encodings @ self.codebook.float().t() # (B*T, codebook_size)
|
| 74 |
+
+ self.codebook.float().pow(2).sum(1, keepdim=True).t()
|
| 75 |
+
) # (1, codebook_size)
|
| 76 |
+
|
| 77 |
+
# dist: (B*T, codebook_size)
|
| 78 |
+
indices = (-dist).max(1)[1] # (B*T)
|
| 79 |
+
indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T)
|
| 80 |
+
|
| 81 |
+
# Get quantized vectors
|
| 82 |
+
z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32
|
| 83 |
+
|
| 84 |
+
# Straight-through estimator
|
| 85 |
+
z_q = z_e + (z_q - z_e).detach() # (B, D', T)
|
| 86 |
+
z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32
|
| 87 |
+
|
| 88 |
+
# z_q: (B, D, T), commit_loss: (B), indices: (B, T), z: (B, D', T)
|
| 89 |
+
return z_q, indices
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ResidualVQ(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
input_dim: int = 768, # Input dimension, unrelated to RVQ
|
| 96 |
+
rvq_dim=None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
|
| 97 |
+
output_dim: int = None, # Output dimension, unrelated to RVQ
|
| 98 |
+
num_quantizers: int = 8,
|
| 99 |
+
codebook_size: int = 1024,
|
| 100 |
+
codebook_dim: int = 256, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.input_dim = input_dim
|
| 104 |
+
|
| 105 |
+
self.num_quantizers = num_quantizers
|
| 106 |
+
self.codebook_size = codebook_size
|
| 107 |
+
self.codebook_dim = codebook_dim
|
| 108 |
+
self.rvq_dim = rvq_dim
|
| 109 |
+
|
| 110 |
+
self.input_proj = (
|
| 111 |
+
WNConv1d(input_dim, rvq_dim, kernel_size=1)
|
| 112 |
+
if input_dim != rvq_dim
|
| 113 |
+
else nn.Identity()
|
| 114 |
+
)
|
| 115 |
+
self.output_proj = (
|
| 116 |
+
WNConv1d(rvq_dim, output_dim, kernel_size=1)
|
| 117 |
+
if rvq_dim != output_dim
|
| 118 |
+
else nn.Identity()
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.quantizers = nn.ModuleList(
|
| 122 |
+
[
|
| 123 |
+
VectorQuantize(
|
| 124 |
+
input_dim=rvq_dim,
|
| 125 |
+
codebook_size=self.codebook_size,
|
| 126 |
+
codebook_dim=codebook_dim,
|
| 127 |
+
)
|
| 128 |
+
for i in range(num_quantizers)
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def encode_codes(self, z: torch.Tensor):
|
| 133 |
+
z = self.input_proj(z)
|
| 134 |
+
residual = z.clone().float() # (B, D, T), ensure fp32
|
| 135 |
+
all_indices = []
|
| 136 |
+
# Quantize to tokens
|
| 137 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 138 |
+
# (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
|
| 139 |
+
z_q_i, indices_i = quantizer.encode_code(residual)
|
| 140 |
+
residual = residual - z_q_i
|
| 141 |
+
all_indices.append(indices_i) # (B, T)
|
| 142 |
+
all_indices = torch.stack(all_indices) # (N, B, T)
|
| 143 |
+
return all_indices
|
| 144 |
+
|
| 145 |
+
def decode_codes(self, codes): # codes: (nq, B, T)
|
| 146 |
+
"""Decode codes from multiple quantizers to embeddings.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
emb: Tensor of shape (B, D, T) representing the decoded embeddings.
|
| 153 |
+
"""
|
| 154 |
+
nq, B, T = codes.shape
|
| 155 |
+
device = codes.device
|
| 156 |
+
emb = torch.zeros(
|
| 157 |
+
B, self.rvq_dim, T, device=device, dtype=torch.float32
|
| 158 |
+
) # (B, D, T)
|
| 159 |
+
for i, quantizer in enumerate(self.quantizers[:nq]):
|
| 160 |
+
code_i = codes[i] # (B, T)
|
| 161 |
+
quantized_i = quantizer.decode_code(code_i) # (B, D', T)
|
| 162 |
+
emb += quantizer.out_project(quantized_i) # Accumulate quantized embeddings
|
| 163 |
+
emb = self.output_proj(emb) # (B, D, T), apply output projection
|
| 164 |
+
return emb # (B, D, T)
|
fireredtts2/codec/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 6 |
+
batch_size = lengths.size(0)
|
| 7 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
| 8 |
+
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
|
| 9 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| 10 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
| 11 |
+
mask = seq_range_expand >= seq_length_expand
|
| 12 |
+
return mask # (b, t)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def make_nonpad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 16 |
+
return ~make_pad_mask(lengths, max_len)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def make_block_causal_mask(
|
| 20 |
+
lengths: torch.Tensor, max_len: int = 0, chunk_size: int = 4
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
mask = make_nonpad_mask(lengths, max_len) # (b, t)
|
| 23 |
+
attn_mask = torch.logical_and(mask.unsqueeze(1), mask.unsqueeze(2)) # (b, t, t)
|
| 24 |
+
|
| 25 |
+
num_blocks = math.ceil(attn_mask.shape[1] / chunk_size)
|
| 26 |
+
block_mask = torch.block_diag(
|
| 27 |
+
*[torch.ones(chunk_size, chunk_size) for _ in range(num_blocks)]
|
| 28 |
+
)
|
| 29 |
+
block_mask = block_mask[: attn_mask.shape[1], : attn_mask.shape[1]].to(
|
| 30 |
+
attn_mask
|
| 31 |
+
) # (t, t)
|
| 32 |
+
|
| 33 |
+
diag_mask = attn_mask.new_full(
|
| 34 |
+
(1, attn_mask.shape[1], attn_mask.shape[2]), fill_value=True
|
| 35 |
+
).tril() # (1, t, t)
|
| 36 |
+
diag_mask = diag_mask.logical_or(block_mask)
|
| 37 |
+
attn_mask = attn_mask.logical_and(diag_mask)
|
| 38 |
+
return attn_mask
|
fireredtts2/codec/whisper.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Extracted from transformers' WhisperModel to simplify package dependency
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Optional, Literal
|
| 7 |
+
from fireredtts2.codec.utils import make_nonpad_mask
|
| 8 |
+
from fireredtts2.codec.audio import mel_filter_bank
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
|
| 12 |
+
"""Returns sinusoids for positional embedding"""
|
| 13 |
+
if channels % 2 != 0:
|
| 14 |
+
raise ValueError(
|
| 15 |
+
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
|
| 16 |
+
)
|
| 17 |
+
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
|
| 18 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 19 |
+
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
|
| 20 |
+
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WhisperSdpaAttention(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
embed_dim: int,
|
| 27 |
+
num_heads: int,
|
| 28 |
+
dropout: float = 0.0,
|
| 29 |
+
bias: bool = True,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.embed_dim = embed_dim
|
| 33 |
+
self.num_heads = num_heads
|
| 34 |
+
self.dropout = dropout
|
| 35 |
+
self.head_dim = embed_dim // num_heads
|
| 36 |
+
self.bias = bias
|
| 37 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 38 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 39 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 40 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 41 |
+
|
| 42 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 43 |
+
return (
|
| 44 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
| 45 |
+
.transpose(1, 2)
|
| 46 |
+
.contiguous()
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def forward(
|
| 50 |
+
self,
|
| 51 |
+
hidden_states: torch.Tensor,
|
| 52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
attention_mask: Bool mask or float mask. Bool mask, True indicates should attend. Float mask is added to the attention score.
|
| 57 |
+
"""
|
| 58 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 59 |
+
|
| 60 |
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
| 61 |
+
key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
|
| 62 |
+
value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)
|
| 63 |
+
|
| 64 |
+
# NOTE sdpa needs a 4-dim attention_mask: (b, nh, tq, tv)
|
| 65 |
+
if attention_mask is not None and len(attention_mask.shape) == 3:
|
| 66 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 67 |
+
|
| 68 |
+
attn_output = F.scaled_dot_product_attention(
|
| 69 |
+
query_states,
|
| 70 |
+
key_states,
|
| 71 |
+
value_states,
|
| 72 |
+
attn_mask=attention_mask,
|
| 73 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 74 |
+
) # (bsz, nh, l, d)
|
| 75 |
+
attn_output = attn_output.transpose(1, 2)
|
| 76 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 77 |
+
|
| 78 |
+
attn_output = self.out_proj(attn_output)
|
| 79 |
+
return attn_output
|
| 80 |
+
|
| 81 |
+
def forward_chunk(
|
| 82 |
+
self,
|
| 83 |
+
hidden_states: torch.Tensor,
|
| 84 |
+
kv_cache: torch.Tensor = None,
|
| 85 |
+
):
|
| 86 |
+
"""Forward self-attention with kv cache.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
hidden_states: shape (b, t, c)
|
| 90 |
+
kv_cache: shape (b, nh, t, c*2)
|
| 91 |
+
"""
|
| 92 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 93 |
+
|
| 94 |
+
# shape (b, nh, t, c)
|
| 95 |
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
| 96 |
+
key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
|
| 97 |
+
value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)
|
| 98 |
+
|
| 99 |
+
# unpack cache
|
| 100 |
+
if kv_cache is not None:
|
| 101 |
+
k_cache, v_cache = kv_cache.chunk(2, dim=-1)
|
| 102 |
+
key_states = torch.cat([k_cache, key_states], dim=2)
|
| 103 |
+
value_states = torch.cat([v_cache, value_states], dim=2)
|
| 104 |
+
new_kv_cache = torch.cat([key_states, value_states], dim=-1)
|
| 105 |
+
|
| 106 |
+
# attention
|
| 107 |
+
attn_output = F.scaled_dot_product_attention(
|
| 108 |
+
query_states,
|
| 109 |
+
key_states,
|
| 110 |
+
value_states,
|
| 111 |
+
attn_mask=None,
|
| 112 |
+
dropout_p=0.0,
|
| 113 |
+
) # (bsz, nh, l, d)
|
| 114 |
+
attn_output = attn_output.transpose(1, 2)
|
| 115 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 116 |
+
|
| 117 |
+
attn_output = self.out_proj(attn_output)
|
| 118 |
+
return attn_output, new_kv_cache
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class WhisperEncoderLayer(nn.Module):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
embed_dim: int,
|
| 125 |
+
num_heads: int,
|
| 126 |
+
ffn_dim: int = None,
|
| 127 |
+
attn_dropout: float = 0.0,
|
| 128 |
+
dropout: float = 0.0,
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.dropout = dropout
|
| 132 |
+
# Attention
|
| 133 |
+
self.self_attn = WhisperSdpaAttention(embed_dim, num_heads, attn_dropout)
|
| 134 |
+
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
| 135 |
+
# FFN
|
| 136 |
+
ffn_dim = ffn_dim if ffn_dim is not None else embed_dim * 4
|
| 137 |
+
self.fc1 = nn.Linear(embed_dim, ffn_dim)
|
| 138 |
+
self.fc2 = nn.Linear(ffn_dim, embed_dim)
|
| 139 |
+
# Output norm
|
| 140 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
| 141 |
+
|
| 142 |
+
def forward(
|
| 143 |
+
self,
|
| 144 |
+
hidden_states: torch.Tensor,
|
| 145 |
+
attention_mask: torch.Tensor,
|
| 146 |
+
):
|
| 147 |
+
# Attention
|
| 148 |
+
residual = hidden_states
|
| 149 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 150 |
+
hidden_states = self.self_attn(hidden_states, attention_mask)
|
| 151 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 152 |
+
hidden_states = residual + hidden_states
|
| 153 |
+
|
| 154 |
+
# FFN
|
| 155 |
+
residual = hidden_states
|
| 156 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 157 |
+
hidden_states = F.gelu(self.fc1(hidden_states))
|
| 158 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 159 |
+
hidden_states = self.fc2(hidden_states)
|
| 160 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 161 |
+
hidden_states = residual + hidden_states
|
| 162 |
+
return hidden_states
|
| 163 |
+
|
| 164 |
+
def forward_chunk(
|
| 165 |
+
self,
|
| 166 |
+
hidden_states: torch.Tensor,
|
| 167 |
+
kv_cache: torch.Tensor = None,
|
| 168 |
+
):
|
| 169 |
+
"""Forward self-attention with kv cache.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
hidden_states: shape (b, t, c)
|
| 173 |
+
kv_cache: shape (b, nh, t, c*2)
|
| 174 |
+
"""
|
| 175 |
+
# Attention
|
| 176 |
+
residual = hidden_states
|
| 177 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 178 |
+
hidden_states, new_kv_cache = self.self_attn.forward_chunk(
|
| 179 |
+
hidden_states, kv_cache
|
| 180 |
+
)
|
| 181 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 182 |
+
hidden_states = residual + hidden_states
|
| 183 |
+
|
| 184 |
+
# FFN
|
| 185 |
+
residual = hidden_states
|
| 186 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 187 |
+
hidden_states = F.gelu(self.fc1(hidden_states))
|
| 188 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 189 |
+
hidden_states = self.fc2(hidden_states)
|
| 190 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 191 |
+
hidden_states = residual + hidden_states
|
| 192 |
+
return hidden_states, new_kv_cache
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class WhisperEncoder(nn.Module):
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
in_dim: int,
|
| 199 |
+
embed_dim: int,
|
| 200 |
+
num_layers: int,
|
| 201 |
+
num_heads: int,
|
| 202 |
+
ffn_dim: int = None,
|
| 203 |
+
attn_dropout: float = 0.0,
|
| 204 |
+
dropout: float = 0.0,
|
| 205 |
+
max_positions: int = 1500,
|
| 206 |
+
):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.in_dim = in_dim
|
| 209 |
+
self.embed_dim = embed_dim
|
| 210 |
+
self.dropout = dropout
|
| 211 |
+
# Input downsampling
|
| 212 |
+
self.conv1 = nn.Conv1d(in_dim, embed_dim, kernel_size=3, padding=1)
|
| 213 |
+
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
| 214 |
+
# Fixed positional embedding
|
| 215 |
+
self.max_positions = max_positions
|
| 216 |
+
self.embed_positions = nn.Embedding(self.max_positions, embed_dim)
|
| 217 |
+
self.embed_positions.requires_grad_(False)
|
| 218 |
+
# Transformer
|
| 219 |
+
self.layers = nn.ModuleList(
|
| 220 |
+
[
|
| 221 |
+
WhisperEncoderLayer(
|
| 222 |
+
embed_dim, num_heads, ffn_dim, attn_dropout, dropout
|
| 223 |
+
)
|
| 224 |
+
for _ in range(num_layers)
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
# Output norm
|
| 228 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
| 229 |
+
# Init weight
|
| 230 |
+
self.apply(self._init_weights)
|
| 231 |
+
# Init position embedding
|
| 232 |
+
self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape))
|
| 233 |
+
|
| 234 |
+
def forward(
|
| 235 |
+
self,
|
| 236 |
+
hidden_states: torch.Tensor,
|
| 237 |
+
hidden_length: torch.Tensor,
|
| 238 |
+
apply_position: bool = True,
|
| 239 |
+
):
|
| 240 |
+
# Downsampling
|
| 241 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 242 |
+
hidden_states = F.gelu(self.conv1(hidden_states))
|
| 243 |
+
hidden_states = F.gelu(self.conv2(hidden_states))
|
| 244 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 245 |
+
hidden_length = hidden_length // 2 # from 100Hz -> 50Hz
|
| 246 |
+
# Pos encoding
|
| 247 |
+
if apply_position:
|
| 248 |
+
pos_embed = self.embed_positions(
|
| 249 |
+
torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
|
| 250 |
+
)
|
| 251 |
+
hidden_states = hidden_states + pos_embed
|
| 252 |
+
hidden_states = nn.functional.dropout(
|
| 253 |
+
hidden_states, p=self.dropout, training=self.training
|
| 254 |
+
)
|
| 255 |
+
# Transformer
|
| 256 |
+
attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t)
|
| 257 |
+
for layer in self.layers:
|
| 258 |
+
hidden_states = layer(hidden_states, attention_mask)
|
| 259 |
+
|
| 260 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 261 |
+
return hidden_states, hidden_length
|
| 262 |
+
|
| 263 |
+
def _init_weights(self, module):
|
| 264 |
+
std = 0.02
|
| 265 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
| 266 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 267 |
+
if module.bias is not None:
|
| 268 |
+
module.bias.data.zero_()
|
| 269 |
+
elif isinstance(module, nn.Embedding):
|
| 270 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 271 |
+
if module.padding_idx is not None:
|
| 272 |
+
module.weight.data[module.padding_idx].zero_()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class WhisperMelExtractor(nn.Module):
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
num_mels: int = 128,
|
| 279 |
+
sampling_rate: int = 16000,
|
| 280 |
+
hop_length: int = 160,
|
| 281 |
+
n_fft: int = 400,
|
| 282 |
+
fmin: float = 0,
|
| 283 |
+
fmax: float = 8000,
|
| 284 |
+
padding_value=0.0,
|
| 285 |
+
):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.num_mels = num_mels
|
| 288 |
+
self.sampling_rate = sampling_rate
|
| 289 |
+
self.hop_length = hop_length
|
| 290 |
+
self.n_fft = n_fft
|
| 291 |
+
self.fmin = fmin
|
| 292 |
+
self.fmax = fmax
|
| 293 |
+
self.padding_value = padding_value
|
| 294 |
+
self.mel_filters = mel_filter_bank(
|
| 295 |
+
num_frequency_bins=(1 + n_fft // 2),
|
| 296 |
+
num_mel_filters=num_mels,
|
| 297 |
+
min_frequency=fmin,
|
| 298 |
+
max_frequency=fmax,
|
| 299 |
+
sampling_rate=sampling_rate,
|
| 300 |
+
norm="slaney",
|
| 301 |
+
mel_scale="slaney",
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def extract_fbank(self, audio: torch.Tensor):
|
| 305 |
+
"""
|
| 306 |
+
Args:
|
| 307 |
+
audio: batched audio of shape (b, t)
|
| 308 |
+
"""
|
| 309 |
+
device = audio.device # compute on cuda if input is on cuda
|
| 310 |
+
# Mel
|
| 311 |
+
window = torch.hann_window(self.n_fft).to(device)
|
| 312 |
+
stft = torch.stft(
|
| 313 |
+
audio, self.n_fft, self.hop_length, window=window, return_complex=True
|
| 314 |
+
)
|
| 315 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 316 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32).to(device)
|
| 317 |
+
mel_spec = mel_filters.T @ magnitudes
|
| 318 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 319 |
+
# Norm
|
| 320 |
+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
| 321 |
+
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
| 322 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 323 |
+
return log_spec
|
| 324 |
+
|
| 325 |
+
def __call__(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
|
| 326 |
+
mel = self.extract_fbank(audio16k).transpose(1, 2)
|
| 327 |
+
mel_length = audio16k_length // self.hop_length
|
| 328 |
+
# mel: (b, t, c=128)
|
| 329 |
+
return mel, mel_length
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Pretrained encoder from whisper-large-v3
|
| 333 |
+
class PretrainedWhisperEncoder(WhisperEncoder):
|
| 334 |
+
@classmethod
|
| 335 |
+
def from_pretrained(cls, pretrained_path: str = None):
|
| 336 |
+
encoder = cls(
|
| 337 |
+
in_dim=128,
|
| 338 |
+
embed_dim=1280,
|
| 339 |
+
num_layers=32,
|
| 340 |
+
num_heads=20,
|
| 341 |
+
ffn_dim=5120,
|
| 342 |
+
attn_dropout=0.0,
|
| 343 |
+
max_positions=1500,
|
| 344 |
+
)
|
| 345 |
+
if pretrained_path is not None:
|
| 346 |
+
ckpt = torch.load(pretrained_path, map_location="cpu")
|
| 347 |
+
encoder.load_state_dict(ckpt)
|
| 348 |
+
encoder.eval()
|
| 349 |
+
# Disable grad
|
| 350 |
+
for p in encoder.parameters():
|
| 351 |
+
p.requires_grad_(False)
|
| 352 |
+
# Add Mel extractor
|
| 353 |
+
encoder.feature_extractor = WhisperMelExtractor(
|
| 354 |
+
num_mels=128,
|
| 355 |
+
sampling_rate=16000,
|
| 356 |
+
hop_length=160,
|
| 357 |
+
n_fft=400,
|
| 358 |
+
fmin=0,
|
| 359 |
+
fmax=8000,
|
| 360 |
+
)
|
| 361 |
+
return encoder
|
| 362 |
+
|
| 363 |
+
@torch.inference_mode()
|
| 364 |
+
def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
|
| 365 |
+
# Extract mel
|
| 366 |
+
mel, mel_length = self.feature_extractor(audio16k, audio16k_length)
|
| 367 |
+
# Forward model
|
| 368 |
+
semantic_feats, semantic_length = super().forward(
|
| 369 |
+
mel, mel_length, apply_position=True
|
| 370 |
+
)
|
| 371 |
+
return semantic_feats, semantic_length
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class WhisperAcousticEncoder(WhisperEncoder):
|
| 375 |
+
def __init__(
|
| 376 |
+
self,
|
| 377 |
+
# Mel extraction params
|
| 378 |
+
num_mels: int = 128,
|
| 379 |
+
sampling_rate: int = 16000,
|
| 380 |
+
hop_length: int = 160,
|
| 381 |
+
n_fft: int = 400,
|
| 382 |
+
fmin: float = 0.0,
|
| 383 |
+
fmax: float = 8000,
|
| 384 |
+
# Encoder params
|
| 385 |
+
embed_dim: int = 768,
|
| 386 |
+
num_layers: int = 12,
|
| 387 |
+
num_heads: int = 8,
|
| 388 |
+
ffn_dim: int = None,
|
| 389 |
+
attn_dropout: float = 0.0,
|
| 390 |
+
dropout: float = 0.0,
|
| 391 |
+
max_positions: int = 1500, # 50Hz * 30s
|
| 392 |
+
):
|
| 393 |
+
super().__init__(
|
| 394 |
+
in_dim=num_mels,
|
| 395 |
+
embed_dim=embed_dim,
|
| 396 |
+
num_layers=num_layers,
|
| 397 |
+
num_heads=num_heads,
|
| 398 |
+
ffn_dim=ffn_dim,
|
| 399 |
+
attn_dropout=attn_dropout,
|
| 400 |
+
dropout=dropout,
|
| 401 |
+
max_positions=max_positions,
|
| 402 |
+
)
|
| 403 |
+
self.feature_extractor = WhisperMelExtractor(
|
| 404 |
+
num_mels=num_mels,
|
| 405 |
+
sampling_rate=sampling_rate,
|
| 406 |
+
hop_length=hop_length,
|
| 407 |
+
n_fft=n_fft,
|
| 408 |
+
fmin=fmin,
|
| 409 |
+
fmax=fmax,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
|
| 413 |
+
# Extract mel
|
| 414 |
+
with torch.no_grad():
|
| 415 |
+
mel, mel_length = self.feature_extractor(audio16k, audio16k_length)
|
| 416 |
+
# Forward model
|
| 417 |
+
hidden_states, hidden_length = super().forward(
|
| 418 |
+
mel, mel_length, apply_position=True
|
| 419 |
+
)
|
| 420 |
+
return hidden_states, hidden_length
|
fireredtts2/fireredtts2.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
from fireredtts2.codec import RedCodecInfer
|
| 9 |
+
from fireredtts2.llm import load_llm_model, load_custom_tokenizer
|
| 10 |
+
from fireredtts2.llm.utils import Segment
|
| 11 |
+
from fireredtts2.utils.spliter import clean_text, split_text, process_text_list
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FireRedTTS2:
|
| 16 |
+
def __init__(self, pretrained_dir, gen_type, device):
|
| 17 |
+
|
| 18 |
+
assert os.path.exists(pretrained_dir)
|
| 19 |
+
assert gen_type in ["monologue", "dialogue"]
|
| 20 |
+
llm_config_path = os.path.join(pretrained_dir, "config_llm.json")
|
| 21 |
+
if gen_type == "monologue":
|
| 22 |
+
llm_ckpt_path = os.path.join(pretrained_dir, "llm_pretrain.pt")
|
| 23 |
+
# llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt")
|
| 24 |
+
else:
|
| 25 |
+
llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt")
|
| 26 |
+
codec_config_path = os.path.join(pretrained_dir, "config_codec.json")
|
| 27 |
+
codec_ckpt_path = os.path.join(pretrained_dir, "codec.pt")
|
| 28 |
+
pretrained_qwen_path = os.path.join(pretrained_dir, "Qwen2.5-1.5B")
|
| 29 |
+
|
| 30 |
+
# check
|
| 31 |
+
assert os.path.exists(llm_config_path)
|
| 32 |
+
assert os.path.exists(llm_ckpt_path)
|
| 33 |
+
assert os.path.exists(codec_config_path)
|
| 34 |
+
assert os.path.exists(codec_ckpt_path)
|
| 35 |
+
assert os.path.exists(pretrained_qwen_path)
|
| 36 |
+
|
| 37 |
+
# ==== Load Torch LLM ====
|
| 38 |
+
llm_config = json.load(open(llm_config_path))
|
| 39 |
+
self._model = load_llm_model(
|
| 40 |
+
configs=llm_config, checkpoint_path=llm_ckpt_path, device=device
|
| 41 |
+
)
|
| 42 |
+
self._model.eval()
|
| 43 |
+
self._model.setup_caches(1)
|
| 44 |
+
print("[INFO] LLM Loaded...")
|
| 45 |
+
|
| 46 |
+
# ==== Load Qwen2.5 Text Tokenizer ====
|
| 47 |
+
self._text_tokenizer = load_custom_tokenizer(pretrained_qwen_path)
|
| 48 |
+
print("[INFO] Text Tokenizer Loaded...")
|
| 49 |
+
|
| 50 |
+
# ==== Load Torch Audio Tokenizer ====
|
| 51 |
+
torch_codec = RedCodecInfer.from_pretrained(codec_config_path, codec_ckpt_path)
|
| 52 |
+
torch_codec.eval()
|
| 53 |
+
self._audio_tokenizer = torch_codec.to(device)
|
| 54 |
+
print("[INFO] Codec Loaded...")
|
| 55 |
+
|
| 56 |
+
self.sample_rate = 16000
|
| 57 |
+
self.device = device
|
| 58 |
+
self.max_seq_len = 3100
|
| 59 |
+
|
| 60 |
+
def load_prompt_audio(self, audio_path) -> torch.Tensor:
|
| 61 |
+
audio, audio_sr = torchaudio.load(audio_path)
|
| 62 |
+
# Audio must be single channel
|
| 63 |
+
if audio.shape[0] > 1:
|
| 64 |
+
audio = audio[0, :].unsqueeze(0)
|
| 65 |
+
audio16k = torchaudio.functional.resample(audio, audio_sr, 16000)
|
| 66 |
+
return audio16k
|
| 67 |
+
|
| 68 |
+
def prepare_prompt(self, text, speaker, audio_path) -> Segment:
|
| 69 |
+
audio_tensor = self.load_prompt_audio(audio_path)
|
| 70 |
+
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
| 71 |
+
|
| 72 |
+
def _tokenize_text_segment(
|
| 73 |
+
self, text: str, speaker: str
|
| 74 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 75 |
+
frame_tokens = []
|
| 76 |
+
frame_masks = []
|
| 77 |
+
|
| 78 |
+
text = speaker + "<|text_start|>" + text + "<|text_end|>"
|
| 79 |
+
text_tokens = self._text_tokenizer.encode(text)
|
| 80 |
+
text_frame = torch.zeros(len(text_tokens), 17).long()
|
| 81 |
+
text_frame_mask = torch.zeros(len(text_tokens), 17).bool()
|
| 82 |
+
text_frame[:, -1] = torch.tensor(text_tokens)
|
| 83 |
+
text_frame_mask[:, -1] = True
|
| 84 |
+
|
| 85 |
+
frame_tokens.append(text_frame.to(self.device))
|
| 86 |
+
frame_masks.append(text_frame_mask.to(self.device))
|
| 87 |
+
|
| 88 |
+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
| 89 |
+
|
| 90 |
+
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 91 |
+
frame_tokens = []
|
| 92 |
+
frame_masks = []
|
| 93 |
+
|
| 94 |
+
# (K, T)
|
| 95 |
+
audio_length = torch.tensor([audio.shape[1]], dtype=torch.long)
|
| 96 |
+
audio_tokens, token_length = self._audio_tokenizer.encode(
|
| 97 |
+
audio.to(self.device),
|
| 98 |
+
audio_length.to(self.device),
|
| 99 |
+
batch_size=48,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
audio_tokens = audio_tokens.squeeze(0)
|
| 103 |
+
# add EOS frame
|
| 104 |
+
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
| 105 |
+
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
| 106 |
+
|
| 107 |
+
audio_frame = torch.zeros(audio_tokens.size(1), 17).long().to(self.device)
|
| 108 |
+
audio_frame_mask = torch.zeros(audio_tokens.size(1), 17).bool().to(self.device)
|
| 109 |
+
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
| 110 |
+
audio_frame_mask[:, :-1] = True
|
| 111 |
+
|
| 112 |
+
frame_tokens.append(audio_frame)
|
| 113 |
+
frame_masks.append(audio_frame_mask)
|
| 114 |
+
|
| 115 |
+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
| 116 |
+
|
| 117 |
+
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 118 |
+
"""
|
| 119 |
+
Returns:
|
| 120 |
+
(seq_len,17), (seq_len, 17)
|
| 121 |
+
"""
|
| 122 |
+
text_tokens, text_masks = self._tokenize_text_segment(
|
| 123 |
+
segment.text, segment.speaker
|
| 124 |
+
)
|
| 125 |
+
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
| 126 |
+
|
| 127 |
+
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat(
|
| 128 |
+
[text_masks, audio_masks], dim=0
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
@torch.inference_mode()
|
| 132 |
+
def generate(
|
| 133 |
+
self,
|
| 134 |
+
text: str,
|
| 135 |
+
speaker: str,
|
| 136 |
+
context: List[Segment],
|
| 137 |
+
max_audio_length_ms: float = 90_000,
|
| 138 |
+
temperature: float = 0.9,
|
| 139 |
+
topk: int = 20,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
self._model.reset_caches()
|
| 142 |
+
|
| 143 |
+
max_generation_len = int(max_audio_length_ms / 80)
|
| 144 |
+
tokens, tokens_mask = [], []
|
| 145 |
+
for segment in context:
|
| 146 |
+
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
| 147 |
+
tokens.append(segment_tokens)
|
| 148 |
+
tokens_mask.append(segment_tokens_mask)
|
| 149 |
+
|
| 150 |
+
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
|
| 151 |
+
text, speaker
|
| 152 |
+
)
|
| 153 |
+
tokens.append(gen_segment_tokens)
|
| 154 |
+
tokens_mask.append(gen_segment_tokens_mask)
|
| 155 |
+
|
| 156 |
+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
| 157 |
+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
| 158 |
+
|
| 159 |
+
samples = []
|
| 160 |
+
curr_tokens = prompt_tokens.unsqueeze(0)
|
| 161 |
+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
| 162 |
+
curr_pos = (
|
| 163 |
+
torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
max_seq_len = 3100
|
| 167 |
+
max_context_len = max_seq_len - max_generation_len
|
| 168 |
+
if curr_tokens.size(1) >= max_context_len:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
for _ in range(max_generation_len):
|
| 174 |
+
sample = self._model.generate_frame(
|
| 175 |
+
curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
|
| 176 |
+
)
|
| 177 |
+
# eos
|
| 178 |
+
if torch.all(sample == 0):
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
samples.append(sample)
|
| 182 |
+
|
| 183 |
+
curr_tokens = torch.cat(
|
| 184 |
+
[sample, torch.zeros(1, 1).long().to(self.device)], dim=1
|
| 185 |
+
).unsqueeze(1)
|
| 186 |
+
curr_tokens_mask = torch.cat(
|
| 187 |
+
[
|
| 188 |
+
torch.ones_like(sample).bool(),
|
| 189 |
+
torch.zeros(1, 1).bool().to(self.device),
|
| 190 |
+
],
|
| 191 |
+
dim=1,
|
| 192 |
+
).unsqueeze(1)
|
| 193 |
+
curr_pos = curr_pos[:, -1:] + 1
|
| 194 |
+
|
| 195 |
+
audio = (
|
| 196 |
+
self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0))
|
| 197 |
+
.squeeze(0)
|
| 198 |
+
.squeeze(0)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return audio
|
| 202 |
+
|
| 203 |
+
def generate_single(
|
| 204 |
+
self, context: List[Segment], temperature: float = 0.9, topk: int = 20
|
| 205 |
+
):
|
| 206 |
+
self._model.reset_caches()
|
| 207 |
+
max_generation_len = 400
|
| 208 |
+
tokens, tokens_mask = [], []
|
| 209 |
+
for segment in context:
|
| 210 |
+
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
| 211 |
+
tokens.append(segment_tokens)
|
| 212 |
+
tokens_mask.append(segment_tokens_mask)
|
| 213 |
+
|
| 214 |
+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
| 215 |
+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
| 216 |
+
prompt_tokens = prompt_tokens[:-3, :]
|
| 217 |
+
prompt_tokens_mask = prompt_tokens_mask[:-3, :]
|
| 218 |
+
|
| 219 |
+
samples = []
|
| 220 |
+
curr_tokens = prompt_tokens.unsqueeze(0)
|
| 221 |
+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
| 222 |
+
curr_pos = (
|
| 223 |
+
torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
num_token = 0
|
| 227 |
+
start_time = time.time()
|
| 228 |
+
for _ in range(max_generation_len):
|
| 229 |
+
sample = self._model.generate_frame(
|
| 230 |
+
curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
|
| 231 |
+
)
|
| 232 |
+
# eos
|
| 233 |
+
if torch.all(sample == 0):
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
samples.append(sample)
|
| 237 |
+
|
| 238 |
+
curr_tokens = torch.cat(
|
| 239 |
+
[sample, torch.zeros(1, 1).long().to(self.device)], dim=1
|
| 240 |
+
).unsqueeze(1)
|
| 241 |
+
curr_tokens_mask = torch.cat(
|
| 242 |
+
[
|
| 243 |
+
torch.ones_like(sample).bool(),
|
| 244 |
+
torch.zeros(1, 1).bool().to(self.device),
|
| 245 |
+
],
|
| 246 |
+
dim=1,
|
| 247 |
+
).unsqueeze(1)
|
| 248 |
+
curr_pos = curr_pos[:, -1:] + 1
|
| 249 |
+
num_token += 1
|
| 250 |
+
if num_token == 2:
|
| 251 |
+
end_time = time.time()
|
| 252 |
+
duration = end_time - start_time
|
| 253 |
+
print("---first pack duration:", duration)
|
| 254 |
+
|
| 255 |
+
gen_tokens = torch.stack(samples).permute(1, 2, 0)
|
| 256 |
+
|
| 257 |
+
return gen_tokens
|
| 258 |
+
|
| 259 |
+
# @torch.inference_mode()
|
| 260 |
+
# def generate_stream(
|
| 261 |
+
# self,
|
| 262 |
+
# text: str,
|
| 263 |
+
# speaker: str,
|
| 264 |
+
# context: List[Segment],
|
| 265 |
+
# max_audio_length_ms: float = 90_000,
|
| 266 |
+
# temperature: float = 0.9,
|
| 267 |
+
# topk: int = 50,
|
| 268 |
+
# ):
|
| 269 |
+
# self._model.reset_caches()
|
| 270 |
+
|
| 271 |
+
# max_generation_len = int(max_audio_length_ms / 80)
|
| 272 |
+
# tokens, tokens_mask = [], []
|
| 273 |
+
# for segment in context:
|
| 274 |
+
# segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
| 275 |
+
# tokens.append(segment_tokens)
|
| 276 |
+
# tokens_mask.append(segment_tokens_mask)
|
| 277 |
+
|
| 278 |
+
# gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
|
| 279 |
+
# text, speaker
|
| 280 |
+
# )
|
| 281 |
+
# tokens.append(gen_segment_tokens)
|
| 282 |
+
# tokens_mask.append(gen_segment_tokens_mask)
|
| 283 |
+
|
| 284 |
+
# prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
| 285 |
+
# prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
| 286 |
+
|
| 287 |
+
# samples = []
|
| 288 |
+
# curr_tokens = prompt_tokens.unsqueeze(0)
|
| 289 |
+
# curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
| 290 |
+
# curr_pos = (
|
| 291 |
+
# torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
| 292 |
+
# )
|
| 293 |
+
|
| 294 |
+
# max_seq_len = 3100
|
| 295 |
+
# max_context_len = max_seq_len - max_generation_len
|
| 296 |
+
# if curr_tokens.size(1) >= max_context_len:
|
| 297 |
+
# raise ValueError(
|
| 298 |
+
# f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
|
| 299 |
+
# )
|
| 300 |
+
|
| 301 |
+
# # codec cache
|
| 302 |
+
# codec_cache = {}
|
| 303 |
+
# prev_sample = None
|
| 304 |
+
|
| 305 |
+
# for _ in range(max_generation_len):
|
| 306 |
+
# sample = self._model.generate_frame(
|
| 307 |
+
# curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
|
| 308 |
+
# )
|
| 309 |
+
# # eos
|
| 310 |
+
# if torch.all(sample == 0):
|
| 311 |
+
# break
|
| 312 |
+
|
| 313 |
+
# # decode one token
|
| 314 |
+
# if prev_sample is None:
|
| 315 |
+
# prev_sample = sample # sample: (b, nq)
|
| 316 |
+
# else:
|
| 317 |
+
# audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token(
|
| 318 |
+
# prev_sample.unsqueeze(-1),
|
| 319 |
+
# codec_cache,
|
| 320 |
+
# last_token=False,
|
| 321 |
+
# )
|
| 322 |
+
# yield audio_chunk.squeeze(0)
|
| 323 |
+
# prev_sample = sample
|
| 324 |
+
# samples.append(sample) # sample: (b, nq)
|
| 325 |
+
|
| 326 |
+
# curr_tokens = torch.cat(
|
| 327 |
+
# [sample, torch.zeros(1, 1).long().to(self.device)], dim=1
|
| 328 |
+
# ).unsqueeze(1)
|
| 329 |
+
# curr_tokens_mask = torch.cat(
|
| 330 |
+
# [
|
| 331 |
+
# torch.ones_like(sample).bool(),
|
| 332 |
+
# torch.zeros(1, 1).bool().to(self.device),
|
| 333 |
+
# ],
|
| 334 |
+
# dim=1,
|
| 335 |
+
# ).unsqueeze(1)
|
| 336 |
+
# curr_pos = curr_pos[:, -1:] + 1
|
| 337 |
+
|
| 338 |
+
# audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token(
|
| 339 |
+
# prev_sample.unsqueeze(-1),
|
| 340 |
+
# codec_cache,
|
| 341 |
+
# last_token=True,
|
| 342 |
+
# )
|
| 343 |
+
# yield audio_chunk.squeeze(0)
|
| 344 |
+
|
| 345 |
+
@torch.inference_mode()
|
| 346 |
+
def generate_dialogue(
|
| 347 |
+
self,
|
| 348 |
+
text_list,
|
| 349 |
+
prompt_wav_list=None,
|
| 350 |
+
prompt_text_list=None,
|
| 351 |
+
temperature=0.9,
|
| 352 |
+
topk=20,
|
| 353 |
+
):
|
| 354 |
+
all_generated_segments = []
|
| 355 |
+
all_storage_segments = []
|
| 356 |
+
prompt_segments = []
|
| 357 |
+
text_list = process_text_list(text_list=text_list)
|
| 358 |
+
if prompt_wav_list is not None:
|
| 359 |
+
assert len(prompt_wav_list) == len(prompt_text_list)
|
| 360 |
+
# Prepare prompts
|
| 361 |
+
for i in range(len(prompt_wav_list)):
|
| 362 |
+
prompt_wav = prompt_wav_list[i]
|
| 363 |
+
prompt_text = prompt_text_list[i]
|
| 364 |
+
speaker = prompt_text[:4]
|
| 365 |
+
assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
|
| 366 |
+
prompt_segments.append(
|
| 367 |
+
self.prepare_prompt(
|
| 368 |
+
text=prompt_text, speaker=speaker, audio_path=prompt_wav
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
for text in tqdm(text_list):
|
| 373 |
+
speaker = text[:4]
|
| 374 |
+
text = text[4:]
|
| 375 |
+
# print("---speaker:", speaker)
|
| 376 |
+
# print("---text:", text)
|
| 377 |
+
assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
|
| 378 |
+
|
| 379 |
+
audio_tensor = self.generate(
|
| 380 |
+
text=text,
|
| 381 |
+
speaker=speaker,
|
| 382 |
+
context=prompt_segments + all_generated_segments,
|
| 383 |
+
max_audio_length_ms=30_000,
|
| 384 |
+
temperature=temperature,
|
| 385 |
+
topk=topk,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# 做上下文管理的时候需要将audio 转到16k
|
| 389 |
+
audio_16k = torchaudio.functional.resample(
|
| 390 |
+
audio_tensor.unsqueeze(0), 24000, 16000
|
| 391 |
+
)
|
| 392 |
+
all_generated_segments.append(
|
| 393 |
+
Segment(text=text, speaker=speaker, audio=audio_16k)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
all_storage_segments.append(
|
| 397 |
+
Segment(text=text, speaker=speaker, audio=audio_tensor.unsqueeze(0))
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Concatenate all generations
|
| 401 |
+
all_audio = torch.cat([seg.audio for seg in all_storage_segments], dim=1)
|
| 402 |
+
all_audio = all_audio.cpu()
|
| 403 |
+
return all_audio
|
| 404 |
+
|
| 405 |
+
@torch.inference_mode()
|
| 406 |
+
def generate_monologue(
|
| 407 |
+
self, text, prompt_wav=None, prompt_text=None, temperature=0.75, topk=20
|
| 408 |
+
):
|
| 409 |
+
# step1. construct context
|
| 410 |
+
if prompt_wav is not None:
|
| 411 |
+
assert os.path.exists(prompt_wav)
|
| 412 |
+
assert prompt_text is not None
|
| 413 |
+
|
| 414 |
+
all_generated_segments = []
|
| 415 |
+
all_storage_segments = []
|
| 416 |
+
prompt_segments = []
|
| 417 |
+
prompt_text = clean_text(text=prompt_text)
|
| 418 |
+
text = clean_text(text=text)
|
| 419 |
+
text_list = split_text(text=text, length=400)
|
| 420 |
+
|
| 421 |
+
audio_list = []
|
| 422 |
+
for text in text_list:
|
| 423 |
+
text = clean_text(text=text)
|
| 424 |
+
input_text = prompt_text[:-1] + "," + text
|
| 425 |
+
prompt_a = self.prepare_prompt(
|
| 426 |
+
text=input_text, speaker="[S1]", audio_path=prompt_wav
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
context = [prompt_a]
|
| 430 |
+
|
| 431 |
+
while True:
|
| 432 |
+
gen_tokens = self.generate_single(
|
| 433 |
+
context=context, temperature=temperature, topk=topk
|
| 434 |
+
)
|
| 435 |
+
if gen_tokens.shape[2] > 18:
|
| 436 |
+
break
|
| 437 |
+
# else:
|
| 438 |
+
# print("生成结果小于1s,重新跑")
|
| 439 |
+
|
| 440 |
+
gen_tokens = gen_tokens[:, :, 2:] # cut leading silence
|
| 441 |
+
audio = self._audio_tokenizer.decode(gen_tokens).squeeze(0).squeeze(0)
|
| 442 |
+
audio_list.append(audio.unsqueeze(0))
|
| 443 |
+
|
| 444 |
+
all_audio = torch.cat(tensors=audio_list, dim=1)
|
| 445 |
+
|
| 446 |
+
return all_audio
|
| 447 |
+
|
| 448 |
+
else:
|
| 449 |
+
# random speaker
|
| 450 |
+
text = clean_text(text=text.strip())
|
| 451 |
+
audio_tensor = self.generate(
|
| 452 |
+
text=text,
|
| 453 |
+
speaker="[S1]",
|
| 454 |
+
context=[],
|
| 455 |
+
max_audio_length_ms=30_000,
|
| 456 |
+
temperature=temperature,
|
| 457 |
+
topk=topk,
|
| 458 |
+
)
|
| 459 |
+
return audio_tensor.unsqueeze(0)
|
fireredtts2/llm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from fireredtts2.llm.utils import load_llm_model, load_custom_tokenizer
|
fireredtts2/llm/llm.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 6 |
+
from fireredtts2.llm.modules import FLAVORS
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _prepare_transformer(model):
|
| 10 |
+
embed_dim = model.tok_embeddings.embedding_dim
|
| 11 |
+
model.tok_embeddings = nn.Identity()
|
| 12 |
+
model.output = nn.Identity()
|
| 13 |
+
return model, embed_dim
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _create_causal_mask(seq_len: int, device: torch.device):
|
| 17 |
+
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
| 21 |
+
"""
|
| 22 |
+
Args:
|
| 23 |
+
mask: (max_seq_len, max_seq_len)
|
| 24 |
+
input_pos: (batch_size, seq_len)
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
(batch_size, seq_len, max_seq_len)
|
| 28 |
+
"""
|
| 29 |
+
r = mask[input_pos, :]
|
| 30 |
+
return r
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Does multinomial sampling without a cuda synchronization
|
| 34 |
+
def _multinomial_sample_one_no_sync(probs):
|
| 35 |
+
q = torch.empty_like(probs).exponential_(1)
|
| 36 |
+
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
| 40 |
+
logits = logits / temperature
|
| 41 |
+
|
| 42 |
+
filter_value: float = -float("Inf")
|
| 43 |
+
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
| 44 |
+
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
| 45 |
+
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
| 46 |
+
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
| 47 |
+
|
| 48 |
+
sample_token = _multinomial_sample_one_no_sync(probs)
|
| 49 |
+
return sample_token
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def sample_top_nsigma(logits: torch.Tensor, n: float, temperature: float):
|
| 53 |
+
"""_summary_
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
logits (torch.Tensor): _description_
|
| 57 |
+
n (float): _description_
|
| 58 |
+
temperature (float): _description_
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
_type_: _description_
|
| 62 |
+
"""
|
| 63 |
+
logits = logits / temperature
|
| 64 |
+
threshold = logits.max(dim=-1, keepdim=True).values - n * logits.std(
|
| 65 |
+
dim=-1, keepdim=True
|
| 66 |
+
)
|
| 67 |
+
logits[logits < threshold] = float("-inf")
|
| 68 |
+
# scores_processed = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 69 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 70 |
+
|
| 71 |
+
sample_token = _multinomial_sample_one_no_sync(probs)
|
| 72 |
+
return sample_token
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ModelArgs:
|
| 77 |
+
backbone_flavor: str
|
| 78 |
+
decoder_flavor: str
|
| 79 |
+
text_vocab_size: int
|
| 80 |
+
audio_vocab_size: int
|
| 81 |
+
audio_num_codebooks: int
|
| 82 |
+
decoder_loss_weight: float
|
| 83 |
+
use_text_loss: bool
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Model(nn.Module, PyTorchModelHubMixin):
|
| 87 |
+
def __init__(self, config: ModelArgs):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.config = config
|
| 90 |
+
|
| 91 |
+
self.backbone, backbone_dim = _prepare_transformer(
|
| 92 |
+
FLAVORS[config.backbone_flavor]()
|
| 93 |
+
)
|
| 94 |
+
self.decoder, decoder_dim = _prepare_transformer(
|
| 95 |
+
FLAVORS[config.decoder_flavor]()
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
| 99 |
+
self.audio_embeddings = nn.Embedding(
|
| 100 |
+
config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
| 104 |
+
self.text_head = nn.Linear(backbone_dim, config.text_vocab_size, bias=False)
|
| 105 |
+
self.codebook0_head = nn.Linear(
|
| 106 |
+
backbone_dim, config.audio_vocab_size, bias=False
|
| 107 |
+
)
|
| 108 |
+
self.audio_head = nn.Parameter(
|
| 109 |
+
torch.empty(
|
| 110 |
+
config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.decoder_loss_weight = config.decoder_loss_weight
|
| 115 |
+
self.use_text_loss = config.use_text_loss
|
| 116 |
+
|
| 117 |
+
# debug
|
| 118 |
+
# print("---backbone_dim:", backbone_dim)
|
| 119 |
+
# print("---decoder_dim:", decoder_dim)
|
| 120 |
+
# print("---self.decoder_loss_weight:", self.decoder_loss_weight)
|
| 121 |
+
# print("---self.use_text_loss:", self.use_text_loss)
|
| 122 |
+
|
| 123 |
+
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
| 124 |
+
"""Setup KV caches and return a causal mask."""
|
| 125 |
+
dtype = next(self.parameters()).dtype
|
| 126 |
+
device = next(self.parameters()).device
|
| 127 |
+
|
| 128 |
+
with device:
|
| 129 |
+
self.backbone.setup_caches(max_batch_size, dtype)
|
| 130 |
+
self.decoder.setup_caches(
|
| 131 |
+
max_batch_size,
|
| 132 |
+
dtype,
|
| 133 |
+
decoder_max_seq_len=self.config.audio_num_codebooks,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.register_buffer(
|
| 137 |
+
"backbone_causal_mask",
|
| 138 |
+
_create_causal_mask(self.backbone.max_seq_len, device),
|
| 139 |
+
)
|
| 140 |
+
self.register_buffer(
|
| 141 |
+
"decoder_causal_mask",
|
| 142 |
+
_create_causal_mask(self.config.audio_num_codebooks, device),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def forward(self, tokens: torch.Tensor, tokens_mask: torch.Tensor):
|
| 146 |
+
"""
|
| 147 |
+
Forward pass for Sesame's CSM model.
|
| 148 |
+
This will be added to the model with `model.forward = types.MethodType(forward, model)`
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
tokens: (batch_size, seq_len, n_codebooks+1)
|
| 152 |
+
tokens_mask: (batch_size, seq_len, n_codebooks+1)
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
dtype = next(self.parameters()).dtype
|
| 156 |
+
bsz, seq_len, _ = tokens.size()
|
| 157 |
+
device = tokens.device
|
| 158 |
+
|
| 159 |
+
# print("---tokens:\n", tokens, tokens.shape)
|
| 160 |
+
# print("---tokens_mask:\n", tokens_mask, tokens_mask.shape)
|
| 161 |
+
# print("---bsz:", bsz)
|
| 162 |
+
# print("---seq_len:", seq_len)
|
| 163 |
+
|
| 164 |
+
# embed tokens
|
| 165 |
+
embeds = self._embed_tokens(tokens) # (bsz,seq_len,33,2048)
|
| 166 |
+
# print("---embeds:\n", embeds, embeds.shape)
|
| 167 |
+
|
| 168 |
+
# get targets and codebook embeddings corresponding to audio tokens
|
| 169 |
+
audio_mask = tokens_mask[:, :, 0] # [bsz, seq_len]
|
| 170 |
+
target_tokens = tokens[audio_mask][:, :-1] # [audio_len, n_codebooks]
|
| 171 |
+
# [audio_len, n_codebooks, embed_dim]
|
| 172 |
+
c_embeds = embeds[:, :, :-1, :][audio_mask]
|
| 173 |
+
# print("---audio_mask:\n", audio_mask, audio_mask.shape)
|
| 174 |
+
# print("---target_tokens:\n", target_tokens, target_tokens.shape)
|
| 175 |
+
|
| 176 |
+
# get targets corresponding to text tokens
|
| 177 |
+
text_mask = tokens_mask[:, :, -1]
|
| 178 |
+
text_target_mask = torch.roll(input=text_mask, shifts=1, dims=1)
|
| 179 |
+
text_target_tokens = tokens[text_target_mask][:, -1]
|
| 180 |
+
|
| 181 |
+
# print("---text_target_mask:\n", text_target_mask, text_target_mask.shape)
|
| 182 |
+
# print("---target_text_tokens:\n", text_target_tokens, text_target_tokens.shape)
|
| 183 |
+
|
| 184 |
+
# print("\n\n")
|
| 185 |
+
|
| 186 |
+
# retain just non-padding embeddings
|
| 187 |
+
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
| 188 |
+
h = masked_embeds.sum(dim=2)
|
| 189 |
+
|
| 190 |
+
# backbone forward pass
|
| 191 |
+
# [bsz, seq_len]
|
| 192 |
+
padding_mask = tokens_mask[:, :, 0] | tokens_mask[:, :, -1]
|
| 193 |
+
# [seq_len, seq_len]
|
| 194 |
+
backbone_attn_mask = _create_causal_mask(seq_len, device)
|
| 195 |
+
# [bsz, seq_len, seq_len]
|
| 196 |
+
padding_3d = padding_mask.unsqueeze(-1) * padding_mask.unsqueeze(1)
|
| 197 |
+
backbone_attn_mask = backbone_attn_mask.unsqueeze(0) * padding_3d
|
| 198 |
+
backbone_attn_mask = backbone_attn_mask | torch.eye(
|
| 199 |
+
seq_len, device=device
|
| 200 |
+
).bool().unsqueeze(0).expand(bsz, -1, -1)
|
| 201 |
+
input_pos = (
|
| 202 |
+
torch.arange(0, seq_len).unsqueeze(0).expand(bsz, seq_len).long().to(device)
|
| 203 |
+
)
|
| 204 |
+
h = self.backbone(h, input_pos=input_pos, mask=backbone_attn_mask).to(
|
| 205 |
+
dtype=dtype
|
| 206 |
+
)
|
| 207 |
+
# print("---h:\n", h, h.shape)
|
| 208 |
+
|
| 209 |
+
# get backbone embeddings used for audio codebook prediction predict first codebook and compute loss
|
| 210 |
+
audio_mask = torch.roll(audio_mask, -1, 1) # shift audio mask to the right by 1
|
| 211 |
+
audio_h = h[audio_mask] # [audio_len, embed_dim]
|
| 212 |
+
# print("---audio_mask after shift:\n", audio_mask, audio_mask.shape)
|
| 213 |
+
c0_logits = self.codebook0_head(audio_h) # [audio_len, audio_vocab_size]
|
| 214 |
+
c0_target = target_tokens[:, 0] # [audio_len]
|
| 215 |
+
c0_loss = F.cross_entropy(c0_logits, c0_target)
|
| 216 |
+
|
| 217 |
+
# predict text loss
|
| 218 |
+
text_h = h[text_mask]
|
| 219 |
+
text_logits = self.text_head(text_h)
|
| 220 |
+
text_loss = F.cross_entropy(text_logits, text_target_tokens, ignore_index=0)
|
| 221 |
+
# print("---text_h:\n", text_h, text_h.shape)
|
| 222 |
+
# print("---text_logits:\n", text_logits)
|
| 223 |
+
# print("---text_loss:", text_loss)
|
| 224 |
+
|
| 225 |
+
# "compute amortization" (train decoder on random 1/16 subset of audio tokens)
|
| 226 |
+
# important change to 1/8
|
| 227 |
+
# indices = torch.randperm(c_embeds.size(0))[: c_embeds.size(0) // 16]
|
| 228 |
+
indices = torch.randperm(c_embeds.size(0))[: c_embeds.size(0) // 8]
|
| 229 |
+
# [audio_len//16, n_codebooks-1, embed_dim]
|
| 230 |
+
c_embeds = c_embeds[indices][:, :-1, :]
|
| 231 |
+
audio_h = audio_h[indices] # [audio_len//16, embed_dim]
|
| 232 |
+
target_tokens = target_tokens[indices][:, 1:] # [audio_len//16, n_codebooks-1]
|
| 233 |
+
|
| 234 |
+
# concatenate backbone embeddings and codebook embeddings for decoder input
|
| 235 |
+
# [audio_len//16, n_codebooks, embed_dim]
|
| 236 |
+
decoder_embeds = torch.cat([audio_h.unsqueeze(1), c_embeds], dim=1)
|
| 237 |
+
N, n_codebooks, _ = decoder_embeds.size()
|
| 238 |
+
c_pos = (
|
| 239 |
+
torch.arange(0, n_codebooks)
|
| 240 |
+
.unsqueeze(0)
|
| 241 |
+
.expand(N, n_codebooks)
|
| 242 |
+
.long()
|
| 243 |
+
.to(device)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
decoder_causal_mask = _create_causal_mask(
|
| 247 |
+
decoder_embeds.size(1), device
|
| 248 |
+
).expand(N, -1, -1)
|
| 249 |
+
decoder_h = self.decoder(
|
| 250 |
+
self.projection(decoder_embeds), input_pos=c_pos, mask=decoder_causal_mask
|
| 251 |
+
).to(dtype=dtype)
|
| 252 |
+
c_logits = torch.einsum("bsd,sdv->bsv", decoder_h[:, 1:, :], self.audio_head)
|
| 253 |
+
|
| 254 |
+
c_loss = F.cross_entropy(
|
| 255 |
+
c_logits.reshape(-1, c_logits.size(-1)), target_tokens.reshape(-1)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if self.use_text_loss:
|
| 259 |
+
loss = (
|
| 260 |
+
2
|
| 261 |
+
* (
|
| 262 |
+
(1 - self.decoder_loss_weight) * c0_loss
|
| 263 |
+
+ self.decoder_loss_weight * c_loss
|
| 264 |
+
)
|
| 265 |
+
+ 0.01 * text_loss
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
loss = 2 * (
|
| 269 |
+
(1 - self.decoder_loss_weight) * c0_loss
|
| 270 |
+
+ self.decoder_loss_weight * c_loss
|
| 271 |
+
)
|
| 272 |
+
return loss, text_loss, c0_loss, c_loss
|
| 273 |
+
|
| 274 |
+
def generate_frame(
|
| 275 |
+
self,
|
| 276 |
+
tokens: torch.Tensor,
|
| 277 |
+
tokens_mask: torch.Tensor,
|
| 278 |
+
input_pos: torch.Tensor,
|
| 279 |
+
temperature: float,
|
| 280 |
+
topk: int,
|
| 281 |
+
) -> torch.Tensor:
|
| 282 |
+
"""
|
| 283 |
+
Args:
|
| 284 |
+
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
| 285 |
+
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
| 286 |
+
input_pos: (batch_size, seq_len) positions for each token
|
| 287 |
+
mask: (batch_size, seq_len, max_seq_len
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
(batch_size, audio_num_codebooks) sampled tokens
|
| 291 |
+
"""
|
| 292 |
+
dtype = next(self.parameters()).dtype
|
| 293 |
+
b, s, _ = tokens.size()
|
| 294 |
+
|
| 295 |
+
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
| 296 |
+
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
| 297 |
+
embeds = self._embed_tokens(tokens)
|
| 298 |
+
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
| 299 |
+
h = masked_embeds.sum(dim=2)
|
| 300 |
+
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(
|
| 301 |
+
dtype=dtype
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
last_h = h[:, -1, :]
|
| 305 |
+
c0_logits = self.codebook0_head(last_h)
|
| 306 |
+
c0_sample = sample_topk(c0_logits, topk, temperature)
|
| 307 |
+
c0_embed = self._embed_audio(0, c0_sample)
|
| 308 |
+
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
| 309 |
+
curr_sample = c0_sample.clone()
|
| 310 |
+
curr_pos = (
|
| 311 |
+
torch.arange(0, curr_h.size(1), device=curr_h.device)
|
| 312 |
+
.unsqueeze(0)
|
| 313 |
+
.repeat(curr_h.size(0), 1)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Decoder caches must be reset every frame.
|
| 317 |
+
self.decoder.reset_caches()
|
| 318 |
+
for i in range(1, self.config.audio_num_codebooks):
|
| 319 |
+
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
| 320 |
+
decoder_h = self.decoder(
|
| 321 |
+
self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
|
| 322 |
+
).to(dtype=dtype)
|
| 323 |
+
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
| 324 |
+
ci_sample = sample_topk(ci_logits, 10, 0.75) # fix to 10 and 0.75
|
| 325 |
+
ci_embed = self._embed_audio(i, ci_sample)
|
| 326 |
+
curr_h = ci_embed
|
| 327 |
+
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
| 328 |
+
curr_pos = curr_pos[:, -1:] + 1
|
| 329 |
+
|
| 330 |
+
return curr_sample
|
| 331 |
+
|
| 332 |
+
def reset_caches(self):
|
| 333 |
+
self.backbone.reset_caches()
|
| 334 |
+
self.decoder.reset_caches()
|
| 335 |
+
|
| 336 |
+
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
| 337 |
+
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
| 338 |
+
|
| 339 |
+
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 340 |
+
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
| 341 |
+
|
| 342 |
+
audio_tokens = tokens[:, :, :-1] + (
|
| 343 |
+
self.config.audio_vocab_size
|
| 344 |
+
* torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
| 345 |
+
)
|
| 346 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 347 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
|
| 355 |
+
MIMI_SAMPLE_RATE = 24000
|
| 356 |
+
BACKBONE_FLAVOR = "qwen-3b"
|
| 357 |
+
DECODER_FLAVOR = "qwen-500m"
|
| 358 |
+
TEXT_VOCAB_SIZE = 128256
|
| 359 |
+
AUDIO_VOCAB_SIZE = 2051
|
| 360 |
+
AUDIO_NUM_CODEBOOKS = 32
|
| 361 |
+
|
| 362 |
+
config = ModelArgs(
|
| 363 |
+
backbone_flavor=BACKBONE_FLAVOR,
|
| 364 |
+
decoder_flavor=DECODER_FLAVOR,
|
| 365 |
+
text_vocab_size=TEXT_VOCAB_SIZE,
|
| 366 |
+
audio_vocab_size=AUDIO_VOCAB_SIZE,
|
| 367 |
+
audio_num_codebooks=AUDIO_NUM_CODEBOOKS,
|
| 368 |
+
decoder_loss_weight=0.5,
|
| 369 |
+
use_text_loss=True,
|
| 370 |
+
)
|
| 371 |
+
model = Model(config)
|
fireredtts2/llm/modules.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchtune.models.qwen2 import qwen2
|
| 2 |
+
from torchtune.modules.transformer import TransformerDecoder
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def qwen2_200M() -> TransformerDecoder:
|
| 6 |
+
return qwen2(
|
| 7 |
+
vocab_size=151936,
|
| 8 |
+
num_layers=4,
|
| 9 |
+
num_heads=12,
|
| 10 |
+
num_kv_heads=2,
|
| 11 |
+
embed_dim=1536,
|
| 12 |
+
intermediate_dim=8960,
|
| 13 |
+
max_seq_len=4096,
|
| 14 |
+
attn_dropout=0.0,
|
| 15 |
+
norm_eps=1e-6,
|
| 16 |
+
rope_base=1000000.0,
|
| 17 |
+
tie_word_embeddings=True,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def qwen2_500M() -> TransformerDecoder:
|
| 22 |
+
return qwen2(
|
| 23 |
+
vocab_size=151936,
|
| 24 |
+
num_layers=24,
|
| 25 |
+
num_heads=14,
|
| 26 |
+
num_kv_heads=2,
|
| 27 |
+
embed_dim=896,
|
| 28 |
+
intermediate_dim=4864,
|
| 29 |
+
max_seq_len=4096,
|
| 30 |
+
attn_dropout=0.0,
|
| 31 |
+
norm_eps=1e-6,
|
| 32 |
+
rope_base=1000000.0,
|
| 33 |
+
tie_word_embeddings=True,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def qwen2_1_5B() -> TransformerDecoder:
|
| 38 |
+
return qwen2(
|
| 39 |
+
vocab_size=151936,
|
| 40 |
+
num_layers=28,
|
| 41 |
+
num_heads=12,
|
| 42 |
+
num_kv_heads=2,
|
| 43 |
+
embed_dim=1536,
|
| 44 |
+
intermediate_dim=8960,
|
| 45 |
+
max_seq_len=4096,
|
| 46 |
+
attn_dropout=0.0,
|
| 47 |
+
norm_eps=1e-6,
|
| 48 |
+
rope_base=1000000.0,
|
| 49 |
+
tie_word_embeddings=True,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def qwen2_3B() -> TransformerDecoder:
|
| 54 |
+
return qwen2(
|
| 55 |
+
vocab_size=151936,
|
| 56 |
+
num_layers=36,
|
| 57 |
+
num_heads=16,
|
| 58 |
+
num_kv_heads=2,
|
| 59 |
+
embed_dim=2048,
|
| 60 |
+
intermediate_dim=11008,
|
| 61 |
+
max_seq_len=4096,
|
| 62 |
+
attn_dropout=0.0,
|
| 63 |
+
norm_eps=1e-6,
|
| 64 |
+
rope_base=1000000.0,
|
| 65 |
+
tie_word_embeddings=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def qwen2_7B() -> TransformerDecoder:
|
| 70 |
+
return qwen2(
|
| 71 |
+
vocab_size=152064,
|
| 72 |
+
num_layers=28,
|
| 73 |
+
num_heads=28,
|
| 74 |
+
num_kv_heads=4,
|
| 75 |
+
embed_dim=3584,
|
| 76 |
+
intermediate_dim=18944,
|
| 77 |
+
max_seq_len=4096,
|
| 78 |
+
attn_dropout=0.0,
|
| 79 |
+
norm_eps=1e-6,
|
| 80 |
+
rope_base=1000000.0,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
FLAVORS = {
|
| 85 |
+
"qwen-200m": qwen2_200M,
|
| 86 |
+
"qwen-500m": qwen2_500M,
|
| 87 |
+
"qwen-1.5b": qwen2_1_5B,
|
| 88 |
+
"qwen-3b": qwen2_3B,
|
| 89 |
+
"qwen-7b": qwen2_7B,
|
| 90 |
+
}
|
fireredtts2/llm/utils.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Union
|
| 8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from fireredtts2.llm.llm import Model, ModelArgs
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class Segment:
|
| 15 |
+
speaker: str
|
| 16 |
+
text: str
|
| 17 |
+
audio: torch.Tensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class WarmupDecayLR(LambdaLR):
|
| 21 |
+
"""
|
| 22 |
+
Learning rate scheduler with a linear warmup and specificable decay.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self, optimizer, warmup_steps: int, total_steps: int, decay_type: str = "linear"
|
| 27 |
+
):
|
| 28 |
+
self.warmup_steps = warmup_steps
|
| 29 |
+
self.total_steps = total_steps
|
| 30 |
+
self.decay_type = decay_type
|
| 31 |
+
super().__init__(optimizer, self.lr_lambda, last_epoch=-1)
|
| 32 |
+
|
| 33 |
+
def lr_lambda(self, step: int) -> float:
|
| 34 |
+
if step < self.warmup_steps:
|
| 35 |
+
return step / self.warmup_steps
|
| 36 |
+
else:
|
| 37 |
+
if self.decay_type == "linear":
|
| 38 |
+
return (self.total_steps - step) / (
|
| 39 |
+
self.total_steps - self.warmup_steps
|
| 40 |
+
)
|
| 41 |
+
elif self.decay_type == "constant":
|
| 42 |
+
return 1.0
|
| 43 |
+
elif self.decay_type == "exponential":
|
| 44 |
+
return 0.1 ** (
|
| 45 |
+
(step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 46 |
+
)
|
| 47 |
+
elif self.decay_type == "cosine":
|
| 48 |
+
return 0.5 * (
|
| 49 |
+
1
|
| 50 |
+
+ torch.cos(
|
| 51 |
+
torch.pi
|
| 52 |
+
* torch.tensor(
|
| 53 |
+
(step - self.warmup_steps)
|
| 54 |
+
/ (self.total_steps - self.warmup_steps)
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Invalid decay type: {self.decay_type}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
additional_special_tokens = [
|
| 63 |
+
"<|text_start|>",
|
| 64 |
+
"<|text_end|>",
|
| 65 |
+
"[S1]",
|
| 66 |
+
"[S2]",
|
| 67 |
+
"[S3]",
|
| 68 |
+
"[S4]",
|
| 69 |
+
"[S5]",
|
| 70 |
+
"[S6]",
|
| 71 |
+
"[S7]",
|
| 72 |
+
"[S8]",
|
| 73 |
+
"[S9]",
|
| 74 |
+
"[S10]",
|
| 75 |
+
"[S11]",
|
| 76 |
+
"[S12]",
|
| 77 |
+
"[S13]",
|
| 78 |
+
"[S14]",
|
| 79 |
+
"[S15]",
|
| 80 |
+
"[S16]",
|
| 81 |
+
"[S17]",
|
| 82 |
+
"[S18]",
|
| 83 |
+
"[S19]",
|
| 84 |
+
"[S20]",
|
| 85 |
+
"[S21]",
|
| 86 |
+
"[S22]",
|
| 87 |
+
"[S23]",
|
| 88 |
+
"[S24]",
|
| 89 |
+
"[S25]",
|
| 90 |
+
"[S26]",
|
| 91 |
+
"[S27]",
|
| 92 |
+
"[S28]",
|
| 93 |
+
"[S29]",
|
| 94 |
+
"[S30]",
|
| 95 |
+
"[S31]",
|
| 96 |
+
"[S32]",
|
| 97 |
+
"[S33]",
|
| 98 |
+
"[S34]",
|
| 99 |
+
"[S35]",
|
| 100 |
+
"[S36]",
|
| 101 |
+
"[S37]",
|
| 102 |
+
"[S38]",
|
| 103 |
+
"[S39]",
|
| 104 |
+
"[S40]",
|
| 105 |
+
"[S_PODCAST_1]",
|
| 106 |
+
"[S_PODCAST_2]",
|
| 107 |
+
"[S_PODCAST_3]",
|
| 108 |
+
"[S_PODCAST_4]",
|
| 109 |
+
"[S_PODCAST_5]",
|
| 110 |
+
"[S_PODCAST_6]",
|
| 111 |
+
"[S_PODCAST_7]",
|
| 112 |
+
"[S_PODCAST_8]",
|
| 113 |
+
"[S_PODCAST_9]",
|
| 114 |
+
"[S_PODCAST_10]",
|
| 115 |
+
"[S_DIALOG_1]",
|
| 116 |
+
"[S_DIALOG_2]",
|
| 117 |
+
"[S_DIALOG_3]",
|
| 118 |
+
"[S_DIALOG_4]",
|
| 119 |
+
"[S_DIALOG_5]",
|
| 120 |
+
"[S_DIALOG_6]",
|
| 121 |
+
"[S_DIALOG_7]",
|
| 122 |
+
"[S_DIALOG_8]",
|
| 123 |
+
"[S_DIALOG_9]",
|
| 124 |
+
"[S_DIALOG_10]",
|
| 125 |
+
"<|emotion_neutral|>",
|
| 126 |
+
"<|emotion_happy|>",
|
| 127 |
+
"<|emotion_sad|>",
|
| 128 |
+
"<|emotion_concern|>",
|
| 129 |
+
"<|emotion_confuse|>",
|
| 130 |
+
"<|emotion_angry|>",
|
| 131 |
+
"<|emotion_surprise|>",
|
| 132 |
+
"<|emotion_disgust|>",
|
| 133 |
+
"<|emotion_nervous|>",
|
| 134 |
+
"<|emotion_apology|>",
|
| 135 |
+
"<|emotion_understand|>",
|
| 136 |
+
"<|emotion_fear|>",
|
| 137 |
+
"<|emotion_comfort|>",
|
| 138 |
+
"<|emotion_shy|>",
|
| 139 |
+
"<|emotion_serious|>",
|
| 140 |
+
"<|emotion_extra1|>",
|
| 141 |
+
"<|emotion_extra2|>",
|
| 142 |
+
"<|emotion_extra3|>",
|
| 143 |
+
"<|emotion_extra4|>",
|
| 144 |
+
"<|emotion_extra5|>",
|
| 145 |
+
"<|emotion_extra6|>",
|
| 146 |
+
"<|emotion_extra7|>",
|
| 147 |
+
"<|emotion_extra8|>",
|
| 148 |
+
"<|emotion_extra9|>",
|
| 149 |
+
"<|emotion_extra10|>",
|
| 150 |
+
"<|breath|>",
|
| 151 |
+
"<|humph|>",
|
| 152 |
+
"<|laugh_heng|>",
|
| 153 |
+
"<|hissing|>",
|
| 154 |
+
"<|sniff|>",
|
| 155 |
+
"<|laugh_he|>",
|
| 156 |
+
"<|sigh|>",
|
| 157 |
+
"<|laugh|>",
|
| 158 |
+
"<|laugh_ha|>",
|
| 159 |
+
"<|quick_breath|>",
|
| 160 |
+
"<|laugh_hei|>",
|
| 161 |
+
"<|laugh_speak|>",
|
| 162 |
+
"<|/laugh_speak|>",
|
| 163 |
+
"<|cry|>",
|
| 164 |
+
"<|choking|>",
|
| 165 |
+
"<|cry_speak|>",
|
| 166 |
+
"<|/cry_speak|>",
|
| 167 |
+
"<|slurp|>",
|
| 168 |
+
"<|clucking|>",
|
| 169 |
+
"<|yawning|>",
|
| 170 |
+
"<|cough|>",
|
| 171 |
+
"<|smack|>",
|
| 172 |
+
"<|hem|>",
|
| 173 |
+
"<|stretch|>",
|
| 174 |
+
"<|sneeze|>",
|
| 175 |
+
"<|paralinguistic_extra1|>",
|
| 176 |
+
"<|paralinguistic_extra2|>",
|
| 177 |
+
"<|paralinguistic_extra3|>",
|
| 178 |
+
"<|paralinguistic_extra4|>",
|
| 179 |
+
"<|paralinguistic_extra5|>",
|
| 180 |
+
"<|paralinguistic_extra6|>",
|
| 181 |
+
"<|paralinguistic_extra7|>",
|
| 182 |
+
"<|paralinguistic_extra8|>",
|
| 183 |
+
"<|paralinguistic_extra10|>",
|
| 184 |
+
"<|paralinguistic_extra11|>",
|
| 185 |
+
"<|paralinguistic_extra12|>",
|
| 186 |
+
"<|paralinguistic_extra13|>",
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def load_custom_tokenizer(qwen2_tokenizer_path: str):
|
| 191 |
+
tok = AutoTokenizer.from_pretrained(qwen2_tokenizer_path)
|
| 192 |
+
special_tokens_dict = {
|
| 193 |
+
"additional_special_tokens": additional_special_tokens,
|
| 194 |
+
}
|
| 195 |
+
tok.add_special_tokens(special_tokens_dict)
|
| 196 |
+
return tok
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def init_weights(model: nn.Module):
|
| 200 |
+
"""
|
| 201 |
+
Initialize the weights of the model.
|
| 202 |
+
- Xavier uniform initialization for linear layers
|
| 203 |
+
- Normal initialization for embeddings
|
| 204 |
+
- Xavier uniform initialization for parameters
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def _init_weights(m):
|
| 208 |
+
if isinstance(m, nn.Linear):
|
| 209 |
+
nn.init.xavier_uniform_(m.weight)
|
| 210 |
+
if m.bias is not None:
|
| 211 |
+
nn.init.zeros_(m.bias)
|
| 212 |
+
elif isinstance(m, nn.Embedding):
|
| 213 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 214 |
+
elif isinstance(m, nn.Parameter):
|
| 215 |
+
nn.init.xavier_uniform_(m.data)
|
| 216 |
+
|
| 217 |
+
model.apply(_init_weights)
|
| 218 |
+
|
| 219 |
+
# Special handling for audio_head because it's nn.Parameter directly
|
| 220 |
+
nn.init.xavier_uniform_(model.audio_head)
|
| 221 |
+
|
| 222 |
+
return model
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def load_llm_model(
|
| 226 |
+
configs,
|
| 227 |
+
checkpoint_path: Union[str, Path] = None,
|
| 228 |
+
device: Union[str, torch.device] = "cuda",
|
| 229 |
+
) -> Model:
|
| 230 |
+
"""Load model, add forward method, and move to device.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
model_name_or_checkpoint_path: Name or path of pretrained model or checkpoint.
|
| 234 |
+
device: Device to move the model to.
|
| 235 |
+
decoder_loss_weight: Decoder loss weight.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
model_arg = ModelArgs(
|
| 239 |
+
backbone_flavor=configs["llm_models"]["backbone_flavor"],
|
| 240 |
+
decoder_flavor=configs["llm_models"]["decoder_flavor"],
|
| 241 |
+
text_vocab_size=configs["llm_models"]["text_vocab_size"],
|
| 242 |
+
audio_vocab_size=configs["llm_models"]["audio_vocab_size"],
|
| 243 |
+
audio_num_codebooks=configs["llm_models"]["audio_num_codebooks"],
|
| 244 |
+
decoder_loss_weight=configs["llm_models"]["decoder_loss_weight"],
|
| 245 |
+
use_text_loss=True,
|
| 246 |
+
)
|
| 247 |
+
model = Model(model_arg)
|
| 248 |
+
|
| 249 |
+
if checkpoint_path and os.path.exists(checkpoint_path):
|
| 250 |
+
state_dict = torch.load(
|
| 251 |
+
checkpoint_path, map_location="cpu", weights_only=False
|
| 252 |
+
)["model"]
|
| 253 |
+
model.load_state_dict(state_dict)
|
| 254 |
+
else:
|
| 255 |
+
model = init_weights(model)
|
| 256 |
+
|
| 257 |
+
model = model.to(device=device)
|
| 258 |
+
return model
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def summarize(
|
| 262 |
+
writer,
|
| 263 |
+
global_step,
|
| 264 |
+
scalars={},
|
| 265 |
+
histograms={},
|
| 266 |
+
images={},
|
| 267 |
+
audios={},
|
| 268 |
+
audio_sampling_rate=22050,
|
| 269 |
+
):
|
| 270 |
+
for k, v in scalars.items():
|
| 271 |
+
writer.add_scalar(k, v, global_step)
|
| 272 |
+
for k, v in histograms.items():
|
| 273 |
+
writer.add_histogram(k, v, global_step)
|
| 274 |
+
for k, v in images.items():
|
| 275 |
+
writer.add_image(k, v, global_step, dataformats="HWC")
|
| 276 |
+
for k, v in audios.items():
|
| 277 |
+
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_grad_norm(model):
|
| 281 |
+
total_norm = 0
|
| 282 |
+
num = 0
|
| 283 |
+
for name, p in model.named_parameters():
|
| 284 |
+
try:
|
| 285 |
+
param_norm = p.grad.data.norm(2)
|
| 286 |
+
total_norm += param_norm.item() ** 2
|
| 287 |
+
num += 1
|
| 288 |
+
except:
|
| 289 |
+
print(name)
|
| 290 |
+
total_norm = total_norm ** (1.0 / 2)
|
| 291 |
+
total_norm = total_norm / num
|
| 292 |
+
return total_norm
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def read_jsonl(path):
|
| 296 |
+
path = os.path.expanduser(path)
|
| 297 |
+
with open(path, "r") as f:
|
| 298 |
+
json_str = f.read()
|
| 299 |
+
data_list = []
|
| 300 |
+
for line in json_str.splitlines():
|
| 301 |
+
data = json.loads(line)
|
| 302 |
+
data_list.append(data)
|
| 303 |
+
return data_list
|
fireredtts2/utils/spliter.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import string
|
| 3 |
+
|
| 4 |
+
SYMBOLS_MAPPING = {
|
| 5 |
+
"\n": "",
|
| 6 |
+
"\t": "",
|
| 7 |
+
"…": ",",
|
| 8 |
+
"“": "'",
|
| 9 |
+
"”": "'",
|
| 10 |
+
"‘": "'",
|
| 11 |
+
"’": "'",
|
| 12 |
+
"【": "",
|
| 13 |
+
"】": "",
|
| 14 |
+
"[": "",
|
| 15 |
+
"]": "",
|
| 16 |
+
"(": "",
|
| 17 |
+
")": "",
|
| 18 |
+
"(": "",
|
| 19 |
+
")": "",
|
| 20 |
+
"・": "",
|
| 21 |
+
"·": "",
|
| 22 |
+
"「": "'",
|
| 23 |
+
"」": "'",
|
| 24 |
+
"《": "'",
|
| 25 |
+
"》": "'",
|
| 26 |
+
"—": "",
|
| 27 |
+
"~": ",",
|
| 28 |
+
"~": ",",
|
| 29 |
+
":": ",",
|
| 30 |
+
";": ",",
|
| 31 |
+
";": ",",
|
| 32 |
+
":": ",",
|
| 33 |
+
'"': "",
|
| 34 |
+
"!": ",",
|
| 35 |
+
# "!": ".",
|
| 36 |
+
"————": "",
|
| 37 |
+
"——": "",
|
| 38 |
+
"—": "",
|
| 39 |
+
"……": ",",
|
| 40 |
+
"*": "",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
REPLACE_SYMBOL_REGEX = re.compile(
|
| 44 |
+
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
EMOJI_REGEX = re.compile(
|
| 49 |
+
"["
|
| 50 |
+
"\U0001f600-\U0001f64f" # emoticons
|
| 51 |
+
"\U0001f300-\U0001f5ff" # symbols & pictographs
|
| 52 |
+
"\U0001f680-\U0001f6ff" # transport & map symbols
|
| 53 |
+
"\U0001f1e0-\U0001f1ff" # flags (iOS)
|
| 54 |
+
"]+",
|
| 55 |
+
flags=re.UNICODE,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def clean_text(text):
|
| 60 |
+
# Clean the text
|
| 61 |
+
text = text.strip()
|
| 62 |
+
text = text.replace("\xa0", "")
|
| 63 |
+
|
| 64 |
+
# Replace all chinese symbols with their english counterparts
|
| 65 |
+
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
| 66 |
+
|
| 67 |
+
# Remove emojis
|
| 68 |
+
text = EMOJI_REGEX.sub(r"", text)
|
| 69 |
+
|
| 70 |
+
# Remove continuous periods (...) and commas (,,,)
|
| 71 |
+
text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
|
| 72 |
+
|
| 73 |
+
return text
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def utf_8_len(text):
|
| 77 |
+
return len(text.encode("utf-8"))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def break_text(texts, length, splits: set):
|
| 81 |
+
for text in texts:
|
| 82 |
+
if utf_8_len(text) <= length:
|
| 83 |
+
yield text
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
curr = ""
|
| 87 |
+
for char in text:
|
| 88 |
+
curr += char
|
| 89 |
+
|
| 90 |
+
if char in splits:
|
| 91 |
+
yield curr
|
| 92 |
+
curr = ""
|
| 93 |
+
|
| 94 |
+
if curr:
|
| 95 |
+
yield curr
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def break_text_by_length(texts, length):
|
| 99 |
+
for text in texts:
|
| 100 |
+
if utf_8_len(text) <= length:
|
| 101 |
+
yield text
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
curr = ""
|
| 105 |
+
for char in text:
|
| 106 |
+
curr += char
|
| 107 |
+
|
| 108 |
+
if utf_8_len(curr) >= length:
|
| 109 |
+
yield curr
|
| 110 |
+
curr = ""
|
| 111 |
+
|
| 112 |
+
if curr:
|
| 113 |
+
yield curr
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def add_cleaned(curr, segments):
|
| 117 |
+
curr = curr.strip()
|
| 118 |
+
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
| 119 |
+
segments.append(curr)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def protect_float(text):
|
| 123 |
+
# Turns 3.14 into <3_f_14> to prevent splitting
|
| 124 |
+
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def unprotect_float(text):
|
| 128 |
+
# Turns <3_f_14> into 3.14
|
| 129 |
+
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def split_text(text, length):
|
| 133 |
+
text = clean_text(text)
|
| 134 |
+
|
| 135 |
+
# Break the text into pieces with following rules:
|
| 136 |
+
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
| 137 |
+
# 2. If the text is longer than length, split at ","
|
| 138 |
+
# 3. If the text is still longer than length, split at " "
|
| 139 |
+
# 4. If the text is still longer than length, split at any character to length
|
| 140 |
+
|
| 141 |
+
texts = [text]
|
| 142 |
+
texts = map(protect_float, texts)
|
| 143 |
+
texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
|
| 144 |
+
texts = map(unprotect_float, texts)
|
| 145 |
+
texts = break_text(texts, length, {",", ","})
|
| 146 |
+
texts = break_text(texts, length, {" "})
|
| 147 |
+
texts = list(break_text_by_length(texts, length))
|
| 148 |
+
|
| 149 |
+
# Then, merge the texts into segments with length <= length
|
| 150 |
+
segments = []
|
| 151 |
+
curr = ""
|
| 152 |
+
|
| 153 |
+
for text in texts:
|
| 154 |
+
if utf_8_len(curr) + utf_8_len(text) <= length:
|
| 155 |
+
curr += text
|
| 156 |
+
else:
|
| 157 |
+
add_cleaned(curr, segments)
|
| 158 |
+
curr = text
|
| 159 |
+
|
| 160 |
+
if curr:
|
| 161 |
+
add_cleaned(curr, segments)
|
| 162 |
+
|
| 163 |
+
return segments
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def contains_chinese(text):
|
| 167 |
+
"""检测文本是否包含中文字符"""
|
| 168 |
+
return bool(re.search(r"[\u4e00-\u9fff]", text))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def count_words_english(text):
|
| 172 |
+
"""统计英文单词数量"""
|
| 173 |
+
return len(text.split())
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def count_characters_chinese(text):
|
| 177 |
+
"""统计中文字符数量"""
|
| 178 |
+
return len(text)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def split_by_punctuation_english(text):
|
| 182 |
+
"""按英文标点符号分割"""
|
| 183 |
+
sentences = re.split(r"([.!?])", text)
|
| 184 |
+
result = []
|
| 185 |
+
for i in range(0, len(sentences) - 1, 2):
|
| 186 |
+
sentence = sentences[i].strip()
|
| 187 |
+
if sentence:
|
| 188 |
+
if i + 1 < len(sentences):
|
| 189 |
+
sentence += sentences[i + 1]
|
| 190 |
+
result.append(sentence)
|
| 191 |
+
|
| 192 |
+
if len(sentences) % 2 == 1 and sentences[-1].strip():
|
| 193 |
+
result.append(sentences[-1].strip())
|
| 194 |
+
|
| 195 |
+
return result
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def split_by_punctuation_chinese(text):
|
| 199 |
+
"""按中文标点符号分割"""
|
| 200 |
+
sentences = re.split(r"([。!?])", text)
|
| 201 |
+
result = []
|
| 202 |
+
for i in range(0, len(sentences) - 1, 2):
|
| 203 |
+
sentence = sentences[i].strip()
|
| 204 |
+
if sentence:
|
| 205 |
+
if i + 1 < len(sentences):
|
| 206 |
+
sentence += sentences[i + 1]
|
| 207 |
+
result.append(sentence)
|
| 208 |
+
|
| 209 |
+
if len(sentences) % 2 == 1 and sentences[-1].strip():
|
| 210 |
+
result.append(sentences[-1].strip())
|
| 211 |
+
|
| 212 |
+
return result
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def merge_sentences_english(sentences, max_words=80):
|
| 216 |
+
"""合并英文句子"""
|
| 217 |
+
result = []
|
| 218 |
+
current_chunk = ""
|
| 219 |
+
|
| 220 |
+
for sentence in sentences:
|
| 221 |
+
if not current_chunk:
|
| 222 |
+
current_chunk = sentence
|
| 223 |
+
else:
|
| 224 |
+
test_chunk = current_chunk + " " + sentence
|
| 225 |
+
if count_words_english(test_chunk) <= max_words:
|
| 226 |
+
current_chunk = test_chunk
|
| 227 |
+
else:
|
| 228 |
+
result.append(current_chunk)
|
| 229 |
+
current_chunk = sentence
|
| 230 |
+
|
| 231 |
+
if current_chunk:
|
| 232 |
+
result.append(current_chunk)
|
| 233 |
+
|
| 234 |
+
return result
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def merge_sentences_chinese(sentences, max_chars=100):
|
| 238 |
+
"""合并中文句子"""
|
| 239 |
+
result = []
|
| 240 |
+
current_chunk = ""
|
| 241 |
+
|
| 242 |
+
for sentence in sentences:
|
| 243 |
+
if not current_chunk:
|
| 244 |
+
current_chunk = sentence
|
| 245 |
+
else:
|
| 246 |
+
test_chunk = current_chunk + sentence
|
| 247 |
+
if count_characters_chinese(test_chunk) <= max_chars:
|
| 248 |
+
current_chunk = test_chunk
|
| 249 |
+
else:
|
| 250 |
+
result.append(current_chunk)
|
| 251 |
+
current_chunk = sentence
|
| 252 |
+
|
| 253 |
+
if current_chunk:
|
| 254 |
+
result.append(current_chunk)
|
| 255 |
+
|
| 256 |
+
return result
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def process_text(text):
|
| 260 |
+
chinese_max_limit = 150
|
| 261 |
+
english_max_limit = 80
|
| 262 |
+
# 移除开头的标记如[S2]
|
| 263 |
+
text = re.sub(r"^\[S\d+\]", "", text).strip()
|
| 264 |
+
is_chinese = contains_chinese(text)
|
| 265 |
+
if is_chinese:
|
| 266 |
+
if count_characters_chinese(text) <= chinese_max_limit:
|
| 267 |
+
return [text]
|
| 268 |
+
sentences = split_by_punctuation_chinese(text)
|
| 269 |
+
result = merge_sentences_chinese(sentences, chinese_max_limit)
|
| 270 |
+
else:
|
| 271 |
+
if count_words_english(text) <= english_max_limit:
|
| 272 |
+
return [text]
|
| 273 |
+
sentences = split_by_punctuation_english(text)
|
| 274 |
+
result = merge_sentences_english(sentences, english_max_limit)
|
| 275 |
+
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def process_text_list(text_list):
|
| 280 |
+
new_text_list = []
|
| 281 |
+
for text in text_list:
|
| 282 |
+
speaker = text[:4]
|
| 283 |
+
# print("---speaker:", speaker)
|
| 284 |
+
assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
|
| 285 |
+
result = process_text(text=text)
|
| 286 |
+
# print("---result:\n", result, len(result))
|
| 287 |
+
for chunk in result:
|
| 288 |
+
new_text_list.append(speaker + chunk)
|
| 289 |
+
return new_text_list
|
pretrained_models/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Put the pre-trained model in this folder.
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchaudio
|
| 2 |
+
torchtune
|
| 3 |
+
torchao
|
| 4 |
+
transformers
|
| 5 |
+
einops
|
| 6 |
+
gradio
|