Shen Feiyu commited on
Commit
71cd91e
·
1 Parent(s): a4ec42e

init at 250916

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: FireRedTTS2
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.45.0
8
  app_file: app.py
 
1
  ---
2
+ title: Tts2 Test
3
+ emoji: 🌖
4
+ colorFrom: pink
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.45.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import spaces
3
+ import gradio as gr
4
+ from tqdm import tqdm
5
+ from huggingface_hub import snapshot_download
6
+ from argparse import ArgumentParser
7
+ from typing import Literal, List, Tuple
8
+ from fireredtts2.fireredtts2 import FireRedTTS2
9
+
10
+
11
+ # ================================================
12
+ # FireRedTTS2 Model
13
+ # ================================================
14
+ # Global model instance
15
+ model: FireRedTTS2 = None
16
+
17
+
18
+ def initiate_model(pretrained_dir: str, device="cuda"):
19
+ global model
20
+ if model is None:
21
+ model = FireRedTTS2(
22
+ pretrained_dir=pretrained_dir,
23
+ gen_type="dialogue",
24
+ device=device,
25
+ )
26
+
27
+
28
+ # ================================================
29
+ # Gradio
30
+ # ================================================
31
+
32
+ # i18n
33
+ _i18n_key2lang_dict = dict(
34
+ # Title markdown
35
+ title_md_desc=dict(
36
+ en="FireRedTTS-2 🔥 Dialogue Generation",
37
+ zh="FireRedTTS-2 🔥 对话生成",
38
+ ),
39
+ # Voice mode radio
40
+ voice_mode_label=dict(
41
+ en="Voice Mode",
42
+ zh="音色模式",
43
+ ),
44
+ voice_model_choice1=dict(
45
+ en="Voice Clone",
46
+ zh="音色克隆",
47
+ ),
48
+ voice_model_choice2=dict(
49
+ en="Random Voice",
50
+ zh="随机音色",
51
+ ),
52
+ # Speaker1 Prompt
53
+ spk1_prompt_audio_label=dict(
54
+ en="Speaker 1 Prompt Audio",
55
+ zh="说话人 1 参考语音",
56
+ ),
57
+ spk1_prompt_text_label=dict(
58
+ en="Speaker 1 Prompt Text",
59
+ zh="说话人 1 参考文本",
60
+ ),
61
+ spk1_prompt_text_placeholder=dict(
62
+ en="[S1] text of speaker 1 prompt audio.",
63
+ zh="[S1] 说话人 1 参考文本",
64
+ ),
65
+ # Speaker2 Prompt
66
+ spk2_prompt_audio_label=dict(
67
+ en="Speaker 2 Prompt Audio",
68
+ zh="说话人 2 参考语音",
69
+ ),
70
+ spk2_prompt_text_label=dict(
71
+ en="Speaker 2 Prompt Text",
72
+ zh="说话人 2 参考文本",
73
+ ),
74
+ spk2_prompt_text_placeholder=dict(
75
+ en="[S2] text of speaker 2 prompt audio.",
76
+ zh="[S2] 说话人 2 参考文本",
77
+ ),
78
+ # Dialogue input textbox
79
+ dialogue_text_input_label=dict(
80
+ en="Dialogue Text Input",
81
+ zh="对话文本输入",
82
+ ),
83
+ dialogue_text_input_placeholder=dict(
84
+ en="[S1]text[S2]text[S1]text...",
85
+ zh="[S1]文本[S2]文本[S1]文本...",
86
+ ),
87
+ # Generate button
88
+ generate_btn_label=dict(
89
+ en="Generate Audio",
90
+ zh="合成",
91
+ ),
92
+ # Generated audio
93
+ generated_audio_label=dict(
94
+ en="Generated Dialogue Audio",
95
+ zh="合成的对话音频",
96
+ ),
97
+ # Warining1: invalid text for prompt
98
+ warn_invalid_spk1_prompt_text=dict(
99
+ en='Invalid speaker 1 prompt text, should strictly follow: "[S1]xxx"',
100
+ zh='说话人 1 参考文本不合规,格式:"[S1]xxx"',
101
+ ),
102
+ warn_invalid_spk2_prompt_text=dict(
103
+ en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"',
104
+ zh='说话人 2 参考文本不合规,格式:"[S2]xxx"',
105
+ ),
106
+ # Warining2: invalid text for dialogue input
107
+ warn_invalid_dialogue_text=dict(
108
+ en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."',
109
+ zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."',
110
+ ),
111
+ # Warining3: incomplete prompt info
112
+ warn_incomplete_prompt=dict(
113
+ en="Please provide prompt audio and text for both speaker 1 and speaker 2",
114
+ zh="请提供说话人 1 与说话人 2 的参考语音与参考文本",
115
+ ),
116
+ )
117
+
118
+ global_lang: Literal["zh", "en"] = "zh"
119
+
120
+
121
+ def i18n(key):
122
+ global global_lang
123
+ return _i18n_key2lang_dict[key][global_lang]
124
+
125
+
126
+ def check_monologue_text(text: str, prefix: str = None) -> bool:
127
+ text = text.strip()
128
+ # Check speaker tags
129
+ if prefix is not None and (not text.startswith(prefix)):
130
+ return False
131
+ # Remove prefix
132
+ if prefix is not None:
133
+ text = text.removeprefix(prefix)
134
+ text = text.strip()
135
+ # If empty?
136
+ if len(text) == 0:
137
+ return False
138
+ return True
139
+
140
+
141
+ def check_dialogue_text(text_list: List[str]) -> bool:
142
+ if len(text_list) == 0:
143
+ return False
144
+ for text in text_list:
145
+ if not (
146
+ check_monologue_text(text, "[S1]")
147
+ or check_monologue_text(text, "[S2]")
148
+ or check_monologue_text(text, "[S3]")
149
+ or check_monologue_text(text, "[S4]")
150
+ ):
151
+ return False
152
+ return True
153
+
154
+
155
+ @spaces.GPU(duration=200)
156
+ def dialogue_synthesis_function(
157
+ target_text: str,
158
+ voice_mode: Literal[0, 1] = 0, # 0 means voice clone
159
+ spk1_prompt_text: str | None = "",
160
+ spk1_prompt_audio: str | None = None,
161
+ spk2_prompt_text: str | None = "",
162
+ spk2_prompt_audio: str | None = None,
163
+ ):
164
+ # Voice clone mode, check prompt info
165
+ if voice_mode == 0:
166
+ prompt_has_value = [
167
+ spk1_prompt_text != "",
168
+ spk1_prompt_audio is not None,
169
+ spk2_prompt_text != "",
170
+ spk2_prompt_audio is not None,
171
+ ]
172
+ if not all(prompt_has_value):
173
+ gr.Warning(message=i18n("warn_incomplete_prompt"))
174
+ return None
175
+ if not check_monologue_text(spk1_prompt_text, "[S1]"):
176
+ gr.Warning(message=i18n("warn_invalid_spk1_prompt_text"))
177
+ return None
178
+ if not check_monologue_text(spk2_prompt_text, "[S2]"):
179
+ gr.Warning(message=i18n("warn_invalid_spk2_prompt_text"))
180
+ return None
181
+ # Check dialogue text
182
+ target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text)
183
+ target_text_list = [text.strip() for text in target_text_list]
184
+ if not check_dialogue_text(target_text_list):
185
+ gr.Warning(message=i18n("warn_invalid_dialogue_text"))
186
+ return None
187
+
188
+ # Go synthesis
189
+ progress_bar = gr.Progress(track_tqdm=True)
190
+ prompt_wav_list = (
191
+ None if voice_mode != 0 else [spk1_prompt_audio, spk2_prompt_audio]
192
+ )
193
+ prompt_text_list = None if voice_mode != 0 else [spk1_prompt_text, spk2_prompt_text]
194
+ target_audio = model.generate_dialogue(
195
+ text_list=target_text_list,
196
+ prompt_wav_list=prompt_wav_list,
197
+ prompt_text_list=prompt_text_list,
198
+ temperature=0.9,
199
+ topk=30,
200
+ )
201
+ return (24000, target_audio.squeeze(0).numpy())
202
+
203
+
204
+ # UI rendering
205
+ def render_interface() -> gr.Blocks:
206
+ with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page:
207
+ # ======================== UI ========================
208
+ # A large title
209
+ title_desc = gr.Markdown(value="# {}".format(i18n("title_md_desc")))
210
+ with gr.Row():
211
+ lang_choice = gr.Radio(
212
+ choices=["中文", "English"],
213
+ value="中文",
214
+ label="Display Language/显示语言",
215
+ type="index",
216
+ interactive=True,
217
+ )
218
+ voice_mode_choice = gr.Radio(
219
+ choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")],
220
+ value=i18n("voice_model_choice1"),
221
+ label=i18n("voice_mode_label"),
222
+ type="index",
223
+ interactive=True,
224
+ )
225
+ with gr.Row():
226
+ # ==== Speaker1 Prompt ====
227
+ with gr.Column(scale=1):
228
+ with gr.Group(visible=True) as spk1_prompt_group:
229
+ spk1_prompt_audio = gr.Audio(
230
+ label=i18n("spk1_prompt_audio_label"),
231
+ type="filepath",
232
+ editable=False,
233
+ interactive=True,
234
+ ) # Audio component returns tmp audio path
235
+ spk1_prompt_text = gr.Textbox(
236
+ label=i18n("spk1_prompt_text_label"),
237
+ placeholder=i18n("spk1_prompt_text_placeholder"),
238
+ lines=3,
239
+ )
240
+ # ==== Speaker2 Prompt ====
241
+ with gr.Column(scale=1):
242
+ with gr.Group(visible=True) as spk2_prompt_group:
243
+ spk2_prompt_audio = gr.Audio(
244
+ label=i18n("spk2_prompt_audio_label"),
245
+ type="filepath",
246
+ editable=False,
247
+ interactive=True,
248
+ )
249
+ spk2_prompt_text = gr.Textbox(
250
+ label=i18n("spk2_prompt_text_label"),
251
+ placeholder=i18n("spk2_prompt_text_placeholder"),
252
+ lines=3,
253
+ )
254
+ # ==== Text input ====
255
+ with gr.Column(scale=2):
256
+ dialogue_text_input = gr.Textbox(
257
+ label=i18n("dialogue_text_input_label"),
258
+ placeholder=i18n("dialogue_text_input_placeholder"),
259
+ lines=18,
260
+ )
261
+ # Generate button
262
+ generate_btn = gr.Button(
263
+ value=i18n("generate_btn_label"), variant="primary", size="lg"
264
+ )
265
+ # Long output audio
266
+ generate_audio = gr.Audio(
267
+ label=i18n("generated_audio_label"),
268
+ interactive=False,
269
+ )
270
+
271
+ # ======================== Action ========================
272
+ # Language action
273
+ def _change_component_language(lang):
274
+ global global_lang
275
+ global_lang = ["zh", "en"][lang]
276
+ return [
277
+ # title_desc
278
+ gr.update(value="# {}".format(i18n("title_md_desc"))),
279
+ # voice_mode_choice
280
+ gr.update(
281
+ choices=[i18n("voice_model_choice1"), i18n("voice_model_choice2")],
282
+ value=i18n("voice_model_choice1"),
283
+ label=i18n("voice_mode_label"),
284
+ ),
285
+ # spk1_prompt_{audio,text}
286
+ gr.update(label=i18n("spk1_prompt_audio_label")),
287
+ gr.update(
288
+ label=i18n("spk1_prompt_text_label"),
289
+ placeholder=i18n("spk1_prompt_text_placeholder"),
290
+ ),
291
+ # spk2_prompt_{audio,text}
292
+ gr.update(label=i18n("spk2_prompt_audio_label")),
293
+ gr.update(
294
+ label=i18n("spk2_prompt_text_label"),
295
+ placeholder=i18n("spk2_prompt_text_placeholder"),
296
+ ),
297
+ # dialogue_text_input
298
+ gr.update(
299
+ label=i18n("dialogue_text_input_label"),
300
+ placeholder=i18n("dialogue_text_input_placeholder"),
301
+ ),
302
+ # generate_btn
303
+ gr.update(value=i18n("generate_btn_label")),
304
+ # generate_audio
305
+ gr.update(label=i18n("generated_audio_label")),
306
+ ]
307
+
308
+ lang_choice.change(
309
+ fn=_change_component_language,
310
+ inputs=[lang_choice],
311
+ outputs=[
312
+ title_desc,
313
+ voice_mode_choice,
314
+ spk1_prompt_audio,
315
+ spk1_prompt_text,
316
+ spk2_prompt_audio,
317
+ spk2_prompt_text,
318
+ dialogue_text_input,
319
+ generate_btn,
320
+ generate_audio,
321
+ ],
322
+ )
323
+
324
+ # Voice clone mode action
325
+ def _change_prompt_input_visibility(voice_mode):
326
+ enable = voice_mode == 0
327
+ return [gr.update(visible=enable), gr.update(visible=enable)]
328
+
329
+ voice_mode_choice.change(
330
+ fn=_change_prompt_input_visibility,
331
+ inputs=[voice_mode_choice],
332
+ outputs=[spk1_prompt_group, spk2_prompt_group],
333
+ )
334
+ generate_btn.click(
335
+ fn=dialogue_synthesis_function,
336
+ inputs=[
337
+ dialogue_text_input,
338
+ voice_mode_choice,
339
+ spk1_prompt_text,
340
+ spk1_prompt_audio,
341
+ spk2_prompt_text,
342
+ spk2_prompt_audio,
343
+ ],
344
+ outputs=[generate_audio],
345
+ )
346
+ return page
347
+
348
+
349
+ if __name__ == "__main__":
350
+ # Download model
351
+ snapshot_download(repo_id='FireRedTeam/FireRedTTS2', local_dir='pretrained_models/FireRedTTS2')
352
+ # Initiate model
353
+ initiate_model('pretrained_models/FireRedTTS2')
354
+ # UI
355
+ page = render_interface()
356
+ page.queue()
357
+ page.launch()
fireredtts2/__init__.py ADDED
File without changes
fireredtts2/codec/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from fireredtts2.codec.model import RedCodecInfer
fireredtts2/codec/audio.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
17
+ and remove unnecessary dependencies.
18
+ """
19
+ import warnings
20
+ import numpy as np
21
+ from typing import Union, Optional
22
+
23
+
24
+ def hertz_to_mel(
25
+ freq: Union[float, np.ndarray], mel_scale: str = "htk"
26
+ ) -> Union[float, np.ndarray]:
27
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
28
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
29
+
30
+ if mel_scale == "htk":
31
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
32
+ elif mel_scale == "kaldi":
33
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
34
+
35
+ min_log_hertz = 1000.0
36
+ min_log_mel = 15.0
37
+ logstep = 27.0 / np.log(6.4)
38
+ mels = 3.0 * freq / 200.0
39
+
40
+ if isinstance(freq, np.ndarray):
41
+ log_region = freq >= min_log_hertz
42
+ mels[log_region] = (
43
+ min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
44
+ )
45
+ elif freq >= min_log_hertz:
46
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
47
+
48
+ return mels
49
+
50
+
51
+ def mel_to_hertz(
52
+ mels: Union[float, np.ndarray], mel_scale: str = "htk"
53
+ ) -> Union[float, np.ndarray]:
54
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
55
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
56
+
57
+ if mel_scale == "htk":
58
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
59
+ elif mel_scale == "kaldi":
60
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
61
+
62
+ min_log_hertz = 1000.0
63
+ min_log_mel = 15.0
64
+ logstep = np.log(6.4) / 27.0
65
+ freq = 200.0 * mels / 3.0
66
+
67
+ if isinstance(mels, np.ndarray):
68
+ log_region = mels >= min_log_mel
69
+ freq[log_region] = min_log_hertz * np.exp(
70
+ logstep * (mels[log_region] - min_log_mel)
71
+ )
72
+ elif mels >= min_log_mel:
73
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
74
+
75
+ return freq
76
+
77
+
78
+ def _create_triangular_filter_bank(
79
+ fft_freqs: np.ndarray, filter_freqs: np.ndarray
80
+ ) -> np.ndarray:
81
+ """
82
+ Creates a triangular filter bank.
83
+
84
+ Adapted from *torchaudio* and *librosa*.
85
+
86
+ Args:
87
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
88
+ Discrete frequencies of the FFT bins in Hz.
89
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
90
+ Center frequencies of the triangular filters to create, in Hz.
91
+
92
+ Returns:
93
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
94
+ """
95
+ filter_diff = np.diff(filter_freqs)
96
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
97
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
98
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
99
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
100
+
101
+
102
+ def mel_filter_bank(
103
+ num_frequency_bins: int,
104
+ num_mel_filters: int,
105
+ min_frequency: float,
106
+ max_frequency: float,
107
+ sampling_rate: int,
108
+ norm: Optional[str] = None,
109
+ mel_scale: str = "htk",
110
+ triangularize_in_mel_space: bool = False,
111
+ ) -> np.ndarray:
112
+ if norm is not None and norm != "slaney":
113
+ raise ValueError('norm must be one of None or "slaney"')
114
+
115
+ # center points of the triangular mel filters
116
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
117
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
118
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
119
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
120
+
121
+ if triangularize_in_mel_space:
122
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
123
+ fft_bin_width = sampling_rate / (num_frequency_bins * 2)
124
+ fft_freqs = hertz_to_mel(
125
+ fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale
126
+ )
127
+ filter_freqs = mel_freqs
128
+ else:
129
+ # frequencies of FFT bins in Hz
130
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
131
+
132
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
133
+
134
+ if norm is not None and norm == "slaney":
135
+ # Slaney-style mel is scaled to be approx constant energy per channel
136
+ enorm = 2.0 / (
137
+ filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]
138
+ )
139
+ mel_filters *= np.expand_dims(enorm, 0)
140
+
141
+ if (mel_filters.max(axis=0) == 0.0).any():
142
+ warnings.warn(
143
+ "At least one mel filter has all zero values. "
144
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
145
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
146
+ )
147
+
148
+ return mel_filters
fireredtts2/codec/decoder.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from fireredtts2.codec.whisper import WhisperEncoderLayer
5
+ from fireredtts2.codec.utils import make_nonpad_mask, make_block_causal_mask
6
+
7
+
8
+ class ResnetBlock(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels: int,
12
+ out_channels: int = None,
13
+ conv_shortcut: bool = False,
14
+ dropout: float = 0.0,
15
+ ):
16
+ super().__init__()
17
+ self.in_channels = in_channels
18
+ out_channels = in_channels if out_channels is None else out_channels
19
+ self.out_channels = out_channels
20
+ self.use_conv_shortcut = conv_shortcut
21
+
22
+ self.block1 = nn.Sequential(
23
+ nn.GroupNorm(
24
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
25
+ ),
26
+ nn.SiLU(),
27
+ nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
28
+ )
29
+
30
+ self.block2 = nn.Sequential(
31
+ nn.GroupNorm(
32
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
33
+ ),
34
+ nn.SiLU(),
35
+ nn.Dropout(dropout),
36
+ nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
37
+ )
38
+
39
+ if self.in_channels != self.out_channels:
40
+ if self.use_conv_shortcut:
41
+ self.conv_shortcut = torch.nn.Conv1d(
42
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
43
+ )
44
+ else:
45
+ self.nin_shortcut = torch.nn.Conv1d(
46
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
47
+ )
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ """
51
+ Args:
52
+ x: shape (b, c, t)
53
+ """
54
+ h = x
55
+ h = self.block1(h)
56
+ h = self.block2(h)
57
+
58
+ if self.in_channels != self.out_channels:
59
+ if self.use_conv_shortcut:
60
+ x = self.conv_shortcut(x)
61
+ else:
62
+ x = self.nin_shortcut(x)
63
+ return x + h
64
+
65
+
66
+ class Transpose(torch.nn.Module):
67
+ def __init__(self, dim0: int, dim1: int):
68
+ super().__init__()
69
+ self.dim0 = dim0
70
+ self.dim1 = dim1
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ x = torch.transpose(x, self.dim0, self.dim1)
74
+ return x
75
+
76
+
77
+ # A causal variant of Conv1d
78
+ class CausalConv1d(torch.nn.Conv1d):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ out_channels: int,
83
+ kernel_size: int,
84
+ ) -> None:
85
+ super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
86
+ self.causal_padding = (kernel_size - 1, 0)
87
+
88
+ def forward(self, x: torch.Tensor):
89
+ x = F.pad(x, self.causal_padding)
90
+ x = super(CausalConv1d, self).forward(x)
91
+ return x
92
+
93
+ def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor = None):
94
+ if cnn_cache is None:
95
+ cnn_cache = x.new_zeros(
96
+ (x.shape[0], self.in_channels, self.causal_padding[0])
97
+ )
98
+ x = torch.cat([cnn_cache, x], dim=2)
99
+ new_cnn_cache = x[..., -self.causal_padding[0] :]
100
+ x = super(CausalConv1d, self).forward(x)
101
+ return x, new_cnn_cache
102
+
103
+
104
+ # A causal variant of ResnetBlock
105
+ class CausalResnetBlock(nn.Module):
106
+ def __init__(
107
+ self,
108
+ in_channels: int,
109
+ out_channels: int = None,
110
+ dropout: float = 0.0,
111
+ ):
112
+ super().__init__()
113
+ self.in_channels = in_channels
114
+ out_channels = in_channels if out_channels is None else out_channels
115
+ self.out_channels = out_channels
116
+
117
+ self.block1 = nn.Sequential(
118
+ Transpose(1, 2),
119
+ nn.LayerNorm(in_channels),
120
+ Transpose(1, 2),
121
+ nn.SiLU(),
122
+ CausalConv1d(in_channels, out_channels, kernel_size=3),
123
+ )
124
+
125
+ self.block2 = nn.Sequential(
126
+ Transpose(1, 2),
127
+ nn.LayerNorm(out_channels),
128
+ Transpose(1, 2),
129
+ nn.SiLU(),
130
+ nn.Dropout(dropout),
131
+ CausalConv1d(out_channels, out_channels, kernel_size=3),
132
+ )
133
+ if self.in_channels != self.out_channels:
134
+ self.nin_shortcut = torch.nn.Conv1d(
135
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor):
139
+ """
140
+ Args:
141
+ x: shape (b, c, t)
142
+ """
143
+ h = x
144
+ h = self.block1(h)
145
+ h = self.block2(h)
146
+ if self.in_channels != self.out_channels:
147
+ x = self.nin_shortcut(x)
148
+ return x + h
149
+
150
+ def forward_chunk(self, x: torch.Tensor, cache: torch.Tensor = None):
151
+ """
152
+ Args:
153
+ x: shape (b, c, t)
154
+ cache: shape (b, c_in+c_out, t=2)
155
+ """
156
+ cache1, cache2 = (
157
+ (None, None)
158
+ if cache is None
159
+ else cache.split((self.in_channels, self.out_channels), dim=1)
160
+ )
161
+ h = x
162
+ # block1
163
+ h = self.block1[:4](h)
164
+ h, new_cache1 = self.block1[4].forward_chunk(h, cache1)
165
+ # block2
166
+ h = self.block2[:5](h)
167
+ h, new_cache2 = self.block2[5].forward_chunk(h, cache2)
168
+ if self.in_channels != self.out_channels:
169
+ x = self.nin_shortcut(x)
170
+ new_cache = torch.cat([new_cache1, new_cache2], dim=1)
171
+ return x + h, new_cache
172
+
173
+
174
+ # Nonstreaming Vocos backbone based on Transformer layers
175
+ class VocosBackbone(nn.Module):
176
+ def __init__(
177
+ self,
178
+ embed_dim: int = 1024,
179
+ num_layers: int = 12,
180
+ num_heads: int = 16,
181
+ dropout: float = 0.1,
182
+ ):
183
+ super().__init__()
184
+ self.in_proj = nn.Conv1d(embed_dim, embed_dim, kernel_size=7, padding=3)
185
+ self.prior_net = nn.Sequential(
186
+ ResnetBlock(embed_dim, embed_dim, dropout=dropout),
187
+ ResnetBlock(embed_dim, embed_dim, dropout=dropout),
188
+ )
189
+ self.transformers = nn.ModuleList(
190
+ [WhisperEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)]
191
+ )
192
+ self.post_net = nn.Sequential(
193
+ ResnetBlock(embed_dim, embed_dim, dropout=dropout),
194
+ ResnetBlock(embed_dim, embed_dim, dropout=dropout),
195
+ )
196
+ self.final_norm = nn.LayerNorm(embed_dim, eps=1e-6)
197
+
198
+ def forward(
199
+ self,
200
+ x: torch.Tensor,
201
+ x_lens: torch.Tensor,
202
+ ):
203
+ """
204
+ Args:
205
+ x: shape (b, t, c)
206
+ x_lens: shape (b,)
207
+ """
208
+ x = x.transpose(1, 2)
209
+ x = self.in_proj(x)
210
+ x = self.prior_net(x)
211
+ x = x.transpose(1, 2)
212
+
213
+ attention_mask = make_nonpad_mask(x_lens).unsqueeze(1) # (b, 1, t)
214
+ # NOTE(sfy): I think positional embedding is unnecessary
215
+ for layer in self.transformers:
216
+ x = layer(x, attention_mask)
217
+ x = x.transpose(1, 2)
218
+ x = self.post_net(x)
219
+ x = x.transpose(1, 2)
220
+ x = self.final_norm(x)
221
+ return x
222
+
223
+
224
+ # Streaming Vocos backbone based on Transformer layers
225
+ class CausalVocosBackbone(nn.Module):
226
+ def __init__(
227
+ self,
228
+ embed_dim: int = 1024,
229
+ num_layers: int = 12,
230
+ num_heads: int = 16,
231
+ dropout: float = 0.1,
232
+ ):
233
+ super().__init__()
234
+ self.in_proj = CausalConv1d(embed_dim, embed_dim, kernel_size=7)
235
+ self.prior_net = nn.Sequential(
236
+ CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
237
+ CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
238
+ )
239
+ self.transformers = nn.ModuleList(
240
+ [WhisperEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)]
241
+ )
242
+ self.post_net = nn.Sequential(
243
+ CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
244
+ CausalResnetBlock(embed_dim, embed_dim, dropout=dropout),
245
+ )
246
+ self.final_norm = nn.LayerNorm(embed_dim, eps=1e-6)
247
+
248
+ def forward(
249
+ self,
250
+ x: torch.Tensor,
251
+ x_lens: torch.Tensor,
252
+ ):
253
+ """
254
+ Args:
255
+ x: shape (b, t, c)
256
+ x_lens: shape (b,)
257
+ """
258
+ x = x.transpose(1, 2)
259
+ x = self.in_proj(x)
260
+ x = self.prior_net(x)
261
+ x = x.transpose(1, 2)
262
+
263
+ # NOTE(sfy): We have no padding in training, so safe for sdpa attention, no Nan.
264
+ # Also, 1 token(12.5Hz) -> 4 latents(50Hz) -> 8 latents(100Hz),
265
+ # so we design a 8 block causal attention mask instead of fully causal to improve performance
266
+ attention_mask = make_block_causal_mask(x_lens, chunk_size=8)
267
+ for layer in self.transformers:
268
+ x = layer(x, attention_mask)
269
+
270
+ x = x.transpose(1, 2)
271
+ x = self.post_net(x)
272
+ x = x.transpose(1, 2)
273
+ x = self.final_norm(x)
274
+ return x
275
+
276
+ def forward_chunk(
277
+ self,
278
+ x: torch.Tensor,
279
+ conv_cache1: torch.Tensor = None,
280
+ conv_cache2: torch.Tensor = None,
281
+ kv_cache: torch.Tensor = None,
282
+ ):
283
+ # Unpack cache
284
+ cache1 = conv_cache1
285
+ cache2, cache3, cache4, cache5 = (
286
+ (None, None, None, None)
287
+ if conv_cache2 is None
288
+ else conv_cache2.chunk(4, dim=1)
289
+ )
290
+
291
+ # cache1: shape (b, c=embed_dim, t=6)
292
+ x = x.transpose(1, 2)
293
+ x, new_cache1 = self.in_proj.forward_chunk(x, cache1)
294
+ # cache2: shape (b, c=embed_dim*2, t=2)
295
+ x, new_cache2 = self.prior_net[0].forward_chunk(x, cache2)
296
+ # cache3: shape (b, c=embed_dim*2, t=2)
297
+ x, new_cache3 = self.prior_net[1].forward_chunk(x, cache3)
298
+ x = x.transpose(1, 2)
299
+
300
+ # k,v-cache: shape (b, nlayer, nh, t, c*2)
301
+ new_kv_cache = []
302
+ for idx, layer in enumerate(self.transformers):
303
+ kv_cache_i = None if kv_cache is None else kv_cache[:, idx]
304
+ x, new_kv_cache_i = layer.forward_chunk(x, kv_cache=kv_cache_i)
305
+ new_kv_cache.append(new_kv_cache_i)
306
+ new_kv_cache = torch.stack(new_kv_cache, dim=1)
307
+
308
+ x = x.transpose(1, 2)
309
+ # cache4: shape (b, c=embed_dim*2, t=2)
310
+ x, new_cache4 = self.post_net[0].forward_chunk(x, cache4)
311
+ # cache5: shape (b, c=embed_dim*2, t=2)
312
+ x, new_cache5 = self.post_net[1].forward_chunk(x, cache5)
313
+ x = x.transpose(1, 2)
314
+ x = self.final_norm(x)
315
+
316
+ new_conv_cache1 = new_cache1
317
+ new_conv_cache2 = torch.cat(
318
+ [new_cache2, new_cache3, new_cache4, new_cache5], dim=1
319
+ )
320
+ return x, new_conv_cache1, new_conv_cache2, new_kv_cache
321
+
322
+
323
+ class ISTFT(nn.Module):
324
+ """
325
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
326
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
327
+ See issue: https://github.com/pytorch/pytorch/issues/62323
328
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
329
+ The NOLA constraint is met as we trim padded samples anyway.
330
+
331
+ Args:
332
+ n_fft (int): Size of Fourier transform.
333
+ hop_length (int): The distance between neighboring sliding window frames.
334
+ win_length (int): The size of window frame and STFT filter.
335
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
336
+ """
337
+
338
+ def __init__(
339
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
340
+ ):
341
+ super().__init__()
342
+ assert padding in ["center", "same"], "Padding must be 'center' or 'same'."
343
+ self.padding = padding
344
+ self.n_fft = n_fft
345
+ self.hop_length = hop_length
346
+ self.win_length = win_length
347
+ window = torch.hann_window(win_length)
348
+ self.register_buffer("window", window)
349
+
350
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
351
+ """
352
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
353
+
354
+ Args:
355
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
356
+ N is the number of frequency bins, and T is the number of time frames.
357
+
358
+ Returns:
359
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
360
+ """
361
+ if self.padding == "center":
362
+ # Fallback to pytorch native implementation
363
+ return torch.istft(
364
+ spec,
365
+ self.n_fft,
366
+ self.hop_length,
367
+ self.win_length,
368
+ self.window,
369
+ center=True,
370
+ )
371
+ elif self.padding == "same":
372
+ pad = (self.win_length - self.hop_length) // 2
373
+ else:
374
+ raise ValueError("Padding must be 'center' or 'same'.")
375
+
376
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
377
+ B, N, T = spec.shape
378
+
379
+ # Inverse FFT
380
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
381
+ ifft = ifft * self.window[None, :, None]
382
+
383
+ # Overlap and Add
384
+ output_size = (T - 1) * self.hop_length + self.win_length
385
+ y = torch.nn.functional.fold(
386
+ ifft,
387
+ output_size=(1, output_size),
388
+ kernel_size=(1, self.win_length),
389
+ stride=(1, self.hop_length),
390
+ )[:, 0, 0, pad:-pad]
391
+
392
+ # Window envelope
393
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
394
+ window_envelope = torch.nn.functional.fold(
395
+ window_sq,
396
+ output_size=(1, output_size),
397
+ kernel_size=(1, self.win_length),
398
+ stride=(1, self.hop_length),
399
+ ).squeeze()[pad:-pad]
400
+
401
+ # Normalize
402
+ assert (window_envelope > 1e-11).all()
403
+ y = y / window_envelope
404
+
405
+ return y
406
+
407
+ def forward_chunk(
408
+ self, spec: torch.Tensor, cache: torch.Tensor = None, last_chunk: bool = False
409
+ ):
410
+ """Forward only one frame.
411
+
412
+ Args:
413
+ spec: shape (B, N, T=chunk_size)
414
+ cache: previous chunk's last ifft frame, shape (B, N, T=3)
415
+ last_chunk: if last_chunk, will not trim the last (win-hop) segment
416
+ Returns:
417
+ y: shape (B, T=effective_length)
418
+ """
419
+ assert self.padding == "same", "Padding must be same."
420
+ assert (
421
+ self.win_length % self.hop_length == 0
422
+ ), f"{self.win_length} {self.hop_length}"
423
+ pad = (self.win_length - self.hop_length) // 2
424
+
425
+ # Inverse FFT
426
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
427
+ ifft = ifft * self.window[None, :, None] # (B, N, T=chunk_size)
428
+
429
+ # Append previous cache
430
+ if cache is not None:
431
+ ifft = torch.cat([cache, ifft], dim=-1)
432
+ new_cache_t = self.win_length // self.hop_length - 1
433
+ new_cache = ifft[..., -new_cache_t:]
434
+
435
+ # Overlap and Add
436
+ output_size = (ifft.shape[-1] - 1) * self.hop_length + self.win_length
437
+ y = torch.nn.functional.fold(
438
+ ifft,
439
+ output_size=(1, output_size),
440
+ kernel_size=(1, self.win_length),
441
+ stride=(1, self.hop_length),
442
+ )[:, 0, 0, :]
443
+
444
+ # Window envelope
445
+ window_sq = (
446
+ self.window.square().expand(1, ifft.shape[-1], -1).transpose(1, 2)
447
+ ) # (B=1, N, T)
448
+ window_envelope = torch.nn.functional.fold(
449
+ window_sq,
450
+ output_size=(1, output_size),
451
+ kernel_size=(1, self.win_length),
452
+ stride=(1, self.hop_length),
453
+ ).squeeze()
454
+
455
+ # Normalize
456
+ # assert (window_envelope > 1e-11).all()
457
+ y = y / window_envelope
458
+
459
+ # Only take effective part
460
+ if cache is None:
461
+ y = y[:, pad:]
462
+ else:
463
+ y = y[:, (self.win_length - self.hop_length) :]
464
+ if last_chunk:
465
+ y = y[:, :-pad]
466
+ else:
467
+ y = y[:, : -(self.win_length - self.hop_length)]
468
+ return y, new_cache
469
+
470
+
471
+ class ISTFTHead(nn.Module):
472
+ """
473
+ ISTFT Head module for predicting STFT complex coefficients.
474
+
475
+ Args:
476
+ dim (int): Hidden dimension of the model.
477
+ n_fft (int): Size of Fourier transform.
478
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
479
+ the resolution of the input features.
480
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
481
+ """
482
+
483
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
484
+ super().__init__()
485
+ self.hop_length = hop_length
486
+ out_dim = n_fft + 2
487
+ self.out = torch.nn.Linear(dim, out_dim)
488
+ self.istft = ISTFT(
489
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
490
+ )
491
+
492
+ def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> torch.Tensor:
493
+ """
494
+ Forward pass of the ISTFTHead module.
495
+
496
+ Args:
497
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
498
+ L is the sequence length, and H denotes the model dimension.
499
+
500
+ Returns:
501
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
502
+ """
503
+ x_pred = self.out(x)
504
+ x_pred = x_pred.transpose(1, 2)
505
+ mag, p = x_pred.chunk(2, dim=1)
506
+ mag = torch.exp(mag)
507
+ mag = torch.clip(
508
+ mag, max=1e2
509
+ ) # safeguard to prevent excessively large magnitudes
510
+ # wrapping happens here. These two lines produce real and imaginary value
511
+ x = torch.cos(p)
512
+ y = torch.sin(p)
513
+ # recalculating phase here does not produce anything new
514
+ # only costs time
515
+ # phase = torch.atan2(y, x)
516
+ # S = mag * torch.exp(phase * 1j)
517
+ # better directly produce the complex value
518
+ S = mag * (x + 1j * y)
519
+ audio = self.istft(S)
520
+ audio_length = x_len * self.hop_length
521
+ return audio, audio_length
522
+
523
+ def forward_chunk(
524
+ self, x: torch.Tensor, cache: torch.Tensor = None, last_chunk: bool = False
525
+ ):
526
+ """ISTFTHead can be adapted in streaming inference without retraining.
527
+
528
+ Args:
529
+ x: shape (B, T, C)
530
+ cache: shape (B, N, T=3), istft cache
531
+ Returns:
532
+ audio: shape (B, t)
533
+ """
534
+ x_pred = self.out(x)
535
+ x_pred = x_pred.transpose(1, 2)
536
+ mag, p = x_pred.chunk(2, dim=1)
537
+ mag = torch.exp(mag) # (B, C, T)
538
+ mag = torch.clip(
539
+ mag, max=1e2
540
+ ) # safeguard to prevent excessively large magnitudes
541
+ # wrapping happens here. These two lines produce real and imaginary value
542
+ x = torch.cos(p)
543
+ y = torch.sin(p)
544
+ S = mag * (x + 1j * y) # (B, C, T)
545
+ audio, new_cache = self.istft.forward_chunk(S, cache, last_chunk)
546
+ return audio, new_cache
547
+
548
+
549
+ # UpsampleConv(50->100Hz) + VocosBackbone + ISTFTHead
550
+ class AcousticDecoder(nn.Module):
551
+ def __init__(
552
+ self,
553
+ # Transformer
554
+ embed_dim: int,
555
+ num_layers: int,
556
+ num_heads: int,
557
+ dropout: float = 0.0,
558
+ # iSTFT
559
+ hop_length: int = 240,
560
+ # Causal
561
+ causal: bool = False,
562
+ ):
563
+ super().__init__()
564
+ self.embed_dim = embed_dim
565
+ self.num_layers = num_layers
566
+ self.num_heads = num_heads
567
+ self.hop_length = hop_length
568
+ self.causal = causal
569
+
570
+ # Output upsample
571
+ self.upsample_conv = nn.Sequential(
572
+ nn.ConvTranspose1d(
573
+ embed_dim,
574
+ embed_dim,
575
+ kernel_size=3,
576
+ stride=2,
577
+ padding=0, # Do not fill input side
578
+ output_padding=0, # Can be adjusted to precisely control length
579
+ ),
580
+ nn.GELU(),
581
+ nn.ConvTranspose1d(
582
+ embed_dim,
583
+ embed_dim,
584
+ kernel_size=3,
585
+ stride=1,
586
+ padding=0, # Do not fill input side
587
+ ),
588
+ nn.GELU(),
589
+ )
590
+ self.backbone = (
591
+ CausalVocosBackbone(embed_dim, num_layers, num_heads, dropout)
592
+ if causal
593
+ else VocosBackbone(embed_dim, num_layers, num_heads, dropout)
594
+ )
595
+ self.isift = ISTFTHead(embed_dim, hop_length * 4, hop_length, padding="same")
596
+ # Init weights
597
+ self.apply(self._init_weights)
598
+
599
+ def _init_weights(self, m):
600
+ if isinstance(m, nn.Conv1d):
601
+ nn.init.trunc_normal_(m.weight, std=0.02)
602
+ nn.init.constant_(m.bias, 0)
603
+
604
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
605
+ """
606
+ Args:
607
+ x: shape (b, t, c)
608
+ x_lens: shape (b,)
609
+ """
610
+ # Upsample
611
+ target_length = x.shape[1] * 2
612
+ x = x.transpose(1, 2)
613
+ x = self.upsample_conv(x)
614
+ x = x.transpose(1, 2)
615
+ # NOTE strict upsampling, trim the last 3 elements
616
+ x = x[:, :target_length]
617
+ x_lens = x_lens * 2
618
+ # Backbone
619
+ x = self.backbone(x, x_lens)
620
+ # iSTFT
621
+ y, y_lens = self.isift(x, x_lens)
622
+ return y, y_lens
623
+
624
+ def forward_upsample_conv_chunk(self, x: torch.Tensor, cache: torch.Tensor = None):
625
+ """Stream forward upsample_conv module with previous block cache.
626
+
627
+ Args:
628
+ x: shape (B, C, T)
629
+ cache: shape (B, C, 3), where 3 denotes 1 history state for 1st conv and 2 for the rest conv.
630
+ """
631
+ # Unpack cache
632
+ cache1, cache2 = (
633
+ (None, None) if cache is None else torch.split(cache, [1, 2], dim=2)
634
+ )
635
+ # 1st conv cache
636
+ if cache1 is not None:
637
+ x = torch.cat([cache1, x], dim=2)
638
+ new_cache1 = x[..., -1:]
639
+ # 1st conv
640
+ x = self.upsample_conv[0](x)[..., :-1] # remove extra 1 frame
641
+ if cache1 is not None:
642
+ x = x[..., 2:] # remove cache1 part
643
+ x = self.upsample_conv[1](x)
644
+ # 2nd conv cache
645
+ if cache2 is not None:
646
+ x = torch.cat([cache2, x], dim=2)
647
+ new_cache2 = x[..., -2:]
648
+ # 2nd conv
649
+ x = self.upsample_conv[2](x)[..., :-2] # remove extra 2 frame
650
+ if cache2 is not None:
651
+ x = x[..., 2:] # remove cache2 part
652
+ x = self.upsample_conv[3](x)
653
+
654
+ new_cache = torch.cat([new_cache1, new_cache2], dim=2)
655
+ return x, new_cache
656
+
657
+ def forward_chunk(
658
+ self,
659
+ x: torch.Tensor,
660
+ # Upsample conv cache
661
+ up_conv_cache: torch.Tensor = None,
662
+ # Backbone conv cache
663
+ bb_conv_cache1: torch.Tensor = None,
664
+ bb_conv_cache2: torch.Tensor = None,
665
+ # Backbone attention cache
666
+ bb_kv_cache: torch.Tensor = None,
667
+ # iSTFT cache
668
+ is_cache: torch.Tensor = None,
669
+ last_chunk: bool = False,
670
+ ):
671
+ """
672
+ Args:
673
+ x: input sequence at 50Hz, length should be multiples of 4
674
+ """
675
+ assert (
676
+ self.causal
677
+ ), "Only AcousticDecoder with causal=True supports forward_chunk method."
678
+
679
+ x = x.transpose(1, 2)
680
+ x, new_up_conv_cache = self.forward_upsample_conv_chunk(x, up_conv_cache)
681
+ x = x.transpose(1, 2)
682
+ # Backbone
683
+ x, new_bb_conv_cache1, new_bb_conv_cache2, new_bb_kv_cache = (
684
+ self.backbone.forward_chunk(
685
+ x,
686
+ bb_conv_cache1,
687
+ bb_conv_cache2,
688
+ bb_kv_cache,
689
+ )
690
+ )
691
+ # iSTFT
692
+ y, new_is_cache = self.isift.forward_chunk(x, is_cache, last_chunk)
693
+ return (
694
+ y,
695
+ new_up_conv_cache,
696
+ new_bb_conv_cache1,
697
+ new_bb_conv_cache2,
698
+ new_bb_kv_cache,
699
+ new_is_cache,
700
+ )
fireredtts2/codec/model.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import List, Dict
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from fireredtts2.codec.rvq import ResidualVQ
10
+ from fireredtts2.codec.decoder import AcousticDecoder
11
+ from fireredtts2.codec.utils import make_nonpad_mask
12
+ from fireredtts2.codec.whisper import (
13
+ WhisperEncoderLayer,
14
+ PretrainedWhisperEncoder,
15
+ WhisperAcousticEncoder,
16
+ )
17
+
18
+
19
+ class SslAdaptor(nn.Module):
20
+ def __init__(
21
+ self,
22
+ in_dim: int,
23
+ embed_dim: int,
24
+ out_dim: int,
25
+ num_layers: int,
26
+ num_heads: int,
27
+ ffn_dim: int = None,
28
+ attn_dropout: float = 0.0,
29
+ dropout: float = 0.0,
30
+ ):
31
+ super().__init__()
32
+ self.in_dim = in_dim
33
+ self.embed_dim = embed_dim
34
+ self.dropout = dropout
35
+ # Input Projection
36
+ self.in_proj = nn.Linear(in_dim, embed_dim)
37
+ # Transformer
38
+ self.layers = nn.ModuleList(
39
+ [
40
+ WhisperEncoderLayer(
41
+ embed_dim, num_heads, ffn_dim, attn_dropout, dropout
42
+ )
43
+ for _ in range(num_layers)
44
+ ]
45
+ )
46
+ # Output norm
47
+ self.layer_norm = nn.LayerNorm(embed_dim)
48
+ # Output projection
49
+ self.out_proj = nn.Linear(embed_dim, out_dim)
50
+ # Init weight
51
+ self.apply(self._init_weights)
52
+
53
+ def forward(
54
+ self,
55
+ hidden_states: torch.Tensor,
56
+ hidden_length: torch.Tensor,
57
+ ):
58
+ # Downsampling
59
+ hidden_states = self.in_proj(hidden_states)
60
+ # Transformer
61
+ attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t)
62
+ for layer in self.layers:
63
+ hidden_states = layer(hidden_states, attention_mask)
64
+ hidden_states = self.layer_norm(hidden_states)
65
+ hidden_states = self.out_proj(hidden_states)
66
+ return hidden_states, hidden_length
67
+
68
+ def _init_weights(self, module):
69
+ std = 0.02
70
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
71
+ module.weight.data.normal_(mean=0.0, std=std)
72
+ if module.bias is not None:
73
+ module.bias.data.zero_()
74
+ elif isinstance(module, nn.Embedding):
75
+ module.weight.data.normal_(mean=0.0, std=std)
76
+ if module.padding_idx is not None:
77
+ module.weight.data[module.padding_idx].zero_()
78
+
79
+
80
+ class ResidualDownConv(nn.Module):
81
+ def __init__(
82
+ self,
83
+ embed_dim: int = 768,
84
+ avg_pooler=4,
85
+ ):
86
+ super().__init__()
87
+ self.embed_dim = embed_dim
88
+ self.avg_pooler = avg_pooler
89
+ self.intermediate_dim = embed_dim * avg_pooler
90
+ # Convolution layer for downsampling
91
+ self.gate_proj = nn.Conv1d(
92
+ embed_dim, self.intermediate_dim, avg_pooler, avg_pooler, bias=False
93
+ )
94
+ self.up_proj = nn.Conv1d(
95
+ embed_dim, self.intermediate_dim, avg_pooler, avg_pooler, bias=False
96
+ )
97
+ # Downsampled linear projection
98
+ self.down_proj = nn.Linear(
99
+ self.intermediate_dim, self.intermediate_dim, bias=False
100
+ )
101
+ # Activation function and layer normalization
102
+ self.act_fn = nn.SiLU()
103
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
104
+ # Final output projection
105
+ self.out_proj = nn.Linear(self.intermediate_dim, embed_dim)
106
+
107
+ def forward(self, x: torch.Tensor, input_length: torch.Tensor):
108
+ output_length = input_length // self.avg_pooler
109
+ batch_size, seq_len, _ = x.shape # (B, T, D)
110
+
111
+ xt = x.permute(0, 2, 1) # (B, D, T)
112
+ g = self.gate_proj(xt).permute(0, 2, 1) # (B, T//4, D*4)
113
+ u = self.up_proj(xt).permute(0, 2, 1) # (B, T//4, D*4)
114
+ x = x.reshape(batch_size, -1, self.intermediate_dim) # (B, T//4, D*4)
115
+
116
+ c = self.down_proj(self.act_fn(g) * u) # (B, T//4, D*4)
117
+ res = self.layer_norm(c + x) # (B, T//4, D*4)
118
+
119
+ res = self.out_proj(res)
120
+ return res, output_length
121
+
122
+
123
+ class UpConv(nn.Module):
124
+ def __init__(
125
+ self,
126
+ embed_dim: int = 768,
127
+ stride: int = 4,
128
+ ):
129
+ super().__init__()
130
+ self.embed_dim = embed_dim
131
+ self.stride = stride
132
+ self.in_proj = nn.Linear(embed_dim, self.stride * embed_dim)
133
+ # Simple transpose convolution layer to keep channel number consistent
134
+ self.up_conv = nn.ConvTranspose1d(
135
+ self.stride * embed_dim,
136
+ embed_dim,
137
+ kernel_size=stride,
138
+ stride=stride,
139
+ bias=False,
140
+ )
141
+
142
+ def forward(self, x: torch.Tensor, input_length: torch.Tensor):
143
+ x = self.in_proj(x)
144
+ x = x.transpose(1, 2)
145
+ res = self.up_conv(x)
146
+ res = res.transpose(1, 2)
147
+ output_length = input_length * self.stride
148
+ return res, output_length
149
+
150
+
151
+ class RedCodec(nn.Module):
152
+ def __init__(
153
+ self,
154
+ ssl: PretrainedWhisperEncoder,
155
+ ssl_adaptor: SslAdaptor,
156
+ acoustic_encoder: WhisperAcousticEncoder,
157
+ downsample: ResidualDownConv,
158
+ rvq: ResidualVQ,
159
+ upsample: UpConv,
160
+ semantic_decoder: SslAdaptor,
161
+ acoustic_decoder: AcousticDecoder,
162
+ ):
163
+ super().__init__()
164
+ self.ssl = ssl
165
+ self.ssl_adaptor = ssl_adaptor
166
+ self.acoustic_encoder = acoustic_encoder
167
+ self.downsample = downsample
168
+ self.rvq = rvq
169
+ self.upsample = upsample
170
+ self.semantic_decoder = semantic_decoder
171
+ self.acoustic_decoder = acoustic_decoder
172
+
173
+ @classmethod
174
+ def from_config(cls, config_json: str) -> "RedCodec":
175
+ with open(config_json, "rb") as f:
176
+ config = json.load(f)["codec"]
177
+ ssl = PretrainedWhisperEncoder.from_pretrained()
178
+ ssl_adaptor = SslAdaptor(**config["ssl_adaptor"])
179
+ acoustic_encoder = WhisperAcousticEncoder(**config["acoustic_encoder"])
180
+ downsample = ResidualDownConv(**config["downsample"])
181
+ rvq = ResidualVQ(**config["rvq"])
182
+ upsample = UpConv(**config["upsample"])
183
+ semantic_decoder = SslAdaptor(**config["semantic_decoder"])
184
+ acoustic_decoder = AcousticDecoder(**config["acoustic_decoder"])
185
+ return cls(
186
+ ssl,
187
+ ssl_adaptor,
188
+ acoustic_encoder,
189
+ downsample,
190
+ rvq,
191
+ upsample,
192
+ semantic_decoder,
193
+ acoustic_decoder,
194
+ )
195
+
196
+
197
+ class RedCodecInfer(RedCodec):
198
+ def __init__(self, codec: RedCodec):
199
+ super().__init__(
200
+ codec.ssl,
201
+ codec.ssl_adaptor,
202
+ codec.acoustic_encoder,
203
+ codec.downsample,
204
+ codec.rvq,
205
+ codec.upsample,
206
+ codec.semantic_decoder,
207
+ codec.acoustic_decoder,
208
+ )
209
+
210
+ @classmethod
211
+ def from_pretrained(cls, conf_path: str, ckpt_path: str) -> "RedCodecInfer":
212
+ with open(conf_path, "r") as f:
213
+ codec = RedCodec.from_config(conf_path)
214
+ ckpt = torch.load(ckpt_path)["generator"]
215
+ codec.load_state_dict(ckpt)
216
+ return cls(codec)
217
+
218
+ def _encode_one_batch(self, audio16k: torch.Tensor):
219
+ B, T = audio16k.shape
220
+ audio16k_length = torch.tensor(
221
+ [T] * B, dtype=torch.long, device=audio16k.device
222
+ )
223
+ # Semantic
224
+ ssl, ssl_length = self.ssl.forward(audio16k, audio16k_length)
225
+ ssl = ssl.clone() # For onnx export
226
+ sem_feats, sem_length = self.ssl_adaptor(ssl, ssl_length)
227
+ # Acoustic
228
+ aco_feats, aco_length = self.acoustic_encoder(audio16k, audio16k_length)
229
+ # VQ
230
+ vq_in_feats = torch.cat([sem_feats, aco_feats], dim=2)
231
+ vq_in_feats, vq_in_length = self.downsample(vq_in_feats, aco_length)
232
+ # RVQ,
233
+ indices = self.rvq.encode_codes(vq_in_feats.transpose(1, 2)) # (nq, B, L)
234
+ indices = indices.permute(1, 0, 2)
235
+ return indices # (B, nq, L)
236
+
237
+ @staticmethod
238
+ def _pad_and_chunk(audio: torch.Tensor, chunk_size: int) -> List[torch.Tensor]:
239
+ pad_len = math.ceil(audio.shape[1] / chunk_size) * chunk_size - audio.shape[1]
240
+ audio = F.pad(audio, (0, pad_len), mode="constant", value=0)
241
+ audio_chunks = audio.split(chunk_size, dim=1)
242
+ return audio_chunks
243
+
244
+ @torch.inference_mode()
245
+ def encode(
246
+ self,
247
+ audio16k: torch.Tensor,
248
+ audio16k_length: torch.Tensor = None,
249
+ batch_size: int = 96,
250
+ ):
251
+ """
252
+ Args:
253
+ audio16k: shape (b, t)
254
+ audio16k_length: (b,)
255
+ Returns:
256
+ token: shape (b, nq, l)
257
+ token_length: (b,)
258
+ """
259
+ if audio16k_length is None:
260
+ assert audio16k.shape[0] == 1
261
+ audio16k_length = torch.tensor(
262
+ [audio16k.shape[1]], dtype=torch.long, device=audio16k.device
263
+ )
264
+
265
+ CHUNK_SIZE = 6 * 16000
266
+ B, T = audio16k.shape
267
+ # Pad, chunk, and batch
268
+ audio16k_batch = []
269
+ batch_size_list = []
270
+ for i in range(B):
271
+ # Remove extra paddings
272
+ one_audio_chunks = self._pad_and_chunk(
273
+ audio16k[i : (i + 1), : audio16k_length[i]], CHUNK_SIZE
274
+ )
275
+ audio16k_batch += one_audio_chunks
276
+ batch_size_list.append(len(one_audio_chunks))
277
+ audio16k_batch = torch.cat(audio16k_batch, dim=0)
278
+ # Batch encode
279
+ token_batch = []
280
+ for i in range(0, audio16k_batch.shape[0], batch_size):
281
+ one_audio_batch = audio16k_batch[i : (i + batch_size)]
282
+ one_token_batch = self._encode_one_batch(one_audio_batch)
283
+ token_batch.append(one_token_batch)
284
+ token_batch = torch.cat(token_batch, dim=0)
285
+ # Recover & concat
286
+ token_list = torch.split(
287
+ token_batch, batch_size_list, dim=0
288
+ ) # [(B=1, nq, l), (B=3, nq, l), ...]
289
+ token_list = [
290
+ torch.cat(token_ts.split(1, dim=0), dim=-1) # (B=1, nq, l)
291
+ for token_ts in token_list
292
+ ]
293
+ # Pad tokens
294
+ token = pad_sequence(
295
+ [ts.squeeze(0).transpose(1, 0) for ts in token_list],
296
+ batch_first=True,
297
+ padding_value=0,
298
+ ).transpose(
299
+ 1, 2
300
+ ) # (B, nq, L)
301
+ token_length = (audio16k_length / 1280).ceil().long()
302
+ token = token[
303
+ ..., : token_length.max()
304
+ ] # Remove extra paddings (we pad to multiples of 6s)
305
+ return token, token_length
306
+
307
+ @torch.inference_mode()
308
+ def decode(self, tokens: torch.Tensor):
309
+ """
310
+ Args:
311
+ tokens: (B=1, nq, L)
312
+ Returns:
313
+ audio: (B=1, t)
314
+ """
315
+ tokens = tokens.permute(1, 0, 2) # (B, nq, L) -> (nq, B, L)
316
+ vq_out_feats = self.rvq.decode_codes(tokens)
317
+ vq_out_feats = vq_out_feats.transpose(1, 2)
318
+ vq_out_length = torch.tensor(
319
+ [vq_out_feats.shape[1]], dtype=torch.long, device=vq_out_feats.device
320
+ )
321
+ vq_out_feats, vq_out_length = self.upsample(vq_out_feats, vq_out_length)
322
+ # audio: (b, t)
323
+ audio, audio_length = self.acoustic_decoder(vq_out_feats, vq_out_length)
324
+ return audio
325
+
326
+ @torch.inference_mode()
327
+ def decode_one_token(
328
+ self, token: torch.Tensor, cache_dict: Dict[str, torch.Tensor], last_token: bool
329
+ ):
330
+ """Decode one single token to audio.
331
+
332
+ Args:
333
+ token: (B=1, nq, L=1)
334
+ Returns:
335
+ audio: (B=1, t)
336
+ """
337
+ # token->latent->upsample, (naturally causal)
338
+ token = token.permute(1, 0, 2) # (B, nq, L) -> (nq, B, L)
339
+ vq_out_feats = self.rvq.decode_codes(token)
340
+ vq_out_feats = vq_out_feats.transpose(1, 2)
341
+ vq_out_length = torch.tensor(
342
+ [vq_out_feats.shape[1]], dtype=torch.long, device=vq_out_feats.device
343
+ )
344
+ vq_out_feats, vq_out_length = self.upsample(vq_out_feats, vq_out_length)
345
+ # acoustic decoder
346
+ up_conv_cache = cache_dict.get("up_conv_cache", None)
347
+ bb_conv_cache1 = cache_dict.get("bb_conv_cache1", None)
348
+ bb_conv_cache2 = cache_dict.get("bb_conv_cache2", None)
349
+ bb_kv_cache = cache_dict.get("bb_kv_cache", None)
350
+ is_cache = cache_dict.get("is_cache", None)
351
+
352
+ (
353
+ audio,
354
+ new_up_conv_cache,
355
+ new_bb_conv_cache1,
356
+ new_bb_conv_cache2,
357
+ new_bb_kv_cache,
358
+ new_is_cache,
359
+ ) = self.acoustic_decoder.forward_chunk(
360
+ vq_out_feats,
361
+ up_conv_cache,
362
+ bb_conv_cache1,
363
+ bb_conv_cache2,
364
+ bb_kv_cache,
365
+ is_cache,
366
+ last_token,
367
+ )
368
+
369
+ new_cache_dict = {
370
+ "up_conv_cache": new_up_conv_cache,
371
+ "bb_conv_cache1": new_bb_conv_cache1,
372
+ "bb_conv_cache2": new_bb_conv_cache2,
373
+ "bb_kv_cache": new_bb_kv_cache,
374
+ "is_cache": new_is_cache,
375
+ }
376
+ return audio, new_cache_dict
fireredtts2/codec/rvq.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+
7
+
8
+ def WNConv1d(*args, **kwargs):
9
+ return weight_norm(nn.Conv1d(*args, **kwargs))
10
+
11
+
12
+ def WNConvTranspose1d(*args, **kwargs):
13
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
14
+
15
+
16
+ class VectorQuantize(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_dim: int,
20
+ codebook_size: int,
21
+ codebook_dim: int,
22
+ ):
23
+ super().__init__()
24
+ self.input_dim = input_dim
25
+ self.codebook_size = codebook_size
26
+ self.codebook_dim = codebook_dim
27
+
28
+ self.in_project = (
29
+ WNConv1d(
30
+ self.input_dim, self.codebook_dim, kernel_size=1
31
+ ) # (B, D, T) -> (B, D', T)
32
+ if self.input_dim != self.codebook_dim
33
+ else nn.Identity()
34
+ )
35
+ self.out_project = (
36
+ WNConv1d(
37
+ self.codebook_dim, self.input_dim, kernel_size=1
38
+ ) # (B, D', T) -> (B, D, T)
39
+ if self.input_dim != self.codebook_dim
40
+ else nn.Identity()
41
+ )
42
+
43
+ # Initialize codebook and EMA buffers
44
+ self.register_buffer(
45
+ "codebook", torch.zeros(codebook_size, codebook_dim).float()
46
+ ) # (codebook_size, D'), ensure fp32
47
+ # Place holder, not used in inference
48
+ self.register_buffer("inited", torch.tensor([True], dtype=torch.bool)) # (1)
49
+ self.register_buffer(
50
+ "cluster_size", torch.zeros(codebook_size).float()
51
+ ) # (codebook_size), ensure fp32
52
+ self.register_buffer(
53
+ "embed_avg", self.codebook.clone().float()
54
+ ) # (codebook_size, D'), ensure fp32
55
+
56
+ def decode_code(self, embed_id): # embed_id: (B, T)
57
+ embed = (
58
+ F.embedding(embed_id, self.codebook).transpose(1, 2).float()
59
+ ) # (B, D', T), ensure fp32
60
+ return embed
61
+
62
+ def encode_code(self, z: torch.Tensor): # z: (B, D, T)
63
+ # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
64
+ z = z.float() # Ensure fp32
65
+ z_e = self.in_project(z).float() # (B, D', T), ensure fp32
66
+
67
+ # Rearrange for quantization
68
+ encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32
69
+
70
+ # Quantization
71
+ dist = (
72
+ encodings.pow(2).sum(1, keepdim=True) # (B*T, 1)
73
+ - 2 * encodings @ self.codebook.float().t() # (B*T, codebook_size)
74
+ + self.codebook.float().pow(2).sum(1, keepdim=True).t()
75
+ ) # (1, codebook_size)
76
+
77
+ # dist: (B*T, codebook_size)
78
+ indices = (-dist).max(1)[1] # (B*T)
79
+ indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T)
80
+
81
+ # Get quantized vectors
82
+ z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32
83
+
84
+ # Straight-through estimator
85
+ z_q = z_e + (z_q - z_e).detach() # (B, D', T)
86
+ z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32
87
+
88
+ # z_q: (B, D, T), commit_loss: (B), indices: (B, T), z: (B, D', T)
89
+ return z_q, indices
90
+
91
+
92
+ class ResidualVQ(nn.Module):
93
+ def __init__(
94
+ self,
95
+ input_dim: int = 768, # Input dimension, unrelated to RVQ
96
+ rvq_dim=None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
97
+ output_dim: int = None, # Output dimension, unrelated to RVQ
98
+ num_quantizers: int = 8,
99
+ codebook_size: int = 1024,
100
+ codebook_dim: int = 256, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
101
+ ):
102
+ super().__init__()
103
+ self.input_dim = input_dim
104
+
105
+ self.num_quantizers = num_quantizers
106
+ self.codebook_size = codebook_size
107
+ self.codebook_dim = codebook_dim
108
+ self.rvq_dim = rvq_dim
109
+
110
+ self.input_proj = (
111
+ WNConv1d(input_dim, rvq_dim, kernel_size=1)
112
+ if input_dim != rvq_dim
113
+ else nn.Identity()
114
+ )
115
+ self.output_proj = (
116
+ WNConv1d(rvq_dim, output_dim, kernel_size=1)
117
+ if rvq_dim != output_dim
118
+ else nn.Identity()
119
+ )
120
+
121
+ self.quantizers = nn.ModuleList(
122
+ [
123
+ VectorQuantize(
124
+ input_dim=rvq_dim,
125
+ codebook_size=self.codebook_size,
126
+ codebook_dim=codebook_dim,
127
+ )
128
+ for i in range(num_quantizers)
129
+ ]
130
+ )
131
+
132
+ def encode_codes(self, z: torch.Tensor):
133
+ z = self.input_proj(z)
134
+ residual = z.clone().float() # (B, D, T), ensure fp32
135
+ all_indices = []
136
+ # Quantize to tokens
137
+ for i, quantizer in enumerate(self.quantizers):
138
+ # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
139
+ z_q_i, indices_i = quantizer.encode_code(residual)
140
+ residual = residual - z_q_i
141
+ all_indices.append(indices_i) # (B, T)
142
+ all_indices = torch.stack(all_indices) # (N, B, T)
143
+ return all_indices
144
+
145
+ def decode_codes(self, codes): # codes: (nq, B, T)
146
+ """Decode codes from multiple quantizers to embeddings.
147
+
148
+ Args:
149
+ codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.
150
+
151
+ Returns:
152
+ emb: Tensor of shape (B, D, T) representing the decoded embeddings.
153
+ """
154
+ nq, B, T = codes.shape
155
+ device = codes.device
156
+ emb = torch.zeros(
157
+ B, self.rvq_dim, T, device=device, dtype=torch.float32
158
+ ) # (B, D, T)
159
+ for i, quantizer in enumerate(self.quantizers[:nq]):
160
+ code_i = codes[i] # (B, T)
161
+ quantized_i = quantizer.decode_code(code_i) # (B, D', T)
162
+ emb += quantizer.out_project(quantized_i) # Accumulate quantized embeddings
163
+ emb = self.output_proj(emb) # (B, D, T), apply output projection
164
+ return emb # (B, D, T)
fireredtts2/codec/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
6
+ batch_size = lengths.size(0)
7
+ max_len = max_len if max_len > 0 else lengths.max().item()
8
+ seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
9
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
10
+ seq_length_expand = lengths.unsqueeze(-1)
11
+ mask = seq_range_expand >= seq_length_expand
12
+ return mask # (b, t)
13
+
14
+
15
+ def make_nonpad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
16
+ return ~make_pad_mask(lengths, max_len)
17
+
18
+
19
+ def make_block_causal_mask(
20
+ lengths: torch.Tensor, max_len: int = 0, chunk_size: int = 4
21
+ ) -> torch.Tensor:
22
+ mask = make_nonpad_mask(lengths, max_len) # (b, t)
23
+ attn_mask = torch.logical_and(mask.unsqueeze(1), mask.unsqueeze(2)) # (b, t, t)
24
+
25
+ num_blocks = math.ceil(attn_mask.shape[1] / chunk_size)
26
+ block_mask = torch.block_diag(
27
+ *[torch.ones(chunk_size, chunk_size) for _ in range(num_blocks)]
28
+ )
29
+ block_mask = block_mask[: attn_mask.shape[1], : attn_mask.shape[1]].to(
30
+ attn_mask
31
+ ) # (t, t)
32
+
33
+ diag_mask = attn_mask.new_full(
34
+ (1, attn_mask.shape[1], attn_mask.shape[2]), fill_value=True
35
+ ).tril() # (1, t, t)
36
+ diag_mask = diag_mask.logical_or(block_mask)
37
+ attn_mask = attn_mask.logical_and(diag_mask)
38
+ return attn_mask
fireredtts2/codec/whisper.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Extracted from transformers' WhisperModel to simplify package dependency
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional, Literal
7
+ from fireredtts2.codec.utils import make_nonpad_mask
8
+ from fireredtts2.codec.audio import mel_filter_bank
9
+
10
+
11
+ def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
12
+ """Returns sinusoids for positional embedding"""
13
+ if channels % 2 != 0:
14
+ raise ValueError(
15
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
16
+ )
17
+ log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
18
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
19
+ scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
20
+ return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
21
+
22
+
23
+ class WhisperSdpaAttention(nn.Module):
24
+ def __init__(
25
+ self,
26
+ embed_dim: int,
27
+ num_heads: int,
28
+ dropout: float = 0.0,
29
+ bias: bool = True,
30
+ ):
31
+ super().__init__()
32
+ self.embed_dim = embed_dim
33
+ self.num_heads = num_heads
34
+ self.dropout = dropout
35
+ self.head_dim = embed_dim // num_heads
36
+ self.bias = bias
37
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
38
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
39
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
40
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
41
+
42
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
43
+ return (
44
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
45
+ .transpose(1, 2)
46
+ .contiguous()
47
+ )
48
+
49
+ def forward(
50
+ self,
51
+ hidden_states: torch.Tensor,
52
+ attention_mask: Optional[torch.Tensor] = None,
53
+ ):
54
+ """
55
+ Args:
56
+ attention_mask: Bool mask or float mask. Bool mask, True indicates should attend. Float mask is added to the attention score.
57
+ """
58
+ bsz, tgt_len, _ = hidden_states.size()
59
+
60
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
61
+ key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
62
+ value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)
63
+
64
+ # NOTE sdpa needs a 4-dim attention_mask: (b, nh, tq, tv)
65
+ if attention_mask is not None and len(attention_mask.shape) == 3:
66
+ attention_mask = attention_mask.unsqueeze(1)
67
+
68
+ attn_output = F.scaled_dot_product_attention(
69
+ query_states,
70
+ key_states,
71
+ value_states,
72
+ attn_mask=attention_mask,
73
+ dropout_p=self.dropout if self.training else 0.0,
74
+ ) # (bsz, nh, l, d)
75
+ attn_output = attn_output.transpose(1, 2)
76
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
77
+
78
+ attn_output = self.out_proj(attn_output)
79
+ return attn_output
80
+
81
+ def forward_chunk(
82
+ self,
83
+ hidden_states: torch.Tensor,
84
+ kv_cache: torch.Tensor = None,
85
+ ):
86
+ """Forward self-attention with kv cache.
87
+
88
+ Args:
89
+ hidden_states: shape (b, t, c)
90
+ kv_cache: shape (b, nh, t, c*2)
91
+ """
92
+ bsz, tgt_len, _ = hidden_states.size()
93
+
94
+ # shape (b, nh, t, c)
95
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
96
+ key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz)
97
+ value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz)
98
+
99
+ # unpack cache
100
+ if kv_cache is not None:
101
+ k_cache, v_cache = kv_cache.chunk(2, dim=-1)
102
+ key_states = torch.cat([k_cache, key_states], dim=2)
103
+ value_states = torch.cat([v_cache, value_states], dim=2)
104
+ new_kv_cache = torch.cat([key_states, value_states], dim=-1)
105
+
106
+ # attention
107
+ attn_output = F.scaled_dot_product_attention(
108
+ query_states,
109
+ key_states,
110
+ value_states,
111
+ attn_mask=None,
112
+ dropout_p=0.0,
113
+ ) # (bsz, nh, l, d)
114
+ attn_output = attn_output.transpose(1, 2)
115
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
116
+
117
+ attn_output = self.out_proj(attn_output)
118
+ return attn_output, new_kv_cache
119
+
120
+
121
+ class WhisperEncoderLayer(nn.Module):
122
+ def __init__(
123
+ self,
124
+ embed_dim: int,
125
+ num_heads: int,
126
+ ffn_dim: int = None,
127
+ attn_dropout: float = 0.0,
128
+ dropout: float = 0.0,
129
+ ):
130
+ super().__init__()
131
+ self.dropout = dropout
132
+ # Attention
133
+ self.self_attn = WhisperSdpaAttention(embed_dim, num_heads, attn_dropout)
134
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
135
+ # FFN
136
+ ffn_dim = ffn_dim if ffn_dim is not None else embed_dim * 4
137
+ self.fc1 = nn.Linear(embed_dim, ffn_dim)
138
+ self.fc2 = nn.Linear(ffn_dim, embed_dim)
139
+ # Output norm
140
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: torch.Tensor,
146
+ ):
147
+ # Attention
148
+ residual = hidden_states
149
+ hidden_states = self.self_attn_layer_norm(hidden_states)
150
+ hidden_states = self.self_attn(hidden_states, attention_mask)
151
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
152
+ hidden_states = residual + hidden_states
153
+
154
+ # FFN
155
+ residual = hidden_states
156
+ hidden_states = self.final_layer_norm(hidden_states)
157
+ hidden_states = F.gelu(self.fc1(hidden_states))
158
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
159
+ hidden_states = self.fc2(hidden_states)
160
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
161
+ hidden_states = residual + hidden_states
162
+ return hidden_states
163
+
164
+ def forward_chunk(
165
+ self,
166
+ hidden_states: torch.Tensor,
167
+ kv_cache: torch.Tensor = None,
168
+ ):
169
+ """Forward self-attention with kv cache.
170
+
171
+ Args:
172
+ hidden_states: shape (b, t, c)
173
+ kv_cache: shape (b, nh, t, c*2)
174
+ """
175
+ # Attention
176
+ residual = hidden_states
177
+ hidden_states = self.self_attn_layer_norm(hidden_states)
178
+ hidden_states, new_kv_cache = self.self_attn.forward_chunk(
179
+ hidden_states, kv_cache
180
+ )
181
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
182
+ hidden_states = residual + hidden_states
183
+
184
+ # FFN
185
+ residual = hidden_states
186
+ hidden_states = self.final_layer_norm(hidden_states)
187
+ hidden_states = F.gelu(self.fc1(hidden_states))
188
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
189
+ hidden_states = self.fc2(hidden_states)
190
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
191
+ hidden_states = residual + hidden_states
192
+ return hidden_states, new_kv_cache
193
+
194
+
195
+ class WhisperEncoder(nn.Module):
196
+ def __init__(
197
+ self,
198
+ in_dim: int,
199
+ embed_dim: int,
200
+ num_layers: int,
201
+ num_heads: int,
202
+ ffn_dim: int = None,
203
+ attn_dropout: float = 0.0,
204
+ dropout: float = 0.0,
205
+ max_positions: int = 1500,
206
+ ):
207
+ super().__init__()
208
+ self.in_dim = in_dim
209
+ self.embed_dim = embed_dim
210
+ self.dropout = dropout
211
+ # Input downsampling
212
+ self.conv1 = nn.Conv1d(in_dim, embed_dim, kernel_size=3, padding=1)
213
+ self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
214
+ # Fixed positional embedding
215
+ self.max_positions = max_positions
216
+ self.embed_positions = nn.Embedding(self.max_positions, embed_dim)
217
+ self.embed_positions.requires_grad_(False)
218
+ # Transformer
219
+ self.layers = nn.ModuleList(
220
+ [
221
+ WhisperEncoderLayer(
222
+ embed_dim, num_heads, ffn_dim, attn_dropout, dropout
223
+ )
224
+ for _ in range(num_layers)
225
+ ]
226
+ )
227
+ # Output norm
228
+ self.layer_norm = nn.LayerNorm(embed_dim)
229
+ # Init weight
230
+ self.apply(self._init_weights)
231
+ # Init position embedding
232
+ self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape))
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ hidden_length: torch.Tensor,
238
+ apply_position: bool = True,
239
+ ):
240
+ # Downsampling
241
+ hidden_states = hidden_states.transpose(1, 2)
242
+ hidden_states = F.gelu(self.conv1(hidden_states))
243
+ hidden_states = F.gelu(self.conv2(hidden_states))
244
+ hidden_states = hidden_states.transpose(1, 2)
245
+ hidden_length = hidden_length // 2 # from 100Hz -> 50Hz
246
+ # Pos encoding
247
+ if apply_position:
248
+ pos_embed = self.embed_positions(
249
+ torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
250
+ )
251
+ hidden_states = hidden_states + pos_embed
252
+ hidden_states = nn.functional.dropout(
253
+ hidden_states, p=self.dropout, training=self.training
254
+ )
255
+ # Transformer
256
+ attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t)
257
+ for layer in self.layers:
258
+ hidden_states = layer(hidden_states, attention_mask)
259
+
260
+ hidden_states = self.layer_norm(hidden_states)
261
+ return hidden_states, hidden_length
262
+
263
+ def _init_weights(self, module):
264
+ std = 0.02
265
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
266
+ module.weight.data.normal_(mean=0.0, std=std)
267
+ if module.bias is not None:
268
+ module.bias.data.zero_()
269
+ elif isinstance(module, nn.Embedding):
270
+ module.weight.data.normal_(mean=0.0, std=std)
271
+ if module.padding_idx is not None:
272
+ module.weight.data[module.padding_idx].zero_()
273
+
274
+
275
+ class WhisperMelExtractor(nn.Module):
276
+ def __init__(
277
+ self,
278
+ num_mels: int = 128,
279
+ sampling_rate: int = 16000,
280
+ hop_length: int = 160,
281
+ n_fft: int = 400,
282
+ fmin: float = 0,
283
+ fmax: float = 8000,
284
+ padding_value=0.0,
285
+ ):
286
+ super().__init__()
287
+ self.num_mels = num_mels
288
+ self.sampling_rate = sampling_rate
289
+ self.hop_length = hop_length
290
+ self.n_fft = n_fft
291
+ self.fmin = fmin
292
+ self.fmax = fmax
293
+ self.padding_value = padding_value
294
+ self.mel_filters = mel_filter_bank(
295
+ num_frequency_bins=(1 + n_fft // 2),
296
+ num_mel_filters=num_mels,
297
+ min_frequency=fmin,
298
+ max_frequency=fmax,
299
+ sampling_rate=sampling_rate,
300
+ norm="slaney",
301
+ mel_scale="slaney",
302
+ )
303
+
304
+ def extract_fbank(self, audio: torch.Tensor):
305
+ """
306
+ Args:
307
+ audio: batched audio of shape (b, t)
308
+ """
309
+ device = audio.device # compute on cuda if input is on cuda
310
+ # Mel
311
+ window = torch.hann_window(self.n_fft).to(device)
312
+ stft = torch.stft(
313
+ audio, self.n_fft, self.hop_length, window=window, return_complex=True
314
+ )
315
+ magnitudes = stft[..., :-1].abs() ** 2
316
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32).to(device)
317
+ mel_spec = mel_filters.T @ magnitudes
318
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
319
+ # Norm
320
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
321
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
322
+ log_spec = (log_spec + 4.0) / 4.0
323
+ return log_spec
324
+
325
+ def __call__(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
326
+ mel = self.extract_fbank(audio16k).transpose(1, 2)
327
+ mel_length = audio16k_length // self.hop_length
328
+ # mel: (b, t, c=128)
329
+ return mel, mel_length
330
+
331
+
332
+ # Pretrained encoder from whisper-large-v3
333
+ class PretrainedWhisperEncoder(WhisperEncoder):
334
+ @classmethod
335
+ def from_pretrained(cls, pretrained_path: str = None):
336
+ encoder = cls(
337
+ in_dim=128,
338
+ embed_dim=1280,
339
+ num_layers=32,
340
+ num_heads=20,
341
+ ffn_dim=5120,
342
+ attn_dropout=0.0,
343
+ max_positions=1500,
344
+ )
345
+ if pretrained_path is not None:
346
+ ckpt = torch.load(pretrained_path, map_location="cpu")
347
+ encoder.load_state_dict(ckpt)
348
+ encoder.eval()
349
+ # Disable grad
350
+ for p in encoder.parameters():
351
+ p.requires_grad_(False)
352
+ # Add Mel extractor
353
+ encoder.feature_extractor = WhisperMelExtractor(
354
+ num_mels=128,
355
+ sampling_rate=16000,
356
+ hop_length=160,
357
+ n_fft=400,
358
+ fmin=0,
359
+ fmax=8000,
360
+ )
361
+ return encoder
362
+
363
+ @torch.inference_mode()
364
+ def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
365
+ # Extract mel
366
+ mel, mel_length = self.feature_extractor(audio16k, audio16k_length)
367
+ # Forward model
368
+ semantic_feats, semantic_length = super().forward(
369
+ mel, mel_length, apply_position=True
370
+ )
371
+ return semantic_feats, semantic_length
372
+
373
+
374
+ class WhisperAcousticEncoder(WhisperEncoder):
375
+ def __init__(
376
+ self,
377
+ # Mel extraction params
378
+ num_mels: int = 128,
379
+ sampling_rate: int = 16000,
380
+ hop_length: int = 160,
381
+ n_fft: int = 400,
382
+ fmin: float = 0.0,
383
+ fmax: float = 8000,
384
+ # Encoder params
385
+ embed_dim: int = 768,
386
+ num_layers: int = 12,
387
+ num_heads: int = 8,
388
+ ffn_dim: int = None,
389
+ attn_dropout: float = 0.0,
390
+ dropout: float = 0.0,
391
+ max_positions: int = 1500, # 50Hz * 30s
392
+ ):
393
+ super().__init__(
394
+ in_dim=num_mels,
395
+ embed_dim=embed_dim,
396
+ num_layers=num_layers,
397
+ num_heads=num_heads,
398
+ ffn_dim=ffn_dim,
399
+ attn_dropout=attn_dropout,
400
+ dropout=dropout,
401
+ max_positions=max_positions,
402
+ )
403
+ self.feature_extractor = WhisperMelExtractor(
404
+ num_mels=num_mels,
405
+ sampling_rate=sampling_rate,
406
+ hop_length=hop_length,
407
+ n_fft=n_fft,
408
+ fmin=fmin,
409
+ fmax=fmax,
410
+ )
411
+
412
+ def forward(self, audio16k: torch.Tensor, audio16k_length: torch.Tensor):
413
+ # Extract mel
414
+ with torch.no_grad():
415
+ mel, mel_length = self.feature_extractor(audio16k, audio16k_length)
416
+ # Forward model
417
+ hidden_states, hidden_length = super().forward(
418
+ mel, mel_length, apply_position=True
419
+ )
420
+ return hidden_states, hidden_length
fireredtts2/fireredtts2.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+
7
+ from typing import List, Tuple
8
+ from fireredtts2.codec import RedCodecInfer
9
+ from fireredtts2.llm import load_llm_model, load_custom_tokenizer
10
+ from fireredtts2.llm.utils import Segment
11
+ from fireredtts2.utils.spliter import clean_text, split_text, process_text_list
12
+ from tqdm import tqdm
13
+
14
+
15
+ class FireRedTTS2:
16
+ def __init__(self, pretrained_dir, gen_type, device):
17
+
18
+ assert os.path.exists(pretrained_dir)
19
+ assert gen_type in ["monologue", "dialogue"]
20
+ llm_config_path = os.path.join(pretrained_dir, "config_llm.json")
21
+ if gen_type == "monologue":
22
+ llm_ckpt_path = os.path.join(pretrained_dir, "llm_pretrain.pt")
23
+ # llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt")
24
+ else:
25
+ llm_ckpt_path = os.path.join(pretrained_dir, "llm_posttrain.pt")
26
+ codec_config_path = os.path.join(pretrained_dir, "config_codec.json")
27
+ codec_ckpt_path = os.path.join(pretrained_dir, "codec.pt")
28
+ pretrained_qwen_path = os.path.join(pretrained_dir, "Qwen2.5-1.5B")
29
+
30
+ # check
31
+ assert os.path.exists(llm_config_path)
32
+ assert os.path.exists(llm_ckpt_path)
33
+ assert os.path.exists(codec_config_path)
34
+ assert os.path.exists(codec_ckpt_path)
35
+ assert os.path.exists(pretrained_qwen_path)
36
+
37
+ # ==== Load Torch LLM ====
38
+ llm_config = json.load(open(llm_config_path))
39
+ self._model = load_llm_model(
40
+ configs=llm_config, checkpoint_path=llm_ckpt_path, device=device
41
+ )
42
+ self._model.eval()
43
+ self._model.setup_caches(1)
44
+ print("[INFO] LLM Loaded...")
45
+
46
+ # ==== Load Qwen2.5 Text Tokenizer ====
47
+ self._text_tokenizer = load_custom_tokenizer(pretrained_qwen_path)
48
+ print("[INFO] Text Tokenizer Loaded...")
49
+
50
+ # ==== Load Torch Audio Tokenizer ====
51
+ torch_codec = RedCodecInfer.from_pretrained(codec_config_path, codec_ckpt_path)
52
+ torch_codec.eval()
53
+ self._audio_tokenizer = torch_codec.to(device)
54
+ print("[INFO] Codec Loaded...")
55
+
56
+ self.sample_rate = 16000
57
+ self.device = device
58
+ self.max_seq_len = 3100
59
+
60
+ def load_prompt_audio(self, audio_path) -> torch.Tensor:
61
+ audio, audio_sr = torchaudio.load(audio_path)
62
+ # Audio must be single channel
63
+ if audio.shape[0] > 1:
64
+ audio = audio[0, :].unsqueeze(0)
65
+ audio16k = torchaudio.functional.resample(audio, audio_sr, 16000)
66
+ return audio16k
67
+
68
+ def prepare_prompt(self, text, speaker, audio_path) -> Segment:
69
+ audio_tensor = self.load_prompt_audio(audio_path)
70
+ return Segment(text=text, speaker=speaker, audio=audio_tensor)
71
+
72
+ def _tokenize_text_segment(
73
+ self, text: str, speaker: str
74
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ frame_tokens = []
76
+ frame_masks = []
77
+
78
+ text = speaker + "<|text_start|>" + text + "<|text_end|>"
79
+ text_tokens = self._text_tokenizer.encode(text)
80
+ text_frame = torch.zeros(len(text_tokens), 17).long()
81
+ text_frame_mask = torch.zeros(len(text_tokens), 17).bool()
82
+ text_frame[:, -1] = torch.tensor(text_tokens)
83
+ text_frame_mask[:, -1] = True
84
+
85
+ frame_tokens.append(text_frame.to(self.device))
86
+ frame_masks.append(text_frame_mask.to(self.device))
87
+
88
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
89
+
90
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ frame_tokens = []
92
+ frame_masks = []
93
+
94
+ # (K, T)
95
+ audio_length = torch.tensor([audio.shape[1]], dtype=torch.long)
96
+ audio_tokens, token_length = self._audio_tokenizer.encode(
97
+ audio.to(self.device),
98
+ audio_length.to(self.device),
99
+ batch_size=48,
100
+ )
101
+
102
+ audio_tokens = audio_tokens.squeeze(0)
103
+ # add EOS frame
104
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
105
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
106
+
107
+ audio_frame = torch.zeros(audio_tokens.size(1), 17).long().to(self.device)
108
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 17).bool().to(self.device)
109
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
110
+ audio_frame_mask[:, :-1] = True
111
+
112
+ frame_tokens.append(audio_frame)
113
+ frame_masks.append(audio_frame_mask)
114
+
115
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
116
+
117
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """
119
+ Returns:
120
+ (seq_len,17), (seq_len, 17)
121
+ """
122
+ text_tokens, text_masks = self._tokenize_text_segment(
123
+ segment.text, segment.speaker
124
+ )
125
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
126
+
127
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat(
128
+ [text_masks, audio_masks], dim=0
129
+ )
130
+
131
+ @torch.inference_mode()
132
+ def generate(
133
+ self,
134
+ text: str,
135
+ speaker: str,
136
+ context: List[Segment],
137
+ max_audio_length_ms: float = 90_000,
138
+ temperature: float = 0.9,
139
+ topk: int = 20,
140
+ ) -> torch.Tensor:
141
+ self._model.reset_caches()
142
+
143
+ max_generation_len = int(max_audio_length_ms / 80)
144
+ tokens, tokens_mask = [], []
145
+ for segment in context:
146
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
147
+ tokens.append(segment_tokens)
148
+ tokens_mask.append(segment_tokens_mask)
149
+
150
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
151
+ text, speaker
152
+ )
153
+ tokens.append(gen_segment_tokens)
154
+ tokens_mask.append(gen_segment_tokens_mask)
155
+
156
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
157
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
158
+
159
+ samples = []
160
+ curr_tokens = prompt_tokens.unsqueeze(0)
161
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
162
+ curr_pos = (
163
+ torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
164
+ )
165
+
166
+ max_seq_len = 3100
167
+ max_context_len = max_seq_len - max_generation_len
168
+ if curr_tokens.size(1) >= max_context_len:
169
+ raise ValueError(
170
+ f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
171
+ )
172
+
173
+ for _ in range(max_generation_len):
174
+ sample = self._model.generate_frame(
175
+ curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
176
+ )
177
+ # eos
178
+ if torch.all(sample == 0):
179
+ break
180
+
181
+ samples.append(sample)
182
+
183
+ curr_tokens = torch.cat(
184
+ [sample, torch.zeros(1, 1).long().to(self.device)], dim=1
185
+ ).unsqueeze(1)
186
+ curr_tokens_mask = torch.cat(
187
+ [
188
+ torch.ones_like(sample).bool(),
189
+ torch.zeros(1, 1).bool().to(self.device),
190
+ ],
191
+ dim=1,
192
+ ).unsqueeze(1)
193
+ curr_pos = curr_pos[:, -1:] + 1
194
+
195
+ audio = (
196
+ self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0))
197
+ .squeeze(0)
198
+ .squeeze(0)
199
+ )
200
+
201
+ return audio
202
+
203
+ def generate_single(
204
+ self, context: List[Segment], temperature: float = 0.9, topk: int = 20
205
+ ):
206
+ self._model.reset_caches()
207
+ max_generation_len = 400
208
+ tokens, tokens_mask = [], []
209
+ for segment in context:
210
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
211
+ tokens.append(segment_tokens)
212
+ tokens_mask.append(segment_tokens_mask)
213
+
214
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
215
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
216
+ prompt_tokens = prompt_tokens[:-3, :]
217
+ prompt_tokens_mask = prompt_tokens_mask[:-3, :]
218
+
219
+ samples = []
220
+ curr_tokens = prompt_tokens.unsqueeze(0)
221
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
222
+ curr_pos = (
223
+ torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
224
+ )
225
+
226
+ num_token = 0
227
+ start_time = time.time()
228
+ for _ in range(max_generation_len):
229
+ sample = self._model.generate_frame(
230
+ curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
231
+ )
232
+ # eos
233
+ if torch.all(sample == 0):
234
+ break
235
+
236
+ samples.append(sample)
237
+
238
+ curr_tokens = torch.cat(
239
+ [sample, torch.zeros(1, 1).long().to(self.device)], dim=1
240
+ ).unsqueeze(1)
241
+ curr_tokens_mask = torch.cat(
242
+ [
243
+ torch.ones_like(sample).bool(),
244
+ torch.zeros(1, 1).bool().to(self.device),
245
+ ],
246
+ dim=1,
247
+ ).unsqueeze(1)
248
+ curr_pos = curr_pos[:, -1:] + 1
249
+ num_token += 1
250
+ if num_token == 2:
251
+ end_time = time.time()
252
+ duration = end_time - start_time
253
+ print("---first pack duration:", duration)
254
+
255
+ gen_tokens = torch.stack(samples).permute(1, 2, 0)
256
+
257
+ return gen_tokens
258
+
259
+ # @torch.inference_mode()
260
+ # def generate_stream(
261
+ # self,
262
+ # text: str,
263
+ # speaker: str,
264
+ # context: List[Segment],
265
+ # max_audio_length_ms: float = 90_000,
266
+ # temperature: float = 0.9,
267
+ # topk: int = 50,
268
+ # ):
269
+ # self._model.reset_caches()
270
+
271
+ # max_generation_len = int(max_audio_length_ms / 80)
272
+ # tokens, tokens_mask = [], []
273
+ # for segment in context:
274
+ # segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
275
+ # tokens.append(segment_tokens)
276
+ # tokens_mask.append(segment_tokens_mask)
277
+
278
+ # gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
279
+ # text, speaker
280
+ # )
281
+ # tokens.append(gen_segment_tokens)
282
+ # tokens_mask.append(gen_segment_tokens_mask)
283
+
284
+ # prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
285
+ # prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
286
+
287
+ # samples = []
288
+ # curr_tokens = prompt_tokens.unsqueeze(0)
289
+ # curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
290
+ # curr_pos = (
291
+ # torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
292
+ # )
293
+
294
+ # max_seq_len = 3100
295
+ # max_context_len = max_seq_len - max_generation_len
296
+ # if curr_tokens.size(1) >= max_context_len:
297
+ # raise ValueError(
298
+ # f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
299
+ # )
300
+
301
+ # # codec cache
302
+ # codec_cache = {}
303
+ # prev_sample = None
304
+
305
+ # for _ in range(max_generation_len):
306
+ # sample = self._model.generate_frame(
307
+ # curr_tokens, curr_tokens_mask, curr_pos, temperature, topk
308
+ # )
309
+ # # eos
310
+ # if torch.all(sample == 0):
311
+ # break
312
+
313
+ # # decode one token
314
+ # if prev_sample is None:
315
+ # prev_sample = sample # sample: (b, nq)
316
+ # else:
317
+ # audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token(
318
+ # prev_sample.unsqueeze(-1),
319
+ # codec_cache,
320
+ # last_token=False,
321
+ # )
322
+ # yield audio_chunk.squeeze(0)
323
+ # prev_sample = sample
324
+ # samples.append(sample) # sample: (b, nq)
325
+
326
+ # curr_tokens = torch.cat(
327
+ # [sample, torch.zeros(1, 1).long().to(self.device)], dim=1
328
+ # ).unsqueeze(1)
329
+ # curr_tokens_mask = torch.cat(
330
+ # [
331
+ # torch.ones_like(sample).bool(),
332
+ # torch.zeros(1, 1).bool().to(self.device),
333
+ # ],
334
+ # dim=1,
335
+ # ).unsqueeze(1)
336
+ # curr_pos = curr_pos[:, -1:] + 1
337
+
338
+ # audio_chunk, codec_cache = self._audio_tokenizer.decode_one_token(
339
+ # prev_sample.unsqueeze(-1),
340
+ # codec_cache,
341
+ # last_token=True,
342
+ # )
343
+ # yield audio_chunk.squeeze(0)
344
+
345
+ @torch.inference_mode()
346
+ def generate_dialogue(
347
+ self,
348
+ text_list,
349
+ prompt_wav_list=None,
350
+ prompt_text_list=None,
351
+ temperature=0.9,
352
+ topk=20,
353
+ ):
354
+ all_generated_segments = []
355
+ all_storage_segments = []
356
+ prompt_segments = []
357
+ text_list = process_text_list(text_list=text_list)
358
+ if prompt_wav_list is not None:
359
+ assert len(prompt_wav_list) == len(prompt_text_list)
360
+ # Prepare prompts
361
+ for i in range(len(prompt_wav_list)):
362
+ prompt_wav = prompt_wav_list[i]
363
+ prompt_text = prompt_text_list[i]
364
+ speaker = prompt_text[:4]
365
+ assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
366
+ prompt_segments.append(
367
+ self.prepare_prompt(
368
+ text=prompt_text, speaker=speaker, audio_path=prompt_wav
369
+ )
370
+ )
371
+
372
+ for text in tqdm(text_list):
373
+ speaker = text[:4]
374
+ text = text[4:]
375
+ # print("---speaker:", speaker)
376
+ # print("---text:", text)
377
+ assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
378
+
379
+ audio_tensor = self.generate(
380
+ text=text,
381
+ speaker=speaker,
382
+ context=prompt_segments + all_generated_segments,
383
+ max_audio_length_ms=30_000,
384
+ temperature=temperature,
385
+ topk=topk,
386
+ )
387
+
388
+ # 做上下文管理的时候需要将audio 转到16k
389
+ audio_16k = torchaudio.functional.resample(
390
+ audio_tensor.unsqueeze(0), 24000, 16000
391
+ )
392
+ all_generated_segments.append(
393
+ Segment(text=text, speaker=speaker, audio=audio_16k)
394
+ )
395
+
396
+ all_storage_segments.append(
397
+ Segment(text=text, speaker=speaker, audio=audio_tensor.unsqueeze(0))
398
+ )
399
+
400
+ # Concatenate all generations
401
+ all_audio = torch.cat([seg.audio for seg in all_storage_segments], dim=1)
402
+ all_audio = all_audio.cpu()
403
+ return all_audio
404
+
405
+ @torch.inference_mode()
406
+ def generate_monologue(
407
+ self, text, prompt_wav=None, prompt_text=None, temperature=0.75, topk=20
408
+ ):
409
+ # step1. construct context
410
+ if prompt_wav is not None:
411
+ assert os.path.exists(prompt_wav)
412
+ assert prompt_text is not None
413
+
414
+ all_generated_segments = []
415
+ all_storage_segments = []
416
+ prompt_segments = []
417
+ prompt_text = clean_text(text=prompt_text)
418
+ text = clean_text(text=text)
419
+ text_list = split_text(text=text, length=400)
420
+
421
+ audio_list = []
422
+ for text in text_list:
423
+ text = clean_text(text=text)
424
+ input_text = prompt_text[:-1] + "," + text
425
+ prompt_a = self.prepare_prompt(
426
+ text=input_text, speaker="[S1]", audio_path=prompt_wav
427
+ )
428
+
429
+ context = [prompt_a]
430
+
431
+ while True:
432
+ gen_tokens = self.generate_single(
433
+ context=context, temperature=temperature, topk=topk
434
+ )
435
+ if gen_tokens.shape[2] > 18:
436
+ break
437
+ # else:
438
+ # print("生成结果小于1s,重新跑")
439
+
440
+ gen_tokens = gen_tokens[:, :, 2:] # cut leading silence
441
+ audio = self._audio_tokenizer.decode(gen_tokens).squeeze(0).squeeze(0)
442
+ audio_list.append(audio.unsqueeze(0))
443
+
444
+ all_audio = torch.cat(tensors=audio_list, dim=1)
445
+
446
+ return all_audio
447
+
448
+ else:
449
+ # random speaker
450
+ text = clean_text(text=text.strip())
451
+ audio_tensor = self.generate(
452
+ text=text,
453
+ speaker="[S1]",
454
+ context=[],
455
+ max_audio_length_ms=30_000,
456
+ temperature=temperature,
457
+ topk=topk,
458
+ )
459
+ return audio_tensor.unsqueeze(0)
fireredtts2/llm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from fireredtts2.llm.utils import load_llm_model, load_custom_tokenizer
fireredtts2/llm/llm.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from fireredtts2.llm.modules import FLAVORS
7
+
8
+
9
+ def _prepare_transformer(model):
10
+ embed_dim = model.tok_embeddings.embedding_dim
11
+ model.tok_embeddings = nn.Identity()
12
+ model.output = nn.Identity()
13
+ return model, embed_dim
14
+
15
+
16
+ def _create_causal_mask(seq_len: int, device: torch.device):
17
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
18
+
19
+
20
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
21
+ """
22
+ Args:
23
+ mask: (max_seq_len, max_seq_len)
24
+ input_pos: (batch_size, seq_len)
25
+
26
+ Returns:
27
+ (batch_size, seq_len, max_seq_len)
28
+ """
29
+ r = mask[input_pos, :]
30
+ return r
31
+
32
+
33
+ # Does multinomial sampling without a cuda synchronization
34
+ def _multinomial_sample_one_no_sync(probs):
35
+ q = torch.empty_like(probs).exponential_(1)
36
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
37
+
38
+
39
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
40
+ logits = logits / temperature
41
+
42
+ filter_value: float = -float("Inf")
43
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
44
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
45
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
46
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
47
+
48
+ sample_token = _multinomial_sample_one_no_sync(probs)
49
+ return sample_token
50
+
51
+
52
+ def sample_top_nsigma(logits: torch.Tensor, n: float, temperature: float):
53
+ """_summary_
54
+
55
+ Args:
56
+ logits (torch.Tensor): _description_
57
+ n (float): _description_
58
+ temperature (float): _description_
59
+
60
+ Returns:
61
+ _type_: _description_
62
+ """
63
+ logits = logits / temperature
64
+ threshold = logits.max(dim=-1, keepdim=True).values - n * logits.std(
65
+ dim=-1, keepdim=True
66
+ )
67
+ logits[logits < threshold] = float("-inf")
68
+ # scores_processed = torch.nn.functional.log_softmax(logits, dim=-1)
69
+ probs = torch.nn.functional.softmax(logits, dim=-1)
70
+
71
+ sample_token = _multinomial_sample_one_no_sync(probs)
72
+ return sample_token
73
+
74
+
75
+ @dataclass
76
+ class ModelArgs:
77
+ backbone_flavor: str
78
+ decoder_flavor: str
79
+ text_vocab_size: int
80
+ audio_vocab_size: int
81
+ audio_num_codebooks: int
82
+ decoder_loss_weight: float
83
+ use_text_loss: bool
84
+
85
+
86
+ class Model(nn.Module, PyTorchModelHubMixin):
87
+ def __init__(self, config: ModelArgs):
88
+ super().__init__()
89
+ self.config = config
90
+
91
+ self.backbone, backbone_dim = _prepare_transformer(
92
+ FLAVORS[config.backbone_flavor]()
93
+ )
94
+ self.decoder, decoder_dim = _prepare_transformer(
95
+ FLAVORS[config.decoder_flavor]()
96
+ )
97
+
98
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
99
+ self.audio_embeddings = nn.Embedding(
100
+ config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
101
+ )
102
+
103
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
104
+ self.text_head = nn.Linear(backbone_dim, config.text_vocab_size, bias=False)
105
+ self.codebook0_head = nn.Linear(
106
+ backbone_dim, config.audio_vocab_size, bias=False
107
+ )
108
+ self.audio_head = nn.Parameter(
109
+ torch.empty(
110
+ config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
111
+ )
112
+ )
113
+
114
+ self.decoder_loss_weight = config.decoder_loss_weight
115
+ self.use_text_loss = config.use_text_loss
116
+
117
+ # debug
118
+ # print("---backbone_dim:", backbone_dim)
119
+ # print("---decoder_dim:", decoder_dim)
120
+ # print("---self.decoder_loss_weight:", self.decoder_loss_weight)
121
+ # print("---self.use_text_loss:", self.use_text_loss)
122
+
123
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
124
+ """Setup KV caches and return a causal mask."""
125
+ dtype = next(self.parameters()).dtype
126
+ device = next(self.parameters()).device
127
+
128
+ with device:
129
+ self.backbone.setup_caches(max_batch_size, dtype)
130
+ self.decoder.setup_caches(
131
+ max_batch_size,
132
+ dtype,
133
+ decoder_max_seq_len=self.config.audio_num_codebooks,
134
+ )
135
+
136
+ self.register_buffer(
137
+ "backbone_causal_mask",
138
+ _create_causal_mask(self.backbone.max_seq_len, device),
139
+ )
140
+ self.register_buffer(
141
+ "decoder_causal_mask",
142
+ _create_causal_mask(self.config.audio_num_codebooks, device),
143
+ )
144
+
145
+ def forward(self, tokens: torch.Tensor, tokens_mask: torch.Tensor):
146
+ """
147
+ Forward pass for Sesame's CSM model.
148
+ This will be added to the model with `model.forward = types.MethodType(forward, model)`
149
+
150
+ Args:
151
+ tokens: (batch_size, seq_len, n_codebooks+1)
152
+ tokens_mask: (batch_size, seq_len, n_codebooks+1)
153
+ """
154
+
155
+ dtype = next(self.parameters()).dtype
156
+ bsz, seq_len, _ = tokens.size()
157
+ device = tokens.device
158
+
159
+ # print("---tokens:\n", tokens, tokens.shape)
160
+ # print("---tokens_mask:\n", tokens_mask, tokens_mask.shape)
161
+ # print("---bsz:", bsz)
162
+ # print("---seq_len:", seq_len)
163
+
164
+ # embed tokens
165
+ embeds = self._embed_tokens(tokens) # (bsz,seq_len,33,2048)
166
+ # print("---embeds:\n", embeds, embeds.shape)
167
+
168
+ # get targets and codebook embeddings corresponding to audio tokens
169
+ audio_mask = tokens_mask[:, :, 0] # [bsz, seq_len]
170
+ target_tokens = tokens[audio_mask][:, :-1] # [audio_len, n_codebooks]
171
+ # [audio_len, n_codebooks, embed_dim]
172
+ c_embeds = embeds[:, :, :-1, :][audio_mask]
173
+ # print("---audio_mask:\n", audio_mask, audio_mask.shape)
174
+ # print("---target_tokens:\n", target_tokens, target_tokens.shape)
175
+
176
+ # get targets corresponding to text tokens
177
+ text_mask = tokens_mask[:, :, -1]
178
+ text_target_mask = torch.roll(input=text_mask, shifts=1, dims=1)
179
+ text_target_tokens = tokens[text_target_mask][:, -1]
180
+
181
+ # print("---text_target_mask:\n", text_target_mask, text_target_mask.shape)
182
+ # print("---target_text_tokens:\n", text_target_tokens, text_target_tokens.shape)
183
+
184
+ # print("\n\n")
185
+
186
+ # retain just non-padding embeddings
187
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
188
+ h = masked_embeds.sum(dim=2)
189
+
190
+ # backbone forward pass
191
+ # [bsz, seq_len]
192
+ padding_mask = tokens_mask[:, :, 0] | tokens_mask[:, :, -1]
193
+ # [seq_len, seq_len]
194
+ backbone_attn_mask = _create_causal_mask(seq_len, device)
195
+ # [bsz, seq_len, seq_len]
196
+ padding_3d = padding_mask.unsqueeze(-1) * padding_mask.unsqueeze(1)
197
+ backbone_attn_mask = backbone_attn_mask.unsqueeze(0) * padding_3d
198
+ backbone_attn_mask = backbone_attn_mask | torch.eye(
199
+ seq_len, device=device
200
+ ).bool().unsqueeze(0).expand(bsz, -1, -1)
201
+ input_pos = (
202
+ torch.arange(0, seq_len).unsqueeze(0).expand(bsz, seq_len).long().to(device)
203
+ )
204
+ h = self.backbone(h, input_pos=input_pos, mask=backbone_attn_mask).to(
205
+ dtype=dtype
206
+ )
207
+ # print("---h:\n", h, h.shape)
208
+
209
+ # get backbone embeddings used for audio codebook prediction predict first codebook and compute loss
210
+ audio_mask = torch.roll(audio_mask, -1, 1) # shift audio mask to the right by 1
211
+ audio_h = h[audio_mask] # [audio_len, embed_dim]
212
+ # print("---audio_mask after shift:\n", audio_mask, audio_mask.shape)
213
+ c0_logits = self.codebook0_head(audio_h) # [audio_len, audio_vocab_size]
214
+ c0_target = target_tokens[:, 0] # [audio_len]
215
+ c0_loss = F.cross_entropy(c0_logits, c0_target)
216
+
217
+ # predict text loss
218
+ text_h = h[text_mask]
219
+ text_logits = self.text_head(text_h)
220
+ text_loss = F.cross_entropy(text_logits, text_target_tokens, ignore_index=0)
221
+ # print("---text_h:\n", text_h, text_h.shape)
222
+ # print("---text_logits:\n", text_logits)
223
+ # print("---text_loss:", text_loss)
224
+
225
+ # "compute amortization" (train decoder on random 1/16 subset of audio tokens)
226
+ # important change to 1/8
227
+ # indices = torch.randperm(c_embeds.size(0))[: c_embeds.size(0) // 16]
228
+ indices = torch.randperm(c_embeds.size(0))[: c_embeds.size(0) // 8]
229
+ # [audio_len//16, n_codebooks-1, embed_dim]
230
+ c_embeds = c_embeds[indices][:, :-1, :]
231
+ audio_h = audio_h[indices] # [audio_len//16, embed_dim]
232
+ target_tokens = target_tokens[indices][:, 1:] # [audio_len//16, n_codebooks-1]
233
+
234
+ # concatenate backbone embeddings and codebook embeddings for decoder input
235
+ # [audio_len//16, n_codebooks, embed_dim]
236
+ decoder_embeds = torch.cat([audio_h.unsqueeze(1), c_embeds], dim=1)
237
+ N, n_codebooks, _ = decoder_embeds.size()
238
+ c_pos = (
239
+ torch.arange(0, n_codebooks)
240
+ .unsqueeze(0)
241
+ .expand(N, n_codebooks)
242
+ .long()
243
+ .to(device)
244
+ )
245
+
246
+ decoder_causal_mask = _create_causal_mask(
247
+ decoder_embeds.size(1), device
248
+ ).expand(N, -1, -1)
249
+ decoder_h = self.decoder(
250
+ self.projection(decoder_embeds), input_pos=c_pos, mask=decoder_causal_mask
251
+ ).to(dtype=dtype)
252
+ c_logits = torch.einsum("bsd,sdv->bsv", decoder_h[:, 1:, :], self.audio_head)
253
+
254
+ c_loss = F.cross_entropy(
255
+ c_logits.reshape(-1, c_logits.size(-1)), target_tokens.reshape(-1)
256
+ )
257
+
258
+ if self.use_text_loss:
259
+ loss = (
260
+ 2
261
+ * (
262
+ (1 - self.decoder_loss_weight) * c0_loss
263
+ + self.decoder_loss_weight * c_loss
264
+ )
265
+ + 0.01 * text_loss
266
+ )
267
+ else:
268
+ loss = 2 * (
269
+ (1 - self.decoder_loss_weight) * c0_loss
270
+ + self.decoder_loss_weight * c_loss
271
+ )
272
+ return loss, text_loss, c0_loss, c_loss
273
+
274
+ def generate_frame(
275
+ self,
276
+ tokens: torch.Tensor,
277
+ tokens_mask: torch.Tensor,
278
+ input_pos: torch.Tensor,
279
+ temperature: float,
280
+ topk: int,
281
+ ) -> torch.Tensor:
282
+ """
283
+ Args:
284
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
285
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
286
+ input_pos: (batch_size, seq_len) positions for each token
287
+ mask: (batch_size, seq_len, max_seq_len
288
+
289
+ Returns:
290
+ (batch_size, audio_num_codebooks) sampled tokens
291
+ """
292
+ dtype = next(self.parameters()).dtype
293
+ b, s, _ = tokens.size()
294
+
295
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
296
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
297
+ embeds = self._embed_tokens(tokens)
298
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
299
+ h = masked_embeds.sum(dim=2)
300
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(
301
+ dtype=dtype
302
+ )
303
+
304
+ last_h = h[:, -1, :]
305
+ c0_logits = self.codebook0_head(last_h)
306
+ c0_sample = sample_topk(c0_logits, topk, temperature)
307
+ c0_embed = self._embed_audio(0, c0_sample)
308
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
309
+ curr_sample = c0_sample.clone()
310
+ curr_pos = (
311
+ torch.arange(0, curr_h.size(1), device=curr_h.device)
312
+ .unsqueeze(0)
313
+ .repeat(curr_h.size(0), 1)
314
+ )
315
+
316
+ # Decoder caches must be reset every frame.
317
+ self.decoder.reset_caches()
318
+ for i in range(1, self.config.audio_num_codebooks):
319
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
320
+ decoder_h = self.decoder(
321
+ self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
322
+ ).to(dtype=dtype)
323
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
324
+ ci_sample = sample_topk(ci_logits, 10, 0.75) # fix to 10 and 0.75
325
+ ci_embed = self._embed_audio(i, ci_sample)
326
+ curr_h = ci_embed
327
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
328
+ curr_pos = curr_pos[:, -1:] + 1
329
+
330
+ return curr_sample
331
+
332
+ def reset_caches(self):
333
+ self.backbone.reset_caches()
334
+ self.decoder.reset_caches()
335
+
336
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
337
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
338
+
339
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
340
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
341
+
342
+ audio_tokens = tokens[:, :, :-1] + (
343
+ self.config.audio_vocab_size
344
+ * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
345
+ )
346
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
347
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
348
+ )
349
+
350
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
351
+
352
+
353
+ if __name__ == "__main__":
354
+
355
+ MIMI_SAMPLE_RATE = 24000
356
+ BACKBONE_FLAVOR = "qwen-3b"
357
+ DECODER_FLAVOR = "qwen-500m"
358
+ TEXT_VOCAB_SIZE = 128256
359
+ AUDIO_VOCAB_SIZE = 2051
360
+ AUDIO_NUM_CODEBOOKS = 32
361
+
362
+ config = ModelArgs(
363
+ backbone_flavor=BACKBONE_FLAVOR,
364
+ decoder_flavor=DECODER_FLAVOR,
365
+ text_vocab_size=TEXT_VOCAB_SIZE,
366
+ audio_vocab_size=AUDIO_VOCAB_SIZE,
367
+ audio_num_codebooks=AUDIO_NUM_CODEBOOKS,
368
+ decoder_loss_weight=0.5,
369
+ use_text_loss=True,
370
+ )
371
+ model = Model(config)
fireredtts2/llm/modules.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchtune.models.qwen2 import qwen2
2
+ from torchtune.modules.transformer import TransformerDecoder
3
+
4
+
5
+ def qwen2_200M() -> TransformerDecoder:
6
+ return qwen2(
7
+ vocab_size=151936,
8
+ num_layers=4,
9
+ num_heads=12,
10
+ num_kv_heads=2,
11
+ embed_dim=1536,
12
+ intermediate_dim=8960,
13
+ max_seq_len=4096,
14
+ attn_dropout=0.0,
15
+ norm_eps=1e-6,
16
+ rope_base=1000000.0,
17
+ tie_word_embeddings=True,
18
+ )
19
+
20
+
21
+ def qwen2_500M() -> TransformerDecoder:
22
+ return qwen2(
23
+ vocab_size=151936,
24
+ num_layers=24,
25
+ num_heads=14,
26
+ num_kv_heads=2,
27
+ embed_dim=896,
28
+ intermediate_dim=4864,
29
+ max_seq_len=4096,
30
+ attn_dropout=0.0,
31
+ norm_eps=1e-6,
32
+ rope_base=1000000.0,
33
+ tie_word_embeddings=True,
34
+ )
35
+
36
+
37
+ def qwen2_1_5B() -> TransformerDecoder:
38
+ return qwen2(
39
+ vocab_size=151936,
40
+ num_layers=28,
41
+ num_heads=12,
42
+ num_kv_heads=2,
43
+ embed_dim=1536,
44
+ intermediate_dim=8960,
45
+ max_seq_len=4096,
46
+ attn_dropout=0.0,
47
+ norm_eps=1e-6,
48
+ rope_base=1000000.0,
49
+ tie_word_embeddings=True,
50
+ )
51
+
52
+
53
+ def qwen2_3B() -> TransformerDecoder:
54
+ return qwen2(
55
+ vocab_size=151936,
56
+ num_layers=36,
57
+ num_heads=16,
58
+ num_kv_heads=2,
59
+ embed_dim=2048,
60
+ intermediate_dim=11008,
61
+ max_seq_len=4096,
62
+ attn_dropout=0.0,
63
+ norm_eps=1e-6,
64
+ rope_base=1000000.0,
65
+ tie_word_embeddings=True,
66
+ )
67
+
68
+
69
+ def qwen2_7B() -> TransformerDecoder:
70
+ return qwen2(
71
+ vocab_size=152064,
72
+ num_layers=28,
73
+ num_heads=28,
74
+ num_kv_heads=4,
75
+ embed_dim=3584,
76
+ intermediate_dim=18944,
77
+ max_seq_len=4096,
78
+ attn_dropout=0.0,
79
+ norm_eps=1e-6,
80
+ rope_base=1000000.0,
81
+ )
82
+
83
+
84
+ FLAVORS = {
85
+ "qwen-200m": qwen2_200M,
86
+ "qwen-500m": qwen2_500M,
87
+ "qwen-1.5b": qwen2_1_5B,
88
+ "qwen-3b": qwen2_3B,
89
+ "qwen-7b": qwen2_7B,
90
+ }
fireredtts2/llm/utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from pathlib import Path
6
+ from dataclasses import dataclass
7
+ from typing import Union
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from transformers import AutoTokenizer
10
+ from fireredtts2.llm.llm import Model, ModelArgs
11
+
12
+
13
+ @dataclass
14
+ class Segment:
15
+ speaker: str
16
+ text: str
17
+ audio: torch.Tensor
18
+
19
+
20
+ class WarmupDecayLR(LambdaLR):
21
+ """
22
+ Learning rate scheduler with a linear warmup and specificable decay.
23
+ """
24
+
25
+ def __init__(
26
+ self, optimizer, warmup_steps: int, total_steps: int, decay_type: str = "linear"
27
+ ):
28
+ self.warmup_steps = warmup_steps
29
+ self.total_steps = total_steps
30
+ self.decay_type = decay_type
31
+ super().__init__(optimizer, self.lr_lambda, last_epoch=-1)
32
+
33
+ def lr_lambda(self, step: int) -> float:
34
+ if step < self.warmup_steps:
35
+ return step / self.warmup_steps
36
+ else:
37
+ if self.decay_type == "linear":
38
+ return (self.total_steps - step) / (
39
+ self.total_steps - self.warmup_steps
40
+ )
41
+ elif self.decay_type == "constant":
42
+ return 1.0
43
+ elif self.decay_type == "exponential":
44
+ return 0.1 ** (
45
+ (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
46
+ )
47
+ elif self.decay_type == "cosine":
48
+ return 0.5 * (
49
+ 1
50
+ + torch.cos(
51
+ torch.pi
52
+ * torch.tensor(
53
+ (step - self.warmup_steps)
54
+ / (self.total_steps - self.warmup_steps)
55
+ )
56
+ )
57
+ )
58
+ else:
59
+ raise ValueError(f"Invalid decay type: {self.decay_type}")
60
+
61
+
62
+ additional_special_tokens = [
63
+ "<|text_start|>",
64
+ "<|text_end|>",
65
+ "[S1]",
66
+ "[S2]",
67
+ "[S3]",
68
+ "[S4]",
69
+ "[S5]",
70
+ "[S6]",
71
+ "[S7]",
72
+ "[S8]",
73
+ "[S9]",
74
+ "[S10]",
75
+ "[S11]",
76
+ "[S12]",
77
+ "[S13]",
78
+ "[S14]",
79
+ "[S15]",
80
+ "[S16]",
81
+ "[S17]",
82
+ "[S18]",
83
+ "[S19]",
84
+ "[S20]",
85
+ "[S21]",
86
+ "[S22]",
87
+ "[S23]",
88
+ "[S24]",
89
+ "[S25]",
90
+ "[S26]",
91
+ "[S27]",
92
+ "[S28]",
93
+ "[S29]",
94
+ "[S30]",
95
+ "[S31]",
96
+ "[S32]",
97
+ "[S33]",
98
+ "[S34]",
99
+ "[S35]",
100
+ "[S36]",
101
+ "[S37]",
102
+ "[S38]",
103
+ "[S39]",
104
+ "[S40]",
105
+ "[S_PODCAST_1]",
106
+ "[S_PODCAST_2]",
107
+ "[S_PODCAST_3]",
108
+ "[S_PODCAST_4]",
109
+ "[S_PODCAST_5]",
110
+ "[S_PODCAST_6]",
111
+ "[S_PODCAST_7]",
112
+ "[S_PODCAST_8]",
113
+ "[S_PODCAST_9]",
114
+ "[S_PODCAST_10]",
115
+ "[S_DIALOG_1]",
116
+ "[S_DIALOG_2]",
117
+ "[S_DIALOG_3]",
118
+ "[S_DIALOG_4]",
119
+ "[S_DIALOG_5]",
120
+ "[S_DIALOG_6]",
121
+ "[S_DIALOG_7]",
122
+ "[S_DIALOG_8]",
123
+ "[S_DIALOG_9]",
124
+ "[S_DIALOG_10]",
125
+ "<|emotion_neutral|>",
126
+ "<|emotion_happy|>",
127
+ "<|emotion_sad|>",
128
+ "<|emotion_concern|>",
129
+ "<|emotion_confuse|>",
130
+ "<|emotion_angry|>",
131
+ "<|emotion_surprise|>",
132
+ "<|emotion_disgust|>",
133
+ "<|emotion_nervous|>",
134
+ "<|emotion_apology|>",
135
+ "<|emotion_understand|>",
136
+ "<|emotion_fear|>",
137
+ "<|emotion_comfort|>",
138
+ "<|emotion_shy|>",
139
+ "<|emotion_serious|>",
140
+ "<|emotion_extra1|>",
141
+ "<|emotion_extra2|>",
142
+ "<|emotion_extra3|>",
143
+ "<|emotion_extra4|>",
144
+ "<|emotion_extra5|>",
145
+ "<|emotion_extra6|>",
146
+ "<|emotion_extra7|>",
147
+ "<|emotion_extra8|>",
148
+ "<|emotion_extra9|>",
149
+ "<|emotion_extra10|>",
150
+ "<|breath|>",
151
+ "<|humph|>",
152
+ "<|laugh_heng|>",
153
+ "<|hissing|>",
154
+ "<|sniff|>",
155
+ "<|laugh_he|>",
156
+ "<|sigh|>",
157
+ "<|laugh|>",
158
+ "<|laugh_ha|>",
159
+ "<|quick_breath|>",
160
+ "<|laugh_hei|>",
161
+ "<|laugh_speak|>",
162
+ "<|/laugh_speak|>",
163
+ "<|cry|>",
164
+ "<|choking|>",
165
+ "<|cry_speak|>",
166
+ "<|/cry_speak|>",
167
+ "<|slurp|>",
168
+ "<|clucking|>",
169
+ "<|yawning|>",
170
+ "<|cough|>",
171
+ "<|smack|>",
172
+ "<|hem|>",
173
+ "<|stretch|>",
174
+ "<|sneeze|>",
175
+ "<|paralinguistic_extra1|>",
176
+ "<|paralinguistic_extra2|>",
177
+ "<|paralinguistic_extra3|>",
178
+ "<|paralinguistic_extra4|>",
179
+ "<|paralinguistic_extra5|>",
180
+ "<|paralinguistic_extra6|>",
181
+ "<|paralinguistic_extra7|>",
182
+ "<|paralinguistic_extra8|>",
183
+ "<|paralinguistic_extra10|>",
184
+ "<|paralinguistic_extra11|>",
185
+ "<|paralinguistic_extra12|>",
186
+ "<|paralinguistic_extra13|>",
187
+ ]
188
+
189
+
190
+ def load_custom_tokenizer(qwen2_tokenizer_path: str):
191
+ tok = AutoTokenizer.from_pretrained(qwen2_tokenizer_path)
192
+ special_tokens_dict = {
193
+ "additional_special_tokens": additional_special_tokens,
194
+ }
195
+ tok.add_special_tokens(special_tokens_dict)
196
+ return tok
197
+
198
+
199
+ def init_weights(model: nn.Module):
200
+ """
201
+ Initialize the weights of the model.
202
+ - Xavier uniform initialization for linear layers
203
+ - Normal initialization for embeddings
204
+ - Xavier uniform initialization for parameters
205
+ """
206
+
207
+ def _init_weights(m):
208
+ if isinstance(m, nn.Linear):
209
+ nn.init.xavier_uniform_(m.weight)
210
+ if m.bias is not None:
211
+ nn.init.zeros_(m.bias)
212
+ elif isinstance(m, nn.Embedding):
213
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
214
+ elif isinstance(m, nn.Parameter):
215
+ nn.init.xavier_uniform_(m.data)
216
+
217
+ model.apply(_init_weights)
218
+
219
+ # Special handling for audio_head because it's nn.Parameter directly
220
+ nn.init.xavier_uniform_(model.audio_head)
221
+
222
+ return model
223
+
224
+
225
+ def load_llm_model(
226
+ configs,
227
+ checkpoint_path: Union[str, Path] = None,
228
+ device: Union[str, torch.device] = "cuda",
229
+ ) -> Model:
230
+ """Load model, add forward method, and move to device.
231
+
232
+ Args:
233
+ model_name_or_checkpoint_path: Name or path of pretrained model or checkpoint.
234
+ device: Device to move the model to.
235
+ decoder_loss_weight: Decoder loss weight.
236
+ """
237
+
238
+ model_arg = ModelArgs(
239
+ backbone_flavor=configs["llm_models"]["backbone_flavor"],
240
+ decoder_flavor=configs["llm_models"]["decoder_flavor"],
241
+ text_vocab_size=configs["llm_models"]["text_vocab_size"],
242
+ audio_vocab_size=configs["llm_models"]["audio_vocab_size"],
243
+ audio_num_codebooks=configs["llm_models"]["audio_num_codebooks"],
244
+ decoder_loss_weight=configs["llm_models"]["decoder_loss_weight"],
245
+ use_text_loss=True,
246
+ )
247
+ model = Model(model_arg)
248
+
249
+ if checkpoint_path and os.path.exists(checkpoint_path):
250
+ state_dict = torch.load(
251
+ checkpoint_path, map_location="cpu", weights_only=False
252
+ )["model"]
253
+ model.load_state_dict(state_dict)
254
+ else:
255
+ model = init_weights(model)
256
+
257
+ model = model.to(device=device)
258
+ return model
259
+
260
+
261
+ def summarize(
262
+ writer,
263
+ global_step,
264
+ scalars={},
265
+ histograms={},
266
+ images={},
267
+ audios={},
268
+ audio_sampling_rate=22050,
269
+ ):
270
+ for k, v in scalars.items():
271
+ writer.add_scalar(k, v, global_step)
272
+ for k, v in histograms.items():
273
+ writer.add_histogram(k, v, global_step)
274
+ for k, v in images.items():
275
+ writer.add_image(k, v, global_step, dataformats="HWC")
276
+ for k, v in audios.items():
277
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
278
+
279
+
280
+ def get_grad_norm(model):
281
+ total_norm = 0
282
+ num = 0
283
+ for name, p in model.named_parameters():
284
+ try:
285
+ param_norm = p.grad.data.norm(2)
286
+ total_norm += param_norm.item() ** 2
287
+ num += 1
288
+ except:
289
+ print(name)
290
+ total_norm = total_norm ** (1.0 / 2)
291
+ total_norm = total_norm / num
292
+ return total_norm
293
+
294
+
295
+ def read_jsonl(path):
296
+ path = os.path.expanduser(path)
297
+ with open(path, "r") as f:
298
+ json_str = f.read()
299
+ data_list = []
300
+ for line in json_str.splitlines():
301
+ data = json.loads(line)
302
+ data_list.append(data)
303
+ return data_list
fireredtts2/utils/spliter.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+ SYMBOLS_MAPPING = {
5
+ "\n": "",
6
+ "\t": "",
7
+ "…": ",",
8
+ "“": "'",
9
+ "”": "'",
10
+ "‘": "'",
11
+ "’": "'",
12
+ "【": "",
13
+ "】": "",
14
+ "[": "",
15
+ "]": "",
16
+ "(": "",
17
+ ")": "",
18
+ "(": "",
19
+ ")": "",
20
+ "・": "",
21
+ "·": "",
22
+ "「": "'",
23
+ "」": "'",
24
+ "《": "'",
25
+ "》": "'",
26
+ "—": "",
27
+ "~": ",",
28
+ "~": ",",
29
+ ":": ",",
30
+ ";": ",",
31
+ ";": ",",
32
+ ":": ",",
33
+ '"': "",
34
+ "!": ",",
35
+ # "!": ".",
36
+ "————": "",
37
+ "——": "",
38
+ "—": "",
39
+ "……": ",",
40
+ "*": "",
41
+ }
42
+
43
+ REPLACE_SYMBOL_REGEX = re.compile(
44
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
45
+ )
46
+
47
+
48
+ EMOJI_REGEX = re.compile(
49
+ "["
50
+ "\U0001f600-\U0001f64f" # emoticons
51
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
52
+ "\U0001f680-\U0001f6ff" # transport & map symbols
53
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
54
+ "]+",
55
+ flags=re.UNICODE,
56
+ )
57
+
58
+
59
+ def clean_text(text):
60
+ # Clean the text
61
+ text = text.strip()
62
+ text = text.replace("\xa0", "")
63
+
64
+ # Replace all chinese symbols with their english counterparts
65
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
66
+
67
+ # Remove emojis
68
+ text = EMOJI_REGEX.sub(r"", text)
69
+
70
+ # Remove continuous periods (...) and commas (,,,)
71
+ text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
72
+
73
+ return text
74
+
75
+
76
+ def utf_8_len(text):
77
+ return len(text.encode("utf-8"))
78
+
79
+
80
+ def break_text(texts, length, splits: set):
81
+ for text in texts:
82
+ if utf_8_len(text) <= length:
83
+ yield text
84
+ continue
85
+
86
+ curr = ""
87
+ for char in text:
88
+ curr += char
89
+
90
+ if char in splits:
91
+ yield curr
92
+ curr = ""
93
+
94
+ if curr:
95
+ yield curr
96
+
97
+
98
+ def break_text_by_length(texts, length):
99
+ for text in texts:
100
+ if utf_8_len(text) <= length:
101
+ yield text
102
+ continue
103
+
104
+ curr = ""
105
+ for char in text:
106
+ curr += char
107
+
108
+ if utf_8_len(curr) >= length:
109
+ yield curr
110
+ curr = ""
111
+
112
+ if curr:
113
+ yield curr
114
+
115
+
116
+ def add_cleaned(curr, segments):
117
+ curr = curr.strip()
118
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
119
+ segments.append(curr)
120
+
121
+
122
+ def protect_float(text):
123
+ # Turns 3.14 into <3_f_14> to prevent splitting
124
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
125
+
126
+
127
+ def unprotect_float(text):
128
+ # Turns <3_f_14> into 3.14
129
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
130
+
131
+
132
+ def split_text(text, length):
133
+ text = clean_text(text)
134
+
135
+ # Break the text into pieces with following rules:
136
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
137
+ # 2. If the text is longer than length, split at ","
138
+ # 3. If the text is still longer than length, split at " "
139
+ # 4. If the text is still longer than length, split at any character to length
140
+
141
+ texts = [text]
142
+ texts = map(protect_float, texts)
143
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
144
+ texts = map(unprotect_float, texts)
145
+ texts = break_text(texts, length, {",", ","})
146
+ texts = break_text(texts, length, {" "})
147
+ texts = list(break_text_by_length(texts, length))
148
+
149
+ # Then, merge the texts into segments with length <= length
150
+ segments = []
151
+ curr = ""
152
+
153
+ for text in texts:
154
+ if utf_8_len(curr) + utf_8_len(text) <= length:
155
+ curr += text
156
+ else:
157
+ add_cleaned(curr, segments)
158
+ curr = text
159
+
160
+ if curr:
161
+ add_cleaned(curr, segments)
162
+
163
+ return segments
164
+
165
+
166
+ def contains_chinese(text):
167
+ """检测文本是否包含中文字符"""
168
+ return bool(re.search(r"[\u4e00-\u9fff]", text))
169
+
170
+
171
+ def count_words_english(text):
172
+ """统计英文单词数量"""
173
+ return len(text.split())
174
+
175
+
176
+ def count_characters_chinese(text):
177
+ """统计中文字符数量"""
178
+ return len(text)
179
+
180
+
181
+ def split_by_punctuation_english(text):
182
+ """按英文标点符号分割"""
183
+ sentences = re.split(r"([.!?])", text)
184
+ result = []
185
+ for i in range(0, len(sentences) - 1, 2):
186
+ sentence = sentences[i].strip()
187
+ if sentence:
188
+ if i + 1 < len(sentences):
189
+ sentence += sentences[i + 1]
190
+ result.append(sentence)
191
+
192
+ if len(sentences) % 2 == 1 and sentences[-1].strip():
193
+ result.append(sentences[-1].strip())
194
+
195
+ return result
196
+
197
+
198
+ def split_by_punctuation_chinese(text):
199
+ """按中文标点符号分割"""
200
+ sentences = re.split(r"([。!?])", text)
201
+ result = []
202
+ for i in range(0, len(sentences) - 1, 2):
203
+ sentence = sentences[i].strip()
204
+ if sentence:
205
+ if i + 1 < len(sentences):
206
+ sentence += sentences[i + 1]
207
+ result.append(sentence)
208
+
209
+ if len(sentences) % 2 == 1 and sentences[-1].strip():
210
+ result.append(sentences[-1].strip())
211
+
212
+ return result
213
+
214
+
215
+ def merge_sentences_english(sentences, max_words=80):
216
+ """合并英文句子"""
217
+ result = []
218
+ current_chunk = ""
219
+
220
+ for sentence in sentences:
221
+ if not current_chunk:
222
+ current_chunk = sentence
223
+ else:
224
+ test_chunk = current_chunk + " " + sentence
225
+ if count_words_english(test_chunk) <= max_words:
226
+ current_chunk = test_chunk
227
+ else:
228
+ result.append(current_chunk)
229
+ current_chunk = sentence
230
+
231
+ if current_chunk:
232
+ result.append(current_chunk)
233
+
234
+ return result
235
+
236
+
237
+ def merge_sentences_chinese(sentences, max_chars=100):
238
+ """合并中文句子"""
239
+ result = []
240
+ current_chunk = ""
241
+
242
+ for sentence in sentences:
243
+ if not current_chunk:
244
+ current_chunk = sentence
245
+ else:
246
+ test_chunk = current_chunk + sentence
247
+ if count_characters_chinese(test_chunk) <= max_chars:
248
+ current_chunk = test_chunk
249
+ else:
250
+ result.append(current_chunk)
251
+ current_chunk = sentence
252
+
253
+ if current_chunk:
254
+ result.append(current_chunk)
255
+
256
+ return result
257
+
258
+
259
+ def process_text(text):
260
+ chinese_max_limit = 150
261
+ english_max_limit = 80
262
+ # 移除开头的标记如[S2]
263
+ text = re.sub(r"^\[S\d+\]", "", text).strip()
264
+ is_chinese = contains_chinese(text)
265
+ if is_chinese:
266
+ if count_characters_chinese(text) <= chinese_max_limit:
267
+ return [text]
268
+ sentences = split_by_punctuation_chinese(text)
269
+ result = merge_sentences_chinese(sentences, chinese_max_limit)
270
+ else:
271
+ if count_words_english(text) <= english_max_limit:
272
+ return [text]
273
+ sentences = split_by_punctuation_english(text)
274
+ result = merge_sentences_english(sentences, english_max_limit)
275
+
276
+ return result
277
+
278
+
279
+ def process_text_list(text_list):
280
+ new_text_list = []
281
+ for text in text_list:
282
+ speaker = text[:4]
283
+ # print("---speaker:", speaker)
284
+ assert speaker in ["[S1]", "[S2]", "[S3]", "[S4]"]
285
+ result = process_text(text=text)
286
+ # print("---result:\n", result, len(result))
287
+ for chunk in result:
288
+ new_text_list.append(speaker + chunk)
289
+ return new_text_list
pretrained_models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Put the pre-trained model in this folder.
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torchaudio
2
+ torchtune
3
+ torchao
4
+ transformers
5
+ einops
6
+ gradio