Spaces:
Running
on
Zero
Running
on
Zero
Shen Feiyu
commited on
Commit
·
faadabf
1
Parent(s):
6cdc9a3
add 1s
Browse files- app.py +284 -0
- configs/config_24k.json +171 -0
- configs/config_24k_flow.json +125 -0
- fireredtts/models/fireredtts.py +266 -0
- fireredtts/models/token2audio.py +108 -0
- fireredtts/modules/__init__.py +0 -0
- fireredtts/modules/acoustic_codec/__init__.py +1 -0
- fireredtts/modules/acoustic_codec/alias_free_torch/__init__.py +6 -0
- fireredtts/modules/acoustic_codec/alias_free_torch/act.py +35 -0
- fireredtts/modules/acoustic_codec/alias_free_torch/filter.py +99 -0
- fireredtts/modules/acoustic_codec/alias_free_torch/resample.py +58 -0
- fireredtts/modules/acoustic_codec/bigcodec.py +698 -0
- fireredtts/modules/acoustic_codec/vector_quantization.py +580 -0
- fireredtts/modules/acoustic_llm/__init__.py +1 -0
- fireredtts/modules/acoustic_llm/acoustic_llm.py +876 -0
- fireredtts/modules/bigvgan/__init__.py +2 -0
- fireredtts/modules/bigvgan/activations.py +126 -0
- fireredtts/modules/bigvgan/alias_free_torch/__init__.py +5 -0
- fireredtts/modules/bigvgan/alias_free_torch/act.py +29 -0
- fireredtts/modules/bigvgan/alias_free_torch/filter.py +98 -0
- fireredtts/modules/bigvgan/alias_free_torch/resample.py +57 -0
- fireredtts/modules/bigvgan/bigvgan.py +369 -0
- fireredtts/modules/bigvgan/mel_spectrogram.py +111 -0
- fireredtts/modules/flowmatching/__init__.py +18 -0
- fireredtts/modules/flowmatching/estimator_dit.py +356 -0
- fireredtts/modules/flowmatching/flow.py +138 -0
- fireredtts/modules/flowmatching/upsample_encoder.py +617 -0
- fireredtts/modules/semantic_llm/llm_gpt2.py +608 -0
- fireredtts/modules/semantic_tokenizer/__init__.py +36 -0
- fireredtts/modules/semantic_tokenizer/audio.py +138 -0
- fireredtts/modules/semantic_tokenizer/ecapa_tdnn.py +931 -0
- fireredtts/modules/semantic_tokenizer/hubert.py +108 -0
- fireredtts/modules/semantic_tokenizer/semantic_tokenizer.py +877 -0
- fireredtts/modules/text_normalizer/__init__.py +0 -0
- fireredtts/modules/text_normalizer/normalize.py +183 -0
- fireredtts/modules/text_normalizer/regex_common.py +23 -0
- fireredtts/modules/text_normalizer/utils.py +171 -0
- fireredtts/setup.py +3 -0
- fireredtts/utils/__init__.py +0 -0
- fireredtts/utils/spliter.py +161 -0
- fireredtts/utils/utils.py +37 -0
- pre-requirements.txt +1 -0
- pretrained_models/README.md +3 -0
- requirements.txt +12 -0
app.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment setup
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append(str(Path(__file__).parent))
|
| 6 |
+
# FIXME add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315
|
| 7 |
+
if os.path.exists('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py'):
|
| 8 |
+
file_lines = []
|
| 9 |
+
with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'r') as f:
|
| 10 |
+
for line in f:
|
| 11 |
+
file_lines.append(line.strip('\n'))
|
| 12 |
+
file_lines[314] = file_lines[314].replace(
|
| 13 |
+
"state = torch.load(f, map_location=torch.device(\"cpu\"))",
|
| 14 |
+
"state = torch.load(f, map_location=torch.device(\"cpu\"), weights_only=False)"
|
| 15 |
+
)
|
| 16 |
+
with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'w') as f:
|
| 17 |
+
for line in file_lines:
|
| 18 |
+
f.write(line+'\n')
|
| 19 |
+
print('[DEBUG] added weights_only=False')
|
| 20 |
+
# Run
|
| 21 |
+
import spaces
|
| 22 |
+
import gradio as gr
|
| 23 |
+
from zipfile import ZipFile
|
| 24 |
+
from typing import Literal
|
| 25 |
+
from huggingface_hub import snapshot_download
|
| 26 |
+
from fireredtts.models.fireredtts import FireRedTTS
|
| 27 |
+
# NOTE disable verbose INFO logs
|
| 28 |
+
import logging
|
| 29 |
+
httpx_logger = logging.getLogger("httpx")
|
| 30 |
+
httpx_logger.setLevel(logging.WARNING)
|
| 31 |
+
|
| 32 |
+
# NOTE Some launching setups
|
| 33 |
+
# - install fairseq manually ("python -m pip install pip==24.0")
|
| 34 |
+
# - manually add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ================================================
|
| 38 |
+
# FireRedTTS1s Model
|
| 39 |
+
# ================================================
|
| 40 |
+
# Global model instance
|
| 41 |
+
tts_flow: FireRedTTS = None
|
| 42 |
+
tts_acollm: FireRedTTS = None
|
| 43 |
+
def initiate_model(pretrained_dir: str):
|
| 44 |
+
global tts_flow, tts_acollm
|
| 45 |
+
if tts_flow is None:
|
| 46 |
+
tts_flow = FireRedTTS(
|
| 47 |
+
config_path='configs/config_24k_flow.json',
|
| 48 |
+
pretrained_path=pretrained_dir,
|
| 49 |
+
)
|
| 50 |
+
if tts_acollm is None:
|
| 51 |
+
tts_acollm = FireRedTTS(
|
| 52 |
+
config_path='configs/config_24k.json',
|
| 53 |
+
pretrained_path=pretrained_dir,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ================================================
|
| 58 |
+
# Gradio
|
| 59 |
+
# ================================================
|
| 60 |
+
|
| 61 |
+
# i18n
|
| 62 |
+
_i18n_key2lang_dict = dict(
|
| 63 |
+
# Title markdown
|
| 64 |
+
title_md_desc=dict(
|
| 65 |
+
en="FireRedTTS-1s 🔥 Streamable TTS",
|
| 66 |
+
zh="FireRedTTS-1s 🔥 可流式TTS",
|
| 67 |
+
),
|
| 68 |
+
# Decoder choice radio
|
| 69 |
+
decoder_choice_label=dict(
|
| 70 |
+
en="Decoder Choice",
|
| 71 |
+
zh="解码器选择",
|
| 72 |
+
),
|
| 73 |
+
decoder_choice_1=dict(
|
| 74 |
+
en="Flow Matching",
|
| 75 |
+
zh="Flow Matching",
|
| 76 |
+
),
|
| 77 |
+
decoder_choice_2=dict(
|
| 78 |
+
en="Acoustic LLM",
|
| 79 |
+
zh="Acoustic LLM",
|
| 80 |
+
),
|
| 81 |
+
# Speaker Prompt
|
| 82 |
+
spk_prompt_audio_label=dict(
|
| 83 |
+
en="Speaker Prompt Audio",
|
| 84 |
+
zh="参考语音",
|
| 85 |
+
),
|
| 86 |
+
spk_prompt_text_label=dict(
|
| 87 |
+
en="Speaker Prompt Text",
|
| 88 |
+
zh="参考语音的文本",
|
| 89 |
+
),
|
| 90 |
+
spk_prompt_text_placeholder=dict(
|
| 91 |
+
en="Speaker Prompt Text",
|
| 92 |
+
zh="参考语音的文本",
|
| 93 |
+
),
|
| 94 |
+
# Input textbox
|
| 95 |
+
target_text_input_label=dict(
|
| 96 |
+
en="Text To Synthesis",
|
| 97 |
+
zh="待合成文本",
|
| 98 |
+
),
|
| 99 |
+
target_text_input_placeholder=dict(
|
| 100 |
+
en="Text To Synthesis",
|
| 101 |
+
zh="待合成文本",
|
| 102 |
+
),
|
| 103 |
+
# Generate button
|
| 104 |
+
generate_btn_label=dict(
|
| 105 |
+
en="Generate Audio",
|
| 106 |
+
zh="合成",
|
| 107 |
+
),
|
| 108 |
+
# Generated audio
|
| 109 |
+
generated_audio_label=dict(
|
| 110 |
+
en="Generated Audio",
|
| 111 |
+
zh="合成的音频",
|
| 112 |
+
),
|
| 113 |
+
# Warining1: incomplete prompt info
|
| 114 |
+
warn_incomplete_prompt=dict(
|
| 115 |
+
en="Please provide prompt audio and text",
|
| 116 |
+
zh="请提供说话人参考语音与参考文本",
|
| 117 |
+
),
|
| 118 |
+
# Warining2: invalid text for target text input
|
| 119 |
+
warn_invalid_target_text=dict(
|
| 120 |
+
en="Empty input text",
|
| 121 |
+
zh="待合成文本为空",
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
global_lang: Literal['zh', 'en'] = 'zh'
|
| 126 |
+
def i18n(key):
|
| 127 |
+
global global_lang
|
| 128 |
+
return _i18n_key2lang_dict[key][global_lang]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def check_monologue_text(text:str, prefix:str=None)->bool:
|
| 132 |
+
text = text.strip()
|
| 133 |
+
# Check speaker tags
|
| 134 |
+
if prefix is not None and (not text.startswith(prefix)):
|
| 135 |
+
return False
|
| 136 |
+
# Remove prefix
|
| 137 |
+
if prefix is not None:
|
| 138 |
+
text = text.removeprefix(prefix)
|
| 139 |
+
text = text.strip()
|
| 140 |
+
# If empty?
|
| 141 |
+
if len(text) == 0:
|
| 142 |
+
return False
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@spaces.GPU(duration=60)
|
| 147 |
+
def synthesis_function(
|
| 148 |
+
spk_prompt_audio: str,
|
| 149 |
+
spk_prompt_text: str,
|
| 150 |
+
target_text: str,
|
| 151 |
+
decoder_choice: Literal[0, 1] = 0, # 0 means flow matching decoder
|
| 152 |
+
):
|
| 153 |
+
global tts_flow, tts_acollm
|
| 154 |
+
|
| 155 |
+
# Check prompt info
|
| 156 |
+
spk_prompt_text = spk_prompt_text.strip()
|
| 157 |
+
if spk_prompt_audio is None or spk_prompt_text == "":
|
| 158 |
+
gr.Warning(message=i18n('warn_incomplete_prompt'))
|
| 159 |
+
return None
|
| 160 |
+
# Check target text
|
| 161 |
+
target_text = target_text.strip()
|
| 162 |
+
if target_text == "":
|
| 163 |
+
gr.Warning(message=i18n('warn_invalid_target_text'))
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
# Go synthesis
|
| 167 |
+
if decoder_choice == 0:
|
| 168 |
+
audio = tts_flow.synthesize(
|
| 169 |
+
prompt_wav=spk_prompt_audio,
|
| 170 |
+
prompt_text=spk_prompt_text,
|
| 171 |
+
text=target_text,
|
| 172 |
+
lang="zh",
|
| 173 |
+
use_tn=True
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
audio = tts_acollm.synthesize(
|
| 177 |
+
prompt_wav=spk_prompt_audio,
|
| 178 |
+
prompt_text=spk_prompt_text,
|
| 179 |
+
text=target_text,
|
| 180 |
+
lang="zh",
|
| 181 |
+
use_tn=True
|
| 182 |
+
)
|
| 183 |
+
return (24000, audio.detach().cpu().squeeze(0).numpy())
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# UI rendering
|
| 187 |
+
def render_interface()->gr.Blocks:
|
| 188 |
+
with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page:
|
| 189 |
+
# ======================== UI ========================
|
| 190 |
+
# A large title
|
| 191 |
+
title_desc = gr.Markdown(value="# {}".format(i18n('title_md_desc')))
|
| 192 |
+
with gr.Row():
|
| 193 |
+
lang_choice = gr.Radio(
|
| 194 |
+
choices=['中文', 'English'],
|
| 195 |
+
value='中文',
|
| 196 |
+
label='Display Language/显示语言',
|
| 197 |
+
type="index",
|
| 198 |
+
interactive=True,
|
| 199 |
+
)
|
| 200 |
+
decoder_choice = gr.Radio(
|
| 201 |
+
choices=[i18n('decoder_choice_1'), i18n('decoder_choice_2')],
|
| 202 |
+
value=i18n('decoder_choice_1'),
|
| 203 |
+
label=i18n('decoder_choice_label'),
|
| 204 |
+
type="index",
|
| 205 |
+
interactive=True,
|
| 206 |
+
)
|
| 207 |
+
with gr.Row():
|
| 208 |
+
# ==== Speaker Prompt ====
|
| 209 |
+
spk_prompt_text = gr.Textbox(
|
| 210 |
+
label=i18n('spk_prompt_text_label'),
|
| 211 |
+
placeholder=i18n('spk_prompt_text_placeholder'),
|
| 212 |
+
lines=5,
|
| 213 |
+
)
|
| 214 |
+
spk_prompt_audio = gr.Audio(
|
| 215 |
+
label=i18n('spk_prompt_audio_label'),
|
| 216 |
+
type="filepath",
|
| 217 |
+
editable=False,
|
| 218 |
+
interactive=True,
|
| 219 |
+
) # Audio component returns tmp audio path
|
| 220 |
+
# ==== Target Text ====
|
| 221 |
+
target_text_input = gr.Textbox(
|
| 222 |
+
label=i18n('target_text_input_label'),
|
| 223 |
+
placeholder=i18n('target_text_input_placeholder'),
|
| 224 |
+
lines=5,
|
| 225 |
+
)
|
| 226 |
+
# Generate button
|
| 227 |
+
generate_btn = gr.Button(value=i18n('generate_btn_label'), variant="primary", size="lg")
|
| 228 |
+
# Long output audio
|
| 229 |
+
generate_audio = gr.Audio(
|
| 230 |
+
label=i18n('generated_audio_label'),
|
| 231 |
+
interactive=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# ======================== Action ========================
|
| 235 |
+
# Language action
|
| 236 |
+
def _change_component_language(lang):
|
| 237 |
+
global global_lang
|
| 238 |
+
global_lang = ['zh', 'en'][lang]
|
| 239 |
+
return [
|
| 240 |
+
# title_desc
|
| 241 |
+
gr.update(value="# {}".format(i18n('title_md_desc'))),
|
| 242 |
+
# decoder_choice
|
| 243 |
+
gr.update(label=i18n('decoder_choice_label')),
|
| 244 |
+
# spk_prompt_{audio,text}
|
| 245 |
+
gr.update(label=i18n('spk_prompt_text_label'), placeholder=i18n('spk_prompt_text_placeholder')),
|
| 246 |
+
gr.update(label=i18n('spk_prompt_audio_label')),
|
| 247 |
+
# target_text_input
|
| 248 |
+
gr.update(label=i18n('target_text_input_label'), placeholder=i18n('target_text_input_placeholder')),
|
| 249 |
+
# generate_btn
|
| 250 |
+
gr.update(value=i18n('generate_btn_label')),
|
| 251 |
+
# generate_audio
|
| 252 |
+
gr.update(label=i18n('generated_audio_label')),
|
| 253 |
+
]
|
| 254 |
+
lang_choice.change(
|
| 255 |
+
fn=_change_component_language,
|
| 256 |
+
inputs=[lang_choice],
|
| 257 |
+
outputs=[
|
| 258 |
+
title_desc, decoder_choice,
|
| 259 |
+
spk_prompt_text, spk_prompt_audio,
|
| 260 |
+
target_text_input,
|
| 261 |
+
generate_btn, generate_audio,
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
+
generate_btn.click(
|
| 265 |
+
fn=synthesis_function,
|
| 266 |
+
inputs=[spk_prompt_audio, spk_prompt_text, target_text_input, decoder_choice],
|
| 267 |
+
outputs=[generate_audio]
|
| 268 |
+
)
|
| 269 |
+
return page
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == '__main__':
|
| 273 |
+
# Download model
|
| 274 |
+
snapshot_download(repo_id='FireRedTeam/FireRedTTS-1S', local_dir='pretrained_models/FireRedTTS-1S')
|
| 275 |
+
# Unzip model, weights under "pretrained_models/FireRedTTS-1S/pretrained_models"
|
| 276 |
+
with ZipFile('pretrained_models/FireRedTTS-1S/pretrained_models.zip', 'r') as zipf:
|
| 277 |
+
zipf.extractall('pretrained_models/FireRedTTS-1S')
|
| 278 |
+
# Init model
|
| 279 |
+
initiate_model('pretrained_models/FireRedTTS-1S/pretrained_models')
|
| 280 |
+
print('[INFO] model loaded')
|
| 281 |
+
# UI
|
| 282 |
+
page = render_interface()
|
| 283 |
+
page.launch()
|
| 284 |
+
|
configs/config_24k.json
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"semantic_llm": {
|
| 3 |
+
"start_text_token": 32000,
|
| 4 |
+
"stop_text_token": 32001,
|
| 5 |
+
"num_text_tokens": 32002,
|
| 6 |
+
"start_audio_token": 16384,
|
| 7 |
+
"stop_audio_token": 16385,
|
| 8 |
+
"num_audio_tokens": 16386,
|
| 9 |
+
"llm_hidden_size": 1024,
|
| 10 |
+
"llm_intermediate_size": 4096,
|
| 11 |
+
"llm_num_layers": 30,
|
| 12 |
+
"llm_num_heads": 16,
|
| 13 |
+
"llm_max_audio_seq_len": 630,
|
| 14 |
+
"llm_max_text_seq_len": 402,
|
| 15 |
+
"llm_max_prompt_len": 250,
|
| 16 |
+
"code_stride_len": 640,
|
| 17 |
+
"EOS_TOKEN": 16385
|
| 18 |
+
},
|
| 19 |
+
"acoustic_llm": {
|
| 20 |
+
"n_stacks": 1,
|
| 21 |
+
"layers": 24,
|
| 22 |
+
"model_dim": 1536,
|
| 23 |
+
"heads": 16,
|
| 24 |
+
"max_text_tokens": 2048,
|
| 25 |
+
"max_speech_tokens": 2048,
|
| 26 |
+
"max_conditioning_inputs": 1,
|
| 27 |
+
"number_text_tokens": 16386,
|
| 28 |
+
"start_text_token": 16384,
|
| 29 |
+
"stop_text_token": 16385,
|
| 30 |
+
"n_frames_per_step": 1,
|
| 31 |
+
"n_heads_per_frame": 8,
|
| 32 |
+
"delay_prediction": 1,
|
| 33 |
+
"upsample_factors": 1,
|
| 34 |
+
"streaming_delayed_frames": 8,
|
| 35 |
+
"number_speech_tokens": 16386,
|
| 36 |
+
"start_speech_token": 16384,
|
| 37 |
+
"stop_speech_token": 16385,
|
| 38 |
+
"speaker_embedding_pretrained": true,
|
| 39 |
+
"speaker_embedding_ckpt": null,
|
| 40 |
+
"speaker_embedding_dim": 512,
|
| 41 |
+
"temperature": 0.5,
|
| 42 |
+
"repetition_penalty": 2.0,
|
| 43 |
+
"top_p": 0.5,
|
| 44 |
+
"top_k": 25
|
| 45 |
+
},
|
| 46 |
+
"acoustic_codec": {
|
| 47 |
+
"n_model_size": 1024,
|
| 48 |
+
"encoder_config": {
|
| 49 |
+
"ngf": 48,
|
| 50 |
+
"up_ratios": [
|
| 51 |
+
2,
|
| 52 |
+
4,
|
| 53 |
+
4,
|
| 54 |
+
4,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
"causal": true
|
| 58 |
+
},
|
| 59 |
+
"decoder_config": {
|
| 60 |
+
"upsample_initial_channel": 1536,
|
| 61 |
+
"ngf": 48,
|
| 62 |
+
"up_ratios": [
|
| 63 |
+
6,
|
| 64 |
+
5,
|
| 65 |
+
4,
|
| 66 |
+
4,
|
| 67 |
+
2
|
| 68 |
+
],
|
| 69 |
+
"causal": true
|
| 70 |
+
},
|
| 71 |
+
"vq_config": {
|
| 72 |
+
"n_groups": 8,
|
| 73 |
+
"ordered": true,
|
| 74 |
+
"codebook_size": [
|
| 75 |
+
128,
|
| 76 |
+
128,
|
| 77 |
+
128,
|
| 78 |
+
128,
|
| 79 |
+
128,
|
| 80 |
+
128,
|
| 81 |
+
128,
|
| 82 |
+
128,
|
| 83 |
+
128,
|
| 84 |
+
128,
|
| 85 |
+
128,
|
| 86 |
+
128,
|
| 87 |
+
128,
|
| 88 |
+
128,
|
| 89 |
+
128,
|
| 90 |
+
128
|
| 91 |
+
],
|
| 92 |
+
"codebook_dim": [
|
| 93 |
+
8,
|
| 94 |
+
8,
|
| 95 |
+
8,
|
| 96 |
+
8,
|
| 97 |
+
8,
|
| 98 |
+
8,
|
| 99 |
+
8,
|
| 100 |
+
8,
|
| 101 |
+
8,
|
| 102 |
+
8,
|
| 103 |
+
8,
|
| 104 |
+
8,
|
| 105 |
+
8,
|
| 106 |
+
8,
|
| 107 |
+
8,
|
| 108 |
+
8
|
| 109 |
+
],
|
| 110 |
+
"requires_projection": true,
|
| 111 |
+
"decay": 0.99,
|
| 112 |
+
"threshold_ema_dead_code": 0,
|
| 113 |
+
"commitment_weight": 0.01
|
| 114 |
+
},
|
| 115 |
+
"resampler_config": {
|
| 116 |
+
"source_sr": 16000,
|
| 117 |
+
"target_sr": 16000
|
| 118 |
+
}
|
| 119 |
+
},
|
| 120 |
+
"semantic_tokenizer": {
|
| 121 |
+
"in_dim": 1024,
|
| 122 |
+
"out_dim": 80,
|
| 123 |
+
"n_model_size": 512,
|
| 124 |
+
"downsample_scales": [
|
| 125 |
+
1,
|
| 126 |
+
1,
|
| 127 |
+
1,
|
| 128 |
+
2
|
| 129 |
+
],
|
| 130 |
+
"upsample_scales": [
|
| 131 |
+
[
|
| 132 |
+
2,
|
| 133 |
+
1
|
| 134 |
+
],
|
| 135 |
+
[
|
| 136 |
+
2,
|
| 137 |
+
1,
|
| 138 |
+
1,
|
| 139 |
+
1
|
| 140 |
+
]
|
| 141 |
+
],
|
| 142 |
+
"mel_config": {
|
| 143 |
+
"style": "BigVGAN",
|
| 144 |
+
"filter_length": 1024,
|
| 145 |
+
"hop_length": 160,
|
| 146 |
+
"win_length": 640,
|
| 147 |
+
"n_mel_channels": 80,
|
| 148 |
+
"sampling_rate": 16000
|
| 149 |
+
},
|
| 150 |
+
"vq_config": {
|
| 151 |
+
"codebook_size": [
|
| 152 |
+
128,
|
| 153 |
+
128
|
| 154 |
+
],
|
| 155 |
+
"codebook_dim": [
|
| 156 |
+
128,
|
| 157 |
+
128
|
| 158 |
+
],
|
| 159 |
+
"requires_projection": true
|
| 160 |
+
},
|
| 161 |
+
"tree_config": [
|
| 162 |
+
{
|
| 163 |
+
"downsample_rate": 1,
|
| 164 |
+
"n_groups": 1,
|
| 165 |
+
"dropout": 0
|
| 166 |
+
}
|
| 167 |
+
],
|
| 168 |
+
"n_samples_per_token": 640,
|
| 169 |
+
"checkpointing": true
|
| 170 |
+
}
|
| 171 |
+
}
|
configs/config_24k_flow.json
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"semantic_llm": {
|
| 3 |
+
"start_text_token": 32000,
|
| 4 |
+
"stop_text_token": 32001,
|
| 5 |
+
"num_text_tokens": 32002,
|
| 6 |
+
"start_audio_token": 16384,
|
| 7 |
+
"stop_audio_token": 16385,
|
| 8 |
+
"num_audio_tokens": 16386,
|
| 9 |
+
"llm_hidden_size": 1024,
|
| 10 |
+
"llm_intermediate_size": 4096,
|
| 11 |
+
"llm_num_layers": 30,
|
| 12 |
+
"llm_num_heads": 16,
|
| 13 |
+
"llm_max_audio_seq_len": 630,
|
| 14 |
+
"llm_max_text_seq_len": 402,
|
| 15 |
+
"llm_max_prompt_len": 250,
|
| 16 |
+
"code_stride_len": 640,
|
| 17 |
+
"EOS_TOKEN": 16385
|
| 18 |
+
},
|
| 19 |
+
"flow": {
|
| 20 |
+
"spk_channels": 512,
|
| 21 |
+
"spk_enc_channels": 80,
|
| 22 |
+
"infer_cfg_rate": 0.7,
|
| 23 |
+
"token_emb": {
|
| 24 |
+
"channels": 512
|
| 25 |
+
},
|
| 26 |
+
"encoder": {
|
| 27 |
+
"input_size": 512,
|
| 28 |
+
"output_size": 512,
|
| 29 |
+
"num_blocks": 6,
|
| 30 |
+
"num_up_blocks": 4,
|
| 31 |
+
"normalize_before": true,
|
| 32 |
+
"up_stride": 2,
|
| 33 |
+
"pre_lookahead_len": 3,
|
| 34 |
+
"attention_heads": 4,
|
| 35 |
+
"key_bias": true,
|
| 36 |
+
"linear_units": 2048,
|
| 37 |
+
"dropout_rate": 0.0,
|
| 38 |
+
"positional_dropout_rate": 0.0,
|
| 39 |
+
"attention_dropout_rate": 0.0
|
| 40 |
+
},
|
| 41 |
+
"estimator": {
|
| 42 |
+
"in_channels": 320,
|
| 43 |
+
"out_channels": 80,
|
| 44 |
+
"mlp_ratio": 4,
|
| 45 |
+
"depth": 16,
|
| 46 |
+
"num_heads": 8,
|
| 47 |
+
"head_dim": 64,
|
| 48 |
+
"hidden_size": 512
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"mel": {
|
| 52 |
+
"num_mels": 80,
|
| 53 |
+
"n_fft": 1920,
|
| 54 |
+
"hop_size": 480,
|
| 55 |
+
"win_size": 1920,
|
| 56 |
+
"sampling_rate": 24000,
|
| 57 |
+
"fmin": 0,
|
| 58 |
+
"fmax": 8000,
|
| 59 |
+
"center": false
|
| 60 |
+
},
|
| 61 |
+
"bigvgan": {
|
| 62 |
+
"num_mels": 80,
|
| 63 |
+
"upsample_initial_channel": 1536,
|
| 64 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
| 65 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 66 |
+
"upsample_rates": [5, 4, 3, 2, 2, 2],
|
| 67 |
+
"upsample_kernel_sizes": [11, 8, 7, 4, 4, 4],
|
| 68 |
+
"resblock_type": "1",
|
| 69 |
+
"snake_logscale": true,
|
| 70 |
+
"activation": "snakebeta",
|
| 71 |
+
"use_tanh_at_final": false,
|
| 72 |
+
"use_bias_at_final": false
|
| 73 |
+
},
|
| 74 |
+
"semantic_tokenizer": {
|
| 75 |
+
"in_dim": 1024,
|
| 76 |
+
"out_dim": 80,
|
| 77 |
+
"n_model_size": 512,
|
| 78 |
+
"downsample_scales": [
|
| 79 |
+
1,
|
| 80 |
+
1,
|
| 81 |
+
1,
|
| 82 |
+
2
|
| 83 |
+
],
|
| 84 |
+
"upsample_scales": [
|
| 85 |
+
[
|
| 86 |
+
2,
|
| 87 |
+
1
|
| 88 |
+
],
|
| 89 |
+
[
|
| 90 |
+
2,
|
| 91 |
+
1,
|
| 92 |
+
1,
|
| 93 |
+
1
|
| 94 |
+
]
|
| 95 |
+
],
|
| 96 |
+
"mel_config": {
|
| 97 |
+
"style": "BigVGAN",
|
| 98 |
+
"filter_length": 1024,
|
| 99 |
+
"hop_length": 160,
|
| 100 |
+
"win_length": 640,
|
| 101 |
+
"n_mel_channels": 80,
|
| 102 |
+
"sampling_rate": 16000
|
| 103 |
+
},
|
| 104 |
+
"vq_config": {
|
| 105 |
+
"codebook_size": [
|
| 106 |
+
128,
|
| 107 |
+
128
|
| 108 |
+
],
|
| 109 |
+
"codebook_dim": [
|
| 110 |
+
128,
|
| 111 |
+
128
|
| 112 |
+
],
|
| 113 |
+
"requires_projection": true
|
| 114 |
+
},
|
| 115 |
+
"tree_config": [
|
| 116 |
+
{
|
| 117 |
+
"downsample_rate": 1,
|
| 118 |
+
"n_groups": 1,
|
| 119 |
+
"dropout": 0
|
| 120 |
+
}
|
| 121 |
+
],
|
| 122 |
+
"n_samples_per_token": 640,
|
| 123 |
+
"checkpointing": true
|
| 124 |
+
}
|
| 125 |
+
}
|
fireredtts/models/fireredtts.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from traceback import format_exc
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
from fireredtts.utils.utils import load_audio
|
| 10 |
+
from fireredtts.modules.text_normalizer.utils import text_split
|
| 11 |
+
from fireredtts.utils.spliter import clean_text
|
| 12 |
+
from fireredtts.modules.text_normalizer.normalize import TextNormalizer
|
| 13 |
+
from fireredtts.modules.semantic_tokenizer import SemanticTokenizer
|
| 14 |
+
from fireredtts.modules.semantic_llm.llm_gpt2 import Speech_LLM_GPT2
|
| 15 |
+
from fireredtts.models.token2audio import TwoStageCodec, FlowToken2Audio
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FireRedTTS:
|
| 19 |
+
def __init__(self, config_path, pretrained_path, device="cuda"):
|
| 20 |
+
self.device = device
|
| 21 |
+
self.config = json.load(open(config_path))
|
| 22 |
+
self.EOS_TOKEN = self.config["semantic_llm"]["EOS_TOKEN"]
|
| 23 |
+
|
| 24 |
+
# pretrained models
|
| 25 |
+
self.tokenizer_path = os.path.join(pretrained_path, "tokenizer")
|
| 26 |
+
self.speech_tokenizer_path = os.path.join(pretrained_path, "speech_tokenizer")
|
| 27 |
+
self.semantic_llm_path = os.path.join(pretrained_path, "semantic_llm.pt")
|
| 28 |
+
assert os.path.exists(self.tokenizer_path)
|
| 29 |
+
assert os.path.exists(self.speech_tokenizer_path)
|
| 30 |
+
assert os.path.exists(self.semantic_llm_path)
|
| 31 |
+
if 'acoustic_llm' in self.config:
|
| 32 |
+
self.acoustic_llm_path = os.path.join(pretrained_path, "acoustic_llm.bin")
|
| 33 |
+
self.acoustic_codec_path = os.path.join(pretrained_path, "acoustic_codec.bin")
|
| 34 |
+
assert os.path.exists(self.acoustic_llm_path)
|
| 35 |
+
assert os.path.exists(self.acoustic_codec_path)
|
| 36 |
+
else:
|
| 37 |
+
self.flow_path = os.path.join(pretrained_path, "flow.pt")
|
| 38 |
+
self.bigvgan_path = os.path.join(pretrained_path, "bigvgan.pt")
|
| 39 |
+
assert os.path.exists(self.flow_path)
|
| 40 |
+
assert os.path.exists(self.bigvgan_path)
|
| 41 |
+
|
| 42 |
+
# text normalizer
|
| 43 |
+
self.text_normalizer = TextNormalizer()
|
| 44 |
+
# text tokenizer
|
| 45 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
| 46 |
+
|
| 47 |
+
# semantic llm
|
| 48 |
+
self.semantic_llm = Speech_LLM_GPT2(
|
| 49 |
+
start_text_token=self.config["semantic_llm"]["start_text_token"],
|
| 50 |
+
stop_text_token=self.config["semantic_llm"]["stop_text_token"],
|
| 51 |
+
num_text_tokens=self.config["semantic_llm"]["num_text_tokens"],
|
| 52 |
+
start_audio_token=self.config["semantic_llm"]["start_audio_token"],
|
| 53 |
+
stop_audio_token=self.config["semantic_llm"]["stop_audio_token"],
|
| 54 |
+
num_audio_tokens=self.config["semantic_llm"]["num_audio_tokens"],
|
| 55 |
+
llm_hidden_size=self.config["semantic_llm"]["llm_hidden_size"],
|
| 56 |
+
llm_intermediate_size=self.config["semantic_llm"]["llm_intermediate_size"],
|
| 57 |
+
llm_num_layers=self.config["semantic_llm"]["llm_num_layers"],
|
| 58 |
+
llm_num_heads=self.config["semantic_llm"]["llm_num_heads"],
|
| 59 |
+
llm_max_audio_seq_len=self.config["semantic_llm"]["llm_max_audio_seq_len"],
|
| 60 |
+
llm_max_text_seq_len=self.config["semantic_llm"]["llm_max_text_seq_len"],
|
| 61 |
+
llm_max_prompt_len=self.config["semantic_llm"]["llm_max_prompt_len"],
|
| 62 |
+
code_stride_len=self.config["semantic_llm"]["code_stride_len"],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
sd = torch.load(self.semantic_llm_path, map_location=device)["model"]
|
| 66 |
+
self.semantic_llm.load_state_dict(sd, strict=True)
|
| 67 |
+
self.semantic_llm = self.semantic_llm.to(device=device)
|
| 68 |
+
self.semantic_llm.eval()
|
| 69 |
+
self.semantic_llm.init_gpt_for_inference(kv_cache=True)
|
| 70 |
+
|
| 71 |
+
# Speech tokenizer
|
| 72 |
+
self.speech_tokenizer = SemanticTokenizer(
|
| 73 |
+
config=self.config["semantic_tokenizer"], path=self.speech_tokenizer_path
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Acoustic decoder
|
| 77 |
+
if 'acoustic_llm' in self.config:
|
| 78 |
+
self.acoustic_decoder = TwoStageCodec(self.config)
|
| 79 |
+
self.acoustic_decoder.load_model(self.acoustic_llm_path, self.acoustic_codec_path)
|
| 80 |
+
else:
|
| 81 |
+
self.acoustic_decoder = FlowToken2Audio(self.config)
|
| 82 |
+
self.acoustic_decoder.load_model(self.flow_path, self.bigvgan_path)
|
| 83 |
+
self.acoustic_decoder.eval()
|
| 84 |
+
self.acoustic_decoder = self.acoustic_decoder.to(device)
|
| 85 |
+
|
| 86 |
+
def extract_spk_embeddings(self, prompt_wav):
|
| 87 |
+
audio, lsr, audio_resampled = load_audio(
|
| 88 |
+
audiopath=prompt_wav,
|
| 89 |
+
sampling_rate=16000,
|
| 90 |
+
)
|
| 91 |
+
_, _, audio_resampled24k = load_audio(
|
| 92 |
+
audiopath=prompt_wav,
|
| 93 |
+
sampling_rate=24000,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
audio_resampled = audio_resampled.to(self.device)
|
| 97 |
+
audio_len = torch.tensor(
|
| 98 |
+
data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# spk_embeddings:[1, 512]
|
| 102 |
+
prompt_tokens, token_lengths, spk_embeddings = self.speech_tokenizer(
|
| 103 |
+
audio_resampled, audio_len
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
prompt_acoustic_tokens, acoustic_llm_spk = self.acoustic_decoder.extract(
|
| 107 |
+
audio_resampled if isinstance(self.acoustic_decoder, TwoStageCodec) else audio_resampled24k,
|
| 108 |
+
audio_len, spk_embeddings.unsqueeze(0)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return prompt_tokens, spk_embeddings, prompt_acoustic_tokens, acoustic_llm_spk
|
| 112 |
+
|
| 113 |
+
def synthesize_base(
|
| 114 |
+
self,
|
| 115 |
+
prompt_semantic_tokens,
|
| 116 |
+
prompt_acoustic_tokens,
|
| 117 |
+
spk_semantic_llm,
|
| 118 |
+
spk_acoustic_llm,
|
| 119 |
+
prompt_text,
|
| 120 |
+
text,
|
| 121 |
+
lang="auto",
|
| 122 |
+
):
|
| 123 |
+
"""_summary_
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
prompt_wav (_type_): _description_
|
| 127 |
+
prompt_text (_type_): _description_
|
| 128 |
+
text (_type_): _description_
|
| 129 |
+
lang (str, optional): _description_. Defaults to "auto".
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
_type_: _description_
|
| 133 |
+
"""
|
| 134 |
+
if lang == "en":
|
| 135 |
+
text = prompt_text + " " + text
|
| 136 |
+
else:
|
| 137 |
+
text = prompt_text + text
|
| 138 |
+
|
| 139 |
+
print("---text:\n", text)
|
| 140 |
+
|
| 141 |
+
# Pre-process prompt tokens
|
| 142 |
+
# text to tokens
|
| 143 |
+
text_tokens = self.text_tokenizer.encode(
|
| 144 |
+
text=text,
|
| 145 |
+
add_special_tokens=False,
|
| 146 |
+
max_length=10**6,
|
| 147 |
+
truncation=False,
|
| 148 |
+
)
|
| 149 |
+
# print("---decode", [self.text_tokenizer.decode([c]) for c in text_tokens])
|
| 150 |
+
|
| 151 |
+
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
|
| 152 |
+
|
| 153 |
+
assert text_tokens.shape[-1] < 200
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
gpt_codes = self.semantic_llm.generate_ic(
|
| 156 |
+
cond_latents=spk_semantic_llm,
|
| 157 |
+
text_inputs=text_tokens,
|
| 158 |
+
prompt_tokens=prompt_semantic_tokens[:, :-3],
|
| 159 |
+
do_sample=True,
|
| 160 |
+
top_p=0.85,
|
| 161 |
+
top_k=30,
|
| 162 |
+
temperature=0.75,
|
| 163 |
+
num_return_sequences=7,
|
| 164 |
+
num_beams=1,
|
| 165 |
+
length_penalty=2.0,
|
| 166 |
+
repetition_penalty=5.0,
|
| 167 |
+
output_attentions=False,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
seqs = []
|
| 171 |
+
for seq in gpt_codes:
|
| 172 |
+
index = (seq == self.EOS_TOKEN).nonzero(as_tuple=True)[0][0]
|
| 173 |
+
seq = seq[:index]
|
| 174 |
+
seqs.append(seq)
|
| 175 |
+
|
| 176 |
+
sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
|
| 177 |
+
sorted_len = [len(l) for l in sorted_seqs]
|
| 178 |
+
|
| 179 |
+
gpt_codes = sorted_seqs[2].unsqueeze(0)
|
| 180 |
+
|
| 181 |
+
# Acoustic decoder
|
| 182 |
+
rec_wavs = self.acoustic_decoder(
|
| 183 |
+
gpt_codes, prompt_semantic_tokens, prompt_acoustic_tokens, spk_acoustic_llm
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
rec_wavs = rec_wavs.detach().cpu()
|
| 187 |
+
return rec_wavs
|
| 188 |
+
|
| 189 |
+
@torch.no_grad()
|
| 190 |
+
def synthesize(self, prompt_wav, prompt_text, text, lang="auto", use_tn=False):
|
| 191 |
+
"""audio synthesize
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
prompt_wav (_type_): _description_
|
| 195 |
+
prompt_text (_type_): _description_
|
| 196 |
+
text (_type_): _description_
|
| 197 |
+
lang (str, optional): _description_. Defaults to "auto".
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
_type_: _description_
|
| 201 |
+
"""
|
| 202 |
+
assert lang in ["zh", "en", "auto"]
|
| 203 |
+
assert os.path.exists(prompt_wav)
|
| 204 |
+
|
| 205 |
+
(
|
| 206 |
+
prompt_semantic_tokens,
|
| 207 |
+
spk_embeddings,
|
| 208 |
+
prompt_acoustic_tokens,
|
| 209 |
+
spk_acoustic_llm,
|
| 210 |
+
) = self.extract_spk_embeddings(prompt_wav=prompt_wav)
|
| 211 |
+
|
| 212 |
+
spk_embeddings = spk_embeddings.unsqueeze(0)
|
| 213 |
+
spk_semantic_llm = self.semantic_llm.reference_embedding(spk_embeddings)
|
| 214 |
+
|
| 215 |
+
# print("---prompt_semantic_tokens:\n", prompt_semantic_tokens)
|
| 216 |
+
# print("---spk_embeddings:\n", spk_embeddings)
|
| 217 |
+
|
| 218 |
+
# clean text
|
| 219 |
+
prompt_text = clean_text(prompt_text)
|
| 220 |
+
text = clean_text(text=text)
|
| 221 |
+
|
| 222 |
+
if use_tn:
|
| 223 |
+
substrings = text_split(text=text)
|
| 224 |
+
|
| 225 |
+
out_wavs = []
|
| 226 |
+
try:
|
| 227 |
+
for sub in substrings:
|
| 228 |
+
|
| 229 |
+
res_lang = self.text_normalizer.tn(text=sub)[1]
|
| 230 |
+
|
| 231 |
+
chunk = self.synthesize_base(
|
| 232 |
+
prompt_semantic_tokens=prompt_semantic_tokens,
|
| 233 |
+
prompt_acoustic_tokens=prompt_acoustic_tokens,
|
| 234 |
+
spk_semantic_llm=spk_semantic_llm,
|
| 235 |
+
spk_acoustic_llm=spk_acoustic_llm,
|
| 236 |
+
prompt_text=prompt_text,
|
| 237 |
+
text=sub,
|
| 238 |
+
lang=res_lang,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
out_wavs.append(chunk)
|
| 242 |
+
out_wav = torch.concat(out_wavs, axis=-1)
|
| 243 |
+
return out_wav
|
| 244 |
+
except:
|
| 245 |
+
print('[ERROR] ', format_exc())
|
| 246 |
+
return None
|
| 247 |
+
else:
|
| 248 |
+
out_wavs = []
|
| 249 |
+
try:
|
| 250 |
+
res_lang = self.text_normalizer.tn(text=text)[1]
|
| 251 |
+
|
| 252 |
+
chunk = self.synthesize_base(
|
| 253 |
+
prompt_semantic_tokens=prompt_semantic_tokens,
|
| 254 |
+
prompt_acoustic_tokens=prompt_acoustic_tokens,
|
| 255 |
+
spk_semantic_llm=spk_semantic_llm,
|
| 256 |
+
spk_acoustic_llm=spk_acoustic_llm,
|
| 257 |
+
prompt_text=prompt_text,
|
| 258 |
+
text=text,
|
| 259 |
+
lang=res_lang,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
out_wavs.append(chunk)
|
| 263 |
+
out_wav = torch.concat(out_wavs, axis=-1)
|
| 264 |
+
return out_wav
|
| 265 |
+
except:
|
| 266 |
+
return None
|
fireredtts/models/token2audio.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from fireredtts.modules.acoustic_llm import AcousticLLM
|
| 4 |
+
from fireredtts.modules.acoustic_codec import AcousticCodec
|
| 5 |
+
from fireredtts.modules.flowmatching import FlowToken2Mel
|
| 6 |
+
from fireredtts.modules.bigvgan import BigVGAN, MelExtractor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TwoStageCodec(torch.nn.Module):
|
| 10 |
+
def __init__(self, config):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.acoustic_llm = AcousticLLM(**config["acoustic_llm"])
|
| 13 |
+
self.acoustic_codec = AcousticCodec(**config["acoustic_codec"])
|
| 14 |
+
|
| 15 |
+
def load_model(self, acoustic_llm_path, acoustic_codec_path):
|
| 16 |
+
self.acoustic_llm.load_state_dict(
|
| 17 |
+
torch.load(acoustic_llm_path, map_location="cpu"), strict=True
|
| 18 |
+
)
|
| 19 |
+
self.acoustic_codec.load_state_dict(
|
| 20 |
+
torch.load(acoustic_codec_path, map_location="cpu"), strict=True
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
@torch.inference_mode()
|
| 24 |
+
def forward(
|
| 25 |
+
self, semantic_token, prompt_semantic_token, prompt_acoustic_token, spk_gpt
|
| 26 |
+
):
|
| 27 |
+
# print('Before: ', semantic_token.shape)
|
| 28 |
+
token_pred = torch.cat((prompt_semantic_token, semantic_token), dim=1)
|
| 29 |
+
|
| 30 |
+
# Fine LLM inference
|
| 31 |
+
token_pred = self.acoustic_llm.inference_speech(
|
| 32 |
+
speech_conditioning_latent=spk_gpt,
|
| 33 |
+
text_inputs=token_pred,
|
| 34 |
+
num_return_sequences=1,
|
| 35 |
+
input_tokens=prompt_acoustic_token,
|
| 36 |
+
)[0]
|
| 37 |
+
|
| 38 |
+
if isinstance(token_pred, (tuple, list)):
|
| 39 |
+
token_pred = [x.unsqueeze(0) for x in token_pred]
|
| 40 |
+
else:
|
| 41 |
+
token_pred = token_pred.unsqueeze(0)
|
| 42 |
+
|
| 43 |
+
acoustic_outputs = self.acoustic_codec.reconstruct_wav(token=token_pred)
|
| 44 |
+
wav = acoustic_outputs["wav_pred"].squeeze(1)
|
| 45 |
+
|
| 46 |
+
return wav
|
| 47 |
+
|
| 48 |
+
def extract(self, wavs, wav_lengths, spk):
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
wavs = wavs.cuda()
|
| 51 |
+
cond_tok = self.acoustic_codec.extract_speech_tokens(wavs, wav_lengths)[
|
| 52 |
+
"token"
|
| 53 |
+
][0]
|
| 54 |
+
spk_gpt = self.acoustic_llm.get_conditioning(spk)
|
| 55 |
+
return cond_tok, spk_gpt
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
"""For FlowToken2Audio, keep interface consistant with TwoStageCodec to minimize code changes.
|
| 59 |
+
prompt_acoustic_token alias to prompt_mel
|
| 60 |
+
spk_gpt alias to spk_embeddings
|
| 61 |
+
"""
|
| 62 |
+
class FlowToken2Audio(torch.nn.Module):
|
| 63 |
+
def __init__(self, config):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.flow = FlowToken2Mel(config['flow'])
|
| 66 |
+
self.bigvgan = BigVGAN(**config['bigvgan'])
|
| 67 |
+
self.mel_extractor = MelExtractor(**config['mel'])
|
| 68 |
+
|
| 69 |
+
def load_model(self, flow_path, bigvgan_path):
|
| 70 |
+
self.flow.load_state_dict(
|
| 71 |
+
torch.load(flow_path, map_location="cpu"), strict=True
|
| 72 |
+
)
|
| 73 |
+
self.bigvgan.load_state_dict(
|
| 74 |
+
torch.load(bigvgan_path, map_location="cpu")['generator'], strict=True
|
| 75 |
+
)
|
| 76 |
+
self.bigvgan.remove_weight_norm()
|
| 77 |
+
|
| 78 |
+
@torch.inference_mode()
|
| 79 |
+
def forward(
|
| 80 |
+
self, semantic_token, prompt_semantic_token, prompt_acoustic_token, spk_gpt
|
| 81 |
+
):
|
| 82 |
+
# Align prompt token & prompt_mel
|
| 83 |
+
target_mel_length = prompt_semantic_token.shape[1] * 2
|
| 84 |
+
if target_mel_length > prompt_acoustic_token.shape[1]:
|
| 85 |
+
prompt_acoustic_token = F.pad(
|
| 86 |
+
prompt_acoustic_token, (0, 0, 0, target_mel_length-prompt_acoustic_token.shape[1]),
|
| 87 |
+
mode='constant', value=-11.5
|
| 88 |
+
)
|
| 89 |
+
elif target_mel_length < prompt_acoustic_token.shape[1]:
|
| 90 |
+
prompt_acoustic_token = prompt_acoustic_token[:, :target_mel_length]
|
| 91 |
+
# prompt_acoustic_token = F.interpolate(
|
| 92 |
+
# prompt_acoustic_token.transpose(1, 2),
|
| 93 |
+
# size=prompt_semantic_token.shape[1] * 2, mode='nearest'
|
| 94 |
+
# ).transpose(1, 2)
|
| 95 |
+
mel_pred = self.flow.inference(
|
| 96 |
+
prompt_token=prompt_semantic_token,
|
| 97 |
+
prompt_xvec=spk_gpt,
|
| 98 |
+
prompt_feat=prompt_acoustic_token,
|
| 99 |
+
token=semantic_token
|
| 100 |
+
)
|
| 101 |
+
wav = self.bigvgan(mel_pred.transpose(1, 2)).squeeze(1)
|
| 102 |
+
return wav
|
| 103 |
+
|
| 104 |
+
def extract(self, wavs, wav_lengths, spk):
|
| 105 |
+
mel = self.mel_extractor(wavs, 24000).transpose(1, 2)
|
| 106 |
+
if torch.cuda.is_available():
|
| 107 |
+
mel = mel.cuda()
|
| 108 |
+
return mel, spk.squeeze(0)
|
fireredtts/modules/__init__.py
ADDED
|
File without changes
|
fireredtts/modules/acoustic_codec/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .bigcodec import BigCodec as AcousticCodec
|
fireredtts/modules/acoustic_codec/alias_free_torch/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
from .filter import *
|
| 5 |
+
from .resample import *
|
| 6 |
+
from .act import *
|
fireredtts/modules/acoustic_codec/alias_free_torch/act.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .resample import UpSample1d, DownSample1d
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Activation1d(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
activation,
|
| 12 |
+
up_ratio: int = 2,
|
| 13 |
+
down_ratio: int = 2,
|
| 14 |
+
up_kernel_size: int = 12,
|
| 15 |
+
down_kernel_size: int = 12,
|
| 16 |
+
causal: bool = False,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.up_ratio = up_ratio
|
| 20 |
+
self.down_ratio = down_ratio
|
| 21 |
+
self.act = activation
|
| 22 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 23 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 24 |
+
self.causal = causal
|
| 25 |
+
|
| 26 |
+
# x: [B,C,T]
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
if self.causal:
|
| 29 |
+
x = self.act(x)
|
| 30 |
+
else:
|
| 31 |
+
x = self.upsample(x)
|
| 32 |
+
x = self.act(x)
|
| 33 |
+
x = self.downsample(x)
|
| 34 |
+
|
| 35 |
+
return x
|
fireredtts/modules/acoustic_codec/alias_free_torch/filter.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
if "sinc" in dir(torch):
|
| 10 |
+
sinc = torch.sinc
|
| 11 |
+
else:
|
| 12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
| 13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
| 14 |
+
# LICENSE is in incl_licenses directory.
|
| 15 |
+
def sinc(x: torch.Tensor):
|
| 16 |
+
"""
|
| 17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
| 18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
| 19 |
+
"""
|
| 20 |
+
return torch.where(
|
| 21 |
+
x == 0,
|
| 22 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 23 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
| 28 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
| 29 |
+
# LICENSE is in incl_licenses directory.
|
| 30 |
+
def kaiser_sinc_filter1d(
|
| 31 |
+
cutoff, half_width, kernel_size
|
| 32 |
+
): # return filter [1,1,kernel_size]
|
| 33 |
+
even = kernel_size % 2 == 0
|
| 34 |
+
half_size = kernel_size // 2
|
| 35 |
+
|
| 36 |
+
# For kaiser window
|
| 37 |
+
delta_f = 4 * half_width
|
| 38 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 39 |
+
if A > 50.0:
|
| 40 |
+
beta = 0.1102 * (A - 8.7)
|
| 41 |
+
elif A >= 21.0:
|
| 42 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
| 43 |
+
else:
|
| 44 |
+
beta = 0.0
|
| 45 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 46 |
+
|
| 47 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
| 48 |
+
if even:
|
| 49 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
| 50 |
+
else:
|
| 51 |
+
time = torch.arange(kernel_size) - half_size
|
| 52 |
+
if cutoff == 0:
|
| 53 |
+
filter_ = torch.zeros_like(time)
|
| 54 |
+
else:
|
| 55 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
| 56 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
| 57 |
+
# of the constant component in the input signal.
|
| 58 |
+
filter_ /= filter_.sum()
|
| 59 |
+
filter = filter_.view(1, 1, kernel_size)
|
| 60 |
+
|
| 61 |
+
return filter
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LowPassFilter1d(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
cutoff=0.5,
|
| 68 |
+
half_width=0.6,
|
| 69 |
+
stride: int = 1,
|
| 70 |
+
padding: bool = True,
|
| 71 |
+
padding_mode: str = "replicate",
|
| 72 |
+
kernel_size: int = 12,
|
| 73 |
+
):
|
| 74 |
+
# kernel_size should be even number for stylegan3 setup,
|
| 75 |
+
# in this implementation, odd number is also possible.
|
| 76 |
+
super().__init__()
|
| 77 |
+
if cutoff < -0.0:
|
| 78 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 79 |
+
if cutoff > 0.5:
|
| 80 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 81 |
+
self.kernel_size = kernel_size
|
| 82 |
+
self.even = kernel_size % 2 == 0
|
| 83 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 84 |
+
self.pad_right = kernel_size // 2
|
| 85 |
+
self.stride = stride
|
| 86 |
+
self.padding = padding
|
| 87 |
+
self.padding_mode = padding_mode
|
| 88 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 89 |
+
self.register_buffer("filter", filter)
|
| 90 |
+
|
| 91 |
+
# input [B, C, T]
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
_, C, _ = x.shape
|
| 94 |
+
|
| 95 |
+
if self.padding:
|
| 96 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 97 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
| 98 |
+
|
| 99 |
+
return out
|
fireredtts/modules/acoustic_codec/alias_free_torch/resample.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from .filter import LowPassFilter1d
|
| 7 |
+
from .filter import kaiser_sinc_filter1d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UpSample1d(nn.Module):
|
| 11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.ratio = ratio
|
| 14 |
+
self.kernel_size = (
|
| 15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 16 |
+
)
|
| 17 |
+
self.stride = ratio
|
| 18 |
+
self.pad = self.kernel_size // ratio - 1
|
| 19 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 20 |
+
self.pad_right = (
|
| 21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 22 |
+
)
|
| 23 |
+
filter = kaiser_sinc_filter1d(
|
| 24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
| 25 |
+
)
|
| 26 |
+
self.register_buffer("filter", filter)
|
| 27 |
+
|
| 28 |
+
# x: [B, C, T]
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
_, C, _ = x.shape
|
| 31 |
+
|
| 32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 33 |
+
x = self.ratio * F.conv_transpose1d(
|
| 34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
| 35 |
+
)
|
| 36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
| 37 |
+
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DownSample1d(nn.Module):
|
| 42 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.ratio = ratio
|
| 45 |
+
self.kernel_size = (
|
| 46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 47 |
+
)
|
| 48 |
+
self.lowpass = LowPassFilter1d(
|
| 49 |
+
cutoff=0.5 / ratio,
|
| 50 |
+
half_width=0.6 / ratio,
|
| 51 |
+
stride=ratio,
|
| 52 |
+
kernel_size=self.kernel_size,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
xx = self.lowpass(x)
|
| 57 |
+
|
| 58 |
+
return xx
|
fireredtts/modules/acoustic_codec/bigcodec.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange
|
| 2 |
+
from torch import sin, pow
|
| 3 |
+
from torch.nn import Parameter
|
| 4 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchaudio
|
| 12 |
+
import typing as tp
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
from .alias_free_torch import *
|
| 16 |
+
from .vector_quantization import VectorQuantization
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
CONV_NORMALIZATIONS = frozenset(
|
| 20 |
+
[
|
| 21 |
+
"none",
|
| 22 |
+
"weight_norm",
|
| 23 |
+
"spectral_norm",
|
| 24 |
+
"time_layer_norm",
|
| 25 |
+
"layer_norm",
|
| 26 |
+
"time_group_norm",
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def init_weights(m):
|
| 32 |
+
if isinstance(m, nn.Conv1d):
|
| 33 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
| 38 |
+
assert norm in CONV_NORMALIZATIONS
|
| 39 |
+
if norm == "weight_norm":
|
| 40 |
+
return weight_norm(module)
|
| 41 |
+
elif norm == "spectral_norm":
|
| 42 |
+
return spectral_norm(module)
|
| 43 |
+
else:
|
| 44 |
+
return module
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_norm_module(
|
| 48 |
+
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
| 49 |
+
) -> nn.Module:
|
| 50 |
+
assert norm in CONV_NORMALIZATIONS
|
| 51 |
+
if norm == "time_group_norm":
|
| 52 |
+
if causal:
|
| 53 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 54 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 55 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 56 |
+
else:
|
| 57 |
+
return nn.Identity()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_extra_padding_for_conv1d(
|
| 61 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 62 |
+
) -> int:
|
| 63 |
+
length = x.shape[-1]
|
| 64 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 65 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 66 |
+
return ideal_length - length
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def pad_for_conv1d(
|
| 70 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 71 |
+
):
|
| 72 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 73 |
+
return F.pad(x, (0, extra_padding))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def pad1d(
|
| 77 |
+
x: torch.Tensor,
|
| 78 |
+
paddings: tp.Tuple[int, int],
|
| 79 |
+
mode: str = "zero",
|
| 80 |
+
value: float = 0.0,
|
| 81 |
+
):
|
| 82 |
+
length = x.shape[-1]
|
| 83 |
+
padding_left, padding_right = paddings
|
| 84 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 85 |
+
if mode == "reflect":
|
| 86 |
+
max_pad = max(padding_left, padding_right)
|
| 87 |
+
extra_pad = 0
|
| 88 |
+
if length <= max_pad:
|
| 89 |
+
extra_pad = max_pad - length + 1
|
| 90 |
+
x = F.pad(x, (0, extra_pad))
|
| 91 |
+
padded = F.pad(x, paddings, mode, value)
|
| 92 |
+
end = padded.shape[-1] - extra_pad
|
| 93 |
+
return padded[..., :end]
|
| 94 |
+
else:
|
| 95 |
+
return F.pad(x, paddings, mode, value)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 99 |
+
padding_left, padding_right = paddings
|
| 100 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 101 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 102 |
+
end = x.shape[-1] - padding_right
|
| 103 |
+
return x[..., padding_left:end]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class NormConv1d(nn.Module):
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
*args,
|
| 111 |
+
causal: bool = False,
|
| 112 |
+
norm: str = "none",
|
| 113 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 118 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 119 |
+
self.norm_type = norm
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
x = self.conv(x)
|
| 123 |
+
x = self.norm(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class NormConvTranspose1d(nn.Module):
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
*args,
|
| 132 |
+
causal: bool = False,
|
| 133 |
+
norm: str = "none",
|
| 134 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 135 |
+
**kwargs,
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.convtr = apply_parametrization_norm(
|
| 139 |
+
nn.ConvTranspose1d(*args, **kwargs), norm
|
| 140 |
+
)
|
| 141 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 142 |
+
self.norm_type = norm
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
x = self.convtr(x)
|
| 146 |
+
x = self.norm(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SConv1d(nn.Module):
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
in_channels: int,
|
| 155 |
+
out_channels: int,
|
| 156 |
+
kernel_size: int,
|
| 157 |
+
stride: int = 1,
|
| 158 |
+
dilation: int = 1,
|
| 159 |
+
groups: int = 1,
|
| 160 |
+
bias: bool = True,
|
| 161 |
+
causal: bool = False,
|
| 162 |
+
norm: str = "none",
|
| 163 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 164 |
+
pad_mode: str = "reflect",
|
| 165 |
+
**kwargs,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
# warn user on unusual setup between dilation and stride
|
| 169 |
+
if stride > 1 and dilation > 1:
|
| 170 |
+
warnings.warn(
|
| 171 |
+
"SConv1d has been initialized with stride > 1 and dilation > 1"
|
| 172 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
| 173 |
+
)
|
| 174 |
+
self.conv = NormConv1d(
|
| 175 |
+
in_channels,
|
| 176 |
+
out_channels,
|
| 177 |
+
kernel_size,
|
| 178 |
+
stride,
|
| 179 |
+
dilation=dilation,
|
| 180 |
+
groups=groups,
|
| 181 |
+
bias=bias,
|
| 182 |
+
causal=causal,
|
| 183 |
+
norm=norm,
|
| 184 |
+
norm_kwargs=norm_kwargs,
|
| 185 |
+
)
|
| 186 |
+
self.causal = causal
|
| 187 |
+
self.pad_mode = pad_mode
|
| 188 |
+
|
| 189 |
+
def forward(self, x):
|
| 190 |
+
B, C, T = x.shape
|
| 191 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
| 192 |
+
stride = self.conv.conv.stride[0]
|
| 193 |
+
dilation = self.conv.conv.dilation[0]
|
| 194 |
+
kernel_size = (
|
| 195 |
+
kernel_size - 1
|
| 196 |
+
) * dilation + 1 # effective kernel size with dilations
|
| 197 |
+
padding_total = kernel_size - stride
|
| 198 |
+
extra_padding = get_extra_padding_for_conv1d(
|
| 199 |
+
x, kernel_size, stride, padding_total
|
| 200 |
+
)
|
| 201 |
+
if self.causal:
|
| 202 |
+
# Left padding for causal
|
| 203 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 204 |
+
else:
|
| 205 |
+
# Asymmetric padding required for odd strides
|
| 206 |
+
padding_right = padding_total // 2
|
| 207 |
+
padding_left = padding_total - padding_right
|
| 208 |
+
x = pad1d(
|
| 209 |
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
| 210 |
+
)
|
| 211 |
+
return self.conv(x)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class SConvTranspose1d(nn.Module):
|
| 215 |
+
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
in_channels: int,
|
| 219 |
+
out_channels: int,
|
| 220 |
+
kernel_size: int,
|
| 221 |
+
stride: int = 1,
|
| 222 |
+
causal: bool = False,
|
| 223 |
+
norm: str = "none",
|
| 224 |
+
trim_right_ratio: float = 1.0,
|
| 225 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 226 |
+
**kwargs,
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.convtr = NormConvTranspose1d(
|
| 230 |
+
in_channels,
|
| 231 |
+
out_channels,
|
| 232 |
+
kernel_size,
|
| 233 |
+
stride,
|
| 234 |
+
causal=causal,
|
| 235 |
+
norm=norm,
|
| 236 |
+
norm_kwargs=norm_kwargs,
|
| 237 |
+
)
|
| 238 |
+
self.causal = causal
|
| 239 |
+
self.trim_right_ratio = trim_right_ratio
|
| 240 |
+
assert (
|
| 241 |
+
self.causal or self.trim_right_ratio == 1.0
|
| 242 |
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 243 |
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 247 |
+
stride = self.convtr.convtr.stride[0]
|
| 248 |
+
padding_total = kernel_size - stride
|
| 249 |
+
|
| 250 |
+
y = self.convtr(x)
|
| 251 |
+
|
| 252 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 253 |
+
# removed at the very end, when keeping only the right length for the output,
|
| 254 |
+
# as removing it here would require also passing the length at the matching layer
|
| 255 |
+
# in the encoder.
|
| 256 |
+
if self.causal:
|
| 257 |
+
# Trim the padding on the right according to the specified ratio
|
| 258 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
| 259 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 260 |
+
padding_left = padding_total - padding_right
|
| 261 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 262 |
+
else:
|
| 263 |
+
# Asymmetric padding required for odd strides
|
| 264 |
+
padding_right = padding_total // 2
|
| 265 |
+
padding_left = padding_total - padding_right
|
| 266 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 267 |
+
return y
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def WNConv1d(*args, **kwargs):
|
| 271 |
+
if kwargs.get("causal", False):
|
| 272 |
+
kwargs["norm"] = "weight_norm"
|
| 273 |
+
conv1d = SConv1d(*args, **kwargs)
|
| 274 |
+
else:
|
| 275 |
+
kwargs.pop("causal")
|
| 276 |
+
conv1d = weight_norm(nn.Conv1d(*args, **kwargs))
|
| 277 |
+
return conv1d
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 281 |
+
if kwargs.get("causal", False):
|
| 282 |
+
kwargs["norm"] = "weight_norm"
|
| 283 |
+
transposed_conv1d = SConvTranspose1d(*args, **kwargs)
|
| 284 |
+
else:
|
| 285 |
+
kwargs.pop("causal")
|
| 286 |
+
transposed_conv1d = weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 287 |
+
return transposed_conv1d
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class SnakeBeta(nn.Module):
|
| 291 |
+
|
| 292 |
+
def __init__(
|
| 293 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
| 294 |
+
):
|
| 295 |
+
super(SnakeBeta, self).__init__()
|
| 296 |
+
self.in_features = in_features
|
| 297 |
+
|
| 298 |
+
# initialize alpha
|
| 299 |
+
self.alpha_logscale = alpha_logscale
|
| 300 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 301 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 302 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 303 |
+
else: # linear scale alphas initialized to ones
|
| 304 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 305 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 306 |
+
|
| 307 |
+
self.alpha.requires_grad = alpha_trainable
|
| 308 |
+
self.beta.requires_grad = alpha_trainable
|
| 309 |
+
|
| 310 |
+
self.no_div_by_zero = 0.000000001
|
| 311 |
+
|
| 312 |
+
def forward(self, x):
|
| 313 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 314 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 315 |
+
if self.alpha_logscale:
|
| 316 |
+
alpha = torch.exp(alpha)
|
| 317 |
+
beta = torch.exp(beta)
|
| 318 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 319 |
+
|
| 320 |
+
return x
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class ResidualUnit(nn.Module):
|
| 324 |
+
|
| 325 |
+
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
|
| 326 |
+
super().__init__()
|
| 327 |
+
pad = ((7 - 1) * dilation) // 2
|
| 328 |
+
self.block = nn.Sequential(
|
| 329 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True), causal=causal),
|
| 330 |
+
WNConv1d(
|
| 331 |
+
dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal
|
| 332 |
+
),
|
| 333 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True), causal=causal),
|
| 334 |
+
WNConv1d(dim, dim, kernel_size=1, causal=causal),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def forward(self, x):
|
| 338 |
+
return x + self.block(x)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class EncoderBlock(nn.Module):
|
| 342 |
+
|
| 343 |
+
def __init__(
|
| 344 |
+
self, dim: int = 16, stride: int = 1, dilations=(1, 3, 9), causal: bool = False
|
| 345 |
+
):
|
| 346 |
+
super().__init__()
|
| 347 |
+
runits = [ResidualUnit(dim // 2, dilation=d, causal=causal) for d in dilations]
|
| 348 |
+
self.block = nn.Sequential(
|
| 349 |
+
*runits,
|
| 350 |
+
Activation1d(
|
| 351 |
+
activation=SnakeBeta(dim // 2, alpha_logscale=True), causal=causal
|
| 352 |
+
),
|
| 353 |
+
WNConv1d(
|
| 354 |
+
dim // 2,
|
| 355 |
+
dim,
|
| 356 |
+
kernel_size=2 * stride,
|
| 357 |
+
stride=stride,
|
| 358 |
+
padding=stride // 2 + stride % 2,
|
| 359 |
+
causal=causal,
|
| 360 |
+
),
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def forward(self, x):
|
| 364 |
+
return self.block(x)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class DecoderBlock(nn.Module):
|
| 368 |
+
|
| 369 |
+
def __init__(
|
| 370 |
+
self,
|
| 371 |
+
input_dim: int = 16,
|
| 372 |
+
output_dim: int = 8,
|
| 373 |
+
stride: int = 1,
|
| 374 |
+
dilations=(1, 3, 9),
|
| 375 |
+
causal: bool = False,
|
| 376 |
+
):
|
| 377 |
+
super().__init__()
|
| 378 |
+
self.block = nn.Sequential(
|
| 379 |
+
Activation1d(
|
| 380 |
+
activation=SnakeBeta(input_dim, alpha_logscale=True), causal=causal
|
| 381 |
+
),
|
| 382 |
+
WNConvTranspose1d(
|
| 383 |
+
input_dim,
|
| 384 |
+
output_dim,
|
| 385 |
+
kernel_size=2 * stride,
|
| 386 |
+
stride=stride,
|
| 387 |
+
padding=stride // 2 + stride % 2,
|
| 388 |
+
output_padding=stride % 2,
|
| 389 |
+
causal=causal,
|
| 390 |
+
),
|
| 391 |
+
)
|
| 392 |
+
self.block.extend(
|
| 393 |
+
[ResidualUnit(output_dim, dilation=d, causal=causal) for d in dilations]
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def forward(self, x):
|
| 397 |
+
return self.block(x)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class ResLSTM(nn.Module):
|
| 401 |
+
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
dimension: int,
|
| 405 |
+
num_layers: int = 2,
|
| 406 |
+
bidirectional: bool = False,
|
| 407 |
+
skip: bool = True,
|
| 408 |
+
):
|
| 409 |
+
super().__init__()
|
| 410 |
+
self.skip = skip
|
| 411 |
+
self.lstm = nn.LSTM(
|
| 412 |
+
dimension,
|
| 413 |
+
dimension if not bidirectional else dimension // 2,
|
| 414 |
+
num_layers,
|
| 415 |
+
batch_first=True,
|
| 416 |
+
bidirectional=bidirectional,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
x = rearrange(x, "b f t -> b t f")
|
| 421 |
+
y, _ = self.lstm(x)
|
| 422 |
+
if self.skip:
|
| 423 |
+
y = y + x
|
| 424 |
+
y = rearrange(y, "b t f -> b f t")
|
| 425 |
+
return y
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class Resampler(nn.Module):
|
| 429 |
+
|
| 430 |
+
def __init__(self, source_sr=24000, target_sr=24000):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.source_sr = source_sr
|
| 433 |
+
self.target_sr = target_sr
|
| 434 |
+
|
| 435 |
+
def forward(self, wav, wav_length):
|
| 436 |
+
if self.source_sr != self.target_sr:
|
| 437 |
+
wav = torchaudio.functional.resample(wav, self.source_sr, self.target_sr)
|
| 438 |
+
wav_length = (wav_length * (self.source_sr / self.target_sr)).int()
|
| 439 |
+
return wav, wav_length
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class CodecEncoder(nn.Module):
|
| 443 |
+
|
| 444 |
+
def __init__(
|
| 445 |
+
self,
|
| 446 |
+
ngf=48,
|
| 447 |
+
use_rnn=True,
|
| 448 |
+
rnn_bidirectional=False,
|
| 449 |
+
rnn_num_layers=2,
|
| 450 |
+
up_ratios=(2, 2, 2, 5, 5),
|
| 451 |
+
dilations=(1, 3, 9),
|
| 452 |
+
out_channels=1024,
|
| 453 |
+
causal=False,
|
| 454 |
+
):
|
| 455 |
+
super().__init__()
|
| 456 |
+
self.hop_length = np.prod(up_ratios)
|
| 457 |
+
self.ngf = ngf
|
| 458 |
+
self.up_ratios = up_ratios
|
| 459 |
+
|
| 460 |
+
# Create first convolution
|
| 461 |
+
d_model = ngf
|
| 462 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, causal=causal)]
|
| 463 |
+
|
| 464 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 465 |
+
for i, stride in enumerate(up_ratios):
|
| 466 |
+
d_model *= 2
|
| 467 |
+
self.block += [
|
| 468 |
+
EncoderBlock(d_model, stride=stride, dilations=dilations, causal=causal)
|
| 469 |
+
]
|
| 470 |
+
# RNN
|
| 471 |
+
if use_rnn:
|
| 472 |
+
self.block += [
|
| 473 |
+
ResLSTM(
|
| 474 |
+
d_model, num_layers=rnn_num_layers, bidirectional=rnn_bidirectional
|
| 475 |
+
)
|
| 476 |
+
]
|
| 477 |
+
# Create last convolution
|
| 478 |
+
self.block += [
|
| 479 |
+
Activation1d(
|
| 480 |
+
activation=SnakeBeta(d_model, alpha_logscale=True), causal=causal
|
| 481 |
+
),
|
| 482 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1, causal=causal),
|
| 483 |
+
]
|
| 484 |
+
|
| 485 |
+
# Wrap black into nn.Sequential
|
| 486 |
+
self.block = nn.Sequential(*self.block)
|
| 487 |
+
self.enc_dim = d_model
|
| 488 |
+
|
| 489 |
+
self.reset_parameters()
|
| 490 |
+
|
| 491 |
+
def forward(self, x):
|
| 492 |
+
out = self.block(x)
|
| 493 |
+
return out
|
| 494 |
+
|
| 495 |
+
def remove_weight_norm(self):
|
| 496 |
+
def _remove_weight_norm(m):
|
| 497 |
+
try:
|
| 498 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 499 |
+
except ValueError: # this module didn't have weight norm
|
| 500 |
+
return
|
| 501 |
+
|
| 502 |
+
self.apply(_remove_weight_norm)
|
| 503 |
+
|
| 504 |
+
def apply_weight_norm(self):
|
| 505 |
+
def _apply_weight_norm(m):
|
| 506 |
+
if isinstance(m, nn.Conv1d):
|
| 507 |
+
torch.nn.utils.weight_norm(m)
|
| 508 |
+
|
| 509 |
+
self.apply(_apply_weight_norm)
|
| 510 |
+
|
| 511 |
+
def reset_parameters(self):
|
| 512 |
+
self.apply(init_weights)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class CodecDecoder(nn.Module):
|
| 516 |
+
|
| 517 |
+
def __init__(
|
| 518 |
+
self,
|
| 519 |
+
in_channels=1024,
|
| 520 |
+
upsample_initial_channel=1536,
|
| 521 |
+
ngf=48,
|
| 522 |
+
use_rnn=True,
|
| 523 |
+
rnn_bidirectional=False,
|
| 524 |
+
rnn_num_layers=2,
|
| 525 |
+
up_ratios=(5, 5, 2, 2, 2),
|
| 526 |
+
dilations=(1, 3, 9),
|
| 527 |
+
causal=False,
|
| 528 |
+
delay=0,
|
| 529 |
+
):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.hop_length = np.prod(up_ratios)
|
| 532 |
+
self.ngf = ngf
|
| 533 |
+
self.up_ratios = up_ratios
|
| 534 |
+
self.delay = delay
|
| 535 |
+
|
| 536 |
+
channels = upsample_initial_channel
|
| 537 |
+
layers = [
|
| 538 |
+
WNConv1d(in_channels, channels, kernel_size=7, padding=3, causal=causal)
|
| 539 |
+
]
|
| 540 |
+
|
| 541 |
+
if use_rnn:
|
| 542 |
+
layers += [
|
| 543 |
+
ResLSTM(
|
| 544 |
+
channels, num_layers=rnn_num_layers, bidirectional=rnn_bidirectional
|
| 545 |
+
)
|
| 546 |
+
]
|
| 547 |
+
|
| 548 |
+
for i, stride in enumerate(up_ratios):
|
| 549 |
+
input_dim = channels // 2**i
|
| 550 |
+
output_dim = channels // 2 ** (i + 1)
|
| 551 |
+
layers += [
|
| 552 |
+
DecoderBlock(input_dim, output_dim, stride, dilations, causal=causal)
|
| 553 |
+
]
|
| 554 |
+
|
| 555 |
+
layers += [
|
| 556 |
+
Activation1d(
|
| 557 |
+
activation=SnakeBeta(output_dim, alpha_logscale=True), causal=causal
|
| 558 |
+
),
|
| 559 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3, causal=causal),
|
| 560 |
+
nn.Tanh(),
|
| 561 |
+
]
|
| 562 |
+
|
| 563 |
+
self.model = nn.Sequential(*layers)
|
| 564 |
+
self.reset_parameters()
|
| 565 |
+
|
| 566 |
+
def forward(self, x):
|
| 567 |
+
# Time delay
|
| 568 |
+
if self.delay > 0:
|
| 569 |
+
x = F.pad(x, (0, self.delay), mode="constant", value=0)
|
| 570 |
+
|
| 571 |
+
x = self.model(x)
|
| 572 |
+
|
| 573 |
+
# De-delay
|
| 574 |
+
if self.delay > 0:
|
| 575 |
+
x = x[..., self.delay :]
|
| 576 |
+
|
| 577 |
+
return x
|
| 578 |
+
|
| 579 |
+
def remove_weight_norm(self):
|
| 580 |
+
def _remove_weight_norm(m):
|
| 581 |
+
try:
|
| 582 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 583 |
+
except ValueError: # this module didn't have weight norm
|
| 584 |
+
return
|
| 585 |
+
|
| 586 |
+
self.apply(_remove_weight_norm)
|
| 587 |
+
|
| 588 |
+
def apply_weight_norm(self):
|
| 589 |
+
def _apply_weight_norm(m):
|
| 590 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
| 591 |
+
torch.nn.utils.weight_norm(m)
|
| 592 |
+
|
| 593 |
+
self.apply(_apply_weight_norm)
|
| 594 |
+
|
| 595 |
+
def reset_parameters(self):
|
| 596 |
+
self.apply(init_weights)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class BigCodec(nn.Module):
|
| 600 |
+
|
| 601 |
+
def __init__(
|
| 602 |
+
self,
|
| 603 |
+
n_model_size: int,
|
| 604 |
+
encoder_config: dict,
|
| 605 |
+
decoder_config: dict,
|
| 606 |
+
vq_config: dict,
|
| 607 |
+
resampler_config: dict = None,
|
| 608 |
+
):
|
| 609 |
+
super(BigCodec, self).__init__()
|
| 610 |
+
self.n_model_size = n_model_size
|
| 611 |
+
|
| 612 |
+
self.encoder = CodecEncoder(out_channels=n_model_size, **encoder_config)
|
| 613 |
+
self.decoder = CodecDecoder(in_channels=n_model_size, **decoder_config)
|
| 614 |
+
self.quantizer = VectorQuantization(n_model_size, **vq_config)
|
| 615 |
+
|
| 616 |
+
# Optional modules
|
| 617 |
+
if resampler_config:
|
| 618 |
+
self.resampler = Resampler(**resampler_config)
|
| 619 |
+
|
| 620 |
+
def forward(
|
| 621 |
+
self, wav, wav_length=None, enable_vq=True, decode=True, update_codebook=True
|
| 622 |
+
):
|
| 623 |
+
# Preprocess wav
|
| 624 |
+
if len(wav.shape) == 2:
|
| 625 |
+
wav = wav.unsqueeze(1)
|
| 626 |
+
if wav_length is None:
|
| 627 |
+
wav_length = torch.full([wav.shape[0]], max(wav.shape)).to(wav.device)
|
| 628 |
+
|
| 629 |
+
# (Optional) Resample
|
| 630 |
+
processed_wav, processed_wav_length = wav, wav_length
|
| 631 |
+
if hasattr(self, "resampler"):
|
| 632 |
+
processed_wav, processed_wav_length = self.resampler(
|
| 633 |
+
processed_wav, processed_wav_length
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# Update VQ parameters
|
| 637 |
+
quant_length = torch.ceil(processed_wav_length / self.encoder.hop_length).int()
|
| 638 |
+
update_codebook = update_codebook and self.training
|
| 639 |
+
|
| 640 |
+
# Encode
|
| 641 |
+
encoder_outputs = self.encoder(processed_wav)
|
| 642 |
+
|
| 643 |
+
# Quantize
|
| 644 |
+
quant, diff, embed_ind = self.quantizer(
|
| 645 |
+
encoder_outputs.transpose(1, 2),
|
| 646 |
+
quant_length.clamp(max=encoder_outputs.shape[2]),
|
| 647 |
+
enable_vq=enable_vq,
|
| 648 |
+
update_codebook=update_codebook,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
if decode:
|
| 652 |
+
# Decode
|
| 653 |
+
decoder_outputs = self.decoder(quant.transpose(1, 2))
|
| 654 |
+
else:
|
| 655 |
+
decoder_outputs = None
|
| 656 |
+
|
| 657 |
+
output_dict = {
|
| 658 |
+
"quant": quant,
|
| 659 |
+
"token": embed_ind,
|
| 660 |
+
"token_length": quant_length,
|
| 661 |
+
"encoder_diffs": diff,
|
| 662 |
+
"wav_pred": decoder_outputs,
|
| 663 |
+
}
|
| 664 |
+
return output_dict
|
| 665 |
+
|
| 666 |
+
@torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
|
| 667 |
+
def extract_speech_tokens(
|
| 668 |
+
self, wav, wav_length, serialize=True, extract_spk=True, shuffle=False
|
| 669 |
+
):
|
| 670 |
+
output_dict = self.forward(wav, wav_length, enable_vq=True, decode=False)
|
| 671 |
+
token_seqs, token_length = [output_dict["token"]], [output_dict["token_length"]]
|
| 672 |
+
output_dict.update(
|
| 673 |
+
{
|
| 674 |
+
"token": token_seqs,
|
| 675 |
+
"token_length": token_length,
|
| 676 |
+
}
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
return output_dict
|
| 680 |
+
|
| 681 |
+
@torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
|
| 682 |
+
def reconstruct_wav(self, token=None, quant=None, spk=None):
|
| 683 |
+
if token is not None:
|
| 684 |
+
# De-tokenization
|
| 685 |
+
quant = self.quantizer.decode(token)
|
| 686 |
+
|
| 687 |
+
# Speaker embedding
|
| 688 |
+
if hasattr(self, "global_encoder"):
|
| 689 |
+
quant = quant + spk.unsqueeze(2)
|
| 690 |
+
else:
|
| 691 |
+
assert quant is not None
|
| 692 |
+
|
| 693 |
+
# Decode
|
| 694 |
+
wav_pred = self.decoder(quant)
|
| 695 |
+
|
| 696 |
+
return {
|
| 697 |
+
"wav_pred": wav_pred,
|
| 698 |
+
}
|
fireredtts/modules/acoustic_codec/vector_quantization.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange, repeat
|
| 2 |
+
from torch import nn
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import typing as tp
|
| 8 |
+
import numpy as np
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 13 |
+
return val if val is not None else d
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def flatten(x, x_len):
|
| 17 |
+
x_f = x.view(-1, *x.shape[2:])
|
| 18 |
+
return x_f
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def ema_inplace(moving_avg, new, decay):
|
| 22 |
+
if isinstance(decay, torch.Tensor):
|
| 23 |
+
moving_avg.data.mul_(decay).add_(new * (1 - decay))
|
| 24 |
+
else:
|
| 25 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 29 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def uniform_init(*shape: int):
|
| 33 |
+
t = torch.empty(shape)
|
| 34 |
+
nn.init.kaiming_uniform_(t)
|
| 35 |
+
return t
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def sample_vectors(samples, num: int):
|
| 39 |
+
num_samples, device = samples.shape[0], samples.device
|
| 40 |
+
|
| 41 |
+
if num_samples >= num:
|
| 42 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 43 |
+
else:
|
| 44 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 45 |
+
|
| 46 |
+
return samples[indices]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class EuclideanCodebook(nn.Module):
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
dim: int,
|
| 54 |
+
codebook_size: int,
|
| 55 |
+
decay: float = 0.99,
|
| 56 |
+
epsilon: float = 1e-5,
|
| 57 |
+
threshold_ema_dead_code: float = 1.0,
|
| 58 |
+
n_cache_iters: int = 1,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.decay = decay
|
| 62 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init
|
| 63 |
+
embed = init_fn(codebook_size, dim)
|
| 64 |
+
|
| 65 |
+
self.codebook_size = codebook_size
|
| 66 |
+
|
| 67 |
+
self.epsilon = epsilon
|
| 68 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 69 |
+
self.update_iter = 0
|
| 70 |
+
|
| 71 |
+
self.n_cache_iters = n_cache_iters
|
| 72 |
+
self.cache_vectors = []
|
| 73 |
+
self.cache_indices = []
|
| 74 |
+
|
| 75 |
+
if isinstance(self.decay, (tuple, list)):
|
| 76 |
+
self.embed_avg_cache = []
|
| 77 |
+
self.register_buffer("diff_avg_long", torch.zeros(codebook_size) + 1e-5)
|
| 78 |
+
self.register_buffer("diff_avg_short", torch.zeros(codebook_size) + 1e-5)
|
| 79 |
+
self.register_buffer("inited", torch.Tensor([True]))
|
| 80 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 81 |
+
self.register_buffer("embed", embed)
|
| 82 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 83 |
+
|
| 84 |
+
@torch.jit.ignore
|
| 85 |
+
def init_embed_(self, data):
|
| 86 |
+
if self.inited:
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
def replace_(self, samples, mask, dists=None):
|
| 90 |
+
reset_cluster_size = min(
|
| 91 |
+
self.threshold_ema_dead_code + 1, self.threshold_ema_dead_code * 1.1
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
modified_codebook = torch.where(
|
| 95 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 96 |
+
)
|
| 97 |
+
modified_codebook_avg = torch.where(
|
| 98 |
+
mask[..., None], modified_codebook * reset_cluster_size, self.embed_avg
|
| 99 |
+
)
|
| 100 |
+
modified_cluster_size = torch.where(
|
| 101 |
+
mask,
|
| 102 |
+
torch.full_like(self.cluster_size, reset_cluster_size),
|
| 103 |
+
self.cluster_size,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.embed.data.copy_(modified_codebook)
|
| 107 |
+
self.embed_avg.data.copy_(modified_codebook_avg)
|
| 108 |
+
self.cluster_size.data.copy_(modified_cluster_size)
|
| 109 |
+
|
| 110 |
+
def expire_codes_(self, batch_samples, dists=None):
|
| 111 |
+
self.update_iter += 1
|
| 112 |
+
if self.threshold_ema_dead_code == 0:
|
| 113 |
+
return
|
| 114 |
+
elif self.threshold_ema_dead_code < 1:
|
| 115 |
+
threshold_ema_dead_code = (
|
| 116 |
+
sum(self.cluster_size) * self.threshold_ema_dead_code
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
threshold_ema_dead_code = self.threshold_ema_dead_code
|
| 120 |
+
|
| 121 |
+
expired_codes = self.cluster_size < threshold_ema_dead_code
|
| 122 |
+
if not torch.any(expired_codes):
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 126 |
+
self.replace_(batch_samples, mask=expired_codes, dists=dists)
|
| 127 |
+
|
| 128 |
+
def preprocess(self, x):
|
| 129 |
+
x = rearrange(x, "... d -> (...) d")
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
def quantize(self, x):
|
| 133 |
+
embed = self.embed.t()
|
| 134 |
+
dist = -(
|
| 135 |
+
x.pow(2).sum(1, keepdim=True)
|
| 136 |
+
- 2 * x @ embed
|
| 137 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 138 |
+
)
|
| 139 |
+
embed_ind = dist.max(dim=-1).indices
|
| 140 |
+
|
| 141 |
+
return embed_ind, dist
|
| 142 |
+
|
| 143 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 144 |
+
return embed_ind.view(*shape[:-1])
|
| 145 |
+
|
| 146 |
+
def dequantize(self, embed_ind):
|
| 147 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 148 |
+
return quantize
|
| 149 |
+
|
| 150 |
+
def encode(self, x):
|
| 151 |
+
shape = x.shape
|
| 152 |
+
# pre-process
|
| 153 |
+
x = self.preprocess(x)
|
| 154 |
+
# quantize
|
| 155 |
+
embed_ind, dist = self.quantize(x)
|
| 156 |
+
# post-process
|
| 157 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 158 |
+
|
| 159 |
+
return embed_ind, dist
|
| 160 |
+
|
| 161 |
+
def decode(self, embed_ind):
|
| 162 |
+
quantize = self.dequantize(embed_ind)
|
| 163 |
+
return quantize
|
| 164 |
+
|
| 165 |
+
def forward(self, x, x_len, enable_vq=True, update_codebook=True, masking=False):
|
| 166 |
+
x_org, shape, dtype = x, x.shape, x.dtype
|
| 167 |
+
|
| 168 |
+
x = self.preprocess(x)
|
| 169 |
+
|
| 170 |
+
embed_ind, dist = self.quantize(x)
|
| 171 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 172 |
+
dist = dist.view(shape[0], shape[1], dist.shape[-1])
|
| 173 |
+
|
| 174 |
+
quantize = self.dequantize(embed_ind)
|
| 175 |
+
|
| 176 |
+
if self.training and update_codebook:
|
| 177 |
+
if enable_vq:
|
| 178 |
+
quantize = x_org + (quantize - x_org).detach()
|
| 179 |
+
else:
|
| 180 |
+
quantize = x_org
|
| 181 |
+
|
| 182 |
+
# Get flatten embedding indices and distances
|
| 183 |
+
if masking:
|
| 184 |
+
x_f = torch.cat(
|
| 185 |
+
[e[: int(e_len)] for e, e_len in zip(x_org, x_len)], dim=0
|
| 186 |
+
)
|
| 187 |
+
embed_ind_f = torch.cat(
|
| 188 |
+
[e[: int(e_len)] for e, e_len in zip(embed_ind, x_len)], dim=0
|
| 189 |
+
)
|
| 190 |
+
dist_f = torch.cat(
|
| 191 |
+
[e[: int(e_len)] for e, e_len in zip(dist, x_len)], dim=0
|
| 192 |
+
)
|
| 193 |
+
q_f = torch.cat(
|
| 194 |
+
[e[: int(e_len)] for e, e_len in zip(quantize.detach(), x_len)],
|
| 195 |
+
dim=0,
|
| 196 |
+
)
|
| 197 |
+
commit_loss = F.mse_loss(q_f, x_f)
|
| 198 |
+
else:
|
| 199 |
+
x_f = x_org.view(-1, x_org.shape[-1]).contiguous()
|
| 200 |
+
embed_ind_f = embed_ind.view(-1).contiguous()
|
| 201 |
+
dist_f = dist.view(-1).contiguous()
|
| 202 |
+
commit_loss = F.mse_loss(quantize.detach(), x_org)
|
| 203 |
+
self.init_embed_(x_f)
|
| 204 |
+
|
| 205 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 206 |
+
# and all the workers will take the same decision.
|
| 207 |
+
self.expire_codes_(x_f, dist_f)
|
| 208 |
+
|
| 209 |
+
# Calculate codebook statistics
|
| 210 |
+
embed_onehot = F.one_hot(embed_ind_f, self.codebook_size).type(dtype)
|
| 211 |
+
embed_onehot_sum = embed_onehot.sum(0)
|
| 212 |
+
embed_sum = x_f.t() @ embed_onehot
|
| 213 |
+
|
| 214 |
+
# EMA updating
|
| 215 |
+
ema_inplace(self.cluster_size, embed_onehot_sum, self.decay)
|
| 216 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 217 |
+
|
| 218 |
+
cluster_size = (
|
| 219 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
| 220 |
+
* self.cluster_size.sum()
|
| 221 |
+
)
|
| 222 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 223 |
+
self.embed.data.copy_(embed_normalized)
|
| 224 |
+
else:
|
| 225 |
+
commit_loss = torch.tensor(
|
| 226 |
+
0.0, device=quantize.device, requires_grad=self.training
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return quantize, commit_loss, embed_ind
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MultiHeadEuclideanCodebook(nn.Module):
|
| 233 |
+
|
| 234 |
+
def __init__(
|
| 235 |
+
self,
|
| 236 |
+
dim: Union[int, list],
|
| 237 |
+
codebook_size: list,
|
| 238 |
+
n_groups: int = 1,
|
| 239 |
+
dropout_rate_per_group: float = 0,
|
| 240 |
+
ordered: bool = False,
|
| 241 |
+
ordered_axis: str = "sequence",
|
| 242 |
+
method: str = "product",
|
| 243 |
+
**kwargs,
|
| 244 |
+
):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.codebook_sizes = codebook_size
|
| 247 |
+
self.codebook_dims = dim
|
| 248 |
+
self.n_groups = n_groups
|
| 249 |
+
self.n_heads_per_group = len(codebook_size) // n_groups
|
| 250 |
+
self.dropout_rate_per_group = dropout_rate_per_group
|
| 251 |
+
self.ordered = ordered
|
| 252 |
+
self.ordered_axis = ordered_axis
|
| 253 |
+
self.method = method
|
| 254 |
+
assert len(codebook_size) % n_groups == 0
|
| 255 |
+
|
| 256 |
+
self.codebooks = nn.ModuleList()
|
| 257 |
+
dim = self.codebook_dims
|
| 258 |
+
for i, size in enumerate(self.codebook_sizes):
|
| 259 |
+
if isinstance(self.codebook_dims, list):
|
| 260 |
+
dim = (
|
| 261 |
+
self.codebook_dims[i]
|
| 262 |
+
if method == "product"
|
| 263 |
+
else sum(self.codebook_dims)
|
| 264 |
+
)
|
| 265 |
+
self.codebooks.append(EuclideanCodebook(dim, size, **kwargs))
|
| 266 |
+
|
| 267 |
+
def decode(self, embed_ind):
|
| 268 |
+
if self.n_groups == 1 or len(embed_ind.shape) == 2:
|
| 269 |
+
embed_ind = embed_ind.unsqueeze(-1)
|
| 270 |
+
|
| 271 |
+
actual_n_groups = embed_ind.shape[-1]
|
| 272 |
+
if actual_n_groups < self.n_groups:
|
| 273 |
+
print(
|
| 274 |
+
f"The actual number of heads ({actual_n_groups}) is smaller than the pre-designed ({self.n_groups})!"
|
| 275 |
+
)
|
| 276 |
+
embed_ind = F.pad(
|
| 277 |
+
embed_ind, (0, self.n_groups - actual_n_groups), "replicate"
|
| 278 |
+
)
|
| 279 |
+
# assert embed_ind.shape[-1] == self.n_groups
|
| 280 |
+
|
| 281 |
+
index_heads, codebook_heads, scale_heads = zip(
|
| 282 |
+
*[
|
| 283 |
+
(
|
| 284 |
+
embed_ind[..., i // self.n_heads_per_group],
|
| 285 |
+
self.codebooks[i : i + self.n_heads_per_group],
|
| 286 |
+
self.codebook_sizes[i : i + self.n_heads_per_group],
|
| 287 |
+
)
|
| 288 |
+
for i in range(0, len(self.codebook_sizes), self.n_heads_per_group)
|
| 289 |
+
]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
quantize_heads, quantize_groups = [], []
|
| 293 |
+
for i in range(self.n_groups):
|
| 294 |
+
embed_ind, codebooks, scales = (
|
| 295 |
+
index_heads[i],
|
| 296 |
+
codebook_heads[i],
|
| 297 |
+
scale_heads[i],
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
inv_scales = list(torch.tensor([1] + scales[:-1]).cumprod(dim=0))[::-1]
|
| 301 |
+
inv_quantizes = []
|
| 302 |
+
for codebook, scale in zip(codebooks[::-1], inv_scales):
|
| 303 |
+
index, embed_ind = embed_ind // scale, embed_ind % scale
|
| 304 |
+
quantize = codebook.dequantize(index)
|
| 305 |
+
inv_quantizes.append(quantize)
|
| 306 |
+
quantizes = inv_quantizes[::-1]
|
| 307 |
+
group_embeddings = torch.cat(quantizes, dim=-1)
|
| 308 |
+
quantize_groups.append(group_embeddings)
|
| 309 |
+
quantize_heads += quantizes
|
| 310 |
+
|
| 311 |
+
if self.method == "product":
|
| 312 |
+
if actual_n_groups < self.n_groups:
|
| 313 |
+
for i in range(actual_n_groups, self.n_groups):
|
| 314 |
+
quantize_groups[i].zero_()
|
| 315 |
+
quantize = torch.cat(quantize_groups, dim=-1)
|
| 316 |
+
elif self.method == "residual":
|
| 317 |
+
quantize = sum(quantize_heads)
|
| 318 |
+
|
| 319 |
+
return quantize
|
| 320 |
+
|
| 321 |
+
def forward(self, x, x_len, enable_vq=True, update_codebook=True):
|
| 322 |
+
# Pre-process
|
| 323 |
+
x = self._preprocess(x)
|
| 324 |
+
|
| 325 |
+
# Quantize
|
| 326 |
+
quants, losses, indices = self._quantize(
|
| 327 |
+
x, x_len, enable_vq=enable_vq, update_codebook=update_codebook
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Integrate
|
| 331 |
+
quant, loss, index = self._integrate(
|
| 332 |
+
quants, losses, indices, update_codebook=update_codebook
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return quant, loss, index
|
| 336 |
+
|
| 337 |
+
def _preprocess(self, x):
|
| 338 |
+
if self.method == "product" and isinstance(self.codebook_dims, (list, tuple)):
|
| 339 |
+
x = torch.split(x, self.codebook_dims, dim=-1)
|
| 340 |
+
return x
|
| 341 |
+
|
| 342 |
+
def _quantize(self, x, x_len, enable_vq, update_codebook):
|
| 343 |
+
if self.method == "product":
|
| 344 |
+
quants, losses, indices = zip(
|
| 345 |
+
*[
|
| 346 |
+
codebook(
|
| 347 |
+
chunk,
|
| 348 |
+
x_len,
|
| 349 |
+
enable_vq=enable_vq,
|
| 350 |
+
update_codebook=update_codebook,
|
| 351 |
+
)
|
| 352 |
+
for chunk, codebook in zip(x, self.codebooks)
|
| 353 |
+
]
|
| 354 |
+
)
|
| 355 |
+
elif self.method == "residual":
|
| 356 |
+
quants, losses, indices = [], [], []
|
| 357 |
+
residual = x
|
| 358 |
+
for codebook in self.codebooks:
|
| 359 |
+
quant, loss, index = codebook(
|
| 360 |
+
residual,
|
| 361 |
+
x_len,
|
| 362 |
+
enable_vq=enable_vq,
|
| 363 |
+
update_codebook=update_codebook,
|
| 364 |
+
)
|
| 365 |
+
residual = residual - quant
|
| 366 |
+
quants.append(quant)
|
| 367 |
+
losses.append(loss)
|
| 368 |
+
indices.append(index)
|
| 369 |
+
|
| 370 |
+
return quants, losses, indices
|
| 371 |
+
|
| 372 |
+
def _integrate(self, quants, losses, indices, update_codebook=True):
|
| 373 |
+
(B, T, D), M = quants[0].shape, len(quants)
|
| 374 |
+
device = quants[0].device
|
| 375 |
+
|
| 376 |
+
# Average loss
|
| 377 |
+
loss = sum(losses) / len(losses)
|
| 378 |
+
|
| 379 |
+
# Get indices
|
| 380 |
+
if self.n_groups == 1:
|
| 381 |
+
scale = (
|
| 382 |
+
torch.tensor([1] + self.codebook_sizes[:-1]).cumprod(dim=0).to(device)
|
| 383 |
+
)
|
| 384 |
+
index = (torch.stack(indices, dim=-1) * scale).sum(dim=-1)
|
| 385 |
+
else:
|
| 386 |
+
index_heads, scale_heads = zip(
|
| 387 |
+
*[
|
| 388 |
+
(
|
| 389 |
+
indices[i : i + self.n_heads_per_group],
|
| 390 |
+
torch.tensor(
|
| 391 |
+
[1]
|
| 392 |
+
+ self.codebook_sizes[i : i + self.n_heads_per_group - 1]
|
| 393 |
+
)
|
| 394 |
+
.cumprod(dim=0)
|
| 395 |
+
.to(device),
|
| 396 |
+
)
|
| 397 |
+
for i in range(0, len(quants), self.n_heads_per_group)
|
| 398 |
+
]
|
| 399 |
+
)
|
| 400 |
+
index = torch.stack(
|
| 401 |
+
[
|
| 402 |
+
(torch.stack(x, dim=-1) * s).sum(dim=-1)
|
| 403 |
+
for x, s in zip(index_heads, scale_heads)
|
| 404 |
+
],
|
| 405 |
+
dim=-1,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Add dropout
|
| 409 |
+
quant_groups = self._dropout(quants, enabled=update_codebook)
|
| 410 |
+
|
| 411 |
+
# Combine quantized features
|
| 412 |
+
if self.method == "product":
|
| 413 |
+
quant = torch.cat(quant_groups, dim=-1)
|
| 414 |
+
elif self.method == "residual":
|
| 415 |
+
quant = torch.cat(quant_groups, dim=-1).view(B, T, M, D).sum(dim=2)
|
| 416 |
+
|
| 417 |
+
return quant, loss, index
|
| 418 |
+
|
| 419 |
+
def _dropout(self, quants, enabled=True):
|
| 420 |
+
if enabled and self.training and self.ordered:
|
| 421 |
+
if self.dropout_rate_per_group == 0:
|
| 422 |
+
threshold = [
|
| 423 |
+
(i // self.n_heads_per_group * 1.0 / self.n_groups)
|
| 424 |
+
for i in range(0, len(quants), self.n_heads_per_group)
|
| 425 |
+
]
|
| 426 |
+
elif self.dropout_rate_per_group == "exp":
|
| 427 |
+
x = [np.exp(4 * i / self.n_groups) for i in range(self.n_groups)]
|
| 428 |
+
x = np.asarray(x) / sum(x)
|
| 429 |
+
threshold = np.cumsum(np.asarray([0] + x))
|
| 430 |
+
else:
|
| 431 |
+
x = np.asarray(self.dropout_rate_per_group) / sum(
|
| 432 |
+
self.dropout_rate_per_group
|
| 433 |
+
)
|
| 434 |
+
threshold = np.cumsum(np.asarray([0] + x))
|
| 435 |
+
|
| 436 |
+
if self.ordered_axis == "sequence":
|
| 437 |
+
rate = torch.rand((quants[0].shape[0], 1, 1), device=quants[0].device)
|
| 438 |
+
elif self.ordered_axis == "frame":
|
| 439 |
+
rate = torch.rand(
|
| 440 |
+
(quants[0].shape[0], quants[0].shape[1], 1), device=quants[0].device
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
quant_groups = []
|
| 444 |
+
for i in range(0, len(quants), self.n_heads_per_group):
|
| 445 |
+
quant_group = torch.cat(quants[i : i + self.n_heads_per_group], dim=-1)
|
| 446 |
+
is_kept = threshold[i // self.n_heads_per_group] <= rate
|
| 447 |
+
quant_group = torch.where(
|
| 448 |
+
is_kept, quant_group, torch.zeros_like(quant_group)
|
| 449 |
+
)
|
| 450 |
+
quant_groups.append(quant_group)
|
| 451 |
+
elif self.ordered:
|
| 452 |
+
quant_groups = []
|
| 453 |
+
for i in range(0, len(quants), self.n_heads_per_group):
|
| 454 |
+
quant_group = torch.cat(quants[i : i + self.n_heads_per_group], dim=-1)
|
| 455 |
+
quant_groups.append(quant_group)
|
| 456 |
+
else:
|
| 457 |
+
quant_groups = quants
|
| 458 |
+
|
| 459 |
+
return quant_groups
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class VectorQuantization(nn.Module):
|
| 463 |
+
|
| 464 |
+
def __init__(
|
| 465 |
+
self,
|
| 466 |
+
dim: int,
|
| 467 |
+
codebook_size: Union[int, list],
|
| 468 |
+
codebook_dim: Union[int, list] = None,
|
| 469 |
+
decay: float = 0.99,
|
| 470 |
+
epsilon: float = 1e-5,
|
| 471 |
+
threshold_ema_dead_code: float = 1.0,
|
| 472 |
+
commitment_weight: float = 1.0,
|
| 473 |
+
requires_projection: bool = False,
|
| 474 |
+
norm: str = "none",
|
| 475 |
+
**kwargs,
|
| 476 |
+
):
|
| 477 |
+
super().__init__()
|
| 478 |
+
_codebook_dim: Union[int, list] = default(codebook_dim, dim)
|
| 479 |
+
|
| 480 |
+
requires_projection = _codebook_dim != dim or requires_projection
|
| 481 |
+
proj_dim = (
|
| 482 |
+
sum(_codebook_dim) if isinstance(_codebook_dim, list) else _codebook_dim
|
| 483 |
+
)
|
| 484 |
+
if requires_projection:
|
| 485 |
+
self.project_in = nn.Linear(dim, proj_dim)
|
| 486 |
+
self.project_out = nn.Linear(proj_dim, dim)
|
| 487 |
+
if norm == "weight_norm":
|
| 488 |
+
self.project_in = torch.nn.utils.weight_norm(self.project_in)
|
| 489 |
+
self.project_out = torch.nn.utils.weight_norm(self.project_out)
|
| 490 |
+
else:
|
| 491 |
+
self.norm = None
|
| 492 |
+
self.project_in = nn.Identity()
|
| 493 |
+
self.project_out = nn.Identity()
|
| 494 |
+
|
| 495 |
+
self.epsilon = epsilon
|
| 496 |
+
self.commitment_weight = commitment_weight
|
| 497 |
+
self.codebook_size = codebook_size
|
| 498 |
+
|
| 499 |
+
codebook_class = (
|
| 500 |
+
EuclideanCodebook
|
| 501 |
+
if isinstance(codebook_size, int)
|
| 502 |
+
else MultiHeadEuclideanCodebook
|
| 503 |
+
)
|
| 504 |
+
self._codebook = codebook_class(
|
| 505 |
+
dim=_codebook_dim,
|
| 506 |
+
codebook_size=codebook_size,
|
| 507 |
+
decay=decay,
|
| 508 |
+
epsilon=epsilon,
|
| 509 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 510 |
+
**kwargs,
|
| 511 |
+
)
|
| 512 |
+
self.codebook_size = codebook_size
|
| 513 |
+
|
| 514 |
+
@property
|
| 515 |
+
def codebook(self):
|
| 516 |
+
return self._codebook.embed
|
| 517 |
+
|
| 518 |
+
def encode(self, x, x_len=None):
|
| 519 |
+
x = rearrange(x, "b d n -> b n d")
|
| 520 |
+
x = self.project_in(x)
|
| 521 |
+
embed_in = self._codebook.encode(x)
|
| 522 |
+
return embed_in
|
| 523 |
+
|
| 524 |
+
def decode(self, embed_ind, embed_len=None):
|
| 525 |
+
quantize = self._codebook.decode(embed_ind)
|
| 526 |
+
quantize = self.project_out(quantize)
|
| 527 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 528 |
+
return quantize
|
| 529 |
+
|
| 530 |
+
def decode_latent(self, latent, latent_len=None):
|
| 531 |
+
if latent_len is None:
|
| 532 |
+
latent_len = (
|
| 533 |
+
torch.Tensor([latent.shape[1]] * latent.shape[0])
|
| 534 |
+
.to(latent.device)
|
| 535 |
+
.int()
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
quantize, _, _ = self._codebook(latent, latent_len)
|
| 539 |
+
quantize = self.project_out(quantize)
|
| 540 |
+
return quantize
|
| 541 |
+
|
| 542 |
+
@torch.cuda.amp.autocast(dtype=torch.float32)
|
| 543 |
+
def forward(
|
| 544 |
+
self,
|
| 545 |
+
x,
|
| 546 |
+
x_len,
|
| 547 |
+
enable_vq=True,
|
| 548 |
+
update_codebook=True,
|
| 549 |
+
return_pre_quant=False,
|
| 550 |
+
return_dict=False,
|
| 551 |
+
):
|
| 552 |
+
device = x.device
|
| 553 |
+
|
| 554 |
+
x = self.project_in(x)
|
| 555 |
+
|
| 556 |
+
quantize, commit_loss, embed_ind = self._codebook(
|
| 557 |
+
x, x_len, enable_vq=enable_vq, update_codebook=update_codebook
|
| 558 |
+
)
|
| 559 |
+
if self.training and update_codebook:
|
| 560 |
+
loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 561 |
+
if self.commitment_weight > 0:
|
| 562 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 563 |
+
else:
|
| 564 |
+
loss = torch.tensor(0.0, device=device, requires_grad=False)
|
| 565 |
+
|
| 566 |
+
embed = quantize
|
| 567 |
+
quantize = self.project_out(quantize)
|
| 568 |
+
|
| 569 |
+
if return_dict:
|
| 570 |
+
return {
|
| 571 |
+
"quantize": quantize,
|
| 572 |
+
"loss": loss,
|
| 573 |
+
"embed": embed,
|
| 574 |
+
"embed_ind": embed_ind,
|
| 575 |
+
}
|
| 576 |
+
elif return_pre_quant:
|
| 577 |
+
pre_quantize = self.project_out(x)
|
| 578 |
+
return (pre_quantize, quantize), loss, embed_ind
|
| 579 |
+
else:
|
| 580 |
+
return quantize, loss, embed_ind
|
fireredtts/modules/acoustic_llm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .acoustic_llm import AcousticLLM
|
fireredtts/modules/acoustic_llm/acoustic_llm.py
ADDED
|
@@ -0,0 +1,876 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange
|
| 2 |
+
from time import time
|
| 3 |
+
from torch.utils.checkpoint import checkpoint
|
| 4 |
+
from transformers import (
|
| 5 |
+
GPT2Config,
|
| 6 |
+
GPT2Model,
|
| 7 |
+
GPT2PreTrainedModel,
|
| 8 |
+
LogitsProcessorList,
|
| 9 |
+
LogitsWarper,
|
| 10 |
+
StoppingCriteria,
|
| 11 |
+
StoppingCriteriaList,
|
| 12 |
+
)
|
| 13 |
+
from transformers.generation.streamers import BaseStreamer
|
| 14 |
+
from transformers.generation.utils import (
|
| 15 |
+
GenerationConfig,
|
| 16 |
+
GenerateDecoderOnlyOutput,
|
| 17 |
+
GenerateEncoderDecoderOutput,
|
| 18 |
+
GenerateNonBeamOutput,
|
| 19 |
+
)
|
| 20 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 21 |
+
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
| 22 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import functools
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MultiHeadRepetitionPenaltyLogitsProcessor(LogitsWarper):
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self, penalty: float = 2.0, n_heads: int = 4, n_frames: int = -1, start_index=0
|
| 35 |
+
):
|
| 36 |
+
if not isinstance(penalty, float) or not (penalty > 0):
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"`penalty` has to be a strictly positive float, but is {penalty}"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.penalty = penalty
|
| 42 |
+
self.n_heads = n_heads
|
| 43 |
+
self.n_frames = n_frames
|
| 44 |
+
self.start_index = start_index
|
| 45 |
+
|
| 46 |
+
def __call__(
|
| 47 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
| 48 |
+
) -> torch.FloatTensor:
|
| 49 |
+
input_ids = input_ids[:, self.start_index :]
|
| 50 |
+
if input_ids.size(1) == 0:
|
| 51 |
+
return scores
|
| 52 |
+
|
| 53 |
+
if self.n_frames <= 0:
|
| 54 |
+
input_ids = torch.flip(input_ids, [1])[:, self.n_heads - 1 :: self.n_heads]
|
| 55 |
+
else:
|
| 56 |
+
input_ids = torch.flip(input_ids, [1])[
|
| 57 |
+
:, self.n_heads - 1 : self.n_heads * self.n_frames : self.n_heads
|
| 58 |
+
]
|
| 59 |
+
score = torch.gather(scores, 1, input_ids)
|
| 60 |
+
|
| 61 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
| 62 |
+
if self.penalty > 100:
|
| 63 |
+
score = torch.full_like(score, -1e3)
|
| 64 |
+
else:
|
| 65 |
+
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
| 66 |
+
|
| 67 |
+
scores.scatter_(1, input_ids, score)
|
| 68 |
+
return scores
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def null_position_embeddings(range, dim):
|
| 72 |
+
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FixedStoppingCriteria(StoppingCriteria):
|
| 76 |
+
|
| 77 |
+
def __init__(self, running_steps, start_index=0):
|
| 78 |
+
self.running_steps = running_steps
|
| 79 |
+
self.start_index = start_index
|
| 80 |
+
|
| 81 |
+
def __call__(
|
| 82 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 83 |
+
) -> torch.BoolTensor:
|
| 84 |
+
assert input_ids.shape[0] == 1, input_ids.shape
|
| 85 |
+
if input_ids.shape[1] - self.start_index >= self.running_steps:
|
| 86 |
+
return torch.tensor([True]).to(input_ids.device)
|
| 87 |
+
return torch.tensor([False]).to(input_ids.device)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class DelayStoppingCriteria(StoppingCriteria):
|
| 91 |
+
|
| 92 |
+
def __init__(self, eos_token_id, delay_steps):
|
| 93 |
+
self.delay_steps = delay_steps
|
| 94 |
+
self.eos_token_id = torch.tensor(eos_token_id)
|
| 95 |
+
|
| 96 |
+
def __call__(
|
| 97 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 98 |
+
) -> torch.BoolTensor:
|
| 99 |
+
assert input_ids.shape[0] == 1, input_ids.shape
|
| 100 |
+
|
| 101 |
+
if (input_ids == self.eos_token_id).any():
|
| 102 |
+
index = (input_ids[0] == self.eos_token_id).nonzero(as_tuple=True)[0][0]
|
| 103 |
+
if index + self.delay_steps < input_ids.shape[1]:
|
| 104 |
+
return torch.tensor([True]).to(input_ids.device)
|
| 105 |
+
return torch.tensor([False]).to(input_ids.device)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class SuppressionLogitsProcessor(LogitsWarper):
|
| 109 |
+
|
| 110 |
+
def __init__(self, suppressed_ids=[]):
|
| 111 |
+
self.suppressed_ids = suppressed_ids
|
| 112 |
+
|
| 113 |
+
def __call__(
|
| 114 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
| 115 |
+
) -> torch.FloatTensor:
|
| 116 |
+
for sid in self.suppressed_ids:
|
| 117 |
+
scores[..., sid] = scores.min()
|
| 118 |
+
return scores
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class MHGPT2InferenceModel(GPT2PreTrainedModel):
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True
|
| 125 |
+
):
|
| 126 |
+
super().__init__(config)
|
| 127 |
+
self.transformer = gpt
|
| 128 |
+
self.text_pos_embedding = text_pos_emb
|
| 129 |
+
self.embeddings = embeddings
|
| 130 |
+
self.lm_head = nn.ModuleList([norm, linear]) # nn.Sequential(norm, linear)
|
| 131 |
+
self.kv_cache = kv_cache
|
| 132 |
+
|
| 133 |
+
# Multi-head configuration
|
| 134 |
+
self.n_heads = len(linear)
|
| 135 |
+
|
| 136 |
+
# Model parallel
|
| 137 |
+
self.model_parallel = False
|
| 138 |
+
self.device_map = None
|
| 139 |
+
self.cached_mel_emb = None
|
| 140 |
+
self.cached_mel_parallel_emb = None
|
| 141 |
+
|
| 142 |
+
def store_mel_emb(self, mel_emb):
|
| 143 |
+
self.cached_mel_emb = mel_emb
|
| 144 |
+
|
| 145 |
+
def store_mel_parallel_emb(self, mel_emb):
|
| 146 |
+
self.cached_mel_parallel_emb = mel_emb
|
| 147 |
+
|
| 148 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
| 149 |
+
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
| 150 |
+
if not self.kv_cache:
|
| 151 |
+
past_key_values = None
|
| 152 |
+
|
| 153 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 154 |
+
position_ids = kwargs.get("position_ids", None)
|
| 155 |
+
|
| 156 |
+
if attention_mask is not None and position_ids is None:
|
| 157 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 158 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 159 |
+
if past_key_values:
|
| 160 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 161 |
+
else:
|
| 162 |
+
position_ids = None
|
| 163 |
+
return {
|
| 164 |
+
"input_ids": input_ids,
|
| 165 |
+
"past_key_values": past_key_values,
|
| 166 |
+
"use_cache": kwargs.get("use_cache"),
|
| 167 |
+
"position_ids": position_ids,
|
| 168 |
+
"attention_mask": attention_mask,
|
| 169 |
+
"token_type_ids": token_type_ids,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
def forward(
|
| 173 |
+
self,
|
| 174 |
+
input_ids=None,
|
| 175 |
+
past_key_values=None,
|
| 176 |
+
attention_mask=None,
|
| 177 |
+
token_type_ids=None,
|
| 178 |
+
position_ids=None,
|
| 179 |
+
head_mask=None,
|
| 180 |
+
inputs_embeds=None,
|
| 181 |
+
encoder_hidden_states=None,
|
| 182 |
+
encoder_attention_mask=None,
|
| 183 |
+
labels=None,
|
| 184 |
+
use_cache=None,
|
| 185 |
+
output_attentions=None,
|
| 186 |
+
output_hidden_states=None,
|
| 187 |
+
return_dict=None,
|
| 188 |
+
):
|
| 189 |
+
assert self.cached_mel_emb is not None
|
| 190 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
| 191 |
+
assert labels is None # Training not supported by this inference model.
|
| 192 |
+
return_dict = (
|
| 193 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 194 |
+
)
|
| 195 |
+
# Create embedding
|
| 196 |
+
mel_len = self.cached_mel_emb.shape[1]
|
| 197 |
+
attention_mask = None
|
| 198 |
+
position_ids = None
|
| 199 |
+
|
| 200 |
+
if input_ids.shape[1] != 1 and past_key_values is None:
|
| 201 |
+
text_inputs = input_ids[:, mel_len:]
|
| 202 |
+
text_emb = sum(
|
| 203 |
+
[self.embeddings[i](text_inputs[:, :, i]) for i in range(self.n_heads)]
|
| 204 |
+
)
|
| 205 |
+
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
| 206 |
+
|
| 207 |
+
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
| 208 |
+
mel_emb = self.cached_mel_emb.repeat_interleave(
|
| 209 |
+
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
|
| 210 |
+
)
|
| 211 |
+
else: # this outcome only occurs once per loop in most cases
|
| 212 |
+
mel_emb = self.cached_mel_emb
|
| 213 |
+
|
| 214 |
+
if self.cached_mel_parallel_emb is not None:
|
| 215 |
+
text_emb = (
|
| 216 |
+
text_emb + self.cached_mel_parallel_emb[:, : text_emb.shape[1]]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
emb = torch.cat([mel_emb, text_emb], dim=1)
|
| 220 |
+
else: # KV-cache mode
|
| 221 |
+
text_inputs = input_ids[:, mel_len:]
|
| 222 |
+
emb = sum(
|
| 223 |
+
[self.embeddings[i](text_inputs[:, -1, i]) for i in range(self.n_heads)]
|
| 224 |
+
)
|
| 225 |
+
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
| 226 |
+
text_inputs.shape[1] - 1, emb.device
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if self.cached_mel_parallel_emb is not None:
|
| 230 |
+
emb = emb + self.cached_mel_parallel_emb[:, text_inputs.shape[1] - 1]
|
| 231 |
+
|
| 232 |
+
transformer_outputs = self.transformer(
|
| 233 |
+
inputs_embeds=emb,
|
| 234 |
+
past_key_values=past_key_values,
|
| 235 |
+
attention_mask=attention_mask,
|
| 236 |
+
token_type_ids=token_type_ids,
|
| 237 |
+
position_ids=position_ids,
|
| 238 |
+
head_mask=head_mask,
|
| 239 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 240 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 241 |
+
use_cache=use_cache,
|
| 242 |
+
output_attentions=True, # output_attentions,
|
| 243 |
+
output_hidden_states=output_hidden_states,
|
| 244 |
+
return_dict=return_dict,
|
| 245 |
+
)
|
| 246 |
+
hidden_states = transformer_outputs[0]
|
| 247 |
+
past_key_values = transformer_outputs.past_key_values
|
| 248 |
+
output_hidden_states = transformer_outputs.hidden_states
|
| 249 |
+
output_attentions = transformer_outputs.attentions
|
| 250 |
+
|
| 251 |
+
# Set device for model parallelism
|
| 252 |
+
if self.model_parallel:
|
| 253 |
+
if torch.backends.mps.is_available():
|
| 254 |
+
self.to(self.transformer.first_device)
|
| 255 |
+
else:
|
| 256 |
+
torch.cuda.set_device(self.transformer.first_device)
|
| 257 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
| 258 |
+
|
| 259 |
+
lm_logits = self.lm_head[0](hidden_states)
|
| 260 |
+
lm_logits = [head(lm_logits) for head in self.lm_head[1]]
|
| 261 |
+
lm_logits = torch.stack(lm_logits, dim=2)
|
| 262 |
+
|
| 263 |
+
if not return_dict:
|
| 264 |
+
return (lm_logits,) + transformer_outputs[1:]
|
| 265 |
+
|
| 266 |
+
output = CausalLMOutputWithCrossAttentions(
|
| 267 |
+
loss=None,
|
| 268 |
+
logits=lm_logits,
|
| 269 |
+
past_key_values=past_key_values,
|
| 270 |
+
hidden_states=output_hidden_states,
|
| 271 |
+
attentions=output_attentions,
|
| 272 |
+
)
|
| 273 |
+
return output
|
| 274 |
+
|
| 275 |
+
def _sample(
|
| 276 |
+
self,
|
| 277 |
+
input_ids: torch.LongTensor,
|
| 278 |
+
logits_processor: LogitsProcessorList,
|
| 279 |
+
stopping_criteria: StoppingCriteriaList,
|
| 280 |
+
generation_config: GenerationConfig,
|
| 281 |
+
synced_gpus: bool,
|
| 282 |
+
streamer: Optional["BaseStreamer"],
|
| 283 |
+
**model_kwargs,
|
| 284 |
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 285 |
+
r"""
|
| 286 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
| 287 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
| 288 |
+
|
| 289 |
+
Parameters:
|
| 290 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 291 |
+
The sequence used as a prompt for the generation.
|
| 292 |
+
logits_processor (`LogitsProcessorList`):
|
| 293 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 294 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 295 |
+
stopping_criteria (`StoppingCriteriaList`):
|
| 296 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
| 297 |
+
used to tell if the generation loop should stop.
|
| 298 |
+
generation_config ([`~generation.GenerationConfig`]):
|
| 299 |
+
The generation configuration to be used as parametrization of the decoding method.
|
| 300 |
+
synced_gpus (`bool`):
|
| 301 |
+
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 302 |
+
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 303 |
+
streamer (`BaseStreamer`, *optional*):
|
| 304 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
| 305 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
| 306 |
+
model_kwargs:
|
| 307 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
| 308 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
| 309 |
+
|
| 310 |
+
Return:
|
| 311 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
| 312 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
| 313 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
| 314 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
| 315 |
+
`model.config.is_encoder_decoder=True`.
|
| 316 |
+
"""
|
| 317 |
+
# init values
|
| 318 |
+
pad_token_id = generation_config._pad_token_tensor
|
| 319 |
+
output_attentions = generation_config.output_attentions
|
| 320 |
+
output_hidden_states = generation_config.output_hidden_states
|
| 321 |
+
output_scores = generation_config.output_scores
|
| 322 |
+
output_logits = generation_config.output_logits
|
| 323 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 324 |
+
max_length = generation_config.max_length
|
| 325 |
+
has_eos_stopping_criteria = any(
|
| 326 |
+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
|
| 327 |
+
)
|
| 328 |
+
do_sample = generation_config.do_sample
|
| 329 |
+
|
| 330 |
+
# init attention / hidden states / scores tuples
|
| 331 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
| 332 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
| 333 |
+
decoder_attentions = (
|
| 334 |
+
() if (return_dict_in_generate and output_attentions) else None
|
| 335 |
+
)
|
| 336 |
+
cross_attentions = (
|
| 337 |
+
() if (return_dict_in_generate and output_attentions) else None
|
| 338 |
+
)
|
| 339 |
+
decoder_hidden_states = (
|
| 340 |
+
() if (return_dict_in_generate and output_hidden_states) else None
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
| 344 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
| 345 |
+
encoder_attentions = (
|
| 346 |
+
model_kwargs["encoder_outputs"].get("attentions")
|
| 347 |
+
if output_attentions
|
| 348 |
+
else None
|
| 349 |
+
)
|
| 350 |
+
encoder_hidden_states = (
|
| 351 |
+
model_kwargs["encoder_outputs"].get("hidden_states")
|
| 352 |
+
if output_hidden_states
|
| 353 |
+
else None
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# keep track of which sequences are already finished
|
| 357 |
+
batch_size, cur_len, num_streams = input_ids.shape
|
| 358 |
+
this_peer_finished = False
|
| 359 |
+
unfinished_sequences = torch.ones(
|
| 360 |
+
batch_size, dtype=torch.long, device=input_ids.device
|
| 361 |
+
)
|
| 362 |
+
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
| 363 |
+
|
| 364 |
+
while self._has_unfinished_sequences(
|
| 365 |
+
this_peer_finished,
|
| 366 |
+
synced_gpus,
|
| 367 |
+
device=input_ids.device,
|
| 368 |
+
cur_len=cur_len,
|
| 369 |
+
max_length=max_length,
|
| 370 |
+
):
|
| 371 |
+
# prepare model inputs
|
| 372 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 373 |
+
|
| 374 |
+
# prepare variable output controls (note: some models won't accept all output controls)
|
| 375 |
+
model_inputs.update(
|
| 376 |
+
{"output_attentions": output_attentions} if output_attentions else {}
|
| 377 |
+
)
|
| 378 |
+
model_inputs.update(
|
| 379 |
+
{"output_hidden_states": output_hidden_states}
|
| 380 |
+
if output_hidden_states
|
| 381 |
+
else {}
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# forward pass to get next token
|
| 385 |
+
outputs = self(**model_inputs, return_dict=True)
|
| 386 |
+
|
| 387 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 388 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 389 |
+
outputs,
|
| 390 |
+
model_kwargs,
|
| 391 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
| 392 |
+
)
|
| 393 |
+
if synced_gpus and this_peer_finished:
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
| 397 |
+
# (the clone itself is always small)
|
| 398 |
+
next_token_logits = outputs.logits.clone()[:, -1].float()
|
| 399 |
+
next_token_logits = next_token_logits.to(input_ids.device)
|
| 400 |
+
|
| 401 |
+
# pre-process distribution
|
| 402 |
+
batch_size, seq_len, num_streams = input_ids.shape
|
| 403 |
+
rearrange_input_ids = rearrange(input_ids, "b l n -> (b n) l")
|
| 404 |
+
next_token_logits = rearrange(next_token_logits, "b n d -> (b n) d")
|
| 405 |
+
next_token_scores = logits_processor(rearrange_input_ids, next_token_logits)
|
| 406 |
+
next_token_scores = rearrange(
|
| 407 |
+
next_token_scores, "(b n) d -> b n d", b=batch_size
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Store scores, attentions and hidden_states when required
|
| 411 |
+
if return_dict_in_generate:
|
| 412 |
+
if output_scores:
|
| 413 |
+
scores += (next_token_scores,)
|
| 414 |
+
if output_logits:
|
| 415 |
+
raw_logits += (next_token_logits,)
|
| 416 |
+
if output_attentions:
|
| 417 |
+
decoder_attentions += (
|
| 418 |
+
(outputs.decoder_attentions,)
|
| 419 |
+
if self.config.is_encoder_decoder
|
| 420 |
+
else (outputs.attentions,)
|
| 421 |
+
)
|
| 422 |
+
if self.config.is_encoder_decoder:
|
| 423 |
+
cross_attentions += (outputs.cross_attentions,)
|
| 424 |
+
|
| 425 |
+
if output_hidden_states:
|
| 426 |
+
decoder_hidden_states += (
|
| 427 |
+
(outputs.decoder_hidden_states,)
|
| 428 |
+
if self.config.is_encoder_decoder
|
| 429 |
+
else (outputs.hidden_states,)
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# token selection
|
| 433 |
+
if do_sample:
|
| 434 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 435 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
| 436 |
+
probs = probs.view(-1, probs.shape[-1])
|
| 437 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 438 |
+
next_tokens = next_tokens.view(*next_token_scores.shape[:-1])
|
| 439 |
+
else:
|
| 440 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 441 |
+
|
| 442 |
+
# finished sentences should have their next token be a padding token
|
| 443 |
+
if has_eos_stopping_criteria:
|
| 444 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
| 445 |
+
1 - unfinished_sequences
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# update generated ids, model inputs, and length for next step
|
| 449 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=1)
|
| 450 |
+
if streamer is not None:
|
| 451 |
+
streamer.put(next_tokens.cpu())
|
| 452 |
+
|
| 453 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
|
| 454 |
+
input_ids, scores
|
| 455 |
+
)
|
| 456 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
| 457 |
+
cur_len += 1
|
| 458 |
+
|
| 459 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 460 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 461 |
+
del outputs
|
| 462 |
+
|
| 463 |
+
if streamer is not None:
|
| 464 |
+
streamer.end()
|
| 465 |
+
|
| 466 |
+
if return_dict_in_generate:
|
| 467 |
+
if self.config.is_encoder_decoder:
|
| 468 |
+
return GenerateEncoderDecoderOutput(
|
| 469 |
+
sequences=input_ids,
|
| 470 |
+
scores=scores,
|
| 471 |
+
logits=raw_logits,
|
| 472 |
+
encoder_attentions=encoder_attentions,
|
| 473 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 474 |
+
decoder_attentions=decoder_attentions,
|
| 475 |
+
cross_attentions=cross_attentions,
|
| 476 |
+
decoder_hidden_states=decoder_hidden_states,
|
| 477 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
return GenerateDecoderOnlyOutput(
|
| 481 |
+
sequences=input_ids,
|
| 482 |
+
scores=scores,
|
| 483 |
+
logits=raw_logits,
|
| 484 |
+
attentions=decoder_attentions,
|
| 485 |
+
hidden_states=decoder_hidden_states,
|
| 486 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
return input_ids
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class LearnedPositionEmbeddings(nn.Module):
|
| 493 |
+
|
| 494 |
+
def __init__(self, seq_len, model_dim, init=0.02):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.emb = nn.Embedding(seq_len, model_dim)
|
| 497 |
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
| 498 |
+
|
| 499 |
+
def forward(self, x):
|
| 500 |
+
sl = x.shape[1]
|
| 501 |
+
return self.emb(torch.arange(0, sl, device=x.device))
|
| 502 |
+
|
| 503 |
+
def get_fixed_embedding(self, ind, dev):
|
| 504 |
+
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def build_hf_gpt_transformer(
|
| 508 |
+
layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
|
| 509 |
+
):
|
| 510 |
+
gpt_config = GPT2Config(
|
| 511 |
+
vocab_size=256,
|
| 512 |
+
n_positions=max_mel_seq_len + max_text_seq_len,
|
| 513 |
+
n_ctx=max_mel_seq_len + max_text_seq_len,
|
| 514 |
+
n_embd=model_dim,
|
| 515 |
+
n_layer=layers,
|
| 516 |
+
n_head=heads,
|
| 517 |
+
use_cache=not checkpointing,
|
| 518 |
+
scale_attn_by_inverse_layer_idx=True,
|
| 519 |
+
reorder_and_upcast_attn=True,
|
| 520 |
+
attn_implementation="sdpa",
|
| 521 |
+
)
|
| 522 |
+
gpt = GPT2Model(gpt_config)
|
| 523 |
+
|
| 524 |
+
if checkpointing:
|
| 525 |
+
gpt.gradient_checkpointing_enable()
|
| 526 |
+
|
| 527 |
+
del gpt.wpe, gpt.wte
|
| 528 |
+
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
| 529 |
+
mel_pos_embs = LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
| 530 |
+
|
| 531 |
+
return gpt, mel_pos_embs
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class AcousticLLM(nn.Module):
|
| 535 |
+
|
| 536 |
+
def __init__(
|
| 537 |
+
self,
|
| 538 |
+
# Model
|
| 539 |
+
n_stacks=2,
|
| 540 |
+
layers=12,
|
| 541 |
+
model_dim=1024,
|
| 542 |
+
heads=16,
|
| 543 |
+
# Text
|
| 544 |
+
max_text_tokens=120,
|
| 545 |
+
number_text_tokens=8194,
|
| 546 |
+
start_text_token=8192,
|
| 547 |
+
stop_text_token=8193,
|
| 548 |
+
# Speech
|
| 549 |
+
n_frames_per_step=4,
|
| 550 |
+
n_heads_per_frame=1,
|
| 551 |
+
max_speech_tokens=250,
|
| 552 |
+
number_speech_tokens=8194,
|
| 553 |
+
start_speech_token=8192,
|
| 554 |
+
stop_speech_token=8193,
|
| 555 |
+
# CoS Prediction
|
| 556 |
+
streaming=False,
|
| 557 |
+
streaming_delayed_frames=4,
|
| 558 |
+
accumulative_speech_embedding=False,
|
| 559 |
+
upsample_factors=2,
|
| 560 |
+
# Reference embedding
|
| 561 |
+
max_conditioning_inputs=1,
|
| 562 |
+
speaker_embedding_pretrained=True,
|
| 563 |
+
speaker_embedding_ckpt=None,
|
| 564 |
+
speaker_embedding_dim=256,
|
| 565 |
+
# For training
|
| 566 |
+
checkpointing=True,
|
| 567 |
+
loss_weights=1.0,
|
| 568 |
+
# For inference
|
| 569 |
+
delay_prediction=1,
|
| 570 |
+
temperature=0.3,
|
| 571 |
+
length_penalty=1.0,
|
| 572 |
+
repetition_penalty=2.0,
|
| 573 |
+
top_p=0.2,
|
| 574 |
+
top_k=50,
|
| 575 |
+
):
|
| 576 |
+
super().__init__()
|
| 577 |
+
self.n_stacks = n_stacks
|
| 578 |
+
self.number_text_tokens = number_text_tokens
|
| 579 |
+
self.start_text_token = start_text_token
|
| 580 |
+
self.stop_text_token = stop_text_token
|
| 581 |
+
self.number_speech_tokens = number_speech_tokens
|
| 582 |
+
self.start_speech_token = start_speech_token
|
| 583 |
+
self.stop_speech_token = stop_speech_token
|
| 584 |
+
self.layers = layers
|
| 585 |
+
self.heads = heads
|
| 586 |
+
|
| 587 |
+
self.streaming = streaming
|
| 588 |
+
self.streaming_delayed_frames = streaming_delayed_frames
|
| 589 |
+
self.accumulative_speech_embedding = accumulative_speech_embedding
|
| 590 |
+
self.upsample_factors = upsample_factors
|
| 591 |
+
|
| 592 |
+
self.n_frames_per_step = n_frames_per_step
|
| 593 |
+
self.n_heads_per_frame = n_heads_per_frame
|
| 594 |
+
self.number_speech_heads = n_heads_per_frame * n_frames_per_step
|
| 595 |
+
|
| 596 |
+
self.max_speech_tokens = max_speech_tokens
|
| 597 |
+
self.max_text_tokens = max_text_tokens
|
| 598 |
+
self.model_dim = model_dim
|
| 599 |
+
self.max_conditioning_inputs = max_conditioning_inputs
|
| 600 |
+
|
| 601 |
+
self.speaker_embedding_pretrained = speaker_embedding_pretrained
|
| 602 |
+
self.speaker_embedding_ckpt = speaker_embedding_ckpt
|
| 603 |
+
self.speaker_embedding_dim = speaker_embedding_dim
|
| 604 |
+
|
| 605 |
+
# For training
|
| 606 |
+
self.loss_weights = loss_weights
|
| 607 |
+
|
| 608 |
+
# For inference
|
| 609 |
+
self.delay_prediction = delay_prediction
|
| 610 |
+
self.temperature = temperature
|
| 611 |
+
self.length_penalty = length_penalty
|
| 612 |
+
self.repetition_penalty = repetition_penalty
|
| 613 |
+
self.top_p = top_p
|
| 614 |
+
self.top_k = top_k
|
| 615 |
+
|
| 616 |
+
# Conditional embedding
|
| 617 |
+
self.reference_embedding = nn.Sequential(
|
| 618 |
+
nn.Linear(speaker_embedding_dim, 256),
|
| 619 |
+
nn.Tanh(),
|
| 620 |
+
nn.Linear(256, model_dim),
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
self.text_embedding = nn.Embedding(self.number_text_tokens + 1, model_dim)
|
| 624 |
+
self.text_embedding.weight.data.normal_(mean=0.0, std=0.02)
|
| 625 |
+
|
| 626 |
+
self.mel_embedding = nn.ModuleList(
|
| 627 |
+
[
|
| 628 |
+
nn.Embedding(self.number_speech_tokens, model_dim)
|
| 629 |
+
for _ in range(self.number_speech_heads)
|
| 630 |
+
]
|
| 631 |
+
)
|
| 632 |
+
for module in self.mel_embedding:
|
| 633 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 634 |
+
|
| 635 |
+
# Build GPTs
|
| 636 |
+
self.gpt, self.mel_pos_embedding = build_hf_gpt_transformer(
|
| 637 |
+
layers,
|
| 638 |
+
model_dim,
|
| 639 |
+
heads,
|
| 640 |
+
self.max_speech_tokens + 2 + self.max_conditioning_inputs,
|
| 641 |
+
self.max_text_tokens + 2,
|
| 642 |
+
checkpointing,
|
| 643 |
+
)
|
| 644 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
| 645 |
+
self.mel_head = nn.ModuleList(
|
| 646 |
+
[
|
| 647 |
+
nn.Linear(model_dim, self.number_speech_tokens)
|
| 648 |
+
for _ in range(self.number_speech_heads)
|
| 649 |
+
]
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=True, half=False):
|
| 653 |
+
seq_length = self.max_speech_tokens + self.max_text_tokens + 2
|
| 654 |
+
gpt_config = GPT2Config(
|
| 655 |
+
vocab_size=self.max_speech_tokens,
|
| 656 |
+
n_positions=seq_length,
|
| 657 |
+
n_ctx=seq_length,
|
| 658 |
+
n_embd=self.model_dim,
|
| 659 |
+
n_layer=self.layers,
|
| 660 |
+
n_head=self.heads,
|
| 661 |
+
gradient_checkpointing=False,
|
| 662 |
+
use_cache=True,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
self.inference_model = MHGPT2InferenceModel(
|
| 666 |
+
gpt_config,
|
| 667 |
+
self.gpt,
|
| 668 |
+
self.mel_pos_embedding,
|
| 669 |
+
self.mel_embedding,
|
| 670 |
+
self.final_norm,
|
| 671 |
+
self.mel_head,
|
| 672 |
+
kv_cache=kv_cache,
|
| 673 |
+
)
|
| 674 |
+
self.inference_model.eval()
|
| 675 |
+
|
| 676 |
+
def build_aligned_inputs_and_targets(
|
| 677 |
+
self, seqs, lens, start_token, stop_token, delay=0
|
| 678 |
+
):
|
| 679 |
+
for i in range(seqs.shape[0]):
|
| 680 |
+
seqs[i, lens[i] :] = stop_token
|
| 681 |
+
|
| 682 |
+
if len(seqs.shape) == 2:
|
| 683 |
+
inp = F.pad(
|
| 684 |
+
seqs, (self.streaming_delayed_frames, 0), value=start_token
|
| 685 |
+
).type_as(seqs)
|
| 686 |
+
inp = F.pad(inp, (0, 1), value=stop_token).type_as(inp)
|
| 687 |
+
tar = F.pad(inp[:, 1:], (0, 1), value=stop_token).type_as(seqs)
|
| 688 |
+
else:
|
| 689 |
+
inp = F.pad(
|
| 690 |
+
seqs, (0, 0, self.streaming_delayed_frames, 0), value=start_token
|
| 691 |
+
).type_as(seqs)
|
| 692 |
+
inp = F.pad(inp, (0, 0, 0, 1), value=stop_token).type_as(inp)
|
| 693 |
+
tar = F.pad(inp[:, 1:], (0, 0, 0, 1), value=stop_token).type_as(seqs)
|
| 694 |
+
|
| 695 |
+
if delay > 0:
|
| 696 |
+
pad_size = delay * (inp.shape[2] - 1)
|
| 697 |
+
L = inp.shape[1] + pad_size
|
| 698 |
+
inp = F.pad(inp, (0, 0, pad_size, 0), value=start_token).type_as(inp)
|
| 699 |
+
inp = F.pad(inp, (0, 0, 0, pad_size), value=stop_token).type_as(inp)
|
| 700 |
+
inp = torch.stack(
|
| 701 |
+
[
|
| 702 |
+
inp[:, pad_size - i * delay : pad_size - i * delay + L, i]
|
| 703 |
+
for i in range(inp.shape[-1])
|
| 704 |
+
],
|
| 705 |
+
dim=-1,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
tar = F.pad(tar, (0, 0, pad_size, 0), value=start_token).type_as(tar)
|
| 709 |
+
tar = F.pad(tar, (0, 0, 0, pad_size), value=stop_token).type_as(tar)
|
| 710 |
+
tar = torch.stack(
|
| 711 |
+
[
|
| 712 |
+
tar[:, pad_size - i * delay : pad_size - i * delay + L, i]
|
| 713 |
+
for i in range(tar.shape[-1])
|
| 714 |
+
],
|
| 715 |
+
dim=-1,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
lens += pad_size
|
| 719 |
+
|
| 720 |
+
return inp, tar, lens + self.streaming_delayed_frames + 1
|
| 721 |
+
|
| 722 |
+
def get_logits(
|
| 723 |
+
self,
|
| 724 |
+
final_norm,
|
| 725 |
+
first_inputs,
|
| 726 |
+
first_head,
|
| 727 |
+
speech_conditioning_inputs=None,
|
| 728 |
+
attention_mask=None,
|
| 729 |
+
get_attns=False,
|
| 730 |
+
return_latent=False,
|
| 731 |
+
):
|
| 732 |
+
emb = first_inputs
|
| 733 |
+
if speech_conditioning_inputs is not None:
|
| 734 |
+
emb = torch.cat([speech_conditioning_inputs, emb], dim=1)
|
| 735 |
+
|
| 736 |
+
gpt_out = self.gpt(
|
| 737 |
+
inputs_embeds=emb,
|
| 738 |
+
return_dict=True,
|
| 739 |
+
attention_mask=attention_mask,
|
| 740 |
+
output_attentions=get_attns,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
enc = gpt_out.last_hidden_state
|
| 744 |
+
if speech_conditioning_inputs is not None:
|
| 745 |
+
enc = enc[:, 1:]
|
| 746 |
+
enc = final_norm(enc)
|
| 747 |
+
|
| 748 |
+
first_logits = [head(enc).permute(0, 2, 1) for head in first_head]
|
| 749 |
+
|
| 750 |
+
return first_logits
|
| 751 |
+
|
| 752 |
+
@torch.cuda.amp.autocast()
|
| 753 |
+
def get_conditioning(self, speech_conditioning_input):
|
| 754 |
+
if hasattr(self, "reference_encoder"):
|
| 755 |
+
if len(speech_conditioning_input.shape) == 2:
|
| 756 |
+
speech_conditioning_input = speech_conditioning_input.unsqueeze(1)
|
| 757 |
+
speech_conditioning_input = self.reference_encoder(
|
| 758 |
+
speech_conditioning_input
|
| 759 |
+
)
|
| 760 |
+
conds = self.reference_embedding(speech_conditioning_input)
|
| 761 |
+
return conds
|
| 762 |
+
|
| 763 |
+
def inference_speech(
|
| 764 |
+
self,
|
| 765 |
+
speech_conditioning_latent,
|
| 766 |
+
text_inputs,
|
| 767 |
+
input_tokens=None,
|
| 768 |
+
num_return_sequences=1,
|
| 769 |
+
max_generate_length=None,
|
| 770 |
+
**hf_generate_kwargs,
|
| 771 |
+
):
|
| 772 |
+
if not hasattr(self, "inference_model"):
|
| 773 |
+
self.post_init_gpt2_config()
|
| 774 |
+
|
| 775 |
+
# Cond
|
| 776 |
+
emb = speech_conditioning_latent
|
| 777 |
+
self.inference_model.store_mel_emb(emb)
|
| 778 |
+
|
| 779 |
+
# Text
|
| 780 |
+
text = torch.repeat_interleave(text_inputs, self.upsample_factors, dim=1)
|
| 781 |
+
text = F.pad(
|
| 782 |
+
text,
|
| 783 |
+
(0, self.streaming_delayed_frames + self.number_speech_heads - 1),
|
| 784 |
+
value=self.stop_speech_token,
|
| 785 |
+
)
|
| 786 |
+
text_embedding = self.text_embedding(text)
|
| 787 |
+
self.inference_model.store_mel_parallel_emb(text_embedding)
|
| 788 |
+
|
| 789 |
+
fake_inputs = torch.full(
|
| 790 |
+
(
|
| 791 |
+
emb.shape[0], # should be 1 for stable inference
|
| 792 |
+
emb.shape[1] + 1, # + 1 for the start_speech_token
|
| 793 |
+
self.number_speech_heads,
|
| 794 |
+
),
|
| 795 |
+
fill_value=self.start_speech_token,
|
| 796 |
+
dtype=torch.long,
|
| 797 |
+
device=text_inputs.device,
|
| 798 |
+
)
|
| 799 |
+
if input_tokens is None:
|
| 800 |
+
inputs = fake_inputs
|
| 801 |
+
prompt_index = 0
|
| 802 |
+
else:
|
| 803 |
+
prompt, _, _ = self.build_aligned_inputs_and_targets(
|
| 804 |
+
input_tokens,
|
| 805 |
+
torch.Tensor([len(input_tokens[0])]).int(),
|
| 806 |
+
self.start_speech_token,
|
| 807 |
+
self.stop_speech_token,
|
| 808 |
+
self.delay_prediction,
|
| 809 |
+
)
|
| 810 |
+
prompt = prompt[:, 1 : 1 + input_tokens.shape[1]]
|
| 811 |
+
inputs = torch.cat([fake_inputs, prompt], dim=1)
|
| 812 |
+
prompt_index = input_tokens.shape[1]
|
| 813 |
+
trunc_index = fake_inputs.shape[1]
|
| 814 |
+
|
| 815 |
+
stop_criteria = StoppingCriteriaList(
|
| 816 |
+
[FixedStoppingCriteria(text_embedding.shape[1], start_index=emb.shape[1])]
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
logits_processor = (
|
| 820 |
+
LogitsProcessorList(
|
| 821 |
+
[
|
| 822 |
+
MultiHeadRepetitionPenaltyLogitsProcessor(
|
| 823 |
+
penalty=self.repetition_penalty,
|
| 824 |
+
n_heads=self.n_heads_per_frame,
|
| 825 |
+
n_frames=-1,
|
| 826 |
+
start_index=trunc_index + prompt_index,
|
| 827 |
+
)
|
| 828 |
+
]
|
| 829 |
+
)
|
| 830 |
+
if self.repetition_penalty > 1.0
|
| 831 |
+
else LogitsProcessorList()
|
| 832 |
+
)
|
| 833 |
+
logits_processor.append(
|
| 834 |
+
SuppressionLogitsProcessor(suppressed_ids=[self.stop_speech_token])
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
max_length = (
|
| 838 |
+
trunc_index + self.max_speech_tokens - 1
|
| 839 |
+
if max_generate_length is None
|
| 840 |
+
else trunc_index + max_generate_length
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
# Recommandation of temp & top_p: (0.8, 0.8), (0.5, 0.5), (0.3, 0.2), (0.2, 0.1)
|
| 844 |
+
gen = self.inference_model.generate(
|
| 845 |
+
inputs,
|
| 846 |
+
bos_token_id=self.start_speech_token,
|
| 847 |
+
pad_token_id=self.stop_speech_token,
|
| 848 |
+
eos_token_id=self.stop_speech_token + 2,
|
| 849 |
+
max_length=max_length,
|
| 850 |
+
stopping_criteria=stop_criteria,
|
| 851 |
+
logits_processor=logits_processor,
|
| 852 |
+
num_return_sequences=num_return_sequences,
|
| 853 |
+
do_sample=True,
|
| 854 |
+
temperature=self.temperature,
|
| 855 |
+
length_penalty=self.length_penalty,
|
| 856 |
+
top_p=self.top_p,
|
| 857 |
+
top_k=self.top_k,
|
| 858 |
+
**hf_generate_kwargs,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
seq = gen[0][trunc_index:]
|
| 862 |
+
|
| 863 |
+
start, heads = 0, []
|
| 864 |
+
for j in range(self.number_speech_heads):
|
| 865 |
+
head = seq[j * self.delay_prediction :, j]
|
| 866 |
+
start_indices = (head == self.start_speech_token).nonzero(as_tuple=True)[0]
|
| 867 |
+
start = max(start, start_indices[-1] + 1 if len(start_indices) > 0 else 0)
|
| 868 |
+
stop = (head == self.stop_speech_token).nonzero(as_tuple=True)[0]
|
| 869 |
+
stop = stop[0] if len(stop) > 0 else len(head)
|
| 870 |
+
heads.append(head[:stop])
|
| 871 |
+
|
| 872 |
+
min_length = min([len(x) for x in heads])
|
| 873 |
+
seq = torch.stack(
|
| 874 |
+
[head[start + prompt_index : min_length] for head in heads], dim=-1
|
| 875 |
+
)
|
| 876 |
+
return [seq]
|
fireredtts/modules/bigvgan/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bigvgan import BigVGAN
|
| 2 |
+
from .mel_spectrogram import MelExtractor
|
fireredtts/modules/bigvgan/activations.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, sin, pow
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Snake(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Implementation of a sine-based periodic activation function
|
| 12 |
+
Shape:
|
| 13 |
+
- Input: (B, C, T)
|
| 14 |
+
- Output: (B, C, T), same shape as the input
|
| 15 |
+
Parameters:
|
| 16 |
+
- alpha - trainable parameter
|
| 17 |
+
References:
|
| 18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 19 |
+
https://arxiv.org/abs/2006.08195
|
| 20 |
+
Examples:
|
| 21 |
+
>>> a1 = snake(256)
|
| 22 |
+
>>> x = torch.randn(256)
|
| 23 |
+
>>> x = a1(x)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialization.
|
| 31 |
+
INPUT:
|
| 32 |
+
- in_features: shape of the input
|
| 33 |
+
- alpha: trainable parameter
|
| 34 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 35 |
+
alpha will be trained along with the rest of your model.
|
| 36 |
+
"""
|
| 37 |
+
super(Snake, self).__init__()
|
| 38 |
+
self.in_features = in_features
|
| 39 |
+
|
| 40 |
+
# initialize alpha
|
| 41 |
+
self.alpha_logscale = alpha_logscale
|
| 42 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 44 |
+
else: # linear scale alphas initialized to ones
|
| 45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 46 |
+
|
| 47 |
+
self.alpha.requires_grad = alpha_trainable
|
| 48 |
+
|
| 49 |
+
self.no_div_by_zero = 0.000000001
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
"""
|
| 53 |
+
Forward pass of the function.
|
| 54 |
+
Applies the function to the input elementwise.
|
| 55 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 56 |
+
"""
|
| 57 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 58 |
+
if self.alpha_logscale:
|
| 59 |
+
alpha = torch.exp(alpha)
|
| 60 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 61 |
+
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SnakeBeta(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 68 |
+
Shape:
|
| 69 |
+
- Input: (B, C, T)
|
| 70 |
+
- Output: (B, C, T), same shape as the input
|
| 71 |
+
Parameters:
|
| 72 |
+
- alpha - trainable parameter that controls frequency
|
| 73 |
+
- beta - trainable parameter that controls magnitude
|
| 74 |
+
References:
|
| 75 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 76 |
+
https://arxiv.org/abs/2006.08195
|
| 77 |
+
Examples:
|
| 78 |
+
>>> a1 = snakebeta(256)
|
| 79 |
+
>>> x = torch.randn(256)
|
| 80 |
+
>>> x = a1(x)
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
Initialization.
|
| 88 |
+
INPUT:
|
| 89 |
+
- in_features: shape of the input
|
| 90 |
+
- alpha - trainable parameter that controls frequency
|
| 91 |
+
- beta - trainable parameter that controls magnitude
|
| 92 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 93 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 94 |
+
alpha will be trained along with the rest of your model.
|
| 95 |
+
"""
|
| 96 |
+
super(SnakeBeta, self).__init__()
|
| 97 |
+
self.in_features = in_features
|
| 98 |
+
|
| 99 |
+
# initialize alpha
|
| 100 |
+
self.alpha_logscale = alpha_logscale
|
| 101 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 102 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 103 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 104 |
+
else: # linear scale alphas initialized to ones
|
| 105 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 106 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 107 |
+
|
| 108 |
+
self.alpha.requires_grad = alpha_trainable
|
| 109 |
+
self.beta.requires_grad = alpha_trainable
|
| 110 |
+
|
| 111 |
+
self.no_div_by_zero = 0.000000001
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
"""
|
| 115 |
+
Forward pass of the function.
|
| 116 |
+
Applies the function to the input elementwise.
|
| 117 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
| 118 |
+
"""
|
| 119 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 120 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 121 |
+
if self.alpha_logscale:
|
| 122 |
+
alpha = torch.exp(alpha)
|
| 123 |
+
beta = torch.exp(beta)
|
| 124 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 125 |
+
|
| 126 |
+
return x
|
fireredtts/modules/bigvgan/alias_free_torch/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
from .filter import *
|
| 4 |
+
from .resample import *
|
| 5 |
+
from .act import *
|
fireredtts/modules/bigvgan/alias_free_torch/act.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from .resample import UpSample1d, DownSample1d
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Activation1d(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
activation,
|
| 11 |
+
up_ratio: int = 2,
|
| 12 |
+
down_ratio: int = 2,
|
| 13 |
+
up_kernel_size: int = 12,
|
| 14 |
+
down_kernel_size: int = 12,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.up_ratio = up_ratio
|
| 18 |
+
self.down_ratio = down_ratio
|
| 19 |
+
self.act = activation
|
| 20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 22 |
+
|
| 23 |
+
# x: [B,C,T]
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = self.upsample(x)
|
| 26 |
+
x = self.act(x)
|
| 27 |
+
x = self.downsample(x)
|
| 28 |
+
|
| 29 |
+
return x
|
fireredtts/modules/bigvgan/alias_free_torch/filter.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
if "sinc" in dir(torch):
|
| 9 |
+
sinc = torch.sinc
|
| 10 |
+
else:
|
| 11 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
| 12 |
+
# https://adefossez.github.io/julius/julius/core.html
|
| 13 |
+
# LICENSE is in incl_licenses directory.
|
| 14 |
+
def sinc(x: torch.Tensor):
|
| 15 |
+
"""
|
| 16 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
| 17 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
| 18 |
+
"""
|
| 19 |
+
return torch.where(
|
| 20 |
+
x == 0,
|
| 21 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 22 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
| 27 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
| 28 |
+
# LICENSE is in incl_licenses directory.
|
| 29 |
+
def kaiser_sinc_filter1d(
|
| 30 |
+
cutoff, half_width, kernel_size
|
| 31 |
+
): # return filter [1,1,kernel_size]
|
| 32 |
+
even = kernel_size % 2 == 0
|
| 33 |
+
half_size = kernel_size // 2
|
| 34 |
+
|
| 35 |
+
# For kaiser window
|
| 36 |
+
delta_f = 4 * half_width
|
| 37 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 38 |
+
if A > 50.0:
|
| 39 |
+
beta = 0.1102 * (A - 8.7)
|
| 40 |
+
elif A >= 21.0:
|
| 41 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
| 42 |
+
else:
|
| 43 |
+
beta = 0.0
|
| 44 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 45 |
+
|
| 46 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
| 47 |
+
if even:
|
| 48 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
| 49 |
+
else:
|
| 50 |
+
time = torch.arange(kernel_size) - half_size
|
| 51 |
+
if cutoff == 0:
|
| 52 |
+
filter_ = torch.zeros_like(time)
|
| 53 |
+
else:
|
| 54 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
| 55 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
| 56 |
+
# of the constant component in the input signal.
|
| 57 |
+
filter_ /= filter_.sum()
|
| 58 |
+
filter = filter_.view(1, 1, kernel_size)
|
| 59 |
+
|
| 60 |
+
return filter
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LowPassFilter1d(nn.Module):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
cutoff=0.5,
|
| 67 |
+
half_width=0.6,
|
| 68 |
+
stride: int = 1,
|
| 69 |
+
padding: bool = True,
|
| 70 |
+
padding_mode: str = "replicate",
|
| 71 |
+
kernel_size: int = 12,
|
| 72 |
+
):
|
| 73 |
+
# kernel_size should be even number for stylegan3 setup,
|
| 74 |
+
# in this implementation, odd number is also possible.
|
| 75 |
+
super().__init__()
|
| 76 |
+
if cutoff < -0.0:
|
| 77 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 78 |
+
if cutoff > 0.5:
|
| 79 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 80 |
+
self.kernel_size = kernel_size
|
| 81 |
+
self.even = kernel_size % 2 == 0
|
| 82 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 83 |
+
self.pad_right = kernel_size // 2
|
| 84 |
+
self.stride = stride
|
| 85 |
+
self.padding = padding
|
| 86 |
+
self.padding_mode = padding_mode
|
| 87 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 88 |
+
self.register_buffer("filter", filter)
|
| 89 |
+
|
| 90 |
+
# input [B, C, T]
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
_, C, _ = x.shape
|
| 93 |
+
|
| 94 |
+
if self.padding:
|
| 95 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 96 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
| 97 |
+
|
| 98 |
+
return out
|
fireredtts/modules/bigvgan/alias_free_torch/resample.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from .filter import LowPassFilter1d
|
| 6 |
+
from .filter import kaiser_sinc_filter1d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UpSample1d(nn.Module):
|
| 10 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.ratio = ratio
|
| 13 |
+
self.kernel_size = (
|
| 14 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 15 |
+
)
|
| 16 |
+
self.stride = ratio
|
| 17 |
+
self.pad = self.kernel_size // ratio - 1
|
| 18 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 19 |
+
self.pad_right = (
|
| 20 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 21 |
+
)
|
| 22 |
+
filter = kaiser_sinc_filter1d(
|
| 23 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
| 24 |
+
)
|
| 25 |
+
self.register_buffer("filter", filter)
|
| 26 |
+
|
| 27 |
+
# x: [B, C, T]
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
_, C, _ = x.shape
|
| 30 |
+
|
| 31 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 32 |
+
x = self.ratio * F.conv_transpose1d(
|
| 33 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
| 34 |
+
)
|
| 35 |
+
x = x[..., self.pad_left : -self.pad_right]
|
| 36 |
+
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DownSample1d(nn.Module):
|
| 41 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.ratio = ratio
|
| 44 |
+
self.kernel_size = (
|
| 45 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 46 |
+
)
|
| 47 |
+
self.lowpass = LowPassFilter1d(
|
| 48 |
+
cutoff=0.5 / ratio,
|
| 49 |
+
half_width=0.6 / ratio,
|
| 50 |
+
stride=ratio,
|
| 51 |
+
kernel_size=self.kernel_size,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
xx = self.lowpass(x)
|
| 56 |
+
|
| 57 |
+
return xx
|
fireredtts/modules/bigvgan/bigvgan.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as tp
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 6 |
+
|
| 7 |
+
from .alias_free_torch import Activation1d as TorchActivation1d
|
| 8 |
+
from .activations import Snake, SnakeBeta
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 12 |
+
classname = m.__class__.__name__
|
| 13 |
+
if classname.find("Conv") != -1:
|
| 14 |
+
m.weight.data.normal_(mean, std)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_padding(kernel_size, dilation=1):
|
| 18 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AMPBlock1(torch.nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
channels,
|
| 25 |
+
kernel_size=3,
|
| 26 |
+
dilation=(1, 3, 5),
|
| 27 |
+
activation=None,
|
| 28 |
+
snake_logscale=True,
|
| 29 |
+
):
|
| 30 |
+
super(AMPBlock1, self).__init__()
|
| 31 |
+
|
| 32 |
+
self.convs1 = nn.ModuleList(
|
| 33 |
+
[
|
| 34 |
+
weight_norm(
|
| 35 |
+
Conv1d(
|
| 36 |
+
channels,
|
| 37 |
+
channels,
|
| 38 |
+
kernel_size,
|
| 39 |
+
1,
|
| 40 |
+
dilation=dilation[0],
|
| 41 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 42 |
+
)
|
| 43 |
+
),
|
| 44 |
+
weight_norm(
|
| 45 |
+
Conv1d(
|
| 46 |
+
channels,
|
| 47 |
+
channels,
|
| 48 |
+
kernel_size,
|
| 49 |
+
1,
|
| 50 |
+
dilation=dilation[1],
|
| 51 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 52 |
+
)
|
| 53 |
+
),
|
| 54 |
+
weight_norm(
|
| 55 |
+
Conv1d(
|
| 56 |
+
channels,
|
| 57 |
+
channels,
|
| 58 |
+
kernel_size,
|
| 59 |
+
1,
|
| 60 |
+
dilation=dilation[2],
|
| 61 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 62 |
+
)
|
| 63 |
+
),
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
self.convs1.apply(init_weights)
|
| 67 |
+
|
| 68 |
+
self.convs2 = nn.ModuleList(
|
| 69 |
+
[
|
| 70 |
+
weight_norm(
|
| 71 |
+
Conv1d(
|
| 72 |
+
channels,
|
| 73 |
+
channels,
|
| 74 |
+
kernel_size,
|
| 75 |
+
1,
|
| 76 |
+
dilation=1,
|
| 77 |
+
padding=get_padding(kernel_size, 1),
|
| 78 |
+
)
|
| 79 |
+
),
|
| 80 |
+
weight_norm(
|
| 81 |
+
Conv1d(
|
| 82 |
+
channels,
|
| 83 |
+
channels,
|
| 84 |
+
kernel_size,
|
| 85 |
+
1,
|
| 86 |
+
dilation=1,
|
| 87 |
+
padding=get_padding(kernel_size, 1),
|
| 88 |
+
)
|
| 89 |
+
),
|
| 90 |
+
weight_norm(
|
| 91 |
+
Conv1d(
|
| 92 |
+
channels,
|
| 93 |
+
channels,
|
| 94 |
+
kernel_size,
|
| 95 |
+
1,
|
| 96 |
+
dilation=1,
|
| 97 |
+
padding=get_padding(kernel_size, 1),
|
| 98 |
+
)
|
| 99 |
+
),
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
self.convs2.apply(init_weights)
|
| 103 |
+
|
| 104 |
+
self.num_layers = len(self.convs1) + len(
|
| 105 |
+
self.convs2
|
| 106 |
+
) # total number of conv layers
|
| 107 |
+
|
| 108 |
+
Activation1d = TorchActivation1d
|
| 109 |
+
if (
|
| 110 |
+
activation == "snake"
|
| 111 |
+
): # periodic nonlinearity with snake function and anti-aliasing
|
| 112 |
+
self.activations = nn.ModuleList(
|
| 113 |
+
[
|
| 114 |
+
Activation1d(
|
| 115 |
+
activation=Snake(channels, alpha_logscale=snake_logscale)
|
| 116 |
+
)
|
| 117 |
+
for _ in range(self.num_layers)
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
elif (
|
| 121 |
+
activation == "snakebeta"
|
| 122 |
+
): # periodic nonlinearity with snakebeta function and anti-aliasing
|
| 123 |
+
self.activations = nn.ModuleList(
|
| 124 |
+
[
|
| 125 |
+
Activation1d(
|
| 126 |
+
activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
|
| 127 |
+
)
|
| 128 |
+
for _ in range(self.num_layers)
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
raise NotImplementedError(
|
| 133 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
| 138 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
| 139 |
+
xt = a1(x)
|
| 140 |
+
xt = c1(xt)
|
| 141 |
+
xt = a2(xt)
|
| 142 |
+
xt = c2(xt)
|
| 143 |
+
x = xt + x
|
| 144 |
+
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
def remove_weight_norm(self):
|
| 148 |
+
for l in self.convs1:
|
| 149 |
+
remove_weight_norm(l)
|
| 150 |
+
for l in self.convs2:
|
| 151 |
+
remove_weight_norm(l)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class AMPBlock2(torch.nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
channels,
|
| 158 |
+
kernel_size=3,
|
| 159 |
+
dilation=(1, 3),
|
| 160 |
+
activation=None,
|
| 161 |
+
snake_logscale=True,
|
| 162 |
+
):
|
| 163 |
+
super(AMPBlock2, self).__init__()
|
| 164 |
+
|
| 165 |
+
self.convs = nn.ModuleList(
|
| 166 |
+
[
|
| 167 |
+
weight_norm(
|
| 168 |
+
Conv1d(
|
| 169 |
+
channels,
|
| 170 |
+
channels,
|
| 171 |
+
kernel_size,
|
| 172 |
+
1,
|
| 173 |
+
dilation=dilation[0],
|
| 174 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 175 |
+
)
|
| 176 |
+
),
|
| 177 |
+
weight_norm(
|
| 178 |
+
Conv1d(
|
| 179 |
+
channels,
|
| 180 |
+
channels,
|
| 181 |
+
kernel_size,
|
| 182 |
+
1,
|
| 183 |
+
dilation=dilation[1],
|
| 184 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 185 |
+
)
|
| 186 |
+
),
|
| 187 |
+
]
|
| 188 |
+
)
|
| 189 |
+
self.convs.apply(init_weights)
|
| 190 |
+
|
| 191 |
+
self.num_layers = len(self.convs) # total number of conv layers
|
| 192 |
+
|
| 193 |
+
Activation1d = TorchActivation1d
|
| 194 |
+
|
| 195 |
+
if (
|
| 196 |
+
activation == "snake"
|
| 197 |
+
): # periodic nonlinearity with snake function and anti-aliasing
|
| 198 |
+
self.activations = nn.ModuleList(
|
| 199 |
+
[
|
| 200 |
+
Activation1d(
|
| 201 |
+
activation=Snake(channels, alpha_logscale=snake_logscale)
|
| 202 |
+
)
|
| 203 |
+
for _ in range(self.num_layers)
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
elif (
|
| 207 |
+
activation == "snakebeta"
|
| 208 |
+
): # periodic nonlinearity with snakebeta function and anti-aliasing
|
| 209 |
+
self.activations = nn.ModuleList(
|
| 210 |
+
[
|
| 211 |
+
Activation1d(
|
| 212 |
+
activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
|
| 213 |
+
)
|
| 214 |
+
for _ in range(self.num_layers)
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
raise NotImplementedError(
|
| 219 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
for c, a in zip(self.convs, self.activations):
|
| 224 |
+
xt = a(x)
|
| 225 |
+
xt = c(xt)
|
| 226 |
+
x = xt + x
|
| 227 |
+
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
def remove_weight_norm(self):
|
| 231 |
+
for l in self.convs:
|
| 232 |
+
remove_weight_norm(l)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class BigVGAN(torch.nn.Module):
|
| 236 |
+
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
num_mels: int,
|
| 240 |
+
upsample_initial_channel: int,
|
| 241 |
+
resblock_kernel_sizes: tp.List[int],
|
| 242 |
+
resblock_dilation_sizes: tp.List[tp.List[int]],
|
| 243 |
+
upsample_rates: tp.List[int],
|
| 244 |
+
upsample_kernel_sizes: tp.List[int],
|
| 245 |
+
resblock_type: str = "1",
|
| 246 |
+
snake_logscale: bool = True,
|
| 247 |
+
activation: str = "snakebeta",
|
| 248 |
+
use_tanh_at_final: bool = False,
|
| 249 |
+
use_bias_at_final: bool = False,
|
| 250 |
+
**kwargs,
|
| 251 |
+
):
|
| 252 |
+
super(BigVGAN, self).__init__()
|
| 253 |
+
|
| 254 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 255 |
+
self.num_upsamples = len(upsample_rates)
|
| 256 |
+
|
| 257 |
+
# pre conv
|
| 258 |
+
self.conv_pre = weight_norm(
|
| 259 |
+
Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
| 263 |
+
resblock = AMPBlock1 if resblock_type == "1" else AMPBlock2
|
| 264 |
+
|
| 265 |
+
# transposed conv-based upsamplers. does not apply anti-aliasing
|
| 266 |
+
self.ups = nn.ModuleList()
|
| 267 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 268 |
+
self.ups.append(
|
| 269 |
+
nn.ModuleList(
|
| 270 |
+
[
|
| 271 |
+
weight_norm(
|
| 272 |
+
ConvTranspose1d(
|
| 273 |
+
upsample_initial_channel // (2**i),
|
| 274 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 275 |
+
k,
|
| 276 |
+
u,
|
| 277 |
+
padding=(k - u) // 2,
|
| 278 |
+
)
|
| 279 |
+
)
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
| 285 |
+
self.resblocks = nn.ModuleList()
|
| 286 |
+
for i in range(len(self.ups)):
|
| 287 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 288 |
+
for j, (k, d) in enumerate(
|
| 289 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
| 290 |
+
):
|
| 291 |
+
self.resblocks.append(
|
| 292 |
+
resblock(
|
| 293 |
+
ch,
|
| 294 |
+
k,
|
| 295 |
+
d,
|
| 296 |
+
activation=activation,
|
| 297 |
+
snake_logscale=snake_logscale,
|
| 298 |
+
)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
Activation1d = TorchActivation1d
|
| 302 |
+
|
| 303 |
+
# post conv
|
| 304 |
+
if (
|
| 305 |
+
activation == "snake"
|
| 306 |
+
): # periodic nonlinearity with snake function and anti-aliasing
|
| 307 |
+
activation_post = Snake(ch, alpha_logscale=snake_logscale)
|
| 308 |
+
self.activation_post = Activation1d(activation=activation_post)
|
| 309 |
+
elif (
|
| 310 |
+
activation == "snakebeta"
|
| 311 |
+
): # periodic nonlinearity with snakebeta function and anti-aliasing
|
| 312 |
+
activation_post = SnakeBeta(ch, alpha_logscale=snake_logscale)
|
| 313 |
+
self.activation_post = Activation1d(activation=activation_post)
|
| 314 |
+
else:
|
| 315 |
+
raise NotImplementedError(
|
| 316 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# whether to use bias for the final conv_post. Defaults to True for backward compatibility
|
| 320 |
+
self.use_bias_at_final = use_bias_at_final
|
| 321 |
+
self.conv_post = weight_norm(
|
| 322 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# weight initialization
|
| 326 |
+
for i in range(len(self.ups)):
|
| 327 |
+
self.ups[i].apply(init_weights)
|
| 328 |
+
self.conv_post.apply(init_weights)
|
| 329 |
+
|
| 330 |
+
# final tanh activation. Defaults to True for backward compatibility
|
| 331 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 332 |
+
|
| 333 |
+
def forward(self, x):
|
| 334 |
+
# pre conv
|
| 335 |
+
x = self.conv_pre(x)
|
| 336 |
+
|
| 337 |
+
for i in range(self.num_upsamples):
|
| 338 |
+
# upsampling
|
| 339 |
+
for i_up in range(len(self.ups[i])):
|
| 340 |
+
x = self.ups[i][i_up](x)
|
| 341 |
+
# AMP blocks
|
| 342 |
+
xs = None
|
| 343 |
+
for j in range(self.num_kernels):
|
| 344 |
+
if xs is None:
|
| 345 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 346 |
+
else:
|
| 347 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 348 |
+
x = xs / self.num_kernels
|
| 349 |
+
|
| 350 |
+
# post conv
|
| 351 |
+
x = self.activation_post(x)
|
| 352 |
+
x = self.conv_post(x)
|
| 353 |
+
# final tanh activation
|
| 354 |
+
if self.use_tanh_at_final:
|
| 355 |
+
x = torch.tanh(x)
|
| 356 |
+
else:
|
| 357 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # bound the output to [-1, 1]
|
| 358 |
+
|
| 359 |
+
return x
|
| 360 |
+
|
| 361 |
+
def remove_weight_norm(self):
|
| 362 |
+
print("Removing weight norm...")
|
| 363 |
+
for l in self.ups:
|
| 364 |
+
for l_i in l:
|
| 365 |
+
remove_weight_norm(l_i)
|
| 366 |
+
for l in self.resblocks:
|
| 367 |
+
l.remove_weight_norm()
|
| 368 |
+
remove_weight_norm(self.conv_pre)
|
| 369 |
+
remove_weight_norm(self.conv_post)
|
fireredtts/modules/bigvgan/mel_spectrogram.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 8 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 12 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def spectral_normalize_torch(magnitudes):
|
| 16 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 17 |
+
return output
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
mel_basis = {}
|
| 21 |
+
hann_window = {}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def mel_spectrogram(
|
| 25 |
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
| 26 |
+
):
|
| 27 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
| 28 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
| 29 |
+
mel = librosa_mel_fn(
|
| 30 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 31 |
+
)
|
| 32 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
| 33 |
+
torch.from_numpy(mel).float().to(y.device)
|
| 34 |
+
)
|
| 35 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 36 |
+
|
| 37 |
+
y = torch.nn.functional.pad(
|
| 38 |
+
y.unsqueeze(1),
|
| 39 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| 40 |
+
mode="reflect",
|
| 41 |
+
)
|
| 42 |
+
y = y.squeeze(1)
|
| 43 |
+
|
| 44 |
+
spec = torch.view_as_real(
|
| 45 |
+
torch.stft(
|
| 46 |
+
y,
|
| 47 |
+
n_fft,
|
| 48 |
+
hop_length=hop_size,
|
| 49 |
+
win_length=win_size,
|
| 50 |
+
window=hann_window[str(y.device)],
|
| 51 |
+
center=center,
|
| 52 |
+
pad_mode="reflect",
|
| 53 |
+
normalized=False,
|
| 54 |
+
onesided=True,
|
| 55 |
+
return_complex=True,
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 60 |
+
|
| 61 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
| 62 |
+
spec = spectral_normalize_torch(spec)
|
| 63 |
+
|
| 64 |
+
return spec
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MelExtractor(object):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
num_mels: int = 80,
|
| 71 |
+
n_fft: int = 1920,
|
| 72 |
+
hop_size: int = 480,
|
| 73 |
+
win_size: int = 1920,
|
| 74 |
+
sampling_rate: int = 24000,
|
| 75 |
+
fmin: int = 0,
|
| 76 |
+
fmax: int = 8000,
|
| 77 |
+
center: bool = False,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.num_mels = num_mels
|
| 81 |
+
self.n_fft = n_fft
|
| 82 |
+
self.hop_size = hop_size
|
| 83 |
+
self.win_size = win_size
|
| 84 |
+
self.sampling_rate = sampling_rate
|
| 85 |
+
self.fmin = fmin
|
| 86 |
+
self.fmax = fmax
|
| 87 |
+
self.center = center
|
| 88 |
+
|
| 89 |
+
def __call__(self, audio: torch.Tensor, audio_sr: int):
|
| 90 |
+
"""Args:
|
| 91 |
+
audio(torch.Tensor): shape (1, t)
|
| 92 |
+
Returns:
|
| 93 |
+
mel(torch.Tensor): shape (1, num_mels, t')
|
| 94 |
+
"""
|
| 95 |
+
if audio_sr != self.sampling_rate:
|
| 96 |
+
audio = torchaudio.functional.resample(
|
| 97 |
+
audio, orig_freq=audio_sr, new_freq=self.sampling_rate
|
| 98 |
+
)
|
| 99 |
+
audio_sr = self.sampling_rate
|
| 100 |
+
mel = mel_spectrogram(
|
| 101 |
+
audio,
|
| 102 |
+
self.n_fft,
|
| 103 |
+
self.num_mels,
|
| 104 |
+
self.sampling_rate,
|
| 105 |
+
self.hop_size,
|
| 106 |
+
self.win_size,
|
| 107 |
+
self.fmin,
|
| 108 |
+
self.fmax,
|
| 109 |
+
self.center,
|
| 110 |
+
) # (1, num_mels, t)
|
| 111 |
+
return mel
|
fireredtts/modules/flowmatching/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .estimator_dit import DiT
|
| 2 |
+
from .upsample_encoder import UpsampleConformerEncoder
|
| 3 |
+
from .flow import CausalFmWithSpkCtx, DualEmbedding
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FlowToken2Mel(CausalFmWithSpkCtx):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
token_emb = DualEmbedding(**config['token_emb'])
|
| 9 |
+
encoder = UpsampleConformerEncoder(**config['encoder'])
|
| 10 |
+
estimator = DiT(**config['estimator'])
|
| 11 |
+
super().__init__(
|
| 12 |
+
spk_channels=config['spk_channels'],
|
| 13 |
+
spk_enc_channels=config['spk_enc_channels'],
|
| 14 |
+
infer_cfg_rate=config['infer_cfg_rate'],
|
| 15 |
+
token_emb=token_emb,
|
| 16 |
+
encoder=encoder,
|
| 17 |
+
estimator=estimator,
|
| 18 |
+
)
|
fireredtts/modules/flowmatching/estimator_dit.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MLP(torch.nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_features:int,
|
| 12 |
+
hidden_features:Optional[int]=None,
|
| 13 |
+
out_features:Optional[int]=None,
|
| 14 |
+
act_layer=nn.GELU,
|
| 15 |
+
norm_layer=None,
|
| 16 |
+
bias=True,
|
| 17 |
+
drop=0.,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
hidden_features = hidden_features or in_features
|
| 21 |
+
out_features = out_features or in_features
|
| 22 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 23 |
+
self.act = act_layer()
|
| 24 |
+
self.drop1 = nn.Dropout(drop)
|
| 25 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 26 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 27 |
+
self.drop2 = nn.Dropout(drop)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x = self.fc1(x)
|
| 31 |
+
x = self.act(x)
|
| 32 |
+
x = self.drop1(x)
|
| 33 |
+
x = self.norm(x)
|
| 34 |
+
x = self.fc2(x)
|
| 35 |
+
x = self.drop2(x)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Attention(torch.nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
num_heads: int = 8,
|
| 44 |
+
head_dim: int = 64,
|
| 45 |
+
qkv_bias: bool = False,
|
| 46 |
+
qk_norm: bool = False,
|
| 47 |
+
attn_drop: float = 0.,
|
| 48 |
+
proj_drop: float = 0.,
|
| 49 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 50 |
+
) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.num_heads = num_heads
|
| 53 |
+
self.head_dim = head_dim
|
| 54 |
+
self.inner_dim = num_heads * head_dim
|
| 55 |
+
self.scale = head_dim ** -0.5
|
| 56 |
+
|
| 57 |
+
self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
|
| 58 |
+
self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
|
| 59 |
+
self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
|
| 60 |
+
|
| 61 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 62 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 63 |
+
|
| 64 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 65 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 66 |
+
|
| 67 |
+
self.proj = nn.Linear(self.inner_dim, dim)
|
| 68 |
+
|
| 69 |
+
def to_heads(self, ts:torch.Tensor):
|
| 70 |
+
b, t, c = ts.shape
|
| 71 |
+
# (b, t, nh, c)
|
| 72 |
+
ts = ts.reshape(b, t, self.num_heads, c // self.num_heads)
|
| 73 |
+
ts = ts.transpose(1, 2)
|
| 74 |
+
return ts
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
"""Args:
|
| 78 |
+
x(torch.Tensor): shape (b, t, c)
|
| 79 |
+
attn_mask(torch.Tensor): shape (b, t, t)
|
| 80 |
+
"""
|
| 81 |
+
b, t, c = x.shape
|
| 82 |
+
|
| 83 |
+
q = self.to_q(x)
|
| 84 |
+
k = self.to_k(x)
|
| 85 |
+
v = self.to_v(x)
|
| 86 |
+
|
| 87 |
+
q = self.to_heads(q) # (b, nh, t, c)
|
| 88 |
+
k = self.to_heads(k)
|
| 89 |
+
v = self.to_heads(v)
|
| 90 |
+
|
| 91 |
+
q = self.q_norm(q)
|
| 92 |
+
k = self.k_norm(k)
|
| 93 |
+
|
| 94 |
+
if attn_mask is not None:
|
| 95 |
+
attn_mask = attn_mask.unsqueeze(1)
|
| 96 |
+
|
| 97 |
+
x = F.scaled_dot_product_attention(
|
| 98 |
+
q, k, v,
|
| 99 |
+
attn_mask=attn_mask,
|
| 100 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 101 |
+
) # (b, nh, t, c)
|
| 102 |
+
x = x.transpose(1, 2).reshape(b, t, -1)
|
| 103 |
+
x = self.proj(x)
|
| 104 |
+
x = self.proj_drop(x)
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def modulate(x, shift, scale):
|
| 109 |
+
return x * (1 + scale) + shift
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TimestepEmbedder(nn.Module):
|
| 113 |
+
"""
|
| 114 |
+
Embeds scalar timesteps into vector representations.
|
| 115 |
+
"""
|
| 116 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.mlp = nn.Sequential(
|
| 119 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 120 |
+
nn.SiLU(),
|
| 121 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 122 |
+
)
|
| 123 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 124 |
+
# from SinusoidalPosEmb
|
| 125 |
+
self.scale = 1000
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 129 |
+
"""
|
| 130 |
+
Create sinusoidal timestep embeddings.
|
| 131 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 132 |
+
These may be fractional.
|
| 133 |
+
:param dim: the dimension of the output.
|
| 134 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 135 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 136 |
+
"""
|
| 137 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 138 |
+
half = dim // 2
|
| 139 |
+
freqs = torch.exp(
|
| 140 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 141 |
+
).to(device=t.device)
|
| 142 |
+
args = t[:, None].float() * freqs[None]
|
| 143 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 144 |
+
if dim % 2:
|
| 145 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 146 |
+
return embedding
|
| 147 |
+
|
| 148 |
+
def forward(self, t):
|
| 149 |
+
t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size)
|
| 150 |
+
t_emb = self.mlp(t_freq)
|
| 151 |
+
return t_emb
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# Convolution related
|
| 155 |
+
class Transpose(torch.nn.Module):
|
| 156 |
+
def __init__(self, dim0: int, dim1: int):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.dim0 = dim0
|
| 159 |
+
self.dim1 = dim1
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor):
|
| 162 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
in_channels: int,
|
| 170 |
+
out_channels: int,
|
| 171 |
+
kernel_size: int,
|
| 172 |
+
) -> None:
|
| 173 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
|
| 174 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 175 |
+
|
| 176 |
+
def forward(self, x: torch.Tensor):
|
| 177 |
+
x = F.pad(x, self.causal_padding)
|
| 178 |
+
x = super(CausalConv1d, self).forward(x)
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class CausalConvBlock(nn.Module):
|
| 183 |
+
def __init__(self,
|
| 184 |
+
in_channels: int,
|
| 185 |
+
out_channels: int,
|
| 186 |
+
kernel_size: int = 3,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.in_channels = in_channels
|
| 190 |
+
self.out_channels = out_channels
|
| 191 |
+
self.kernel_size = kernel_size
|
| 192 |
+
|
| 193 |
+
self.block = torch.nn.Sequential(
|
| 194 |
+
# norm
|
| 195 |
+
# conv1
|
| 196 |
+
Transpose(1, 2),
|
| 197 |
+
CausalConv1d(in_channels, out_channels, kernel_size),
|
| 198 |
+
Transpose(1, 2),
|
| 199 |
+
# norm & act
|
| 200 |
+
nn.LayerNorm(out_channels),
|
| 201 |
+
nn.Mish(),
|
| 202 |
+
# conv2
|
| 203 |
+
Transpose(1, 2),
|
| 204 |
+
CausalConv1d(out_channels, out_channels, kernel_size),
|
| 205 |
+
Transpose(1, 2),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
|
| 209 |
+
"""
|
| 210 |
+
Args:
|
| 211 |
+
x: shape (b, t, c)
|
| 212 |
+
mask: shape (b, t, 1)
|
| 213 |
+
"""
|
| 214 |
+
if mask is not None: x = x * mask
|
| 215 |
+
x = self.block(x)
|
| 216 |
+
if mask is not None: x = x * mask
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class DiTBlock(nn.Module):
|
| 221 |
+
"""
|
| 222 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 223 |
+
"""
|
| 224 |
+
def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 227 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs)
|
| 228 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 229 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 230 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 231 |
+
self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 232 |
+
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 233 |
+
self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3)
|
| 234 |
+
self.adaLN_modulation = nn.Sequential(
|
| 235 |
+
nn.SiLU(),
|
| 236 |
+
nn.Linear(hidden_size, 9 * hidden_size, bias=True)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None):
|
| 240 |
+
"""Args
|
| 241 |
+
x: shape (b, t, c)
|
| 242 |
+
c: shape (b, 1, c)
|
| 243 |
+
attn_mask: shape (b, t, t), bool type attention mask
|
| 244 |
+
conv_mask: shape (b, 1, t), bool type non-pad mask
|
| 245 |
+
"""
|
| 246 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
|
| 247 |
+
= self.adaLN_modulation(c).chunk(9, dim=-1)
|
| 248 |
+
# attention
|
| 249 |
+
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask)
|
| 250 |
+
# conv
|
| 251 |
+
x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv), mask=conv_mask)
|
| 252 |
+
# mlp
|
| 253 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class FinalLayer(nn.Module):
|
| 258 |
+
"""
|
| 259 |
+
The final layer of DiT.
|
| 260 |
+
"""
|
| 261 |
+
def __init__(self, hidden_size, out_channels):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.adaLN_modulation = nn.Sequential(
|
| 264 |
+
nn.SiLU(),
|
| 265 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 266 |
+
)
|
| 267 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 268 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
| 269 |
+
|
| 270 |
+
def forward(self, x, c):
|
| 271 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 272 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 273 |
+
x = self.linear(x)
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class DiT(nn.Module):
|
| 278 |
+
"""
|
| 279 |
+
Diffusion model with a Transformer backbone.
|
| 280 |
+
"""
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
in_channels: int,
|
| 284 |
+
out_channels: int,
|
| 285 |
+
mlp_ratio: float = 4.0,
|
| 286 |
+
depth: int = 28,
|
| 287 |
+
num_heads: int = 8,
|
| 288 |
+
head_dim: int = 64,
|
| 289 |
+
hidden_size: int = 256,
|
| 290 |
+
):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.in_channels = in_channels
|
| 293 |
+
self.out_channels = out_channels
|
| 294 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 295 |
+
|
| 296 |
+
self.in_proj = nn.Linear(in_channels, hidden_size)
|
| 297 |
+
|
| 298 |
+
self.blocks = nn.ModuleList([
|
| 299 |
+
DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 300 |
+
])
|
| 301 |
+
self.final_layer = FinalLayer(hidden_size, self.out_channels)
|
| 302 |
+
self.initialize_weights()
|
| 303 |
+
|
| 304 |
+
def initialize_weights(self):
|
| 305 |
+
# Initialize transformer layers:
|
| 306 |
+
def _basic_init(module):
|
| 307 |
+
if isinstance(module, nn.Linear):
|
| 308 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 309 |
+
if module.bias is not None:
|
| 310 |
+
nn.init.constant_(module.bias, 0)
|
| 311 |
+
self.apply(_basic_init)
|
| 312 |
+
|
| 313 |
+
# Initialize timestep embedding MLP:
|
| 314 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 315 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 316 |
+
|
| 317 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 318 |
+
for block in self.blocks:
|
| 319 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 320 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 321 |
+
|
| 322 |
+
# Zero-out output layers:
|
| 323 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 324 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 325 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 326 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 327 |
+
|
| 328 |
+
"""For non-streaming inference.
|
| 329 |
+
"""
|
| 330 |
+
def forward(self, x:torch.Tensor, c:torch.Tensor, t:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None):
|
| 331 |
+
"""
|
| 332 |
+
Args:
|
| 333 |
+
x: shape (b, c, t)
|
| 334 |
+
c: aux condition, shape (b, c, t)
|
| 335 |
+
t: shape (b,)
|
| 336 |
+
attn_mask: (b, t, t)
|
| 337 |
+
conv_mask: (b, 1, t)
|
| 338 |
+
Returns:
|
| 339 |
+
pred: shape (b, c, t)
|
| 340 |
+
"""
|
| 341 |
+
# time
|
| 342 |
+
t = self.t_embedder(t.view(-1)).unsqueeze(1) # (b, 1, c)
|
| 343 |
+
|
| 344 |
+
# CausalConvBlock mask is (b, t, 1)
|
| 345 |
+
conv_mask = conv_mask if conv_mask is None else conv_mask.transpose(1, 2)
|
| 346 |
+
|
| 347 |
+
x = torch.cat([x, c], dim=1)
|
| 348 |
+
# forward blocks
|
| 349 |
+
x = x.transpose(1, 2)
|
| 350 |
+
x = self.in_proj(x)
|
| 351 |
+
for block in self.blocks:
|
| 352 |
+
x = block(x, t, attn_mask=attn_mask, conv_mask=conv_mask)
|
| 353 |
+
x = self.final_layer(x, t)
|
| 354 |
+
x = x.transpose(1, 2)
|
| 355 |
+
return x
|
| 356 |
+
|
fireredtts/modules/flowmatching/flow.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
from einops import pack, repeat
|
| 5 |
+
from .estimator_dit import DiT
|
| 6 |
+
from .upsample_encoder import UpsampleConformerEncoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DualEmbedding(torch.nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
channels:int=512,
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.codebook_size = 128
|
| 16 |
+
self.codebook_dim = 128
|
| 17 |
+
self.codebook = torch.nn.ModuleList([
|
| 18 |
+
torch.nn.Embedding(self.codebook_size, self.codebook_dim),
|
| 19 |
+
torch.nn.Embedding(self.codebook_size, self.codebook_dim),
|
| 20 |
+
])
|
| 21 |
+
self.out_proj = torch.nn.Linear(self.codebook_dim * 2, channels)
|
| 22 |
+
|
| 23 |
+
def forward(self, tokens):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
tokens: shape (b, t)
|
| 27 |
+
Returns:
|
| 28 |
+
token_embs: shape (b, t, c)
|
| 29 |
+
"""
|
| 30 |
+
token_embs = torch.cat([
|
| 31 |
+
self.codebook[0](tokens % self.codebook_size),
|
| 32 |
+
self.codebook[1](tokens // self.codebook_size)
|
| 33 |
+
], dim=-1)
|
| 34 |
+
token_embs = self.out_proj(token_embs)
|
| 35 |
+
return token_embs
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CausalFmWithSpkCtx(torch.nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
# Basic in-out
|
| 42 |
+
spk_channels: int,
|
| 43 |
+
spk_enc_channels: int, # out channels of spk & encoder projection
|
| 44 |
+
# Module
|
| 45 |
+
token_emb: DualEmbedding,
|
| 46 |
+
encoder: UpsampleConformerEncoder,
|
| 47 |
+
estimator: DiT,
|
| 48 |
+
# Flow cfg
|
| 49 |
+
infer_cfg_rate: float = 0.7,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
# Variants
|
| 53 |
+
self.up_stride = encoder.up_stride
|
| 54 |
+
self.infer_cfg_rate = infer_cfg_rate
|
| 55 |
+
# Module
|
| 56 |
+
self.spk_proj = torch.nn.Linear(spk_channels, spk_enc_channels)
|
| 57 |
+
self.token_emb = token_emb
|
| 58 |
+
self.encoder = encoder
|
| 59 |
+
self.encoder_proj = torch.nn.Linear(encoder.output_size, spk_enc_channels)
|
| 60 |
+
self.estimator = estimator
|
| 61 |
+
# Initial noise, maximum of 600s
|
| 62 |
+
self.register_buffer(
|
| 63 |
+
"x0",
|
| 64 |
+
torch.randn([1, self.estimator.out_channels, 50 * 600]),
|
| 65 |
+
persistent=False,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _euler(
|
| 69 |
+
self,
|
| 70 |
+
x0: torch.Tensor,
|
| 71 |
+
c: torch.Tensor,
|
| 72 |
+
n_timesteps: int = 10,
|
| 73 |
+
):
|
| 74 |
+
# time steps
|
| 75 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1).to(x0)
|
| 76 |
+
# cosine time schduling
|
| 77 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 78 |
+
# euler solver
|
| 79 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 80 |
+
t = t.unsqueeze(dim=0)
|
| 81 |
+
|
| 82 |
+
xt = x0
|
| 83 |
+
for step in range(1, len(t_span)):
|
| 84 |
+
# pack input
|
| 85 |
+
x_in = torch.cat([xt, xt], dim=0)
|
| 86 |
+
c_in = torch.cat([c, torch.zeros_like(c)], dim=0)
|
| 87 |
+
t_in = torch.cat([t, t], dim=0)
|
| 88 |
+
|
| 89 |
+
# model call
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
vt = self.estimator.forward(x_in, c_in, t_in)
|
| 92 |
+
# cfg
|
| 93 |
+
vt_cond, vt_cfg = vt.chunk(2, dim=0)
|
| 94 |
+
vt = (1.0 + self.infer_cfg_rate) * vt_cond - self.infer_cfg_rate * vt_cfg
|
| 95 |
+
|
| 96 |
+
xt = xt + dt * vt
|
| 97 |
+
t = t + dt
|
| 98 |
+
if step < len(t_span) - 1:
|
| 99 |
+
dt = t_span[step + 1] - t
|
| 100 |
+
return xt
|
| 101 |
+
|
| 102 |
+
def inference(
|
| 103 |
+
self,
|
| 104 |
+
prompt_token: torch.Tensor,
|
| 105 |
+
prompt_xvec: torch.Tensor,
|
| 106 |
+
prompt_feat: torch.Tensor,
|
| 107 |
+
token: torch.Tensor,
|
| 108 |
+
):
|
| 109 |
+
# NOTE align prompt_token, prompt_feat in advance
|
| 110 |
+
|
| 111 |
+
# Spk condition
|
| 112 |
+
embedding = F.normalize(prompt_xvec, dim=1)
|
| 113 |
+
spks = self.spk_proj(embedding)
|
| 114 |
+
|
| 115 |
+
# Token condition
|
| 116 |
+
token = torch.concat([prompt_token, token], dim=1)
|
| 117 |
+
xs = self.token_emb(token)
|
| 118 |
+
|
| 119 |
+
xs_lens = torch.tensor([xs.shape[1]]).to(token)
|
| 120 |
+
xs = self.encoder(xs, xs_lens)
|
| 121 |
+
mu = self.encoder_proj(xs)
|
| 122 |
+
|
| 123 |
+
# Mel context
|
| 124 |
+
ctx = torch.zeros_like(mu)
|
| 125 |
+
ctx[:, : prompt_feat.shape[1]] = prompt_feat
|
| 126 |
+
|
| 127 |
+
# Compose condition
|
| 128 |
+
cond = mu.transpose(1, 2)
|
| 129 |
+
ctx = ctx.transpose(1, 2)
|
| 130 |
+
spks = repeat(spks, "b c -> b c t", t=cond.shape[-1])
|
| 131 |
+
cond = pack([cond, spks, ctx], "b * t")[0]
|
| 132 |
+
|
| 133 |
+
# FM inference
|
| 134 |
+
x0 = self.x0[..., : mu.shape[1]]
|
| 135 |
+
x1 = self._euler(x0, cond, n_timesteps=10)
|
| 136 |
+
|
| 137 |
+
feat = x1.transpose(1, 2)[:, prompt_feat.shape[1] :]
|
| 138 |
+
return feat
|
fireredtts/modules/flowmatching/upsample_encoder.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from typing import Tuple, List, Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
"""Attention modules.
|
| 9 |
+
"""
|
| 10 |
+
class MultiHeadedAttention(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
n_head: int,
|
| 13 |
+
n_feat: int,
|
| 14 |
+
dropout_rate: float,
|
| 15 |
+
key_bias: bool = True):
|
| 16 |
+
super().__init__()
|
| 17 |
+
assert n_feat % n_head == 0
|
| 18 |
+
# We assume d_v always equals d_k
|
| 19 |
+
self.d_k = n_feat // n_head
|
| 20 |
+
self.h = n_head
|
| 21 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 22 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
| 23 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 24 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 25 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 26 |
+
|
| 27 |
+
def forward_qkv(self,
|
| 28 |
+
query: torch.Tensor,
|
| 29 |
+
key: torch.Tensor,
|
| 30 |
+
value: torch.Tensor):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
query,key,value: shape (b, t, c)
|
| 34 |
+
Returns:
|
| 35 |
+
query,key,value: shape (b, nh, t, c//nh)
|
| 36 |
+
"""
|
| 37 |
+
n_batch = query.size(0)
|
| 38 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 39 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 40 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 41 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 42 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 43 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 44 |
+
return q, k, v
|
| 45 |
+
|
| 46 |
+
def forward_attention(self,
|
| 47 |
+
value: torch.Tensor,
|
| 48 |
+
scores: torch.Tensor,
|
| 49 |
+
mask: torch.Tensor = None):
|
| 50 |
+
"""Compute attention context vector.
|
| 51 |
+
Args:
|
| 52 |
+
value (torch.Tensor): shape: (b, nh, t2, c//nh).
|
| 53 |
+
scores (torch.Tensor): shape: (b, nh, t1, t2).
|
| 54 |
+
mask (torch.Tensor): attention padded mask, size (b, 1, t2) or (b, t1, t2)
|
| 55 |
+
Returns:
|
| 56 |
+
shape: (b, t1, c)
|
| 57 |
+
"""
|
| 58 |
+
b = value.size(0)
|
| 59 |
+
if mask is not None:
|
| 60 |
+
mask = mask.unsqueeze(1).eq(0)
|
| 61 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
| 62 |
+
attn = scores.softmax(dim=-1).masked_fill(mask, 0.0)
|
| 63 |
+
else:
|
| 64 |
+
attn = scores.softmax(dim=-1)
|
| 65 |
+
p_attn = self.dropout(attn)
|
| 66 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 67 |
+
x = x.transpose(1, 2).contiguous().view(b, -1, self.h * self.d_k)
|
| 68 |
+
return self.linear_out(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 72 |
+
def __init__(self,
|
| 73 |
+
n_head: int,
|
| 74 |
+
n_feat: int,
|
| 75 |
+
dropout_rate: float,
|
| 76 |
+
key_bias: bool = True):
|
| 77 |
+
"""Multi-Head Attention layer with relative position encoding.
|
| 78 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 79 |
+
Args:
|
| 80 |
+
n_head (int): The number of heads.
|
| 81 |
+
n_feat (int): The number of features.
|
| 82 |
+
dropout_rate (float): Dropout rate.
|
| 83 |
+
"""
|
| 84 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
| 85 |
+
# linear transformation for positional encoding
|
| 86 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 87 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 88 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 89 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 90 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 91 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 92 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 93 |
+
|
| 94 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
"""Compute relative positional encoding.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
| 99 |
+
time1 means the length of query vector.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
torch.Tensor: Output tensor.
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
| 106 |
+
device=x.device,
|
| 107 |
+
dtype=x.dtype)
|
| 108 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 109 |
+
|
| 110 |
+
x_padded = x_padded.view(x.size()[0],
|
| 111 |
+
x.size()[1],
|
| 112 |
+
x.size(3) + 1, x.size(2))
|
| 113 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
| 114 |
+
:, :, :, : x.size(-1) // 2 + 1
|
| 115 |
+
] # only keep the positions from 0 to time2
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
query: torch.Tensor,
|
| 121 |
+
key: torch.Tensor,
|
| 122 |
+
value: torch.Tensor,
|
| 123 |
+
pos_emb: torch.Tensor,
|
| 124 |
+
mask: torch.Tensor = None,
|
| 125 |
+
cache: torch.Tensor = None,
|
| 126 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 127 |
+
"""
|
| 128 |
+
Args:
|
| 129 |
+
query (torch.Tensor): shape (b, t1, c).
|
| 130 |
+
key (torch.Tensor): shape (b, t2, c).
|
| 131 |
+
value (torch.Tensor): shape (b, t2, c).
|
| 132 |
+
mask (torch.Tensor): attention padded mask, shape (b, 1, t2) or (b, t1, t2).
|
| 133 |
+
pos_emb (torch.Tensor): Positional embedding tensor (b, 2*t1-1, c).
|
| 134 |
+
cache (torch.Tensor): Cache tensor (1, nh, cache_t, d_k * 2).
|
| 135 |
+
Returns:
|
| 136 |
+
torch.Tensor: Output tensor (b, t1, d_model).
|
| 137 |
+
torch.Tensor: Cache tensor (1, nh, cache_t + t1, d_k * 2)
|
| 138 |
+
"""
|
| 139 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 140 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 141 |
+
|
| 142 |
+
if cache is not None and cache.size(0) > 0:
|
| 143 |
+
key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
|
| 144 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 145 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 146 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 147 |
+
|
| 148 |
+
n_batch_pos = pos_emb.size(0)
|
| 149 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) # (batch, 2*time1-1, head, d_k)
|
| 150 |
+
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
| 151 |
+
|
| 152 |
+
# (batch, head, time1, d_k)
|
| 153 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 154 |
+
# (batch, head, time1, d_k)
|
| 155 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 156 |
+
|
| 157 |
+
# compute attention score
|
| 158 |
+
# first compute matrix a and matrix c
|
| 159 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 160 |
+
# (batch, head, time1, time2)
|
| 161 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 162 |
+
|
| 163 |
+
# compute matrix b and matrix d
|
| 164 |
+
# matrix_bd: (batch, head, time1, 2*time1-1)
|
| 165 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 166 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
| 167 |
+
if matrix_ac.shape != matrix_bd.shape:
|
| 168 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 169 |
+
|
| 170 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
| 171 |
+
|
| 172 |
+
return self.forward_attention(v, scores, mask), new_cache
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
| 176 |
+
"""Relative positional encoding module (new implementation).
|
| 177 |
+
|
| 178 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 179 |
+
|
| 180 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
d_model (int): Embedding dimension.
|
| 184 |
+
dropout_rate (float): Dropout rate.
|
| 185 |
+
max_len (int): Maximum input length.
|
| 186 |
+
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, d_model: int, dropout_rate: float=0.0, max_len: int = 5000):
|
| 190 |
+
"""Construct an PositionalEncoding object."""
|
| 191 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
| 192 |
+
self.d_model = d_model
|
| 193 |
+
self.xscale = math.sqrt(self.d_model)
|
| 194 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 195 |
+
self.pe = None
|
| 196 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 197 |
+
|
| 198 |
+
def extend_pe(self, x: torch.Tensor):
|
| 199 |
+
"""Reset the positional encodings."""
|
| 200 |
+
if self.pe is not None:
|
| 201 |
+
# self.pe contains both positive and negative parts
|
| 202 |
+
# the length of self.pe is 2 * input_len - 1
|
| 203 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 204 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 205 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 206 |
+
return
|
| 207 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
| 208 |
+
# position of key vector. We use position relative positions when keys
|
| 209 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 210 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 211 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 212 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 213 |
+
div_term = torch.exp(
|
| 214 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 215 |
+
* -(math.log(10000.0) / self.d_model)
|
| 216 |
+
)
|
| 217 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 218 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 219 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 220 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 221 |
+
|
| 222 |
+
# Reserve the order of positive indices and concat both positive and
|
| 223 |
+
# negative indices. This is used to support the shifting trick
|
| 224 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 225 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 226 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 227 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 228 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 229 |
+
|
| 230 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
| 231 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 232 |
+
"""Add positional encoding.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 239 |
+
|
| 240 |
+
"""
|
| 241 |
+
self.extend_pe(x)
|
| 242 |
+
x = x * self.xscale
|
| 243 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
| 244 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 245 |
+
|
| 246 |
+
def position_encoding(self,
|
| 247 |
+
offset: Union[int, torch.Tensor],
|
| 248 |
+
size: int) -> torch.Tensor:
|
| 249 |
+
""" For getting encoding in a streaming fashion
|
| 250 |
+
|
| 251 |
+
Attention!!!!!
|
| 252 |
+
we apply dropout only once at the whole utterance level in a none
|
| 253 |
+
streaming way, but will call this function several times with
|
| 254 |
+
increasing input size in a streaming scenario, so the dropout will
|
| 255 |
+
be applied several times.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
offset (int or torch.tensor): start offset
|
| 259 |
+
size (int): required size of position encoding
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
torch.Tensor: Corresponding encoding
|
| 263 |
+
"""
|
| 264 |
+
pos_emb = self.pe[
|
| 265 |
+
:,
|
| 266 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
| 267 |
+
]
|
| 268 |
+
return pos_emb
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
"""Other modules.
|
| 272 |
+
"""
|
| 273 |
+
class Upsample1D(nn.Module):
|
| 274 |
+
"""A 1D upsampling layer with an optional convolution.
|
| 275 |
+
|
| 276 |
+
Parameters:
|
| 277 |
+
channels (`int`):
|
| 278 |
+
number of channels in the inputs and outputs.
|
| 279 |
+
use_conv (`bool`, default `False`):
|
| 280 |
+
option to use a convolution.
|
| 281 |
+
use_conv_transpose (`bool`, default `False`):
|
| 282 |
+
option to use a convolution transpose.
|
| 283 |
+
out_channels (`int`, optional):
|
| 284 |
+
number of output channels. Defaults to `channels`.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.channels = channels
|
| 290 |
+
self.out_channels = out_channels
|
| 291 |
+
self.stride = stride
|
| 292 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
| 293 |
+
|
| 294 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
| 295 |
+
outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
|
| 296 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
| 297 |
+
outputs = self.conv(outputs)
|
| 298 |
+
return outputs, input_lengths * self.stride
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class PreLookaheadLayer(nn.Module):
|
| 302 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.channels = channels
|
| 305 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 306 |
+
self.conv1 = nn.Conv1d(
|
| 307 |
+
channels, channels,
|
| 308 |
+
kernel_size=pre_lookahead_len + 1,
|
| 309 |
+
stride=1, padding=0,
|
| 310 |
+
)
|
| 311 |
+
self.conv2 = nn.Conv1d(
|
| 312 |
+
channels, channels,
|
| 313 |
+
kernel_size=3, stride=1, padding=0,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 317 |
+
"""
|
| 318 |
+
inputs: (batch_size, seq_len, channels)
|
| 319 |
+
"""
|
| 320 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
| 321 |
+
# look ahead
|
| 322 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
| 323 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
| 324 |
+
# outputs
|
| 325 |
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
|
| 326 |
+
outputs = self.conv2(outputs)
|
| 327 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
| 328 |
+
|
| 329 |
+
# residual connection
|
| 330 |
+
outputs = outputs + inputs
|
| 331 |
+
return outputs
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 335 |
+
"""Positionwise feed forward layer.
|
| 336 |
+
|
| 337 |
+
FeedForward are appied on each position of the sequence.
|
| 338 |
+
The output dim is same with the input dim.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
idim (int): Input dimenstion.
|
| 342 |
+
hidden_units (int): The number of hidden units.
|
| 343 |
+
dropout_rate (float): Dropout rate.
|
| 344 |
+
activation (torch.nn.Module): Activation function
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
idim: int,
|
| 350 |
+
hidden_units: int,
|
| 351 |
+
dropout_rate: float,
|
| 352 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 353 |
+
):
|
| 354 |
+
"""Construct a PositionwiseFeedForward object."""
|
| 355 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 356 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
| 357 |
+
self.activation = activation
|
| 358 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 359 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
| 360 |
+
|
| 361 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 362 |
+
"""Forward function.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
xs: input tensor (B, L, D)
|
| 366 |
+
Returns:
|
| 367 |
+
output tensor, (B, L, D)
|
| 368 |
+
"""
|
| 369 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class LinearNoSubsampling(torch.nn.Module):
|
| 373 |
+
"""Linear transform the input without subsampling
|
| 374 |
+
Args:
|
| 375 |
+
idim (int): Input dimension.
|
| 376 |
+
odim (int): Output dimension.
|
| 377 |
+
dropout_rate (float): Dropout rate.
|
| 378 |
+
"""
|
| 379 |
+
def __init__(self,
|
| 380 |
+
idim: int,
|
| 381 |
+
odim: int,
|
| 382 |
+
dropout_rate: float,
|
| 383 |
+
pos_enc_class: torch.nn.Module
|
| 384 |
+
):
|
| 385 |
+
"""Construct an linear object."""
|
| 386 |
+
super().__init__()
|
| 387 |
+
self.out = torch.nn.Sequential(
|
| 388 |
+
torch.nn.Linear(idim, odim),
|
| 389 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
| 390 |
+
torch.nn.Dropout(dropout_rate),
|
| 391 |
+
)
|
| 392 |
+
self.pos_enc = pos_enc_class
|
| 393 |
+
|
| 394 |
+
def forward(
|
| 395 |
+
self,
|
| 396 |
+
x: torch.Tensor,
|
| 397 |
+
offset: int = 0
|
| 398 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 399 |
+
"""Input x.
|
| 400 |
+
Args:
|
| 401 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 402 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 406 |
+
where time' = time .
|
| 407 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 408 |
+
where time' = time .
|
| 409 |
+
"""
|
| 410 |
+
x = self.out(x)
|
| 411 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 412 |
+
return x, pos_emb
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
"""Encoder layer & encoder
|
| 416 |
+
"""
|
| 417 |
+
class ConformerEncoderLayer(nn.Module):
|
| 418 |
+
"""Encoder layer module.
|
| 419 |
+
Args:
|
| 420 |
+
size (int): Input dimension.
|
| 421 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 422 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
| 423 |
+
instance can be used as the argument.
|
| 424 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 425 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 426 |
+
dropout_rate (float): Dropout rate.
|
| 427 |
+
normalize_before (bool):
|
| 428 |
+
True: use layer_norm before each sub-block.
|
| 429 |
+
False: use layer_norm after each sub-block.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
def __init__(
|
| 433 |
+
self,
|
| 434 |
+
size: int,
|
| 435 |
+
self_attn: torch.nn.Module,
|
| 436 |
+
feed_forward: torch.nn.Module,
|
| 437 |
+
dropout_rate: float = 0.1,
|
| 438 |
+
normalize_before: bool = True,
|
| 439 |
+
):
|
| 440 |
+
"""Construct an EncoderLayer object."""
|
| 441 |
+
super().__init__()
|
| 442 |
+
self.self_attn = self_attn
|
| 443 |
+
self.feed_forward = feed_forward
|
| 444 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
| 445 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
| 446 |
+
self.ff_scale = 1.0
|
| 447 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 448 |
+
self.size = size
|
| 449 |
+
self.normalize_before = normalize_before
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
x: torch.Tensor,
|
| 454 |
+
mask: torch.Tensor,
|
| 455 |
+
pos_emb: torch.Tensor,
|
| 456 |
+
att_cache: torch.Tensor = None,
|
| 457 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 458 |
+
"""
|
| 459 |
+
Args:
|
| 460 |
+
x: shape (b, t, c)
|
| 461 |
+
mask: self-attention padded mask, shape (b, 1, t) or (b, t, t)
|
| 462 |
+
pos_emb: relative positional embedding, shape (b, t, 2t-1)
|
| 463 |
+
att_cache: shape (1, nh, cache_t, d_k * 2)
|
| 464 |
+
"""
|
| 465 |
+
# multi-headed self-attention module
|
| 466 |
+
residual = x
|
| 467 |
+
if self.normalize_before:
|
| 468 |
+
x = self.norm_mha(x)
|
| 469 |
+
# att_cache: (b, head, cache_t, d_k*2)
|
| 470 |
+
x_att, new_att_cache = self.self_attn(x, x, x, pos_emb, mask, att_cache)
|
| 471 |
+
x = residual + self.dropout(x_att)
|
| 472 |
+
if not self.normalize_before:
|
| 473 |
+
x = self.norm_mha(x)
|
| 474 |
+
|
| 475 |
+
# feed forward module
|
| 476 |
+
residual = x
|
| 477 |
+
if self.normalize_before:
|
| 478 |
+
x = self.norm_ff(x)
|
| 479 |
+
x_ffn = self.feed_forward(x)
|
| 480 |
+
x = residual + self.ff_scale * self.dropout(x_ffn)
|
| 481 |
+
if not self.normalize_before:
|
| 482 |
+
x = self.norm_ff(x)
|
| 483 |
+
|
| 484 |
+
return x, new_att_cache
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class UpsampleConformerEncoder(torch.nn.Module):
|
| 488 |
+
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
# Common
|
| 492 |
+
input_size: int = 512,
|
| 493 |
+
output_size: int = 512,
|
| 494 |
+
num_blocks: int = 6,
|
| 495 |
+
num_up_blocks: int = 4,
|
| 496 |
+
normalize_before: bool = True,
|
| 497 |
+
# Input & upsampling
|
| 498 |
+
up_stride: int = 2,
|
| 499 |
+
pre_lookahead_len: int = 3,
|
| 500 |
+
# Attention
|
| 501 |
+
attention_heads: int = 4,
|
| 502 |
+
key_bias: bool = True,
|
| 503 |
+
# MLP
|
| 504 |
+
linear_units: int = 2048,
|
| 505 |
+
# Dropouts
|
| 506 |
+
dropout_rate: float = 0.0,
|
| 507 |
+
positional_dropout_rate: float = 0.0,
|
| 508 |
+
attention_dropout_rate: float = 0.0,
|
| 509 |
+
):
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.input_size = input_size
|
| 512 |
+
self.output_size = output_size
|
| 513 |
+
self.up_stride = up_stride
|
| 514 |
+
# Input embedding
|
| 515 |
+
self.embed = LinearNoSubsampling(
|
| 516 |
+
input_size,
|
| 517 |
+
output_size,
|
| 518 |
+
dropout_rate,
|
| 519 |
+
# Positional encoding
|
| 520 |
+
EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
|
| 521 |
+
)
|
| 522 |
+
# Look ahead
|
| 523 |
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=output_size, pre_lookahead_len=pre_lookahead_len)
|
| 524 |
+
# Norm
|
| 525 |
+
self.normalize_before = normalize_before
|
| 526 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
| 527 |
+
# Act
|
| 528 |
+
activation = torch.nn.SiLU()
|
| 529 |
+
# Self-attention module definition
|
| 530 |
+
encoder_selfattn_layer_args = (
|
| 531 |
+
attention_heads,
|
| 532 |
+
output_size,
|
| 533 |
+
attention_dropout_rate,
|
| 534 |
+
key_bias,
|
| 535 |
+
)
|
| 536 |
+
# Feed-forward module definition
|
| 537 |
+
positionwise_layer_args = (
|
| 538 |
+
output_size,
|
| 539 |
+
linear_units,
|
| 540 |
+
dropout_rate,
|
| 541 |
+
activation,
|
| 542 |
+
)
|
| 543 |
+
# 1st Conformer
|
| 544 |
+
self.encoders = torch.nn.ModuleList([
|
| 545 |
+
ConformerEncoderLayer(
|
| 546 |
+
output_size,
|
| 547 |
+
# Self-attn
|
| 548 |
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
| 549 |
+
# FFN
|
| 550 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
| 551 |
+
dropout_rate,
|
| 552 |
+
normalize_before,
|
| 553 |
+
) for _ in range(num_blocks)
|
| 554 |
+
])
|
| 555 |
+
# Upsample
|
| 556 |
+
self.up_layer = Upsample1D(channels=output_size, out_channels=output_size, stride=up_stride)
|
| 557 |
+
# Input embedding2
|
| 558 |
+
self.up_embed = LinearNoSubsampling(
|
| 559 |
+
input_size,
|
| 560 |
+
output_size,
|
| 561 |
+
dropout_rate,
|
| 562 |
+
# Positional encoding
|
| 563 |
+
EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
|
| 564 |
+
)
|
| 565 |
+
# 2nd Conformer
|
| 566 |
+
self.up_encoders = torch.nn.ModuleList([
|
| 567 |
+
ConformerEncoderLayer(
|
| 568 |
+
output_size,
|
| 569 |
+
# Self-attn
|
| 570 |
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
| 571 |
+
# FFN
|
| 572 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
| 573 |
+
dropout_rate,
|
| 574 |
+
normalize_before,
|
| 575 |
+
) for _ in range(num_up_blocks)
|
| 576 |
+
])
|
| 577 |
+
|
| 578 |
+
"""For non-streaming inference.
|
| 579 |
+
"""
|
| 580 |
+
def forward(
|
| 581 |
+
self,
|
| 582 |
+
xs: torch.Tensor,
|
| 583 |
+
xs_lens: torch.Tensor,
|
| 584 |
+
# attention mask BEFORE upsample
|
| 585 |
+
attn_mask1: torch.Tensor=None,
|
| 586 |
+
# attention mask AFTER upsample
|
| 587 |
+
attn_mask2: torch.Tensor=None,
|
| 588 |
+
) -> torch.Tensor:
|
| 589 |
+
"""
|
| 590 |
+
Args:
|
| 591 |
+
xs: shape (b, t, c)
|
| 592 |
+
xs_lens: shape (b,)
|
| 593 |
+
attn_mask1: (token level) shape (b, t, t)
|
| 594 |
+
attn_mask2: (mel level) shape (b, 2t, 2t)
|
| 595 |
+
"""
|
| 596 |
+
# Input & lookahead
|
| 597 |
+
xs, pos_emb = self.embed(xs)
|
| 598 |
+
xs = self.pre_lookahead_layer(xs)
|
| 599 |
+
|
| 600 |
+
# 1st Conformer
|
| 601 |
+
for block in self.encoders:
|
| 602 |
+
xs, _ = block(xs, mask=attn_mask1, pos_emb=pos_emb)
|
| 603 |
+
|
| 604 |
+
# Upsample to mel-level
|
| 605 |
+
xs = xs.transpose(1, 2).contiguous()
|
| 606 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
| 607 |
+
xs = xs.transpose(1, 2).contiguous()
|
| 608 |
+
# Input
|
| 609 |
+
xs, pos_emb = self.up_embed(xs)
|
| 610 |
+
|
| 611 |
+
# 2nd Conformer
|
| 612 |
+
for block in self.up_encoders:
|
| 613 |
+
xs, _ = block(xs, mask=attn_mask2, pos_emb=pos_emb)
|
| 614 |
+
|
| 615 |
+
if self.normalize_before:
|
| 616 |
+
xs = self.after_norm(xs)
|
| 617 |
+
return xs
|
fireredtts/modules/semantic_llm/llm_gpt2.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
| 7 |
+
import functools
|
| 8 |
+
from transformers import GPT2PreTrainedModel, GPT2Model, GPT2Config
|
| 9 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# GPT2 NROMAL INFERENCE MODE
|
| 13 |
+
class GPT2InferenceModel(GPT2PreTrainedModel):
|
| 14 |
+
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
| 17 |
+
super().__init__(config)
|
| 18 |
+
self.transformer = gpt
|
| 19 |
+
self.pos_embedding = pos_emb
|
| 20 |
+
self.embeddings = embeddings
|
| 21 |
+
self.final_norm = norm
|
| 22 |
+
self.lm_head = nn.Sequential(norm, linear)
|
| 23 |
+
self.kv_cache = kv_cache
|
| 24 |
+
|
| 25 |
+
def store_prefix_emb(self, prefix_emb):
|
| 26 |
+
self.cached_prefix_emb = prefix_emb
|
| 27 |
+
|
| 28 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
| 29 |
+
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
| 30 |
+
if not self.kv_cache:
|
| 31 |
+
past_key_values = None
|
| 32 |
+
|
| 33 |
+
# only last token for inputs_ids if past is defined in kwargs
|
| 34 |
+
if past_key_values is not None:
|
| 35 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 36 |
+
if token_type_ids is not None:
|
| 37 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
| 38 |
+
|
| 39 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 40 |
+
position_ids = kwargs.get("position_ids", None)
|
| 41 |
+
|
| 42 |
+
if attention_mask is not None and position_ids is None:
|
| 43 |
+
# create position_ids on the fly for batch generation
|
| 44 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 45 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 46 |
+
if past_key_values is not None:
|
| 47 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 48 |
+
else:
|
| 49 |
+
position_ids = None
|
| 50 |
+
return {
|
| 51 |
+
"input_ids": input_ids,
|
| 52 |
+
"past_key_values": past_key_values,
|
| 53 |
+
"use_cache": kwargs.get("use_cache"),
|
| 54 |
+
"position_ids": position_ids,
|
| 55 |
+
"attention_mask": attention_mask,
|
| 56 |
+
"token_type_ids": token_type_ids,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self,
|
| 61 |
+
input_ids=None,
|
| 62 |
+
past_key_values=None,
|
| 63 |
+
attention_mask=None,
|
| 64 |
+
token_type_ids=None,
|
| 65 |
+
position_ids=None,
|
| 66 |
+
head_mask=None,
|
| 67 |
+
inputs_embeds=None,
|
| 68 |
+
encoder_hidden_states=None,
|
| 69 |
+
encoder_attention_mask=None,
|
| 70 |
+
labels=None,
|
| 71 |
+
use_cache=None,
|
| 72 |
+
output_attentions=None,
|
| 73 |
+
output_hidden_states=None,
|
| 74 |
+
return_dict=None,
|
| 75 |
+
):
|
| 76 |
+
assert self.cached_prefix_emb is not None
|
| 77 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
| 78 |
+
assert labels is None # Training not supported by this inference model.
|
| 79 |
+
return_dict = (
|
| 80 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
|
| 84 |
+
|
| 85 |
+
# Create embedding
|
| 86 |
+
prefix_len = self.cached_prefix_emb.shape[1]
|
| 87 |
+
if input_ids.shape[1] != 1:
|
| 88 |
+
gen_inputs = input_ids[:, prefix_len:]
|
| 89 |
+
gen_emb = self.embeddings(gen_inputs)
|
| 90 |
+
gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
| 91 |
+
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
| 92 |
+
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
| 93 |
+
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
| 97 |
+
emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
| 98 |
+
else:
|
| 99 |
+
emb = self.embeddings(input_ids)
|
| 100 |
+
emb = emb + self.pos_embedding.get_fixed_embedding(
|
| 101 |
+
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
| 102 |
+
)
|
| 103 |
+
transformer_outputs = self.transformer(
|
| 104 |
+
inputs_embeds=emb,
|
| 105 |
+
past_key_values=past_key_values,
|
| 106 |
+
attention_mask=attention_mask,
|
| 107 |
+
token_type_ids=token_type_ids,
|
| 108 |
+
position_ids=position_ids,
|
| 109 |
+
head_mask=head_mask,
|
| 110 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 111 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 112 |
+
use_cache=use_cache,
|
| 113 |
+
output_attentions=output_attentions,
|
| 114 |
+
output_hidden_states=output_hidden_states,
|
| 115 |
+
return_dict=return_dict,
|
| 116 |
+
)
|
| 117 |
+
hidden_states = transformer_outputs[0]
|
| 118 |
+
lm_logits = self.lm_head(hidden_states)
|
| 119 |
+
|
| 120 |
+
if not return_dict:
|
| 121 |
+
return (lm_logits,) + transformer_outputs[1:]
|
| 122 |
+
|
| 123 |
+
return CausalLMOutputWithCrossAttentions(
|
| 124 |
+
loss=None,
|
| 125 |
+
logits=lm_logits,
|
| 126 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 127 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 128 |
+
attentions=transformer_outputs.attentions,
|
| 129 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def _reorder_cache(past, beam_idx):
|
| 134 |
+
"""
|
| 135 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
| 136 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
| 137 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
| 138 |
+
"""
|
| 139 |
+
return tuple(
|
| 140 |
+
tuple(
|
| 141 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
| 142 |
+
for past_state in layer_past
|
| 143 |
+
)
|
| 144 |
+
for layer_past in past
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# GPT2 INDEX-CONTEXT INFERENCE MODE
|
| 149 |
+
class GPT2ICInferenceModel(GPT2PreTrainedModel):
|
| 150 |
+
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
| 151 |
+
|
| 152 |
+
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
| 153 |
+
super().__init__(config)
|
| 154 |
+
self.transformer = gpt
|
| 155 |
+
self.pos_embedding = pos_emb
|
| 156 |
+
self.embeddings = embeddings
|
| 157 |
+
self.final_norm = norm
|
| 158 |
+
self.lm_head = nn.Sequential(norm, linear)
|
| 159 |
+
self.kv_cache = kv_cache
|
| 160 |
+
|
| 161 |
+
def store_prefix_emb(self, prefix_emb):
|
| 162 |
+
self.cached_prefix_emb = prefix_emb
|
| 163 |
+
|
| 164 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
| 165 |
+
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
| 166 |
+
if not self.kv_cache:
|
| 167 |
+
past_key_values = None
|
| 168 |
+
|
| 169 |
+
# only last token for inputs_ids if past is defined in kwargs
|
| 170 |
+
if past_key_values is not None:
|
| 171 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 172 |
+
if token_type_ids is not None:
|
| 173 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
| 174 |
+
|
| 175 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 176 |
+
position_ids = kwargs.get("position_ids", None)
|
| 177 |
+
|
| 178 |
+
if attention_mask is not None and position_ids is None:
|
| 179 |
+
# create position_ids on the fly for batch generation
|
| 180 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 181 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 182 |
+
if past_key_values is not None:
|
| 183 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 184 |
+
else:
|
| 185 |
+
position_ids = None
|
| 186 |
+
return {
|
| 187 |
+
"input_ids": input_ids,
|
| 188 |
+
"past_key_values": past_key_values,
|
| 189 |
+
"use_cache": kwargs.get("use_cache"),
|
| 190 |
+
"position_ids": position_ids,
|
| 191 |
+
"attention_mask": attention_mask,
|
| 192 |
+
"token_type_ids": token_type_ids,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
input_ids=None,
|
| 198 |
+
past_key_values=None,
|
| 199 |
+
attention_mask=None,
|
| 200 |
+
token_type_ids=None,
|
| 201 |
+
position_ids=None,
|
| 202 |
+
head_mask=None,
|
| 203 |
+
inputs_embeds=None,
|
| 204 |
+
encoder_hidden_states=None,
|
| 205 |
+
encoder_attention_mask=None,
|
| 206 |
+
labels=None,
|
| 207 |
+
use_cache=None,
|
| 208 |
+
output_attentions=None,
|
| 209 |
+
output_hidden_states=None,
|
| 210 |
+
return_dict=None,
|
| 211 |
+
):
|
| 212 |
+
assert self.cached_prefix_emb is not None
|
| 213 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
| 214 |
+
assert labels is None # Training not supported by this inference model.
|
| 215 |
+
return_dict = (
|
| 216 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
|
| 220 |
+
|
| 221 |
+
# Create embedding
|
| 222 |
+
prefix_len = self.cached_prefix_emb.shape[1]
|
| 223 |
+
if input_ids.shape[1] != 1:
|
| 224 |
+
# gen_inputs = input_ids[:, prefix_len:]
|
| 225 |
+
# gen_emb = self.embeddings(gen_inputs)
|
| 226 |
+
# gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
| 227 |
+
gen_emb = self.cached_prefix_emb
|
| 228 |
+
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
| 229 |
+
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
| 230 |
+
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
| 234 |
+
# emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
| 235 |
+
emb = gen_emb
|
| 236 |
+
else:
|
| 237 |
+
emb = self.embeddings(input_ids)
|
| 238 |
+
emb = emb + self.pos_embedding.get_fixed_embedding(
|
| 239 |
+
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
| 240 |
+
)
|
| 241 |
+
transformer_outputs = self.transformer(
|
| 242 |
+
inputs_embeds=emb,
|
| 243 |
+
past_key_values=past_key_values,
|
| 244 |
+
attention_mask=attention_mask,
|
| 245 |
+
token_type_ids=token_type_ids,
|
| 246 |
+
position_ids=position_ids,
|
| 247 |
+
head_mask=head_mask,
|
| 248 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 249 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 250 |
+
use_cache=use_cache,
|
| 251 |
+
output_attentions=output_attentions,
|
| 252 |
+
output_hidden_states=output_hidden_states,
|
| 253 |
+
return_dict=return_dict,
|
| 254 |
+
)
|
| 255 |
+
hidden_states = transformer_outputs[0]
|
| 256 |
+
lm_logits = self.lm_head(hidden_states)
|
| 257 |
+
|
| 258 |
+
if not return_dict:
|
| 259 |
+
return (lm_logits,) + transformer_outputs[1:]
|
| 260 |
+
|
| 261 |
+
return CausalLMOutputWithCrossAttentions(
|
| 262 |
+
loss=None,
|
| 263 |
+
logits=lm_logits,
|
| 264 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 265 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 266 |
+
attentions=transformer_outputs.attentions,
|
| 267 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _reorder_cache(past, beam_idx):
|
| 272 |
+
"""
|
| 273 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
| 274 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
| 275 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
| 276 |
+
"""
|
| 277 |
+
return tuple(
|
| 278 |
+
tuple(
|
| 279 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
| 280 |
+
for past_state in layer_past
|
| 281 |
+
)
|
| 282 |
+
for layer_past in past
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def null_position_embeddings(range, dim):
|
| 287 |
+
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class LearnedPositionEmbeddings(nn.Module):
|
| 291 |
+
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
| 292 |
+
super().__init__()
|
| 293 |
+
# nn.Embedding
|
| 294 |
+
self.emb = torch.nn.Embedding(seq_len, model_dim)
|
| 295 |
+
# Initializing this way is standard for GPT-2
|
| 296 |
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
| 297 |
+
self.relative = relative
|
| 298 |
+
self.seq_len = seq_len
|
| 299 |
+
|
| 300 |
+
def forward(self, x):
|
| 301 |
+
sl = x.shape[1]
|
| 302 |
+
if self.relative:
|
| 303 |
+
start = random.randint(sl, self.seq_len) - sl
|
| 304 |
+
return self.emb(torch.arange(start, start + sl, device=x.device))
|
| 305 |
+
else:
|
| 306 |
+
return self.emb(torch.arange(0, sl, device=x.device))
|
| 307 |
+
|
| 308 |
+
def get_fixed_embedding(self, ind, dev):
|
| 309 |
+
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def build_hf_gpt_transformer(
|
| 313 |
+
layers,
|
| 314 |
+
model_dim,
|
| 315 |
+
heads,
|
| 316 |
+
max_mel_seq_len,
|
| 317 |
+
max_text_seq_len,
|
| 318 |
+
max_prompt_len,
|
| 319 |
+
checkpointing,
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
GPT-2 implemented by the HuggingFace library.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
gpt_config = GPT2Config(
|
| 326 |
+
vocab_size=256, # Unused.
|
| 327 |
+
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
| 328 |
+
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
| 329 |
+
n_embd=model_dim,
|
| 330 |
+
n_layer=layers,
|
| 331 |
+
n_head=heads,
|
| 332 |
+
gradient_checkpointing=checkpointing,
|
| 333 |
+
use_cache=not checkpointing,
|
| 334 |
+
)
|
| 335 |
+
gpt = GPT2Model(gpt_config)
|
| 336 |
+
# Override the built in positional embeddings
|
| 337 |
+
del gpt.wpe
|
| 338 |
+
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
| 339 |
+
# Built-in token embeddings are unused.
|
| 340 |
+
del gpt.wte
|
| 341 |
+
|
| 342 |
+
mel_pos_emb = (
|
| 343 |
+
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
| 344 |
+
if max_mel_seq_len != -1
|
| 345 |
+
else functools.partial(null_position_embeddings, dim=model_dim)
|
| 346 |
+
)
|
| 347 |
+
text_pos_emb = (
|
| 348 |
+
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
| 349 |
+
if max_mel_seq_len != -1
|
| 350 |
+
else functools.partial(null_position_embeddings, dim=model_dim)
|
| 351 |
+
)
|
| 352 |
+
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
|
| 353 |
+
return gpt, mel_pos_emb, text_pos_emb, None, None
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class Speech_LLM_GPT2(nn.Module):
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
start_text_token,
|
| 360 |
+
stop_text_token,
|
| 361 |
+
num_text_tokens,
|
| 362 |
+
start_audio_token,
|
| 363 |
+
stop_audio_token,
|
| 364 |
+
num_audio_tokens,
|
| 365 |
+
llm_hidden_size,
|
| 366 |
+
llm_intermediate_size,
|
| 367 |
+
llm_num_layers,
|
| 368 |
+
llm_num_heads,
|
| 369 |
+
llm_max_audio_seq_len,
|
| 370 |
+
llm_max_text_seq_len,
|
| 371 |
+
llm_max_prompt_len,
|
| 372 |
+
code_stride_len=640,
|
| 373 |
+
max_conditioning_inputs=1,
|
| 374 |
+
label_smoothing=0.0,
|
| 375 |
+
checkpointing=False,
|
| 376 |
+
):
|
| 377 |
+
"""
|
| 378 |
+
Args:
|
| 379 |
+
|
| 380 |
+
"""
|
| 381 |
+
super().__init__()
|
| 382 |
+
|
| 383 |
+
self.label_smoothing = label_smoothing
|
| 384 |
+
# text token config
|
| 385 |
+
self.start_text_token = start_text_token
|
| 386 |
+
self.stop_text_token = stop_text_token
|
| 387 |
+
self.num_text_tokens = num_text_tokens
|
| 388 |
+
|
| 389 |
+
# audio token config
|
| 390 |
+
self.start_audio_token = start_audio_token
|
| 391 |
+
self.stop_audio_token = stop_audio_token
|
| 392 |
+
self.num_audio_tokens = num_audio_tokens
|
| 393 |
+
|
| 394 |
+
# prompts token config
|
| 395 |
+
self.start_prompt_token = start_audio_token
|
| 396 |
+
self.stop_prompt_token = stop_audio_token
|
| 397 |
+
|
| 398 |
+
# other config
|
| 399 |
+
self.max_conditioning_inputs = max_conditioning_inputs
|
| 400 |
+
|
| 401 |
+
# length configs
|
| 402 |
+
self.max_text_len = llm_max_text_seq_len + 2 # add <bos> <eos>
|
| 403 |
+
self.max_prompt_len = llm_max_prompt_len
|
| 404 |
+
self.max_audio_len = llm_max_audio_seq_len + 2 + self.max_conditioning_inputs
|
| 405 |
+
self.max_gen_audio_tokens = (
|
| 406 |
+
llm_max_audio_seq_len - self.max_conditioning_inputs - 2
|
| 407 |
+
)
|
| 408 |
+
self.code_stride_len = code_stride_len
|
| 409 |
+
|
| 410 |
+
# model config
|
| 411 |
+
self.llm_hidden_size = llm_hidden_size
|
| 412 |
+
self.llm_intermediate_size = llm_intermediate_size
|
| 413 |
+
self.llm_num_layers = llm_num_layers
|
| 414 |
+
self.llm_num_heads = llm_num_heads
|
| 415 |
+
|
| 416 |
+
# text embedding and audio embeddings
|
| 417 |
+
self.text_embedding = nn.Embedding(self.num_text_tokens, self.llm_hidden_size)
|
| 418 |
+
self.audio_embedding = nn.Embedding(self.num_audio_tokens, self.llm_hidden_size)
|
| 419 |
+
|
| 420 |
+
# low-level llm model
|
| 421 |
+
self.gpt2, self.audio_pos_embedding, self.text_pos_embedding, _, _ = (
|
| 422 |
+
build_hf_gpt_transformer(
|
| 423 |
+
layers=self.llm_num_layers,
|
| 424 |
+
model_dim=self.llm_hidden_size,
|
| 425 |
+
heads=self.llm_num_heads,
|
| 426 |
+
max_mel_seq_len=self.max_audio_len,
|
| 427 |
+
max_text_seq_len=self.max_text_len,
|
| 428 |
+
max_prompt_len=self.max_prompt_len,
|
| 429 |
+
checkpointing=checkpointing,
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# text and audio linear
|
| 434 |
+
self.final_norm = nn.LayerNorm(self.llm_hidden_size)
|
| 435 |
+
self.text_head = nn.Linear(self.llm_hidden_size, self.num_text_tokens)
|
| 436 |
+
self.audio_head = nn.Linear(self.llm_hidden_size, self.num_audio_tokens)
|
| 437 |
+
|
| 438 |
+
# speaker特征变换
|
| 439 |
+
self.reference_embedding = nn.Sequential(
|
| 440 |
+
nn.Linear(512, 256),
|
| 441 |
+
nn.Tanh(),
|
| 442 |
+
nn.Linear(256, self.llm_hidden_size),
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
|
| 446 |
+
"""_summary_
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
kv_cache (bool, optional): _description_. Defaults to True.
|
| 450 |
+
use_deepspeed (bool, optional): _description_. Defaults to False.
|
| 451 |
+
"""
|
| 452 |
+
seq_length = self.max_audio_len + self.max_text_len + self.max_prompt_len + 1
|
| 453 |
+
|
| 454 |
+
gpt_config = GPT2Config(
|
| 455 |
+
vocab_size=self.num_audio_tokens,
|
| 456 |
+
n_positions=seq_length,
|
| 457 |
+
n_ctx=seq_length,
|
| 458 |
+
n_embd=self.llm_hidden_size,
|
| 459 |
+
n_layer=self.llm_num_layers,
|
| 460 |
+
n_head=self.llm_num_heads,
|
| 461 |
+
gradient_checkpointing=False,
|
| 462 |
+
use_cache=True,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# normal inference model
|
| 466 |
+
self.gpt_inference = GPT2InferenceModel(
|
| 467 |
+
config=gpt_config,
|
| 468 |
+
gpt=self.gpt2,
|
| 469 |
+
pos_emb=self.audio_pos_embedding,
|
| 470 |
+
embeddings=self.audio_embedding,
|
| 471 |
+
norm=self.final_norm,
|
| 472 |
+
linear=self.audio_head,
|
| 473 |
+
kv_cache=kv_cache,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# in-context inference model
|
| 477 |
+
self.gpt_inference_ic = GPT2ICInferenceModel(
|
| 478 |
+
config=gpt_config,
|
| 479 |
+
gpt=self.gpt2,
|
| 480 |
+
pos_emb=self.audio_pos_embedding,
|
| 481 |
+
embeddings=self.audio_embedding,
|
| 482 |
+
norm=self.final_norm,
|
| 483 |
+
linear=self.audio_head,
|
| 484 |
+
kv_cache=kv_cache,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
self.gpt2.wte = self.audio_embedding
|
| 488 |
+
|
| 489 |
+
# --------------------------- normal inference ---------------------------
|
| 490 |
+
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
| 491 |
+
self.compute_embeddings(cond_latents, text_inputs)
|
| 492 |
+
return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
|
| 493 |
+
|
| 494 |
+
def compute_embeddings(self, cond_latents, text_inputs):
|
| 495 |
+
|
| 496 |
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
| 497 |
+
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
| 498 |
+
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
| 499 |
+
emb = torch.cat([cond_latents, emb], dim=1)
|
| 500 |
+
self.gpt_inference.store_prefix_emb(emb)
|
| 501 |
+
gpt_inputs = torch.full(
|
| 502 |
+
(
|
| 503 |
+
emb.shape[0],
|
| 504 |
+
emb.shape[1] + 1, # +1 for the start_audio_token
|
| 505 |
+
),
|
| 506 |
+
fill_value=1,
|
| 507 |
+
dtype=torch.long,
|
| 508 |
+
device=text_inputs.device,
|
| 509 |
+
)
|
| 510 |
+
gpt_inputs[:, -1] = self.start_audio_token
|
| 511 |
+
return gpt_inputs
|
| 512 |
+
|
| 513 |
+
def generate(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
| 514 |
+
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
|
| 515 |
+
gen = self.gpt_inference.generate(
|
| 516 |
+
gpt_inputs,
|
| 517 |
+
bos_token_id=self.start_audio_token,
|
| 518 |
+
pad_token_id=self.stop_audio_token,
|
| 519 |
+
eos_token_id=self.stop_audio_token,
|
| 520 |
+
max_length=self.max_gen_audio_tokens + gpt_inputs.shape[-1],
|
| 521 |
+
**hf_generate_kwargs,
|
| 522 |
+
)
|
| 523 |
+
if "return_dict_in_generate" in hf_generate_kwargs:
|
| 524 |
+
return gen.sequences[:, gpt_inputs.shape[1] :], gen
|
| 525 |
+
return gen[:, gpt_inputs.shape[1] :]
|
| 526 |
+
|
| 527 |
+
# --------------------------- normal inference --------------------------
|
| 528 |
+
|
| 529 |
+
# --------------------------- IC inference ---------------------------
|
| 530 |
+
def compute_embeddings_ic(self, cond_latents, text_inputs, prompt_tokens):
|
| 531 |
+
"""_summary_
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
cond_latents (_type_): speaker embedding
|
| 535 |
+
text_inputs (_type_): text tokens
|
| 536 |
+
prompt_tokens (_type_): prompts_tokens
|
| 537 |
+
|
| 538 |
+
Returns:
|
| 539 |
+
_type_: _description_
|
| 540 |
+
"""
|
| 541 |
+
|
| 542 |
+
# text embeddings
|
| 543 |
+
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
| 544 |
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
| 545 |
+
|
| 546 |
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
|
| 547 |
+
text_inputs
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# prompt_tokens
|
| 551 |
+
|
| 552 |
+
prompt_tokens = F.pad(prompt_tokens, (1, 0), value=self.start_audio_token)
|
| 553 |
+
audio_emb = self.audio_embedding(prompt_tokens) + self.audio_pos_embedding(
|
| 554 |
+
prompt_tokens
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
emb = torch.cat([cond_latents, text_emb, audio_emb], dim=1)
|
| 558 |
+
|
| 559 |
+
self.gpt_inference_ic.store_prefix_emb(emb)
|
| 560 |
+
gpt_inputs = torch.full(
|
| 561 |
+
(emb.shape[0], emb.shape[1]),
|
| 562 |
+
fill_value=1,
|
| 563 |
+
dtype=torch.long,
|
| 564 |
+
device=text_inputs.device,
|
| 565 |
+
)
|
| 566 |
+
return gpt_inputs
|
| 567 |
+
|
| 568 |
+
def generate_ic(
|
| 569 |
+
self, cond_latents, text_inputs, prompt_tokens, **hf_generate_kwargs
|
| 570 |
+
):
|
| 571 |
+
"""_summary_
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
cond_latents (_type_): _description_
|
| 575 |
+
text_inputs (_type_): _description_
|
| 576 |
+
prompt_tokens (_type_): _description_
|
| 577 |
+
|
| 578 |
+
Returns:
|
| 579 |
+
_type_: _description_
|
| 580 |
+
"""
|
| 581 |
+
gpt_inputs = self.compute_embeddings_ic(
|
| 582 |
+
cond_latents, text_inputs, prompt_tokens
|
| 583 |
+
)
|
| 584 |
+
gen = self.gpt_inference_ic.generate(
|
| 585 |
+
gpt_inputs,
|
| 586 |
+
bos_token_id=self.start_audio_token,
|
| 587 |
+
pad_token_id=self.stop_audio_token,
|
| 588 |
+
eos_token_id=self.stop_audio_token,
|
| 589 |
+
max_length=self.max_gen_audio_tokens + gpt_inputs.shape[-1],
|
| 590 |
+
**hf_generate_kwargs,
|
| 591 |
+
)
|
| 592 |
+
if "return_dict_in_generate" in hf_generate_kwargs:
|
| 593 |
+
return gen.sequences[:, gpt_inputs.shape[1] :], gen
|
| 594 |
+
|
| 595 |
+
return gen[:, gpt_inputs.shape[1] :]
|
| 596 |
+
|
| 597 |
+
# --------------------------- IC inference ---------------------------
|
| 598 |
+
|
| 599 |
+
def get_generator(self, fake_inputs, **hf_generate_kwargs):
|
| 600 |
+
return self.gpt_inference.generate_stream(
|
| 601 |
+
fake_inputs,
|
| 602 |
+
bos_token_id=self.start_audio_token,
|
| 603 |
+
pad_token_id=self.stop_audio_token,
|
| 604 |
+
eos_token_id=self.stop_audio_token,
|
| 605 |
+
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
|
| 606 |
+
do_stream=True,
|
| 607 |
+
**hf_generate_kwargs,
|
| 608 |
+
)
|
fireredtts/modules/semantic_tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from .hubert import HuBERT
|
| 5 |
+
from .semantic_tokenizer import SemanticVQVAE
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SemanticTokenizer:
|
| 9 |
+
|
| 10 |
+
def __init__(self, config, path):
|
| 11 |
+
self.model = SemanticVQVAE(**config)
|
| 12 |
+
self.model.load_state_dict(
|
| 13 |
+
torch.load(os.path.join(path, "codec.bin"), map_location="cpu"), strict=True
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
hubert = HuBERT(os.path.join(path, "hubert.pt"))
|
| 17 |
+
for name, param in hubert.named_parameters():
|
| 18 |
+
param.requires_grad = False
|
| 19 |
+
self.model.ssl_extractor = hubert
|
| 20 |
+
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
self.model = self.model.cuda()
|
| 23 |
+
self.model.eval()
|
| 24 |
+
|
| 25 |
+
def __call__(self, wavs, wav_lengths):
|
| 26 |
+
tokens, token_lengths, spk_embeddings = self.extract(wavs, wav_lengths)
|
| 27 |
+
return tokens, token_lengths, spk_embeddings
|
| 28 |
+
|
| 29 |
+
def extract(self, wavs, wav_lengths):
|
| 30 |
+
saved_features = self.model.extract_speech_tokens(wavs, wav_lengths)
|
| 31 |
+
|
| 32 |
+
tokens = saved_features["token"]
|
| 33 |
+
token_lengths = saved_features["token_length"]
|
| 34 |
+
spk_embeddings = saved_features["spk"]
|
| 35 |
+
|
| 36 |
+
return tokens, token_lengths, spk_embeddings
|
fireredtts/modules/semantic_tokenizer/audio.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 12 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dynamic_range_decompression(x, C=1):
|
| 16 |
+
return np.exp(x) / C
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 20 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 24 |
+
return torch.exp(x) / C
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def spectral_normalize_torch(magnitudes):
|
| 28 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 29 |
+
return output
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 33 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 34 |
+
return output
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TorchMelSpectrogram(nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
filter_length=1024,
|
| 42 |
+
hop_length=200,
|
| 43 |
+
win_length=800,
|
| 44 |
+
n_mel_channels=80,
|
| 45 |
+
mel_fmin=0,
|
| 46 |
+
mel_fmax=8000,
|
| 47 |
+
sampling_rate=16000,
|
| 48 |
+
sampling_rate_org=None,
|
| 49 |
+
normalize=False,
|
| 50 |
+
mel_norm_file=None,
|
| 51 |
+
scale=1.0,
|
| 52 |
+
padding="center",
|
| 53 |
+
style="Tortoise",
|
| 54 |
+
):
|
| 55 |
+
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.style = style
|
| 58 |
+
self.filter_length = filter_length
|
| 59 |
+
self.hop_length = hop_length
|
| 60 |
+
self.win_length = win_length
|
| 61 |
+
self.n_mel_channels = n_mel_channels
|
| 62 |
+
self.mel_fmin = mel_fmin
|
| 63 |
+
self.mel_fmax = mel_fmax
|
| 64 |
+
self.sampling_rate = sampling_rate
|
| 65 |
+
self.sampling_rate_org = (
|
| 66 |
+
sampling_rate_org if sampling_rate_org is not None else sampling_rate
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.mel_basis = {}
|
| 70 |
+
self.hann_window = {}
|
| 71 |
+
|
| 72 |
+
self.scale = scale
|
| 73 |
+
|
| 74 |
+
def forward(self, inp, length=None):
|
| 75 |
+
if len(inp.shape) == 3:
|
| 76 |
+
inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
|
| 77 |
+
assert len(inp.shape) == 2
|
| 78 |
+
|
| 79 |
+
if self.sampling_rate_org != self.sampling_rate:
|
| 80 |
+
inp = torchaudio.functional.resample(
|
| 81 |
+
inp, self.sampling_rate_org, self.sampling_rate
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
y = inp
|
| 85 |
+
if len(list(self.mel_basis.keys())) == 0:
|
| 86 |
+
mel = librosa_mel_fn(
|
| 87 |
+
sr=self.sampling_rate,
|
| 88 |
+
n_fft=self.filter_length,
|
| 89 |
+
n_mels=self.n_mel_channels,
|
| 90 |
+
fmin=self.mel_fmin,
|
| 91 |
+
fmax=self.mel_fmax,
|
| 92 |
+
)
|
| 93 |
+
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = (
|
| 94 |
+
torch.from_numpy(mel).float().to(y.device)
|
| 95 |
+
)
|
| 96 |
+
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(
|
| 97 |
+
y.device
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
y = torch.nn.functional.pad(
|
| 101 |
+
y.unsqueeze(1),
|
| 102 |
+
(
|
| 103 |
+
int((self.filter_length - self.hop_length) / 2),
|
| 104 |
+
int((self.filter_length - self.hop_length) / 2),
|
| 105 |
+
),
|
| 106 |
+
mode="reflect",
|
| 107 |
+
)
|
| 108 |
+
y = y.squeeze(1)
|
| 109 |
+
|
| 110 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
| 111 |
+
spec = torch.stft(
|
| 112 |
+
y,
|
| 113 |
+
self.filter_length,
|
| 114 |
+
hop_length=self.hop_length,
|
| 115 |
+
win_length=self.win_length,
|
| 116 |
+
window=self.hann_window[str(y.device)],
|
| 117 |
+
center=False,
|
| 118 |
+
pad_mode="reflect",
|
| 119 |
+
normalized=False,
|
| 120 |
+
onesided=True,
|
| 121 |
+
return_complex=True,
|
| 122 |
+
)
|
| 123 |
+
spec = torch.view_as_real(spec)
|
| 124 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 125 |
+
|
| 126 |
+
spec = torch.matmul(
|
| 127 |
+
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], spec
|
| 128 |
+
)
|
| 129 |
+
spec = spectral_normalize_torch(spec)
|
| 130 |
+
|
| 131 |
+
max_mel_length = math.ceil(y.shape[-1] / self.hop_length)
|
| 132 |
+
spec = spec[..., :max_mel_length].transpose(1, 2)
|
| 133 |
+
|
| 134 |
+
if length is None:
|
| 135 |
+
return spec
|
| 136 |
+
else:
|
| 137 |
+
spec_len = torch.ceil(length / self.hop_length).clamp(max=spec.shape[1])
|
| 138 |
+
return spec, spec_len
|
fireredtts/modules/semantic_tokenizer/ecapa_tdnn.py
ADDED
|
@@ -0,0 +1,931 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A popular speaker recognition and diarization model.
|
| 2 |
+
|
| 3 |
+
Authors
|
| 4 |
+
* Hwidong Na 2020
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import torch # noqa: F401
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
| 15 |
+
"""Creates a binary mask for each sequence.
|
| 16 |
+
|
| 17 |
+
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
| 18 |
+
|
| 19 |
+
Arguments
|
| 20 |
+
---------
|
| 21 |
+
length : torch.LongTensor
|
| 22 |
+
Containing the length of each sequence in the batch. Must be 1D.
|
| 23 |
+
max_len : int
|
| 24 |
+
Max length for the mask, also the size of the second dimension.
|
| 25 |
+
dtype : torch.dtype, default: None
|
| 26 |
+
The dtype of the generated mask.
|
| 27 |
+
device: torch.device, default: None
|
| 28 |
+
The device to put the mask variable.
|
| 29 |
+
|
| 30 |
+
Returns
|
| 31 |
+
-------
|
| 32 |
+
mask : tensor
|
| 33 |
+
The binary mask.
|
| 34 |
+
|
| 35 |
+
Example
|
| 36 |
+
-------
|
| 37 |
+
>>> length=torch.Tensor([1,2,3])
|
| 38 |
+
>>> mask=length_to_mask(length)
|
| 39 |
+
>>> mask
|
| 40 |
+
tensor([[1., 0., 0.],
|
| 41 |
+
[1., 1., 0.],
|
| 42 |
+
[1., 1., 1.]])
|
| 43 |
+
"""
|
| 44 |
+
assert len(length.shape) == 1
|
| 45 |
+
|
| 46 |
+
if max_len is None:
|
| 47 |
+
max_len = length.max().long().item() # using arange to generate mask
|
| 48 |
+
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
| 49 |
+
len(length), max_len
|
| 50 |
+
) < length.unsqueeze(1)
|
| 51 |
+
|
| 52 |
+
if dtype is None:
|
| 53 |
+
dtype = length.dtype
|
| 54 |
+
|
| 55 |
+
if device is None:
|
| 56 |
+
device = length.device
|
| 57 |
+
|
| 58 |
+
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
| 59 |
+
return mask
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
| 63 |
+
"""This function computes the number of elements to add for zero-padding.
|
| 64 |
+
|
| 65 |
+
Arguments
|
| 66 |
+
---------
|
| 67 |
+
L_in : int
|
| 68 |
+
stride: int
|
| 69 |
+
kernel_size : int
|
| 70 |
+
dilation : int
|
| 71 |
+
"""
|
| 72 |
+
if stride > 1:
|
| 73 |
+
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
| 74 |
+
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
| 75 |
+
padding = [kernel_size // 2, kernel_size // 2]
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
| 79 |
+
|
| 80 |
+
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
| 81 |
+
return padding
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Conv1d(nn.Module):
|
| 85 |
+
"""This function implements 1d convolution.
|
| 86 |
+
|
| 87 |
+
Arguments
|
| 88 |
+
---------
|
| 89 |
+
out_channels : int
|
| 90 |
+
It is the number of output channels.
|
| 91 |
+
kernel_size : int
|
| 92 |
+
Kernel size of the convolutional filters.
|
| 93 |
+
input_shape : tuple
|
| 94 |
+
The shape of the input. Alternatively use ``in_channels``.
|
| 95 |
+
in_channels : int
|
| 96 |
+
The number of input channels. Alternatively use ``input_shape``.
|
| 97 |
+
stride : int
|
| 98 |
+
Stride factor of the convolutional filters. When the stride factor > 1,
|
| 99 |
+
a decimation in time is performed.
|
| 100 |
+
dilation : int
|
| 101 |
+
Dilation factor of the convolutional filters.
|
| 102 |
+
padding : str
|
| 103 |
+
(same, valid, causal). If "valid", no padding is performed.
|
| 104 |
+
If "same" and stride is 1, output shape is the same as the input shape.
|
| 105 |
+
"causal" results in causal (dilated) convolutions.
|
| 106 |
+
padding_mode : str
|
| 107 |
+
This flag specifies the type of padding. See torch.nn documentation
|
| 108 |
+
for more information.
|
| 109 |
+
skip_transpose : bool
|
| 110 |
+
If False, uses batch x time x channel convention of speechbrain.
|
| 111 |
+
If True, uses batch x channel x time convention.
|
| 112 |
+
|
| 113 |
+
Example
|
| 114 |
+
-------
|
| 115 |
+
>>> inp_tensor = torch.rand([10, 40, 16])
|
| 116 |
+
>>> cnn_1d = Conv1d(
|
| 117 |
+
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
|
| 118 |
+
... )
|
| 119 |
+
>>> out_tensor = cnn_1d(inp_tensor)
|
| 120 |
+
>>> out_tensor.shape
|
| 121 |
+
torch.Size([10, 40, 8])
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
out_channels,
|
| 127 |
+
kernel_size,
|
| 128 |
+
input_shape=None,
|
| 129 |
+
in_channels=None,
|
| 130 |
+
stride=1,
|
| 131 |
+
dilation=1,
|
| 132 |
+
padding="same",
|
| 133 |
+
groups=1,
|
| 134 |
+
bias=True,
|
| 135 |
+
padding_mode="reflect",
|
| 136 |
+
skip_transpose=True,
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.kernel_size = kernel_size
|
| 140 |
+
self.stride = stride
|
| 141 |
+
self.dilation = dilation
|
| 142 |
+
self.padding = padding
|
| 143 |
+
self.padding_mode = padding_mode
|
| 144 |
+
self.unsqueeze = False
|
| 145 |
+
self.skip_transpose = skip_transpose
|
| 146 |
+
|
| 147 |
+
if input_shape is None and in_channels is None:
|
| 148 |
+
raise ValueError("Must provide one of input_shape or in_channels")
|
| 149 |
+
|
| 150 |
+
if in_channels is None:
|
| 151 |
+
in_channels = self._check_input_shape(input_shape)
|
| 152 |
+
|
| 153 |
+
self.conv = nn.Conv1d(
|
| 154 |
+
in_channels,
|
| 155 |
+
out_channels,
|
| 156 |
+
self.kernel_size,
|
| 157 |
+
stride=self.stride,
|
| 158 |
+
dilation=self.dilation,
|
| 159 |
+
padding=0,
|
| 160 |
+
groups=groups,
|
| 161 |
+
bias=bias,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
"""Returns the output of the convolution.
|
| 166 |
+
|
| 167 |
+
Arguments
|
| 168 |
+
---------
|
| 169 |
+
x : torch.Tensor (batch, time, channel)
|
| 170 |
+
input to convolve. 2d or 4d tensors are expected.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
if not self.skip_transpose:
|
| 174 |
+
x = x.transpose(1, -1)
|
| 175 |
+
|
| 176 |
+
if self.unsqueeze:
|
| 177 |
+
x = x.unsqueeze(1)
|
| 178 |
+
|
| 179 |
+
if self.padding == "same":
|
| 180 |
+
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
|
| 181 |
+
|
| 182 |
+
elif self.padding == "causal":
|
| 183 |
+
num_pad = (self.kernel_size - 1) * self.dilation
|
| 184 |
+
x = F.pad(x, (num_pad, 0))
|
| 185 |
+
|
| 186 |
+
elif self.padding == "valid":
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
"Padding must be 'same', 'valid' or 'causal'. Got " + self.padding
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
wx = self.conv(x)
|
| 195 |
+
|
| 196 |
+
if self.unsqueeze:
|
| 197 |
+
wx = wx.squeeze(1)
|
| 198 |
+
|
| 199 |
+
if not self.skip_transpose:
|
| 200 |
+
wx = wx.transpose(1, -1)
|
| 201 |
+
|
| 202 |
+
return wx
|
| 203 |
+
|
| 204 |
+
def _manage_padding(
|
| 205 |
+
self,
|
| 206 |
+
x,
|
| 207 |
+
kernel_size: int,
|
| 208 |
+
dilation: int,
|
| 209 |
+
stride: int,
|
| 210 |
+
):
|
| 211 |
+
"""This function performs zero-padding on the time axis
|
| 212 |
+
such that their lengths is unchanged after the convolution.
|
| 213 |
+
|
| 214 |
+
Arguments
|
| 215 |
+
---------
|
| 216 |
+
x : torch.Tensor
|
| 217 |
+
Input tensor.
|
| 218 |
+
kernel_size : int
|
| 219 |
+
Size of kernel.
|
| 220 |
+
dilation : int
|
| 221 |
+
Dilation used.
|
| 222 |
+
stride : int
|
| 223 |
+
Stride.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
# Detecting input shape
|
| 227 |
+
L_in = x.shape[-1]
|
| 228 |
+
|
| 229 |
+
# Time padding
|
| 230 |
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
| 231 |
+
|
| 232 |
+
# Applying padding
|
| 233 |
+
x = F.pad(x, padding, mode=self.padding_mode)
|
| 234 |
+
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
def _check_input_shape(self, shape):
|
| 238 |
+
"""Checks the input shape and returns the number of input channels."""
|
| 239 |
+
|
| 240 |
+
if len(shape) == 2:
|
| 241 |
+
self.unsqueeze = True
|
| 242 |
+
in_channels = 1
|
| 243 |
+
elif self.skip_transpose:
|
| 244 |
+
in_channels = shape[1]
|
| 245 |
+
elif len(shape) == 3:
|
| 246 |
+
in_channels = shape[2]
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
|
| 249 |
+
|
| 250 |
+
# Kernel size must be odd
|
| 251 |
+
if self.kernel_size % 2 == 0:
|
| 252 |
+
raise ValueError(
|
| 253 |
+
"The field kernel size must be an odd number. Got %s."
|
| 254 |
+
% (self.kernel_size)
|
| 255 |
+
)
|
| 256 |
+
return in_channels
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class Fp32BatchNorm(nn.Module):
|
| 260 |
+
def __init__(self, sync=True, *args, **kwargs):
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
if (
|
| 264 |
+
not torch.distributed.is_initialized()
|
| 265 |
+
or torch.distributed.get_world_size() == 1
|
| 266 |
+
):
|
| 267 |
+
sync = False
|
| 268 |
+
|
| 269 |
+
if sync:
|
| 270 |
+
self.bn = nn.SyncBatchNorm(*args, **kwargs)
|
| 271 |
+
else:
|
| 272 |
+
self.bn = nn.BatchNorm1d(*args, **kwargs)
|
| 273 |
+
|
| 274 |
+
self.sync = sync
|
| 275 |
+
|
| 276 |
+
def forward(self, input):
|
| 277 |
+
if self.bn.running_mean.dtype != torch.float:
|
| 278 |
+
if self.sync:
|
| 279 |
+
self.bn.running_mean = self.bn.running_mean.float()
|
| 280 |
+
self.bn.running_var = self.bn.running_var.float()
|
| 281 |
+
if self.bn.affine:
|
| 282 |
+
try:
|
| 283 |
+
self.bn.weight = self.bn.weight.float()
|
| 284 |
+
self.bn.bias = self.bn.bias.float()
|
| 285 |
+
except:
|
| 286 |
+
self.bn.float()
|
| 287 |
+
else:
|
| 288 |
+
self.bn.float()
|
| 289 |
+
|
| 290 |
+
output = self.bn(input.float())
|
| 291 |
+
return output.type_as(input)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class BatchNorm1d(nn.Module):
|
| 295 |
+
"""Applies 1d batch normalization to the input tensor.
|
| 296 |
+
|
| 297 |
+
Arguments
|
| 298 |
+
---------
|
| 299 |
+
input_shape : tuple
|
| 300 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
| 301 |
+
input_size : int
|
| 302 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
| 303 |
+
eps : float
|
| 304 |
+
This value is added to std deviation estimation to improve the numerical
|
| 305 |
+
stability.
|
| 306 |
+
momentum : float
|
| 307 |
+
It is a value used for the running_mean and running_var computation.
|
| 308 |
+
affine : bool
|
| 309 |
+
When set to True, the affine parameters are learned.
|
| 310 |
+
track_running_stats : bool
|
| 311 |
+
When set to True, this module tracks the running mean and variance,
|
| 312 |
+
and when set to False, this module does not track such statistics.
|
| 313 |
+
combine_batch_time : bool
|
| 314 |
+
When true, it combines batch an time axis.
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
Example
|
| 318 |
+
-------
|
| 319 |
+
>>> input = torch.randn(100, 10)
|
| 320 |
+
>>> norm = BatchNorm1d(input_shape=input.shape)
|
| 321 |
+
>>> output = norm(input)
|
| 322 |
+
>>> output.shape
|
| 323 |
+
torch.Size([100, 10])
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
input_shape=None,
|
| 329 |
+
input_size=None,
|
| 330 |
+
eps=1e-05,
|
| 331 |
+
momentum=0.1,
|
| 332 |
+
affine=True,
|
| 333 |
+
track_running_stats=True,
|
| 334 |
+
combine_batch_time=False,
|
| 335 |
+
skip_transpose=True,
|
| 336 |
+
enabled=True,
|
| 337 |
+
):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.combine_batch_time = combine_batch_time
|
| 340 |
+
self.skip_transpose = skip_transpose
|
| 341 |
+
|
| 342 |
+
if input_size is None and skip_transpose:
|
| 343 |
+
input_size = input_shape[1]
|
| 344 |
+
elif input_size is None:
|
| 345 |
+
input_size = input_shape[-1]
|
| 346 |
+
|
| 347 |
+
if enabled:
|
| 348 |
+
self.norm = Fp32BatchNorm(
|
| 349 |
+
num_features=input_size,
|
| 350 |
+
eps=eps,
|
| 351 |
+
momentum=momentum,
|
| 352 |
+
affine=affine,
|
| 353 |
+
track_running_stats=track_running_stats,
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
self.norm = nn.Identity()
|
| 357 |
+
|
| 358 |
+
def forward(self, x):
|
| 359 |
+
"""Returns the normalized input tensor.
|
| 360 |
+
|
| 361 |
+
Arguments
|
| 362 |
+
---------
|
| 363 |
+
x : torch.Tensor (batch, time, [channels])
|
| 364 |
+
input to normalize. 2d or 3d tensors are expected in input
|
| 365 |
+
4d tensors can be used when combine_dims=True.
|
| 366 |
+
"""
|
| 367 |
+
shape_or = x.shape
|
| 368 |
+
if self.combine_batch_time:
|
| 369 |
+
if x.ndim == 3:
|
| 370 |
+
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
| 371 |
+
else:
|
| 372 |
+
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
|
| 373 |
+
|
| 374 |
+
elif not self.skip_transpose:
|
| 375 |
+
x = x.transpose(-1, 1)
|
| 376 |
+
|
| 377 |
+
x_n = self.norm(x)
|
| 378 |
+
|
| 379 |
+
if self.combine_batch_time:
|
| 380 |
+
x_n = x_n.reshape(shape_or)
|
| 381 |
+
elif not self.skip_transpose:
|
| 382 |
+
x_n = x_n.transpose(1, -1)
|
| 383 |
+
|
| 384 |
+
return x_n
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class Linear(torch.nn.Module):
|
| 388 |
+
"""Computes a linear transformation y = wx + b.
|
| 389 |
+
|
| 390 |
+
Arguments
|
| 391 |
+
---------
|
| 392 |
+
n_neurons : int
|
| 393 |
+
It is the number of output neurons (i.e, the dimensionality of the
|
| 394 |
+
output).
|
| 395 |
+
bias : bool
|
| 396 |
+
If True, the additive bias b is adopted.
|
| 397 |
+
combine_dims : bool
|
| 398 |
+
If True and the input is 4D, combine 3rd and 4th dimensions of input.
|
| 399 |
+
|
| 400 |
+
Example
|
| 401 |
+
-------
|
| 402 |
+
>>> inputs = torch.rand(10, 50, 40)
|
| 403 |
+
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
|
| 404 |
+
>>> output = lin_t(inputs)
|
| 405 |
+
>>> output.shape
|
| 406 |
+
torch.Size([10, 50, 100])
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
n_neurons,
|
| 412 |
+
input_shape=None,
|
| 413 |
+
input_size=None,
|
| 414 |
+
bias=True,
|
| 415 |
+
combine_dims=False,
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.combine_dims = combine_dims
|
| 419 |
+
|
| 420 |
+
if input_shape is None and input_size is None:
|
| 421 |
+
raise ValueError("Expected one of input_shape or input_size")
|
| 422 |
+
|
| 423 |
+
if input_size is None:
|
| 424 |
+
input_size = input_shape[-1]
|
| 425 |
+
if len(input_shape) == 4 and self.combine_dims:
|
| 426 |
+
input_size = input_shape[2] * input_shape[3]
|
| 427 |
+
|
| 428 |
+
# Weights are initialized following pytorch approach
|
| 429 |
+
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
| 430 |
+
|
| 431 |
+
def forward(self, x):
|
| 432 |
+
"""Returns the linear transformation of input tensor.
|
| 433 |
+
|
| 434 |
+
Arguments
|
| 435 |
+
---------
|
| 436 |
+
x : torch.Tensor
|
| 437 |
+
Input to transform linearly.
|
| 438 |
+
"""
|
| 439 |
+
if x.ndim == 4 and self.combine_dims:
|
| 440 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
| 441 |
+
|
| 442 |
+
wx = self.w(x)
|
| 443 |
+
|
| 444 |
+
return wx
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class TDNNBlock(nn.Module):
|
| 448 |
+
"""An implementation of TDNN.
|
| 449 |
+
|
| 450 |
+
Arguments
|
| 451 |
+
----------
|
| 452 |
+
in_channels : int
|
| 453 |
+
Number of input channels.
|
| 454 |
+
out_channels : int
|
| 455 |
+
The number of output channels.
|
| 456 |
+
kernel_size : int
|
| 457 |
+
The kernel size of the TDNN blocks.
|
| 458 |
+
dilation : int
|
| 459 |
+
The dilation of the Res2Net block.
|
| 460 |
+
activation : torch class
|
| 461 |
+
A class for constructing the activation layers.
|
| 462 |
+
|
| 463 |
+
Example
|
| 464 |
+
-------
|
| 465 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 466 |
+
>>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
|
| 467 |
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
| 468 |
+
>>> out_tensor.shape
|
| 469 |
+
torch.Size([8, 120, 64])
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
def __init__(
|
| 473 |
+
self,
|
| 474 |
+
in_channels,
|
| 475 |
+
out_channels,
|
| 476 |
+
kernel_size,
|
| 477 |
+
dilation,
|
| 478 |
+
activation=nn.ReLU,
|
| 479 |
+
batch_norm=True,
|
| 480 |
+
):
|
| 481 |
+
super(TDNNBlock, self).__init__()
|
| 482 |
+
self.conv = Conv1d(
|
| 483 |
+
in_channels=in_channels,
|
| 484 |
+
out_channels=out_channels,
|
| 485 |
+
kernel_size=kernel_size,
|
| 486 |
+
dilation=dilation,
|
| 487 |
+
)
|
| 488 |
+
self.activation = activation()
|
| 489 |
+
self.norm = BatchNorm1d(input_size=out_channels, enabled=batch_norm)
|
| 490 |
+
|
| 491 |
+
def forward(self, x):
|
| 492 |
+
return self.norm(self.activation(self.conv(x)))
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class Res2NetBlock(torch.nn.Module):
|
| 496 |
+
"""An implementation of Res2NetBlock w/ dilation.
|
| 497 |
+
|
| 498 |
+
Arguments
|
| 499 |
+
---------
|
| 500 |
+
in_channels : int
|
| 501 |
+
The number of channels expected in the input.
|
| 502 |
+
out_channels : int
|
| 503 |
+
The number of output channels.
|
| 504 |
+
scale : int
|
| 505 |
+
The scale of the Res2Net block.
|
| 506 |
+
kernel_size: int
|
| 507 |
+
The kernel size of the Res2Net block.
|
| 508 |
+
dilation : int
|
| 509 |
+
The dilation of the Res2Net block.
|
| 510 |
+
|
| 511 |
+
Example
|
| 512 |
+
-------
|
| 513 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 514 |
+
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
|
| 515 |
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
| 516 |
+
>>> out_tensor.shape
|
| 517 |
+
torch.Size([8, 120, 64])
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
def __init__(
|
| 521 |
+
self,
|
| 522 |
+
in_channels,
|
| 523 |
+
out_channels,
|
| 524 |
+
scale=8,
|
| 525 |
+
kernel_size=3,
|
| 526 |
+
dilation=1,
|
| 527 |
+
batch_norm=True,
|
| 528 |
+
):
|
| 529 |
+
super(Res2NetBlock, self).__init__()
|
| 530 |
+
assert in_channels % scale == 0
|
| 531 |
+
assert out_channels % scale == 0
|
| 532 |
+
|
| 533 |
+
in_channel = in_channels // scale
|
| 534 |
+
hidden_channel = out_channels // scale
|
| 535 |
+
|
| 536 |
+
self.blocks = nn.ModuleList(
|
| 537 |
+
[
|
| 538 |
+
TDNNBlock(
|
| 539 |
+
in_channel,
|
| 540 |
+
hidden_channel,
|
| 541 |
+
kernel_size=kernel_size,
|
| 542 |
+
dilation=dilation,
|
| 543 |
+
batch_norm=batch_norm,
|
| 544 |
+
)
|
| 545 |
+
for i in range(scale - 1)
|
| 546 |
+
]
|
| 547 |
+
)
|
| 548 |
+
self.scale = scale
|
| 549 |
+
|
| 550 |
+
def forward(self, x):
|
| 551 |
+
y = []
|
| 552 |
+
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
| 553 |
+
if i == 0:
|
| 554 |
+
y_i = x_i
|
| 555 |
+
elif i == 1:
|
| 556 |
+
y_i = self.blocks[i - 1](x_i)
|
| 557 |
+
else:
|
| 558 |
+
y_i = self.blocks[i - 1](x_i + y_i)
|
| 559 |
+
y.append(y_i)
|
| 560 |
+
y = torch.cat(y, dim=1)
|
| 561 |
+
return y
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class SEBlock(nn.Module):
|
| 565 |
+
"""An implementation of squeeze-and-excitation block.
|
| 566 |
+
|
| 567 |
+
Arguments
|
| 568 |
+
---------
|
| 569 |
+
in_channels : int
|
| 570 |
+
The number of input channels.
|
| 571 |
+
se_channels : int
|
| 572 |
+
The number of output channels after squeeze.
|
| 573 |
+
out_channels : int
|
| 574 |
+
The number of output channels.
|
| 575 |
+
|
| 576 |
+
Example
|
| 577 |
+
-------
|
| 578 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 579 |
+
>>> se_layer = SEBlock(64, 16, 64)
|
| 580 |
+
>>> lengths = torch.rand((8,))
|
| 581 |
+
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
|
| 582 |
+
>>> out_tensor.shape
|
| 583 |
+
torch.Size([8, 120, 64])
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
def __init__(self, in_channels, se_channels, out_channels):
|
| 587 |
+
super(SEBlock, self).__init__()
|
| 588 |
+
|
| 589 |
+
self.conv1 = Conv1d(
|
| 590 |
+
in_channels=in_channels, out_channels=se_channels, kernel_size=1
|
| 591 |
+
)
|
| 592 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
| 593 |
+
self.conv2 = Conv1d(
|
| 594 |
+
in_channels=se_channels, out_channels=out_channels, kernel_size=1
|
| 595 |
+
)
|
| 596 |
+
self.sigmoid = torch.nn.Sigmoid()
|
| 597 |
+
|
| 598 |
+
def forward(self, x, lengths=None):
|
| 599 |
+
L = x.shape[-1]
|
| 600 |
+
if lengths is not None:
|
| 601 |
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
| 602 |
+
mask = mask.unsqueeze(1)
|
| 603 |
+
total = mask.sum(dim=2, keepdim=True)
|
| 604 |
+
s = (x * mask).sum(dim=2, keepdim=True) / total
|
| 605 |
+
else:
|
| 606 |
+
s = x.mean(dim=2, keepdim=True)
|
| 607 |
+
|
| 608 |
+
s = self.relu(self.conv1(s))
|
| 609 |
+
s = self.sigmoid(self.conv2(s))
|
| 610 |
+
|
| 611 |
+
return s * x
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class AttentiveStatisticsPooling(nn.Module):
|
| 615 |
+
"""This class implements an attentive statistic pooling layer for each channel.
|
| 616 |
+
It returns the concatenated mean and std of the input tensor.
|
| 617 |
+
|
| 618 |
+
Arguments
|
| 619 |
+
---------
|
| 620 |
+
channels: int
|
| 621 |
+
The number of input channels.
|
| 622 |
+
attention_channels: int
|
| 623 |
+
The number of attention channels.
|
| 624 |
+
|
| 625 |
+
Example
|
| 626 |
+
-------
|
| 627 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 628 |
+
>>> asp_layer = AttentiveStatisticsPooling(64)
|
| 629 |
+
>>> lengths = torch.rand((8,))
|
| 630 |
+
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
|
| 631 |
+
>>> out_tensor.shape
|
| 632 |
+
torch.Size([8, 1, 128])
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
def __init__(
|
| 636 |
+
self, channels, attention_channels=128, global_context=True, batch_norm=True
|
| 637 |
+
):
|
| 638 |
+
super().__init__()
|
| 639 |
+
|
| 640 |
+
self.eps = 1e-12
|
| 641 |
+
self.global_context = global_context
|
| 642 |
+
if global_context:
|
| 643 |
+
self.tdnn = TDNNBlock(
|
| 644 |
+
channels * 3, attention_channels, 1, 1, batch_norm=batch_norm
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
self.tdnn = TDNNBlock(
|
| 648 |
+
channels, attention_channels, 1, 1, batch_norm, batch_norm
|
| 649 |
+
)
|
| 650 |
+
self.tanh = nn.Tanh()
|
| 651 |
+
self.conv = Conv1d(
|
| 652 |
+
in_channels=attention_channels, out_channels=channels, kernel_size=1
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, x, lengths=None):
|
| 656 |
+
"""Calculates mean and std for a batch (input tensor).
|
| 657 |
+
|
| 658 |
+
Arguments
|
| 659 |
+
---------
|
| 660 |
+
x : torch.Tensor
|
| 661 |
+
Tensor of shape [N, C, L].
|
| 662 |
+
"""
|
| 663 |
+
L = x.shape[-1]
|
| 664 |
+
|
| 665 |
+
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
| 666 |
+
mean = (m * x).sum(dim)
|
| 667 |
+
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
| 668 |
+
return mean, std
|
| 669 |
+
|
| 670 |
+
if lengths is None:
|
| 671 |
+
lengths = torch.ones(x.shape[0], device=x.device)
|
| 672 |
+
|
| 673 |
+
# Make binary mask of shape [N, 1, L]
|
| 674 |
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
| 675 |
+
mask = mask.unsqueeze(1)
|
| 676 |
+
|
| 677 |
+
# Expand the temporal context of the pooling layer by allowing the
|
| 678 |
+
# self-attention to look at global properties of the utterance.
|
| 679 |
+
if self.global_context:
|
| 680 |
+
# torch.std is unstable for backward computation
|
| 681 |
+
# https://github.com/pytorch/pytorch/issues/4320
|
| 682 |
+
total = mask.sum(dim=2, keepdim=True).float()
|
| 683 |
+
mean, std = _compute_statistics(x, mask / total)
|
| 684 |
+
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
| 685 |
+
std = std.unsqueeze(2).repeat(1, 1, L)
|
| 686 |
+
attn = torch.cat([x, mean, std], dim=1)
|
| 687 |
+
else:
|
| 688 |
+
attn = x
|
| 689 |
+
|
| 690 |
+
# Apply layers
|
| 691 |
+
attn = self.conv(self.tanh(self.tdnn(attn)))
|
| 692 |
+
|
| 693 |
+
# Filter out zero-paddings
|
| 694 |
+
attn = attn.masked_fill(mask == 0, float("-inf"))
|
| 695 |
+
|
| 696 |
+
attn = F.softmax(attn, dim=2)
|
| 697 |
+
mean, std = _compute_statistics(x, attn)
|
| 698 |
+
# Append mean and std of the batch
|
| 699 |
+
pooled_stats = torch.cat((mean, std), dim=1)
|
| 700 |
+
pooled_stats = pooled_stats.unsqueeze(2)
|
| 701 |
+
|
| 702 |
+
return pooled_stats
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
class SERes2NetBlock(nn.Module):
|
| 706 |
+
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
| 707 |
+
TDNN-Res2Net-TDNN-SEBlock.
|
| 708 |
+
|
| 709 |
+
Arguments
|
| 710 |
+
----------
|
| 711 |
+
out_channels: int
|
| 712 |
+
The number of output channels.
|
| 713 |
+
res2net_scale: int
|
| 714 |
+
The scale of the Res2Net block.
|
| 715 |
+
kernel_size: int
|
| 716 |
+
The kernel size of the TDNN blocks.
|
| 717 |
+
dilation: int
|
| 718 |
+
The dilation of the Res2Net block.
|
| 719 |
+
activation : torch class
|
| 720 |
+
A class for constructing the activation layers.
|
| 721 |
+
|
| 722 |
+
Example
|
| 723 |
+
-------
|
| 724 |
+
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
|
| 725 |
+
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
|
| 726 |
+
>>> out = conv(x).transpose(1, 2)
|
| 727 |
+
>>> out.shape
|
| 728 |
+
torch.Size([8, 120, 64])
|
| 729 |
+
"""
|
| 730 |
+
|
| 731 |
+
def __init__(
|
| 732 |
+
self,
|
| 733 |
+
in_channels,
|
| 734 |
+
out_channels,
|
| 735 |
+
res2net_scale=8,
|
| 736 |
+
se_channels=128,
|
| 737 |
+
kernel_size=1,
|
| 738 |
+
dilation=1,
|
| 739 |
+
activation=torch.nn.ReLU,
|
| 740 |
+
batch_norm=True,
|
| 741 |
+
):
|
| 742 |
+
super().__init__()
|
| 743 |
+
self.out_channels = out_channels
|
| 744 |
+
self.tdnn1 = TDNNBlock(
|
| 745 |
+
in_channels,
|
| 746 |
+
out_channels,
|
| 747 |
+
kernel_size=1,
|
| 748 |
+
dilation=1,
|
| 749 |
+
activation=activation,
|
| 750 |
+
batch_norm=batch_norm,
|
| 751 |
+
)
|
| 752 |
+
self.res2net_block = Res2NetBlock(
|
| 753 |
+
out_channels,
|
| 754 |
+
out_channels,
|
| 755 |
+
res2net_scale,
|
| 756 |
+
kernel_size,
|
| 757 |
+
dilation,
|
| 758 |
+
batch_norm=batch_norm,
|
| 759 |
+
)
|
| 760 |
+
self.tdnn2 = TDNNBlock(
|
| 761 |
+
out_channels,
|
| 762 |
+
out_channels,
|
| 763 |
+
kernel_size=1,
|
| 764 |
+
dilation=1,
|
| 765 |
+
activation=activation,
|
| 766 |
+
batch_norm=batch_norm,
|
| 767 |
+
)
|
| 768 |
+
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
| 769 |
+
|
| 770 |
+
self.shortcut = None
|
| 771 |
+
if in_channels != out_channels:
|
| 772 |
+
self.shortcut = Conv1d(
|
| 773 |
+
in_channels=in_channels,
|
| 774 |
+
out_channels=out_channels,
|
| 775 |
+
kernel_size=1,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
def forward(self, x, lengths=None):
|
| 779 |
+
residual = x
|
| 780 |
+
if self.shortcut:
|
| 781 |
+
residual = self.shortcut(x)
|
| 782 |
+
|
| 783 |
+
x = self.tdnn1(x)
|
| 784 |
+
x = self.res2net_block(x)
|
| 785 |
+
x = self.tdnn2(x)
|
| 786 |
+
x = self.se_block(x, lengths)
|
| 787 |
+
|
| 788 |
+
return x + residual
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class ECAPA_TDNN(torch.nn.Module):
|
| 792 |
+
"""An implementation of the speaker embedding model in a paper.
|
| 793 |
+
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
| 794 |
+
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
| 795 |
+
|
| 796 |
+
Arguments
|
| 797 |
+
---------
|
| 798 |
+
device : str
|
| 799 |
+
Device used, e.g., "cpu" or "cuda".
|
| 800 |
+
activation : torch class
|
| 801 |
+
A class for constructing the activation layers.
|
| 802 |
+
channels : list of ints
|
| 803 |
+
Output channels for TDNN/SERes2Net layer.
|
| 804 |
+
kernel_sizes : list of ints
|
| 805 |
+
List of kernel sizes for each layer.
|
| 806 |
+
dilations : list of ints
|
| 807 |
+
List of dilations for kernels in each layer.
|
| 808 |
+
lin_neurons : int
|
| 809 |
+
Number of neurons in linear layers.
|
| 810 |
+
|
| 811 |
+
Example
|
| 812 |
+
-------
|
| 813 |
+
>>> input_feats = torch.rand([5, 120, 80])
|
| 814 |
+
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
|
| 815 |
+
>>> outputs = compute_embedding(input_feats)
|
| 816 |
+
>>> outputs.shape
|
| 817 |
+
torch.Size([5, 1, 192])
|
| 818 |
+
"""
|
| 819 |
+
|
| 820 |
+
def __init__(
|
| 821 |
+
self,
|
| 822 |
+
input_size,
|
| 823 |
+
lin_neurons=192,
|
| 824 |
+
activation=torch.nn.ReLU,
|
| 825 |
+
channels=[512, 512, 512, 512, 1536],
|
| 826 |
+
kernel_sizes=[5, 3, 3, 3, 1],
|
| 827 |
+
dilations=[1, 2, 3, 4, 1],
|
| 828 |
+
attention_channels=128,
|
| 829 |
+
res2net_scale=8,
|
| 830 |
+
se_channels=128,
|
| 831 |
+
global_context=True,
|
| 832 |
+
batch_norm=True,
|
| 833 |
+
):
|
| 834 |
+
|
| 835 |
+
super().__init__()
|
| 836 |
+
assert len(channels) == len(kernel_sizes)
|
| 837 |
+
assert len(channels) == len(dilations)
|
| 838 |
+
self.channels = channels
|
| 839 |
+
self.blocks = nn.ModuleList()
|
| 840 |
+
|
| 841 |
+
# The initial TDNN layer
|
| 842 |
+
self.blocks.append(
|
| 843 |
+
TDNNBlock(
|
| 844 |
+
input_size,
|
| 845 |
+
channels[0],
|
| 846 |
+
kernel_sizes[0],
|
| 847 |
+
dilations[0],
|
| 848 |
+
activation,
|
| 849 |
+
batch_norm=batch_norm,
|
| 850 |
+
)
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# SE-Res2Net layers
|
| 854 |
+
for i in range(1, len(channels) - 1):
|
| 855 |
+
self.blocks.append(
|
| 856 |
+
SERes2NetBlock(
|
| 857 |
+
channels[i - 1],
|
| 858 |
+
channels[i],
|
| 859 |
+
res2net_scale=res2net_scale,
|
| 860 |
+
se_channels=se_channels,
|
| 861 |
+
kernel_size=kernel_sizes[i],
|
| 862 |
+
dilation=dilations[i],
|
| 863 |
+
activation=activation,
|
| 864 |
+
batch_norm=batch_norm,
|
| 865 |
+
)
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# Multi-layer feature aggregation
|
| 869 |
+
self.mfa = TDNNBlock(
|
| 870 |
+
channels[-1],
|
| 871 |
+
channels[-1],
|
| 872 |
+
kernel_sizes[-1],
|
| 873 |
+
dilations[-1],
|
| 874 |
+
activation,
|
| 875 |
+
batch_norm=batch_norm,
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
# Attentive Statistical Pooling
|
| 879 |
+
self.asp = AttentiveStatisticsPooling(
|
| 880 |
+
channels[-1],
|
| 881 |
+
attention_channels=attention_channels,
|
| 882 |
+
global_context=global_context,
|
| 883 |
+
batch_norm=batch_norm,
|
| 884 |
+
)
|
| 885 |
+
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2, enabled=batch_norm)
|
| 886 |
+
|
| 887 |
+
# Final linear transformation
|
| 888 |
+
self.fc = Conv1d(
|
| 889 |
+
in_channels=channels[-1] * 2,
|
| 890 |
+
out_channels=lin_neurons,
|
| 891 |
+
kernel_size=1,
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
# @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
|
| 895 |
+
def forward(self, x, lengths=None):
|
| 896 |
+
"""Returns the embedding vector.
|
| 897 |
+
|
| 898 |
+
Arguments
|
| 899 |
+
---------
|
| 900 |
+
x : torch.Tensor
|
| 901 |
+
Tensor of shape (batch, time, channel).
|
| 902 |
+
"""
|
| 903 |
+
# Minimize transpose for efficiency
|
| 904 |
+
x = x.transpose(1, 2)
|
| 905 |
+
|
| 906 |
+
xl = []
|
| 907 |
+
for layer in self.blocks:
|
| 908 |
+
try:
|
| 909 |
+
x = layer(x, lengths=lengths)
|
| 910 |
+
except TypeError:
|
| 911 |
+
x = layer(x)
|
| 912 |
+
xl.append(x)
|
| 913 |
+
|
| 914 |
+
# Multi-layer feature aggregation
|
| 915 |
+
x = torch.cat(xl[1:], dim=1)
|
| 916 |
+
x = self.mfa(x)
|
| 917 |
+
|
| 918 |
+
# Attentive Statistical Pooling
|
| 919 |
+
x = self.asp(x, lengths=lengths)
|
| 920 |
+
x = self.asp_bn(x)
|
| 921 |
+
|
| 922 |
+
# Final linear transformation
|
| 923 |
+
x = self.fc(x)
|
| 924 |
+
|
| 925 |
+
x = x.squeeze(-1)
|
| 926 |
+
return x
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
if __name__ == "__main__":
|
| 930 |
+
model = ECAPA_TDNN(128, batch_norm=False)
|
| 931 |
+
# print(model)
|
fireredtts/modules/semantic_tokenizer/hubert.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fairseq import checkpoint_utils
|
| 2 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
| 11 |
+
max_len = torch.max(lengths).item() if max_len is None else max_len
|
| 12 |
+
ids = torch.arange(0, max_len).to(lengths.device)
|
| 13 |
+
mask = ~(ids < lengths.unsqueeze(1)).bool()
|
| 14 |
+
return mask
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HuBERT(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(self, model_path, sampling_rate=16000):
|
| 20 |
+
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task(
|
| 24 |
+
[model_path],
|
| 25 |
+
suffix="",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
model = models[0]
|
| 29 |
+
model = model.half()
|
| 30 |
+
model.eval()
|
| 31 |
+
self.model = model
|
| 32 |
+
|
| 33 |
+
for param in self.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
|
| 36 |
+
self.sampling_rate = sampling_rate
|
| 37 |
+
self.normalize = saved_cfg.task.normalize
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
@torch.cuda.amp.autocast(enabled=False, dtype=torch.float16)
|
| 41 |
+
def forward(self, inp, length=None, split=True, split_size=4):
|
| 42 |
+
self.model.eval()
|
| 43 |
+
if self.training and split:
|
| 44 |
+
split_size = int(math.ceil(inp.shape[0] / 4))
|
| 45 |
+
outs, out_lens = [], []
|
| 46 |
+
for i in range(0, inp.shape[0], split_size):
|
| 47 |
+
inp_, length_ = inp[i : i + split_size], length[i : i + split_size]
|
| 48 |
+
out_, out_len_ = self._extract(inp_, length_)
|
| 49 |
+
outs.append(out_)
|
| 50 |
+
out_lens.append(out_len_)
|
| 51 |
+
max_length = max([max(ols) for ols in out_lens])
|
| 52 |
+
|
| 53 |
+
return torch.cat(
|
| 54 |
+
[F.pad(o, (0, 0, 0, max_length - o.shape[1]), value=0) for o in outs],
|
| 55 |
+
dim=0,
|
| 56 |
+
), torch.cat(out_lens, dim=0)
|
| 57 |
+
else:
|
| 58 |
+
return self._extract(inp, length)
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def _extract(self, inp, length):
|
| 62 |
+
frame_samples = int(self.sampling_rate * 0.02)
|
| 63 |
+
device = inp.device
|
| 64 |
+
|
| 65 |
+
if len(inp.shape) == 3:
|
| 66 |
+
inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
|
| 67 |
+
assert len(inp.shape) == 2
|
| 68 |
+
assert self.sampling_rate == 16000
|
| 69 |
+
|
| 70 |
+
feats = inp
|
| 71 |
+
|
| 72 |
+
# Padding with 0
|
| 73 |
+
padding_size = 3200 # Longer to cover receptive field
|
| 74 |
+
feats = F.pad(feats, (0, padding_size), mode="constant", value=0)
|
| 75 |
+
|
| 76 |
+
# Norm volume using LN
|
| 77 |
+
feats = self._postprocess(
|
| 78 |
+
feats, length + padding_size, normalize=self.normalize
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if length is None:
|
| 82 |
+
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
| 83 |
+
else:
|
| 84 |
+
length = torch.ceil(length / 320).int()
|
| 85 |
+
padding_mask = get_mask_from_lengths(length).bool()
|
| 86 |
+
padding_mask = F.pad(padding_mask, (0, 9), value=True)
|
| 87 |
+
|
| 88 |
+
inputs = {
|
| 89 |
+
"source": feats.half().to(device),
|
| 90 |
+
"padding_mask": padding_mask.to(device),
|
| 91 |
+
"mask": False,
|
| 92 |
+
}
|
| 93 |
+
logits, _ = self.model.extract_features(**inputs)
|
| 94 |
+
logits = logits[:, : length.max()].float()
|
| 95 |
+
|
| 96 |
+
return logits, length
|
| 97 |
+
|
| 98 |
+
def _postprocess(self, feats, lengths, normalize=False):
|
| 99 |
+
assert feats.dim() == 2, feats.dim()
|
| 100 |
+
|
| 101 |
+
if normalize:
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
feats = [
|
| 104 |
+
F.layer_norm(feat[:length], feat[:length].shape)
|
| 105 |
+
for feat, length in zip(feats, lengths)
|
| 106 |
+
]
|
| 107 |
+
feats = pad_sequence(feats, batch_first=True, padding_value=0)
|
| 108 |
+
return feats
|
fireredtts/modules/semantic_tokenizer/semantic_tokenizer.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import reduce
|
| 2 |
+
from tokenize import Triple
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.autograd import Function
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
| 7 |
+
from torch.utils.checkpoint import checkpoint
|
| 8 |
+
|
| 9 |
+
import einops
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
import typing as tp
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from .audio import TorchMelSpectrogram
|
| 20 |
+
from .ecapa_tdnn import ECAPA_TDNN
|
| 21 |
+
from .hubert import HuBERT
|
| 22 |
+
from ..acoustic_codec.vector_quantization import VectorQuantization
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
CONV_NORMALIZATIONS = frozenset(
|
| 26 |
+
[
|
| 27 |
+
"none",
|
| 28 |
+
"weight_norm",
|
| 29 |
+
"spectral_norm",
|
| 30 |
+
"time_layer_norm",
|
| 31 |
+
"layer_norm",
|
| 32 |
+
"time_group_norm",
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
NORM = "weight_norm"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
| 39 |
+
max_len = torch.max(lengths).item() if max_len is None else max_len
|
| 40 |
+
ids = torch.arange(0, max_len).to(lengths.device)
|
| 41 |
+
mask = ~(ids < lengths.unsqueeze(1)).bool()
|
| 42 |
+
return mask
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ConvLayerNorm(nn.LayerNorm):
|
| 46 |
+
def __init__(
|
| 47 |
+
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
|
| 48 |
+
):
|
| 49 |
+
super().__init__(normalized_shape, **kwargs)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = einops.rearrange(x, "b ... t -> b t ...")
|
| 53 |
+
x = super().forward(x)
|
| 54 |
+
x = einops.rearrange(x, "b t ... -> b ... t")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
| 59 |
+
assert norm in CONV_NORMALIZATIONS
|
| 60 |
+
if norm == "weight_norm":
|
| 61 |
+
return weight_norm(module)
|
| 62 |
+
elif norm == "spectral_norm":
|
| 63 |
+
return spectral_norm(module)
|
| 64 |
+
else:
|
| 65 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 66 |
+
# doesn't need reparametrization.
|
| 67 |
+
return module
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_norm_module(
|
| 71 |
+
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
| 72 |
+
) -> nn.Module:
|
| 73 |
+
assert norm in CONV_NORMALIZATIONS
|
| 74 |
+
if norm == "layer_norm":
|
| 75 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 76 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
| 77 |
+
elif norm == "time_group_norm":
|
| 78 |
+
if causal:
|
| 79 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 80 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 81 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 82 |
+
else:
|
| 83 |
+
return nn.Identity()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_extra_padding_for_conv1d(
|
| 87 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 88 |
+
) -> int:
|
| 89 |
+
length = x.shape[-1]
|
| 90 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 91 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 92 |
+
return ideal_length - length
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def pad_for_conv1d(
|
| 96 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 97 |
+
):
|
| 98 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 99 |
+
return F.pad(x, (0, extra_padding))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def pad1d(
|
| 103 |
+
x: torch.Tensor,
|
| 104 |
+
paddings: tp.Tuple[int, int],
|
| 105 |
+
mode: str = "zero",
|
| 106 |
+
value: float = 0.0,
|
| 107 |
+
):
|
| 108 |
+
length = x.shape[-1]
|
| 109 |
+
padding_left, padding_right = paddings
|
| 110 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 111 |
+
if mode == "reflect":
|
| 112 |
+
max_pad = max(padding_left, padding_right)
|
| 113 |
+
extra_pad = 0
|
| 114 |
+
if length <= max_pad:
|
| 115 |
+
extra_pad = max_pad - length + 1
|
| 116 |
+
x = F.pad(x, (0, extra_pad))
|
| 117 |
+
padded = F.pad(x, paddings, mode, value)
|
| 118 |
+
end = padded.shape[-1] - extra_pad
|
| 119 |
+
return padded[..., :end]
|
| 120 |
+
else:
|
| 121 |
+
return F.pad(x, paddings, mode, value)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 125 |
+
padding_left, padding_right = paddings
|
| 126 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 127 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 128 |
+
end = x.shape[-1] - padding_right
|
| 129 |
+
return x[..., padding_left:end]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class NormConv1d(nn.Module):
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
*args,
|
| 137 |
+
causal: bool = False,
|
| 138 |
+
norm: str = "none",
|
| 139 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 140 |
+
**kwargs,
|
| 141 |
+
):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 144 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 145 |
+
self.norm_type = norm
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
x = self.conv(x)
|
| 149 |
+
x = self.norm(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class NormConv2d(nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
*args,
|
| 158 |
+
norm: str = "none",
|
| 159 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 160 |
+
**kwargs,
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
| 164 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
| 165 |
+
self.norm_type = norm
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
x = self.conv(x)
|
| 169 |
+
x = self.norm(x)
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class NormConvTranspose1d(nn.Module):
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
*args,
|
| 178 |
+
causal: bool = False,
|
| 179 |
+
norm: str = "none",
|
| 180 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 181 |
+
**kwargs,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.convtr = apply_parametrization_norm(
|
| 185 |
+
nn.ConvTranspose1d(*args, **kwargs), norm
|
| 186 |
+
)
|
| 187 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 188 |
+
self.norm_type = norm
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
x = self.convtr(x)
|
| 192 |
+
x = self.norm(x)
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class NormConvTranspose2d(nn.Module):
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
*args,
|
| 201 |
+
norm: str = "none",
|
| 202 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.convtr = apply_parametrization_norm(
|
| 207 |
+
nn.ConvTranspose2d(*args, **kwargs), norm
|
| 208 |
+
)
|
| 209 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
x = self.convtr(x)
|
| 213 |
+
x = self.norm(x)
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class SConv1d(nn.Module):
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
in_channels: int,
|
| 222 |
+
out_channels: int,
|
| 223 |
+
kernel_size: int,
|
| 224 |
+
stride: int = 1,
|
| 225 |
+
dilation: int = 1,
|
| 226 |
+
groups: int = 1,
|
| 227 |
+
bias: bool = True,
|
| 228 |
+
causal: bool = False,
|
| 229 |
+
norm: str = "weight_norm",
|
| 230 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 231 |
+
pad_mode: str = "reflect",
|
| 232 |
+
):
|
| 233 |
+
super().__init__()
|
| 234 |
+
# warn user on unusual setup between dilation and stride
|
| 235 |
+
if stride > 1 and dilation > 1:
|
| 236 |
+
warnings.warn(
|
| 237 |
+
"SConv1d has been initialized with stride > 1 and dilation > 1"
|
| 238 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
| 239 |
+
)
|
| 240 |
+
self.conv = NormConv1d(
|
| 241 |
+
in_channels,
|
| 242 |
+
out_channels,
|
| 243 |
+
kernel_size,
|
| 244 |
+
stride,
|
| 245 |
+
dilation=dilation,
|
| 246 |
+
groups=groups,
|
| 247 |
+
bias=bias,
|
| 248 |
+
causal=causal,
|
| 249 |
+
norm=norm,
|
| 250 |
+
norm_kwargs=norm_kwargs,
|
| 251 |
+
)
|
| 252 |
+
self.causal = causal
|
| 253 |
+
self.pad_mode = pad_mode
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
B, C, T = x.shape
|
| 257 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
| 258 |
+
stride = self.conv.conv.stride[0]
|
| 259 |
+
dilation = self.conv.conv.dilation[0]
|
| 260 |
+
kernel_size = (
|
| 261 |
+
kernel_size - 1
|
| 262 |
+
) * dilation + 1 # effective kernel size with dilations
|
| 263 |
+
padding_total = kernel_size - stride
|
| 264 |
+
extra_padding = get_extra_padding_for_conv1d(
|
| 265 |
+
x, kernel_size, stride, padding_total
|
| 266 |
+
)
|
| 267 |
+
if self.causal:
|
| 268 |
+
# Left padding for causal
|
| 269 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 270 |
+
else:
|
| 271 |
+
# Asymmetric padding required for odd strides
|
| 272 |
+
padding_right = padding_total // 2
|
| 273 |
+
padding_left = padding_total - padding_right
|
| 274 |
+
x = pad1d(
|
| 275 |
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
| 276 |
+
)
|
| 277 |
+
return self.conv(x)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class SConvTranspose1d(nn.Module):
|
| 281 |
+
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
in_channels: int,
|
| 285 |
+
out_channels: int,
|
| 286 |
+
kernel_size: int,
|
| 287 |
+
stride: int = 1,
|
| 288 |
+
causal: bool = False,
|
| 289 |
+
norm: str = "weight_norm",
|
| 290 |
+
trim_right_ratio: float = 1.0,
|
| 291 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 292 |
+
):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.convtr = NormConvTranspose1d(
|
| 295 |
+
in_channels,
|
| 296 |
+
out_channels,
|
| 297 |
+
kernel_size,
|
| 298 |
+
stride,
|
| 299 |
+
causal=causal,
|
| 300 |
+
norm=norm,
|
| 301 |
+
norm_kwargs=norm_kwargs,
|
| 302 |
+
)
|
| 303 |
+
self.causal = causal
|
| 304 |
+
self.trim_right_ratio = trim_right_ratio
|
| 305 |
+
assert (
|
| 306 |
+
self.causal or self.trim_right_ratio == 1.0
|
| 307 |
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 308 |
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
| 309 |
+
|
| 310 |
+
def forward(self, x):
|
| 311 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 312 |
+
stride = self.convtr.convtr.stride[0]
|
| 313 |
+
padding_total = kernel_size - stride
|
| 314 |
+
|
| 315 |
+
y = self.convtr(x)
|
| 316 |
+
|
| 317 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 318 |
+
# removed at the very end, when keeping only the right length for the output,
|
| 319 |
+
# as removing it here would require also passing the length at the matching layer
|
| 320 |
+
# in the encoder.
|
| 321 |
+
if self.causal:
|
| 322 |
+
# Trim the padding on the right according to the specified ratio
|
| 323 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
| 324 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 325 |
+
padding_left = padding_total - padding_right
|
| 326 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 327 |
+
else:
|
| 328 |
+
# Asymmetric padding required for odd strides
|
| 329 |
+
padding_right = padding_total // 2
|
| 330 |
+
padding_left = padding_total - padding_right
|
| 331 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 332 |
+
return y
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class SLSTM(nn.Module):
|
| 336 |
+
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
dimension: int,
|
| 340 |
+
num_layers: int = 2,
|
| 341 |
+
bidirectional: bool = False,
|
| 342 |
+
skip: bool = True,
|
| 343 |
+
):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.bidirectional = bidirectional
|
| 346 |
+
self.skip = skip
|
| 347 |
+
if bidirectional:
|
| 348 |
+
self.lstm = nn.LSTM(
|
| 349 |
+
dimension, dimension // 2, num_layers, bidirectional=bidirectional
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
| 353 |
+
|
| 354 |
+
def forward(self, x):
|
| 355 |
+
x = x.permute(2, 0, 1)
|
| 356 |
+
y, _ = self.lstm(x)
|
| 357 |
+
if self.skip:
|
| 358 |
+
y = y + x
|
| 359 |
+
y = y.permute(1, 2, 0)
|
| 360 |
+
return y
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class Swish(nn.Module):
|
| 364 |
+
def forward(self, x):
|
| 365 |
+
return x * torch.sigmoid(x)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class ResidualUnit(nn.Module):
|
| 369 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
|
| 370 |
+
super().__init__()
|
| 371 |
+
|
| 372 |
+
self.layers = nn.Sequential(
|
| 373 |
+
SConv1d(
|
| 374 |
+
in_channels=in_channels,
|
| 375 |
+
out_channels=out_channels // 2,
|
| 376 |
+
kernel_size=kernel_size,
|
| 377 |
+
groups=groups,
|
| 378 |
+
norm=NORM,
|
| 379 |
+
),
|
| 380 |
+
Swish(),
|
| 381 |
+
SConv1d(
|
| 382 |
+
in_channels=out_channels // 2,
|
| 383 |
+
out_channels=out_channels,
|
| 384 |
+
kernel_size=kernel_size,
|
| 385 |
+
groups=groups,
|
| 386 |
+
norm=NORM,
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def forward(self, x):
|
| 391 |
+
return x + self.layers(x)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class EncoderBlock(nn.Module):
|
| 395 |
+
def __init__(self, out_channels, stride):
|
| 396 |
+
super().__init__()
|
| 397 |
+
|
| 398 |
+
self.layers = nn.Sequential(
|
| 399 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels),
|
| 400 |
+
Swish(),
|
| 401 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels),
|
| 402 |
+
Swish(),
|
| 403 |
+
SConv1d(
|
| 404 |
+
in_channels=out_channels,
|
| 405 |
+
out_channels=out_channels,
|
| 406 |
+
kernel_size=2 * stride,
|
| 407 |
+
stride=stride,
|
| 408 |
+
norm=NORM,
|
| 409 |
+
),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def forward(self, x):
|
| 413 |
+
return self.layers(x)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class DecoderBlock(nn.Module):
|
| 417 |
+
def __init__(self, in_channels, stride):
|
| 418 |
+
super().__init__()
|
| 419 |
+
out_channels = in_channels
|
| 420 |
+
self.layers = nn.Sequential(
|
| 421 |
+
SConvTranspose1d(
|
| 422 |
+
in_channels=in_channels,
|
| 423 |
+
out_channels=out_channels,
|
| 424 |
+
kernel_size=2 * stride,
|
| 425 |
+
stride=stride,
|
| 426 |
+
norm=NORM,
|
| 427 |
+
),
|
| 428 |
+
Swish(),
|
| 429 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels),
|
| 430 |
+
Swish(),
|
| 431 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels),
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def forward(self, x):
|
| 435 |
+
return self.layers(x)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class Encoder(nn.Module):
|
| 439 |
+
def __init__(self, C, D, strides=[2, 2], checkpointing=True):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.checkpointing = checkpointing
|
| 442 |
+
|
| 443 |
+
self.downsample_scale = np.cumprod(np.asarray(strides))[-1]
|
| 444 |
+
self.layers = [
|
| 445 |
+
SConv1d(in_channels=C, out_channels=D, kernel_size=3, norm=NORM),
|
| 446 |
+
Swish(),
|
| 447 |
+
]
|
| 448 |
+
for stride in strides:
|
| 449 |
+
self.layers += [
|
| 450 |
+
EncoderBlock(out_channels=D, stride=stride),
|
| 451 |
+
Swish(),
|
| 452 |
+
]
|
| 453 |
+
self.layers += [
|
| 454 |
+
SConv1d(in_channels=D, out_channels=D, kernel_size=3, norm=NORM),
|
| 455 |
+
SLSTM(D, num_layers=1, bidirectional=True),
|
| 456 |
+
]
|
| 457 |
+
self.layers = nn.Sequential(*self.layers)
|
| 458 |
+
|
| 459 |
+
def forward(self, x):
|
| 460 |
+
if self.checkpointing:
|
| 461 |
+
x = checkpoint(
|
| 462 |
+
self.layers, x.transpose(1, 2), use_reentrant=False
|
| 463 |
+
).transpose(1, 2)
|
| 464 |
+
else:
|
| 465 |
+
x = self.layers(x.transpose(1, 2)).transpose(1, 2)
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class Decoder(nn.Module):
|
| 470 |
+
def __init__(self, C, D, H, strides=[2, 2], checkpointing=True):
|
| 471 |
+
super().__init__()
|
| 472 |
+
self.checkpointing = checkpointing
|
| 473 |
+
|
| 474 |
+
self.in_layer = nn.Sequential(
|
| 475 |
+
SConv1d(in_channels=D, out_channels=H, kernel_size=3, norm=NORM),
|
| 476 |
+
SLSTM(H, num_layers=1, bidirectional=True),
|
| 477 |
+
)
|
| 478 |
+
self.layers = nn.ModuleList()
|
| 479 |
+
for stride in strides:
|
| 480 |
+
self.layers.append(
|
| 481 |
+
nn.Sequential(DecoderBlock(in_channels=H, stride=stride), Swish())
|
| 482 |
+
)
|
| 483 |
+
self.out_layer = SConv1d(
|
| 484 |
+
in_channels=H, out_channels=C, kernel_size=3, norm=NORM
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
def forward(self, x, g=None):
|
| 488 |
+
if self.checkpointing:
|
| 489 |
+
y = checkpoint(self._forward, x, g, use_reentrant=False)
|
| 490 |
+
else:
|
| 491 |
+
y = self._forward(x, g)
|
| 492 |
+
return y
|
| 493 |
+
|
| 494 |
+
def _forward(self, x, g=None):
|
| 495 |
+
h = self.in_layer(x.transpose(1, 2))
|
| 496 |
+
|
| 497 |
+
for layer in self.layers:
|
| 498 |
+
up_g = g.unsqueeze(-1).repeat(1, 1, h.shape[-1])
|
| 499 |
+
h = h + up_g
|
| 500 |
+
h = layer(h)
|
| 501 |
+
|
| 502 |
+
y = self.out_layer(h)
|
| 503 |
+
|
| 504 |
+
return y.transpose(1, 2), h.transpose(1, 2)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class TimeRegulator(nn.Module):
|
| 508 |
+
|
| 509 |
+
def __init__(self, in_dim, scale, learnable=False):
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.scale = scale
|
| 512 |
+
self.learnable = learnable
|
| 513 |
+
|
| 514 |
+
def forward(self, x, x_len, downsample=True):
|
| 515 |
+
if downsample:
|
| 516 |
+
x = self.downsample(x, x_len)
|
| 517 |
+
else:
|
| 518 |
+
x = self.upsample(x, x_len)
|
| 519 |
+
return x
|
| 520 |
+
|
| 521 |
+
def downsample(self, x, x_len):
|
| 522 |
+
x = torch.nn.functional.avg_pool1d(
|
| 523 |
+
x.transpose(1, 2), self.scale, stride=self.scale, ceil_mode=True
|
| 524 |
+
).transpose(1, 2)
|
| 525 |
+
x_len = (x_len / self.scale).ceil()
|
| 526 |
+
return x, x_len
|
| 527 |
+
|
| 528 |
+
def upsample(self, x, x_len):
|
| 529 |
+
if self.learnable:
|
| 530 |
+
x = self.upsampler(x.transpose(1, 2)).transpose(1, 2)
|
| 531 |
+
else:
|
| 532 |
+
x = torch.repeat_interleave(x, self.scale, dim=1)
|
| 533 |
+
return x
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class TreeVectorQuantization(nn.Module):
|
| 537 |
+
|
| 538 |
+
def __init__(
|
| 539 |
+
self,
|
| 540 |
+
in_dim,
|
| 541 |
+
vq_class="VectorQuantization",
|
| 542 |
+
vq_config={},
|
| 543 |
+
tree_config={},
|
| 544 |
+
):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.vq_config = vq_config
|
| 547 |
+
self.tree_config = tree_config
|
| 548 |
+
|
| 549 |
+
self.quantizers = nn.ModuleList()
|
| 550 |
+
self.time_regulators = nn.ModuleList()
|
| 551 |
+
for config in self.tree_config:
|
| 552 |
+
vq_config = self.vq_config.copy()
|
| 553 |
+
if not isinstance(vq_config["codebook_size"], (tuple, list)):
|
| 554 |
+
vq_config["codebook_size"] = [vq_config["codebook_size"]]
|
| 555 |
+
vq_config["codebook_dim"] = [vq_config["codebook_dim"]]
|
| 556 |
+
vq_config["codebook_size"] = vq_config["codebook_size"] * config["n_groups"]
|
| 557 |
+
vq_config["codebook_dim"] = vq_config["codebook_dim"] * config["n_groups"]
|
| 558 |
+
self.quantizers.append(
|
| 559 |
+
VectorQuantization(
|
| 560 |
+
in_dim,
|
| 561 |
+
n_groups=config.get("n_groups", 1),
|
| 562 |
+
dropout_rate_per_group=config.get("dropout_rate_per_group", 0),
|
| 563 |
+
ordered=config.get("ordered", False),
|
| 564 |
+
**vq_config,
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
self.time_regulators.append(
|
| 568 |
+
TimeRegulator(
|
| 569 |
+
in_dim,
|
| 570 |
+
config["downsample_rate"],
|
| 571 |
+
config.get("learnable_time_regulator", False),
|
| 572 |
+
)
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
def forward(
|
| 576 |
+
self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False
|
| 577 |
+
):
|
| 578 |
+
output, (quants, losses, embed_inds) = self.quantize(
|
| 579 |
+
inp,
|
| 580 |
+
inp_len,
|
| 581 |
+
enable_vq=enable_vq,
|
| 582 |
+
update_codebook=update_codebook,
|
| 583 |
+
return_pre_quant=return_pre_quant,
|
| 584 |
+
)
|
| 585 |
+
loss = sum(losses) / len(losses)
|
| 586 |
+
return output, (quants, loss, embed_inds)
|
| 587 |
+
|
| 588 |
+
def quantize(
|
| 589 |
+
self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False
|
| 590 |
+
):
|
| 591 |
+
quants, losses, embed_inds = [], [], []
|
| 592 |
+
|
| 593 |
+
pre_quant_output, quant_output, residual = 0, 0, inp
|
| 594 |
+
for tree_config, quantizer, regulator in zip(
|
| 595 |
+
self.tree_config, self.quantizers, self.time_regulators
|
| 596 |
+
):
|
| 597 |
+
# Downsample
|
| 598 |
+
x, x_len = regulator(residual, inp_len, True)
|
| 599 |
+
|
| 600 |
+
# Quantization
|
| 601 |
+
q, diff, embed_ind = quantizer(
|
| 602 |
+
x,
|
| 603 |
+
x_len,
|
| 604 |
+
enable_vq=enable_vq,
|
| 605 |
+
update_codebook=update_codebook,
|
| 606 |
+
return_pre_quant=return_pre_quant,
|
| 607 |
+
)
|
| 608 |
+
if return_pre_quant:
|
| 609 |
+
pq, q = q
|
| 610 |
+
|
| 611 |
+
# Upsample
|
| 612 |
+
x = regulator(q, x_len, False)[:, : residual.shape[1]]
|
| 613 |
+
|
| 614 |
+
residual = residual - x
|
| 615 |
+
quant_output = quant_output + x
|
| 616 |
+
|
| 617 |
+
if return_pre_quant:
|
| 618 |
+
pq = regulator(pq, x_len, False)[:, : residual.shape[1]]
|
| 619 |
+
pre_quant_output = pre_quant_output + pq
|
| 620 |
+
|
| 621 |
+
quants.append(q)
|
| 622 |
+
losses.append(diff)
|
| 623 |
+
embed_inds.append(embed_ind)
|
| 624 |
+
|
| 625 |
+
if return_pre_quant:
|
| 626 |
+
return (pre_quant_output, quant_output), (quants, losses, embed_inds)
|
| 627 |
+
return quant_output, (quants, losses, embed_inds)
|
| 628 |
+
|
| 629 |
+
def decode(self, seqs, seq_lens=None):
|
| 630 |
+
if not isinstance(seqs, (tuple, list)):
|
| 631 |
+
tokens, token_lens = self.deserialize(seqs, seq_lens)
|
| 632 |
+
else:
|
| 633 |
+
tokens, token_lens = seqs, seq_lens
|
| 634 |
+
|
| 635 |
+
quant_output = 0
|
| 636 |
+
for token, quantizer, regulator in zip(
|
| 637 |
+
tokens, self.quantizers, self.time_regulators
|
| 638 |
+
):
|
| 639 |
+
x = quantizer.decode(token).transpose(1, 2)
|
| 640 |
+
x = regulator(x, None, False)
|
| 641 |
+
if torch.is_tensor(quant_output):
|
| 642 |
+
x = x[:, : quant_output.size(1)]
|
| 643 |
+
quant_output = quant_output + x
|
| 644 |
+
|
| 645 |
+
return quant_output, token_lens
|
| 646 |
+
|
| 647 |
+
def serialize(self, tokens, token_lens):
|
| 648 |
+
assert len(tokens) <= 2, "we only support 1 or 2-scale sequences now..."
|
| 649 |
+
|
| 650 |
+
scale = self.tree_config[0]["downsample_rate"]
|
| 651 |
+
token_lens = ((token_lens.float() / scale).ceil() * scale).int()
|
| 652 |
+
|
| 653 |
+
seq1 = tokens[0].unsqueeze(-1)
|
| 654 |
+
|
| 655 |
+
if len(tokens) == 1:
|
| 656 |
+
seq_cat = seq1.view(seq1.shape[0], -1)
|
| 657 |
+
seq_cat_lens = (token_lens / scale * seq1.shape[2]).int()
|
| 658 |
+
elif len(tokens) == 2:
|
| 659 |
+
seq2 = F.pad(
|
| 660 |
+
tokens[1], (0, token_lens.max() - tokens[1].size(1)), "replicate"
|
| 661 |
+
)
|
| 662 |
+
seq2 = torch.stack([seq2[:, i::scale] for i in range(scale)], dim=-1)
|
| 663 |
+
seq_cat = torch.cat((seq1, seq2), dim=-1).view(seq1.shape[0], -1)
|
| 664 |
+
seq_cat_lens = (token_lens / scale + token_lens).int()
|
| 665 |
+
|
| 666 |
+
return seq_cat, seq_cat_lens
|
| 667 |
+
|
| 668 |
+
def deserialize(self, seqs, seq_lens):
|
| 669 |
+
if len(self.tree_config) == 1:
|
| 670 |
+
return [seqs], seq_lens
|
| 671 |
+
|
| 672 |
+
max_scale = max(config["downsample_rate"] for config in self.tree_config)
|
| 673 |
+
total_scale = sum(config["downsample_rate"] for config in self.tree_config)
|
| 674 |
+
|
| 675 |
+
# Cut for aligning
|
| 676 |
+
if seq_lens is None:
|
| 677 |
+
seq_lens = torch.full([seqs.shape[0]], seqs.shape[1]).to(seqs.device)
|
| 678 |
+
seq_lens = (seq_lens / total_scale).int() * total_scale
|
| 679 |
+
token_lens = (seq_lens / total_scale).int() * max_scale
|
| 680 |
+
seqs = seqs[:, : seq_lens.max()]
|
| 681 |
+
|
| 682 |
+
# Separate
|
| 683 |
+
tokens = torch.stack(
|
| 684 |
+
[seqs[:, i::total_scale] for i in range(total_scale)], dim=-1
|
| 685 |
+
)
|
| 686 |
+
seq1 = tokens[..., 0]
|
| 687 |
+
seq2 = tokens[..., 1:].contiguous().view(tokens.shape[0], -1)
|
| 688 |
+
|
| 689 |
+
return [seq1, seq2], token_lens
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class SemanticVQVAE(nn.Module):
|
| 693 |
+
|
| 694 |
+
def __init__(
|
| 695 |
+
self,
|
| 696 |
+
in_dim,
|
| 697 |
+
out_dim,
|
| 698 |
+
n_model_size,
|
| 699 |
+
downsample_scales=[1, 2],
|
| 700 |
+
upsample_scales=[[2, 1], [2, 1]],
|
| 701 |
+
mel_config={},
|
| 702 |
+
ssl_config={},
|
| 703 |
+
# Quantization
|
| 704 |
+
vq_class="VectorQuantization",
|
| 705 |
+
vq_config={},
|
| 706 |
+
tree_config={},
|
| 707 |
+
# Training
|
| 708 |
+
checkpointing=True,
|
| 709 |
+
dual_decoding=False,
|
| 710 |
+
n_samples_per_token=640,
|
| 711 |
+
online_extraction=True,
|
| 712 |
+
ssl_extractor=None,
|
| 713 |
+
):
|
| 714 |
+
super(SemanticVQVAE, self).__init__()
|
| 715 |
+
self.in_dim = in_dim
|
| 716 |
+
self.n_model_size = n_model_size
|
| 717 |
+
self.mel_config = mel_config
|
| 718 |
+
self.dual_decoding = dual_decoding
|
| 719 |
+
self.vq_config = vq_config
|
| 720 |
+
self.tree_config = tree_config
|
| 721 |
+
self.output_feature = "mel"
|
| 722 |
+
self.n_samples_per_token = n_samples_per_token
|
| 723 |
+
self.checkpointing = checkpointing
|
| 724 |
+
|
| 725 |
+
self.mel_spectrogram = TorchMelSpectrogram(**mel_config)
|
| 726 |
+
|
| 727 |
+
# Speaker encoder
|
| 728 |
+
self.speaker_encoder = ECAPA_TDNN(
|
| 729 |
+
out_dim,
|
| 730 |
+
n_model_size,
|
| 731 |
+
channels=[512, 512, 512, 512, 1536],
|
| 732 |
+
kernel_sizes=[5, 3, 3, 3, 1],
|
| 733 |
+
dilations=[1, 2, 3, 4, 1],
|
| 734 |
+
attention_channels=128,
|
| 735 |
+
res2net_scale=4,
|
| 736 |
+
se_channels=128,
|
| 737 |
+
global_context=True,
|
| 738 |
+
batch_norm=True,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Encoder & decoder
|
| 742 |
+
self.encoder = Encoder(
|
| 743 |
+
in_dim, n_model_size, downsample_scales, checkpointing=checkpointing
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# Quantization
|
| 747 |
+
self.quantizer = TreeVectorQuantization(
|
| 748 |
+
n_model_size,
|
| 749 |
+
vq_class=vq_class,
|
| 750 |
+
vq_config=vq_config,
|
| 751 |
+
tree_config=tree_config,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
def forward(
|
| 755 |
+
self,
|
| 756 |
+
wav,
|
| 757 |
+
wav_length,
|
| 758 |
+
enable_vq=True,
|
| 759 |
+
decode=True,
|
| 760 |
+
extract_spk=True,
|
| 761 |
+
shuffle=False,
|
| 762 |
+
**kwargs,
|
| 763 |
+
):
|
| 764 |
+
output_dict = {}
|
| 765 |
+
|
| 766 |
+
with torch.no_grad():
|
| 767 |
+
# Pad waveform
|
| 768 |
+
if wav.shape[1] % self.n_samples_per_token > 0:
|
| 769 |
+
pad_size = (
|
| 770 |
+
self.n_samples_per_token - wav.shape[1] % self.n_samples_per_token
|
| 771 |
+
)
|
| 772 |
+
wav = F.pad(wav, (0, pad_size), value=0)
|
| 773 |
+
wav_length += pad_size
|
| 774 |
+
|
| 775 |
+
# Extract mel & sll
|
| 776 |
+
mel, mel_length = kwargs.get("mel", None), kwargs.get("mel_length", None)
|
| 777 |
+
if mel is None:
|
| 778 |
+
mel, mel_length = self.mel_spectrogram(wav, wav_length)
|
| 779 |
+
output_dict.update({"mel": mel, "mel_length": mel_length})
|
| 780 |
+
|
| 781 |
+
ssl, ssl_length = kwargs.get("ssl", None), kwargs.get("ssl_length", None)
|
| 782 |
+
if ssl is None:
|
| 783 |
+
ssl, ssl_length = self.ssl_extractor(wav, wav_length)
|
| 784 |
+
output_dict.update({"ssl": ssl.float(), "ssl_length": ssl_length})
|
| 785 |
+
|
| 786 |
+
input, input_length = ssl, ssl_length
|
| 787 |
+
output, output_length = mel, mel_length
|
| 788 |
+
|
| 789 |
+
encoder_outputs = self.encoder(input)
|
| 790 |
+
quant_length = torch.ceil(input_length / self.encoder.downsample_scale)
|
| 791 |
+
quant_length = quant_length.clamp(max=encoder_outputs.shape[1])
|
| 792 |
+
|
| 793 |
+
quant, (quants, diff, embed_ind) = self.quantizer(
|
| 794 |
+
encoder_outputs,
|
| 795 |
+
quant_length,
|
| 796 |
+
enable_vq=enable_vq,
|
| 797 |
+
update_codebook=True,
|
| 798 |
+
return_pre_quant=self.dual_decoding,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
output_dict.update(
|
| 802 |
+
{
|
| 803 |
+
"quants": quants,
|
| 804 |
+
"token": embed_ind,
|
| 805 |
+
"token_length": quant_length.int(),
|
| 806 |
+
"encoder_diffs": diff,
|
| 807 |
+
}
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# Speaker
|
| 811 |
+
if extract_spk:
|
| 812 |
+
cond, cond_length = output, output_length
|
| 813 |
+
speaker_embedding = self.speaker_encoder(cond, cond_length)
|
| 814 |
+
speaker_embedding_1 = speaker_embedding_2 = speaker_embedding
|
| 815 |
+
output_dict["spk"] = speaker_embedding
|
| 816 |
+
|
| 817 |
+
return output_dict
|
| 818 |
+
|
| 819 |
+
@torch.no_grad()
|
| 820 |
+
def extract_speech_tokens(
|
| 821 |
+
self, wav, wav_length, serialize=True, extract_spk=True, shuffle=False
|
| 822 |
+
):
|
| 823 |
+
output_dict = self.forward(
|
| 824 |
+
wav, wav_length, True, False, extract_spk=extract_spk, shuffle=shuffle
|
| 825 |
+
)
|
| 826 |
+
token_seqs, token_length = output_dict["token"], output_dict["token_length"]
|
| 827 |
+
|
| 828 |
+
# Align sequences
|
| 829 |
+
scale = self.tree_config[0]["downsample_rate"]
|
| 830 |
+
token_length = (torch.ceil(token_length / scale) * scale).int()
|
| 831 |
+
|
| 832 |
+
new_token_seqs, new_token_lens = [], []
|
| 833 |
+
for i, token_seq in enumerate(token_seqs):
|
| 834 |
+
# discrete-continuous tokens
|
| 835 |
+
residual = None
|
| 836 |
+
if isinstance(token_seq, (tuple, list)):
|
| 837 |
+
token_seq, residual = token_seq
|
| 838 |
+
|
| 839 |
+
scale = self.tree_config[i]["downsample_rate"]
|
| 840 |
+
new_token_len = token_length // scale
|
| 841 |
+
pad = int(new_token_len.max()) - token_seq.shape[1]
|
| 842 |
+
token_seq = F.pad(
|
| 843 |
+
token_seq,
|
| 844 |
+
(0, pad) if len(token_seq.shape) == 2 else (0, 0, 0, pad),
|
| 845 |
+
"replicate",
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
if residual is not None:
|
| 849 |
+
token_seq = (token_seq, residual)
|
| 850 |
+
new_token_seqs.append(token_seq)
|
| 851 |
+
new_token_lens.append(new_token_len)
|
| 852 |
+
|
| 853 |
+
if len(new_token_seqs) == 1:
|
| 854 |
+
new_token_seqs, new_token_lens = new_token_seqs[0], new_token_lens[0]
|
| 855 |
+
elif serialize:
|
| 856 |
+
new_token_seqs, new_token_lens = self.quantizer.serialize(
|
| 857 |
+
new_token_seqs, new_token_lens
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
output_dict.update(
|
| 861 |
+
{
|
| 862 |
+
"embed": output_dict["quants"],
|
| 863 |
+
"token": new_token_seqs,
|
| 864 |
+
"token_length": new_token_lens,
|
| 865 |
+
}
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
return output_dict
|
| 869 |
+
|
| 870 |
+
@torch.no_grad()
|
| 871 |
+
def code_to_latent(self, token, mel=None):
|
| 872 |
+
quant, _ = self.quantizer.decode(token, None)
|
| 873 |
+
speaker_embedding = self.speaker_encoder(mel)
|
| 874 |
+
latents = quant + speaker_embedding.unsqueeze(1).repeat(1, quant.shape[1], 1)
|
| 875 |
+
return {
|
| 876 |
+
"latents": latents,
|
| 877 |
+
}
|
fireredtts/modules/text_normalizer/__init__.py
ADDED
|
File without changes
|
fireredtts/modules/text_normalizer/normalize.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import regex
|
| 3 |
+
import inflect
|
| 4 |
+
import unicodedata
|
| 5 |
+
from lingua import Language, LanguageDetectorBuilder
|
| 6 |
+
from builtins import str as unicode
|
| 7 |
+
|
| 8 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 9 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 10 |
+
|
| 11 |
+
from fireredtts.modules.text_normalizer.regex_common import *
|
| 12 |
+
from fireredtts.modules.text_normalizer.utils import *
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def preprocess_text(sentence):
|
| 16 |
+
# preprocessing
|
| 17 |
+
sentence = bytes(sentence, "utf-8").decode("utf-8", "ignore")
|
| 18 |
+
sentence = regex.sub("[\p{Cf}--[\u200d]]", "", sentence, flags=regex.V1)
|
| 19 |
+
sentence = regex.sub("\p{Co}", "", sentence)
|
| 20 |
+
sentence = sentence.replace("\u00a0", " ")
|
| 21 |
+
sentence = sentence.replace("\ufffd", "")
|
| 22 |
+
sentence = regex.sub("\p{Zl}", "\n", sentence)
|
| 23 |
+
sentence = regex.sub("\p{Zp}", "\n", sentence)
|
| 24 |
+
|
| 25 |
+
sentence = unicode(sentence)
|
| 26 |
+
sentence = "".join(
|
| 27 |
+
char
|
| 28 |
+
for char in unicodedata.normalize("NFD", sentence)
|
| 29 |
+
if unicodedata.category(char) != "Mn"
|
| 30 |
+
) # Strip accents
|
| 31 |
+
|
| 32 |
+
sentence = strip_kaomoji(sentence)
|
| 33 |
+
# full to half with exemption (to be converted after number TN): 。,:
|
| 34 |
+
sentence = f2b(sentence, exemption="。,:")
|
| 35 |
+
|
| 36 |
+
# clean spaces
|
| 37 |
+
sentence = sentence.replace("\n", ",")
|
| 38 |
+
sentence = sentence.replace("\t", ",")
|
| 39 |
+
sentence = sentence.replace("\r", ",")
|
| 40 |
+
sentence = re.sub(r"[。.]{3,}", "…", sentence)
|
| 41 |
+
sentence = re.sub(r"[…⋯]{1,}", "…", sentence)
|
| 42 |
+
sentence = re.sub(r"[ ]+", " ", sentence)
|
| 43 |
+
sentence = sentence.strip()
|
| 44 |
+
|
| 45 |
+
# punctuation reduction
|
| 46 |
+
result = ""
|
| 47 |
+
for idx, char in enumerate(sentence):
|
| 48 |
+
if char in symbol_reduction:
|
| 49 |
+
char = symbol_reduction[char]
|
| 50 |
+
|
| 51 |
+
if char == " ":
|
| 52 |
+
if idx == 0:
|
| 53 |
+
continue
|
| 54 |
+
if is_chinese(sentence[idx + 1]) and (
|
| 55 |
+
is_chinese(sentence[idx - 1]) or sentence[idx - 1] in '") '
|
| 56 |
+
):
|
| 57 |
+
result += ","
|
| 58 |
+
else:
|
| 59 |
+
result += " "
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
if is_valid_char(char):
|
| 63 |
+
result += char
|
| 64 |
+
result = re.sub(r"[ ]+", " ", result)
|
| 65 |
+
return result
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def rettt(sentence):
|
| 69 |
+
# handle abbreviations for all languages
|
| 70 |
+
sentence = sentence.replace("&nd", "and")
|
| 71 |
+
sentence = sentence.replace("Jan.", "january")
|
| 72 |
+
sentence = sentence.replace("Feb.", "febrary")
|
| 73 |
+
sentence = sentence.replace("Mar.", "march")
|
| 74 |
+
sentence = sentence.replace("Apr.", "april")
|
| 75 |
+
sentence = sentence.replace("May.", "may")
|
| 76 |
+
sentence = sentence.replace("Jun.", "june")
|
| 77 |
+
sentence = sentence.replace("Jul.", "july")
|
| 78 |
+
sentence = sentence.replace("Aug.", "august")
|
| 79 |
+
sentence = sentence.replace("Sept.", "september")
|
| 80 |
+
sentence = sentence.replace("Sep.", "september")
|
| 81 |
+
sentence = sentence.replace("Oct.", "october")
|
| 82 |
+
sentence = sentence.replace("Nov.", "november")
|
| 83 |
+
sentence = sentence.replace("Dec.", "december")
|
| 84 |
+
sentence = sentence.replace("Mon.", "monday")
|
| 85 |
+
sentence = sentence.replace("Tues.", "tuesday")
|
| 86 |
+
sentence = sentence.replace("Wed.", "wednesday")
|
| 87 |
+
sentence = sentence.replace("Thur.", "thursday")
|
| 88 |
+
sentence = sentence.replace("Fri.", "friday")
|
| 89 |
+
sentence = sentence.replace("Sat.", "saturday")
|
| 90 |
+
if sentence != "Sun.":
|
| 91 |
+
sentence = sentence.replace("Sun.", "sunday")
|
| 92 |
+
sentence = re.sub(r" St\. ([A-Z])", r" saint \1", sentence)
|
| 93 |
+
sentence = re.sub(r" St\.", " street", sentence)
|
| 94 |
+
sentence = re.sub(r" Rd\.", " road", sentence)
|
| 95 |
+
sentence = re.sub(r"[Aa]\.[Mm]\.", "A_M", sentence)
|
| 96 |
+
sentence = re.sub(r"[Pp]\.[Mm]\.", "P_M", sentence)
|
| 97 |
+
sentence = re.sub(r"[Bb]\.[Cc]\.", "B_C", sentence)
|
| 98 |
+
sentence = re.sub(r"[Ad]\.[Dd]\.", "A_D", sentence)
|
| 99 |
+
sentence = sentence.replace("Mr.", "mister")
|
| 100 |
+
sentence = sentence.replace("Ms.", "miss")
|
| 101 |
+
sentence = sentence.replace("Mrs.", "misses")
|
| 102 |
+
sentence = sentence.replace("Ph.D", "P_H_D")
|
| 103 |
+
sentence = sentence.replace("i.e.", "that is")
|
| 104 |
+
sentence = sentence.replace("e.g.", "for example")
|
| 105 |
+
sentence = sentence.replace("btw.", "by the way")
|
| 106 |
+
sentence = sentence.replace("btw", "by the way")
|
| 107 |
+
sentence = sentence.replace("b.t.w.", "by the way")
|
| 108 |
+
sentence = sentence.replace("@", " at ")
|
| 109 |
+
return sentence
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TextNormalizer:
|
| 113 |
+
def __init__(self):
|
| 114 |
+
self.language_detector = LanguageDetectorBuilder.from_languages(
|
| 115 |
+
Language.ENGLISH, Language.CHINESE
|
| 116 |
+
).build()
|
| 117 |
+
self.zh_normalizer = ZhNormalizer()
|
| 118 |
+
self.en_normalizer = EnNormalizer()
|
| 119 |
+
self.inflect_parser = inflect.engine()
|
| 120 |
+
self.lang2token = {Language.ENGLISH: "en", Language.CHINESE: "zh"}
|
| 121 |
+
|
| 122 |
+
def tn(self, text):
|
| 123 |
+
text = preprocess_text(text)
|
| 124 |
+
text = rettt(text) # regex replacements
|
| 125 |
+
# for non chinese languages
|
| 126 |
+
language = self.language_detector.detect_language_of(text)
|
| 127 |
+
# enforce chinese if text contains any chinese character
|
| 128 |
+
if contains_chinese(text):
|
| 129 |
+
language = Language.CHINESE
|
| 130 |
+
text_lang = self.lang2token.get(language, "zh")
|
| 131 |
+
|
| 132 |
+
if is_upper_eng_and_digit(text):
|
| 133 |
+
language = Language.CHINESE
|
| 134 |
+
|
| 135 |
+
if language == Language.CHINESE:
|
| 136 |
+
text = self.zh_normalizer.normalize(text)
|
| 137 |
+
# print("---text after zh_normalizer:", text)
|
| 138 |
+
text = text.replace("\n", "")
|
| 139 |
+
text = text.replace(",", ",")
|
| 140 |
+
text = text.replace(".", "。")
|
| 141 |
+
text = re.sub(r"[,,]+$", "。", text)
|
| 142 |
+
# print("---text after zh_normalizer 2:", text)
|
| 143 |
+
else:
|
| 144 |
+
text = re.sub(r"[^ 0-9A-Za-z\[\]'.,:?!_\-]", "", text)
|
| 145 |
+
text = self.en_normalizer.normalize(text)
|
| 146 |
+
# fallback number normalization
|
| 147 |
+
pieces = re.split(r"(\d+)", text)
|
| 148 |
+
text = "".join(
|
| 149 |
+
[
|
| 150 |
+
self.inflect_parser.number_to_words(p) if p.isnumeric() else p
|
| 151 |
+
for p in pieces
|
| 152 |
+
if len(p) > 0
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# cleanup
|
| 157 |
+
text = text.replace("_", " ")
|
| 158 |
+
text = re.sub(r"[ ]+", " ", text)
|
| 159 |
+
|
| 160 |
+
# spell caplital words
|
| 161 |
+
pieces = re.split(r"([A-Z]{2,4}|[ ])", text)
|
| 162 |
+
for idx, p in enumerate(pieces):
|
| 163 |
+
if re.match("[A-Z]{2,4}", p):
|
| 164 |
+
pieces[idx] = " ".join(p)
|
| 165 |
+
text = " ".join([p for p in pieces if p != " "])
|
| 166 |
+
|
| 167 |
+
# post TN full to half
|
| 168 |
+
# text = text.replace("。", ".")
|
| 169 |
+
# text = text.replace(",", ",")
|
| 170 |
+
# text = text.replace(":", ":")
|
| 171 |
+
|
| 172 |
+
# model limitations
|
| 173 |
+
text = text.lower().strip()
|
| 174 |
+
text = text.replace('"', "")
|
| 175 |
+
text = text.replace("·", " ")
|
| 176 |
+
# text = re.sub("[…~!,&*%$#^:;!:;]+", ",", text)
|
| 177 |
+
text = re.sub("[…~!&*%$#^:;!:;]+", ",", text)
|
| 178 |
+
text = re.sub("[,]+", ",", text)
|
| 179 |
+
text = re.sub(r"[,. ]+$", ".", text)
|
| 180 |
+
if len(text) > 0 and text[-1] not in ".?":
|
| 181 |
+
text = text + "."
|
| 182 |
+
text = text.replace("。.", "。")
|
| 183 |
+
return text, text_lang
|
fireredtts/modules/text_normalizer/regex_common.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
kaomoji_regex = re.compile(
|
| 4 |
+
r"[oヽwΣ┗╰O︿Ψ凸]?[(|≡*(].{0,4}[Д✿_▽→≧﹏`∩⊙∇☆≡๑〃′エ≦▔@﹁εヘ•́ω益‿≖ฺ皿•̀艹 ̄△|゚].{0,5}[|≡*))][┛ブ凸cdd︴oOΨ︿w╯ノ]?"
|
| 5 |
+
)
|
| 6 |
+
chinese_regex = re.compile(r"[\u4e00-\u9fa5]")
|
| 7 |
+
digit_regex = re.compile(r"(\\d+)(\\.\\d+)?", re.UNICODE)
|
| 8 |
+
|
| 9 |
+
chinese_char_regex = re.compile(r"^[\u4e00-\u9fa5]$", re.UNICODE)
|
| 10 |
+
eng_and_digit_char_regex = re.compile(r"^[0-9.,A-Za-z]+$", re.UNICODE)
|
| 11 |
+
upper_eng_and_digit_regex = re.compile(r"^[ 0-9A-Z\"'.,:?!\-]+$", re.UNICODE)
|
| 12 |
+
valid_char_regex = re.compile(
|
| 13 |
+
r"[\t\r\n ]|"
|
| 14 |
+
r"[\u4e00-\u9fa5]|"
|
| 15 |
+
r"\u0080|[\u20a0-\u20bf]|\u00a2|\u00a3|\u00a5|\uffe0|\uffe1|\uffe5|\uffe6|"
|
| 16 |
+
r"\u3000|\u3002|\u00b7|\u2014|\u2019|\u2026|\uff01|\uff1f|\uff0e|\uff1a|\uff1b|\uff0b|\uff0c|\uff0d|\uff0f|[\ufe10-\ufe16]|[\ufe50-\ufe51]|[\ufe55-\ufe57]|\ufe6a|"
|
| 17 |
+
r"[\u0030-\u0040]|"
|
| 18 |
+
r"[\u0391-\u03c9]|"
|
| 19 |
+
r"[\u00b0-\u00b3]|[\u2015-\u2018]|[\u3000-\u303f]|"
|
| 20 |
+
r"[\u0022-\u002f\u003a-\u003e\u0040\u005b-\u0060\u007b-\u007e]|"
|
| 21 |
+
r"[\uff21-\uff3a]|[\uff41-\uff5a]|[\u0041-\u005a]|[\u0061-\u007a]",
|
| 22 |
+
re.UNICODE,
|
| 23 |
+
)
|
fireredtts/modules/text_normalizer/utils.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fireredtts.modules.text_normalizer.regex_common import *
|
| 2 |
+
from sentencex import segment
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
symbol_reduction = {
|
| 7 |
+
"「": '"',
|
| 8 |
+
"」": '"',
|
| 9 |
+
"`": '"',
|
| 10 |
+
"〝": '"',
|
| 11 |
+
"〞": '"',
|
| 12 |
+
"‟": '"',
|
| 13 |
+
"„": '"',
|
| 14 |
+
"{": "(",
|
| 15 |
+
"}": ")",
|
| 16 |
+
"【": "(",
|
| 17 |
+
"】": ")",
|
| 18 |
+
"〖": "(",
|
| 19 |
+
"〗": ")",
|
| 20 |
+
"〔": "(",
|
| 21 |
+
"〕": ")",
|
| 22 |
+
"〘": "(",
|
| 23 |
+
"〙": ")",
|
| 24 |
+
"《": "(",
|
| 25 |
+
"》": ")",
|
| 26 |
+
"⦅": "(",
|
| 27 |
+
"⦆": ")",
|
| 28 |
+
"〚": "(",
|
| 29 |
+
"〛": ")",
|
| 30 |
+
"『": '"',
|
| 31 |
+
"』": '"',
|
| 32 |
+
"「": '"',
|
| 33 |
+
"」": '"',
|
| 34 |
+
"{": "(",
|
| 35 |
+
"}": ")",
|
| 36 |
+
"〈": "(",
|
| 37 |
+
"〉": ")",
|
| 38 |
+
"•": "·",
|
| 39 |
+
"‧": "·",
|
| 40 |
+
"〰": "…",
|
| 41 |
+
"﹏": "…",
|
| 42 |
+
"〜": "~",
|
| 43 |
+
"~": "~",
|
| 44 |
+
"+": "+",
|
| 45 |
+
"、": "、",
|
| 46 |
+
"。": "。",
|
| 47 |
+
"︐": ",",
|
| 48 |
+
"﹐": ",",
|
| 49 |
+
"︑": "、",
|
| 50 |
+
"﹑": "、",
|
| 51 |
+
"︒": "。",
|
| 52 |
+
"︓": ":",
|
| 53 |
+
"﹕": ":",
|
| 54 |
+
"︔": ";",
|
| 55 |
+
"﹔": ";",
|
| 56 |
+
"︕": "!",
|
| 57 |
+
"﹗": "!",
|
| 58 |
+
"︖": "?",
|
| 59 |
+
"﹖": "?",
|
| 60 |
+
"﹙": "(",
|
| 61 |
+
"﹚": ")",
|
| 62 |
+
"﹪": "%",
|
| 63 |
+
"﹠": "&",
|
| 64 |
+
">": ">",
|
| 65 |
+
"|": "、",
|
| 66 |
+
"=": "=",
|
| 67 |
+
"‐": "-",
|
| 68 |
+
"‑": "-",
|
| 69 |
+
"‒": "-",
|
| 70 |
+
"–": "-",
|
| 71 |
+
"—": "-",
|
| 72 |
+
"―": "-",
|
| 73 |
+
"%": "%",
|
| 74 |
+
"μ": "u",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
strong_break = re.compile("([。”;;!!:…??)\)\]』】」}~\r\n]| \.)", re.UNICODE)
|
| 79 |
+
weak_break = re.compile(
|
| 80 |
+
"["
|
| 81 |
+
"\U00002702-\U000027b0\U0001f926-\U0001f937\U00010000-\U0001fbff\U00030000-\U0010ffff"
|
| 82 |
+
"\u2640-\u2642\u2600-\u2b55\u23cf\u23e9\u231a\ufe0f\u3030"
|
| 83 |
+
"\t,,. ]",
|
| 84 |
+
re.UNICODE,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def contains_chinese(text):
|
| 89 |
+
return bool(chinese_regex.search(text))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def strip_kaomoji(text):
|
| 93 |
+
return kaomoji_regex.sub(" ", text)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def is_chinese(char):
|
| 97 |
+
return chinese_char_regex.match(char)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def is_eng_and_digit(char):
|
| 101 |
+
return eng_and_digit_char_regex.match(char)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def is_upper_eng_and_digit(text):
|
| 105 |
+
return upper_eng_and_digit_regex.match(text)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def is_valid_char(char):
|
| 109 |
+
return valid_char_regex.match(char)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def is_digit(text):
|
| 113 |
+
return digit_regex.match(text)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def f2b(ustr, exemption="。,:"):
|
| 117 |
+
half = []
|
| 118 |
+
for u in ustr:
|
| 119 |
+
num = ord(u)
|
| 120 |
+
if num == 0x3000:
|
| 121 |
+
half.append(" ")
|
| 122 |
+
elif u in exemption: # exemption
|
| 123 |
+
half.append(u)
|
| 124 |
+
elif 0xFF01 <= num <= 0xFF5E:
|
| 125 |
+
num -= 0xFEE0
|
| 126 |
+
half.append(chr(num))
|
| 127 |
+
else:
|
| 128 |
+
half.append(u)
|
| 129 |
+
return "".join(half)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def zh_text_split(text, length=80):
|
| 133 |
+
if length == 0:
|
| 134 |
+
return []
|
| 135 |
+
if length == 1:
|
| 136 |
+
return [c for c in length]
|
| 137 |
+
if len(text) <= length:
|
| 138 |
+
return [text]
|
| 139 |
+
|
| 140 |
+
match_strong = re.search(strong_break, text[:length][::-1])
|
| 141 |
+
match_weak = re.search(weak_break, text[:length][::-1])
|
| 142 |
+
end_ind_strong = length - match_strong.start() if match_strong else 0
|
| 143 |
+
end_ind_weak = length - match_weak.start() if match_weak else 0
|
| 144 |
+
|
| 145 |
+
if end_ind_strong < length // 3:
|
| 146 |
+
if end_ind_weak < length // 3:
|
| 147 |
+
valid_max = max(end_ind_strong, end_ind_weak)
|
| 148 |
+
if valid_max >= 3:
|
| 149 |
+
return [text[:valid_max]] + zh_text_split(text[valid_max:])
|
| 150 |
+
else:
|
| 151 |
+
return [text[:length]] + zh_text_split(text[length:])
|
| 152 |
+
else:
|
| 153 |
+
return [text[:end_ind_weak]] + zh_text_split(text[end_ind_weak:])
|
| 154 |
+
else:
|
| 155 |
+
return [text[:end_ind_strong]] + zh_text_split(text[end_ind_strong:])
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def text_split(text):
|
| 159 |
+
if contains_chinese(text):
|
| 160 |
+
substrings = list(segment("zh", text))
|
| 161 |
+
new_substrings = []
|
| 162 |
+
for s in substrings:
|
| 163 |
+
if len(s) > 50:
|
| 164 |
+
new_substrings += zh_text_split(s, length=50)
|
| 165 |
+
else:
|
| 166 |
+
new_substrings.append(s)
|
| 167 |
+
substrings = new_substrings
|
| 168 |
+
else:
|
| 169 |
+
substrings = list(segment("en", text))
|
| 170 |
+
|
| 171 |
+
return substrings
|
fireredtts/setup.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(name="fireredtts", version="0.1", packages=find_packages())
|
fireredtts/utils/__init__.py
ADDED
|
File without changes
|
fireredtts/utils/spliter.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import string
|
| 3 |
+
|
| 4 |
+
SYMBOLS_MAPPING = {
|
| 5 |
+
"\n": "",
|
| 6 |
+
"…": ".",
|
| 7 |
+
"“": "'",
|
| 8 |
+
"”": "'",
|
| 9 |
+
"‘": "'",
|
| 10 |
+
"’": "'",
|
| 11 |
+
"【": "",
|
| 12 |
+
"】": "",
|
| 13 |
+
"[": "",
|
| 14 |
+
"]": "",
|
| 15 |
+
"(": "",
|
| 16 |
+
")": "",
|
| 17 |
+
"(": "",
|
| 18 |
+
")": "",
|
| 19 |
+
"・": "",
|
| 20 |
+
"·": "",
|
| 21 |
+
"「": "'",
|
| 22 |
+
"」": "'",
|
| 23 |
+
"《": "'",
|
| 24 |
+
"》": "'",
|
| 25 |
+
"—": "",
|
| 26 |
+
"~": "",
|
| 27 |
+
"~": "",
|
| 28 |
+
":": ",",
|
| 29 |
+
";": ",",
|
| 30 |
+
";": ",",
|
| 31 |
+
":": ",",
|
| 32 |
+
'"': "",
|
| 33 |
+
"!": "。",
|
| 34 |
+
"!": ".",
|
| 35 |
+
"————": ",",
|
| 36 |
+
"——": ",",
|
| 37 |
+
"—": ",",
|
| 38 |
+
"……": ",",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
REPLACE_SYMBOL_REGEX = re.compile(
|
| 42 |
+
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
EMOJI_REGEX = re.compile(
|
| 47 |
+
"["
|
| 48 |
+
"\U0001f600-\U0001f64f" # emoticons
|
| 49 |
+
"\U0001f300-\U0001f5ff" # symbols & pictographs
|
| 50 |
+
"\U0001f680-\U0001f6ff" # transport & map symbols
|
| 51 |
+
"\U0001f1e0-\U0001f1ff" # flags (iOS)
|
| 52 |
+
"]+",
|
| 53 |
+
flags=re.UNICODE,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def clean_text(text):
|
| 58 |
+
# Clean the text
|
| 59 |
+
text = text.strip()
|
| 60 |
+
text = text.replace("\xa0", "")
|
| 61 |
+
|
| 62 |
+
# Replace all chinese symbols with their english counterparts
|
| 63 |
+
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
| 64 |
+
|
| 65 |
+
# Remove emojis
|
| 66 |
+
text = EMOJI_REGEX.sub(r"", text)
|
| 67 |
+
|
| 68 |
+
# Remove continuous periods (...) and commas (,,,)
|
| 69 |
+
text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
|
| 70 |
+
|
| 71 |
+
return text
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def utf_8_len(text):
|
| 75 |
+
return len(text.encode("utf-8"))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def break_text(texts, length, splits: set):
|
| 79 |
+
for text in texts:
|
| 80 |
+
if utf_8_len(text) <= length:
|
| 81 |
+
yield text
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
curr = ""
|
| 85 |
+
for char in text:
|
| 86 |
+
curr += char
|
| 87 |
+
|
| 88 |
+
if char in splits:
|
| 89 |
+
yield curr
|
| 90 |
+
curr = ""
|
| 91 |
+
|
| 92 |
+
if curr:
|
| 93 |
+
yield curr
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def break_text_by_length(texts, length):
|
| 97 |
+
for text in texts:
|
| 98 |
+
if utf_8_len(text) <= length:
|
| 99 |
+
yield text
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
curr = ""
|
| 103 |
+
for char in text:
|
| 104 |
+
curr += char
|
| 105 |
+
|
| 106 |
+
if utf_8_len(curr) >= length:
|
| 107 |
+
yield curr
|
| 108 |
+
curr = ""
|
| 109 |
+
|
| 110 |
+
if curr:
|
| 111 |
+
yield curr
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def add_cleaned(curr, segments):
|
| 115 |
+
curr = curr.strip()
|
| 116 |
+
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
| 117 |
+
segments.append(curr)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def protect_float(text):
|
| 121 |
+
# Turns 3.14 into <3_f_14> to prevent splitting
|
| 122 |
+
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def unprotect_float(text):
|
| 126 |
+
# Turns <3_f_14> into 3.14
|
| 127 |
+
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def split_text(text, length):
|
| 131 |
+
text = clean_text(text)
|
| 132 |
+
|
| 133 |
+
# Break the text into pieces with following rules:
|
| 134 |
+
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
| 135 |
+
# 2. If the text is longer than length, split at ","
|
| 136 |
+
# 3. If the text is still longer than length, split at " "
|
| 137 |
+
# 4. If the text is still longer than length, split at any character to length
|
| 138 |
+
|
| 139 |
+
texts = [text]
|
| 140 |
+
texts = map(protect_float, texts)
|
| 141 |
+
texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
|
| 142 |
+
texts = map(unprotect_float, texts)
|
| 143 |
+
texts = break_text(texts, length, {",", ","})
|
| 144 |
+
texts = break_text(texts, length, {" "})
|
| 145 |
+
texts = list(break_text_by_length(texts, length))
|
| 146 |
+
|
| 147 |
+
# Then, merge the texts into segments with length <= length
|
| 148 |
+
segments = []
|
| 149 |
+
curr = ""
|
| 150 |
+
|
| 151 |
+
for text in texts:
|
| 152 |
+
if utf_8_len(curr) + utf_8_len(text) <= length:
|
| 153 |
+
curr += text
|
| 154 |
+
else:
|
| 155 |
+
add_cleaned(curr, segments)
|
| 156 |
+
curr = text
|
| 157 |
+
|
| 158 |
+
if curr:
|
| 159 |
+
add_cleaned(curr, segments)
|
| 160 |
+
|
| 161 |
+
return segments
|
fireredtts/utils/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torchaudio
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_audio(audiopath, sampling_rate):
|
| 10 |
+
"""_summary_
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
audiopath (_type_): audio_path
|
| 14 |
+
sampling_rate (_type_): sampling_rate
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
_type_: _description_
|
| 18 |
+
"""
|
| 19 |
+
audio, lsr = torchaudio.load(audiopath)
|
| 20 |
+
|
| 21 |
+
# stereo to mono if needed
|
| 22 |
+
if audio.size(0) != 1:
|
| 23 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
| 24 |
+
|
| 25 |
+
# resample
|
| 26 |
+
audio_resampled = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
| 27 |
+
if torch.any(audio > 10) or not torch.any(audio < 0):
|
| 28 |
+
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
| 29 |
+
|
| 30 |
+
if torch.any(audio_resampled > 10) or not torch.any(audio_resampled < 0):
|
| 31 |
+
print(
|
| 32 |
+
f"Error with {audiopath}. Max={audio_resampled.max()} min={audio_resampled.min()}"
|
| 33 |
+
)
|
| 34 |
+
# clip audio invalid values
|
| 35 |
+
audio.clip_(-1, 1)
|
| 36 |
+
audio_resampled.clip_(-1, 1)
|
| 37 |
+
return audio, lsr, audio_resampled
|
pre-requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip==24.0
|
pretrained_models/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Pretrained Models
|
| 2 |
+
|
| 3 |
+
Download the required model files and place them in the folder `pretrained_models`
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchaudio
|
| 2 |
+
fairseq
|
| 3 |
+
diffusers==0.27.2
|
| 4 |
+
librosa==0.10.2
|
| 5 |
+
soundfile==0.12.1
|
| 6 |
+
einops==0.8.0
|
| 7 |
+
transformers==4.44.2
|
| 8 |
+
tiktoken==0.7.0
|
| 9 |
+
inflect==7.4.0
|
| 10 |
+
lingua-language-detector==2.0.2
|
| 11 |
+
WeTextProcessing==1.0.3
|
| 12 |
+
sentencex==0.6.1
|