Shen Feiyu commited on
Commit
faadabf
·
1 Parent(s): 6cdc9a3
Files changed (44) hide show
  1. app.py +284 -0
  2. configs/config_24k.json +171 -0
  3. configs/config_24k_flow.json +125 -0
  4. fireredtts/models/fireredtts.py +266 -0
  5. fireredtts/models/token2audio.py +108 -0
  6. fireredtts/modules/__init__.py +0 -0
  7. fireredtts/modules/acoustic_codec/__init__.py +1 -0
  8. fireredtts/modules/acoustic_codec/alias_free_torch/__init__.py +6 -0
  9. fireredtts/modules/acoustic_codec/alias_free_torch/act.py +35 -0
  10. fireredtts/modules/acoustic_codec/alias_free_torch/filter.py +99 -0
  11. fireredtts/modules/acoustic_codec/alias_free_torch/resample.py +58 -0
  12. fireredtts/modules/acoustic_codec/bigcodec.py +698 -0
  13. fireredtts/modules/acoustic_codec/vector_quantization.py +580 -0
  14. fireredtts/modules/acoustic_llm/__init__.py +1 -0
  15. fireredtts/modules/acoustic_llm/acoustic_llm.py +876 -0
  16. fireredtts/modules/bigvgan/__init__.py +2 -0
  17. fireredtts/modules/bigvgan/activations.py +126 -0
  18. fireredtts/modules/bigvgan/alias_free_torch/__init__.py +5 -0
  19. fireredtts/modules/bigvgan/alias_free_torch/act.py +29 -0
  20. fireredtts/modules/bigvgan/alias_free_torch/filter.py +98 -0
  21. fireredtts/modules/bigvgan/alias_free_torch/resample.py +57 -0
  22. fireredtts/modules/bigvgan/bigvgan.py +369 -0
  23. fireredtts/modules/bigvgan/mel_spectrogram.py +111 -0
  24. fireredtts/modules/flowmatching/__init__.py +18 -0
  25. fireredtts/modules/flowmatching/estimator_dit.py +356 -0
  26. fireredtts/modules/flowmatching/flow.py +138 -0
  27. fireredtts/modules/flowmatching/upsample_encoder.py +617 -0
  28. fireredtts/modules/semantic_llm/llm_gpt2.py +608 -0
  29. fireredtts/modules/semantic_tokenizer/__init__.py +36 -0
  30. fireredtts/modules/semantic_tokenizer/audio.py +138 -0
  31. fireredtts/modules/semantic_tokenizer/ecapa_tdnn.py +931 -0
  32. fireredtts/modules/semantic_tokenizer/hubert.py +108 -0
  33. fireredtts/modules/semantic_tokenizer/semantic_tokenizer.py +877 -0
  34. fireredtts/modules/text_normalizer/__init__.py +0 -0
  35. fireredtts/modules/text_normalizer/normalize.py +183 -0
  36. fireredtts/modules/text_normalizer/regex_common.py +23 -0
  37. fireredtts/modules/text_normalizer/utils.py +171 -0
  38. fireredtts/setup.py +3 -0
  39. fireredtts/utils/__init__.py +0 -0
  40. fireredtts/utils/spliter.py +161 -0
  41. fireredtts/utils/utils.py +37 -0
  42. pre-requirements.txt +1 -0
  43. pretrained_models/README.md +3 -0
  44. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment setup
2
+ from pathlib import Path
3
+ import os
4
+ import sys
5
+ sys.path.append(str(Path(__file__).parent))
6
+ # FIXME add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315
7
+ if os.path.exists('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py'):
8
+ file_lines = []
9
+ with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'r') as f:
10
+ for line in f:
11
+ file_lines.append(line.strip('\n'))
12
+ file_lines[314] = file_lines[314].replace(
13
+ "state = torch.load(f, map_location=torch.device(\"cpu\"))",
14
+ "state = torch.load(f, map_location=torch.device(\"cpu\"), weights_only=False)"
15
+ )
16
+ with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'w') as f:
17
+ for line in file_lines:
18
+ f.write(line+'\n')
19
+ print('[DEBUG] added weights_only=False')
20
+ # Run
21
+ import spaces
22
+ import gradio as gr
23
+ from zipfile import ZipFile
24
+ from typing import Literal
25
+ from huggingface_hub import snapshot_download
26
+ from fireredtts.models.fireredtts import FireRedTTS
27
+ # NOTE disable verbose INFO logs
28
+ import logging
29
+ httpx_logger = logging.getLogger("httpx")
30
+ httpx_logger.setLevel(logging.WARNING)
31
+
32
+ # NOTE Some launching setups
33
+ # - install fairseq manually ("python -m pip install pip==24.0")
34
+ # - manually add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315
35
+
36
+
37
+ # ================================================
38
+ # FireRedTTS1s Model
39
+ # ================================================
40
+ # Global model instance
41
+ tts_flow: FireRedTTS = None
42
+ tts_acollm: FireRedTTS = None
43
+ def initiate_model(pretrained_dir: str):
44
+ global tts_flow, tts_acollm
45
+ if tts_flow is None:
46
+ tts_flow = FireRedTTS(
47
+ config_path='configs/config_24k_flow.json',
48
+ pretrained_path=pretrained_dir,
49
+ )
50
+ if tts_acollm is None:
51
+ tts_acollm = FireRedTTS(
52
+ config_path='configs/config_24k.json',
53
+ pretrained_path=pretrained_dir,
54
+ )
55
+
56
+
57
+ # ================================================
58
+ # Gradio
59
+ # ================================================
60
+
61
+ # i18n
62
+ _i18n_key2lang_dict = dict(
63
+ # Title markdown
64
+ title_md_desc=dict(
65
+ en="FireRedTTS-1s 🔥 Streamable TTS",
66
+ zh="FireRedTTS-1s 🔥 可流式TTS",
67
+ ),
68
+ # Decoder choice radio
69
+ decoder_choice_label=dict(
70
+ en="Decoder Choice",
71
+ zh="解码器选择",
72
+ ),
73
+ decoder_choice_1=dict(
74
+ en="Flow Matching",
75
+ zh="Flow Matching",
76
+ ),
77
+ decoder_choice_2=dict(
78
+ en="Acoustic LLM",
79
+ zh="Acoustic LLM",
80
+ ),
81
+ # Speaker Prompt
82
+ spk_prompt_audio_label=dict(
83
+ en="Speaker Prompt Audio",
84
+ zh="参考语音",
85
+ ),
86
+ spk_prompt_text_label=dict(
87
+ en="Speaker Prompt Text",
88
+ zh="参考语音的文本",
89
+ ),
90
+ spk_prompt_text_placeholder=dict(
91
+ en="Speaker Prompt Text",
92
+ zh="参考语音的文本",
93
+ ),
94
+ # Input textbox
95
+ target_text_input_label=dict(
96
+ en="Text To Synthesis",
97
+ zh="待合成文本",
98
+ ),
99
+ target_text_input_placeholder=dict(
100
+ en="Text To Synthesis",
101
+ zh="待合成文本",
102
+ ),
103
+ # Generate button
104
+ generate_btn_label=dict(
105
+ en="Generate Audio",
106
+ zh="合成",
107
+ ),
108
+ # Generated audio
109
+ generated_audio_label=dict(
110
+ en="Generated Audio",
111
+ zh="合成的音频",
112
+ ),
113
+ # Warining1: incomplete prompt info
114
+ warn_incomplete_prompt=dict(
115
+ en="Please provide prompt audio and text",
116
+ zh="请提供说话人参考语音与参考文本",
117
+ ),
118
+ # Warining2: invalid text for target text input
119
+ warn_invalid_target_text=dict(
120
+ en="Empty input text",
121
+ zh="待合成文本为空",
122
+ ),
123
+ )
124
+
125
+ global_lang: Literal['zh', 'en'] = 'zh'
126
+ def i18n(key):
127
+ global global_lang
128
+ return _i18n_key2lang_dict[key][global_lang]
129
+
130
+
131
+ def check_monologue_text(text:str, prefix:str=None)->bool:
132
+ text = text.strip()
133
+ # Check speaker tags
134
+ if prefix is not None and (not text.startswith(prefix)):
135
+ return False
136
+ # Remove prefix
137
+ if prefix is not None:
138
+ text = text.removeprefix(prefix)
139
+ text = text.strip()
140
+ # If empty?
141
+ if len(text) == 0:
142
+ return False
143
+ return True
144
+
145
+
146
+ @spaces.GPU(duration=60)
147
+ def synthesis_function(
148
+ spk_prompt_audio: str,
149
+ spk_prompt_text: str,
150
+ target_text: str,
151
+ decoder_choice: Literal[0, 1] = 0, # 0 means flow matching decoder
152
+ ):
153
+ global tts_flow, tts_acollm
154
+
155
+ # Check prompt info
156
+ spk_prompt_text = spk_prompt_text.strip()
157
+ if spk_prompt_audio is None or spk_prompt_text == "":
158
+ gr.Warning(message=i18n('warn_incomplete_prompt'))
159
+ return None
160
+ # Check target text
161
+ target_text = target_text.strip()
162
+ if target_text == "":
163
+ gr.Warning(message=i18n('warn_invalid_target_text'))
164
+ return None
165
+
166
+ # Go synthesis
167
+ if decoder_choice == 0:
168
+ audio = tts_flow.synthesize(
169
+ prompt_wav=spk_prompt_audio,
170
+ prompt_text=spk_prompt_text,
171
+ text=target_text,
172
+ lang="zh",
173
+ use_tn=True
174
+ )
175
+ else:
176
+ audio = tts_acollm.synthesize(
177
+ prompt_wav=spk_prompt_audio,
178
+ prompt_text=spk_prompt_text,
179
+ text=target_text,
180
+ lang="zh",
181
+ use_tn=True
182
+ )
183
+ return (24000, audio.detach().cpu().squeeze(0).numpy())
184
+
185
+
186
+ # UI rendering
187
+ def render_interface()->gr.Blocks:
188
+ with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page:
189
+ # ======================== UI ========================
190
+ # A large title
191
+ title_desc = gr.Markdown(value="# {}".format(i18n('title_md_desc')))
192
+ with gr.Row():
193
+ lang_choice = gr.Radio(
194
+ choices=['中文', 'English'],
195
+ value='中文',
196
+ label='Display Language/显示语言',
197
+ type="index",
198
+ interactive=True,
199
+ )
200
+ decoder_choice = gr.Radio(
201
+ choices=[i18n('decoder_choice_1'), i18n('decoder_choice_2')],
202
+ value=i18n('decoder_choice_1'),
203
+ label=i18n('decoder_choice_label'),
204
+ type="index",
205
+ interactive=True,
206
+ )
207
+ with gr.Row():
208
+ # ==== Speaker Prompt ====
209
+ spk_prompt_text = gr.Textbox(
210
+ label=i18n('spk_prompt_text_label'),
211
+ placeholder=i18n('spk_prompt_text_placeholder'),
212
+ lines=5,
213
+ )
214
+ spk_prompt_audio = gr.Audio(
215
+ label=i18n('spk_prompt_audio_label'),
216
+ type="filepath",
217
+ editable=False,
218
+ interactive=True,
219
+ ) # Audio component returns tmp audio path
220
+ # ==== Target Text ====
221
+ target_text_input = gr.Textbox(
222
+ label=i18n('target_text_input_label'),
223
+ placeholder=i18n('target_text_input_placeholder'),
224
+ lines=5,
225
+ )
226
+ # Generate button
227
+ generate_btn = gr.Button(value=i18n('generate_btn_label'), variant="primary", size="lg")
228
+ # Long output audio
229
+ generate_audio = gr.Audio(
230
+ label=i18n('generated_audio_label'),
231
+ interactive=False,
232
+ )
233
+
234
+ # ======================== Action ========================
235
+ # Language action
236
+ def _change_component_language(lang):
237
+ global global_lang
238
+ global_lang = ['zh', 'en'][lang]
239
+ return [
240
+ # title_desc
241
+ gr.update(value="# {}".format(i18n('title_md_desc'))),
242
+ # decoder_choice
243
+ gr.update(label=i18n('decoder_choice_label')),
244
+ # spk_prompt_{audio,text}
245
+ gr.update(label=i18n('spk_prompt_text_label'), placeholder=i18n('spk_prompt_text_placeholder')),
246
+ gr.update(label=i18n('spk_prompt_audio_label')),
247
+ # target_text_input
248
+ gr.update(label=i18n('target_text_input_label'), placeholder=i18n('target_text_input_placeholder')),
249
+ # generate_btn
250
+ gr.update(value=i18n('generate_btn_label')),
251
+ # generate_audio
252
+ gr.update(label=i18n('generated_audio_label')),
253
+ ]
254
+ lang_choice.change(
255
+ fn=_change_component_language,
256
+ inputs=[lang_choice],
257
+ outputs=[
258
+ title_desc, decoder_choice,
259
+ spk_prompt_text, spk_prompt_audio,
260
+ target_text_input,
261
+ generate_btn, generate_audio,
262
+ ]
263
+ )
264
+ generate_btn.click(
265
+ fn=synthesis_function,
266
+ inputs=[spk_prompt_audio, spk_prompt_text, target_text_input, decoder_choice],
267
+ outputs=[generate_audio]
268
+ )
269
+ return page
270
+
271
+
272
+ if __name__ == '__main__':
273
+ # Download model
274
+ snapshot_download(repo_id='FireRedTeam/FireRedTTS-1S', local_dir='pretrained_models/FireRedTTS-1S')
275
+ # Unzip model, weights under "pretrained_models/FireRedTTS-1S/pretrained_models"
276
+ with ZipFile('pretrained_models/FireRedTTS-1S/pretrained_models.zip', 'r') as zipf:
277
+ zipf.extractall('pretrained_models/FireRedTTS-1S')
278
+ # Init model
279
+ initiate_model('pretrained_models/FireRedTTS-1S/pretrained_models')
280
+ print('[INFO] model loaded')
281
+ # UI
282
+ page = render_interface()
283
+ page.launch()
284
+
configs/config_24k.json ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "semantic_llm": {
3
+ "start_text_token": 32000,
4
+ "stop_text_token": 32001,
5
+ "num_text_tokens": 32002,
6
+ "start_audio_token": 16384,
7
+ "stop_audio_token": 16385,
8
+ "num_audio_tokens": 16386,
9
+ "llm_hidden_size": 1024,
10
+ "llm_intermediate_size": 4096,
11
+ "llm_num_layers": 30,
12
+ "llm_num_heads": 16,
13
+ "llm_max_audio_seq_len": 630,
14
+ "llm_max_text_seq_len": 402,
15
+ "llm_max_prompt_len": 250,
16
+ "code_stride_len": 640,
17
+ "EOS_TOKEN": 16385
18
+ },
19
+ "acoustic_llm": {
20
+ "n_stacks": 1,
21
+ "layers": 24,
22
+ "model_dim": 1536,
23
+ "heads": 16,
24
+ "max_text_tokens": 2048,
25
+ "max_speech_tokens": 2048,
26
+ "max_conditioning_inputs": 1,
27
+ "number_text_tokens": 16386,
28
+ "start_text_token": 16384,
29
+ "stop_text_token": 16385,
30
+ "n_frames_per_step": 1,
31
+ "n_heads_per_frame": 8,
32
+ "delay_prediction": 1,
33
+ "upsample_factors": 1,
34
+ "streaming_delayed_frames": 8,
35
+ "number_speech_tokens": 16386,
36
+ "start_speech_token": 16384,
37
+ "stop_speech_token": 16385,
38
+ "speaker_embedding_pretrained": true,
39
+ "speaker_embedding_ckpt": null,
40
+ "speaker_embedding_dim": 512,
41
+ "temperature": 0.5,
42
+ "repetition_penalty": 2.0,
43
+ "top_p": 0.5,
44
+ "top_k": 25
45
+ },
46
+ "acoustic_codec": {
47
+ "n_model_size": 1024,
48
+ "encoder_config": {
49
+ "ngf": 48,
50
+ "up_ratios": [
51
+ 2,
52
+ 4,
53
+ 4,
54
+ 4,
55
+ 5
56
+ ],
57
+ "causal": true
58
+ },
59
+ "decoder_config": {
60
+ "upsample_initial_channel": 1536,
61
+ "ngf": 48,
62
+ "up_ratios": [
63
+ 6,
64
+ 5,
65
+ 4,
66
+ 4,
67
+ 2
68
+ ],
69
+ "causal": true
70
+ },
71
+ "vq_config": {
72
+ "n_groups": 8,
73
+ "ordered": true,
74
+ "codebook_size": [
75
+ 128,
76
+ 128,
77
+ 128,
78
+ 128,
79
+ 128,
80
+ 128,
81
+ 128,
82
+ 128,
83
+ 128,
84
+ 128,
85
+ 128,
86
+ 128,
87
+ 128,
88
+ 128,
89
+ 128,
90
+ 128
91
+ ],
92
+ "codebook_dim": [
93
+ 8,
94
+ 8,
95
+ 8,
96
+ 8,
97
+ 8,
98
+ 8,
99
+ 8,
100
+ 8,
101
+ 8,
102
+ 8,
103
+ 8,
104
+ 8,
105
+ 8,
106
+ 8,
107
+ 8,
108
+ 8
109
+ ],
110
+ "requires_projection": true,
111
+ "decay": 0.99,
112
+ "threshold_ema_dead_code": 0,
113
+ "commitment_weight": 0.01
114
+ },
115
+ "resampler_config": {
116
+ "source_sr": 16000,
117
+ "target_sr": 16000
118
+ }
119
+ },
120
+ "semantic_tokenizer": {
121
+ "in_dim": 1024,
122
+ "out_dim": 80,
123
+ "n_model_size": 512,
124
+ "downsample_scales": [
125
+ 1,
126
+ 1,
127
+ 1,
128
+ 2
129
+ ],
130
+ "upsample_scales": [
131
+ [
132
+ 2,
133
+ 1
134
+ ],
135
+ [
136
+ 2,
137
+ 1,
138
+ 1,
139
+ 1
140
+ ]
141
+ ],
142
+ "mel_config": {
143
+ "style": "BigVGAN",
144
+ "filter_length": 1024,
145
+ "hop_length": 160,
146
+ "win_length": 640,
147
+ "n_mel_channels": 80,
148
+ "sampling_rate": 16000
149
+ },
150
+ "vq_config": {
151
+ "codebook_size": [
152
+ 128,
153
+ 128
154
+ ],
155
+ "codebook_dim": [
156
+ 128,
157
+ 128
158
+ ],
159
+ "requires_projection": true
160
+ },
161
+ "tree_config": [
162
+ {
163
+ "downsample_rate": 1,
164
+ "n_groups": 1,
165
+ "dropout": 0
166
+ }
167
+ ],
168
+ "n_samples_per_token": 640,
169
+ "checkpointing": true
170
+ }
171
+ }
configs/config_24k_flow.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "semantic_llm": {
3
+ "start_text_token": 32000,
4
+ "stop_text_token": 32001,
5
+ "num_text_tokens": 32002,
6
+ "start_audio_token": 16384,
7
+ "stop_audio_token": 16385,
8
+ "num_audio_tokens": 16386,
9
+ "llm_hidden_size": 1024,
10
+ "llm_intermediate_size": 4096,
11
+ "llm_num_layers": 30,
12
+ "llm_num_heads": 16,
13
+ "llm_max_audio_seq_len": 630,
14
+ "llm_max_text_seq_len": 402,
15
+ "llm_max_prompt_len": 250,
16
+ "code_stride_len": 640,
17
+ "EOS_TOKEN": 16385
18
+ },
19
+ "flow": {
20
+ "spk_channels": 512,
21
+ "spk_enc_channels": 80,
22
+ "infer_cfg_rate": 0.7,
23
+ "token_emb": {
24
+ "channels": 512
25
+ },
26
+ "encoder": {
27
+ "input_size": 512,
28
+ "output_size": 512,
29
+ "num_blocks": 6,
30
+ "num_up_blocks": 4,
31
+ "normalize_before": true,
32
+ "up_stride": 2,
33
+ "pre_lookahead_len": 3,
34
+ "attention_heads": 4,
35
+ "key_bias": true,
36
+ "linear_units": 2048,
37
+ "dropout_rate": 0.0,
38
+ "positional_dropout_rate": 0.0,
39
+ "attention_dropout_rate": 0.0
40
+ },
41
+ "estimator": {
42
+ "in_channels": 320,
43
+ "out_channels": 80,
44
+ "mlp_ratio": 4,
45
+ "depth": 16,
46
+ "num_heads": 8,
47
+ "head_dim": 64,
48
+ "hidden_size": 512
49
+ }
50
+ },
51
+ "mel": {
52
+ "num_mels": 80,
53
+ "n_fft": 1920,
54
+ "hop_size": 480,
55
+ "win_size": 1920,
56
+ "sampling_rate": 24000,
57
+ "fmin": 0,
58
+ "fmax": 8000,
59
+ "center": false
60
+ },
61
+ "bigvgan": {
62
+ "num_mels": 80,
63
+ "upsample_initial_channel": 1536,
64
+ "resblock_kernel_sizes": [3, 7, 11],
65
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
66
+ "upsample_rates": [5, 4, 3, 2, 2, 2],
67
+ "upsample_kernel_sizes": [11, 8, 7, 4, 4, 4],
68
+ "resblock_type": "1",
69
+ "snake_logscale": true,
70
+ "activation": "snakebeta",
71
+ "use_tanh_at_final": false,
72
+ "use_bias_at_final": false
73
+ },
74
+ "semantic_tokenizer": {
75
+ "in_dim": 1024,
76
+ "out_dim": 80,
77
+ "n_model_size": 512,
78
+ "downsample_scales": [
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 2
83
+ ],
84
+ "upsample_scales": [
85
+ [
86
+ 2,
87
+ 1
88
+ ],
89
+ [
90
+ 2,
91
+ 1,
92
+ 1,
93
+ 1
94
+ ]
95
+ ],
96
+ "mel_config": {
97
+ "style": "BigVGAN",
98
+ "filter_length": 1024,
99
+ "hop_length": 160,
100
+ "win_length": 640,
101
+ "n_mel_channels": 80,
102
+ "sampling_rate": 16000
103
+ },
104
+ "vq_config": {
105
+ "codebook_size": [
106
+ 128,
107
+ 128
108
+ ],
109
+ "codebook_dim": [
110
+ 128,
111
+ 128
112
+ ],
113
+ "requires_projection": true
114
+ },
115
+ "tree_config": [
116
+ {
117
+ "downsample_rate": 1,
118
+ "n_groups": 1,
119
+ "dropout": 0
120
+ }
121
+ ],
122
+ "n_samples_per_token": 640,
123
+ "checkpointing": true
124
+ }
125
+ }
fireredtts/models/fireredtts.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from traceback import format_exc
5
+ import torch
6
+
7
+ from transformers import AutoTokenizer
8
+
9
+ from fireredtts.utils.utils import load_audio
10
+ from fireredtts.modules.text_normalizer.utils import text_split
11
+ from fireredtts.utils.spliter import clean_text
12
+ from fireredtts.modules.text_normalizer.normalize import TextNormalizer
13
+ from fireredtts.modules.semantic_tokenizer import SemanticTokenizer
14
+ from fireredtts.modules.semantic_llm.llm_gpt2 import Speech_LLM_GPT2
15
+ from fireredtts.models.token2audio import TwoStageCodec, FlowToken2Audio
16
+
17
+
18
+ class FireRedTTS:
19
+ def __init__(self, config_path, pretrained_path, device="cuda"):
20
+ self.device = device
21
+ self.config = json.load(open(config_path))
22
+ self.EOS_TOKEN = self.config["semantic_llm"]["EOS_TOKEN"]
23
+
24
+ # pretrained models
25
+ self.tokenizer_path = os.path.join(pretrained_path, "tokenizer")
26
+ self.speech_tokenizer_path = os.path.join(pretrained_path, "speech_tokenizer")
27
+ self.semantic_llm_path = os.path.join(pretrained_path, "semantic_llm.pt")
28
+ assert os.path.exists(self.tokenizer_path)
29
+ assert os.path.exists(self.speech_tokenizer_path)
30
+ assert os.path.exists(self.semantic_llm_path)
31
+ if 'acoustic_llm' in self.config:
32
+ self.acoustic_llm_path = os.path.join(pretrained_path, "acoustic_llm.bin")
33
+ self.acoustic_codec_path = os.path.join(pretrained_path, "acoustic_codec.bin")
34
+ assert os.path.exists(self.acoustic_llm_path)
35
+ assert os.path.exists(self.acoustic_codec_path)
36
+ else:
37
+ self.flow_path = os.path.join(pretrained_path, "flow.pt")
38
+ self.bigvgan_path = os.path.join(pretrained_path, "bigvgan.pt")
39
+ assert os.path.exists(self.flow_path)
40
+ assert os.path.exists(self.bigvgan_path)
41
+
42
+ # text normalizer
43
+ self.text_normalizer = TextNormalizer()
44
+ # text tokenizer
45
+ self.text_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
46
+
47
+ # semantic llm
48
+ self.semantic_llm = Speech_LLM_GPT2(
49
+ start_text_token=self.config["semantic_llm"]["start_text_token"],
50
+ stop_text_token=self.config["semantic_llm"]["stop_text_token"],
51
+ num_text_tokens=self.config["semantic_llm"]["num_text_tokens"],
52
+ start_audio_token=self.config["semantic_llm"]["start_audio_token"],
53
+ stop_audio_token=self.config["semantic_llm"]["stop_audio_token"],
54
+ num_audio_tokens=self.config["semantic_llm"]["num_audio_tokens"],
55
+ llm_hidden_size=self.config["semantic_llm"]["llm_hidden_size"],
56
+ llm_intermediate_size=self.config["semantic_llm"]["llm_intermediate_size"],
57
+ llm_num_layers=self.config["semantic_llm"]["llm_num_layers"],
58
+ llm_num_heads=self.config["semantic_llm"]["llm_num_heads"],
59
+ llm_max_audio_seq_len=self.config["semantic_llm"]["llm_max_audio_seq_len"],
60
+ llm_max_text_seq_len=self.config["semantic_llm"]["llm_max_text_seq_len"],
61
+ llm_max_prompt_len=self.config["semantic_llm"]["llm_max_prompt_len"],
62
+ code_stride_len=self.config["semantic_llm"]["code_stride_len"],
63
+ )
64
+
65
+ sd = torch.load(self.semantic_llm_path, map_location=device)["model"]
66
+ self.semantic_llm.load_state_dict(sd, strict=True)
67
+ self.semantic_llm = self.semantic_llm.to(device=device)
68
+ self.semantic_llm.eval()
69
+ self.semantic_llm.init_gpt_for_inference(kv_cache=True)
70
+
71
+ # Speech tokenizer
72
+ self.speech_tokenizer = SemanticTokenizer(
73
+ config=self.config["semantic_tokenizer"], path=self.speech_tokenizer_path
74
+ )
75
+
76
+ # Acoustic decoder
77
+ if 'acoustic_llm' in self.config:
78
+ self.acoustic_decoder = TwoStageCodec(self.config)
79
+ self.acoustic_decoder.load_model(self.acoustic_llm_path, self.acoustic_codec_path)
80
+ else:
81
+ self.acoustic_decoder = FlowToken2Audio(self.config)
82
+ self.acoustic_decoder.load_model(self.flow_path, self.bigvgan_path)
83
+ self.acoustic_decoder.eval()
84
+ self.acoustic_decoder = self.acoustic_decoder.to(device)
85
+
86
+ def extract_spk_embeddings(self, prompt_wav):
87
+ audio, lsr, audio_resampled = load_audio(
88
+ audiopath=prompt_wav,
89
+ sampling_rate=16000,
90
+ )
91
+ _, _, audio_resampled24k = load_audio(
92
+ audiopath=prompt_wav,
93
+ sampling_rate=24000,
94
+ )
95
+
96
+ audio_resampled = audio_resampled.to(self.device)
97
+ audio_len = torch.tensor(
98
+ data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
99
+ )
100
+
101
+ # spk_embeddings:[1, 512]
102
+ prompt_tokens, token_lengths, spk_embeddings = self.speech_tokenizer(
103
+ audio_resampled, audio_len
104
+ )
105
+
106
+ prompt_acoustic_tokens, acoustic_llm_spk = self.acoustic_decoder.extract(
107
+ audio_resampled if isinstance(self.acoustic_decoder, TwoStageCodec) else audio_resampled24k,
108
+ audio_len, spk_embeddings.unsqueeze(0)
109
+ )
110
+
111
+ return prompt_tokens, spk_embeddings, prompt_acoustic_tokens, acoustic_llm_spk
112
+
113
+ def synthesize_base(
114
+ self,
115
+ prompt_semantic_tokens,
116
+ prompt_acoustic_tokens,
117
+ spk_semantic_llm,
118
+ spk_acoustic_llm,
119
+ prompt_text,
120
+ text,
121
+ lang="auto",
122
+ ):
123
+ """_summary_
124
+
125
+ Args:
126
+ prompt_wav (_type_): _description_
127
+ prompt_text (_type_): _description_
128
+ text (_type_): _description_
129
+ lang (str, optional): _description_. Defaults to "auto".
130
+
131
+ Returns:
132
+ _type_: _description_
133
+ """
134
+ if lang == "en":
135
+ text = prompt_text + " " + text
136
+ else:
137
+ text = prompt_text + text
138
+
139
+ print("---text:\n", text)
140
+
141
+ # Pre-process prompt tokens
142
+ # text to tokens
143
+ text_tokens = self.text_tokenizer.encode(
144
+ text=text,
145
+ add_special_tokens=False,
146
+ max_length=10**6,
147
+ truncation=False,
148
+ )
149
+ # print("---decode", [self.text_tokenizer.decode([c]) for c in text_tokens])
150
+
151
+ text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
152
+
153
+ assert text_tokens.shape[-1] < 200
154
+ with torch.no_grad():
155
+ gpt_codes = self.semantic_llm.generate_ic(
156
+ cond_latents=spk_semantic_llm,
157
+ text_inputs=text_tokens,
158
+ prompt_tokens=prompt_semantic_tokens[:, :-3],
159
+ do_sample=True,
160
+ top_p=0.85,
161
+ top_k=30,
162
+ temperature=0.75,
163
+ num_return_sequences=7,
164
+ num_beams=1,
165
+ length_penalty=2.0,
166
+ repetition_penalty=5.0,
167
+ output_attentions=False,
168
+ )
169
+
170
+ seqs = []
171
+ for seq in gpt_codes:
172
+ index = (seq == self.EOS_TOKEN).nonzero(as_tuple=True)[0][0]
173
+ seq = seq[:index]
174
+ seqs.append(seq)
175
+
176
+ sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
177
+ sorted_len = [len(l) for l in sorted_seqs]
178
+
179
+ gpt_codes = sorted_seqs[2].unsqueeze(0)
180
+
181
+ # Acoustic decoder
182
+ rec_wavs = self.acoustic_decoder(
183
+ gpt_codes, prompt_semantic_tokens, prompt_acoustic_tokens, spk_acoustic_llm
184
+ )
185
+
186
+ rec_wavs = rec_wavs.detach().cpu()
187
+ return rec_wavs
188
+
189
+ @torch.no_grad()
190
+ def synthesize(self, prompt_wav, prompt_text, text, lang="auto", use_tn=False):
191
+ """audio synthesize
192
+
193
+ Args:
194
+ prompt_wav (_type_): _description_
195
+ prompt_text (_type_): _description_
196
+ text (_type_): _description_
197
+ lang (str, optional): _description_. Defaults to "auto".
198
+
199
+ Returns:
200
+ _type_: _description_
201
+ """
202
+ assert lang in ["zh", "en", "auto"]
203
+ assert os.path.exists(prompt_wav)
204
+
205
+ (
206
+ prompt_semantic_tokens,
207
+ spk_embeddings,
208
+ prompt_acoustic_tokens,
209
+ spk_acoustic_llm,
210
+ ) = self.extract_spk_embeddings(prompt_wav=prompt_wav)
211
+
212
+ spk_embeddings = spk_embeddings.unsqueeze(0)
213
+ spk_semantic_llm = self.semantic_llm.reference_embedding(spk_embeddings)
214
+
215
+ # print("---prompt_semantic_tokens:\n", prompt_semantic_tokens)
216
+ # print("---spk_embeddings:\n", spk_embeddings)
217
+
218
+ # clean text
219
+ prompt_text = clean_text(prompt_text)
220
+ text = clean_text(text=text)
221
+
222
+ if use_tn:
223
+ substrings = text_split(text=text)
224
+
225
+ out_wavs = []
226
+ try:
227
+ for sub in substrings:
228
+
229
+ res_lang = self.text_normalizer.tn(text=sub)[1]
230
+
231
+ chunk = self.synthesize_base(
232
+ prompt_semantic_tokens=prompt_semantic_tokens,
233
+ prompt_acoustic_tokens=prompt_acoustic_tokens,
234
+ spk_semantic_llm=spk_semantic_llm,
235
+ spk_acoustic_llm=spk_acoustic_llm,
236
+ prompt_text=prompt_text,
237
+ text=sub,
238
+ lang=res_lang,
239
+ )
240
+
241
+ out_wavs.append(chunk)
242
+ out_wav = torch.concat(out_wavs, axis=-1)
243
+ return out_wav
244
+ except:
245
+ print('[ERROR] ', format_exc())
246
+ return None
247
+ else:
248
+ out_wavs = []
249
+ try:
250
+ res_lang = self.text_normalizer.tn(text=text)[1]
251
+
252
+ chunk = self.synthesize_base(
253
+ prompt_semantic_tokens=prompt_semantic_tokens,
254
+ prompt_acoustic_tokens=prompt_acoustic_tokens,
255
+ spk_semantic_llm=spk_semantic_llm,
256
+ spk_acoustic_llm=spk_acoustic_llm,
257
+ prompt_text=prompt_text,
258
+ text=text,
259
+ lang=res_lang,
260
+ )
261
+
262
+ out_wavs.append(chunk)
263
+ out_wav = torch.concat(out_wavs, axis=-1)
264
+ return out_wav
265
+ except:
266
+ return None
fireredtts/models/token2audio.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from fireredtts.modules.acoustic_llm import AcousticLLM
4
+ from fireredtts.modules.acoustic_codec import AcousticCodec
5
+ from fireredtts.modules.flowmatching import FlowToken2Mel
6
+ from fireredtts.modules.bigvgan import BigVGAN, MelExtractor
7
+
8
+
9
+ class TwoStageCodec(torch.nn.Module):
10
+ def __init__(self, config):
11
+ super().__init__()
12
+ self.acoustic_llm = AcousticLLM(**config["acoustic_llm"])
13
+ self.acoustic_codec = AcousticCodec(**config["acoustic_codec"])
14
+
15
+ def load_model(self, acoustic_llm_path, acoustic_codec_path):
16
+ self.acoustic_llm.load_state_dict(
17
+ torch.load(acoustic_llm_path, map_location="cpu"), strict=True
18
+ )
19
+ self.acoustic_codec.load_state_dict(
20
+ torch.load(acoustic_codec_path, map_location="cpu"), strict=True
21
+ )
22
+
23
+ @torch.inference_mode()
24
+ def forward(
25
+ self, semantic_token, prompt_semantic_token, prompt_acoustic_token, spk_gpt
26
+ ):
27
+ # print('Before: ', semantic_token.shape)
28
+ token_pred = torch.cat((prompt_semantic_token, semantic_token), dim=1)
29
+
30
+ # Fine LLM inference
31
+ token_pred = self.acoustic_llm.inference_speech(
32
+ speech_conditioning_latent=spk_gpt,
33
+ text_inputs=token_pred,
34
+ num_return_sequences=1,
35
+ input_tokens=prompt_acoustic_token,
36
+ )[0]
37
+
38
+ if isinstance(token_pred, (tuple, list)):
39
+ token_pred = [x.unsqueeze(0) for x in token_pred]
40
+ else:
41
+ token_pred = token_pred.unsqueeze(0)
42
+
43
+ acoustic_outputs = self.acoustic_codec.reconstruct_wav(token=token_pred)
44
+ wav = acoustic_outputs["wav_pred"].squeeze(1)
45
+
46
+ return wav
47
+
48
+ def extract(self, wavs, wav_lengths, spk):
49
+ if torch.cuda.is_available():
50
+ wavs = wavs.cuda()
51
+ cond_tok = self.acoustic_codec.extract_speech_tokens(wavs, wav_lengths)[
52
+ "token"
53
+ ][0]
54
+ spk_gpt = self.acoustic_llm.get_conditioning(spk)
55
+ return cond_tok, spk_gpt
56
+
57
+
58
+ """For FlowToken2Audio, keep interface consistant with TwoStageCodec to minimize code changes.
59
+ prompt_acoustic_token alias to prompt_mel
60
+ spk_gpt alias to spk_embeddings
61
+ """
62
+ class FlowToken2Audio(torch.nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.flow = FlowToken2Mel(config['flow'])
66
+ self.bigvgan = BigVGAN(**config['bigvgan'])
67
+ self.mel_extractor = MelExtractor(**config['mel'])
68
+
69
+ def load_model(self, flow_path, bigvgan_path):
70
+ self.flow.load_state_dict(
71
+ torch.load(flow_path, map_location="cpu"), strict=True
72
+ )
73
+ self.bigvgan.load_state_dict(
74
+ torch.load(bigvgan_path, map_location="cpu")['generator'], strict=True
75
+ )
76
+ self.bigvgan.remove_weight_norm()
77
+
78
+ @torch.inference_mode()
79
+ def forward(
80
+ self, semantic_token, prompt_semantic_token, prompt_acoustic_token, spk_gpt
81
+ ):
82
+ # Align prompt token & prompt_mel
83
+ target_mel_length = prompt_semantic_token.shape[1] * 2
84
+ if target_mel_length > prompt_acoustic_token.shape[1]:
85
+ prompt_acoustic_token = F.pad(
86
+ prompt_acoustic_token, (0, 0, 0, target_mel_length-prompt_acoustic_token.shape[1]),
87
+ mode='constant', value=-11.5
88
+ )
89
+ elif target_mel_length < prompt_acoustic_token.shape[1]:
90
+ prompt_acoustic_token = prompt_acoustic_token[:, :target_mel_length]
91
+ # prompt_acoustic_token = F.interpolate(
92
+ # prompt_acoustic_token.transpose(1, 2),
93
+ # size=prompt_semantic_token.shape[1] * 2, mode='nearest'
94
+ # ).transpose(1, 2)
95
+ mel_pred = self.flow.inference(
96
+ prompt_token=prompt_semantic_token,
97
+ prompt_xvec=spk_gpt,
98
+ prompt_feat=prompt_acoustic_token,
99
+ token=semantic_token
100
+ )
101
+ wav = self.bigvgan(mel_pred.transpose(1, 2)).squeeze(1)
102
+ return wav
103
+
104
+ def extract(self, wavs, wav_lengths, spk):
105
+ mel = self.mel_extractor(wavs, 24000).transpose(1, 2)
106
+ if torch.cuda.is_available():
107
+ mel = mel.cuda()
108
+ return mel, spk.squeeze(0)
fireredtts/modules/__init__.py ADDED
File without changes
fireredtts/modules/acoustic_codec/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bigcodec import BigCodec as AcousticCodec
fireredtts/modules/acoustic_codec/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
fireredtts/modules/acoustic_codec/alias_free_torch/act.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ causal: bool = False,
17
+ ):
18
+ super().__init__()
19
+ self.up_ratio = up_ratio
20
+ self.down_ratio = down_ratio
21
+ self.act = activation
22
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
23
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
24
+ self.causal = causal
25
+
26
+ # x: [B,C,T]
27
+ def forward(self, x):
28
+ if self.causal:
29
+ x = self.act(x)
30
+ else:
31
+ x = self.upsample(x)
32
+ x = self.act(x)
33
+ x = self.downsample(x)
34
+
35
+ return x
fireredtts/modules/acoustic_codec/alias_free_torch/filter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
57
+ # of the constant component in the input signal.
58
+ filter_ /= filter_.sum()
59
+ filter = filter_.view(1, 1, kernel_size)
60
+
61
+ return filter
62
+
63
+
64
+ class LowPassFilter1d(nn.Module):
65
+ def __init__(
66
+ self,
67
+ cutoff=0.5,
68
+ half_width=0.6,
69
+ stride: int = 1,
70
+ padding: bool = True,
71
+ padding_mode: str = "replicate",
72
+ kernel_size: int = 12,
73
+ ):
74
+ # kernel_size should be even number for stylegan3 setup,
75
+ # in this implementation, odd number is also possible.
76
+ super().__init__()
77
+ if cutoff < -0.0:
78
+ raise ValueError("Minimum cutoff must be larger than zero.")
79
+ if cutoff > 0.5:
80
+ raise ValueError("A cutoff above 0.5 does not make sense.")
81
+ self.kernel_size = kernel_size
82
+ self.even = kernel_size % 2 == 0
83
+ self.pad_left = kernel_size // 2 - int(self.even)
84
+ self.pad_right = kernel_size // 2
85
+ self.stride = stride
86
+ self.padding = padding
87
+ self.padding_mode = padding_mode
88
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
89
+ self.register_buffer("filter", filter)
90
+
91
+ # input [B, C, T]
92
+ def forward(self, x):
93
+ _, C, _ = x.shape
94
+
95
+ if self.padding:
96
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
97
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
98
+
99
+ return out
fireredtts/modules/acoustic_codec/alias_free_torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
fireredtts/modules/acoustic_codec/bigcodec.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from torch import sin, pow
3
+ from torch.nn import Parameter
4
+ from torch.nn.utils import spectral_norm, weight_norm
5
+
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchaudio
12
+ import typing as tp
13
+ import warnings
14
+
15
+ from .alias_free_torch import *
16
+ from .vector_quantization import VectorQuantization
17
+
18
+
19
+ CONV_NORMALIZATIONS = frozenset(
20
+ [
21
+ "none",
22
+ "weight_norm",
23
+ "spectral_norm",
24
+ "time_layer_norm",
25
+ "layer_norm",
26
+ "time_group_norm",
27
+ ]
28
+ )
29
+
30
+
31
+ def init_weights(m):
32
+ if isinstance(m, nn.Conv1d):
33
+ nn.init.trunc_normal_(m.weight, std=0.02)
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+
37
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
38
+ assert norm in CONV_NORMALIZATIONS
39
+ if norm == "weight_norm":
40
+ return weight_norm(module)
41
+ elif norm == "spectral_norm":
42
+ return spectral_norm(module)
43
+ else:
44
+ return module
45
+
46
+
47
+ def get_norm_module(
48
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
49
+ ) -> nn.Module:
50
+ assert norm in CONV_NORMALIZATIONS
51
+ if norm == "time_group_norm":
52
+ if causal:
53
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
54
+ assert isinstance(module, nn.modules.conv._ConvNd)
55
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
56
+ else:
57
+ return nn.Identity()
58
+
59
+
60
+ def get_extra_padding_for_conv1d(
61
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
62
+ ) -> int:
63
+ length = x.shape[-1]
64
+ n_frames = (length - kernel_size + padding_total) / stride + 1
65
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
66
+ return ideal_length - length
67
+
68
+
69
+ def pad_for_conv1d(
70
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
71
+ ):
72
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
73
+ return F.pad(x, (0, extra_padding))
74
+
75
+
76
+ def pad1d(
77
+ x: torch.Tensor,
78
+ paddings: tp.Tuple[int, int],
79
+ mode: str = "zero",
80
+ value: float = 0.0,
81
+ ):
82
+ length = x.shape[-1]
83
+ padding_left, padding_right = paddings
84
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
85
+ if mode == "reflect":
86
+ max_pad = max(padding_left, padding_right)
87
+ extra_pad = 0
88
+ if length <= max_pad:
89
+ extra_pad = max_pad - length + 1
90
+ x = F.pad(x, (0, extra_pad))
91
+ padded = F.pad(x, paddings, mode, value)
92
+ end = padded.shape[-1] - extra_pad
93
+ return padded[..., :end]
94
+ else:
95
+ return F.pad(x, paddings, mode, value)
96
+
97
+
98
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
99
+ padding_left, padding_right = paddings
100
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
101
+ assert (padding_left + padding_right) <= x.shape[-1]
102
+ end = x.shape[-1] - padding_right
103
+ return x[..., padding_left:end]
104
+
105
+
106
+ class NormConv1d(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ *args,
111
+ causal: bool = False,
112
+ norm: str = "none",
113
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
114
+ **kwargs,
115
+ ):
116
+ super().__init__()
117
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
118
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
119
+ self.norm_type = norm
120
+
121
+ def forward(self, x):
122
+ x = self.conv(x)
123
+ x = self.norm(x)
124
+ return x
125
+
126
+
127
+ class NormConvTranspose1d(nn.Module):
128
+
129
+ def __init__(
130
+ self,
131
+ *args,
132
+ causal: bool = False,
133
+ norm: str = "none",
134
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
135
+ **kwargs,
136
+ ):
137
+ super().__init__()
138
+ self.convtr = apply_parametrization_norm(
139
+ nn.ConvTranspose1d(*args, **kwargs), norm
140
+ )
141
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
142
+ self.norm_type = norm
143
+
144
+ def forward(self, x):
145
+ x = self.convtr(x)
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ class SConv1d(nn.Module):
151
+
152
+ def __init__(
153
+ self,
154
+ in_channels: int,
155
+ out_channels: int,
156
+ kernel_size: int,
157
+ stride: int = 1,
158
+ dilation: int = 1,
159
+ groups: int = 1,
160
+ bias: bool = True,
161
+ causal: bool = False,
162
+ norm: str = "none",
163
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
164
+ pad_mode: str = "reflect",
165
+ **kwargs,
166
+ ):
167
+ super().__init__()
168
+ # warn user on unusual setup between dilation and stride
169
+ if stride > 1 and dilation > 1:
170
+ warnings.warn(
171
+ "SConv1d has been initialized with stride > 1 and dilation > 1"
172
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
173
+ )
174
+ self.conv = NormConv1d(
175
+ in_channels,
176
+ out_channels,
177
+ kernel_size,
178
+ stride,
179
+ dilation=dilation,
180
+ groups=groups,
181
+ bias=bias,
182
+ causal=causal,
183
+ norm=norm,
184
+ norm_kwargs=norm_kwargs,
185
+ )
186
+ self.causal = causal
187
+ self.pad_mode = pad_mode
188
+
189
+ def forward(self, x):
190
+ B, C, T = x.shape
191
+ kernel_size = self.conv.conv.kernel_size[0]
192
+ stride = self.conv.conv.stride[0]
193
+ dilation = self.conv.conv.dilation[0]
194
+ kernel_size = (
195
+ kernel_size - 1
196
+ ) * dilation + 1 # effective kernel size with dilations
197
+ padding_total = kernel_size - stride
198
+ extra_padding = get_extra_padding_for_conv1d(
199
+ x, kernel_size, stride, padding_total
200
+ )
201
+ if self.causal:
202
+ # Left padding for causal
203
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
204
+ else:
205
+ # Asymmetric padding required for odd strides
206
+ padding_right = padding_total // 2
207
+ padding_left = padding_total - padding_right
208
+ x = pad1d(
209
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
210
+ )
211
+ return self.conv(x)
212
+
213
+
214
+ class SConvTranspose1d(nn.Module):
215
+
216
+ def __init__(
217
+ self,
218
+ in_channels: int,
219
+ out_channels: int,
220
+ kernel_size: int,
221
+ stride: int = 1,
222
+ causal: bool = False,
223
+ norm: str = "none",
224
+ trim_right_ratio: float = 1.0,
225
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
226
+ **kwargs,
227
+ ):
228
+ super().__init__()
229
+ self.convtr = NormConvTranspose1d(
230
+ in_channels,
231
+ out_channels,
232
+ kernel_size,
233
+ stride,
234
+ causal=causal,
235
+ norm=norm,
236
+ norm_kwargs=norm_kwargs,
237
+ )
238
+ self.causal = causal
239
+ self.trim_right_ratio = trim_right_ratio
240
+ assert (
241
+ self.causal or self.trim_right_ratio == 1.0
242
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
243
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
244
+
245
+ def forward(self, x):
246
+ kernel_size = self.convtr.convtr.kernel_size[0]
247
+ stride = self.convtr.convtr.stride[0]
248
+ padding_total = kernel_size - stride
249
+
250
+ y = self.convtr(x)
251
+
252
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
253
+ # removed at the very end, when keeping only the right length for the output,
254
+ # as removing it here would require also passing the length at the matching layer
255
+ # in the encoder.
256
+ if self.causal:
257
+ # Trim the padding on the right according to the specified ratio
258
+ # if trim_right_ratio = 1.0, trim everything from right
259
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
260
+ padding_left = padding_total - padding_right
261
+ y = unpad1d(y, (padding_left, padding_right))
262
+ else:
263
+ # Asymmetric padding required for odd strides
264
+ padding_right = padding_total // 2
265
+ padding_left = padding_total - padding_right
266
+ y = unpad1d(y, (padding_left, padding_right))
267
+ return y
268
+
269
+
270
+ def WNConv1d(*args, **kwargs):
271
+ if kwargs.get("causal", False):
272
+ kwargs["norm"] = "weight_norm"
273
+ conv1d = SConv1d(*args, **kwargs)
274
+ else:
275
+ kwargs.pop("causal")
276
+ conv1d = weight_norm(nn.Conv1d(*args, **kwargs))
277
+ return conv1d
278
+
279
+
280
+ def WNConvTranspose1d(*args, **kwargs):
281
+ if kwargs.get("causal", False):
282
+ kwargs["norm"] = "weight_norm"
283
+ transposed_conv1d = SConvTranspose1d(*args, **kwargs)
284
+ else:
285
+ kwargs.pop("causal")
286
+ transposed_conv1d = weight_norm(nn.ConvTranspose1d(*args, **kwargs))
287
+ return transposed_conv1d
288
+
289
+
290
+ class SnakeBeta(nn.Module):
291
+
292
+ def __init__(
293
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
294
+ ):
295
+ super(SnakeBeta, self).__init__()
296
+ self.in_features = in_features
297
+
298
+ # initialize alpha
299
+ self.alpha_logscale = alpha_logscale
300
+ if self.alpha_logscale: # log scale alphas initialized to zeros
301
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
302
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
303
+ else: # linear scale alphas initialized to ones
304
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
305
+ self.beta = Parameter(torch.ones(in_features) * alpha)
306
+
307
+ self.alpha.requires_grad = alpha_trainable
308
+ self.beta.requires_grad = alpha_trainable
309
+
310
+ self.no_div_by_zero = 0.000000001
311
+
312
+ def forward(self, x):
313
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
314
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
315
+ if self.alpha_logscale:
316
+ alpha = torch.exp(alpha)
317
+ beta = torch.exp(beta)
318
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
319
+
320
+ return x
321
+
322
+
323
+ class ResidualUnit(nn.Module):
324
+
325
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
326
+ super().__init__()
327
+ pad = ((7 - 1) * dilation) // 2
328
+ self.block = nn.Sequential(
329
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True), causal=causal),
330
+ WNConv1d(
331
+ dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal
332
+ ),
333
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True), causal=causal),
334
+ WNConv1d(dim, dim, kernel_size=1, causal=causal),
335
+ )
336
+
337
+ def forward(self, x):
338
+ return x + self.block(x)
339
+
340
+
341
+ class EncoderBlock(nn.Module):
342
+
343
+ def __init__(
344
+ self, dim: int = 16, stride: int = 1, dilations=(1, 3, 9), causal: bool = False
345
+ ):
346
+ super().__init__()
347
+ runits = [ResidualUnit(dim // 2, dilation=d, causal=causal) for d in dilations]
348
+ self.block = nn.Sequential(
349
+ *runits,
350
+ Activation1d(
351
+ activation=SnakeBeta(dim // 2, alpha_logscale=True), causal=causal
352
+ ),
353
+ WNConv1d(
354
+ dim // 2,
355
+ dim,
356
+ kernel_size=2 * stride,
357
+ stride=stride,
358
+ padding=stride // 2 + stride % 2,
359
+ causal=causal,
360
+ ),
361
+ )
362
+
363
+ def forward(self, x):
364
+ return self.block(x)
365
+
366
+
367
+ class DecoderBlock(nn.Module):
368
+
369
+ def __init__(
370
+ self,
371
+ input_dim: int = 16,
372
+ output_dim: int = 8,
373
+ stride: int = 1,
374
+ dilations=(1, 3, 9),
375
+ causal: bool = False,
376
+ ):
377
+ super().__init__()
378
+ self.block = nn.Sequential(
379
+ Activation1d(
380
+ activation=SnakeBeta(input_dim, alpha_logscale=True), causal=causal
381
+ ),
382
+ WNConvTranspose1d(
383
+ input_dim,
384
+ output_dim,
385
+ kernel_size=2 * stride,
386
+ stride=stride,
387
+ padding=stride // 2 + stride % 2,
388
+ output_padding=stride % 2,
389
+ causal=causal,
390
+ ),
391
+ )
392
+ self.block.extend(
393
+ [ResidualUnit(output_dim, dilation=d, causal=causal) for d in dilations]
394
+ )
395
+
396
+ def forward(self, x):
397
+ return self.block(x)
398
+
399
+
400
+ class ResLSTM(nn.Module):
401
+
402
+ def __init__(
403
+ self,
404
+ dimension: int,
405
+ num_layers: int = 2,
406
+ bidirectional: bool = False,
407
+ skip: bool = True,
408
+ ):
409
+ super().__init__()
410
+ self.skip = skip
411
+ self.lstm = nn.LSTM(
412
+ dimension,
413
+ dimension if not bidirectional else dimension // 2,
414
+ num_layers,
415
+ batch_first=True,
416
+ bidirectional=bidirectional,
417
+ )
418
+
419
+ def forward(self, x):
420
+ x = rearrange(x, "b f t -> b t f")
421
+ y, _ = self.lstm(x)
422
+ if self.skip:
423
+ y = y + x
424
+ y = rearrange(y, "b t f -> b f t")
425
+ return y
426
+
427
+
428
+ class Resampler(nn.Module):
429
+
430
+ def __init__(self, source_sr=24000, target_sr=24000):
431
+ super().__init__()
432
+ self.source_sr = source_sr
433
+ self.target_sr = target_sr
434
+
435
+ def forward(self, wav, wav_length):
436
+ if self.source_sr != self.target_sr:
437
+ wav = torchaudio.functional.resample(wav, self.source_sr, self.target_sr)
438
+ wav_length = (wav_length * (self.source_sr / self.target_sr)).int()
439
+ return wav, wav_length
440
+
441
+
442
+ class CodecEncoder(nn.Module):
443
+
444
+ def __init__(
445
+ self,
446
+ ngf=48,
447
+ use_rnn=True,
448
+ rnn_bidirectional=False,
449
+ rnn_num_layers=2,
450
+ up_ratios=(2, 2, 2, 5, 5),
451
+ dilations=(1, 3, 9),
452
+ out_channels=1024,
453
+ causal=False,
454
+ ):
455
+ super().__init__()
456
+ self.hop_length = np.prod(up_ratios)
457
+ self.ngf = ngf
458
+ self.up_ratios = up_ratios
459
+
460
+ # Create first convolution
461
+ d_model = ngf
462
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, causal=causal)]
463
+
464
+ # Create EncoderBlocks that double channels as they downsample by `stride`
465
+ for i, stride in enumerate(up_ratios):
466
+ d_model *= 2
467
+ self.block += [
468
+ EncoderBlock(d_model, stride=stride, dilations=dilations, causal=causal)
469
+ ]
470
+ # RNN
471
+ if use_rnn:
472
+ self.block += [
473
+ ResLSTM(
474
+ d_model, num_layers=rnn_num_layers, bidirectional=rnn_bidirectional
475
+ )
476
+ ]
477
+ # Create last convolution
478
+ self.block += [
479
+ Activation1d(
480
+ activation=SnakeBeta(d_model, alpha_logscale=True), causal=causal
481
+ ),
482
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1, causal=causal),
483
+ ]
484
+
485
+ # Wrap black into nn.Sequential
486
+ self.block = nn.Sequential(*self.block)
487
+ self.enc_dim = d_model
488
+
489
+ self.reset_parameters()
490
+
491
+ def forward(self, x):
492
+ out = self.block(x)
493
+ return out
494
+
495
+ def remove_weight_norm(self):
496
+ def _remove_weight_norm(m):
497
+ try:
498
+ torch.nn.utils.remove_weight_norm(m)
499
+ except ValueError: # this module didn't have weight norm
500
+ return
501
+
502
+ self.apply(_remove_weight_norm)
503
+
504
+ def apply_weight_norm(self):
505
+ def _apply_weight_norm(m):
506
+ if isinstance(m, nn.Conv1d):
507
+ torch.nn.utils.weight_norm(m)
508
+
509
+ self.apply(_apply_weight_norm)
510
+
511
+ def reset_parameters(self):
512
+ self.apply(init_weights)
513
+
514
+
515
+ class CodecDecoder(nn.Module):
516
+
517
+ def __init__(
518
+ self,
519
+ in_channels=1024,
520
+ upsample_initial_channel=1536,
521
+ ngf=48,
522
+ use_rnn=True,
523
+ rnn_bidirectional=False,
524
+ rnn_num_layers=2,
525
+ up_ratios=(5, 5, 2, 2, 2),
526
+ dilations=(1, 3, 9),
527
+ causal=False,
528
+ delay=0,
529
+ ):
530
+ super().__init__()
531
+ self.hop_length = np.prod(up_ratios)
532
+ self.ngf = ngf
533
+ self.up_ratios = up_ratios
534
+ self.delay = delay
535
+
536
+ channels = upsample_initial_channel
537
+ layers = [
538
+ WNConv1d(in_channels, channels, kernel_size=7, padding=3, causal=causal)
539
+ ]
540
+
541
+ if use_rnn:
542
+ layers += [
543
+ ResLSTM(
544
+ channels, num_layers=rnn_num_layers, bidirectional=rnn_bidirectional
545
+ )
546
+ ]
547
+
548
+ for i, stride in enumerate(up_ratios):
549
+ input_dim = channels // 2**i
550
+ output_dim = channels // 2 ** (i + 1)
551
+ layers += [
552
+ DecoderBlock(input_dim, output_dim, stride, dilations, causal=causal)
553
+ ]
554
+
555
+ layers += [
556
+ Activation1d(
557
+ activation=SnakeBeta(output_dim, alpha_logscale=True), causal=causal
558
+ ),
559
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3, causal=causal),
560
+ nn.Tanh(),
561
+ ]
562
+
563
+ self.model = nn.Sequential(*layers)
564
+ self.reset_parameters()
565
+
566
+ def forward(self, x):
567
+ # Time delay
568
+ if self.delay > 0:
569
+ x = F.pad(x, (0, self.delay), mode="constant", value=0)
570
+
571
+ x = self.model(x)
572
+
573
+ # De-delay
574
+ if self.delay > 0:
575
+ x = x[..., self.delay :]
576
+
577
+ return x
578
+
579
+ def remove_weight_norm(self):
580
+ def _remove_weight_norm(m):
581
+ try:
582
+ torch.nn.utils.remove_weight_norm(m)
583
+ except ValueError: # this module didn't have weight norm
584
+ return
585
+
586
+ self.apply(_remove_weight_norm)
587
+
588
+ def apply_weight_norm(self):
589
+ def _apply_weight_norm(m):
590
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
591
+ torch.nn.utils.weight_norm(m)
592
+
593
+ self.apply(_apply_weight_norm)
594
+
595
+ def reset_parameters(self):
596
+ self.apply(init_weights)
597
+
598
+
599
+ class BigCodec(nn.Module):
600
+
601
+ def __init__(
602
+ self,
603
+ n_model_size: int,
604
+ encoder_config: dict,
605
+ decoder_config: dict,
606
+ vq_config: dict,
607
+ resampler_config: dict = None,
608
+ ):
609
+ super(BigCodec, self).__init__()
610
+ self.n_model_size = n_model_size
611
+
612
+ self.encoder = CodecEncoder(out_channels=n_model_size, **encoder_config)
613
+ self.decoder = CodecDecoder(in_channels=n_model_size, **decoder_config)
614
+ self.quantizer = VectorQuantization(n_model_size, **vq_config)
615
+
616
+ # Optional modules
617
+ if resampler_config:
618
+ self.resampler = Resampler(**resampler_config)
619
+
620
+ def forward(
621
+ self, wav, wav_length=None, enable_vq=True, decode=True, update_codebook=True
622
+ ):
623
+ # Preprocess wav
624
+ if len(wav.shape) == 2:
625
+ wav = wav.unsqueeze(1)
626
+ if wav_length is None:
627
+ wav_length = torch.full([wav.shape[0]], max(wav.shape)).to(wav.device)
628
+
629
+ # (Optional) Resample
630
+ processed_wav, processed_wav_length = wav, wav_length
631
+ if hasattr(self, "resampler"):
632
+ processed_wav, processed_wav_length = self.resampler(
633
+ processed_wav, processed_wav_length
634
+ )
635
+
636
+ # Update VQ parameters
637
+ quant_length = torch.ceil(processed_wav_length / self.encoder.hop_length).int()
638
+ update_codebook = update_codebook and self.training
639
+
640
+ # Encode
641
+ encoder_outputs = self.encoder(processed_wav)
642
+
643
+ # Quantize
644
+ quant, diff, embed_ind = self.quantizer(
645
+ encoder_outputs.transpose(1, 2),
646
+ quant_length.clamp(max=encoder_outputs.shape[2]),
647
+ enable_vq=enable_vq,
648
+ update_codebook=update_codebook,
649
+ )
650
+
651
+ if decode:
652
+ # Decode
653
+ decoder_outputs = self.decoder(quant.transpose(1, 2))
654
+ else:
655
+ decoder_outputs = None
656
+
657
+ output_dict = {
658
+ "quant": quant,
659
+ "token": embed_ind,
660
+ "token_length": quant_length,
661
+ "encoder_diffs": diff,
662
+ "wav_pred": decoder_outputs,
663
+ }
664
+ return output_dict
665
+
666
+ @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
667
+ def extract_speech_tokens(
668
+ self, wav, wav_length, serialize=True, extract_spk=True, shuffle=False
669
+ ):
670
+ output_dict = self.forward(wav, wav_length, enable_vq=True, decode=False)
671
+ token_seqs, token_length = [output_dict["token"]], [output_dict["token_length"]]
672
+ output_dict.update(
673
+ {
674
+ "token": token_seqs,
675
+ "token_length": token_length,
676
+ }
677
+ )
678
+
679
+ return output_dict
680
+
681
+ @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
682
+ def reconstruct_wav(self, token=None, quant=None, spk=None):
683
+ if token is not None:
684
+ # De-tokenization
685
+ quant = self.quantizer.decode(token)
686
+
687
+ # Speaker embedding
688
+ if hasattr(self, "global_encoder"):
689
+ quant = quant + spk.unsqueeze(2)
690
+ else:
691
+ assert quant is not None
692
+
693
+ # Decode
694
+ wav_pred = self.decoder(quant)
695
+
696
+ return {
697
+ "wav_pred": wav_pred,
698
+ }
fireredtts/modules/acoustic_codec/vector_quantization.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+ from torch import nn
3
+ from typing import Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import typing as tp
8
+ import numpy as np
9
+ import warnings
10
+
11
+
12
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
13
+ return val if val is not None else d
14
+
15
+
16
+ def flatten(x, x_len):
17
+ x_f = x.view(-1, *x.shape[2:])
18
+ return x_f
19
+
20
+
21
+ def ema_inplace(moving_avg, new, decay):
22
+ if isinstance(decay, torch.Tensor):
23
+ moving_avg.data.mul_(decay).add_(new * (1 - decay))
24
+ else:
25
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
26
+
27
+
28
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
29
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
30
+
31
+
32
+ def uniform_init(*shape: int):
33
+ t = torch.empty(shape)
34
+ nn.init.kaiming_uniform_(t)
35
+ return t
36
+
37
+
38
+ def sample_vectors(samples, num: int):
39
+ num_samples, device = samples.shape[0], samples.device
40
+
41
+ if num_samples >= num:
42
+ indices = torch.randperm(num_samples, device=device)[:num]
43
+ else:
44
+ indices = torch.randint(0, num_samples, (num,), device=device)
45
+
46
+ return samples[indices]
47
+
48
+
49
+ class EuclideanCodebook(nn.Module):
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ codebook_size: int,
55
+ decay: float = 0.99,
56
+ epsilon: float = 1e-5,
57
+ threshold_ema_dead_code: float = 1.0,
58
+ n_cache_iters: int = 1,
59
+ ):
60
+ super().__init__()
61
+ self.decay = decay
62
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init
63
+ embed = init_fn(codebook_size, dim)
64
+
65
+ self.codebook_size = codebook_size
66
+
67
+ self.epsilon = epsilon
68
+ self.threshold_ema_dead_code = threshold_ema_dead_code
69
+ self.update_iter = 0
70
+
71
+ self.n_cache_iters = n_cache_iters
72
+ self.cache_vectors = []
73
+ self.cache_indices = []
74
+
75
+ if isinstance(self.decay, (tuple, list)):
76
+ self.embed_avg_cache = []
77
+ self.register_buffer("diff_avg_long", torch.zeros(codebook_size) + 1e-5)
78
+ self.register_buffer("diff_avg_short", torch.zeros(codebook_size) + 1e-5)
79
+ self.register_buffer("inited", torch.Tensor([True]))
80
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
81
+ self.register_buffer("embed", embed)
82
+ self.register_buffer("embed_avg", embed.clone())
83
+
84
+ @torch.jit.ignore
85
+ def init_embed_(self, data):
86
+ if self.inited:
87
+ return
88
+
89
+ def replace_(self, samples, mask, dists=None):
90
+ reset_cluster_size = min(
91
+ self.threshold_ema_dead_code + 1, self.threshold_ema_dead_code * 1.1
92
+ )
93
+
94
+ modified_codebook = torch.where(
95
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
96
+ )
97
+ modified_codebook_avg = torch.where(
98
+ mask[..., None], modified_codebook * reset_cluster_size, self.embed_avg
99
+ )
100
+ modified_cluster_size = torch.where(
101
+ mask,
102
+ torch.full_like(self.cluster_size, reset_cluster_size),
103
+ self.cluster_size,
104
+ )
105
+
106
+ self.embed.data.copy_(modified_codebook)
107
+ self.embed_avg.data.copy_(modified_codebook_avg)
108
+ self.cluster_size.data.copy_(modified_cluster_size)
109
+
110
+ def expire_codes_(self, batch_samples, dists=None):
111
+ self.update_iter += 1
112
+ if self.threshold_ema_dead_code == 0:
113
+ return
114
+ elif self.threshold_ema_dead_code < 1:
115
+ threshold_ema_dead_code = (
116
+ sum(self.cluster_size) * self.threshold_ema_dead_code
117
+ )
118
+ else:
119
+ threshold_ema_dead_code = self.threshold_ema_dead_code
120
+
121
+ expired_codes = self.cluster_size < threshold_ema_dead_code
122
+ if not torch.any(expired_codes):
123
+ return
124
+
125
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
126
+ self.replace_(batch_samples, mask=expired_codes, dists=dists)
127
+
128
+ def preprocess(self, x):
129
+ x = rearrange(x, "... d -> (...) d")
130
+ return x
131
+
132
+ def quantize(self, x):
133
+ embed = self.embed.t()
134
+ dist = -(
135
+ x.pow(2).sum(1, keepdim=True)
136
+ - 2 * x @ embed
137
+ + embed.pow(2).sum(0, keepdim=True)
138
+ )
139
+ embed_ind = dist.max(dim=-1).indices
140
+
141
+ return embed_ind, dist
142
+
143
+ def postprocess_emb(self, embed_ind, shape):
144
+ return embed_ind.view(*shape[:-1])
145
+
146
+ def dequantize(self, embed_ind):
147
+ quantize = F.embedding(embed_ind, self.embed)
148
+ return quantize
149
+
150
+ def encode(self, x):
151
+ shape = x.shape
152
+ # pre-process
153
+ x = self.preprocess(x)
154
+ # quantize
155
+ embed_ind, dist = self.quantize(x)
156
+ # post-process
157
+ embed_ind = self.postprocess_emb(embed_ind, shape)
158
+
159
+ return embed_ind, dist
160
+
161
+ def decode(self, embed_ind):
162
+ quantize = self.dequantize(embed_ind)
163
+ return quantize
164
+
165
+ def forward(self, x, x_len, enable_vq=True, update_codebook=True, masking=False):
166
+ x_org, shape, dtype = x, x.shape, x.dtype
167
+
168
+ x = self.preprocess(x)
169
+
170
+ embed_ind, dist = self.quantize(x)
171
+ embed_ind = self.postprocess_emb(embed_ind, shape)
172
+ dist = dist.view(shape[0], shape[1], dist.shape[-1])
173
+
174
+ quantize = self.dequantize(embed_ind)
175
+
176
+ if self.training and update_codebook:
177
+ if enable_vq:
178
+ quantize = x_org + (quantize - x_org).detach()
179
+ else:
180
+ quantize = x_org
181
+
182
+ # Get flatten embedding indices and distances
183
+ if masking:
184
+ x_f = torch.cat(
185
+ [e[: int(e_len)] for e, e_len in zip(x_org, x_len)], dim=0
186
+ )
187
+ embed_ind_f = torch.cat(
188
+ [e[: int(e_len)] for e, e_len in zip(embed_ind, x_len)], dim=0
189
+ )
190
+ dist_f = torch.cat(
191
+ [e[: int(e_len)] for e, e_len in zip(dist, x_len)], dim=0
192
+ )
193
+ q_f = torch.cat(
194
+ [e[: int(e_len)] for e, e_len in zip(quantize.detach(), x_len)],
195
+ dim=0,
196
+ )
197
+ commit_loss = F.mse_loss(q_f, x_f)
198
+ else:
199
+ x_f = x_org.view(-1, x_org.shape[-1]).contiguous()
200
+ embed_ind_f = embed_ind.view(-1).contiguous()
201
+ dist_f = dist.view(-1).contiguous()
202
+ commit_loss = F.mse_loss(quantize.detach(), x_org)
203
+ self.init_embed_(x_f)
204
+
205
+ # We do the expiry of code at that point as buffers are in sync
206
+ # and all the workers will take the same decision.
207
+ self.expire_codes_(x_f, dist_f)
208
+
209
+ # Calculate codebook statistics
210
+ embed_onehot = F.one_hot(embed_ind_f, self.codebook_size).type(dtype)
211
+ embed_onehot_sum = embed_onehot.sum(0)
212
+ embed_sum = x_f.t() @ embed_onehot
213
+
214
+ # EMA updating
215
+ ema_inplace(self.cluster_size, embed_onehot_sum, self.decay)
216
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
217
+
218
+ cluster_size = (
219
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
220
+ * self.cluster_size.sum()
221
+ )
222
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
223
+ self.embed.data.copy_(embed_normalized)
224
+ else:
225
+ commit_loss = torch.tensor(
226
+ 0.0, device=quantize.device, requires_grad=self.training
227
+ )
228
+
229
+ return quantize, commit_loss, embed_ind
230
+
231
+
232
+ class MultiHeadEuclideanCodebook(nn.Module):
233
+
234
+ def __init__(
235
+ self,
236
+ dim: Union[int, list],
237
+ codebook_size: list,
238
+ n_groups: int = 1,
239
+ dropout_rate_per_group: float = 0,
240
+ ordered: bool = False,
241
+ ordered_axis: str = "sequence",
242
+ method: str = "product",
243
+ **kwargs,
244
+ ):
245
+ super().__init__()
246
+ self.codebook_sizes = codebook_size
247
+ self.codebook_dims = dim
248
+ self.n_groups = n_groups
249
+ self.n_heads_per_group = len(codebook_size) // n_groups
250
+ self.dropout_rate_per_group = dropout_rate_per_group
251
+ self.ordered = ordered
252
+ self.ordered_axis = ordered_axis
253
+ self.method = method
254
+ assert len(codebook_size) % n_groups == 0
255
+
256
+ self.codebooks = nn.ModuleList()
257
+ dim = self.codebook_dims
258
+ for i, size in enumerate(self.codebook_sizes):
259
+ if isinstance(self.codebook_dims, list):
260
+ dim = (
261
+ self.codebook_dims[i]
262
+ if method == "product"
263
+ else sum(self.codebook_dims)
264
+ )
265
+ self.codebooks.append(EuclideanCodebook(dim, size, **kwargs))
266
+
267
+ def decode(self, embed_ind):
268
+ if self.n_groups == 1 or len(embed_ind.shape) == 2:
269
+ embed_ind = embed_ind.unsqueeze(-1)
270
+
271
+ actual_n_groups = embed_ind.shape[-1]
272
+ if actual_n_groups < self.n_groups:
273
+ print(
274
+ f"The actual number of heads ({actual_n_groups}) is smaller than the pre-designed ({self.n_groups})!"
275
+ )
276
+ embed_ind = F.pad(
277
+ embed_ind, (0, self.n_groups - actual_n_groups), "replicate"
278
+ )
279
+ # assert embed_ind.shape[-1] == self.n_groups
280
+
281
+ index_heads, codebook_heads, scale_heads = zip(
282
+ *[
283
+ (
284
+ embed_ind[..., i // self.n_heads_per_group],
285
+ self.codebooks[i : i + self.n_heads_per_group],
286
+ self.codebook_sizes[i : i + self.n_heads_per_group],
287
+ )
288
+ for i in range(0, len(self.codebook_sizes), self.n_heads_per_group)
289
+ ]
290
+ )
291
+
292
+ quantize_heads, quantize_groups = [], []
293
+ for i in range(self.n_groups):
294
+ embed_ind, codebooks, scales = (
295
+ index_heads[i],
296
+ codebook_heads[i],
297
+ scale_heads[i],
298
+ )
299
+
300
+ inv_scales = list(torch.tensor([1] + scales[:-1]).cumprod(dim=0))[::-1]
301
+ inv_quantizes = []
302
+ for codebook, scale in zip(codebooks[::-1], inv_scales):
303
+ index, embed_ind = embed_ind // scale, embed_ind % scale
304
+ quantize = codebook.dequantize(index)
305
+ inv_quantizes.append(quantize)
306
+ quantizes = inv_quantizes[::-1]
307
+ group_embeddings = torch.cat(quantizes, dim=-1)
308
+ quantize_groups.append(group_embeddings)
309
+ quantize_heads += quantizes
310
+
311
+ if self.method == "product":
312
+ if actual_n_groups < self.n_groups:
313
+ for i in range(actual_n_groups, self.n_groups):
314
+ quantize_groups[i].zero_()
315
+ quantize = torch.cat(quantize_groups, dim=-1)
316
+ elif self.method == "residual":
317
+ quantize = sum(quantize_heads)
318
+
319
+ return quantize
320
+
321
+ def forward(self, x, x_len, enable_vq=True, update_codebook=True):
322
+ # Pre-process
323
+ x = self._preprocess(x)
324
+
325
+ # Quantize
326
+ quants, losses, indices = self._quantize(
327
+ x, x_len, enable_vq=enable_vq, update_codebook=update_codebook
328
+ )
329
+
330
+ # Integrate
331
+ quant, loss, index = self._integrate(
332
+ quants, losses, indices, update_codebook=update_codebook
333
+ )
334
+
335
+ return quant, loss, index
336
+
337
+ def _preprocess(self, x):
338
+ if self.method == "product" and isinstance(self.codebook_dims, (list, tuple)):
339
+ x = torch.split(x, self.codebook_dims, dim=-1)
340
+ return x
341
+
342
+ def _quantize(self, x, x_len, enable_vq, update_codebook):
343
+ if self.method == "product":
344
+ quants, losses, indices = zip(
345
+ *[
346
+ codebook(
347
+ chunk,
348
+ x_len,
349
+ enable_vq=enable_vq,
350
+ update_codebook=update_codebook,
351
+ )
352
+ for chunk, codebook in zip(x, self.codebooks)
353
+ ]
354
+ )
355
+ elif self.method == "residual":
356
+ quants, losses, indices = [], [], []
357
+ residual = x
358
+ for codebook in self.codebooks:
359
+ quant, loss, index = codebook(
360
+ residual,
361
+ x_len,
362
+ enable_vq=enable_vq,
363
+ update_codebook=update_codebook,
364
+ )
365
+ residual = residual - quant
366
+ quants.append(quant)
367
+ losses.append(loss)
368
+ indices.append(index)
369
+
370
+ return quants, losses, indices
371
+
372
+ def _integrate(self, quants, losses, indices, update_codebook=True):
373
+ (B, T, D), M = quants[0].shape, len(quants)
374
+ device = quants[0].device
375
+
376
+ # Average loss
377
+ loss = sum(losses) / len(losses)
378
+
379
+ # Get indices
380
+ if self.n_groups == 1:
381
+ scale = (
382
+ torch.tensor([1] + self.codebook_sizes[:-1]).cumprod(dim=0).to(device)
383
+ )
384
+ index = (torch.stack(indices, dim=-1) * scale).sum(dim=-1)
385
+ else:
386
+ index_heads, scale_heads = zip(
387
+ *[
388
+ (
389
+ indices[i : i + self.n_heads_per_group],
390
+ torch.tensor(
391
+ [1]
392
+ + self.codebook_sizes[i : i + self.n_heads_per_group - 1]
393
+ )
394
+ .cumprod(dim=0)
395
+ .to(device),
396
+ )
397
+ for i in range(0, len(quants), self.n_heads_per_group)
398
+ ]
399
+ )
400
+ index = torch.stack(
401
+ [
402
+ (torch.stack(x, dim=-1) * s).sum(dim=-1)
403
+ for x, s in zip(index_heads, scale_heads)
404
+ ],
405
+ dim=-1,
406
+ )
407
+
408
+ # Add dropout
409
+ quant_groups = self._dropout(quants, enabled=update_codebook)
410
+
411
+ # Combine quantized features
412
+ if self.method == "product":
413
+ quant = torch.cat(quant_groups, dim=-1)
414
+ elif self.method == "residual":
415
+ quant = torch.cat(quant_groups, dim=-1).view(B, T, M, D).sum(dim=2)
416
+
417
+ return quant, loss, index
418
+
419
+ def _dropout(self, quants, enabled=True):
420
+ if enabled and self.training and self.ordered:
421
+ if self.dropout_rate_per_group == 0:
422
+ threshold = [
423
+ (i // self.n_heads_per_group * 1.0 / self.n_groups)
424
+ for i in range(0, len(quants), self.n_heads_per_group)
425
+ ]
426
+ elif self.dropout_rate_per_group == "exp":
427
+ x = [np.exp(4 * i / self.n_groups) for i in range(self.n_groups)]
428
+ x = np.asarray(x) / sum(x)
429
+ threshold = np.cumsum(np.asarray([0] + x))
430
+ else:
431
+ x = np.asarray(self.dropout_rate_per_group) / sum(
432
+ self.dropout_rate_per_group
433
+ )
434
+ threshold = np.cumsum(np.asarray([0] + x))
435
+
436
+ if self.ordered_axis == "sequence":
437
+ rate = torch.rand((quants[0].shape[0], 1, 1), device=quants[0].device)
438
+ elif self.ordered_axis == "frame":
439
+ rate = torch.rand(
440
+ (quants[0].shape[0], quants[0].shape[1], 1), device=quants[0].device
441
+ )
442
+
443
+ quant_groups = []
444
+ for i in range(0, len(quants), self.n_heads_per_group):
445
+ quant_group = torch.cat(quants[i : i + self.n_heads_per_group], dim=-1)
446
+ is_kept = threshold[i // self.n_heads_per_group] <= rate
447
+ quant_group = torch.where(
448
+ is_kept, quant_group, torch.zeros_like(quant_group)
449
+ )
450
+ quant_groups.append(quant_group)
451
+ elif self.ordered:
452
+ quant_groups = []
453
+ for i in range(0, len(quants), self.n_heads_per_group):
454
+ quant_group = torch.cat(quants[i : i + self.n_heads_per_group], dim=-1)
455
+ quant_groups.append(quant_group)
456
+ else:
457
+ quant_groups = quants
458
+
459
+ return quant_groups
460
+
461
+
462
+ class VectorQuantization(nn.Module):
463
+
464
+ def __init__(
465
+ self,
466
+ dim: int,
467
+ codebook_size: Union[int, list],
468
+ codebook_dim: Union[int, list] = None,
469
+ decay: float = 0.99,
470
+ epsilon: float = 1e-5,
471
+ threshold_ema_dead_code: float = 1.0,
472
+ commitment_weight: float = 1.0,
473
+ requires_projection: bool = False,
474
+ norm: str = "none",
475
+ **kwargs,
476
+ ):
477
+ super().__init__()
478
+ _codebook_dim: Union[int, list] = default(codebook_dim, dim)
479
+
480
+ requires_projection = _codebook_dim != dim or requires_projection
481
+ proj_dim = (
482
+ sum(_codebook_dim) if isinstance(_codebook_dim, list) else _codebook_dim
483
+ )
484
+ if requires_projection:
485
+ self.project_in = nn.Linear(dim, proj_dim)
486
+ self.project_out = nn.Linear(proj_dim, dim)
487
+ if norm == "weight_norm":
488
+ self.project_in = torch.nn.utils.weight_norm(self.project_in)
489
+ self.project_out = torch.nn.utils.weight_norm(self.project_out)
490
+ else:
491
+ self.norm = None
492
+ self.project_in = nn.Identity()
493
+ self.project_out = nn.Identity()
494
+
495
+ self.epsilon = epsilon
496
+ self.commitment_weight = commitment_weight
497
+ self.codebook_size = codebook_size
498
+
499
+ codebook_class = (
500
+ EuclideanCodebook
501
+ if isinstance(codebook_size, int)
502
+ else MultiHeadEuclideanCodebook
503
+ )
504
+ self._codebook = codebook_class(
505
+ dim=_codebook_dim,
506
+ codebook_size=codebook_size,
507
+ decay=decay,
508
+ epsilon=epsilon,
509
+ threshold_ema_dead_code=threshold_ema_dead_code,
510
+ **kwargs,
511
+ )
512
+ self.codebook_size = codebook_size
513
+
514
+ @property
515
+ def codebook(self):
516
+ return self._codebook.embed
517
+
518
+ def encode(self, x, x_len=None):
519
+ x = rearrange(x, "b d n -> b n d")
520
+ x = self.project_in(x)
521
+ embed_in = self._codebook.encode(x)
522
+ return embed_in
523
+
524
+ def decode(self, embed_ind, embed_len=None):
525
+ quantize = self._codebook.decode(embed_ind)
526
+ quantize = self.project_out(quantize)
527
+ quantize = rearrange(quantize, "b n d -> b d n")
528
+ return quantize
529
+
530
+ def decode_latent(self, latent, latent_len=None):
531
+ if latent_len is None:
532
+ latent_len = (
533
+ torch.Tensor([latent.shape[1]] * latent.shape[0])
534
+ .to(latent.device)
535
+ .int()
536
+ )
537
+
538
+ quantize, _, _ = self._codebook(latent, latent_len)
539
+ quantize = self.project_out(quantize)
540
+ return quantize
541
+
542
+ @torch.cuda.amp.autocast(dtype=torch.float32)
543
+ def forward(
544
+ self,
545
+ x,
546
+ x_len,
547
+ enable_vq=True,
548
+ update_codebook=True,
549
+ return_pre_quant=False,
550
+ return_dict=False,
551
+ ):
552
+ device = x.device
553
+
554
+ x = self.project_in(x)
555
+
556
+ quantize, commit_loss, embed_ind = self._codebook(
557
+ x, x_len, enable_vq=enable_vq, update_codebook=update_codebook
558
+ )
559
+ if self.training and update_codebook:
560
+ loss = torch.tensor(0.0, device=device, requires_grad=True)
561
+ if self.commitment_weight > 0:
562
+ loss = loss + commit_loss * self.commitment_weight
563
+ else:
564
+ loss = torch.tensor(0.0, device=device, requires_grad=False)
565
+
566
+ embed = quantize
567
+ quantize = self.project_out(quantize)
568
+
569
+ if return_dict:
570
+ return {
571
+ "quantize": quantize,
572
+ "loss": loss,
573
+ "embed": embed,
574
+ "embed_ind": embed_ind,
575
+ }
576
+ elif return_pre_quant:
577
+ pre_quantize = self.project_out(x)
578
+ return (pre_quantize, quantize), loss, embed_ind
579
+ else:
580
+ return quantize, loss, embed_ind
fireredtts/modules/acoustic_llm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .acoustic_llm import AcousticLLM
fireredtts/modules/acoustic_llm/acoustic_llm.py ADDED
@@ -0,0 +1,876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from time import time
3
+ from torch.utils.checkpoint import checkpoint
4
+ from transformers import (
5
+ GPT2Config,
6
+ GPT2Model,
7
+ GPT2PreTrainedModel,
8
+ LogitsProcessorList,
9
+ LogitsWarper,
10
+ StoppingCriteria,
11
+ StoppingCriteriaList,
12
+ )
13
+ from transformers.generation.streamers import BaseStreamer
14
+ from transformers.generation.utils import (
15
+ GenerationConfig,
16
+ GenerateDecoderOnlyOutput,
17
+ GenerateEncoderDecoderOutput,
18
+ GenerateNonBeamOutput,
19
+ )
20
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
21
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
22
+ from typing import Any, Dict, Optional, Tuple, Union
23
+
24
+ import functools
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+
31
+ class MultiHeadRepetitionPenaltyLogitsProcessor(LogitsWarper):
32
+
33
+ def __init__(
34
+ self, penalty: float = 2.0, n_heads: int = 4, n_frames: int = -1, start_index=0
35
+ ):
36
+ if not isinstance(penalty, float) or not (penalty > 0):
37
+ raise ValueError(
38
+ f"`penalty` has to be a strictly positive float, but is {penalty}"
39
+ )
40
+
41
+ self.penalty = penalty
42
+ self.n_heads = n_heads
43
+ self.n_frames = n_frames
44
+ self.start_index = start_index
45
+
46
+ def __call__(
47
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
48
+ ) -> torch.FloatTensor:
49
+ input_ids = input_ids[:, self.start_index :]
50
+ if input_ids.size(1) == 0:
51
+ return scores
52
+
53
+ if self.n_frames <= 0:
54
+ input_ids = torch.flip(input_ids, [1])[:, self.n_heads - 1 :: self.n_heads]
55
+ else:
56
+ input_ids = torch.flip(input_ids, [1])[
57
+ :, self.n_heads - 1 : self.n_heads * self.n_frames : self.n_heads
58
+ ]
59
+ score = torch.gather(scores, 1, input_ids)
60
+
61
+ # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
62
+ if self.penalty > 100:
63
+ score = torch.full_like(score, -1e3)
64
+ else:
65
+ score = torch.where(score < 0, score * self.penalty, score / self.penalty)
66
+
67
+ scores.scatter_(1, input_ids, score)
68
+ return scores
69
+
70
+
71
+ def null_position_embeddings(range, dim):
72
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
73
+
74
+
75
+ class FixedStoppingCriteria(StoppingCriteria):
76
+
77
+ def __init__(self, running_steps, start_index=0):
78
+ self.running_steps = running_steps
79
+ self.start_index = start_index
80
+
81
+ def __call__(
82
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
83
+ ) -> torch.BoolTensor:
84
+ assert input_ids.shape[0] == 1, input_ids.shape
85
+ if input_ids.shape[1] - self.start_index >= self.running_steps:
86
+ return torch.tensor([True]).to(input_ids.device)
87
+ return torch.tensor([False]).to(input_ids.device)
88
+
89
+
90
+ class DelayStoppingCriteria(StoppingCriteria):
91
+
92
+ def __init__(self, eos_token_id, delay_steps):
93
+ self.delay_steps = delay_steps
94
+ self.eos_token_id = torch.tensor(eos_token_id)
95
+
96
+ def __call__(
97
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
98
+ ) -> torch.BoolTensor:
99
+ assert input_ids.shape[0] == 1, input_ids.shape
100
+
101
+ if (input_ids == self.eos_token_id).any():
102
+ index = (input_ids[0] == self.eos_token_id).nonzero(as_tuple=True)[0][0]
103
+ if index + self.delay_steps < input_ids.shape[1]:
104
+ return torch.tensor([True]).to(input_ids.device)
105
+ return torch.tensor([False]).to(input_ids.device)
106
+
107
+
108
+ class SuppressionLogitsProcessor(LogitsWarper):
109
+
110
+ def __init__(self, suppressed_ids=[]):
111
+ self.suppressed_ids = suppressed_ids
112
+
113
+ def __call__(
114
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
115
+ ) -> torch.FloatTensor:
116
+ for sid in self.suppressed_ids:
117
+ scores[..., sid] = scores.min()
118
+ return scores
119
+
120
+
121
+ class MHGPT2InferenceModel(GPT2PreTrainedModel):
122
+
123
+ def __init__(
124
+ self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True
125
+ ):
126
+ super().__init__(config)
127
+ self.transformer = gpt
128
+ self.text_pos_embedding = text_pos_emb
129
+ self.embeddings = embeddings
130
+ self.lm_head = nn.ModuleList([norm, linear]) # nn.Sequential(norm, linear)
131
+ self.kv_cache = kv_cache
132
+
133
+ # Multi-head configuration
134
+ self.n_heads = len(linear)
135
+
136
+ # Model parallel
137
+ self.model_parallel = False
138
+ self.device_map = None
139
+ self.cached_mel_emb = None
140
+ self.cached_mel_parallel_emb = None
141
+
142
+ def store_mel_emb(self, mel_emb):
143
+ self.cached_mel_emb = mel_emb
144
+
145
+ def store_mel_parallel_emb(self, mel_emb):
146
+ self.cached_mel_parallel_emb = mel_emb
147
+
148
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
149
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
150
+ if not self.kv_cache:
151
+ past_key_values = None
152
+
153
+ attention_mask = kwargs.get("attention_mask", None)
154
+ position_ids = kwargs.get("position_ids", None)
155
+
156
+ if attention_mask is not None and position_ids is None:
157
+ position_ids = attention_mask.long().cumsum(-1) - 1
158
+ position_ids.masked_fill_(attention_mask == 0, 1)
159
+ if past_key_values:
160
+ position_ids = position_ids[:, -1].unsqueeze(-1)
161
+ else:
162
+ position_ids = None
163
+ return {
164
+ "input_ids": input_ids,
165
+ "past_key_values": past_key_values,
166
+ "use_cache": kwargs.get("use_cache"),
167
+ "position_ids": position_ids,
168
+ "attention_mask": attention_mask,
169
+ "token_type_ids": token_type_ids,
170
+ }
171
+
172
+ def forward(
173
+ self,
174
+ input_ids=None,
175
+ past_key_values=None,
176
+ attention_mask=None,
177
+ token_type_ids=None,
178
+ position_ids=None,
179
+ head_mask=None,
180
+ inputs_embeds=None,
181
+ encoder_hidden_states=None,
182
+ encoder_attention_mask=None,
183
+ labels=None,
184
+ use_cache=None,
185
+ output_attentions=None,
186
+ output_hidden_states=None,
187
+ return_dict=None,
188
+ ):
189
+ assert self.cached_mel_emb is not None
190
+ assert inputs_embeds is None # Not supported by this inference model.
191
+ assert labels is None # Training not supported by this inference model.
192
+ return_dict = (
193
+ return_dict if return_dict is not None else self.config.use_return_dict
194
+ )
195
+ # Create embedding
196
+ mel_len = self.cached_mel_emb.shape[1]
197
+ attention_mask = None
198
+ position_ids = None
199
+
200
+ if input_ids.shape[1] != 1 and past_key_values is None:
201
+ text_inputs = input_ids[:, mel_len:]
202
+ text_emb = sum(
203
+ [self.embeddings[i](text_inputs[:, :, i]) for i in range(self.n_heads)]
204
+ )
205
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
206
+
207
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
208
+ mel_emb = self.cached_mel_emb.repeat_interleave(
209
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
210
+ )
211
+ else: # this outcome only occurs once per loop in most cases
212
+ mel_emb = self.cached_mel_emb
213
+
214
+ if self.cached_mel_parallel_emb is not None:
215
+ text_emb = (
216
+ text_emb + self.cached_mel_parallel_emb[:, : text_emb.shape[1]]
217
+ )
218
+
219
+ emb = torch.cat([mel_emb, text_emb], dim=1)
220
+ else: # KV-cache mode
221
+ text_inputs = input_ids[:, mel_len:]
222
+ emb = sum(
223
+ [self.embeddings[i](text_inputs[:, -1, i]) for i in range(self.n_heads)]
224
+ )
225
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
226
+ text_inputs.shape[1] - 1, emb.device
227
+ )
228
+
229
+ if self.cached_mel_parallel_emb is not None:
230
+ emb = emb + self.cached_mel_parallel_emb[:, text_inputs.shape[1] - 1]
231
+
232
+ transformer_outputs = self.transformer(
233
+ inputs_embeds=emb,
234
+ past_key_values=past_key_values,
235
+ attention_mask=attention_mask,
236
+ token_type_ids=token_type_ids,
237
+ position_ids=position_ids,
238
+ head_mask=head_mask,
239
+ encoder_hidden_states=encoder_hidden_states,
240
+ encoder_attention_mask=encoder_attention_mask,
241
+ use_cache=use_cache,
242
+ output_attentions=True, # output_attentions,
243
+ output_hidden_states=output_hidden_states,
244
+ return_dict=return_dict,
245
+ )
246
+ hidden_states = transformer_outputs[0]
247
+ past_key_values = transformer_outputs.past_key_values
248
+ output_hidden_states = transformer_outputs.hidden_states
249
+ output_attentions = transformer_outputs.attentions
250
+
251
+ # Set device for model parallelism
252
+ if self.model_parallel:
253
+ if torch.backends.mps.is_available():
254
+ self.to(self.transformer.first_device)
255
+ else:
256
+ torch.cuda.set_device(self.transformer.first_device)
257
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
258
+
259
+ lm_logits = self.lm_head[0](hidden_states)
260
+ lm_logits = [head(lm_logits) for head in self.lm_head[1]]
261
+ lm_logits = torch.stack(lm_logits, dim=2)
262
+
263
+ if not return_dict:
264
+ return (lm_logits,) + transformer_outputs[1:]
265
+
266
+ output = CausalLMOutputWithCrossAttentions(
267
+ loss=None,
268
+ logits=lm_logits,
269
+ past_key_values=past_key_values,
270
+ hidden_states=output_hidden_states,
271
+ attentions=output_attentions,
272
+ )
273
+ return output
274
+
275
+ def _sample(
276
+ self,
277
+ input_ids: torch.LongTensor,
278
+ logits_processor: LogitsProcessorList,
279
+ stopping_criteria: StoppingCriteriaList,
280
+ generation_config: GenerationConfig,
281
+ synced_gpus: bool,
282
+ streamer: Optional["BaseStreamer"],
283
+ **model_kwargs,
284
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
285
+ r"""
286
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
287
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
288
+
289
+ Parameters:
290
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
291
+ The sequence used as a prompt for the generation.
292
+ logits_processor (`LogitsProcessorList`):
293
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
294
+ used to modify the prediction scores of the language modeling head applied at each generation step.
295
+ stopping_criteria (`StoppingCriteriaList`):
296
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
297
+ used to tell if the generation loop should stop.
298
+ generation_config ([`~generation.GenerationConfig`]):
299
+ The generation configuration to be used as parametrization of the decoding method.
300
+ synced_gpus (`bool`):
301
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
302
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
303
+ streamer (`BaseStreamer`, *optional*):
304
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
305
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
306
+ model_kwargs:
307
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
308
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
309
+
310
+ Return:
311
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
312
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
313
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
314
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
315
+ `model.config.is_encoder_decoder=True`.
316
+ """
317
+ # init values
318
+ pad_token_id = generation_config._pad_token_tensor
319
+ output_attentions = generation_config.output_attentions
320
+ output_hidden_states = generation_config.output_hidden_states
321
+ output_scores = generation_config.output_scores
322
+ output_logits = generation_config.output_logits
323
+ return_dict_in_generate = generation_config.return_dict_in_generate
324
+ max_length = generation_config.max_length
325
+ has_eos_stopping_criteria = any(
326
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
327
+ )
328
+ do_sample = generation_config.do_sample
329
+
330
+ # init attention / hidden states / scores tuples
331
+ scores = () if (return_dict_in_generate and output_scores) else None
332
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
333
+ decoder_attentions = (
334
+ () if (return_dict_in_generate and output_attentions) else None
335
+ )
336
+ cross_attentions = (
337
+ () if (return_dict_in_generate and output_attentions) else None
338
+ )
339
+ decoder_hidden_states = (
340
+ () if (return_dict_in_generate and output_hidden_states) else None
341
+ )
342
+
343
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
344
+ if return_dict_in_generate and self.config.is_encoder_decoder:
345
+ encoder_attentions = (
346
+ model_kwargs["encoder_outputs"].get("attentions")
347
+ if output_attentions
348
+ else None
349
+ )
350
+ encoder_hidden_states = (
351
+ model_kwargs["encoder_outputs"].get("hidden_states")
352
+ if output_hidden_states
353
+ else None
354
+ )
355
+
356
+ # keep track of which sequences are already finished
357
+ batch_size, cur_len, num_streams = input_ids.shape
358
+ this_peer_finished = False
359
+ unfinished_sequences = torch.ones(
360
+ batch_size, dtype=torch.long, device=input_ids.device
361
+ )
362
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
363
+
364
+ while self._has_unfinished_sequences(
365
+ this_peer_finished,
366
+ synced_gpus,
367
+ device=input_ids.device,
368
+ cur_len=cur_len,
369
+ max_length=max_length,
370
+ ):
371
+ # prepare model inputs
372
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
373
+
374
+ # prepare variable output controls (note: some models won't accept all output controls)
375
+ model_inputs.update(
376
+ {"output_attentions": output_attentions} if output_attentions else {}
377
+ )
378
+ model_inputs.update(
379
+ {"output_hidden_states": output_hidden_states}
380
+ if output_hidden_states
381
+ else {}
382
+ )
383
+
384
+ # forward pass to get next token
385
+ outputs = self(**model_inputs, return_dict=True)
386
+
387
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
388
+ model_kwargs = self._update_model_kwargs_for_generation(
389
+ outputs,
390
+ model_kwargs,
391
+ is_encoder_decoder=self.config.is_encoder_decoder,
392
+ )
393
+ if synced_gpus and this_peer_finished:
394
+ continue
395
+
396
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
397
+ # (the clone itself is always small)
398
+ next_token_logits = outputs.logits.clone()[:, -1].float()
399
+ next_token_logits = next_token_logits.to(input_ids.device)
400
+
401
+ # pre-process distribution
402
+ batch_size, seq_len, num_streams = input_ids.shape
403
+ rearrange_input_ids = rearrange(input_ids, "b l n -> (b n) l")
404
+ next_token_logits = rearrange(next_token_logits, "b n d -> (b n) d")
405
+ next_token_scores = logits_processor(rearrange_input_ids, next_token_logits)
406
+ next_token_scores = rearrange(
407
+ next_token_scores, "(b n) d -> b n d", b=batch_size
408
+ )
409
+
410
+ # Store scores, attentions and hidden_states when required
411
+ if return_dict_in_generate:
412
+ if output_scores:
413
+ scores += (next_token_scores,)
414
+ if output_logits:
415
+ raw_logits += (next_token_logits,)
416
+ if output_attentions:
417
+ decoder_attentions += (
418
+ (outputs.decoder_attentions,)
419
+ if self.config.is_encoder_decoder
420
+ else (outputs.attentions,)
421
+ )
422
+ if self.config.is_encoder_decoder:
423
+ cross_attentions += (outputs.cross_attentions,)
424
+
425
+ if output_hidden_states:
426
+ decoder_hidden_states += (
427
+ (outputs.decoder_hidden_states,)
428
+ if self.config.is_encoder_decoder
429
+ else (outputs.hidden_states,)
430
+ )
431
+
432
+ # token selection
433
+ if do_sample:
434
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
435
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
436
+ probs = probs.view(-1, probs.shape[-1])
437
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
438
+ next_tokens = next_tokens.view(*next_token_scores.shape[:-1])
439
+ else:
440
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
441
+
442
+ # finished sentences should have their next token be a padding token
443
+ if has_eos_stopping_criteria:
444
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
445
+ 1 - unfinished_sequences
446
+ )
447
+
448
+ # update generated ids, model inputs, and length for next step
449
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=1)
450
+ if streamer is not None:
451
+ streamer.put(next_tokens.cpu())
452
+
453
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
454
+ input_ids, scores
455
+ )
456
+ this_peer_finished = unfinished_sequences.max() == 0
457
+ cur_len += 1
458
+
459
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
460
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
461
+ del outputs
462
+
463
+ if streamer is not None:
464
+ streamer.end()
465
+
466
+ if return_dict_in_generate:
467
+ if self.config.is_encoder_decoder:
468
+ return GenerateEncoderDecoderOutput(
469
+ sequences=input_ids,
470
+ scores=scores,
471
+ logits=raw_logits,
472
+ encoder_attentions=encoder_attentions,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ decoder_attentions=decoder_attentions,
475
+ cross_attentions=cross_attentions,
476
+ decoder_hidden_states=decoder_hidden_states,
477
+ past_key_values=model_kwargs.get("past_key_values"),
478
+ )
479
+ else:
480
+ return GenerateDecoderOnlyOutput(
481
+ sequences=input_ids,
482
+ scores=scores,
483
+ logits=raw_logits,
484
+ attentions=decoder_attentions,
485
+ hidden_states=decoder_hidden_states,
486
+ past_key_values=model_kwargs.get("past_key_values"),
487
+ )
488
+ else:
489
+ return input_ids
490
+
491
+
492
+ class LearnedPositionEmbeddings(nn.Module):
493
+
494
+ def __init__(self, seq_len, model_dim, init=0.02):
495
+ super().__init__()
496
+ self.emb = nn.Embedding(seq_len, model_dim)
497
+ self.emb.weight.data.normal_(mean=0.0, std=init)
498
+
499
+ def forward(self, x):
500
+ sl = x.shape[1]
501
+ return self.emb(torch.arange(0, sl, device=x.device))
502
+
503
+ def get_fixed_embedding(self, ind, dev):
504
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
505
+
506
+
507
+ def build_hf_gpt_transformer(
508
+ layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
509
+ ):
510
+ gpt_config = GPT2Config(
511
+ vocab_size=256,
512
+ n_positions=max_mel_seq_len + max_text_seq_len,
513
+ n_ctx=max_mel_seq_len + max_text_seq_len,
514
+ n_embd=model_dim,
515
+ n_layer=layers,
516
+ n_head=heads,
517
+ use_cache=not checkpointing,
518
+ scale_attn_by_inverse_layer_idx=True,
519
+ reorder_and_upcast_attn=True,
520
+ attn_implementation="sdpa",
521
+ )
522
+ gpt = GPT2Model(gpt_config)
523
+
524
+ if checkpointing:
525
+ gpt.gradient_checkpointing_enable()
526
+
527
+ del gpt.wpe, gpt.wte
528
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
529
+ mel_pos_embs = LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
530
+
531
+ return gpt, mel_pos_embs
532
+
533
+
534
+ class AcousticLLM(nn.Module):
535
+
536
+ def __init__(
537
+ self,
538
+ # Model
539
+ n_stacks=2,
540
+ layers=12,
541
+ model_dim=1024,
542
+ heads=16,
543
+ # Text
544
+ max_text_tokens=120,
545
+ number_text_tokens=8194,
546
+ start_text_token=8192,
547
+ stop_text_token=8193,
548
+ # Speech
549
+ n_frames_per_step=4,
550
+ n_heads_per_frame=1,
551
+ max_speech_tokens=250,
552
+ number_speech_tokens=8194,
553
+ start_speech_token=8192,
554
+ stop_speech_token=8193,
555
+ # CoS Prediction
556
+ streaming=False,
557
+ streaming_delayed_frames=4,
558
+ accumulative_speech_embedding=False,
559
+ upsample_factors=2,
560
+ # Reference embedding
561
+ max_conditioning_inputs=1,
562
+ speaker_embedding_pretrained=True,
563
+ speaker_embedding_ckpt=None,
564
+ speaker_embedding_dim=256,
565
+ # For training
566
+ checkpointing=True,
567
+ loss_weights=1.0,
568
+ # For inference
569
+ delay_prediction=1,
570
+ temperature=0.3,
571
+ length_penalty=1.0,
572
+ repetition_penalty=2.0,
573
+ top_p=0.2,
574
+ top_k=50,
575
+ ):
576
+ super().__init__()
577
+ self.n_stacks = n_stacks
578
+ self.number_text_tokens = number_text_tokens
579
+ self.start_text_token = start_text_token
580
+ self.stop_text_token = stop_text_token
581
+ self.number_speech_tokens = number_speech_tokens
582
+ self.start_speech_token = start_speech_token
583
+ self.stop_speech_token = stop_speech_token
584
+ self.layers = layers
585
+ self.heads = heads
586
+
587
+ self.streaming = streaming
588
+ self.streaming_delayed_frames = streaming_delayed_frames
589
+ self.accumulative_speech_embedding = accumulative_speech_embedding
590
+ self.upsample_factors = upsample_factors
591
+
592
+ self.n_frames_per_step = n_frames_per_step
593
+ self.n_heads_per_frame = n_heads_per_frame
594
+ self.number_speech_heads = n_heads_per_frame * n_frames_per_step
595
+
596
+ self.max_speech_tokens = max_speech_tokens
597
+ self.max_text_tokens = max_text_tokens
598
+ self.model_dim = model_dim
599
+ self.max_conditioning_inputs = max_conditioning_inputs
600
+
601
+ self.speaker_embedding_pretrained = speaker_embedding_pretrained
602
+ self.speaker_embedding_ckpt = speaker_embedding_ckpt
603
+ self.speaker_embedding_dim = speaker_embedding_dim
604
+
605
+ # For training
606
+ self.loss_weights = loss_weights
607
+
608
+ # For inference
609
+ self.delay_prediction = delay_prediction
610
+ self.temperature = temperature
611
+ self.length_penalty = length_penalty
612
+ self.repetition_penalty = repetition_penalty
613
+ self.top_p = top_p
614
+ self.top_k = top_k
615
+
616
+ # Conditional embedding
617
+ self.reference_embedding = nn.Sequential(
618
+ nn.Linear(speaker_embedding_dim, 256),
619
+ nn.Tanh(),
620
+ nn.Linear(256, model_dim),
621
+ )
622
+
623
+ self.text_embedding = nn.Embedding(self.number_text_tokens + 1, model_dim)
624
+ self.text_embedding.weight.data.normal_(mean=0.0, std=0.02)
625
+
626
+ self.mel_embedding = nn.ModuleList(
627
+ [
628
+ nn.Embedding(self.number_speech_tokens, model_dim)
629
+ for _ in range(self.number_speech_heads)
630
+ ]
631
+ )
632
+ for module in self.mel_embedding:
633
+ module.weight.data.normal_(mean=0.0, std=0.02)
634
+
635
+ # Build GPTs
636
+ self.gpt, self.mel_pos_embedding = build_hf_gpt_transformer(
637
+ layers,
638
+ model_dim,
639
+ heads,
640
+ self.max_speech_tokens + 2 + self.max_conditioning_inputs,
641
+ self.max_text_tokens + 2,
642
+ checkpointing,
643
+ )
644
+ self.final_norm = nn.LayerNorm(model_dim)
645
+ self.mel_head = nn.ModuleList(
646
+ [
647
+ nn.Linear(model_dim, self.number_speech_tokens)
648
+ for _ in range(self.number_speech_heads)
649
+ ]
650
+ )
651
+
652
+ def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=True, half=False):
653
+ seq_length = self.max_speech_tokens + self.max_text_tokens + 2
654
+ gpt_config = GPT2Config(
655
+ vocab_size=self.max_speech_tokens,
656
+ n_positions=seq_length,
657
+ n_ctx=seq_length,
658
+ n_embd=self.model_dim,
659
+ n_layer=self.layers,
660
+ n_head=self.heads,
661
+ gradient_checkpointing=False,
662
+ use_cache=True,
663
+ )
664
+
665
+ self.inference_model = MHGPT2InferenceModel(
666
+ gpt_config,
667
+ self.gpt,
668
+ self.mel_pos_embedding,
669
+ self.mel_embedding,
670
+ self.final_norm,
671
+ self.mel_head,
672
+ kv_cache=kv_cache,
673
+ )
674
+ self.inference_model.eval()
675
+
676
+ def build_aligned_inputs_and_targets(
677
+ self, seqs, lens, start_token, stop_token, delay=0
678
+ ):
679
+ for i in range(seqs.shape[0]):
680
+ seqs[i, lens[i] :] = stop_token
681
+
682
+ if len(seqs.shape) == 2:
683
+ inp = F.pad(
684
+ seqs, (self.streaming_delayed_frames, 0), value=start_token
685
+ ).type_as(seqs)
686
+ inp = F.pad(inp, (0, 1), value=stop_token).type_as(inp)
687
+ tar = F.pad(inp[:, 1:], (0, 1), value=stop_token).type_as(seqs)
688
+ else:
689
+ inp = F.pad(
690
+ seqs, (0, 0, self.streaming_delayed_frames, 0), value=start_token
691
+ ).type_as(seqs)
692
+ inp = F.pad(inp, (0, 0, 0, 1), value=stop_token).type_as(inp)
693
+ tar = F.pad(inp[:, 1:], (0, 0, 0, 1), value=stop_token).type_as(seqs)
694
+
695
+ if delay > 0:
696
+ pad_size = delay * (inp.shape[2] - 1)
697
+ L = inp.shape[1] + pad_size
698
+ inp = F.pad(inp, (0, 0, pad_size, 0), value=start_token).type_as(inp)
699
+ inp = F.pad(inp, (0, 0, 0, pad_size), value=stop_token).type_as(inp)
700
+ inp = torch.stack(
701
+ [
702
+ inp[:, pad_size - i * delay : pad_size - i * delay + L, i]
703
+ for i in range(inp.shape[-1])
704
+ ],
705
+ dim=-1,
706
+ )
707
+
708
+ tar = F.pad(tar, (0, 0, pad_size, 0), value=start_token).type_as(tar)
709
+ tar = F.pad(tar, (0, 0, 0, pad_size), value=stop_token).type_as(tar)
710
+ tar = torch.stack(
711
+ [
712
+ tar[:, pad_size - i * delay : pad_size - i * delay + L, i]
713
+ for i in range(tar.shape[-1])
714
+ ],
715
+ dim=-1,
716
+ )
717
+
718
+ lens += pad_size
719
+
720
+ return inp, tar, lens + self.streaming_delayed_frames + 1
721
+
722
+ def get_logits(
723
+ self,
724
+ final_norm,
725
+ first_inputs,
726
+ first_head,
727
+ speech_conditioning_inputs=None,
728
+ attention_mask=None,
729
+ get_attns=False,
730
+ return_latent=False,
731
+ ):
732
+ emb = first_inputs
733
+ if speech_conditioning_inputs is not None:
734
+ emb = torch.cat([speech_conditioning_inputs, emb], dim=1)
735
+
736
+ gpt_out = self.gpt(
737
+ inputs_embeds=emb,
738
+ return_dict=True,
739
+ attention_mask=attention_mask,
740
+ output_attentions=get_attns,
741
+ )
742
+
743
+ enc = gpt_out.last_hidden_state
744
+ if speech_conditioning_inputs is not None:
745
+ enc = enc[:, 1:]
746
+ enc = final_norm(enc)
747
+
748
+ first_logits = [head(enc).permute(0, 2, 1) for head in first_head]
749
+
750
+ return first_logits
751
+
752
+ @torch.cuda.amp.autocast()
753
+ def get_conditioning(self, speech_conditioning_input):
754
+ if hasattr(self, "reference_encoder"):
755
+ if len(speech_conditioning_input.shape) == 2:
756
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1)
757
+ speech_conditioning_input = self.reference_encoder(
758
+ speech_conditioning_input
759
+ )
760
+ conds = self.reference_embedding(speech_conditioning_input)
761
+ return conds
762
+
763
+ def inference_speech(
764
+ self,
765
+ speech_conditioning_latent,
766
+ text_inputs,
767
+ input_tokens=None,
768
+ num_return_sequences=1,
769
+ max_generate_length=None,
770
+ **hf_generate_kwargs,
771
+ ):
772
+ if not hasattr(self, "inference_model"):
773
+ self.post_init_gpt2_config()
774
+
775
+ # Cond
776
+ emb = speech_conditioning_latent
777
+ self.inference_model.store_mel_emb(emb)
778
+
779
+ # Text
780
+ text = torch.repeat_interleave(text_inputs, self.upsample_factors, dim=1)
781
+ text = F.pad(
782
+ text,
783
+ (0, self.streaming_delayed_frames + self.number_speech_heads - 1),
784
+ value=self.stop_speech_token,
785
+ )
786
+ text_embedding = self.text_embedding(text)
787
+ self.inference_model.store_mel_parallel_emb(text_embedding)
788
+
789
+ fake_inputs = torch.full(
790
+ (
791
+ emb.shape[0], # should be 1 for stable inference
792
+ emb.shape[1] + 1, # + 1 for the start_speech_token
793
+ self.number_speech_heads,
794
+ ),
795
+ fill_value=self.start_speech_token,
796
+ dtype=torch.long,
797
+ device=text_inputs.device,
798
+ )
799
+ if input_tokens is None:
800
+ inputs = fake_inputs
801
+ prompt_index = 0
802
+ else:
803
+ prompt, _, _ = self.build_aligned_inputs_and_targets(
804
+ input_tokens,
805
+ torch.Tensor([len(input_tokens[0])]).int(),
806
+ self.start_speech_token,
807
+ self.stop_speech_token,
808
+ self.delay_prediction,
809
+ )
810
+ prompt = prompt[:, 1 : 1 + input_tokens.shape[1]]
811
+ inputs = torch.cat([fake_inputs, prompt], dim=1)
812
+ prompt_index = input_tokens.shape[1]
813
+ trunc_index = fake_inputs.shape[1]
814
+
815
+ stop_criteria = StoppingCriteriaList(
816
+ [FixedStoppingCriteria(text_embedding.shape[1], start_index=emb.shape[1])]
817
+ )
818
+
819
+ logits_processor = (
820
+ LogitsProcessorList(
821
+ [
822
+ MultiHeadRepetitionPenaltyLogitsProcessor(
823
+ penalty=self.repetition_penalty,
824
+ n_heads=self.n_heads_per_frame,
825
+ n_frames=-1,
826
+ start_index=trunc_index + prompt_index,
827
+ )
828
+ ]
829
+ )
830
+ if self.repetition_penalty > 1.0
831
+ else LogitsProcessorList()
832
+ )
833
+ logits_processor.append(
834
+ SuppressionLogitsProcessor(suppressed_ids=[self.stop_speech_token])
835
+ )
836
+
837
+ max_length = (
838
+ trunc_index + self.max_speech_tokens - 1
839
+ if max_generate_length is None
840
+ else trunc_index + max_generate_length
841
+ )
842
+
843
+ # Recommandation of temp & top_p: (0.8, 0.8), (0.5, 0.5), (0.3, 0.2), (0.2, 0.1)
844
+ gen = self.inference_model.generate(
845
+ inputs,
846
+ bos_token_id=self.start_speech_token,
847
+ pad_token_id=self.stop_speech_token,
848
+ eos_token_id=self.stop_speech_token + 2,
849
+ max_length=max_length,
850
+ stopping_criteria=stop_criteria,
851
+ logits_processor=logits_processor,
852
+ num_return_sequences=num_return_sequences,
853
+ do_sample=True,
854
+ temperature=self.temperature,
855
+ length_penalty=self.length_penalty,
856
+ top_p=self.top_p,
857
+ top_k=self.top_k,
858
+ **hf_generate_kwargs,
859
+ )
860
+
861
+ seq = gen[0][trunc_index:]
862
+
863
+ start, heads = 0, []
864
+ for j in range(self.number_speech_heads):
865
+ head = seq[j * self.delay_prediction :, j]
866
+ start_indices = (head == self.start_speech_token).nonzero(as_tuple=True)[0]
867
+ start = max(start, start_indices[-1] + 1 if len(start_indices) > 0 else 0)
868
+ stop = (head == self.stop_speech_token).nonzero(as_tuple=True)[0]
869
+ stop = stop[0] if len(stop) > 0 else len(head)
870
+ heads.append(head[:stop])
871
+
872
+ min_length = min([len(x) for x in heads])
873
+ seq = torch.stack(
874
+ [head[start + prompt_index : min_length] for head in heads], dim=-1
875
+ )
876
+ return [seq]
fireredtts/modules/bigvgan/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .bigvgan import BigVGAN
2
+ from .mel_spectrogram import MelExtractor
fireredtts/modules/bigvgan/activations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(
27
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
28
+ ):
29
+ """
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ """
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake ∶= x + 1/a * sin^2 (xa)
56
+ """
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ """
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ """
82
+
83
+ def __init__(
84
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
85
+ ):
86
+ """
87
+ Initialization.
88
+ INPUT:
89
+ - in_features: shape of the input
90
+ - alpha - trainable parameter that controls frequency
91
+ - beta - trainable parameter that controls magnitude
92
+ alpha is initialized to 1 by default, higher values = higher-frequency.
93
+ beta is initialized to 1 by default, higher values = higher-magnitude.
94
+ alpha will be trained along with the rest of your model.
95
+ """
96
+ super(SnakeBeta, self).__init__()
97
+ self.in_features = in_features
98
+
99
+ # initialize alpha
100
+ self.alpha_logscale = alpha_logscale
101
+ if self.alpha_logscale: # log scale alphas initialized to zeros
102
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
103
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
104
+ else: # linear scale alphas initialized to ones
105
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
106
+ self.beta = Parameter(torch.ones(in_features) * alpha)
107
+
108
+ self.alpha.requires_grad = alpha_trainable
109
+ self.beta.requires_grad = alpha_trainable
110
+
111
+ self.no_div_by_zero = 0.000000001
112
+
113
+ def forward(self, x):
114
+ """
115
+ Forward pass of the function.
116
+ Applies the function to the input elementwise.
117
+ SnakeBeta := x + 1/b * sin^2 (xa)
118
+ """
119
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
120
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
121
+ if self.alpha_logscale:
122
+ alpha = torch.exp(alpha)
123
+ beta = torch.exp(beta)
124
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
125
+
126
+ return x
fireredtts/modules/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
fireredtts/modules/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
fireredtts/modules/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ # LICENSE is in incl_licenses directory.
14
+ def sinc(x: torch.Tensor):
15
+ """
16
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
17
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
18
+ """
19
+ return torch.where(
20
+ x == 0,
21
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x,
23
+ )
24
+
25
+
26
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
27
+ # https://adefossez.github.io/julius/julius/lowpass.html
28
+ # LICENSE is in incl_licenses directory.
29
+ def kaiser_sinc_filter1d(
30
+ cutoff, half_width, kernel_size
31
+ ): # return filter [1,1,kernel_size]
32
+ even = kernel_size % 2 == 0
33
+ half_size = kernel_size // 2
34
+
35
+ # For kaiser window
36
+ delta_f = 4 * half_width
37
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
38
+ if A > 50.0:
39
+ beta = 0.1102 * (A - 8.7)
40
+ elif A >= 21.0:
41
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
42
+ else:
43
+ beta = 0.0
44
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
45
+
46
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
47
+ if even:
48
+ time = torch.arange(-half_size, half_size) + 0.5
49
+ else:
50
+ time = torch.arange(kernel_size) - half_size
51
+ if cutoff == 0:
52
+ filter_ = torch.zeros_like(time)
53
+ else:
54
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
55
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
56
+ # of the constant component in the input signal.
57
+ filter_ /= filter_.sum()
58
+ filter = filter_.view(1, 1, kernel_size)
59
+
60
+ return filter
61
+
62
+
63
+ class LowPassFilter1d(nn.Module):
64
+ def __init__(
65
+ self,
66
+ cutoff=0.5,
67
+ half_width=0.6,
68
+ stride: int = 1,
69
+ padding: bool = True,
70
+ padding_mode: str = "replicate",
71
+ kernel_size: int = 12,
72
+ ):
73
+ # kernel_size should be even number for stylegan3 setup,
74
+ # in this implementation, odd number is also possible.
75
+ super().__init__()
76
+ if cutoff < -0.0:
77
+ raise ValueError("Minimum cutoff must be larger than zero.")
78
+ if cutoff > 0.5:
79
+ raise ValueError("A cutoff above 0.5 does not make sense.")
80
+ self.kernel_size = kernel_size
81
+ self.even = kernel_size % 2 == 0
82
+ self.pad_left = kernel_size // 2 - int(self.even)
83
+ self.pad_right = kernel_size // 2
84
+ self.stride = stride
85
+ self.padding = padding
86
+ self.padding_mode = padding_mode
87
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
88
+ self.register_buffer("filter", filter)
89
+
90
+ # input [B, C, T]
91
+ def forward(self, x):
92
+ _, C, _ = x.shape
93
+
94
+ if self.padding:
95
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
96
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
97
+
98
+ return out
fireredtts/modules/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx
fireredtts/modules/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ from .alias_free_torch import Activation1d as TorchActivation1d
8
+ from .activations import Snake, SnakeBeta
9
+
10
+
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class AMPBlock1(torch.nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels,
25
+ kernel_size=3,
26
+ dilation=(1, 3, 5),
27
+ activation=None,
28
+ snake_logscale=True,
29
+ ):
30
+ super(AMPBlock1, self).__init__()
31
+
32
+ self.convs1 = nn.ModuleList(
33
+ [
34
+ weight_norm(
35
+ Conv1d(
36
+ channels,
37
+ channels,
38
+ kernel_size,
39
+ 1,
40
+ dilation=dilation[0],
41
+ padding=get_padding(kernel_size, dilation[0]),
42
+ )
43
+ ),
44
+ weight_norm(
45
+ Conv1d(
46
+ channels,
47
+ channels,
48
+ kernel_size,
49
+ 1,
50
+ dilation=dilation[1],
51
+ padding=get_padding(kernel_size, dilation[1]),
52
+ )
53
+ ),
54
+ weight_norm(
55
+ Conv1d(
56
+ channels,
57
+ channels,
58
+ kernel_size,
59
+ 1,
60
+ dilation=dilation[2],
61
+ padding=get_padding(kernel_size, dilation[2]),
62
+ )
63
+ ),
64
+ ]
65
+ )
66
+ self.convs1.apply(init_weights)
67
+
68
+ self.convs2 = nn.ModuleList(
69
+ [
70
+ weight_norm(
71
+ Conv1d(
72
+ channels,
73
+ channels,
74
+ kernel_size,
75
+ 1,
76
+ dilation=1,
77
+ padding=get_padding(kernel_size, 1),
78
+ )
79
+ ),
80
+ weight_norm(
81
+ Conv1d(
82
+ channels,
83
+ channels,
84
+ kernel_size,
85
+ 1,
86
+ dilation=1,
87
+ padding=get_padding(kernel_size, 1),
88
+ )
89
+ ),
90
+ weight_norm(
91
+ Conv1d(
92
+ channels,
93
+ channels,
94
+ kernel_size,
95
+ 1,
96
+ dilation=1,
97
+ padding=get_padding(kernel_size, 1),
98
+ )
99
+ ),
100
+ ]
101
+ )
102
+ self.convs2.apply(init_weights)
103
+
104
+ self.num_layers = len(self.convs1) + len(
105
+ self.convs2
106
+ ) # total number of conv layers
107
+
108
+ Activation1d = TorchActivation1d
109
+ if (
110
+ activation == "snake"
111
+ ): # periodic nonlinearity with snake function and anti-aliasing
112
+ self.activations = nn.ModuleList(
113
+ [
114
+ Activation1d(
115
+ activation=Snake(channels, alpha_logscale=snake_logscale)
116
+ )
117
+ for _ in range(self.num_layers)
118
+ ]
119
+ )
120
+ elif (
121
+ activation == "snakebeta"
122
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
123
+ self.activations = nn.ModuleList(
124
+ [
125
+ Activation1d(
126
+ activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
127
+ )
128
+ for _ in range(self.num_layers)
129
+ ]
130
+ )
131
+ else:
132
+ raise NotImplementedError(
133
+ "activation incorrectly specified. check the config file and look for 'activation'."
134
+ )
135
+
136
+ def forward(self, x):
137
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
138
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
139
+ xt = a1(x)
140
+ xt = c1(xt)
141
+ xt = a2(xt)
142
+ xt = c2(xt)
143
+ x = xt + x
144
+
145
+ return x
146
+
147
+ def remove_weight_norm(self):
148
+ for l in self.convs1:
149
+ remove_weight_norm(l)
150
+ for l in self.convs2:
151
+ remove_weight_norm(l)
152
+
153
+
154
+ class AMPBlock2(torch.nn.Module):
155
+ def __init__(
156
+ self,
157
+ channels,
158
+ kernel_size=3,
159
+ dilation=(1, 3),
160
+ activation=None,
161
+ snake_logscale=True,
162
+ ):
163
+ super(AMPBlock2, self).__init__()
164
+
165
+ self.convs = nn.ModuleList(
166
+ [
167
+ weight_norm(
168
+ Conv1d(
169
+ channels,
170
+ channels,
171
+ kernel_size,
172
+ 1,
173
+ dilation=dilation[0],
174
+ padding=get_padding(kernel_size, dilation[0]),
175
+ )
176
+ ),
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ 1,
183
+ dilation=dilation[1],
184
+ padding=get_padding(kernel_size, dilation[1]),
185
+ )
186
+ ),
187
+ ]
188
+ )
189
+ self.convs.apply(init_weights)
190
+
191
+ self.num_layers = len(self.convs) # total number of conv layers
192
+
193
+ Activation1d = TorchActivation1d
194
+
195
+ if (
196
+ activation == "snake"
197
+ ): # periodic nonlinearity with snake function and anti-aliasing
198
+ self.activations = nn.ModuleList(
199
+ [
200
+ Activation1d(
201
+ activation=Snake(channels, alpha_logscale=snake_logscale)
202
+ )
203
+ for _ in range(self.num_layers)
204
+ ]
205
+ )
206
+ elif (
207
+ activation == "snakebeta"
208
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
209
+ self.activations = nn.ModuleList(
210
+ [
211
+ Activation1d(
212
+ activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
213
+ )
214
+ for _ in range(self.num_layers)
215
+ ]
216
+ )
217
+ else:
218
+ raise NotImplementedError(
219
+ "activation incorrectly specified. check the config file and look for 'activation'."
220
+ )
221
+
222
+ def forward(self, x):
223
+ for c, a in zip(self.convs, self.activations):
224
+ xt = a(x)
225
+ xt = c(xt)
226
+ x = xt + x
227
+
228
+ return x
229
+
230
+ def remove_weight_norm(self):
231
+ for l in self.convs:
232
+ remove_weight_norm(l)
233
+
234
+
235
+ class BigVGAN(torch.nn.Module):
236
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
237
+ def __init__(
238
+ self,
239
+ num_mels: int,
240
+ upsample_initial_channel: int,
241
+ resblock_kernel_sizes: tp.List[int],
242
+ resblock_dilation_sizes: tp.List[tp.List[int]],
243
+ upsample_rates: tp.List[int],
244
+ upsample_kernel_sizes: tp.List[int],
245
+ resblock_type: str = "1",
246
+ snake_logscale: bool = True,
247
+ activation: str = "snakebeta",
248
+ use_tanh_at_final: bool = False,
249
+ use_bias_at_final: bool = False,
250
+ **kwargs,
251
+ ):
252
+ super(BigVGAN, self).__init__()
253
+
254
+ self.num_kernels = len(resblock_kernel_sizes)
255
+ self.num_upsamples = len(upsample_rates)
256
+
257
+ # pre conv
258
+ self.conv_pre = weight_norm(
259
+ Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)
260
+ )
261
+
262
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
263
+ resblock = AMPBlock1 if resblock_type == "1" else AMPBlock2
264
+
265
+ # transposed conv-based upsamplers. does not apply anti-aliasing
266
+ self.ups = nn.ModuleList()
267
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
268
+ self.ups.append(
269
+ nn.ModuleList(
270
+ [
271
+ weight_norm(
272
+ ConvTranspose1d(
273
+ upsample_initial_channel // (2**i),
274
+ upsample_initial_channel // (2 ** (i + 1)),
275
+ k,
276
+ u,
277
+ padding=(k - u) // 2,
278
+ )
279
+ )
280
+ ]
281
+ )
282
+ )
283
+
284
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
285
+ self.resblocks = nn.ModuleList()
286
+ for i in range(len(self.ups)):
287
+ ch = upsample_initial_channel // (2 ** (i + 1))
288
+ for j, (k, d) in enumerate(
289
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
290
+ ):
291
+ self.resblocks.append(
292
+ resblock(
293
+ ch,
294
+ k,
295
+ d,
296
+ activation=activation,
297
+ snake_logscale=snake_logscale,
298
+ )
299
+ )
300
+
301
+ Activation1d = TorchActivation1d
302
+
303
+ # post conv
304
+ if (
305
+ activation == "snake"
306
+ ): # periodic nonlinearity with snake function and anti-aliasing
307
+ activation_post = Snake(ch, alpha_logscale=snake_logscale)
308
+ self.activation_post = Activation1d(activation=activation_post)
309
+ elif (
310
+ activation == "snakebeta"
311
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
312
+ activation_post = SnakeBeta(ch, alpha_logscale=snake_logscale)
313
+ self.activation_post = Activation1d(activation=activation_post)
314
+ else:
315
+ raise NotImplementedError(
316
+ "activation incorrectly specified. check the config file and look for 'activation'."
317
+ )
318
+
319
+ # whether to use bias for the final conv_post. Defaults to True for backward compatibility
320
+ self.use_bias_at_final = use_bias_at_final
321
+ self.conv_post = weight_norm(
322
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
323
+ )
324
+
325
+ # weight initialization
326
+ for i in range(len(self.ups)):
327
+ self.ups[i].apply(init_weights)
328
+ self.conv_post.apply(init_weights)
329
+
330
+ # final tanh activation. Defaults to True for backward compatibility
331
+ self.use_tanh_at_final = use_tanh_at_final
332
+
333
+ def forward(self, x):
334
+ # pre conv
335
+ x = self.conv_pre(x)
336
+
337
+ for i in range(self.num_upsamples):
338
+ # upsampling
339
+ for i_up in range(len(self.ups[i])):
340
+ x = self.ups[i][i_up](x)
341
+ # AMP blocks
342
+ xs = None
343
+ for j in range(self.num_kernels):
344
+ if xs is None:
345
+ xs = self.resblocks[i * self.num_kernels + j](x)
346
+ else:
347
+ xs += self.resblocks[i * self.num_kernels + j](x)
348
+ x = xs / self.num_kernels
349
+
350
+ # post conv
351
+ x = self.activation_post(x)
352
+ x = self.conv_post(x)
353
+ # final tanh activation
354
+ if self.use_tanh_at_final:
355
+ x = torch.tanh(x)
356
+ else:
357
+ x = torch.clamp(x, min=-1.0, max=1.0) # bound the output to [-1, 1]
358
+
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ print("Removing weight norm...")
363
+ for l in self.ups:
364
+ for l_i in l:
365
+ remove_weight_norm(l_i)
366
+ for l in self.resblocks:
367
+ l.remove_weight_norm()
368
+ remove_weight_norm(self.conv_pre)
369
+ remove_weight_norm(self.conv_post)
fireredtts/modules/bigvgan/mel_spectrogram.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchaudio
4
+ from librosa.filters import mel as librosa_mel_fn
5
+
6
+
7
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
8
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
9
+
10
+
11
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
12
+ return torch.log(torch.clamp(x, min=clip_val) * C)
13
+
14
+
15
+ def spectral_normalize_torch(magnitudes):
16
+ output = dynamic_range_compression_torch(magnitudes)
17
+ return output
18
+
19
+
20
+ mel_basis = {}
21
+ hann_window = {}
22
+
23
+
24
+ def mel_spectrogram(
25
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
26
+ ):
27
+ global mel_basis, hann_window # pylint: disable=global-statement
28
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
29
+ mel = librosa_mel_fn(
30
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
31
+ )
32
+ mel_basis[str(fmax) + "_" + str(y.device)] = (
33
+ torch.from_numpy(mel).float().to(y.device)
34
+ )
35
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
36
+
37
+ y = torch.nn.functional.pad(
38
+ y.unsqueeze(1),
39
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
40
+ mode="reflect",
41
+ )
42
+ y = y.squeeze(1)
43
+
44
+ spec = torch.view_as_real(
45
+ torch.stft(
46
+ y,
47
+ n_fft,
48
+ hop_length=hop_size,
49
+ win_length=win_size,
50
+ window=hann_window[str(y.device)],
51
+ center=center,
52
+ pad_mode="reflect",
53
+ normalized=False,
54
+ onesided=True,
55
+ return_complex=True,
56
+ )
57
+ )
58
+
59
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
60
+
61
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
62
+ spec = spectral_normalize_torch(spec)
63
+
64
+ return spec
65
+
66
+
67
+ class MelExtractor(object):
68
+ def __init__(
69
+ self,
70
+ num_mels: int = 80,
71
+ n_fft: int = 1920,
72
+ hop_size: int = 480,
73
+ win_size: int = 1920,
74
+ sampling_rate: int = 24000,
75
+ fmin: int = 0,
76
+ fmax: int = 8000,
77
+ center: bool = False,
78
+ ):
79
+ super().__init__()
80
+ self.num_mels = num_mels
81
+ self.n_fft = n_fft
82
+ self.hop_size = hop_size
83
+ self.win_size = win_size
84
+ self.sampling_rate = sampling_rate
85
+ self.fmin = fmin
86
+ self.fmax = fmax
87
+ self.center = center
88
+
89
+ def __call__(self, audio: torch.Tensor, audio_sr: int):
90
+ """Args:
91
+ audio(torch.Tensor): shape (1, t)
92
+ Returns:
93
+ mel(torch.Tensor): shape (1, num_mels, t')
94
+ """
95
+ if audio_sr != self.sampling_rate:
96
+ audio = torchaudio.functional.resample(
97
+ audio, orig_freq=audio_sr, new_freq=self.sampling_rate
98
+ )
99
+ audio_sr = self.sampling_rate
100
+ mel = mel_spectrogram(
101
+ audio,
102
+ self.n_fft,
103
+ self.num_mels,
104
+ self.sampling_rate,
105
+ self.hop_size,
106
+ self.win_size,
107
+ self.fmin,
108
+ self.fmax,
109
+ self.center,
110
+ ) # (1, num_mels, t)
111
+ return mel
fireredtts/modules/flowmatching/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .estimator_dit import DiT
2
+ from .upsample_encoder import UpsampleConformerEncoder
3
+ from .flow import CausalFmWithSpkCtx, DualEmbedding
4
+
5
+
6
+ class FlowToken2Mel(CausalFmWithSpkCtx):
7
+ def __init__(self, config):
8
+ token_emb = DualEmbedding(**config['token_emb'])
9
+ encoder = UpsampleConformerEncoder(**config['encoder'])
10
+ estimator = DiT(**config['estimator'])
11
+ super().__init__(
12
+ spk_channels=config['spk_channels'],
13
+ spk_enc_channels=config['spk_enc_channels'],
14
+ infer_cfg_rate=config['infer_cfg_rate'],
15
+ token_emb=token_emb,
16
+ encoder=encoder,
17
+ estimator=estimator,
18
+ )
fireredtts/modules/flowmatching/estimator_dit.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from typing import Optional
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class MLP(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_features:int,
12
+ hidden_features:Optional[int]=None,
13
+ out_features:Optional[int]=None,
14
+ act_layer=nn.GELU,
15
+ norm_layer=None,
16
+ bias=True,
17
+ drop=0.,
18
+ ):
19
+ super().__init__()
20
+ hidden_features = hidden_features or in_features
21
+ out_features = out_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
23
+ self.act = act_layer()
24
+ self.drop1 = nn.Dropout(drop)
25
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
26
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
27
+ self.drop2 = nn.Dropout(drop)
28
+
29
+ def forward(self, x):
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.drop1(x)
33
+ x = self.norm(x)
34
+ x = self.fc2(x)
35
+ x = self.drop2(x)
36
+ return x
37
+
38
+
39
+ class Attention(torch.nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ num_heads: int = 8,
44
+ head_dim: int = 64,
45
+ qkv_bias: bool = False,
46
+ qk_norm: bool = False,
47
+ attn_drop: float = 0.,
48
+ proj_drop: float = 0.,
49
+ norm_layer: nn.Module = nn.LayerNorm,
50
+ ) -> None:
51
+ super().__init__()
52
+ self.num_heads = num_heads
53
+ self.head_dim = head_dim
54
+ self.inner_dim = num_heads * head_dim
55
+ self.scale = head_dim ** -0.5
56
+
57
+ self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
58
+ self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
59
+ self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
60
+
61
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
62
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
63
+
64
+ self.attn_drop = nn.Dropout(attn_drop)
65
+ self.proj_drop = nn.Dropout(proj_drop)
66
+
67
+ self.proj = nn.Linear(self.inner_dim, dim)
68
+
69
+ def to_heads(self, ts:torch.Tensor):
70
+ b, t, c = ts.shape
71
+ # (b, t, nh, c)
72
+ ts = ts.reshape(b, t, self.num_heads, c // self.num_heads)
73
+ ts = ts.transpose(1, 2)
74
+ return ts
75
+
76
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
77
+ """Args:
78
+ x(torch.Tensor): shape (b, t, c)
79
+ attn_mask(torch.Tensor): shape (b, t, t)
80
+ """
81
+ b, t, c = x.shape
82
+
83
+ q = self.to_q(x)
84
+ k = self.to_k(x)
85
+ v = self.to_v(x)
86
+
87
+ q = self.to_heads(q) # (b, nh, t, c)
88
+ k = self.to_heads(k)
89
+ v = self.to_heads(v)
90
+
91
+ q = self.q_norm(q)
92
+ k = self.k_norm(k)
93
+
94
+ if attn_mask is not None:
95
+ attn_mask = attn_mask.unsqueeze(1)
96
+
97
+ x = F.scaled_dot_product_attention(
98
+ q, k, v,
99
+ attn_mask=attn_mask,
100
+ dropout_p=self.attn_drop.p if self.training else 0.,
101
+ ) # (b, nh, t, c)
102
+ x = x.transpose(1, 2).reshape(b, t, -1)
103
+ x = self.proj(x)
104
+ x = self.proj_drop(x)
105
+ return x
106
+
107
+
108
+ def modulate(x, shift, scale):
109
+ return x * (1 + scale) + shift
110
+
111
+
112
+ class TimestepEmbedder(nn.Module):
113
+ """
114
+ Embeds scalar timesteps into vector representations.
115
+ """
116
+ def __init__(self, hidden_size, frequency_embedding_size=256):
117
+ super().__init__()
118
+ self.mlp = nn.Sequential(
119
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
120
+ nn.SiLU(),
121
+ nn.Linear(hidden_size, hidden_size, bias=True),
122
+ )
123
+ self.frequency_embedding_size = frequency_embedding_size
124
+ # from SinusoidalPosEmb
125
+ self.scale = 1000
126
+
127
+ @staticmethod
128
+ def timestep_embedding(t, dim, max_period=10000):
129
+ """
130
+ Create sinusoidal timestep embeddings.
131
+ :param t: a 1-D Tensor of N indices, one per batch element.
132
+ These may be fractional.
133
+ :param dim: the dimension of the output.
134
+ :param max_period: controls the minimum frequency of the embeddings.
135
+ :return: an (N, D) Tensor of positional embeddings.
136
+ """
137
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
138
+ half = dim // 2
139
+ freqs = torch.exp(
140
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
141
+ ).to(device=t.device)
142
+ args = t[:, None].float() * freqs[None]
143
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
144
+ if dim % 2:
145
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
146
+ return embedding
147
+
148
+ def forward(self, t):
149
+ t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size)
150
+ t_emb = self.mlp(t_freq)
151
+ return t_emb
152
+
153
+
154
+ # Convolution related
155
+ class Transpose(torch.nn.Module):
156
+ def __init__(self, dim0: int, dim1: int):
157
+ super().__init__()
158
+ self.dim0 = dim0
159
+ self.dim1 = dim1
160
+
161
+ def forward(self, x: torch.Tensor):
162
+ x = torch.transpose(x, self.dim0, self.dim1)
163
+ return x
164
+
165
+
166
+ class CausalConv1d(torch.nn.Conv1d):
167
+ def __init__(
168
+ self,
169
+ in_channels: int,
170
+ out_channels: int,
171
+ kernel_size: int,
172
+ ) -> None:
173
+ super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
174
+ self.causal_padding = (kernel_size - 1, 0)
175
+
176
+ def forward(self, x: torch.Tensor):
177
+ x = F.pad(x, self.causal_padding)
178
+ x = super(CausalConv1d, self).forward(x)
179
+ return x
180
+
181
+
182
+ class CausalConvBlock(nn.Module):
183
+ def __init__(self,
184
+ in_channels: int,
185
+ out_channels: int,
186
+ kernel_size: int = 3,
187
+ ):
188
+ super().__init__()
189
+ self.in_channels = in_channels
190
+ self.out_channels = out_channels
191
+ self.kernel_size = kernel_size
192
+
193
+ self.block = torch.nn.Sequential(
194
+ # norm
195
+ # conv1
196
+ Transpose(1, 2),
197
+ CausalConv1d(in_channels, out_channels, kernel_size),
198
+ Transpose(1, 2),
199
+ # norm & act
200
+ nn.LayerNorm(out_channels),
201
+ nn.Mish(),
202
+ # conv2
203
+ Transpose(1, 2),
204
+ CausalConv1d(out_channels, out_channels, kernel_size),
205
+ Transpose(1, 2),
206
+ )
207
+
208
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
209
+ """
210
+ Args:
211
+ x: shape (b, t, c)
212
+ mask: shape (b, t, 1)
213
+ """
214
+ if mask is not None: x = x * mask
215
+ x = self.block(x)
216
+ if mask is not None: x = x * mask
217
+ return x
218
+
219
+
220
+ class DiTBlock(nn.Module):
221
+ """
222
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
223
+ """
224
+ def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs):
225
+ super().__init__()
226
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
227
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs)
228
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
229
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
230
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
231
+ self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
232
+ self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
+ self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3)
234
+ self.adaLN_modulation = nn.Sequential(
235
+ nn.SiLU(),
236
+ nn.Linear(hidden_size, 9 * hidden_size, bias=True)
237
+ )
238
+
239
+ def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None):
240
+ """Args
241
+ x: shape (b, t, c)
242
+ c: shape (b, 1, c)
243
+ attn_mask: shape (b, t, t), bool type attention mask
244
+ conv_mask: shape (b, 1, t), bool type non-pad mask
245
+ """
246
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
247
+ = self.adaLN_modulation(c).chunk(9, dim=-1)
248
+ # attention
249
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask)
250
+ # conv
251
+ x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv), mask=conv_mask)
252
+ # mlp
253
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
254
+ return x
255
+
256
+
257
+ class FinalLayer(nn.Module):
258
+ """
259
+ The final layer of DiT.
260
+ """
261
+ def __init__(self, hidden_size, out_channels):
262
+ super().__init__()
263
+ self.adaLN_modulation = nn.Sequential(
264
+ nn.SiLU(),
265
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
266
+ )
267
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
268
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
269
+
270
+ def forward(self, x, c):
271
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
272
+ x = modulate(self.norm_final(x), shift, scale)
273
+ x = self.linear(x)
274
+ return x
275
+
276
+
277
+ class DiT(nn.Module):
278
+ """
279
+ Diffusion model with a Transformer backbone.
280
+ """
281
+ def __init__(
282
+ self,
283
+ in_channels: int,
284
+ out_channels: int,
285
+ mlp_ratio: float = 4.0,
286
+ depth: int = 28,
287
+ num_heads: int = 8,
288
+ head_dim: int = 64,
289
+ hidden_size: int = 256,
290
+ ):
291
+ super().__init__()
292
+ self.in_channels = in_channels
293
+ self.out_channels = out_channels
294
+ self.t_embedder = TimestepEmbedder(hidden_size)
295
+
296
+ self.in_proj = nn.Linear(in_channels, hidden_size)
297
+
298
+ self.blocks = nn.ModuleList([
299
+ DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth)
300
+ ])
301
+ self.final_layer = FinalLayer(hidden_size, self.out_channels)
302
+ self.initialize_weights()
303
+
304
+ def initialize_weights(self):
305
+ # Initialize transformer layers:
306
+ def _basic_init(module):
307
+ if isinstance(module, nn.Linear):
308
+ torch.nn.init.xavier_uniform_(module.weight)
309
+ if module.bias is not None:
310
+ nn.init.constant_(module.bias, 0)
311
+ self.apply(_basic_init)
312
+
313
+ # Initialize timestep embedding MLP:
314
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
315
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
316
+
317
+ # Zero-out adaLN modulation layers in DiT blocks:
318
+ for block in self.blocks:
319
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
320
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
321
+
322
+ # Zero-out output layers:
323
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
324
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
325
+ nn.init.constant_(self.final_layer.linear.weight, 0)
326
+ nn.init.constant_(self.final_layer.linear.bias, 0)
327
+
328
+ """For non-streaming inference.
329
+ """
330
+ def forward(self, x:torch.Tensor, c:torch.Tensor, t:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None):
331
+ """
332
+ Args:
333
+ x: shape (b, c, t)
334
+ c: aux condition, shape (b, c, t)
335
+ t: shape (b,)
336
+ attn_mask: (b, t, t)
337
+ conv_mask: (b, 1, t)
338
+ Returns:
339
+ pred: shape (b, c, t)
340
+ """
341
+ # time
342
+ t = self.t_embedder(t.view(-1)).unsqueeze(1) # (b, 1, c)
343
+
344
+ # CausalConvBlock mask is (b, t, 1)
345
+ conv_mask = conv_mask if conv_mask is None else conv_mask.transpose(1, 2)
346
+
347
+ x = torch.cat([x, c], dim=1)
348
+ # forward blocks
349
+ x = x.transpose(1, 2)
350
+ x = self.in_proj(x)
351
+ for block in self.blocks:
352
+ x = block(x, t, attn_mask=attn_mask, conv_mask=conv_mask)
353
+ x = self.final_layer(x, t)
354
+ x = x.transpose(1, 2)
355
+ return x
356
+
fireredtts/modules/flowmatching/flow.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Dict, List
4
+ from einops import pack, repeat
5
+ from .estimator_dit import DiT
6
+ from .upsample_encoder import UpsampleConformerEncoder
7
+
8
+
9
+ class DualEmbedding(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ channels:int=512,
13
+ ):
14
+ super().__init__()
15
+ self.codebook_size = 128
16
+ self.codebook_dim = 128
17
+ self.codebook = torch.nn.ModuleList([
18
+ torch.nn.Embedding(self.codebook_size, self.codebook_dim),
19
+ torch.nn.Embedding(self.codebook_size, self.codebook_dim),
20
+ ])
21
+ self.out_proj = torch.nn.Linear(self.codebook_dim * 2, channels)
22
+
23
+ def forward(self, tokens):
24
+ """
25
+ Args:
26
+ tokens: shape (b, t)
27
+ Returns:
28
+ token_embs: shape (b, t, c)
29
+ """
30
+ token_embs = torch.cat([
31
+ self.codebook[0](tokens % self.codebook_size),
32
+ self.codebook[1](tokens // self.codebook_size)
33
+ ], dim=-1)
34
+ token_embs = self.out_proj(token_embs)
35
+ return token_embs
36
+
37
+
38
+ class CausalFmWithSpkCtx(torch.nn.Module):
39
+ def __init__(
40
+ self,
41
+ # Basic in-out
42
+ spk_channels: int,
43
+ spk_enc_channels: int, # out channels of spk & encoder projection
44
+ # Module
45
+ token_emb: DualEmbedding,
46
+ encoder: UpsampleConformerEncoder,
47
+ estimator: DiT,
48
+ # Flow cfg
49
+ infer_cfg_rate: float = 0.7,
50
+ ):
51
+ super().__init__()
52
+ # Variants
53
+ self.up_stride = encoder.up_stride
54
+ self.infer_cfg_rate = infer_cfg_rate
55
+ # Module
56
+ self.spk_proj = torch.nn.Linear(spk_channels, spk_enc_channels)
57
+ self.token_emb = token_emb
58
+ self.encoder = encoder
59
+ self.encoder_proj = torch.nn.Linear(encoder.output_size, spk_enc_channels)
60
+ self.estimator = estimator
61
+ # Initial noise, maximum of 600s
62
+ self.register_buffer(
63
+ "x0",
64
+ torch.randn([1, self.estimator.out_channels, 50 * 600]),
65
+ persistent=False,
66
+ )
67
+
68
+ def _euler(
69
+ self,
70
+ x0: torch.Tensor,
71
+ c: torch.Tensor,
72
+ n_timesteps: int = 10,
73
+ ):
74
+ # time steps
75
+ t_span = torch.linspace(0, 1, n_timesteps + 1).to(x0)
76
+ # cosine time schduling
77
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
78
+ # euler solver
79
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
80
+ t = t.unsqueeze(dim=0)
81
+
82
+ xt = x0
83
+ for step in range(1, len(t_span)):
84
+ # pack input
85
+ x_in = torch.cat([xt, xt], dim=0)
86
+ c_in = torch.cat([c, torch.zeros_like(c)], dim=0)
87
+ t_in = torch.cat([t, t], dim=0)
88
+
89
+ # model call
90
+ with torch.no_grad():
91
+ vt = self.estimator.forward(x_in, c_in, t_in)
92
+ # cfg
93
+ vt_cond, vt_cfg = vt.chunk(2, dim=0)
94
+ vt = (1.0 + self.infer_cfg_rate) * vt_cond - self.infer_cfg_rate * vt_cfg
95
+
96
+ xt = xt + dt * vt
97
+ t = t + dt
98
+ if step < len(t_span) - 1:
99
+ dt = t_span[step + 1] - t
100
+ return xt
101
+
102
+ def inference(
103
+ self,
104
+ prompt_token: torch.Tensor,
105
+ prompt_xvec: torch.Tensor,
106
+ prompt_feat: torch.Tensor,
107
+ token: torch.Tensor,
108
+ ):
109
+ # NOTE align prompt_token, prompt_feat in advance
110
+
111
+ # Spk condition
112
+ embedding = F.normalize(prompt_xvec, dim=1)
113
+ spks = self.spk_proj(embedding)
114
+
115
+ # Token condition
116
+ token = torch.concat([prompt_token, token], dim=1)
117
+ xs = self.token_emb(token)
118
+
119
+ xs_lens = torch.tensor([xs.shape[1]]).to(token)
120
+ xs = self.encoder(xs, xs_lens)
121
+ mu = self.encoder_proj(xs)
122
+
123
+ # Mel context
124
+ ctx = torch.zeros_like(mu)
125
+ ctx[:, : prompt_feat.shape[1]] = prompt_feat
126
+
127
+ # Compose condition
128
+ cond = mu.transpose(1, 2)
129
+ ctx = ctx.transpose(1, 2)
130
+ spks = repeat(spks, "b c -> b c t", t=cond.shape[-1])
131
+ cond = pack([cond, spks, ctx], "b * t")[0]
132
+
133
+ # FM inference
134
+ x0 = self.x0[..., : mu.shape[1]]
135
+ x1 = self._euler(x0, cond, n_timesteps=10)
136
+
137
+ feat = x1.transpose(1, 2)[:, prompt_feat.shape[1] :]
138
+ return feat
fireredtts/modules/flowmatching/upsample_encoder.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from typing import Tuple, List, Union
6
+
7
+
8
+ """Attention modules.
9
+ """
10
+ class MultiHeadedAttention(nn.Module):
11
+ def __init__(self,
12
+ n_head: int,
13
+ n_feat: int,
14
+ dropout_rate: float,
15
+ key_bias: bool = True):
16
+ super().__init__()
17
+ assert n_feat % n_head == 0
18
+ # We assume d_v always equals d_k
19
+ self.d_k = n_feat // n_head
20
+ self.h = n_head
21
+ self.linear_q = nn.Linear(n_feat, n_feat)
22
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
23
+ self.linear_v = nn.Linear(n_feat, n_feat)
24
+ self.linear_out = nn.Linear(n_feat, n_feat)
25
+ self.dropout = nn.Dropout(p=dropout_rate)
26
+
27
+ def forward_qkv(self,
28
+ query: torch.Tensor,
29
+ key: torch.Tensor,
30
+ value: torch.Tensor):
31
+ """
32
+ Args:
33
+ query,key,value: shape (b, t, c)
34
+ Returns:
35
+ query,key,value: shape (b, nh, t, c//nh)
36
+ """
37
+ n_batch = query.size(0)
38
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
39
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
40
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
41
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
42
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
43
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
44
+ return q, k, v
45
+
46
+ def forward_attention(self,
47
+ value: torch.Tensor,
48
+ scores: torch.Tensor,
49
+ mask: torch.Tensor = None):
50
+ """Compute attention context vector.
51
+ Args:
52
+ value (torch.Tensor): shape: (b, nh, t2, c//nh).
53
+ scores (torch.Tensor): shape: (b, nh, t1, t2).
54
+ mask (torch.Tensor): attention padded mask, size (b, 1, t2) or (b, t1, t2)
55
+ Returns:
56
+ shape: (b, t1, c)
57
+ """
58
+ b = value.size(0)
59
+ if mask is not None:
60
+ mask = mask.unsqueeze(1).eq(0)
61
+ scores = scores.masked_fill(mask, -float('inf'))
62
+ attn = scores.softmax(dim=-1).masked_fill(mask, 0.0)
63
+ else:
64
+ attn = scores.softmax(dim=-1)
65
+ p_attn = self.dropout(attn)
66
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
67
+ x = x.transpose(1, 2).contiguous().view(b, -1, self.h * self.d_k)
68
+ return self.linear_out(x)
69
+
70
+
71
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
72
+ def __init__(self,
73
+ n_head: int,
74
+ n_feat: int,
75
+ dropout_rate: float,
76
+ key_bias: bool = True):
77
+ """Multi-Head Attention layer with relative position encoding.
78
+ Paper: https://arxiv.org/abs/1901.02860
79
+ Args:
80
+ n_head (int): The number of heads.
81
+ n_feat (int): The number of features.
82
+ dropout_rate (float): Dropout rate.
83
+ """
84
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
85
+ # linear transformation for positional encoding
86
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
87
+ # these two learnable bias are used in matrix c and matrix d
88
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
89
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
90
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
91
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
92
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
93
+
94
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
95
+ """Compute relative positional encoding.
96
+
97
+ Args:
98
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
99
+ time1 means the length of query vector.
100
+
101
+ Returns:
102
+ torch.Tensor: Output tensor.
103
+
104
+ """
105
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
106
+ device=x.device,
107
+ dtype=x.dtype)
108
+ x_padded = torch.cat([zero_pad, x], dim=-1)
109
+
110
+ x_padded = x_padded.view(x.size()[0],
111
+ x.size()[1],
112
+ x.size(3) + 1, x.size(2))
113
+ x = x_padded[:, :, 1:].view_as(x)[
114
+ :, :, :, : x.size(-1) // 2 + 1
115
+ ] # only keep the positions from 0 to time2
116
+ return x
117
+
118
+ def forward(
119
+ self,
120
+ query: torch.Tensor,
121
+ key: torch.Tensor,
122
+ value: torch.Tensor,
123
+ pos_emb: torch.Tensor,
124
+ mask: torch.Tensor = None,
125
+ cache: torch.Tensor = None,
126
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ """
128
+ Args:
129
+ query (torch.Tensor): shape (b, t1, c).
130
+ key (torch.Tensor): shape (b, t2, c).
131
+ value (torch.Tensor): shape (b, t2, c).
132
+ mask (torch.Tensor): attention padded mask, shape (b, 1, t2) or (b, t1, t2).
133
+ pos_emb (torch.Tensor): Positional embedding tensor (b, 2*t1-1, c).
134
+ cache (torch.Tensor): Cache tensor (1, nh, cache_t, d_k * 2).
135
+ Returns:
136
+ torch.Tensor: Output tensor (b, t1, d_model).
137
+ torch.Tensor: Cache tensor (1, nh, cache_t + t1, d_k * 2)
138
+ """
139
+ q, k, v = self.forward_qkv(query, key, value)
140
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
141
+
142
+ if cache is not None and cache.size(0) > 0:
143
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
144
+ k = torch.cat([key_cache, k], dim=2)
145
+ v = torch.cat([value_cache, v], dim=2)
146
+ new_cache = torch.cat((k, v), dim=-1)
147
+
148
+ n_batch_pos = pos_emb.size(0)
149
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) # (batch, 2*time1-1, head, d_k)
150
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
151
+
152
+ # (batch, head, time1, d_k)
153
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
154
+ # (batch, head, time1, d_k)
155
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
156
+
157
+ # compute attention score
158
+ # first compute matrix a and matrix c
159
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
160
+ # (batch, head, time1, time2)
161
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
162
+
163
+ # compute matrix b and matrix d
164
+ # matrix_bd: (batch, head, time1, 2*time1-1)
165
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
166
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
167
+ if matrix_ac.shape != matrix_bd.shape:
168
+ matrix_bd = self.rel_shift(matrix_bd)
169
+
170
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
171
+
172
+ return self.forward_attention(v, scores, mask), new_cache
173
+
174
+
175
+ class EspnetRelPositionalEncoding(torch.nn.Module):
176
+ """Relative positional encoding module (new implementation).
177
+
178
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
179
+
180
+ See : Appendix B in https://arxiv.org/abs/1901.02860
181
+
182
+ Args:
183
+ d_model (int): Embedding dimension.
184
+ dropout_rate (float): Dropout rate.
185
+ max_len (int): Maximum input length.
186
+
187
+ """
188
+
189
+ def __init__(self, d_model: int, dropout_rate: float=0.0, max_len: int = 5000):
190
+ """Construct an PositionalEncoding object."""
191
+ super(EspnetRelPositionalEncoding, self).__init__()
192
+ self.d_model = d_model
193
+ self.xscale = math.sqrt(self.d_model)
194
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
195
+ self.pe = None
196
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
197
+
198
+ def extend_pe(self, x: torch.Tensor):
199
+ """Reset the positional encodings."""
200
+ if self.pe is not None:
201
+ # self.pe contains both positive and negative parts
202
+ # the length of self.pe is 2 * input_len - 1
203
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
204
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
205
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
206
+ return
207
+ # Suppose `i` means to the position of query vecotr and `j` means the
208
+ # position of key vector. We use position relative positions when keys
209
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
210
+ pe_positive = torch.zeros(x.size(1), self.d_model)
211
+ pe_negative = torch.zeros(x.size(1), self.d_model)
212
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
213
+ div_term = torch.exp(
214
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
215
+ * -(math.log(10000.0) / self.d_model)
216
+ )
217
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
218
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
219
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
220
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
221
+
222
+ # Reserve the order of positive indices and concat both positive and
223
+ # negative indices. This is used to support the shifting trick
224
+ # as in https://arxiv.org/abs/1901.02860
225
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
226
+ pe_negative = pe_negative[1:].unsqueeze(0)
227
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
228
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
229
+
230
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
231
+ -> Tuple[torch.Tensor, torch.Tensor]:
232
+ """Add positional encoding.
233
+
234
+ Args:
235
+ x (torch.Tensor): Input tensor (batch, time, `*`).
236
+
237
+ Returns:
238
+ torch.Tensor: Encoded tensor (batch, time, `*`).
239
+
240
+ """
241
+ self.extend_pe(x)
242
+ x = x * self.xscale
243
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
244
+ return self.dropout(x), self.dropout(pos_emb)
245
+
246
+ def position_encoding(self,
247
+ offset: Union[int, torch.Tensor],
248
+ size: int) -> torch.Tensor:
249
+ """ For getting encoding in a streaming fashion
250
+
251
+ Attention!!!!!
252
+ we apply dropout only once at the whole utterance level in a none
253
+ streaming way, but will call this function several times with
254
+ increasing input size in a streaming scenario, so the dropout will
255
+ be applied several times.
256
+
257
+ Args:
258
+ offset (int or torch.tensor): start offset
259
+ size (int): required size of position encoding
260
+
261
+ Returns:
262
+ torch.Tensor: Corresponding encoding
263
+ """
264
+ pos_emb = self.pe[
265
+ :,
266
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
267
+ ]
268
+ return pos_emb
269
+
270
+
271
+ """Other modules.
272
+ """
273
+ class Upsample1D(nn.Module):
274
+ """A 1D upsampling layer with an optional convolution.
275
+
276
+ Parameters:
277
+ channels (`int`):
278
+ number of channels in the inputs and outputs.
279
+ use_conv (`bool`, default `False`):
280
+ option to use a convolution.
281
+ use_conv_transpose (`bool`, default `False`):
282
+ option to use a convolution transpose.
283
+ out_channels (`int`, optional):
284
+ number of output channels. Defaults to `channels`.
285
+ """
286
+
287
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
288
+ super().__init__()
289
+ self.channels = channels
290
+ self.out_channels = out_channels
291
+ self.stride = stride
292
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
293
+
294
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
295
+ outputs = F.interpolate(inputs, scale_factor=self.stride, mode="nearest")
296
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
297
+ outputs = self.conv(outputs)
298
+ return outputs, input_lengths * self.stride
299
+
300
+
301
+ class PreLookaheadLayer(nn.Module):
302
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
303
+ super().__init__()
304
+ self.channels = channels
305
+ self.pre_lookahead_len = pre_lookahead_len
306
+ self.conv1 = nn.Conv1d(
307
+ channels, channels,
308
+ kernel_size=pre_lookahead_len + 1,
309
+ stride=1, padding=0,
310
+ )
311
+ self.conv2 = nn.Conv1d(
312
+ channels, channels,
313
+ kernel_size=3, stride=1, padding=0,
314
+ )
315
+
316
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
317
+ """
318
+ inputs: (batch_size, seq_len, channels)
319
+ """
320
+ outputs = inputs.transpose(1, 2).contiguous()
321
+ # look ahead
322
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
323
+ outputs = F.leaky_relu(self.conv1(outputs))
324
+ # outputs
325
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
326
+ outputs = self.conv2(outputs)
327
+ outputs = outputs.transpose(1, 2).contiguous()
328
+
329
+ # residual connection
330
+ outputs = outputs + inputs
331
+ return outputs
332
+
333
+
334
+ class PositionwiseFeedForward(torch.nn.Module):
335
+ """Positionwise feed forward layer.
336
+
337
+ FeedForward are appied on each position of the sequence.
338
+ The output dim is same with the input dim.
339
+
340
+ Args:
341
+ idim (int): Input dimenstion.
342
+ hidden_units (int): The number of hidden units.
343
+ dropout_rate (float): Dropout rate.
344
+ activation (torch.nn.Module): Activation function
345
+ """
346
+
347
+ def __init__(
348
+ self,
349
+ idim: int,
350
+ hidden_units: int,
351
+ dropout_rate: float,
352
+ activation: torch.nn.Module = torch.nn.ReLU(),
353
+ ):
354
+ """Construct a PositionwiseFeedForward object."""
355
+ super(PositionwiseFeedForward, self).__init__()
356
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
357
+ self.activation = activation
358
+ self.dropout = torch.nn.Dropout(dropout_rate)
359
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
360
+
361
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
362
+ """Forward function.
363
+
364
+ Args:
365
+ xs: input tensor (B, L, D)
366
+ Returns:
367
+ output tensor, (B, L, D)
368
+ """
369
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
370
+
371
+
372
+ class LinearNoSubsampling(torch.nn.Module):
373
+ """Linear transform the input without subsampling
374
+ Args:
375
+ idim (int): Input dimension.
376
+ odim (int): Output dimension.
377
+ dropout_rate (float): Dropout rate.
378
+ """
379
+ def __init__(self,
380
+ idim: int,
381
+ odim: int,
382
+ dropout_rate: float,
383
+ pos_enc_class: torch.nn.Module
384
+ ):
385
+ """Construct an linear object."""
386
+ super().__init__()
387
+ self.out = torch.nn.Sequential(
388
+ torch.nn.Linear(idim, odim),
389
+ torch.nn.LayerNorm(odim, eps=1e-5),
390
+ torch.nn.Dropout(dropout_rate),
391
+ )
392
+ self.pos_enc = pos_enc_class
393
+
394
+ def forward(
395
+ self,
396
+ x: torch.Tensor,
397
+ offset: int = 0
398
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
399
+ """Input x.
400
+ Args:
401
+ x (torch.Tensor): Input tensor (#batch, time, idim).
402
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
403
+
404
+ Returns:
405
+ torch.Tensor: linear input tensor (#batch, time', odim),
406
+ where time' = time .
407
+ torch.Tensor: linear input mask (#batch, 1, time'),
408
+ where time' = time .
409
+ """
410
+ x = self.out(x)
411
+ x, pos_emb = self.pos_enc(x, offset)
412
+ return x, pos_emb
413
+
414
+
415
+ """Encoder layer & encoder
416
+ """
417
+ class ConformerEncoderLayer(nn.Module):
418
+ """Encoder layer module.
419
+ Args:
420
+ size (int): Input dimension.
421
+ self_attn (torch.nn.Module): Self-attention module instance.
422
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
423
+ instance can be used as the argument.
424
+ feed_forward (torch.nn.Module): Feed-forward module instance.
425
+ `PositionwiseFeedForward` instance can be used as the argument.
426
+ dropout_rate (float): Dropout rate.
427
+ normalize_before (bool):
428
+ True: use layer_norm before each sub-block.
429
+ False: use layer_norm after each sub-block.
430
+ """
431
+
432
+ def __init__(
433
+ self,
434
+ size: int,
435
+ self_attn: torch.nn.Module,
436
+ feed_forward: torch.nn.Module,
437
+ dropout_rate: float = 0.1,
438
+ normalize_before: bool = True,
439
+ ):
440
+ """Construct an EncoderLayer object."""
441
+ super().__init__()
442
+ self.self_attn = self_attn
443
+ self.feed_forward = feed_forward
444
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
445
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
446
+ self.ff_scale = 1.0
447
+ self.dropout = nn.Dropout(dropout_rate)
448
+ self.size = size
449
+ self.normalize_before = normalize_before
450
+
451
+ def forward(
452
+ self,
453
+ x: torch.Tensor,
454
+ mask: torch.Tensor,
455
+ pos_emb: torch.Tensor,
456
+ att_cache: torch.Tensor = None,
457
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
458
+ """
459
+ Args:
460
+ x: shape (b, t, c)
461
+ mask: self-attention padded mask, shape (b, 1, t) or (b, t, t)
462
+ pos_emb: relative positional embedding, shape (b, t, 2t-1)
463
+ att_cache: shape (1, nh, cache_t, d_k * 2)
464
+ """
465
+ # multi-headed self-attention module
466
+ residual = x
467
+ if self.normalize_before:
468
+ x = self.norm_mha(x)
469
+ # att_cache: (b, head, cache_t, d_k*2)
470
+ x_att, new_att_cache = self.self_attn(x, x, x, pos_emb, mask, att_cache)
471
+ x = residual + self.dropout(x_att)
472
+ if not self.normalize_before:
473
+ x = self.norm_mha(x)
474
+
475
+ # feed forward module
476
+ residual = x
477
+ if self.normalize_before:
478
+ x = self.norm_ff(x)
479
+ x_ffn = self.feed_forward(x)
480
+ x = residual + self.ff_scale * self.dropout(x_ffn)
481
+ if not self.normalize_before:
482
+ x = self.norm_ff(x)
483
+
484
+ return x, new_att_cache
485
+
486
+
487
+ class UpsampleConformerEncoder(torch.nn.Module):
488
+
489
+ def __init__(
490
+ self,
491
+ # Common
492
+ input_size: int = 512,
493
+ output_size: int = 512,
494
+ num_blocks: int = 6,
495
+ num_up_blocks: int = 4,
496
+ normalize_before: bool = True,
497
+ # Input & upsampling
498
+ up_stride: int = 2,
499
+ pre_lookahead_len: int = 3,
500
+ # Attention
501
+ attention_heads: int = 4,
502
+ key_bias: bool = True,
503
+ # MLP
504
+ linear_units: int = 2048,
505
+ # Dropouts
506
+ dropout_rate: float = 0.0,
507
+ positional_dropout_rate: float = 0.0,
508
+ attention_dropout_rate: float = 0.0,
509
+ ):
510
+ super().__init__()
511
+ self.input_size = input_size
512
+ self.output_size = output_size
513
+ self.up_stride = up_stride
514
+ # Input embedding
515
+ self.embed = LinearNoSubsampling(
516
+ input_size,
517
+ output_size,
518
+ dropout_rate,
519
+ # Positional encoding
520
+ EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
521
+ )
522
+ # Look ahead
523
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=output_size, pre_lookahead_len=pre_lookahead_len)
524
+ # Norm
525
+ self.normalize_before = normalize_before
526
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
527
+ # Act
528
+ activation = torch.nn.SiLU()
529
+ # Self-attention module definition
530
+ encoder_selfattn_layer_args = (
531
+ attention_heads,
532
+ output_size,
533
+ attention_dropout_rate,
534
+ key_bias,
535
+ )
536
+ # Feed-forward module definition
537
+ positionwise_layer_args = (
538
+ output_size,
539
+ linear_units,
540
+ dropout_rate,
541
+ activation,
542
+ )
543
+ # 1st Conformer
544
+ self.encoders = torch.nn.ModuleList([
545
+ ConformerEncoderLayer(
546
+ output_size,
547
+ # Self-attn
548
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
549
+ # FFN
550
+ PositionwiseFeedForward(*positionwise_layer_args),
551
+ dropout_rate,
552
+ normalize_before,
553
+ ) for _ in range(num_blocks)
554
+ ])
555
+ # Upsample
556
+ self.up_layer = Upsample1D(channels=output_size, out_channels=output_size, stride=up_stride)
557
+ # Input embedding2
558
+ self.up_embed = LinearNoSubsampling(
559
+ input_size,
560
+ output_size,
561
+ dropout_rate,
562
+ # Positional encoding
563
+ EspnetRelPositionalEncoding(output_size, positional_dropout_rate),
564
+ )
565
+ # 2nd Conformer
566
+ self.up_encoders = torch.nn.ModuleList([
567
+ ConformerEncoderLayer(
568
+ output_size,
569
+ # Self-attn
570
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
571
+ # FFN
572
+ PositionwiseFeedForward(*positionwise_layer_args),
573
+ dropout_rate,
574
+ normalize_before,
575
+ ) for _ in range(num_up_blocks)
576
+ ])
577
+
578
+ """For non-streaming inference.
579
+ """
580
+ def forward(
581
+ self,
582
+ xs: torch.Tensor,
583
+ xs_lens: torch.Tensor,
584
+ # attention mask BEFORE upsample
585
+ attn_mask1: torch.Tensor=None,
586
+ # attention mask AFTER upsample
587
+ attn_mask2: torch.Tensor=None,
588
+ ) -> torch.Tensor:
589
+ """
590
+ Args:
591
+ xs: shape (b, t, c)
592
+ xs_lens: shape (b,)
593
+ attn_mask1: (token level) shape (b, t, t)
594
+ attn_mask2: (mel level) shape (b, 2t, 2t)
595
+ """
596
+ # Input & lookahead
597
+ xs, pos_emb = self.embed(xs)
598
+ xs = self.pre_lookahead_layer(xs)
599
+
600
+ # 1st Conformer
601
+ for block in self.encoders:
602
+ xs, _ = block(xs, mask=attn_mask1, pos_emb=pos_emb)
603
+
604
+ # Upsample to mel-level
605
+ xs = xs.transpose(1, 2).contiguous()
606
+ xs, xs_lens = self.up_layer(xs, xs_lens)
607
+ xs = xs.transpose(1, 2).contiguous()
608
+ # Input
609
+ xs, pos_emb = self.up_embed(xs)
610
+
611
+ # 2nd Conformer
612
+ for block in self.up_encoders:
613
+ xs, _ = block(xs, mask=attn_mask2, pos_emb=pos_emb)
614
+
615
+ if self.normalize_before:
616
+ xs = self.after_norm(xs)
617
+ return xs
fireredtts/modules/semantic_llm/llm_gpt2.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
7
+ import functools
8
+ from transformers import GPT2PreTrainedModel, GPT2Model, GPT2Config
9
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
10
+
11
+
12
+ # GPT2 NROMAL INFERENCE MODE
13
+ class GPT2InferenceModel(GPT2PreTrainedModel):
14
+ """Override GPT2LMHeadModel to allow for prefix conditioning."""
15
+
16
+ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
17
+ super().__init__(config)
18
+ self.transformer = gpt
19
+ self.pos_embedding = pos_emb
20
+ self.embeddings = embeddings
21
+ self.final_norm = norm
22
+ self.lm_head = nn.Sequential(norm, linear)
23
+ self.kv_cache = kv_cache
24
+
25
+ def store_prefix_emb(self, prefix_emb):
26
+ self.cached_prefix_emb = prefix_emb
27
+
28
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
29
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
30
+ if not self.kv_cache:
31
+ past_key_values = None
32
+
33
+ # only last token for inputs_ids if past is defined in kwargs
34
+ if past_key_values is not None:
35
+ input_ids = input_ids[:, -1].unsqueeze(-1)
36
+ if token_type_ids is not None:
37
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
38
+
39
+ attention_mask = kwargs.get("attention_mask", None)
40
+ position_ids = kwargs.get("position_ids", None)
41
+
42
+ if attention_mask is not None and position_ids is None:
43
+ # create position_ids on the fly for batch generation
44
+ position_ids = attention_mask.long().cumsum(-1) - 1
45
+ position_ids.masked_fill_(attention_mask == 0, 1)
46
+ if past_key_values is not None:
47
+ position_ids = position_ids[:, -1].unsqueeze(-1)
48
+ else:
49
+ position_ids = None
50
+ return {
51
+ "input_ids": input_ids,
52
+ "past_key_values": past_key_values,
53
+ "use_cache": kwargs.get("use_cache"),
54
+ "position_ids": position_ids,
55
+ "attention_mask": attention_mask,
56
+ "token_type_ids": token_type_ids,
57
+ }
58
+
59
+ def forward(
60
+ self,
61
+ input_ids=None,
62
+ past_key_values=None,
63
+ attention_mask=None,
64
+ token_type_ids=None,
65
+ position_ids=None,
66
+ head_mask=None,
67
+ inputs_embeds=None,
68
+ encoder_hidden_states=None,
69
+ encoder_attention_mask=None,
70
+ labels=None,
71
+ use_cache=None,
72
+ output_attentions=None,
73
+ output_hidden_states=None,
74
+ return_dict=None,
75
+ ):
76
+ assert self.cached_prefix_emb is not None
77
+ assert inputs_embeds is None # Not supported by this inference model.
78
+ assert labels is None # Training not supported by this inference model.
79
+ return_dict = (
80
+ return_dict if return_dict is not None else self.config.use_return_dict
81
+ )
82
+
83
+ # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
84
+
85
+ # Create embedding
86
+ prefix_len = self.cached_prefix_emb.shape[1]
87
+ if input_ids.shape[1] != 1:
88
+ gen_inputs = input_ids[:, prefix_len:]
89
+ gen_emb = self.embeddings(gen_inputs)
90
+ gen_emb = gen_emb + self.pos_embedding(gen_emb)
91
+ if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
92
+ prefix_emb = self.cached_prefix_emb.repeat_interleave(
93
+ gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
94
+ )
95
+ else:
96
+ prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
97
+ emb = torch.cat([prefix_emb, gen_emb], dim=1)
98
+ else:
99
+ emb = self.embeddings(input_ids)
100
+ emb = emb + self.pos_embedding.get_fixed_embedding(
101
+ attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
102
+ )
103
+ transformer_outputs = self.transformer(
104
+ inputs_embeds=emb,
105
+ past_key_values=past_key_values,
106
+ attention_mask=attention_mask,
107
+ token_type_ids=token_type_ids,
108
+ position_ids=position_ids,
109
+ head_mask=head_mask,
110
+ encoder_hidden_states=encoder_hidden_states,
111
+ encoder_attention_mask=encoder_attention_mask,
112
+ use_cache=use_cache,
113
+ output_attentions=output_attentions,
114
+ output_hidden_states=output_hidden_states,
115
+ return_dict=return_dict,
116
+ )
117
+ hidden_states = transformer_outputs[0]
118
+ lm_logits = self.lm_head(hidden_states)
119
+
120
+ if not return_dict:
121
+ return (lm_logits,) + transformer_outputs[1:]
122
+
123
+ return CausalLMOutputWithCrossAttentions(
124
+ loss=None,
125
+ logits=lm_logits,
126
+ past_key_values=transformer_outputs.past_key_values,
127
+ hidden_states=transformer_outputs.hidden_states,
128
+ attentions=transformer_outputs.attentions,
129
+ cross_attentions=transformer_outputs.cross_attentions,
130
+ )
131
+
132
+ @staticmethod
133
+ def _reorder_cache(past, beam_idx):
134
+ """
135
+ This function is used to re-order the :obj:`past_key_values` cache if
136
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
137
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
138
+ """
139
+ return tuple(
140
+ tuple(
141
+ past_state.index_select(0, beam_idx.to(past_state.device))
142
+ for past_state in layer_past
143
+ )
144
+ for layer_past in past
145
+ )
146
+
147
+
148
+ # GPT2 INDEX-CONTEXT INFERENCE MODE
149
+ class GPT2ICInferenceModel(GPT2PreTrainedModel):
150
+ """Override GPT2LMHeadModel to allow for prefix conditioning."""
151
+
152
+ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
153
+ super().__init__(config)
154
+ self.transformer = gpt
155
+ self.pos_embedding = pos_emb
156
+ self.embeddings = embeddings
157
+ self.final_norm = norm
158
+ self.lm_head = nn.Sequential(norm, linear)
159
+ self.kv_cache = kv_cache
160
+
161
+ def store_prefix_emb(self, prefix_emb):
162
+ self.cached_prefix_emb = prefix_emb
163
+
164
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
165
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
166
+ if not self.kv_cache:
167
+ past_key_values = None
168
+
169
+ # only last token for inputs_ids if past is defined in kwargs
170
+ if past_key_values is not None:
171
+ input_ids = input_ids[:, -1].unsqueeze(-1)
172
+ if token_type_ids is not None:
173
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
174
+
175
+ attention_mask = kwargs.get("attention_mask", None)
176
+ position_ids = kwargs.get("position_ids", None)
177
+
178
+ if attention_mask is not None and position_ids is None:
179
+ # create position_ids on the fly for batch generation
180
+ position_ids = attention_mask.long().cumsum(-1) - 1
181
+ position_ids.masked_fill_(attention_mask == 0, 1)
182
+ if past_key_values is not None:
183
+ position_ids = position_ids[:, -1].unsqueeze(-1)
184
+ else:
185
+ position_ids = None
186
+ return {
187
+ "input_ids": input_ids,
188
+ "past_key_values": past_key_values,
189
+ "use_cache": kwargs.get("use_cache"),
190
+ "position_ids": position_ids,
191
+ "attention_mask": attention_mask,
192
+ "token_type_ids": token_type_ids,
193
+ }
194
+
195
+ def forward(
196
+ self,
197
+ input_ids=None,
198
+ past_key_values=None,
199
+ attention_mask=None,
200
+ token_type_ids=None,
201
+ position_ids=None,
202
+ head_mask=None,
203
+ inputs_embeds=None,
204
+ encoder_hidden_states=None,
205
+ encoder_attention_mask=None,
206
+ labels=None,
207
+ use_cache=None,
208
+ output_attentions=None,
209
+ output_hidden_states=None,
210
+ return_dict=None,
211
+ ):
212
+ assert self.cached_prefix_emb is not None
213
+ assert inputs_embeds is None # Not supported by this inference model.
214
+ assert labels is None # Training not supported by this inference model.
215
+ return_dict = (
216
+ return_dict if return_dict is not None else self.config.use_return_dict
217
+ )
218
+
219
+ # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
220
+
221
+ # Create embedding
222
+ prefix_len = self.cached_prefix_emb.shape[1]
223
+ if input_ids.shape[1] != 1:
224
+ # gen_inputs = input_ids[:, prefix_len:]
225
+ # gen_emb = self.embeddings(gen_inputs)
226
+ # gen_emb = gen_emb + self.pos_embedding(gen_emb)
227
+ gen_emb = self.cached_prefix_emb
228
+ if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
229
+ prefix_emb = self.cached_prefix_emb.repeat_interleave(
230
+ gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
231
+ )
232
+ else:
233
+ prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
234
+ # emb = torch.cat([prefix_emb, gen_emb], dim=1)
235
+ emb = gen_emb
236
+ else:
237
+ emb = self.embeddings(input_ids)
238
+ emb = emb + self.pos_embedding.get_fixed_embedding(
239
+ attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
240
+ )
241
+ transformer_outputs = self.transformer(
242
+ inputs_embeds=emb,
243
+ past_key_values=past_key_values,
244
+ attention_mask=attention_mask,
245
+ token_type_ids=token_type_ids,
246
+ position_ids=position_ids,
247
+ head_mask=head_mask,
248
+ encoder_hidden_states=encoder_hidden_states,
249
+ encoder_attention_mask=encoder_attention_mask,
250
+ use_cache=use_cache,
251
+ output_attentions=output_attentions,
252
+ output_hidden_states=output_hidden_states,
253
+ return_dict=return_dict,
254
+ )
255
+ hidden_states = transformer_outputs[0]
256
+ lm_logits = self.lm_head(hidden_states)
257
+
258
+ if not return_dict:
259
+ return (lm_logits,) + transformer_outputs[1:]
260
+
261
+ return CausalLMOutputWithCrossAttentions(
262
+ loss=None,
263
+ logits=lm_logits,
264
+ past_key_values=transformer_outputs.past_key_values,
265
+ hidden_states=transformer_outputs.hidden_states,
266
+ attentions=transformer_outputs.attentions,
267
+ cross_attentions=transformer_outputs.cross_attentions,
268
+ )
269
+
270
+ @staticmethod
271
+ def _reorder_cache(past, beam_idx):
272
+ """
273
+ This function is used to re-order the :obj:`past_key_values` cache if
274
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
275
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
276
+ """
277
+ return tuple(
278
+ tuple(
279
+ past_state.index_select(0, beam_idx.to(past_state.device))
280
+ for past_state in layer_past
281
+ )
282
+ for layer_past in past
283
+ )
284
+
285
+
286
+ def null_position_embeddings(range, dim):
287
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
288
+
289
+
290
+ class LearnedPositionEmbeddings(nn.Module):
291
+ def __init__(self, seq_len, model_dim, init=0.02, relative=False):
292
+ super().__init__()
293
+ # nn.Embedding
294
+ self.emb = torch.nn.Embedding(seq_len, model_dim)
295
+ # Initializing this way is standard for GPT-2
296
+ self.emb.weight.data.normal_(mean=0.0, std=init)
297
+ self.relative = relative
298
+ self.seq_len = seq_len
299
+
300
+ def forward(self, x):
301
+ sl = x.shape[1]
302
+ if self.relative:
303
+ start = random.randint(sl, self.seq_len) - sl
304
+ return self.emb(torch.arange(start, start + sl, device=x.device))
305
+ else:
306
+ return self.emb(torch.arange(0, sl, device=x.device))
307
+
308
+ def get_fixed_embedding(self, ind, dev):
309
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
310
+
311
+
312
+ def build_hf_gpt_transformer(
313
+ layers,
314
+ model_dim,
315
+ heads,
316
+ max_mel_seq_len,
317
+ max_text_seq_len,
318
+ max_prompt_len,
319
+ checkpointing,
320
+ ):
321
+ """
322
+ GPT-2 implemented by the HuggingFace library.
323
+ """
324
+
325
+ gpt_config = GPT2Config(
326
+ vocab_size=256, # Unused.
327
+ n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
328
+ n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
329
+ n_embd=model_dim,
330
+ n_layer=layers,
331
+ n_head=heads,
332
+ gradient_checkpointing=checkpointing,
333
+ use_cache=not checkpointing,
334
+ )
335
+ gpt = GPT2Model(gpt_config)
336
+ # Override the built in positional embeddings
337
+ del gpt.wpe
338
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
339
+ # Built-in token embeddings are unused.
340
+ del gpt.wte
341
+
342
+ mel_pos_emb = (
343
+ LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
344
+ if max_mel_seq_len != -1
345
+ else functools.partial(null_position_embeddings, dim=model_dim)
346
+ )
347
+ text_pos_emb = (
348
+ LearnedPositionEmbeddings(max_text_seq_len, model_dim)
349
+ if max_mel_seq_len != -1
350
+ else functools.partial(null_position_embeddings, dim=model_dim)
351
+ )
352
+ # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
353
+ return gpt, mel_pos_emb, text_pos_emb, None, None
354
+
355
+
356
+ class Speech_LLM_GPT2(nn.Module):
357
+ def __init__(
358
+ self,
359
+ start_text_token,
360
+ stop_text_token,
361
+ num_text_tokens,
362
+ start_audio_token,
363
+ stop_audio_token,
364
+ num_audio_tokens,
365
+ llm_hidden_size,
366
+ llm_intermediate_size,
367
+ llm_num_layers,
368
+ llm_num_heads,
369
+ llm_max_audio_seq_len,
370
+ llm_max_text_seq_len,
371
+ llm_max_prompt_len,
372
+ code_stride_len=640,
373
+ max_conditioning_inputs=1,
374
+ label_smoothing=0.0,
375
+ checkpointing=False,
376
+ ):
377
+ """
378
+ Args:
379
+
380
+ """
381
+ super().__init__()
382
+
383
+ self.label_smoothing = label_smoothing
384
+ # text token config
385
+ self.start_text_token = start_text_token
386
+ self.stop_text_token = stop_text_token
387
+ self.num_text_tokens = num_text_tokens
388
+
389
+ # audio token config
390
+ self.start_audio_token = start_audio_token
391
+ self.stop_audio_token = stop_audio_token
392
+ self.num_audio_tokens = num_audio_tokens
393
+
394
+ # prompts token config
395
+ self.start_prompt_token = start_audio_token
396
+ self.stop_prompt_token = stop_audio_token
397
+
398
+ # other config
399
+ self.max_conditioning_inputs = max_conditioning_inputs
400
+
401
+ # length configs
402
+ self.max_text_len = llm_max_text_seq_len + 2 # add <bos> <eos>
403
+ self.max_prompt_len = llm_max_prompt_len
404
+ self.max_audio_len = llm_max_audio_seq_len + 2 + self.max_conditioning_inputs
405
+ self.max_gen_audio_tokens = (
406
+ llm_max_audio_seq_len - self.max_conditioning_inputs - 2
407
+ )
408
+ self.code_stride_len = code_stride_len
409
+
410
+ # model config
411
+ self.llm_hidden_size = llm_hidden_size
412
+ self.llm_intermediate_size = llm_intermediate_size
413
+ self.llm_num_layers = llm_num_layers
414
+ self.llm_num_heads = llm_num_heads
415
+
416
+ # text embedding and audio embeddings
417
+ self.text_embedding = nn.Embedding(self.num_text_tokens, self.llm_hidden_size)
418
+ self.audio_embedding = nn.Embedding(self.num_audio_tokens, self.llm_hidden_size)
419
+
420
+ # low-level llm model
421
+ self.gpt2, self.audio_pos_embedding, self.text_pos_embedding, _, _ = (
422
+ build_hf_gpt_transformer(
423
+ layers=self.llm_num_layers,
424
+ model_dim=self.llm_hidden_size,
425
+ heads=self.llm_num_heads,
426
+ max_mel_seq_len=self.max_audio_len,
427
+ max_text_seq_len=self.max_text_len,
428
+ max_prompt_len=self.max_prompt_len,
429
+ checkpointing=checkpointing,
430
+ )
431
+ )
432
+
433
+ # text and audio linear
434
+ self.final_norm = nn.LayerNorm(self.llm_hidden_size)
435
+ self.text_head = nn.Linear(self.llm_hidden_size, self.num_text_tokens)
436
+ self.audio_head = nn.Linear(self.llm_hidden_size, self.num_audio_tokens)
437
+
438
+ # speaker特征变换
439
+ self.reference_embedding = nn.Sequential(
440
+ nn.Linear(512, 256),
441
+ nn.Tanh(),
442
+ nn.Linear(256, self.llm_hidden_size),
443
+ )
444
+
445
+ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
446
+ """_summary_
447
+
448
+ Args:
449
+ kv_cache (bool, optional): _description_. Defaults to True.
450
+ use_deepspeed (bool, optional): _description_. Defaults to False.
451
+ """
452
+ seq_length = self.max_audio_len + self.max_text_len + self.max_prompt_len + 1
453
+
454
+ gpt_config = GPT2Config(
455
+ vocab_size=self.num_audio_tokens,
456
+ n_positions=seq_length,
457
+ n_ctx=seq_length,
458
+ n_embd=self.llm_hidden_size,
459
+ n_layer=self.llm_num_layers,
460
+ n_head=self.llm_num_heads,
461
+ gradient_checkpointing=False,
462
+ use_cache=True,
463
+ )
464
+
465
+ # normal inference model
466
+ self.gpt_inference = GPT2InferenceModel(
467
+ config=gpt_config,
468
+ gpt=self.gpt2,
469
+ pos_emb=self.audio_pos_embedding,
470
+ embeddings=self.audio_embedding,
471
+ norm=self.final_norm,
472
+ linear=self.audio_head,
473
+ kv_cache=kv_cache,
474
+ )
475
+
476
+ # in-context inference model
477
+ self.gpt_inference_ic = GPT2ICInferenceModel(
478
+ config=gpt_config,
479
+ gpt=self.gpt2,
480
+ pos_emb=self.audio_pos_embedding,
481
+ embeddings=self.audio_embedding,
482
+ norm=self.final_norm,
483
+ linear=self.audio_head,
484
+ kv_cache=kv_cache,
485
+ )
486
+
487
+ self.gpt2.wte = self.audio_embedding
488
+
489
+ # --------------------------- normal inference ---------------------------
490
+ def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
491
+ self.compute_embeddings(cond_latents, text_inputs)
492
+ return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
493
+
494
+ def compute_embeddings(self, cond_latents, text_inputs):
495
+
496
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
497
+ text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
498
+ emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
499
+ emb = torch.cat([cond_latents, emb], dim=1)
500
+ self.gpt_inference.store_prefix_emb(emb)
501
+ gpt_inputs = torch.full(
502
+ (
503
+ emb.shape[0],
504
+ emb.shape[1] + 1, # +1 for the start_audio_token
505
+ ),
506
+ fill_value=1,
507
+ dtype=torch.long,
508
+ device=text_inputs.device,
509
+ )
510
+ gpt_inputs[:, -1] = self.start_audio_token
511
+ return gpt_inputs
512
+
513
+ def generate(self, cond_latents, text_inputs, **hf_generate_kwargs):
514
+ gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
515
+ gen = self.gpt_inference.generate(
516
+ gpt_inputs,
517
+ bos_token_id=self.start_audio_token,
518
+ pad_token_id=self.stop_audio_token,
519
+ eos_token_id=self.stop_audio_token,
520
+ max_length=self.max_gen_audio_tokens + gpt_inputs.shape[-1],
521
+ **hf_generate_kwargs,
522
+ )
523
+ if "return_dict_in_generate" in hf_generate_kwargs:
524
+ return gen.sequences[:, gpt_inputs.shape[1] :], gen
525
+ return gen[:, gpt_inputs.shape[1] :]
526
+
527
+ # --------------------------- normal inference --------------------------
528
+
529
+ # --------------------------- IC inference ---------------------------
530
+ def compute_embeddings_ic(self, cond_latents, text_inputs, prompt_tokens):
531
+ """_summary_
532
+
533
+ Args:
534
+ cond_latents (_type_): speaker embedding
535
+ text_inputs (_type_): text tokens
536
+ prompt_tokens (_type_): prompts_tokens
537
+
538
+ Returns:
539
+ _type_: _description_
540
+ """
541
+
542
+ # text embeddings
543
+ text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
544
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
545
+
546
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
547
+ text_inputs
548
+ )
549
+
550
+ # prompt_tokens
551
+
552
+ prompt_tokens = F.pad(prompt_tokens, (1, 0), value=self.start_audio_token)
553
+ audio_emb = self.audio_embedding(prompt_tokens) + self.audio_pos_embedding(
554
+ prompt_tokens
555
+ )
556
+
557
+ emb = torch.cat([cond_latents, text_emb, audio_emb], dim=1)
558
+
559
+ self.gpt_inference_ic.store_prefix_emb(emb)
560
+ gpt_inputs = torch.full(
561
+ (emb.shape[0], emb.shape[1]),
562
+ fill_value=1,
563
+ dtype=torch.long,
564
+ device=text_inputs.device,
565
+ )
566
+ return gpt_inputs
567
+
568
+ def generate_ic(
569
+ self, cond_latents, text_inputs, prompt_tokens, **hf_generate_kwargs
570
+ ):
571
+ """_summary_
572
+
573
+ Args:
574
+ cond_latents (_type_): _description_
575
+ text_inputs (_type_): _description_
576
+ prompt_tokens (_type_): _description_
577
+
578
+ Returns:
579
+ _type_: _description_
580
+ """
581
+ gpt_inputs = self.compute_embeddings_ic(
582
+ cond_latents, text_inputs, prompt_tokens
583
+ )
584
+ gen = self.gpt_inference_ic.generate(
585
+ gpt_inputs,
586
+ bos_token_id=self.start_audio_token,
587
+ pad_token_id=self.stop_audio_token,
588
+ eos_token_id=self.stop_audio_token,
589
+ max_length=self.max_gen_audio_tokens + gpt_inputs.shape[-1],
590
+ **hf_generate_kwargs,
591
+ )
592
+ if "return_dict_in_generate" in hf_generate_kwargs:
593
+ return gen.sequences[:, gpt_inputs.shape[1] :], gen
594
+
595
+ return gen[:, gpt_inputs.shape[1] :]
596
+
597
+ # --------------------------- IC inference ---------------------------
598
+
599
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
600
+ return self.gpt_inference.generate_stream(
601
+ fake_inputs,
602
+ bos_token_id=self.start_audio_token,
603
+ pad_token_id=self.stop_audio_token,
604
+ eos_token_id=self.stop_audio_token,
605
+ max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
606
+ do_stream=True,
607
+ **hf_generate_kwargs,
608
+ )
fireredtts/modules/semantic_tokenizer/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from .hubert import HuBERT
5
+ from .semantic_tokenizer import SemanticVQVAE
6
+
7
+
8
+ class SemanticTokenizer:
9
+
10
+ def __init__(self, config, path):
11
+ self.model = SemanticVQVAE(**config)
12
+ self.model.load_state_dict(
13
+ torch.load(os.path.join(path, "codec.bin"), map_location="cpu"), strict=True
14
+ )
15
+
16
+ hubert = HuBERT(os.path.join(path, "hubert.pt"))
17
+ for name, param in hubert.named_parameters():
18
+ param.requires_grad = False
19
+ self.model.ssl_extractor = hubert
20
+
21
+ if torch.cuda.is_available():
22
+ self.model = self.model.cuda()
23
+ self.model.eval()
24
+
25
+ def __call__(self, wavs, wav_lengths):
26
+ tokens, token_lengths, spk_embeddings = self.extract(wavs, wav_lengths)
27
+ return tokens, token_lengths, spk_embeddings
28
+
29
+ def extract(self, wavs, wav_lengths):
30
+ saved_features = self.model.extract_speech_tokens(wavs, wav_lengths)
31
+
32
+ tokens = saved_features["token"]
33
+ token_lengths = saved_features["token_length"]
34
+ spk_embeddings = saved_features["spk"]
35
+
36
+ return tokens, token_lengths, spk_embeddings
fireredtts/modules/semantic_tokenizer/audio.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from librosa.filters import mel as librosa_mel_fn
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ import math
6
+ import numpy as np
7
+ import torch
8
+ import torchaudio
9
+
10
+
11
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
12
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
13
+
14
+
15
+ def dynamic_range_decompression(x, C=1):
16
+ return np.exp(x) / C
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ return torch.log(torch.clamp(x, min=clip_val) * C)
21
+
22
+
23
+ def dynamic_range_decompression_torch(x, C=1):
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ output = dynamic_range_compression_torch(magnitudes)
29
+ return output
30
+
31
+
32
+ def spectral_de_normalize_torch(magnitudes):
33
+ output = dynamic_range_decompression_torch(magnitudes)
34
+ return output
35
+
36
+
37
+ class TorchMelSpectrogram(nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ filter_length=1024,
42
+ hop_length=200,
43
+ win_length=800,
44
+ n_mel_channels=80,
45
+ mel_fmin=0,
46
+ mel_fmax=8000,
47
+ sampling_rate=16000,
48
+ sampling_rate_org=None,
49
+ normalize=False,
50
+ mel_norm_file=None,
51
+ scale=1.0,
52
+ padding="center",
53
+ style="Tortoise",
54
+ ):
55
+
56
+ super().__init__()
57
+ self.style = style
58
+ self.filter_length = filter_length
59
+ self.hop_length = hop_length
60
+ self.win_length = win_length
61
+ self.n_mel_channels = n_mel_channels
62
+ self.mel_fmin = mel_fmin
63
+ self.mel_fmax = mel_fmax
64
+ self.sampling_rate = sampling_rate
65
+ self.sampling_rate_org = (
66
+ sampling_rate_org if sampling_rate_org is not None else sampling_rate
67
+ )
68
+
69
+ self.mel_basis = {}
70
+ self.hann_window = {}
71
+
72
+ self.scale = scale
73
+
74
+ def forward(self, inp, length=None):
75
+ if len(inp.shape) == 3:
76
+ inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
77
+ assert len(inp.shape) == 2
78
+
79
+ if self.sampling_rate_org != self.sampling_rate:
80
+ inp = torchaudio.functional.resample(
81
+ inp, self.sampling_rate_org, self.sampling_rate
82
+ )
83
+
84
+ y = inp
85
+ if len(list(self.mel_basis.keys())) == 0:
86
+ mel = librosa_mel_fn(
87
+ sr=self.sampling_rate,
88
+ n_fft=self.filter_length,
89
+ n_mels=self.n_mel_channels,
90
+ fmin=self.mel_fmin,
91
+ fmax=self.mel_fmax,
92
+ )
93
+ self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = (
94
+ torch.from_numpy(mel).float().to(y.device)
95
+ )
96
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(
97
+ y.device
98
+ )
99
+
100
+ y = torch.nn.functional.pad(
101
+ y.unsqueeze(1),
102
+ (
103
+ int((self.filter_length - self.hop_length) / 2),
104
+ int((self.filter_length - self.hop_length) / 2),
105
+ ),
106
+ mode="reflect",
107
+ )
108
+ y = y.squeeze(1)
109
+
110
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
111
+ spec = torch.stft(
112
+ y,
113
+ self.filter_length,
114
+ hop_length=self.hop_length,
115
+ win_length=self.win_length,
116
+ window=self.hann_window[str(y.device)],
117
+ center=False,
118
+ pad_mode="reflect",
119
+ normalized=False,
120
+ onesided=True,
121
+ return_complex=True,
122
+ )
123
+ spec = torch.view_as_real(spec)
124
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
125
+
126
+ spec = torch.matmul(
127
+ self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], spec
128
+ )
129
+ spec = spectral_normalize_torch(spec)
130
+
131
+ max_mel_length = math.ceil(y.shape[-1] / self.hop_length)
132
+ spec = spec[..., :max_mel_length].transpose(1, 2)
133
+
134
+ if length is None:
135
+ return spec
136
+ else:
137
+ spec_len = torch.ceil(length / self.hop_length).clamp(max=spec.shape[1])
138
+ return spec, spec_len
fireredtts/modules/semantic_tokenizer/ecapa_tdnn.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A popular speaker recognition and diarization model.
2
+
3
+ Authors
4
+ * Hwidong Na 2020
5
+ """
6
+
7
+ import math
8
+ import os
9
+ import torch # noqa: F401
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def length_to_mask(length, max_len=None, dtype=None, device=None):
15
+ """Creates a binary mask for each sequence.
16
+
17
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
18
+
19
+ Arguments
20
+ ---------
21
+ length : torch.LongTensor
22
+ Containing the length of each sequence in the batch. Must be 1D.
23
+ max_len : int
24
+ Max length for the mask, also the size of the second dimension.
25
+ dtype : torch.dtype, default: None
26
+ The dtype of the generated mask.
27
+ device: torch.device, default: None
28
+ The device to put the mask variable.
29
+
30
+ Returns
31
+ -------
32
+ mask : tensor
33
+ The binary mask.
34
+
35
+ Example
36
+ -------
37
+ >>> length=torch.Tensor([1,2,3])
38
+ >>> mask=length_to_mask(length)
39
+ >>> mask
40
+ tensor([[1., 0., 0.],
41
+ [1., 1., 0.],
42
+ [1., 1., 1.]])
43
+ """
44
+ assert len(length.shape) == 1
45
+
46
+ if max_len is None:
47
+ max_len = length.max().long().item() # using arange to generate mask
48
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
49
+ len(length), max_len
50
+ ) < length.unsqueeze(1)
51
+
52
+ if dtype is None:
53
+ dtype = length.dtype
54
+
55
+ if device is None:
56
+ device = length.device
57
+
58
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
59
+ return mask
60
+
61
+
62
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
63
+ """This function computes the number of elements to add for zero-padding.
64
+
65
+ Arguments
66
+ ---------
67
+ L_in : int
68
+ stride: int
69
+ kernel_size : int
70
+ dilation : int
71
+ """
72
+ if stride > 1:
73
+ n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
74
+ L_out = stride * (n_steps - 1) + kernel_size * dilation
75
+ padding = [kernel_size // 2, kernel_size // 2]
76
+
77
+ else:
78
+ L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
79
+
80
+ padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
81
+ return padding
82
+
83
+
84
+ class Conv1d(nn.Module):
85
+ """This function implements 1d convolution.
86
+
87
+ Arguments
88
+ ---------
89
+ out_channels : int
90
+ It is the number of output channels.
91
+ kernel_size : int
92
+ Kernel size of the convolutional filters.
93
+ input_shape : tuple
94
+ The shape of the input. Alternatively use ``in_channels``.
95
+ in_channels : int
96
+ The number of input channels. Alternatively use ``input_shape``.
97
+ stride : int
98
+ Stride factor of the convolutional filters. When the stride factor > 1,
99
+ a decimation in time is performed.
100
+ dilation : int
101
+ Dilation factor of the convolutional filters.
102
+ padding : str
103
+ (same, valid, causal). If "valid", no padding is performed.
104
+ If "same" and stride is 1, output shape is the same as the input shape.
105
+ "causal" results in causal (dilated) convolutions.
106
+ padding_mode : str
107
+ This flag specifies the type of padding. See torch.nn documentation
108
+ for more information.
109
+ skip_transpose : bool
110
+ If False, uses batch x time x channel convention of speechbrain.
111
+ If True, uses batch x channel x time convention.
112
+
113
+ Example
114
+ -------
115
+ >>> inp_tensor = torch.rand([10, 40, 16])
116
+ >>> cnn_1d = Conv1d(
117
+ ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
118
+ ... )
119
+ >>> out_tensor = cnn_1d(inp_tensor)
120
+ >>> out_tensor.shape
121
+ torch.Size([10, 40, 8])
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ out_channels,
127
+ kernel_size,
128
+ input_shape=None,
129
+ in_channels=None,
130
+ stride=1,
131
+ dilation=1,
132
+ padding="same",
133
+ groups=1,
134
+ bias=True,
135
+ padding_mode="reflect",
136
+ skip_transpose=True,
137
+ ):
138
+ super().__init__()
139
+ self.kernel_size = kernel_size
140
+ self.stride = stride
141
+ self.dilation = dilation
142
+ self.padding = padding
143
+ self.padding_mode = padding_mode
144
+ self.unsqueeze = False
145
+ self.skip_transpose = skip_transpose
146
+
147
+ if input_shape is None and in_channels is None:
148
+ raise ValueError("Must provide one of input_shape or in_channels")
149
+
150
+ if in_channels is None:
151
+ in_channels = self._check_input_shape(input_shape)
152
+
153
+ self.conv = nn.Conv1d(
154
+ in_channels,
155
+ out_channels,
156
+ self.kernel_size,
157
+ stride=self.stride,
158
+ dilation=self.dilation,
159
+ padding=0,
160
+ groups=groups,
161
+ bias=bias,
162
+ )
163
+
164
+ def forward(self, x):
165
+ """Returns the output of the convolution.
166
+
167
+ Arguments
168
+ ---------
169
+ x : torch.Tensor (batch, time, channel)
170
+ input to convolve. 2d or 4d tensors are expected.
171
+ """
172
+
173
+ if not self.skip_transpose:
174
+ x = x.transpose(1, -1)
175
+
176
+ if self.unsqueeze:
177
+ x = x.unsqueeze(1)
178
+
179
+ if self.padding == "same":
180
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
181
+
182
+ elif self.padding == "causal":
183
+ num_pad = (self.kernel_size - 1) * self.dilation
184
+ x = F.pad(x, (num_pad, 0))
185
+
186
+ elif self.padding == "valid":
187
+ pass
188
+
189
+ else:
190
+ raise ValueError(
191
+ "Padding must be 'same', 'valid' or 'causal'. Got " + self.padding
192
+ )
193
+
194
+ wx = self.conv(x)
195
+
196
+ if self.unsqueeze:
197
+ wx = wx.squeeze(1)
198
+
199
+ if not self.skip_transpose:
200
+ wx = wx.transpose(1, -1)
201
+
202
+ return wx
203
+
204
+ def _manage_padding(
205
+ self,
206
+ x,
207
+ kernel_size: int,
208
+ dilation: int,
209
+ stride: int,
210
+ ):
211
+ """This function performs zero-padding on the time axis
212
+ such that their lengths is unchanged after the convolution.
213
+
214
+ Arguments
215
+ ---------
216
+ x : torch.Tensor
217
+ Input tensor.
218
+ kernel_size : int
219
+ Size of kernel.
220
+ dilation : int
221
+ Dilation used.
222
+ stride : int
223
+ Stride.
224
+ """
225
+
226
+ # Detecting input shape
227
+ L_in = x.shape[-1]
228
+
229
+ # Time padding
230
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
231
+
232
+ # Applying padding
233
+ x = F.pad(x, padding, mode=self.padding_mode)
234
+
235
+ return x
236
+
237
+ def _check_input_shape(self, shape):
238
+ """Checks the input shape and returns the number of input channels."""
239
+
240
+ if len(shape) == 2:
241
+ self.unsqueeze = True
242
+ in_channels = 1
243
+ elif self.skip_transpose:
244
+ in_channels = shape[1]
245
+ elif len(shape) == 3:
246
+ in_channels = shape[2]
247
+ else:
248
+ raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
249
+
250
+ # Kernel size must be odd
251
+ if self.kernel_size % 2 == 0:
252
+ raise ValueError(
253
+ "The field kernel size must be an odd number. Got %s."
254
+ % (self.kernel_size)
255
+ )
256
+ return in_channels
257
+
258
+
259
+ class Fp32BatchNorm(nn.Module):
260
+ def __init__(self, sync=True, *args, **kwargs):
261
+ super().__init__()
262
+
263
+ if (
264
+ not torch.distributed.is_initialized()
265
+ or torch.distributed.get_world_size() == 1
266
+ ):
267
+ sync = False
268
+
269
+ if sync:
270
+ self.bn = nn.SyncBatchNorm(*args, **kwargs)
271
+ else:
272
+ self.bn = nn.BatchNorm1d(*args, **kwargs)
273
+
274
+ self.sync = sync
275
+
276
+ def forward(self, input):
277
+ if self.bn.running_mean.dtype != torch.float:
278
+ if self.sync:
279
+ self.bn.running_mean = self.bn.running_mean.float()
280
+ self.bn.running_var = self.bn.running_var.float()
281
+ if self.bn.affine:
282
+ try:
283
+ self.bn.weight = self.bn.weight.float()
284
+ self.bn.bias = self.bn.bias.float()
285
+ except:
286
+ self.bn.float()
287
+ else:
288
+ self.bn.float()
289
+
290
+ output = self.bn(input.float())
291
+ return output.type_as(input)
292
+
293
+
294
+ class BatchNorm1d(nn.Module):
295
+ """Applies 1d batch normalization to the input tensor.
296
+
297
+ Arguments
298
+ ---------
299
+ input_shape : tuple
300
+ The expected shape of the input. Alternatively, use ``input_size``.
301
+ input_size : int
302
+ The expected size of the input. Alternatively, use ``input_shape``.
303
+ eps : float
304
+ This value is added to std deviation estimation to improve the numerical
305
+ stability.
306
+ momentum : float
307
+ It is a value used for the running_mean and running_var computation.
308
+ affine : bool
309
+ When set to True, the affine parameters are learned.
310
+ track_running_stats : bool
311
+ When set to True, this module tracks the running mean and variance,
312
+ and when set to False, this module does not track such statistics.
313
+ combine_batch_time : bool
314
+ When true, it combines batch an time axis.
315
+
316
+
317
+ Example
318
+ -------
319
+ >>> input = torch.randn(100, 10)
320
+ >>> norm = BatchNorm1d(input_shape=input.shape)
321
+ >>> output = norm(input)
322
+ >>> output.shape
323
+ torch.Size([100, 10])
324
+ """
325
+
326
+ def __init__(
327
+ self,
328
+ input_shape=None,
329
+ input_size=None,
330
+ eps=1e-05,
331
+ momentum=0.1,
332
+ affine=True,
333
+ track_running_stats=True,
334
+ combine_batch_time=False,
335
+ skip_transpose=True,
336
+ enabled=True,
337
+ ):
338
+ super().__init__()
339
+ self.combine_batch_time = combine_batch_time
340
+ self.skip_transpose = skip_transpose
341
+
342
+ if input_size is None and skip_transpose:
343
+ input_size = input_shape[1]
344
+ elif input_size is None:
345
+ input_size = input_shape[-1]
346
+
347
+ if enabled:
348
+ self.norm = Fp32BatchNorm(
349
+ num_features=input_size,
350
+ eps=eps,
351
+ momentum=momentum,
352
+ affine=affine,
353
+ track_running_stats=track_running_stats,
354
+ )
355
+ else:
356
+ self.norm = nn.Identity()
357
+
358
+ def forward(self, x):
359
+ """Returns the normalized input tensor.
360
+
361
+ Arguments
362
+ ---------
363
+ x : torch.Tensor (batch, time, [channels])
364
+ input to normalize. 2d or 3d tensors are expected in input
365
+ 4d tensors can be used when combine_dims=True.
366
+ """
367
+ shape_or = x.shape
368
+ if self.combine_batch_time:
369
+ if x.ndim == 3:
370
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
371
+ else:
372
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
373
+
374
+ elif not self.skip_transpose:
375
+ x = x.transpose(-1, 1)
376
+
377
+ x_n = self.norm(x)
378
+
379
+ if self.combine_batch_time:
380
+ x_n = x_n.reshape(shape_or)
381
+ elif not self.skip_transpose:
382
+ x_n = x_n.transpose(1, -1)
383
+
384
+ return x_n
385
+
386
+
387
+ class Linear(torch.nn.Module):
388
+ """Computes a linear transformation y = wx + b.
389
+
390
+ Arguments
391
+ ---------
392
+ n_neurons : int
393
+ It is the number of output neurons (i.e, the dimensionality of the
394
+ output).
395
+ bias : bool
396
+ If True, the additive bias b is adopted.
397
+ combine_dims : bool
398
+ If True and the input is 4D, combine 3rd and 4th dimensions of input.
399
+
400
+ Example
401
+ -------
402
+ >>> inputs = torch.rand(10, 50, 40)
403
+ >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
404
+ >>> output = lin_t(inputs)
405
+ >>> output.shape
406
+ torch.Size([10, 50, 100])
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ n_neurons,
412
+ input_shape=None,
413
+ input_size=None,
414
+ bias=True,
415
+ combine_dims=False,
416
+ ):
417
+ super().__init__()
418
+ self.combine_dims = combine_dims
419
+
420
+ if input_shape is None and input_size is None:
421
+ raise ValueError("Expected one of input_shape or input_size")
422
+
423
+ if input_size is None:
424
+ input_size = input_shape[-1]
425
+ if len(input_shape) == 4 and self.combine_dims:
426
+ input_size = input_shape[2] * input_shape[3]
427
+
428
+ # Weights are initialized following pytorch approach
429
+ self.w = nn.Linear(input_size, n_neurons, bias=bias)
430
+
431
+ def forward(self, x):
432
+ """Returns the linear transformation of input tensor.
433
+
434
+ Arguments
435
+ ---------
436
+ x : torch.Tensor
437
+ Input to transform linearly.
438
+ """
439
+ if x.ndim == 4 and self.combine_dims:
440
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
441
+
442
+ wx = self.w(x)
443
+
444
+ return wx
445
+
446
+
447
+ class TDNNBlock(nn.Module):
448
+ """An implementation of TDNN.
449
+
450
+ Arguments
451
+ ----------
452
+ in_channels : int
453
+ Number of input channels.
454
+ out_channels : int
455
+ The number of output channels.
456
+ kernel_size : int
457
+ The kernel size of the TDNN blocks.
458
+ dilation : int
459
+ The dilation of the Res2Net block.
460
+ activation : torch class
461
+ A class for constructing the activation layers.
462
+
463
+ Example
464
+ -------
465
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
466
+ >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
467
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
468
+ >>> out_tensor.shape
469
+ torch.Size([8, 120, 64])
470
+ """
471
+
472
+ def __init__(
473
+ self,
474
+ in_channels,
475
+ out_channels,
476
+ kernel_size,
477
+ dilation,
478
+ activation=nn.ReLU,
479
+ batch_norm=True,
480
+ ):
481
+ super(TDNNBlock, self).__init__()
482
+ self.conv = Conv1d(
483
+ in_channels=in_channels,
484
+ out_channels=out_channels,
485
+ kernel_size=kernel_size,
486
+ dilation=dilation,
487
+ )
488
+ self.activation = activation()
489
+ self.norm = BatchNorm1d(input_size=out_channels, enabled=batch_norm)
490
+
491
+ def forward(self, x):
492
+ return self.norm(self.activation(self.conv(x)))
493
+
494
+
495
+ class Res2NetBlock(torch.nn.Module):
496
+ """An implementation of Res2NetBlock w/ dilation.
497
+
498
+ Arguments
499
+ ---------
500
+ in_channels : int
501
+ The number of channels expected in the input.
502
+ out_channels : int
503
+ The number of output channels.
504
+ scale : int
505
+ The scale of the Res2Net block.
506
+ kernel_size: int
507
+ The kernel size of the Res2Net block.
508
+ dilation : int
509
+ The dilation of the Res2Net block.
510
+
511
+ Example
512
+ -------
513
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
514
+ >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
515
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
516
+ >>> out_tensor.shape
517
+ torch.Size([8, 120, 64])
518
+ """
519
+
520
+ def __init__(
521
+ self,
522
+ in_channels,
523
+ out_channels,
524
+ scale=8,
525
+ kernel_size=3,
526
+ dilation=1,
527
+ batch_norm=True,
528
+ ):
529
+ super(Res2NetBlock, self).__init__()
530
+ assert in_channels % scale == 0
531
+ assert out_channels % scale == 0
532
+
533
+ in_channel = in_channels // scale
534
+ hidden_channel = out_channels // scale
535
+
536
+ self.blocks = nn.ModuleList(
537
+ [
538
+ TDNNBlock(
539
+ in_channel,
540
+ hidden_channel,
541
+ kernel_size=kernel_size,
542
+ dilation=dilation,
543
+ batch_norm=batch_norm,
544
+ )
545
+ for i in range(scale - 1)
546
+ ]
547
+ )
548
+ self.scale = scale
549
+
550
+ def forward(self, x):
551
+ y = []
552
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
553
+ if i == 0:
554
+ y_i = x_i
555
+ elif i == 1:
556
+ y_i = self.blocks[i - 1](x_i)
557
+ else:
558
+ y_i = self.blocks[i - 1](x_i + y_i)
559
+ y.append(y_i)
560
+ y = torch.cat(y, dim=1)
561
+ return y
562
+
563
+
564
+ class SEBlock(nn.Module):
565
+ """An implementation of squeeze-and-excitation block.
566
+
567
+ Arguments
568
+ ---------
569
+ in_channels : int
570
+ The number of input channels.
571
+ se_channels : int
572
+ The number of output channels after squeeze.
573
+ out_channels : int
574
+ The number of output channels.
575
+
576
+ Example
577
+ -------
578
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
579
+ >>> se_layer = SEBlock(64, 16, 64)
580
+ >>> lengths = torch.rand((8,))
581
+ >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
582
+ >>> out_tensor.shape
583
+ torch.Size([8, 120, 64])
584
+ """
585
+
586
+ def __init__(self, in_channels, se_channels, out_channels):
587
+ super(SEBlock, self).__init__()
588
+
589
+ self.conv1 = Conv1d(
590
+ in_channels=in_channels, out_channels=se_channels, kernel_size=1
591
+ )
592
+ self.relu = torch.nn.ReLU(inplace=True)
593
+ self.conv2 = Conv1d(
594
+ in_channels=se_channels, out_channels=out_channels, kernel_size=1
595
+ )
596
+ self.sigmoid = torch.nn.Sigmoid()
597
+
598
+ def forward(self, x, lengths=None):
599
+ L = x.shape[-1]
600
+ if lengths is not None:
601
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
602
+ mask = mask.unsqueeze(1)
603
+ total = mask.sum(dim=2, keepdim=True)
604
+ s = (x * mask).sum(dim=2, keepdim=True) / total
605
+ else:
606
+ s = x.mean(dim=2, keepdim=True)
607
+
608
+ s = self.relu(self.conv1(s))
609
+ s = self.sigmoid(self.conv2(s))
610
+
611
+ return s * x
612
+
613
+
614
+ class AttentiveStatisticsPooling(nn.Module):
615
+ """This class implements an attentive statistic pooling layer for each channel.
616
+ It returns the concatenated mean and std of the input tensor.
617
+
618
+ Arguments
619
+ ---------
620
+ channels: int
621
+ The number of input channels.
622
+ attention_channels: int
623
+ The number of attention channels.
624
+
625
+ Example
626
+ -------
627
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
628
+ >>> asp_layer = AttentiveStatisticsPooling(64)
629
+ >>> lengths = torch.rand((8,))
630
+ >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
631
+ >>> out_tensor.shape
632
+ torch.Size([8, 1, 128])
633
+ """
634
+
635
+ def __init__(
636
+ self, channels, attention_channels=128, global_context=True, batch_norm=True
637
+ ):
638
+ super().__init__()
639
+
640
+ self.eps = 1e-12
641
+ self.global_context = global_context
642
+ if global_context:
643
+ self.tdnn = TDNNBlock(
644
+ channels * 3, attention_channels, 1, 1, batch_norm=batch_norm
645
+ )
646
+ else:
647
+ self.tdnn = TDNNBlock(
648
+ channels, attention_channels, 1, 1, batch_norm, batch_norm
649
+ )
650
+ self.tanh = nn.Tanh()
651
+ self.conv = Conv1d(
652
+ in_channels=attention_channels, out_channels=channels, kernel_size=1
653
+ )
654
+
655
+ def forward(self, x, lengths=None):
656
+ """Calculates mean and std for a batch (input tensor).
657
+
658
+ Arguments
659
+ ---------
660
+ x : torch.Tensor
661
+ Tensor of shape [N, C, L].
662
+ """
663
+ L = x.shape[-1]
664
+
665
+ def _compute_statistics(x, m, dim=2, eps=self.eps):
666
+ mean = (m * x).sum(dim)
667
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
668
+ return mean, std
669
+
670
+ if lengths is None:
671
+ lengths = torch.ones(x.shape[0], device=x.device)
672
+
673
+ # Make binary mask of shape [N, 1, L]
674
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
675
+ mask = mask.unsqueeze(1)
676
+
677
+ # Expand the temporal context of the pooling layer by allowing the
678
+ # self-attention to look at global properties of the utterance.
679
+ if self.global_context:
680
+ # torch.std is unstable for backward computation
681
+ # https://github.com/pytorch/pytorch/issues/4320
682
+ total = mask.sum(dim=2, keepdim=True).float()
683
+ mean, std = _compute_statistics(x, mask / total)
684
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
685
+ std = std.unsqueeze(2).repeat(1, 1, L)
686
+ attn = torch.cat([x, mean, std], dim=1)
687
+ else:
688
+ attn = x
689
+
690
+ # Apply layers
691
+ attn = self.conv(self.tanh(self.tdnn(attn)))
692
+
693
+ # Filter out zero-paddings
694
+ attn = attn.masked_fill(mask == 0, float("-inf"))
695
+
696
+ attn = F.softmax(attn, dim=2)
697
+ mean, std = _compute_statistics(x, attn)
698
+ # Append mean and std of the batch
699
+ pooled_stats = torch.cat((mean, std), dim=1)
700
+ pooled_stats = pooled_stats.unsqueeze(2)
701
+
702
+ return pooled_stats
703
+
704
+
705
+ class SERes2NetBlock(nn.Module):
706
+ """An implementation of building block in ECAPA-TDNN, i.e.,
707
+ TDNN-Res2Net-TDNN-SEBlock.
708
+
709
+ Arguments
710
+ ----------
711
+ out_channels: int
712
+ The number of output channels.
713
+ res2net_scale: int
714
+ The scale of the Res2Net block.
715
+ kernel_size: int
716
+ The kernel size of the TDNN blocks.
717
+ dilation: int
718
+ The dilation of the Res2Net block.
719
+ activation : torch class
720
+ A class for constructing the activation layers.
721
+
722
+ Example
723
+ -------
724
+ >>> x = torch.rand(8, 120, 64).transpose(1, 2)
725
+ >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
726
+ >>> out = conv(x).transpose(1, 2)
727
+ >>> out.shape
728
+ torch.Size([8, 120, 64])
729
+ """
730
+
731
+ def __init__(
732
+ self,
733
+ in_channels,
734
+ out_channels,
735
+ res2net_scale=8,
736
+ se_channels=128,
737
+ kernel_size=1,
738
+ dilation=1,
739
+ activation=torch.nn.ReLU,
740
+ batch_norm=True,
741
+ ):
742
+ super().__init__()
743
+ self.out_channels = out_channels
744
+ self.tdnn1 = TDNNBlock(
745
+ in_channels,
746
+ out_channels,
747
+ kernel_size=1,
748
+ dilation=1,
749
+ activation=activation,
750
+ batch_norm=batch_norm,
751
+ )
752
+ self.res2net_block = Res2NetBlock(
753
+ out_channels,
754
+ out_channels,
755
+ res2net_scale,
756
+ kernel_size,
757
+ dilation,
758
+ batch_norm=batch_norm,
759
+ )
760
+ self.tdnn2 = TDNNBlock(
761
+ out_channels,
762
+ out_channels,
763
+ kernel_size=1,
764
+ dilation=1,
765
+ activation=activation,
766
+ batch_norm=batch_norm,
767
+ )
768
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
769
+
770
+ self.shortcut = None
771
+ if in_channels != out_channels:
772
+ self.shortcut = Conv1d(
773
+ in_channels=in_channels,
774
+ out_channels=out_channels,
775
+ kernel_size=1,
776
+ )
777
+
778
+ def forward(self, x, lengths=None):
779
+ residual = x
780
+ if self.shortcut:
781
+ residual = self.shortcut(x)
782
+
783
+ x = self.tdnn1(x)
784
+ x = self.res2net_block(x)
785
+ x = self.tdnn2(x)
786
+ x = self.se_block(x, lengths)
787
+
788
+ return x + residual
789
+
790
+
791
+ class ECAPA_TDNN(torch.nn.Module):
792
+ """An implementation of the speaker embedding model in a paper.
793
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
794
+ TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
795
+
796
+ Arguments
797
+ ---------
798
+ device : str
799
+ Device used, e.g., "cpu" or "cuda".
800
+ activation : torch class
801
+ A class for constructing the activation layers.
802
+ channels : list of ints
803
+ Output channels for TDNN/SERes2Net layer.
804
+ kernel_sizes : list of ints
805
+ List of kernel sizes for each layer.
806
+ dilations : list of ints
807
+ List of dilations for kernels in each layer.
808
+ lin_neurons : int
809
+ Number of neurons in linear layers.
810
+
811
+ Example
812
+ -------
813
+ >>> input_feats = torch.rand([5, 120, 80])
814
+ >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
815
+ >>> outputs = compute_embedding(input_feats)
816
+ >>> outputs.shape
817
+ torch.Size([5, 1, 192])
818
+ """
819
+
820
+ def __init__(
821
+ self,
822
+ input_size,
823
+ lin_neurons=192,
824
+ activation=torch.nn.ReLU,
825
+ channels=[512, 512, 512, 512, 1536],
826
+ kernel_sizes=[5, 3, 3, 3, 1],
827
+ dilations=[1, 2, 3, 4, 1],
828
+ attention_channels=128,
829
+ res2net_scale=8,
830
+ se_channels=128,
831
+ global_context=True,
832
+ batch_norm=True,
833
+ ):
834
+
835
+ super().__init__()
836
+ assert len(channels) == len(kernel_sizes)
837
+ assert len(channels) == len(dilations)
838
+ self.channels = channels
839
+ self.blocks = nn.ModuleList()
840
+
841
+ # The initial TDNN layer
842
+ self.blocks.append(
843
+ TDNNBlock(
844
+ input_size,
845
+ channels[0],
846
+ kernel_sizes[0],
847
+ dilations[0],
848
+ activation,
849
+ batch_norm=batch_norm,
850
+ )
851
+ )
852
+
853
+ # SE-Res2Net layers
854
+ for i in range(1, len(channels) - 1):
855
+ self.blocks.append(
856
+ SERes2NetBlock(
857
+ channels[i - 1],
858
+ channels[i],
859
+ res2net_scale=res2net_scale,
860
+ se_channels=se_channels,
861
+ kernel_size=kernel_sizes[i],
862
+ dilation=dilations[i],
863
+ activation=activation,
864
+ batch_norm=batch_norm,
865
+ )
866
+ )
867
+
868
+ # Multi-layer feature aggregation
869
+ self.mfa = TDNNBlock(
870
+ channels[-1],
871
+ channels[-1],
872
+ kernel_sizes[-1],
873
+ dilations[-1],
874
+ activation,
875
+ batch_norm=batch_norm,
876
+ )
877
+
878
+ # Attentive Statistical Pooling
879
+ self.asp = AttentiveStatisticsPooling(
880
+ channels[-1],
881
+ attention_channels=attention_channels,
882
+ global_context=global_context,
883
+ batch_norm=batch_norm,
884
+ )
885
+ self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2, enabled=batch_norm)
886
+
887
+ # Final linear transformation
888
+ self.fc = Conv1d(
889
+ in_channels=channels[-1] * 2,
890
+ out_channels=lin_neurons,
891
+ kernel_size=1,
892
+ )
893
+
894
+ # @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32)
895
+ def forward(self, x, lengths=None):
896
+ """Returns the embedding vector.
897
+
898
+ Arguments
899
+ ---------
900
+ x : torch.Tensor
901
+ Tensor of shape (batch, time, channel).
902
+ """
903
+ # Minimize transpose for efficiency
904
+ x = x.transpose(1, 2)
905
+
906
+ xl = []
907
+ for layer in self.blocks:
908
+ try:
909
+ x = layer(x, lengths=lengths)
910
+ except TypeError:
911
+ x = layer(x)
912
+ xl.append(x)
913
+
914
+ # Multi-layer feature aggregation
915
+ x = torch.cat(xl[1:], dim=1)
916
+ x = self.mfa(x)
917
+
918
+ # Attentive Statistical Pooling
919
+ x = self.asp(x, lengths=lengths)
920
+ x = self.asp_bn(x)
921
+
922
+ # Final linear transformation
923
+ x = self.fc(x)
924
+
925
+ x = x.squeeze(-1)
926
+ return x
927
+
928
+
929
+ if __name__ == "__main__":
930
+ model = ECAPA_TDNN(128, batch_norm=False)
931
+ # print(model)
fireredtts/modules/semantic_tokenizer/hubert.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fairseq import checkpoint_utils
2
+ from torch.nn.utils.rnn import pad_sequence
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def get_mask_from_lengths(lengths, max_len=None):
11
+ max_len = torch.max(lengths).item() if max_len is None else max_len
12
+ ids = torch.arange(0, max_len).to(lengths.device)
13
+ mask = ~(ids < lengths.unsqueeze(1)).bool()
14
+ return mask
15
+
16
+
17
+ class HuBERT(nn.Module):
18
+
19
+ def __init__(self, model_path, sampling_rate=16000):
20
+
21
+ super().__init__()
22
+
23
+ models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task(
24
+ [model_path],
25
+ suffix="",
26
+ )
27
+
28
+ model = models[0]
29
+ model = model.half()
30
+ model.eval()
31
+ self.model = model
32
+
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ self.sampling_rate = sampling_rate
37
+ self.normalize = saved_cfg.task.normalize
38
+
39
+ @torch.no_grad()
40
+ @torch.cuda.amp.autocast(enabled=False, dtype=torch.float16)
41
+ def forward(self, inp, length=None, split=True, split_size=4):
42
+ self.model.eval()
43
+ if self.training and split:
44
+ split_size = int(math.ceil(inp.shape[0] / 4))
45
+ outs, out_lens = [], []
46
+ for i in range(0, inp.shape[0], split_size):
47
+ inp_, length_ = inp[i : i + split_size], length[i : i + split_size]
48
+ out_, out_len_ = self._extract(inp_, length_)
49
+ outs.append(out_)
50
+ out_lens.append(out_len_)
51
+ max_length = max([max(ols) for ols in out_lens])
52
+
53
+ return torch.cat(
54
+ [F.pad(o, (0, 0, 0, max_length - o.shape[1]), value=0) for o in outs],
55
+ dim=0,
56
+ ), torch.cat(out_lens, dim=0)
57
+ else:
58
+ return self._extract(inp, length)
59
+
60
+ @torch.no_grad()
61
+ def _extract(self, inp, length):
62
+ frame_samples = int(self.sampling_rate * 0.02)
63
+ device = inp.device
64
+
65
+ if len(inp.shape) == 3:
66
+ inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
67
+ assert len(inp.shape) == 2
68
+ assert self.sampling_rate == 16000
69
+
70
+ feats = inp
71
+
72
+ # Padding with 0
73
+ padding_size = 3200 # Longer to cover receptive field
74
+ feats = F.pad(feats, (0, padding_size), mode="constant", value=0)
75
+
76
+ # Norm volume using LN
77
+ feats = self._postprocess(
78
+ feats, length + padding_size, normalize=self.normalize
79
+ )
80
+
81
+ if length is None:
82
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
83
+ else:
84
+ length = torch.ceil(length / 320).int()
85
+ padding_mask = get_mask_from_lengths(length).bool()
86
+ padding_mask = F.pad(padding_mask, (0, 9), value=True)
87
+
88
+ inputs = {
89
+ "source": feats.half().to(device),
90
+ "padding_mask": padding_mask.to(device),
91
+ "mask": False,
92
+ }
93
+ logits, _ = self.model.extract_features(**inputs)
94
+ logits = logits[:, : length.max()].float()
95
+
96
+ return logits, length
97
+
98
+ def _postprocess(self, feats, lengths, normalize=False):
99
+ assert feats.dim() == 2, feats.dim()
100
+
101
+ if normalize:
102
+ with torch.no_grad():
103
+ feats = [
104
+ F.layer_norm(feat[:length], feat[:length].shape)
105
+ for feat, length in zip(feats, lengths)
106
+ ]
107
+ feats = pad_sequence(feats, batch_first=True, padding_value=0)
108
+ return feats
fireredtts/modules/semantic_tokenizer/semantic_tokenizer.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from tokenize import Triple
3
+ from torch import nn
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+ from torch.nn.utils import spectral_norm, weight_norm
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ import einops
10
+ import math
11
+ import numpy as np
12
+ import os
13
+ import random
14
+ import torch
15
+ import torchaudio
16
+ import typing as tp
17
+ import warnings
18
+
19
+ from .audio import TorchMelSpectrogram
20
+ from .ecapa_tdnn import ECAPA_TDNN
21
+ from .hubert import HuBERT
22
+ from ..acoustic_codec.vector_quantization import VectorQuantization
23
+
24
+
25
+ CONV_NORMALIZATIONS = frozenset(
26
+ [
27
+ "none",
28
+ "weight_norm",
29
+ "spectral_norm",
30
+ "time_layer_norm",
31
+ "layer_norm",
32
+ "time_group_norm",
33
+ ]
34
+ )
35
+ NORM = "weight_norm"
36
+
37
+
38
+ def get_mask_from_lengths(lengths, max_len=None):
39
+ max_len = torch.max(lengths).item() if max_len is None else max_len
40
+ ids = torch.arange(0, max_len).to(lengths.device)
41
+ mask = ~(ids < lengths.unsqueeze(1)).bool()
42
+ return mask
43
+
44
+
45
+ class ConvLayerNorm(nn.LayerNorm):
46
+ def __init__(
47
+ self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
48
+ ):
49
+ super().__init__(normalized_shape, **kwargs)
50
+
51
+ def forward(self, x):
52
+ x = einops.rearrange(x, "b ... t -> b t ...")
53
+ x = super().forward(x)
54
+ x = einops.rearrange(x, "b t ... -> b ... t")
55
+ return
56
+
57
+
58
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
59
+ assert norm in CONV_NORMALIZATIONS
60
+ if norm == "weight_norm":
61
+ return weight_norm(module)
62
+ elif norm == "spectral_norm":
63
+ return spectral_norm(module)
64
+ else:
65
+ # We already check was in CONV_NORMALIZATION, so any other choice
66
+ # doesn't need reparametrization.
67
+ return module
68
+
69
+
70
+ def get_norm_module(
71
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
72
+ ) -> nn.Module:
73
+ assert norm in CONV_NORMALIZATIONS
74
+ if norm == "layer_norm":
75
+ assert isinstance(module, nn.modules.conv._ConvNd)
76
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
77
+ elif norm == "time_group_norm":
78
+ if causal:
79
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
80
+ assert isinstance(module, nn.modules.conv._ConvNd)
81
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
82
+ else:
83
+ return nn.Identity()
84
+
85
+
86
+ def get_extra_padding_for_conv1d(
87
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
88
+ ) -> int:
89
+ length = x.shape[-1]
90
+ n_frames = (length - kernel_size + padding_total) / stride + 1
91
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
92
+ return ideal_length - length
93
+
94
+
95
+ def pad_for_conv1d(
96
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
97
+ ):
98
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
99
+ return F.pad(x, (0, extra_padding))
100
+
101
+
102
+ def pad1d(
103
+ x: torch.Tensor,
104
+ paddings: tp.Tuple[int, int],
105
+ mode: str = "zero",
106
+ value: float = 0.0,
107
+ ):
108
+ length = x.shape[-1]
109
+ padding_left, padding_right = paddings
110
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
111
+ if mode == "reflect":
112
+ max_pad = max(padding_left, padding_right)
113
+ extra_pad = 0
114
+ if length <= max_pad:
115
+ extra_pad = max_pad - length + 1
116
+ x = F.pad(x, (0, extra_pad))
117
+ padded = F.pad(x, paddings, mode, value)
118
+ end = padded.shape[-1] - extra_pad
119
+ return padded[..., :end]
120
+ else:
121
+ return F.pad(x, paddings, mode, value)
122
+
123
+
124
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
125
+ padding_left, padding_right = paddings
126
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
127
+ assert (padding_left + padding_right) <= x.shape[-1]
128
+ end = x.shape[-1] - padding_right
129
+ return x[..., padding_left:end]
130
+
131
+
132
+ class NormConv1d(nn.Module):
133
+
134
+ def __init__(
135
+ self,
136
+ *args,
137
+ causal: bool = False,
138
+ norm: str = "none",
139
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
140
+ **kwargs,
141
+ ):
142
+ super().__init__()
143
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
144
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
145
+ self.norm_type = norm
146
+
147
+ def forward(self, x):
148
+ x = self.conv(x)
149
+ x = self.norm(x)
150
+ return x
151
+
152
+
153
+ class NormConv2d(nn.Module):
154
+
155
+ def __init__(
156
+ self,
157
+ *args,
158
+ norm: str = "none",
159
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
160
+ **kwargs,
161
+ ):
162
+ super().__init__()
163
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
164
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
165
+ self.norm_type = norm
166
+
167
+ def forward(self, x):
168
+ x = self.conv(x)
169
+ x = self.norm(x)
170
+ return x
171
+
172
+
173
+ class NormConvTranspose1d(nn.Module):
174
+
175
+ def __init__(
176
+ self,
177
+ *args,
178
+ causal: bool = False,
179
+ norm: str = "none",
180
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
181
+ **kwargs,
182
+ ):
183
+ super().__init__()
184
+ self.convtr = apply_parametrization_norm(
185
+ nn.ConvTranspose1d(*args, **kwargs), norm
186
+ )
187
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
188
+ self.norm_type = norm
189
+
190
+ def forward(self, x):
191
+ x = self.convtr(x)
192
+ x = self.norm(x)
193
+ return x
194
+
195
+
196
+ class NormConvTranspose2d(nn.Module):
197
+
198
+ def __init__(
199
+ self,
200
+ *args,
201
+ norm: str = "none",
202
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
203
+ **kwargs,
204
+ ):
205
+ super().__init__()
206
+ self.convtr = apply_parametrization_norm(
207
+ nn.ConvTranspose2d(*args, **kwargs), norm
208
+ )
209
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
210
+
211
+ def forward(self, x):
212
+ x = self.convtr(x)
213
+ x = self.norm(x)
214
+ return x
215
+
216
+
217
+ class SConv1d(nn.Module):
218
+
219
+ def __init__(
220
+ self,
221
+ in_channels: int,
222
+ out_channels: int,
223
+ kernel_size: int,
224
+ stride: int = 1,
225
+ dilation: int = 1,
226
+ groups: int = 1,
227
+ bias: bool = True,
228
+ causal: bool = False,
229
+ norm: str = "weight_norm",
230
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
231
+ pad_mode: str = "reflect",
232
+ ):
233
+ super().__init__()
234
+ # warn user on unusual setup between dilation and stride
235
+ if stride > 1 and dilation > 1:
236
+ warnings.warn(
237
+ "SConv1d has been initialized with stride > 1 and dilation > 1"
238
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
239
+ )
240
+ self.conv = NormConv1d(
241
+ in_channels,
242
+ out_channels,
243
+ kernel_size,
244
+ stride,
245
+ dilation=dilation,
246
+ groups=groups,
247
+ bias=bias,
248
+ causal=causal,
249
+ norm=norm,
250
+ norm_kwargs=norm_kwargs,
251
+ )
252
+ self.causal = causal
253
+ self.pad_mode = pad_mode
254
+
255
+ def forward(self, x):
256
+ B, C, T = x.shape
257
+ kernel_size = self.conv.conv.kernel_size[0]
258
+ stride = self.conv.conv.stride[0]
259
+ dilation = self.conv.conv.dilation[0]
260
+ kernel_size = (
261
+ kernel_size - 1
262
+ ) * dilation + 1 # effective kernel size with dilations
263
+ padding_total = kernel_size - stride
264
+ extra_padding = get_extra_padding_for_conv1d(
265
+ x, kernel_size, stride, padding_total
266
+ )
267
+ if self.causal:
268
+ # Left padding for causal
269
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
270
+ else:
271
+ # Asymmetric padding required for odd strides
272
+ padding_right = padding_total // 2
273
+ padding_left = padding_total - padding_right
274
+ x = pad1d(
275
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
276
+ )
277
+ return self.conv(x)
278
+
279
+
280
+ class SConvTranspose1d(nn.Module):
281
+
282
+ def __init__(
283
+ self,
284
+ in_channels: int,
285
+ out_channels: int,
286
+ kernel_size: int,
287
+ stride: int = 1,
288
+ causal: bool = False,
289
+ norm: str = "weight_norm",
290
+ trim_right_ratio: float = 1.0,
291
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
292
+ ):
293
+ super().__init__()
294
+ self.convtr = NormConvTranspose1d(
295
+ in_channels,
296
+ out_channels,
297
+ kernel_size,
298
+ stride,
299
+ causal=causal,
300
+ norm=norm,
301
+ norm_kwargs=norm_kwargs,
302
+ )
303
+ self.causal = causal
304
+ self.trim_right_ratio = trim_right_ratio
305
+ assert (
306
+ self.causal or self.trim_right_ratio == 1.0
307
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
308
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
309
+
310
+ def forward(self, x):
311
+ kernel_size = self.convtr.convtr.kernel_size[0]
312
+ stride = self.convtr.convtr.stride[0]
313
+ padding_total = kernel_size - stride
314
+
315
+ y = self.convtr(x)
316
+
317
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
318
+ # removed at the very end, when keeping only the right length for the output,
319
+ # as removing it here would require also passing the length at the matching layer
320
+ # in the encoder.
321
+ if self.causal:
322
+ # Trim the padding on the right according to the specified ratio
323
+ # if trim_right_ratio = 1.0, trim everything from right
324
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
325
+ padding_left = padding_total - padding_right
326
+ y = unpad1d(y, (padding_left, padding_right))
327
+ else:
328
+ # Asymmetric padding required for odd strides
329
+ padding_right = padding_total // 2
330
+ padding_left = padding_total - padding_right
331
+ y = unpad1d(y, (padding_left, padding_right))
332
+ return y
333
+
334
+
335
+ class SLSTM(nn.Module):
336
+
337
+ def __init__(
338
+ self,
339
+ dimension: int,
340
+ num_layers: int = 2,
341
+ bidirectional: bool = False,
342
+ skip: bool = True,
343
+ ):
344
+ super().__init__()
345
+ self.bidirectional = bidirectional
346
+ self.skip = skip
347
+ if bidirectional:
348
+ self.lstm = nn.LSTM(
349
+ dimension, dimension // 2, num_layers, bidirectional=bidirectional
350
+ )
351
+ else:
352
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
353
+
354
+ def forward(self, x):
355
+ x = x.permute(2, 0, 1)
356
+ y, _ = self.lstm(x)
357
+ if self.skip:
358
+ y = y + x
359
+ y = y.permute(1, 2, 0)
360
+ return y
361
+
362
+
363
+ class Swish(nn.Module):
364
+ def forward(self, x):
365
+ return x * torch.sigmoid(x)
366
+
367
+
368
+ class ResidualUnit(nn.Module):
369
+ def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
370
+ super().__init__()
371
+
372
+ self.layers = nn.Sequential(
373
+ SConv1d(
374
+ in_channels=in_channels,
375
+ out_channels=out_channels // 2,
376
+ kernel_size=kernel_size,
377
+ groups=groups,
378
+ norm=NORM,
379
+ ),
380
+ Swish(),
381
+ SConv1d(
382
+ in_channels=out_channels // 2,
383
+ out_channels=out_channels,
384
+ kernel_size=kernel_size,
385
+ groups=groups,
386
+ norm=NORM,
387
+ ),
388
+ )
389
+
390
+ def forward(self, x):
391
+ return x + self.layers(x)
392
+
393
+
394
+ class EncoderBlock(nn.Module):
395
+ def __init__(self, out_channels, stride):
396
+ super().__init__()
397
+
398
+ self.layers = nn.Sequential(
399
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels),
400
+ Swish(),
401
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels),
402
+ Swish(),
403
+ SConv1d(
404
+ in_channels=out_channels,
405
+ out_channels=out_channels,
406
+ kernel_size=2 * stride,
407
+ stride=stride,
408
+ norm=NORM,
409
+ ),
410
+ )
411
+
412
+ def forward(self, x):
413
+ return self.layers(x)
414
+
415
+
416
+ class DecoderBlock(nn.Module):
417
+ def __init__(self, in_channels, stride):
418
+ super().__init__()
419
+ out_channels = in_channels
420
+ self.layers = nn.Sequential(
421
+ SConvTranspose1d(
422
+ in_channels=in_channels,
423
+ out_channels=out_channels,
424
+ kernel_size=2 * stride,
425
+ stride=stride,
426
+ norm=NORM,
427
+ ),
428
+ Swish(),
429
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels),
430
+ Swish(),
431
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels),
432
+ )
433
+
434
+ def forward(self, x):
435
+ return self.layers(x)
436
+
437
+
438
+ class Encoder(nn.Module):
439
+ def __init__(self, C, D, strides=[2, 2], checkpointing=True):
440
+ super().__init__()
441
+ self.checkpointing = checkpointing
442
+
443
+ self.downsample_scale = np.cumprod(np.asarray(strides))[-1]
444
+ self.layers = [
445
+ SConv1d(in_channels=C, out_channels=D, kernel_size=3, norm=NORM),
446
+ Swish(),
447
+ ]
448
+ for stride in strides:
449
+ self.layers += [
450
+ EncoderBlock(out_channels=D, stride=stride),
451
+ Swish(),
452
+ ]
453
+ self.layers += [
454
+ SConv1d(in_channels=D, out_channels=D, kernel_size=3, norm=NORM),
455
+ SLSTM(D, num_layers=1, bidirectional=True),
456
+ ]
457
+ self.layers = nn.Sequential(*self.layers)
458
+
459
+ def forward(self, x):
460
+ if self.checkpointing:
461
+ x = checkpoint(
462
+ self.layers, x.transpose(1, 2), use_reentrant=False
463
+ ).transpose(1, 2)
464
+ else:
465
+ x = self.layers(x.transpose(1, 2)).transpose(1, 2)
466
+ return x
467
+
468
+
469
+ class Decoder(nn.Module):
470
+ def __init__(self, C, D, H, strides=[2, 2], checkpointing=True):
471
+ super().__init__()
472
+ self.checkpointing = checkpointing
473
+
474
+ self.in_layer = nn.Sequential(
475
+ SConv1d(in_channels=D, out_channels=H, kernel_size=3, norm=NORM),
476
+ SLSTM(H, num_layers=1, bidirectional=True),
477
+ )
478
+ self.layers = nn.ModuleList()
479
+ for stride in strides:
480
+ self.layers.append(
481
+ nn.Sequential(DecoderBlock(in_channels=H, stride=stride), Swish())
482
+ )
483
+ self.out_layer = SConv1d(
484
+ in_channels=H, out_channels=C, kernel_size=3, norm=NORM
485
+ )
486
+
487
+ def forward(self, x, g=None):
488
+ if self.checkpointing:
489
+ y = checkpoint(self._forward, x, g, use_reentrant=False)
490
+ else:
491
+ y = self._forward(x, g)
492
+ return y
493
+
494
+ def _forward(self, x, g=None):
495
+ h = self.in_layer(x.transpose(1, 2))
496
+
497
+ for layer in self.layers:
498
+ up_g = g.unsqueeze(-1).repeat(1, 1, h.shape[-1])
499
+ h = h + up_g
500
+ h = layer(h)
501
+
502
+ y = self.out_layer(h)
503
+
504
+ return y.transpose(1, 2), h.transpose(1, 2)
505
+
506
+
507
+ class TimeRegulator(nn.Module):
508
+
509
+ def __init__(self, in_dim, scale, learnable=False):
510
+ super().__init__()
511
+ self.scale = scale
512
+ self.learnable = learnable
513
+
514
+ def forward(self, x, x_len, downsample=True):
515
+ if downsample:
516
+ x = self.downsample(x, x_len)
517
+ else:
518
+ x = self.upsample(x, x_len)
519
+ return x
520
+
521
+ def downsample(self, x, x_len):
522
+ x = torch.nn.functional.avg_pool1d(
523
+ x.transpose(1, 2), self.scale, stride=self.scale, ceil_mode=True
524
+ ).transpose(1, 2)
525
+ x_len = (x_len / self.scale).ceil()
526
+ return x, x_len
527
+
528
+ def upsample(self, x, x_len):
529
+ if self.learnable:
530
+ x = self.upsampler(x.transpose(1, 2)).transpose(1, 2)
531
+ else:
532
+ x = torch.repeat_interleave(x, self.scale, dim=1)
533
+ return x
534
+
535
+
536
+ class TreeVectorQuantization(nn.Module):
537
+
538
+ def __init__(
539
+ self,
540
+ in_dim,
541
+ vq_class="VectorQuantization",
542
+ vq_config={},
543
+ tree_config={},
544
+ ):
545
+ super().__init__()
546
+ self.vq_config = vq_config
547
+ self.tree_config = tree_config
548
+
549
+ self.quantizers = nn.ModuleList()
550
+ self.time_regulators = nn.ModuleList()
551
+ for config in self.tree_config:
552
+ vq_config = self.vq_config.copy()
553
+ if not isinstance(vq_config["codebook_size"], (tuple, list)):
554
+ vq_config["codebook_size"] = [vq_config["codebook_size"]]
555
+ vq_config["codebook_dim"] = [vq_config["codebook_dim"]]
556
+ vq_config["codebook_size"] = vq_config["codebook_size"] * config["n_groups"]
557
+ vq_config["codebook_dim"] = vq_config["codebook_dim"] * config["n_groups"]
558
+ self.quantizers.append(
559
+ VectorQuantization(
560
+ in_dim,
561
+ n_groups=config.get("n_groups", 1),
562
+ dropout_rate_per_group=config.get("dropout_rate_per_group", 0),
563
+ ordered=config.get("ordered", False),
564
+ **vq_config,
565
+ )
566
+ )
567
+ self.time_regulators.append(
568
+ TimeRegulator(
569
+ in_dim,
570
+ config["downsample_rate"],
571
+ config.get("learnable_time_regulator", False),
572
+ )
573
+ )
574
+
575
+ def forward(
576
+ self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False
577
+ ):
578
+ output, (quants, losses, embed_inds) = self.quantize(
579
+ inp,
580
+ inp_len,
581
+ enable_vq=enable_vq,
582
+ update_codebook=update_codebook,
583
+ return_pre_quant=return_pre_quant,
584
+ )
585
+ loss = sum(losses) / len(losses)
586
+ return output, (quants, loss, embed_inds)
587
+
588
+ def quantize(
589
+ self, inp, inp_len, enable_vq=True, update_codebook=True, return_pre_quant=False
590
+ ):
591
+ quants, losses, embed_inds = [], [], []
592
+
593
+ pre_quant_output, quant_output, residual = 0, 0, inp
594
+ for tree_config, quantizer, regulator in zip(
595
+ self.tree_config, self.quantizers, self.time_regulators
596
+ ):
597
+ # Downsample
598
+ x, x_len = regulator(residual, inp_len, True)
599
+
600
+ # Quantization
601
+ q, diff, embed_ind = quantizer(
602
+ x,
603
+ x_len,
604
+ enable_vq=enable_vq,
605
+ update_codebook=update_codebook,
606
+ return_pre_quant=return_pre_quant,
607
+ )
608
+ if return_pre_quant:
609
+ pq, q = q
610
+
611
+ # Upsample
612
+ x = regulator(q, x_len, False)[:, : residual.shape[1]]
613
+
614
+ residual = residual - x
615
+ quant_output = quant_output + x
616
+
617
+ if return_pre_quant:
618
+ pq = regulator(pq, x_len, False)[:, : residual.shape[1]]
619
+ pre_quant_output = pre_quant_output + pq
620
+
621
+ quants.append(q)
622
+ losses.append(diff)
623
+ embed_inds.append(embed_ind)
624
+
625
+ if return_pre_quant:
626
+ return (pre_quant_output, quant_output), (quants, losses, embed_inds)
627
+ return quant_output, (quants, losses, embed_inds)
628
+
629
+ def decode(self, seqs, seq_lens=None):
630
+ if not isinstance(seqs, (tuple, list)):
631
+ tokens, token_lens = self.deserialize(seqs, seq_lens)
632
+ else:
633
+ tokens, token_lens = seqs, seq_lens
634
+
635
+ quant_output = 0
636
+ for token, quantizer, regulator in zip(
637
+ tokens, self.quantizers, self.time_regulators
638
+ ):
639
+ x = quantizer.decode(token).transpose(1, 2)
640
+ x = regulator(x, None, False)
641
+ if torch.is_tensor(quant_output):
642
+ x = x[:, : quant_output.size(1)]
643
+ quant_output = quant_output + x
644
+
645
+ return quant_output, token_lens
646
+
647
+ def serialize(self, tokens, token_lens):
648
+ assert len(tokens) <= 2, "we only support 1 or 2-scale sequences now..."
649
+
650
+ scale = self.tree_config[0]["downsample_rate"]
651
+ token_lens = ((token_lens.float() / scale).ceil() * scale).int()
652
+
653
+ seq1 = tokens[0].unsqueeze(-1)
654
+
655
+ if len(tokens) == 1:
656
+ seq_cat = seq1.view(seq1.shape[0], -1)
657
+ seq_cat_lens = (token_lens / scale * seq1.shape[2]).int()
658
+ elif len(tokens) == 2:
659
+ seq2 = F.pad(
660
+ tokens[1], (0, token_lens.max() - tokens[1].size(1)), "replicate"
661
+ )
662
+ seq2 = torch.stack([seq2[:, i::scale] for i in range(scale)], dim=-1)
663
+ seq_cat = torch.cat((seq1, seq2), dim=-1).view(seq1.shape[0], -1)
664
+ seq_cat_lens = (token_lens / scale + token_lens).int()
665
+
666
+ return seq_cat, seq_cat_lens
667
+
668
+ def deserialize(self, seqs, seq_lens):
669
+ if len(self.tree_config) == 1:
670
+ return [seqs], seq_lens
671
+
672
+ max_scale = max(config["downsample_rate"] for config in self.tree_config)
673
+ total_scale = sum(config["downsample_rate"] for config in self.tree_config)
674
+
675
+ # Cut for aligning
676
+ if seq_lens is None:
677
+ seq_lens = torch.full([seqs.shape[0]], seqs.shape[1]).to(seqs.device)
678
+ seq_lens = (seq_lens / total_scale).int() * total_scale
679
+ token_lens = (seq_lens / total_scale).int() * max_scale
680
+ seqs = seqs[:, : seq_lens.max()]
681
+
682
+ # Separate
683
+ tokens = torch.stack(
684
+ [seqs[:, i::total_scale] for i in range(total_scale)], dim=-1
685
+ )
686
+ seq1 = tokens[..., 0]
687
+ seq2 = tokens[..., 1:].contiguous().view(tokens.shape[0], -1)
688
+
689
+ return [seq1, seq2], token_lens
690
+
691
+
692
+ class SemanticVQVAE(nn.Module):
693
+
694
+ def __init__(
695
+ self,
696
+ in_dim,
697
+ out_dim,
698
+ n_model_size,
699
+ downsample_scales=[1, 2],
700
+ upsample_scales=[[2, 1], [2, 1]],
701
+ mel_config={},
702
+ ssl_config={},
703
+ # Quantization
704
+ vq_class="VectorQuantization",
705
+ vq_config={},
706
+ tree_config={},
707
+ # Training
708
+ checkpointing=True,
709
+ dual_decoding=False,
710
+ n_samples_per_token=640,
711
+ online_extraction=True,
712
+ ssl_extractor=None,
713
+ ):
714
+ super(SemanticVQVAE, self).__init__()
715
+ self.in_dim = in_dim
716
+ self.n_model_size = n_model_size
717
+ self.mel_config = mel_config
718
+ self.dual_decoding = dual_decoding
719
+ self.vq_config = vq_config
720
+ self.tree_config = tree_config
721
+ self.output_feature = "mel"
722
+ self.n_samples_per_token = n_samples_per_token
723
+ self.checkpointing = checkpointing
724
+
725
+ self.mel_spectrogram = TorchMelSpectrogram(**mel_config)
726
+
727
+ # Speaker encoder
728
+ self.speaker_encoder = ECAPA_TDNN(
729
+ out_dim,
730
+ n_model_size,
731
+ channels=[512, 512, 512, 512, 1536],
732
+ kernel_sizes=[5, 3, 3, 3, 1],
733
+ dilations=[1, 2, 3, 4, 1],
734
+ attention_channels=128,
735
+ res2net_scale=4,
736
+ se_channels=128,
737
+ global_context=True,
738
+ batch_norm=True,
739
+ )
740
+
741
+ # Encoder & decoder
742
+ self.encoder = Encoder(
743
+ in_dim, n_model_size, downsample_scales, checkpointing=checkpointing
744
+ )
745
+
746
+ # Quantization
747
+ self.quantizer = TreeVectorQuantization(
748
+ n_model_size,
749
+ vq_class=vq_class,
750
+ vq_config=vq_config,
751
+ tree_config=tree_config,
752
+ )
753
+
754
+ def forward(
755
+ self,
756
+ wav,
757
+ wav_length,
758
+ enable_vq=True,
759
+ decode=True,
760
+ extract_spk=True,
761
+ shuffle=False,
762
+ **kwargs,
763
+ ):
764
+ output_dict = {}
765
+
766
+ with torch.no_grad():
767
+ # Pad waveform
768
+ if wav.shape[1] % self.n_samples_per_token > 0:
769
+ pad_size = (
770
+ self.n_samples_per_token - wav.shape[1] % self.n_samples_per_token
771
+ )
772
+ wav = F.pad(wav, (0, pad_size), value=0)
773
+ wav_length += pad_size
774
+
775
+ # Extract mel & sll
776
+ mel, mel_length = kwargs.get("mel", None), kwargs.get("mel_length", None)
777
+ if mel is None:
778
+ mel, mel_length = self.mel_spectrogram(wav, wav_length)
779
+ output_dict.update({"mel": mel, "mel_length": mel_length})
780
+
781
+ ssl, ssl_length = kwargs.get("ssl", None), kwargs.get("ssl_length", None)
782
+ if ssl is None:
783
+ ssl, ssl_length = self.ssl_extractor(wav, wav_length)
784
+ output_dict.update({"ssl": ssl.float(), "ssl_length": ssl_length})
785
+
786
+ input, input_length = ssl, ssl_length
787
+ output, output_length = mel, mel_length
788
+
789
+ encoder_outputs = self.encoder(input)
790
+ quant_length = torch.ceil(input_length / self.encoder.downsample_scale)
791
+ quant_length = quant_length.clamp(max=encoder_outputs.shape[1])
792
+
793
+ quant, (quants, diff, embed_ind) = self.quantizer(
794
+ encoder_outputs,
795
+ quant_length,
796
+ enable_vq=enable_vq,
797
+ update_codebook=True,
798
+ return_pre_quant=self.dual_decoding,
799
+ )
800
+
801
+ output_dict.update(
802
+ {
803
+ "quants": quants,
804
+ "token": embed_ind,
805
+ "token_length": quant_length.int(),
806
+ "encoder_diffs": diff,
807
+ }
808
+ )
809
+
810
+ # Speaker
811
+ if extract_spk:
812
+ cond, cond_length = output, output_length
813
+ speaker_embedding = self.speaker_encoder(cond, cond_length)
814
+ speaker_embedding_1 = speaker_embedding_2 = speaker_embedding
815
+ output_dict["spk"] = speaker_embedding
816
+
817
+ return output_dict
818
+
819
+ @torch.no_grad()
820
+ def extract_speech_tokens(
821
+ self, wav, wav_length, serialize=True, extract_spk=True, shuffle=False
822
+ ):
823
+ output_dict = self.forward(
824
+ wav, wav_length, True, False, extract_spk=extract_spk, shuffle=shuffle
825
+ )
826
+ token_seqs, token_length = output_dict["token"], output_dict["token_length"]
827
+
828
+ # Align sequences
829
+ scale = self.tree_config[0]["downsample_rate"]
830
+ token_length = (torch.ceil(token_length / scale) * scale).int()
831
+
832
+ new_token_seqs, new_token_lens = [], []
833
+ for i, token_seq in enumerate(token_seqs):
834
+ # discrete-continuous tokens
835
+ residual = None
836
+ if isinstance(token_seq, (tuple, list)):
837
+ token_seq, residual = token_seq
838
+
839
+ scale = self.tree_config[i]["downsample_rate"]
840
+ new_token_len = token_length // scale
841
+ pad = int(new_token_len.max()) - token_seq.shape[1]
842
+ token_seq = F.pad(
843
+ token_seq,
844
+ (0, pad) if len(token_seq.shape) == 2 else (0, 0, 0, pad),
845
+ "replicate",
846
+ )
847
+
848
+ if residual is not None:
849
+ token_seq = (token_seq, residual)
850
+ new_token_seqs.append(token_seq)
851
+ new_token_lens.append(new_token_len)
852
+
853
+ if len(new_token_seqs) == 1:
854
+ new_token_seqs, new_token_lens = new_token_seqs[0], new_token_lens[0]
855
+ elif serialize:
856
+ new_token_seqs, new_token_lens = self.quantizer.serialize(
857
+ new_token_seqs, new_token_lens
858
+ )
859
+
860
+ output_dict.update(
861
+ {
862
+ "embed": output_dict["quants"],
863
+ "token": new_token_seqs,
864
+ "token_length": new_token_lens,
865
+ }
866
+ )
867
+
868
+ return output_dict
869
+
870
+ @torch.no_grad()
871
+ def code_to_latent(self, token, mel=None):
872
+ quant, _ = self.quantizer.decode(token, None)
873
+ speaker_embedding = self.speaker_encoder(mel)
874
+ latents = quant + speaker_embedding.unsqueeze(1).repeat(1, quant.shape[1], 1)
875
+ return {
876
+ "latents": latents,
877
+ }
fireredtts/modules/text_normalizer/__init__.py ADDED
File without changes
fireredtts/modules/text_normalizer/normalize.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import regex
3
+ import inflect
4
+ import unicodedata
5
+ from lingua import Language, LanguageDetectorBuilder
6
+ from builtins import str as unicode
7
+
8
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
9
+ from tn.english.normalizer import Normalizer as EnNormalizer
10
+
11
+ from fireredtts.modules.text_normalizer.regex_common import *
12
+ from fireredtts.modules.text_normalizer.utils import *
13
+
14
+
15
+ def preprocess_text(sentence):
16
+ # preprocessing
17
+ sentence = bytes(sentence, "utf-8").decode("utf-8", "ignore")
18
+ sentence = regex.sub("[\p{Cf}--[\u200d]]", "", sentence, flags=regex.V1)
19
+ sentence = regex.sub("\p{Co}", "", sentence)
20
+ sentence = sentence.replace("\u00a0", " ")
21
+ sentence = sentence.replace("\ufffd", "")
22
+ sentence = regex.sub("\p{Zl}", "\n", sentence)
23
+ sentence = regex.sub("\p{Zp}", "\n", sentence)
24
+
25
+ sentence = unicode(sentence)
26
+ sentence = "".join(
27
+ char
28
+ for char in unicodedata.normalize("NFD", sentence)
29
+ if unicodedata.category(char) != "Mn"
30
+ ) # Strip accents
31
+
32
+ sentence = strip_kaomoji(sentence)
33
+ # full to half with exemption (to be converted after number TN): 。,:
34
+ sentence = f2b(sentence, exemption="。,:")
35
+
36
+ # clean spaces
37
+ sentence = sentence.replace("\n", ",")
38
+ sentence = sentence.replace("\t", ",")
39
+ sentence = sentence.replace("\r", ",")
40
+ sentence = re.sub(r"[。.]{3,}", "…", sentence)
41
+ sentence = re.sub(r"[…⋯]{1,}", "…", sentence)
42
+ sentence = re.sub(r"[ ]+", " ", sentence)
43
+ sentence = sentence.strip()
44
+
45
+ # punctuation reduction
46
+ result = ""
47
+ for idx, char in enumerate(sentence):
48
+ if char in symbol_reduction:
49
+ char = symbol_reduction[char]
50
+
51
+ if char == " ":
52
+ if idx == 0:
53
+ continue
54
+ if is_chinese(sentence[idx + 1]) and (
55
+ is_chinese(sentence[idx - 1]) or sentence[idx - 1] in '") '
56
+ ):
57
+ result += ","
58
+ else:
59
+ result += " "
60
+ continue
61
+
62
+ if is_valid_char(char):
63
+ result += char
64
+ result = re.sub(r"[ ]+", " ", result)
65
+ return result
66
+
67
+
68
+ def rettt(sentence):
69
+ # handle abbreviations for all languages
70
+ sentence = sentence.replace("&nd", "and")
71
+ sentence = sentence.replace("Jan.", "january")
72
+ sentence = sentence.replace("Feb.", "febrary")
73
+ sentence = sentence.replace("Mar.", "march")
74
+ sentence = sentence.replace("Apr.", "april")
75
+ sentence = sentence.replace("May.", "may")
76
+ sentence = sentence.replace("Jun.", "june")
77
+ sentence = sentence.replace("Jul.", "july")
78
+ sentence = sentence.replace("Aug.", "august")
79
+ sentence = sentence.replace("Sept.", "september")
80
+ sentence = sentence.replace("Sep.", "september")
81
+ sentence = sentence.replace("Oct.", "october")
82
+ sentence = sentence.replace("Nov.", "november")
83
+ sentence = sentence.replace("Dec.", "december")
84
+ sentence = sentence.replace("Mon.", "monday")
85
+ sentence = sentence.replace("Tues.", "tuesday")
86
+ sentence = sentence.replace("Wed.", "wednesday")
87
+ sentence = sentence.replace("Thur.", "thursday")
88
+ sentence = sentence.replace("Fri.", "friday")
89
+ sentence = sentence.replace("Sat.", "saturday")
90
+ if sentence != "Sun.":
91
+ sentence = sentence.replace("Sun.", "sunday")
92
+ sentence = re.sub(r" St\. ([A-Z])", r" saint \1", sentence)
93
+ sentence = re.sub(r" St\.", " street", sentence)
94
+ sentence = re.sub(r" Rd\.", " road", sentence)
95
+ sentence = re.sub(r"[Aa]\.[Mm]\.", "A_M", sentence)
96
+ sentence = re.sub(r"[Pp]\.[Mm]\.", "P_M", sentence)
97
+ sentence = re.sub(r"[Bb]\.[Cc]\.", "B_C", sentence)
98
+ sentence = re.sub(r"[Ad]\.[Dd]\.", "A_D", sentence)
99
+ sentence = sentence.replace("Mr.", "mister")
100
+ sentence = sentence.replace("Ms.", "miss")
101
+ sentence = sentence.replace("Mrs.", "misses")
102
+ sentence = sentence.replace("Ph.D", "P_H_D")
103
+ sentence = sentence.replace("i.e.", "that is")
104
+ sentence = sentence.replace("e.g.", "for example")
105
+ sentence = sentence.replace("btw.", "by the way")
106
+ sentence = sentence.replace("btw", "by the way")
107
+ sentence = sentence.replace("b.t.w.", "by the way")
108
+ sentence = sentence.replace("@", " at ")
109
+ return sentence
110
+
111
+
112
+ class TextNormalizer:
113
+ def __init__(self):
114
+ self.language_detector = LanguageDetectorBuilder.from_languages(
115
+ Language.ENGLISH, Language.CHINESE
116
+ ).build()
117
+ self.zh_normalizer = ZhNormalizer()
118
+ self.en_normalizer = EnNormalizer()
119
+ self.inflect_parser = inflect.engine()
120
+ self.lang2token = {Language.ENGLISH: "en", Language.CHINESE: "zh"}
121
+
122
+ def tn(self, text):
123
+ text = preprocess_text(text)
124
+ text = rettt(text) # regex replacements
125
+ # for non chinese languages
126
+ language = self.language_detector.detect_language_of(text)
127
+ # enforce chinese if text contains any chinese character
128
+ if contains_chinese(text):
129
+ language = Language.CHINESE
130
+ text_lang = self.lang2token.get(language, "zh")
131
+
132
+ if is_upper_eng_and_digit(text):
133
+ language = Language.CHINESE
134
+
135
+ if language == Language.CHINESE:
136
+ text = self.zh_normalizer.normalize(text)
137
+ # print("---text after zh_normalizer:", text)
138
+ text = text.replace("\n", "")
139
+ text = text.replace(",", ",")
140
+ text = text.replace(".", "。")
141
+ text = re.sub(r"[,,]+$", "。", text)
142
+ # print("---text after zh_normalizer 2:", text)
143
+ else:
144
+ text = re.sub(r"[^ 0-9A-Za-z\[\]'.,:?!_\-]", "", text)
145
+ text = self.en_normalizer.normalize(text)
146
+ # fallback number normalization
147
+ pieces = re.split(r"(\d+)", text)
148
+ text = "".join(
149
+ [
150
+ self.inflect_parser.number_to_words(p) if p.isnumeric() else p
151
+ for p in pieces
152
+ if len(p) > 0
153
+ ]
154
+ )
155
+
156
+ # cleanup
157
+ text = text.replace("_", " ")
158
+ text = re.sub(r"[ ]+", " ", text)
159
+
160
+ # spell caplital words
161
+ pieces = re.split(r"([A-Z]{2,4}|[ ])", text)
162
+ for idx, p in enumerate(pieces):
163
+ if re.match("[A-Z]{2,4}", p):
164
+ pieces[idx] = " ".join(p)
165
+ text = " ".join([p for p in pieces if p != " "])
166
+
167
+ # post TN full to half
168
+ # text = text.replace("。", ".")
169
+ # text = text.replace(",", ",")
170
+ # text = text.replace(":", ":")
171
+
172
+ # model limitations
173
+ text = text.lower().strip()
174
+ text = text.replace('"', "")
175
+ text = text.replace("·", " ")
176
+ # text = re.sub("[…~!,&*%$#^:;!:;]+", ",", text)
177
+ text = re.sub("[…~!&*%$#^:;!:;]+", ",", text)
178
+ text = re.sub("[,]+", ",", text)
179
+ text = re.sub(r"[,. ]+$", ".", text)
180
+ if len(text) > 0 and text[-1] not in ".?":
181
+ text = text + "."
182
+ text = text.replace("。.", "。")
183
+ return text, text_lang
fireredtts/modules/text_normalizer/regex_common.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ kaomoji_regex = re.compile(
4
+ r"[oヽwΣ┗╰O︿Ψ凸]?[(|≡*(].{0,4}[Д✿_▽→≧﹏`∩⊙∇☆≡๑〃′エ≦▔@﹁εヘ•́ω益‿≖ฺ皿•̀艹 ̄△|゚].{0,5}[|≡*))][┛ブ凸cdd︴oOΨ︿w╯ノ]?"
5
+ )
6
+ chinese_regex = re.compile(r"[\u4e00-\u9fa5]")
7
+ digit_regex = re.compile(r"(\\d+)(\\.\\d+)?", re.UNICODE)
8
+
9
+ chinese_char_regex = re.compile(r"^[\u4e00-\u9fa5]$", re.UNICODE)
10
+ eng_and_digit_char_regex = re.compile(r"^[0-9.,A-Za-z]+$", re.UNICODE)
11
+ upper_eng_and_digit_regex = re.compile(r"^[ 0-9A-Z\"'.,:?!\-]+$", re.UNICODE)
12
+ valid_char_regex = re.compile(
13
+ r"[\t\r\n ]|"
14
+ r"[\u4e00-\u9fa5]|"
15
+ r"\u0080|[\u20a0-\u20bf]|\u00a2|\u00a3|\u00a5|\uffe0|\uffe1|\uffe5|\uffe6|"
16
+ r"\u3000|\u3002|\u00b7|\u2014|\u2019|\u2026|\uff01|\uff1f|\uff0e|\uff1a|\uff1b|\uff0b|\uff0c|\uff0d|\uff0f|[\ufe10-\ufe16]|[\ufe50-\ufe51]|[\ufe55-\ufe57]|\ufe6a|"
17
+ r"[\u0030-\u0040]|"
18
+ r"[\u0391-\u03c9]|"
19
+ r"[\u00b0-\u00b3]|[\u2015-\u2018]|[\u3000-\u303f]|"
20
+ r"[\u0022-\u002f\u003a-\u003e\u0040\u005b-\u0060\u007b-\u007e]|"
21
+ r"[\uff21-\uff3a]|[\uff41-\uff5a]|[\u0041-\u005a]|[\u0061-\u007a]",
22
+ re.UNICODE,
23
+ )
fireredtts/modules/text_normalizer/utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredtts.modules.text_normalizer.regex_common import *
2
+ from sentencex import segment
3
+ import re
4
+
5
+
6
+ symbol_reduction = {
7
+ "「": '"',
8
+ "」": '"',
9
+ "`": '"',
10
+ "〝": '"',
11
+ "〞": '"',
12
+ "‟": '"',
13
+ "„": '"',
14
+ "{": "(",
15
+ "}": ")",
16
+ "【": "(",
17
+ "】": ")",
18
+ "〖": "(",
19
+ "〗": ")",
20
+ "〔": "(",
21
+ "〕": ")",
22
+ "〘": "(",
23
+ "〙": ")",
24
+ "《": "(",
25
+ "》": ")",
26
+ "⦅": "(",
27
+ "⦆": ")",
28
+ "〚": "(",
29
+ "〛": ")",
30
+ "『": '"',
31
+ "』": '"',
32
+ "「": '"',
33
+ "」": '"',
34
+ "{": "(",
35
+ "}": ")",
36
+ "〈": "(",
37
+ "〉": ")",
38
+ "•": "·",
39
+ "‧": "·",
40
+ "〰": "…",
41
+ "﹏": "…",
42
+ "〜": "~",
43
+ "~": "~",
44
+ "+": "+",
45
+ "、": "、",
46
+ "。": "。",
47
+ "︐": ",",
48
+ "﹐": ",",
49
+ "︑": "、",
50
+ "﹑": "、",
51
+ "︒": "。",
52
+ "︓": ":",
53
+ "﹕": ":",
54
+ "︔": ";",
55
+ "﹔": ";",
56
+ "︕": "!",
57
+ "﹗": "!",
58
+ "︖": "?",
59
+ "﹖": "?",
60
+ "﹙": "(",
61
+ "﹚": ")",
62
+ "﹪": "%",
63
+ "﹠": "&",
64
+ ">": ">",
65
+ "|": "、",
66
+ "=": "=",
67
+ "‐": "-",
68
+ "‑": "-",
69
+ "‒": "-",
70
+ "–": "-",
71
+ "—": "-",
72
+ "―": "-",
73
+ "%": "%",
74
+ "μ": "u",
75
+ }
76
+
77
+
78
+ strong_break = re.compile("([。”;;!!:…??)\)\]』】」}~\r\n]| \.)", re.UNICODE)
79
+ weak_break = re.compile(
80
+ "["
81
+ "\U00002702-\U000027b0\U0001f926-\U0001f937\U00010000-\U0001fbff\U00030000-\U0010ffff"
82
+ "\u2640-\u2642\u2600-\u2b55\u23cf\u23e9\u231a\ufe0f\u3030"
83
+ "\t,,. ]",
84
+ re.UNICODE,
85
+ )
86
+
87
+
88
+ def contains_chinese(text):
89
+ return bool(chinese_regex.search(text))
90
+
91
+
92
+ def strip_kaomoji(text):
93
+ return kaomoji_regex.sub(" ", text)
94
+
95
+
96
+ def is_chinese(char):
97
+ return chinese_char_regex.match(char)
98
+
99
+
100
+ def is_eng_and_digit(char):
101
+ return eng_and_digit_char_regex.match(char)
102
+
103
+
104
+ def is_upper_eng_and_digit(text):
105
+ return upper_eng_and_digit_regex.match(text)
106
+
107
+
108
+ def is_valid_char(char):
109
+ return valid_char_regex.match(char)
110
+
111
+
112
+ def is_digit(text):
113
+ return digit_regex.match(text)
114
+
115
+
116
+ def f2b(ustr, exemption="。,:"):
117
+ half = []
118
+ for u in ustr:
119
+ num = ord(u)
120
+ if num == 0x3000:
121
+ half.append(" ")
122
+ elif u in exemption: # exemption
123
+ half.append(u)
124
+ elif 0xFF01 <= num <= 0xFF5E:
125
+ num -= 0xFEE0
126
+ half.append(chr(num))
127
+ else:
128
+ half.append(u)
129
+ return "".join(half)
130
+
131
+
132
+ def zh_text_split(text, length=80):
133
+ if length == 0:
134
+ return []
135
+ if length == 1:
136
+ return [c for c in length]
137
+ if len(text) <= length:
138
+ return [text]
139
+
140
+ match_strong = re.search(strong_break, text[:length][::-1])
141
+ match_weak = re.search(weak_break, text[:length][::-1])
142
+ end_ind_strong = length - match_strong.start() if match_strong else 0
143
+ end_ind_weak = length - match_weak.start() if match_weak else 0
144
+
145
+ if end_ind_strong < length // 3:
146
+ if end_ind_weak < length // 3:
147
+ valid_max = max(end_ind_strong, end_ind_weak)
148
+ if valid_max >= 3:
149
+ return [text[:valid_max]] + zh_text_split(text[valid_max:])
150
+ else:
151
+ return [text[:length]] + zh_text_split(text[length:])
152
+ else:
153
+ return [text[:end_ind_weak]] + zh_text_split(text[end_ind_weak:])
154
+ else:
155
+ return [text[:end_ind_strong]] + zh_text_split(text[end_ind_strong:])
156
+
157
+
158
+ def text_split(text):
159
+ if contains_chinese(text):
160
+ substrings = list(segment("zh", text))
161
+ new_substrings = []
162
+ for s in substrings:
163
+ if len(s) > 50:
164
+ new_substrings += zh_text_split(s, length=50)
165
+ else:
166
+ new_substrings.append(s)
167
+ substrings = new_substrings
168
+ else:
169
+ substrings = list(segment("en", text))
170
+
171
+ return substrings
fireredtts/setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(name="fireredtts", version="0.1", packages=find_packages())
fireredtts/utils/__init__.py ADDED
File without changes
fireredtts/utils/spliter.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+ SYMBOLS_MAPPING = {
5
+ "\n": "",
6
+ "…": ".",
7
+ "“": "'",
8
+ "”": "'",
9
+ "‘": "'",
10
+ "’": "'",
11
+ "【": "",
12
+ "】": "",
13
+ "[": "",
14
+ "]": "",
15
+ "(": "",
16
+ ")": "",
17
+ "(": "",
18
+ ")": "",
19
+ "・": "",
20
+ "·": "",
21
+ "「": "'",
22
+ "」": "'",
23
+ "《": "'",
24
+ "》": "'",
25
+ "—": "",
26
+ "~": "",
27
+ "~": "",
28
+ ":": ",",
29
+ ";": ",",
30
+ ";": ",",
31
+ ":": ",",
32
+ '"': "",
33
+ "!": "。",
34
+ "!": ".",
35
+ "————": ",",
36
+ "——": ",",
37
+ "—": ",",
38
+ "……": ",",
39
+ }
40
+
41
+ REPLACE_SYMBOL_REGEX = re.compile(
42
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
43
+ )
44
+
45
+
46
+ EMOJI_REGEX = re.compile(
47
+ "["
48
+ "\U0001f600-\U0001f64f" # emoticons
49
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
50
+ "\U0001f680-\U0001f6ff" # transport & map symbols
51
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
52
+ "]+",
53
+ flags=re.UNICODE,
54
+ )
55
+
56
+
57
+ def clean_text(text):
58
+ # Clean the text
59
+ text = text.strip()
60
+ text = text.replace("\xa0", "")
61
+
62
+ # Replace all chinese symbols with their english counterparts
63
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
64
+
65
+ # Remove emojis
66
+ text = EMOJI_REGEX.sub(r"", text)
67
+
68
+ # Remove continuous periods (...) and commas (,,,)
69
+ text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
70
+
71
+ return text
72
+
73
+
74
+ def utf_8_len(text):
75
+ return len(text.encode("utf-8"))
76
+
77
+
78
+ def break_text(texts, length, splits: set):
79
+ for text in texts:
80
+ if utf_8_len(text) <= length:
81
+ yield text
82
+ continue
83
+
84
+ curr = ""
85
+ for char in text:
86
+ curr += char
87
+
88
+ if char in splits:
89
+ yield curr
90
+ curr = ""
91
+
92
+ if curr:
93
+ yield curr
94
+
95
+
96
+ def break_text_by_length(texts, length):
97
+ for text in texts:
98
+ if utf_8_len(text) <= length:
99
+ yield text
100
+ continue
101
+
102
+ curr = ""
103
+ for char in text:
104
+ curr += char
105
+
106
+ if utf_8_len(curr) >= length:
107
+ yield curr
108
+ curr = ""
109
+
110
+ if curr:
111
+ yield curr
112
+
113
+
114
+ def add_cleaned(curr, segments):
115
+ curr = curr.strip()
116
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
117
+ segments.append(curr)
118
+
119
+
120
+ def protect_float(text):
121
+ # Turns 3.14 into <3_f_14> to prevent splitting
122
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
123
+
124
+
125
+ def unprotect_float(text):
126
+ # Turns <3_f_14> into 3.14
127
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
128
+
129
+
130
+ def split_text(text, length):
131
+ text = clean_text(text)
132
+
133
+ # Break the text into pieces with following rules:
134
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
135
+ # 2. If the text is longer than length, split at ","
136
+ # 3. If the text is still longer than length, split at " "
137
+ # 4. If the text is still longer than length, split at any character to length
138
+
139
+ texts = [text]
140
+ texts = map(protect_float, texts)
141
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
142
+ texts = map(unprotect_float, texts)
143
+ texts = break_text(texts, length, {",", ","})
144
+ texts = break_text(texts, length, {" "})
145
+ texts = list(break_text_by_length(texts, length))
146
+
147
+ # Then, merge the texts into segments with length <= length
148
+ segments = []
149
+ curr = ""
150
+
151
+ for text in texts:
152
+ if utf_8_len(curr) + utf_8_len(text) <= length:
153
+ curr += text
154
+ else:
155
+ add_cleaned(curr, segments)
156
+ curr = text
157
+
158
+ if curr:
159
+ add_cleaned(curr, segments)
160
+
161
+ return segments
fireredtts/utils/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+
8
+
9
+ def load_audio(audiopath, sampling_rate):
10
+ """_summary_
11
+
12
+ Args:
13
+ audiopath (_type_): audio_path
14
+ sampling_rate (_type_): sampling_rate
15
+
16
+ Returns:
17
+ _type_: _description_
18
+ """
19
+ audio, lsr = torchaudio.load(audiopath)
20
+
21
+ # stereo to mono if needed
22
+ if audio.size(0) != 1:
23
+ audio = torch.mean(audio, dim=0, keepdim=True)
24
+
25
+ # resample
26
+ audio_resampled = torchaudio.functional.resample(audio, lsr, sampling_rate)
27
+ if torch.any(audio > 10) or not torch.any(audio < 0):
28
+ print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
29
+
30
+ if torch.any(audio_resampled > 10) or not torch.any(audio_resampled < 0):
31
+ print(
32
+ f"Error with {audiopath}. Max={audio_resampled.max()} min={audio_resampled.min()}"
33
+ )
34
+ # clip audio invalid values
35
+ audio.clip_(-1, 1)
36
+ audio_resampled.clip_(-1, 1)
37
+ return audio, lsr, audio_resampled
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip==24.0
pretrained_models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## Pretrained Models
2
+
3
+ Download the required model files and place them in the folder `pretrained_models`
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchaudio
2
+ fairseq
3
+ diffusers==0.27.2
4
+ librosa==0.10.2
5
+ soundfile==0.12.1
6
+ einops==0.8.0
7
+ transformers==4.44.2
8
+ tiktoken==0.7.0
9
+ inflect==7.4.0
10
+ lingua-language-detector==2.0.2
11
+ WeTextProcessing==1.0.3
12
+ sentencex==0.6.1